use anyhow::{anyhow, Result}; use fastembed::{ImageEmbedding, ImageInitOptions}; use indicatif::{ProgressBar, ProgressStyle}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use std::{marker::PhantomData, path::PathBuf}; use super::vecmetric::VecMetric; use crate::{BatchEmbedder, Config}; pub(crate) struct ContentEmbedder<'a, Metric> { cfg: &'a Config, _sim: PhantomData, } impl<'a, Metric> ContentEmbedder<'a, Metric> { pub(crate) fn new(cfg: &'a Config) -> Self { ContentEmbedder { cfg, _sim: PhantomData } } } impl ContentEmbedder<'_, Metric> { fn embeds_or_err( &mut self, paths: &[PathBuf], ) -> Result::Embedding>>> { let mut options = ImageInitOptions::default(); options.cache_dir = self.cfg.base_dirs.get_cache_home(); let embedder = ImageEmbedding::try_new(options)?; let bar = ProgressBar::new(paths.len() as u64); bar.set_style(ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} {msg}")?); bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Embedding images..."); let mut res = Vec::with_capacity(paths.len()); // fastembeds supports batched processing, but does not support error reporting on a // per-image basis. Thus, we first try embedding 64 images at once, and if that fails, fall // back to passing them to fastembeds one-by-one, so that we can get all the non-failure // results. for chunk in paths.chunks(64) { match embedder.embed(chunk.iter().collect(), Some(8)) { Ok(embeds) => res.extend(embeds.into_iter().map(|e| Ok(e.into()))), Err(_) => { // embed one by one let mut embeds = chunk .par_iter() .map(|path| match embedder.embed(vec![path], Some(1)) { Err(e) => Err(e), Ok(mut embed) if embed.len() == 1 => Ok(embed.pop().unwrap().into()), Ok(embed) => { Err(anyhow!("Embedder did not return a single value: {embed:?}")) } }) .collect(); res.append(&mut embeds); } } bar.inc(64); } Ok(res) } } impl BatchEmbedder for ContentEmbedder<'_, Metric> { type Embedding = Metric; const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Vec> { match self.embeds_or_err(paths) { Ok(embeddings) => embeddings, Err(e) => vec![Err(e)], } } }