aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders/mod.rs')
-rw-r--r--src/embedders/mod.rs51
1 files changed, 51 insertions, 0 deletions
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
new file mode 100644
index 0000000..353222b
--- /dev/null
+++ b/src/embedders/mod.rs
@@ -0,0 +1,51 @@
+pub mod ai;
+pub mod pure;
+pub(crate) use ai::*;
+pub(crate) use pure::*;
+
+use anyhow::Result;
+use indicatif::{ParallelProgressIterator, ProgressStyle};
+use rayon::prelude::*;
+use serde::{Deserialize, Serialize};
+use std::path::{Path, PathBuf};
+
+pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
+ fn dist(&self, _: &Self) -> f64;
+}
+
+impl MetricElem for f64 {
+ fn dist(&self, b: &f64) -> f64 {
+ (self - b).abs()
+ }
+}
+
+pub trait EmbedderT: Send + Sync {
+ type Embedding: MetricElem;
+ const NAME: &'static str;
+
+ fn embed(&self, _: &Path) -> Result<Self::Embedding>;
+}
+
+pub trait BatchEmbedder: Send + Sync {
+ type Embedding: MetricElem;
+ const NAME: &'static str;
+
+ fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
+}
+
+impl<T: EmbedderT> BatchEmbedder for T {
+ type Embedding = T::Embedding;
+ const NAME: &'static str = T::NAME;
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ paths
+ .par_iter()
+ .progress_with_style(st)
+ .map(|p| self.embed(p))
+ .collect::<Vec<_>>()
+ .into_iter()
+ .try_collect()
+ }
+}