aboutsummaryrefslogtreecommitdiff
path: root/src/embedders.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders.rs')
-rw-r--r--src/embedders.rs147
1 files changed, 76 insertions, 71 deletions
diff --git a/src/embedders.rs b/src/embedders.rs
index 94fc122..0693b5e 100644
--- a/src/embedders.rs
+++ b/src/embedders.rs
@@ -1,7 +1,8 @@
use rayon::prelude::*;
use std::path::PathBuf;
+use serde::{Deserialize, Serialize};
-pub trait MetricElem {
+pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
fn dist(&self, _: &Self) -> f64;
}
@@ -11,39 +12,55 @@ impl MetricElem for f64 {
}
}
-pub trait EmbedderT {
+pub trait EmbedderT: Send + Sync {
type Embedding: MetricElem;
+ const NAME: &'static str;
- fn embed(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>;
+ fn embed(&self, _: &PathBuf) -> Result<Self::Embedding, String>;
+}
+
+pub trait BatchEmbedder: Send + Sync {
+ type Embedding: MetricElem;
+ const NAME: &'static str;
+
+ fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>;
+}
+
+impl<T: EmbedderT> BatchEmbedder for T {
+ type Embedding = T::Embedding;
+ const NAME: &'static str = T::NAME;
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>, String> {
+ paths.par_iter()
+ .map(|p| self.embed(p))
+ .collect::<Vec<_>>()
+ .into_iter()
+ .try_collect()
+ }
}
pub struct BrightnessEmbedder;
impl EmbedderT for BrightnessEmbedder {
type Embedding = f64;
+ const NAME: &'static str = "Brightness";
- fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<f64>, String> {
- paths
- .par_iter()
- .map(|p| {
- let im = image::open(p).map_err(|e| e.to_string())?;
- let num_bytes = 3 * (im.height() * im.width());
+ fn embed(&self, path: &PathBuf) -> Result<f64, String> {
+ let im = image::open(path).map_err(|e| e.to_string())?;
+ let num_bytes = 3 * (im.height() * im.width());
- if num_bytes == 0 {
- Err("Encountered NaN brightness, due to an empty image")?;
- }
+ if num_bytes == 0 {
+ return Err("Encountered NaN brightness, due to an empty image".to_string());
+ }
- Ok(im.to_rgb8()
- .iter()
- .map(|e| *e as u64)
- .sum::<u64>() as f64 / num_bytes as f64)
- })
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ Ok(im.to_rgb8()
+ .iter()
+ .map(|e| *e as u64)
+ .sum::<u64>() as f64 / num_bytes as f64)
}
}
#[repr(transparent)]
+#[derive(Serialize, Deserialize)]
pub struct Hue (f64);
impl MetricElem for Hue {
fn dist(&self, b: &Hue) -> f64 {
@@ -55,42 +72,36 @@ impl MetricElem for Hue {
pub struct HueEmbedder;
impl EmbedderT for HueEmbedder {
type Embedding = Hue;
+ const NAME: &'static str = "Hue";
- fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<Hue>, String> {
- paths
- .par_iter()
- .map(|p| {
- let im = image::open(p).map_err(|e| e.to_string())?;
- let num_pixels = im.height() * im.width();
- let [sr, sg, sb] = im
- .to_rgb8()
- .pixels()
- .fold([0, 0, 0], |[or, og, ob], n| {
- let [nr, ng, nb] = n.0;
- [or + nr as u64, og + ng as u64, ob + nb as u64]
- })
- .map(|e| e as f64 / 255. / num_pixels as f64);
+ fn embed(&self, path: &PathBuf) -> Result<Hue, String> {
+ let im = image::open(path).map_err(|e| e.to_string())?;
+ let num_pixels = im.height() * im.width();
+ let [sr, sg, sb] = im
+ .to_rgb8()
+ .pixels()
+ .fold([0, 0, 0], |[or, og, ob], n| {
+ let [nr, ng, nb] = n.0;
+ [or + nr as u64, og + ng as u64, ob + nb as u64]
+ })
+ .map(|e| e as f64 / 255. / num_pixels as f64);
- let hue =
- if sr >= sg && sr >= sb {
- (sg - sb) / (sr - sg.min(sb))
- }
- else if sg >= sb {
- 2. + (sb - sr) / (sg - sr.min(sb))
- }
- else {
- 4. + (sr - sg) / (sb - sr.min(sg))
- };
+ let hue =
+ if sr >= sg && sr >= sb {
+ (sg - sb) / (sr - sg.min(sb))
+ }
+ else if sg >= sb {
+ 2. + (sb - sr) / (sg - sr.min(sb))
+ }
+ else {
+ 4. + (sr - sg) / (sb - sr.min(sg))
+ };
- if hue.is_nan() {
- Err("Encountered NaN hue, possibly because of a colorless or empty image")?;
- }
+ if hue.is_nan() {
+ return Err("Encountered NaN hue, possibly because of a colorless or empty image".to_string());
+ }
- Ok(Hue(hue))
- })
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ Ok(Hue(hue))
}
}
@@ -105,26 +116,20 @@ impl MetricElem for (f64, f64, f64) {
pub struct ColorEmbedder;
impl EmbedderT for ColorEmbedder {
type Embedding = (f64, f64, f64);
+ const NAME: &'static str = "Color";
- fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<(f64, f64, f64)>, String> {
- paths
- .par_iter()
- .map(|p| {
- let im = image::open(p).map_err(|e| e.to_string())?;
- let num_pixels = im.height() * im.width();
- let [sr, sg, sb] = im
- .to_rgb8()
- .pixels()
- .fold([0, 0, 0], |[or, og, ob], n| {
- let [nr, ng, nb] = n.0;
- [or + nr as u64, og + ng as u64, ob + nb as u64]
- })
- .map(|e| e as f64 / num_pixels as f64);
-
- Ok((sr, sg, sb))
+ fn embed(&self, path: &PathBuf) -> Result<(f64, f64, f64), String> {
+ let im = image::open(path).map_err(|e| e.to_string())?;
+ let num_pixels = im.height() * im.width();
+ let [sr, sg, sb] = im
+ .to_rgb8()
+ .pixels()
+ .fold([0, 0, 0], |[or, og, ob], n| {
+ let [nr, ng, nb] = n.0;
+ [or + nr as u64, og + ng as u64, ob + nb as u64]
})
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ .map(|e| e as f64 / num_pixels as f64);
+
+ Ok((sr, sg, sb))
}
}