diff options
Diffstat (limited to 'src/ai_embedders.rs')
-rw-r--r-- | src/ai_embedders.rs | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs index e8e6c8b..3848674 100644 --- a/src/ai_embedders.rs +++ b/src/ai_embedders.rs @@ -1,16 +1,23 @@ use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; use serde::{Deserialize, Serialize}; -use std::{io::{BufRead, BufReader, copy, Cursor}, path::PathBuf, process::{Command, Stdio}}; +use std::{ + fs::{remove_file, File}, + io::{copy, BufRead, BufReader, Cursor}, + path::PathBuf, + process::{Command, Stdio}, +}; -use crate::{Config, BatchEmbedder, MetricElem}; +use crate::{BatchEmbedder, Config, MetricElem}; #[repr(transparent)] #[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding (Vec<f32>); +pub(crate) struct Imgbedding(Vec<f32>); impl MetricElem for Imgbedding { fn dist(&self, other: &Self) -> f64 { - self.0.iter().zip(other.0.iter()) + self.0 + .iter() + .zip(other.0.iter()) .map(|(a, b)| (a - b).powf(2.)) .sum::<f32>() .sqrt() as f64 @@ -22,15 +29,19 @@ pub(crate) struct ContentEmbedder<'a> { } impl<'a> ContentEmbedder<'a> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg: cfg } + ContentEmbedder { cfg } } } impl<'a> Drop for ContentEmbedder<'a> { fn drop(&mut self) { - self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py") + self.cfg + .base_dirs + .place_runtime_file("imgbeddings-api.py") .iter() - .for_each(|p| { let _ = std::fs::remove_file(&p); }); + .for_each(|p| { + let _ = remove_file(p); + }); } } @@ -39,13 +50,17 @@ impl BatchEmbedder for ContentEmbedder<'_> { const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { - let venv_dir = self.cfg.base_dirs + let venv_dir = self + .cfg + .base_dirs .create_data_directory("imgbeddings-venv")?; - let script_file = self.cfg.base_dirs + let script_file = self + .cfg + .base_dirs .place_runtime_file("imgbeddings-api.py")?; let api_prog = include_bytes!("imgbeddings-api.py"); - copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?; + copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); @@ -73,9 +88,12 @@ impl BatchEmbedder for ContentEmbedder<'_> { .stdout(Stdio::piped()) .spawn()?; + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + let bar = ProgressBar::new(paths.len() as u64).with_style(st); BufReader::new(child.stdout.unwrap()) .lines() - .progress_count(paths.len().try_into().unwrap()) + .progress_with(bar) .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) .try_collect() } |