diff options
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r-- | src/embedders/ai.rs | 105 |
1 files changed, 38 insertions, 67 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 7d5ae90..8c9de11 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -1,12 +1,8 @@ -use anyhow::Result; -use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use std::{ - fs::{remove_file, File}, - io::{copy, BufRead, BufReader, Cursor}, - marker::PhantomData, - path::PathBuf, - process::{Command, Stdio}, -}; +use anyhow::{anyhow, Result}; +use fastembed::{ImageEmbedding, ImageInitOptions}; +use indicatif::{ProgressBar, ProgressStyle}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use std::{marker::PhantomData, path::PathBuf}; use super::vecmetric::VecMetric; use crate::{BatchEmbedder, Config}; @@ -17,22 +13,7 @@ pub(crate) struct ContentEmbedder<'a, Metric> { } impl<'a, Metric> ContentEmbedder<'a, Metric> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { - cfg, - _sim: PhantomData, - } - } -} - -impl<Metric> Drop for ContentEmbedder<'_, Metric> { - fn drop(&mut self) { - self.cfg - .base_dirs - .place_runtime_file("imgbeddings-api.py") - .iter() - .for_each(|p| { - let _ = remove_file(p); - }); + ContentEmbedder { cfg, _sim: PhantomData } } } @@ -41,54 +22,44 @@ impl<Metric: VecMetric> ContentEmbedder<'_, Metric> { &mut self, paths: &[PathBuf], ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> { - let venv_dir = self - .cfg - .base_dirs - .create_data_directory("imgbeddings-venv")?; - let script_file = self - .cfg - .base_dirs - .place_runtime_file("imgbeddings-api.py")?; - - let api_prog = include_bytes!("imgbeddings-api.py"); - copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; + let mut options = ImageInitOptions::default(); + options.cache_dir = self.cfg.base_dirs.get_cache_home(); + let embedder = ImageEmbedding::try_new(options)?; - let bar = ProgressBar::new_spinner(); - bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); + let bar = ProgressBar::new(paths.len() as u64); + bar.set_style(ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} {msg}")?); bar.enable_steady_tick(std::time::Duration::from_millis(100)); + bar.set_message("Embedding images..."); - bar.set_message("Creating venv..."); - Command::new("python3") - .args(["-m", "venv", venv_dir.to_str().unwrap()]) - .stdout(Stdio::null()) - .spawn()? - .wait()?; + let mut res = Vec::with_capacity(paths.len()); - bar.set_message("Installing/checking packages..."); - Command::new(venv_dir.join("bin/pip3")) - .args(["install", "imgbeddings"]) - .stdout(Stdio::null()) - .spawn()? - .wait()?; - bar.finish(); + // fastembeds supports batched processing, but does not support error reporting on a + // per-image basis. Thus, we first try embedding 64 images at once, and if that fails, fall + // back to passing them to fastembeds one-by-one, so that we can get all the non-failure + // results. + for chunk in paths.chunks(64) { + match embedder.embed(chunk.iter().collect(), Some(8)) { + Ok(embeds) => res.extend(embeds.into_iter().map(|e| Ok(e.into()))), + Err(_) => { + // embed one by one + let mut embeds = chunk + .par_iter() + .map(|path| match embedder.embed(vec![path], Some(1)) { + Err(e) => Err(e), + Ok(mut embed) if embed.len() == 1 => Ok(embed.pop().unwrap().into()), + Ok(embed) => { + Err(anyhow!("Embedder did not return a single value: {embed:?}")) + } + }) + .collect(); - let child = Command::new(venv_dir.join("bin/python3")) - .arg(script_file) - .args(paths) - .stderr(Stdio::inherit()) - .stdout(Stdio::piped()) - .spawn()?; + res.append(&mut embeds); + } + } + bar.inc(64); + } - // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück. - // python-code muss dafür auch geändert werden xD - let st = - ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; - let bar = ProgressBar::new(paths.len() as u64).with_style(st); - Ok(BufReader::new(child.stdout.unwrap()) - .lines() - .progress_with(bar) - .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) - .collect()) + Ok(res) } } |