diff options
Diffstat (limited to 'src/embedders.rs')
-rw-r--r-- | src/embedders.rs | 147 |
1 files changed, 76 insertions, 71 deletions
diff --git a/src/embedders.rs b/src/embedders.rs index 94fc122..0693b5e 100644 --- a/src/embedders.rs +++ b/src/embedders.rs @@ -1,7 +1,8 @@ use rayon::prelude::*; use std::path::PathBuf; +use serde::{Deserialize, Serialize}; -pub trait MetricElem { +pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { fn dist(&self, _: &Self) -> f64; } @@ -11,39 +12,55 @@ impl MetricElem for f64 { } } -pub trait EmbedderT { +pub trait EmbedderT: Send + Sync { type Embedding: MetricElem; + const NAME: &'static str; - fn embed(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>; + fn embed(&self, _: &PathBuf) -> Result<Self::Embedding, String>; +} + +pub trait BatchEmbedder: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>; +} + +impl<T: EmbedderT> BatchEmbedder for T { + type Embedding = T::Embedding; + const NAME: &'static str = T::NAME; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>, String> { + paths.par_iter() + .map(|p| self.embed(p)) + .collect::<Vec<_>>() + .into_iter() + .try_collect() + } } pub struct BrightnessEmbedder; impl EmbedderT for BrightnessEmbedder { type Embedding = f64; + const NAME: &'static str = "Brightness"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<f64>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_bytes = 3 * (im.height() * im.width()); + fn embed(&self, path: &PathBuf) -> Result<f64, String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_bytes = 3 * (im.height() * im.width()); - if num_bytes == 0 { - Err("Encountered NaN brightness, due to an empty image")?; - } + if num_bytes == 0 { + return Err("Encountered NaN brightness, due to an empty image".to_string()); + } - Ok(im.to_rgb8() - .iter() - .map(|e| *e as u64) - .sum::<u64>() as f64 / num_bytes as f64) - }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + Ok(im.to_rgb8() + .iter() + .map(|e| *e as u64) + .sum::<u64>() as f64 / num_bytes as f64) } } #[repr(transparent)] +#[derive(Serialize, Deserialize)] pub struct Hue (f64); impl MetricElem for Hue { fn dist(&self, b: &Hue) -> f64 { @@ -55,42 +72,36 @@ impl MetricElem for Hue { pub struct HueEmbedder; impl EmbedderT for HueEmbedder { type Embedding = Hue; + const NAME: &'static str = "Hue"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<Hue>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / 255. / num_pixels as f64); + fn embed(&self, path: &PathBuf) -> Result<Hue, String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / 255. / num_pixels as f64); - let hue = - if sr >= sg && sr >= sb { - (sg - sb) / (sr - sg.min(sb)) - } - else if sg >= sb { - 2. + (sb - sr) / (sg - sr.min(sb)) - } - else { - 4. + (sr - sg) / (sb - sr.min(sg)) - }; + let hue = + if sr >= sg && sr >= sb { + (sg - sb) / (sr - sg.min(sb)) + } + else if sg >= sb { + 2. + (sb - sr) / (sg - sr.min(sb)) + } + else { + 4. + (sr - sg) / (sb - sr.min(sg)) + }; - if hue.is_nan() { - Err("Encountered NaN hue, possibly because of a colorless or empty image")?; - } + if hue.is_nan() { + return Err("Encountered NaN hue, possibly because of a colorless or empty image".to_string()); + } - Ok(Hue(hue)) - }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + Ok(Hue(hue)) } } @@ -105,26 +116,20 @@ impl MetricElem for (f64, f64, f64) { pub struct ColorEmbedder; impl EmbedderT for ColorEmbedder { type Embedding = (f64, f64, f64); + const NAME: &'static str = "Color"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<(f64, f64, f64)>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / num_pixels as f64); - - Ok((sr, sg, sb)) + fn embed(&self, path: &PathBuf) -> Result<(f64, f64, f64), String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + .map(|e| e as f64 / num_pixels as f64); + + Ok((sr, sg, sb)) } } |