diff options
Diffstat (limited to 'src/embedders/mod.rs')
-rw-r--r-- | src/embedders/mod.rs | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs new file mode 100644 index 0000000..353222b --- /dev/null +++ b/src/embedders/mod.rs @@ -0,0 +1,51 @@ +pub mod ai; +pub mod pure; +pub(crate) use ai::*; +pub(crate) use pure::*; + +use anyhow::Result; +use indicatif::{ParallelProgressIterator, ProgressStyle}; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; + +pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { + fn dist(&self, _: &Self) -> f64; +} + +impl MetricElem for f64 { + fn dist(&self, b: &f64) -> f64 { + (self - b).abs() + } +} + +pub trait EmbedderT: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embed(&self, _: &Path) -> Result<Self::Embedding>; +} + +pub trait BatchEmbedder: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>; +} + +impl<T: EmbedderT> BatchEmbedder for T { + type Embedding = T::Embedding; + const NAME: &'static str = T::NAME; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + paths + .par_iter() + .progress_with_style(st) + .map(|p| self.embed(p)) + .collect::<Vec<_>>() + .into_iter() + .try_collect() + } +} |