From fbfee0a2bb436a6205d67f561dbd6284621504d6 Mon Sep 17 00:00:00 2001 From: metamuffin Date: Wed, 20 Sep 2023 16:55:50 +0200 Subject: move embedder to module --- src/embedders/mod.rs | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 src/embedders/mod.rs (limited to 'src/embedders/mod.rs') diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs new file mode 100644 index 0000000..353222b --- /dev/null +++ b/src/embedders/mod.rs @@ -0,0 +1,51 @@ +pub mod ai; +pub mod pure; +pub(crate) use ai::*; +pub(crate) use pure::*; + +use anyhow::Result; +use indicatif::{ParallelProgressIterator, ProgressStyle}; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; + +pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { + fn dist(&self, _: &Self) -> f64; +} + +impl MetricElem for f64 { + fn dist(&self, b: &f64) -> f64 { + (self - b).abs() + } +} + +pub trait EmbedderT: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embed(&self, _: &Path) -> Result; +} + +pub trait BatchEmbedder: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embeds(&mut self, _: &[PathBuf]) -> Result>; +} + +impl BatchEmbedder for T { + type Embedding = T::Embedding; + const NAME: &'static str = T::NAME; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result> { + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + paths + .par_iter() + .progress_with_style(st) + .map(|p| self.embed(p)) + .collect::>() + .into_iter() + .try_collect() + } +} -- cgit v1.2.3-70-g09d2