diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-06 13:37:22 +0200 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-06 13:37:22 +0200 |
commit | cb67cd7389b68750518b67c7a58beb7659298352 (patch) | |
tree | af6cb44cb47dacb82127f5bcf586c5b60aa4603a /src | |
parent | 3fe8109102083ef32da9e11eef5dc45dba530333 (diff) | |
download | embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.bz2 embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.zst |
cache embeddings by path (should be by hash, but thats for later)
Diffstat (limited to 'src')
-rw-r--r-- | src/embedders.rs | 147 | ||||
-rw-r--r-- | src/main.rs | 104 |
2 files changed, 147 insertions, 104 deletions
diff --git a/src/embedders.rs b/src/embedders.rs index 94fc122..0693b5e 100644 --- a/src/embedders.rs +++ b/src/embedders.rs @@ -1,7 +1,8 @@ use rayon::prelude::*; use std::path::PathBuf; +use serde::{Deserialize, Serialize}; -pub trait MetricElem { +pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { fn dist(&self, _: &Self) -> f64; } @@ -11,39 +12,55 @@ impl MetricElem for f64 { } } -pub trait EmbedderT { +pub trait EmbedderT: Send + Sync { type Embedding: MetricElem; + const NAME: &'static str; - fn embed(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>; + fn embed(&self, _: &PathBuf) -> Result<Self::Embedding, String>; +} + +pub trait BatchEmbedder: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>; +} + +impl<T: EmbedderT> BatchEmbedder for T { + type Embedding = T::Embedding; + const NAME: &'static str = T::NAME; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>, String> { + paths.par_iter() + .map(|p| self.embed(p)) + .collect::<Vec<_>>() + .into_iter() + .try_collect() + } } pub struct BrightnessEmbedder; impl EmbedderT for BrightnessEmbedder { type Embedding = f64; + const NAME: &'static str = "Brightness"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<f64>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_bytes = 3 * (im.height() * im.width()); + fn embed(&self, path: &PathBuf) -> Result<f64, String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_bytes = 3 * (im.height() * im.width()); - if num_bytes == 0 { - Err("Encountered NaN brightness, due to an empty image")?; - } + if num_bytes == 0 { + return Err("Encountered NaN brightness, due to an empty image".to_string()); + } - Ok(im.to_rgb8() - .iter() - .map(|e| *e as u64) - .sum::<u64>() as f64 / num_bytes as f64) - }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + Ok(im.to_rgb8() + .iter() + .map(|e| *e as u64) + .sum::<u64>() as f64 / num_bytes as f64) } } #[repr(transparent)] +#[derive(Serialize, Deserialize)] pub struct Hue (f64); impl MetricElem for Hue { fn dist(&self, b: &Hue) -> f64 { @@ -55,42 +72,36 @@ impl MetricElem for Hue { pub struct HueEmbedder; impl EmbedderT for HueEmbedder { type Embedding = Hue; + const NAME: &'static str = "Hue"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<Hue>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / 255. / num_pixels as f64); + fn embed(&self, path: &PathBuf) -> Result<Hue, String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / 255. / num_pixels as f64); - let hue = - if sr >= sg && sr >= sb { - (sg - sb) / (sr - sg.min(sb)) - } - else if sg >= sb { - 2. + (sb - sr) / (sg - sr.min(sb)) - } - else { - 4. + (sr - sg) / (sb - sr.min(sg)) - }; + let hue = + if sr >= sg && sr >= sb { + (sg - sb) / (sr - sg.min(sb)) + } + else if sg >= sb { + 2. + (sb - sr) / (sg - sr.min(sb)) + } + else { + 4. + (sr - sg) / (sb - sr.min(sg)) + }; - if hue.is_nan() { - Err("Encountered NaN hue, possibly because of a colorless or empty image")?; - } + if hue.is_nan() { + return Err("Encountered NaN hue, possibly because of a colorless or empty image".to_string()); + } - Ok(Hue(hue)) - }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + Ok(Hue(hue)) } } @@ -105,26 +116,20 @@ impl MetricElem for (f64, f64, f64) { pub struct ColorEmbedder; impl EmbedderT for ColorEmbedder { type Embedding = (f64, f64, f64); + const NAME: &'static str = "Color"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<(f64, f64, f64)>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / num_pixels as f64); - - Ok((sr, sg, sb)) + fn embed(&self, path: &PathBuf) -> Result<(f64, f64, f64), String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + .map(|e| e as f64 / num_pixels as f64); + + Ok((sr, sg, sb)) } } diff --git a/src/main.rs b/src/main.rs index a28885d..c4a0c26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,20 +25,20 @@ struct Args { images: Vec<PathBuf>, } -//#[derive(Debug)] -//struct Config { -// base_dir: xdg::BaseDirectories, -//} -// -//fn get_config() -> Result<Config, String> { -// let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort") -// .map_err(|_| "oh no")?; -// -// Ok(Config{base_dir: dirs}) -//} +#[derive(Debug)] +struct Config { + base_dirs: xdg::BaseDirectories, +} -fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> - where E: EmbedderT +fn get_config() -> Result<Config, String> { + let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort") + .map_err(|_| "oh no")?; + + Ok(Config{base_dirs: dirs}) +} + +fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> + where M: MetricElem { // wrapper struct to // - reverse the ordering @@ -58,31 +58,23 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> } } - let num_ims = args.images.len(); - if num_ims == 0 { - return Ok(Vec::new()); - } - - let embeds: Vec<_> = e - .embed(&args.images)? - .into_iter() - .collect(); + let num_embeds = embeds.len(); let mut possible_edges = - PriorityQueue::with_capacity((num_ims * num_ims - num_ims) / 2); - let mut mst = HashMap::with_capacity(num_ims); + PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); + let mut mst = HashMap::with_capacity(num_embeds); // here, we start at 0. // we might get a better result in the end if we started with a vertex next // to the lowest-cost edge, but we don't know which one that is (though we // could compute that without changing our asymptotic complexity) mst.insert(0, Vec::new()); - for i in 1..num_ims { + for i in 1..num_embeds { possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); } // prims algorithm or something like that - while mst.len() < num_ims { + while mst.len() < num_embeds { // find the edge with the least cost that connects us to a new vertex let (new, old) = loop { let ((a, b), _) = possible_edges.pop().unwrap(); @@ -97,7 +89,7 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> // insert all the new edges we could take mst.entry(old).and_modify(|v|v.push(new)); - for i in 0..num_ims { + for i in 0..num_embeds { // don't consider edges taking us to nodes we already visited if mst.contains_key(&i) { continue; @@ -107,25 +99,71 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> } } - // find TSP approximation via DFS through the MST + mst +} + +fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); } - let mut tsp_path = Vec::with_capacity(num_ims); + let mut tsp_path = Vec::with_capacity(mst.len()); dfs(0, &mst, &mut tsp_path); + tsp_path +} + +fn process_embedder<E>(mut e: E, args: Args, cfg: Config) -> Result<Vec<PathBuf>, String> + where E: EmbedderT +{ + if args.images.len() == 0 { + return Ok(Vec::new()); + } + + let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db") + .map_err(|e| e.to_string())?).map_err(|e| e.to_string())?; + let tree = typed_sled::Tree::<PathBuf, E::Embedding>::open(&db, E::NAME); + + // TODO nicht pfad, sondern hash vom bild als key nehmen + let mut embeds: Vec<Option<_>> = args.images + .iter() + .map(|p| tree.get(p).map_err(|e| e.to_string())) + .try_collect()?; + + let missing_embeds_indices: Vec<_> = embeds + .iter() + .enumerate() + .filter_map(|(i, v)| match v { + None => Some(i), + Some(_) => None, + }).collect(); + let missing_embeds = e.embeds(&missing_embeds_indices + .iter() + .map(|i| args.images[*i].clone()) + .collect::<Vec<_>>())?; + + for (idx, emb) in missing_embeds_indices + .into_iter().zip(missing_embeds.into_iter()) + { + // TODO hier auch hash statt pfad + tree.insert(&args.images[idx], &emb).map_err(|e| e.to_string())?; + embeds[idx] = Some(emb); + } + + let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); + let tsp_path = tsp_from_mst(get_mst(&embeds)); + Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect()) } fn main() -> Result<(), String> { - //let cfg = get_config()?; + let cfg = get_config()?; let args = Args::parse(); let tsp_path = match args.embedder { - Embedder::Brightness => process_embedder(BrightnessEmbedder, args), - Embedder::Hue => process_embedder(HueEmbedder, args), - Embedder::Color => process_embedder(ColorEmbedder, args), + Embedder::Brightness => process_embedder(BrightnessEmbedder, args, cfg), + Embedder::Hue => process_embedder(HueEmbedder, args, cfg), + Embedder::Color => process_embedder(ColorEmbedder, args, cfg), }?; for p in tsp_path { |