aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/ai.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r--src/embedders/ai.rs105
1 files changed, 38 insertions, 67 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 7d5ae90..8c9de11 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -1,12 +1,8 @@
-use anyhow::Result;
-use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
-use std::{
- fs::{remove_file, File},
- io::{copy, BufRead, BufReader, Cursor},
- marker::PhantomData,
- path::PathBuf,
- process::{Command, Stdio},
-};
+use anyhow::{anyhow, Result};
+use fastembed::{ImageEmbedding, ImageInitOptions};
+use indicatif::{ProgressBar, ProgressStyle};
+use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
+use std::{marker::PhantomData, path::PathBuf};
use super::vecmetric::VecMetric;
use crate::{BatchEmbedder, Config};
@@ -17,22 +13,7 @@ pub(crate) struct ContentEmbedder<'a, Metric> {
}
impl<'a, Metric> ContentEmbedder<'a, Metric> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder {
- cfg,
- _sim: PhantomData,
- }
- }
-}
-
-impl<Metric> Drop for ContentEmbedder<'_, Metric> {
- fn drop(&mut self) {
- self.cfg
- .base_dirs
- .place_runtime_file("imgbeddings-api.py")
- .iter()
- .for_each(|p| {
- let _ = remove_file(p);
- });
+ ContentEmbedder { cfg, _sim: PhantomData }
}
}
@@ -41,54 +22,44 @@ impl<Metric: VecMetric> ContentEmbedder<'_, Metric> {
&mut self,
paths: &[PathBuf],
) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> {
- let venv_dir = self
- .cfg
- .base_dirs
- .create_data_directory("imgbeddings-venv")?;
- let script_file = self
- .cfg
- .base_dirs
- .place_runtime_file("imgbeddings-api.py")?;
-
- let api_prog = include_bytes!("imgbeddings-api.py");
- copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;
+ let mut options = ImageInitOptions::default();
+ options.cache_dir = self.cfg.base_dirs.get_cache_home();
+ let embedder = ImageEmbedding::try_new(options)?;
- let bar = ProgressBar::new_spinner();
- bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
+ let bar = ProgressBar::new(paths.len() as u64);
+ bar.set_style(ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} {msg}")?);
bar.enable_steady_tick(std::time::Duration::from_millis(100));
+ bar.set_message("Embedding images...");
- bar.set_message("Creating venv...");
- Command::new("python3")
- .args(["-m", "venv", venv_dir.to_str().unwrap()])
- .stdout(Stdio::null())
- .spawn()?
- .wait()?;
+ let mut res = Vec::with_capacity(paths.len());
- bar.set_message("Installing/checking packages...");
- Command::new(venv_dir.join("bin/pip3"))
- .args(["install", "imgbeddings"])
- .stdout(Stdio::null())
- .spawn()?
- .wait()?;
- bar.finish();
+ // fastembeds supports batched processing, but does not support error reporting on a
+ // per-image basis. Thus, we first try embedding 64 images at once, and if that fails, fall
+ // back to passing them to fastembeds one-by-one, so that we can get all the non-failure
+ // results.
+ for chunk in paths.chunks(64) {
+ match embedder.embed(chunk.iter().collect(), Some(8)) {
+ Ok(embeds) => res.extend(embeds.into_iter().map(|e| Ok(e.into()))),
+ Err(_) => {
+ // embed one by one
+ let mut embeds = chunk
+ .par_iter()
+ .map(|path| match embedder.embed(vec![path], Some(1)) {
+ Err(e) => Err(e),
+ Ok(mut embed) if embed.len() == 1 => Ok(embed.pop().unwrap().into()),
+ Ok(embed) => {
+ Err(anyhow!("Embedder did not return a single value: {embed:?}"))
+ }
+ })
+ .collect();
- let child = Command::new(venv_dir.join("bin/python3"))
- .arg(script_file)
- .args(paths)
- .stderr(Stdio::inherit())
- .stdout(Stdio::piped())
- .spawn()?;
+ res.append(&mut embeds);
+ }
+ }
+ bar.inc(64);
+ }
- // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück.
- // python-code muss dafür auch geändert werden xD
- let st =
- ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
- let bar = ProgressBar::new(paths.len() as u64).with_style(st);
- Ok(BufReader::new(child.stdout.unwrap())
- .lines()
- .progress_with(bar)
- .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
- .collect())
+ Ok(res)
}
}