diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-06 13:37:22 +0200 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-06 13:37:22 +0200 |
commit | cb67cd7389b68750518b67c7a58beb7659298352 (patch) | |
tree | af6cb44cb47dacb82127f5bcf586c5b60aa4603a | |
parent | 3fe8109102083ef32da9e11eef5dc45dba530333 (diff) | |
download | embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.bz2 embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.zst |
cache embeddings by path (should be by hash, but thats for later)
-rw-r--r-- | Cargo.lock | 156 | ||||
-rw-r--r-- | Cargo.toml | 3 | ||||
-rw-r--r-- | src/embedders.rs | 147 | ||||
-rw-r--r-- | src/main.rs | 104 |
4 files changed, 306 insertions, 104 deletions
@@ -63,6 +63,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] name = "bit_field" version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -236,6 +245,9 @@ dependencies = [ "indicatif", "priority-queue", "rayon", + "serde", + "sled", + "typed-sled", "xdg", ] @@ -294,6 +306,16 @@ dependencies = [ ] [[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + +[[package]] name = "futures-core" version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -306,6 +328,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" [[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] name = "getrandom" version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -539,6 +570,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall", + "smallvec", + "winapi", +] + +[[package]] name = "pin-project" version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -637,18 +693,63 @@ dependencies = [ ] [[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] +name = "serde" +version = "1.0.188" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.188" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] name = "simd-adler32" version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" [[package]] +name = "sled" +version = "0.34.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935" +dependencies = [ + "crc32fast", + "crossbeam-epoch", + "crossbeam-utils", + "fs2", + "fxhash", + "libc", + "log", + "parking_lot", +] + +[[package]] name = "smallvec" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -681,6 +782,26 @@ dependencies = [ ] [[package]] +name = "thiserror" +version = "1.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] name = "tiff" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -692,6 +813,19 @@ dependencies = [ ] [[package]] +name = "typed-sled" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1060f05a4450ec5b758da60951b04f225a93a62079316630e76cf25c4034500d" +dependencies = [ + "bincode", + "pin-project", + "serde", + "sled", + "thiserror", +] + +[[package]] name = "unicode-ident" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -776,6 +910,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb" [[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] name = "windows-sys" version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -12,3 +12,6 @@ clap = {version = "4", features = ["derive"]} priority-queue = "1" rayon = "1" indicatif = "0" +sled = "0" +typed-sled = "0" +serde = "1" diff --git a/src/embedders.rs b/src/embedders.rs index 94fc122..0693b5e 100644 --- a/src/embedders.rs +++ b/src/embedders.rs @@ -1,7 +1,8 @@ use rayon::prelude::*; use std::path::PathBuf; +use serde::{Deserialize, Serialize}; -pub trait MetricElem { +pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { fn dist(&self, _: &Self) -> f64; } @@ -11,39 +12,55 @@ impl MetricElem for f64 { } } -pub trait EmbedderT { +pub trait EmbedderT: Send + Sync { type Embedding: MetricElem; + const NAME: &'static str; - fn embed(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>; + fn embed(&self, _: &PathBuf) -> Result<Self::Embedding, String>; +} + +pub trait BatchEmbedder: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>; +} + +impl<T: EmbedderT> BatchEmbedder for T { + type Embedding = T::Embedding; + const NAME: &'static str = T::NAME; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>, String> { + paths.par_iter() + .map(|p| self.embed(p)) + .collect::<Vec<_>>() + .into_iter() + .try_collect() + } } pub struct BrightnessEmbedder; impl EmbedderT for BrightnessEmbedder { type Embedding = f64; + const NAME: &'static str = "Brightness"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<f64>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_bytes = 3 * (im.height() * im.width()); + fn embed(&self, path: &PathBuf) -> Result<f64, String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_bytes = 3 * (im.height() * im.width()); - if num_bytes == 0 { - Err("Encountered NaN brightness, due to an empty image")?; - } + if num_bytes == 0 { + return Err("Encountered NaN brightness, due to an empty image".to_string()); + } - Ok(im.to_rgb8() - .iter() - .map(|e| *e as u64) - .sum::<u64>() as f64 / num_bytes as f64) - }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + Ok(im.to_rgb8() + .iter() + .map(|e| *e as u64) + .sum::<u64>() as f64 / num_bytes as f64) } } #[repr(transparent)] +#[derive(Serialize, Deserialize)] pub struct Hue (f64); impl MetricElem for Hue { fn dist(&self, b: &Hue) -> f64 { @@ -55,42 +72,36 @@ impl MetricElem for Hue { pub struct HueEmbedder; impl EmbedderT for HueEmbedder { type Embedding = Hue; + const NAME: &'static str = "Hue"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<Hue>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / 255. / num_pixels as f64); + fn embed(&self, path: &PathBuf) -> Result<Hue, String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / 255. / num_pixels as f64); - let hue = - if sr >= sg && sr >= sb { - (sg - sb) / (sr - sg.min(sb)) - } - else if sg >= sb { - 2. + (sb - sr) / (sg - sr.min(sb)) - } - else { - 4. + (sr - sg) / (sb - sr.min(sg)) - }; + let hue = + if sr >= sg && sr >= sb { + (sg - sb) / (sr - sg.min(sb)) + } + else if sg >= sb { + 2. + (sb - sr) / (sg - sr.min(sb)) + } + else { + 4. + (sr - sg) / (sb - sr.min(sg)) + }; - if hue.is_nan() { - Err("Encountered NaN hue, possibly because of a colorless or empty image")?; - } + if hue.is_nan() { + return Err("Encountered NaN hue, possibly because of a colorless or empty image".to_string()); + } - Ok(Hue(hue)) - }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + Ok(Hue(hue)) } } @@ -105,26 +116,20 @@ impl MetricElem for (f64, f64, f64) { pub struct ColorEmbedder; impl EmbedderT for ColorEmbedder { type Embedding = (f64, f64, f64); + const NAME: &'static str = "Color"; - fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<(f64, f64, f64)>, String> { - paths - .par_iter() - .map(|p| { - let im = image::open(p).map_err(|e| e.to_string())?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / num_pixels as f64); - - Ok((sr, sg, sb)) + fn embed(&self, path: &PathBuf) -> Result<(f64, f64, f64), String> { + let im = image::open(path).map_err(|e| e.to_string())?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] }) - .collect::<Vec<_>>() - .into_iter() - .try_collect() + .map(|e| e as f64 / num_pixels as f64); + + Ok((sr, sg, sb)) } } diff --git a/src/main.rs b/src/main.rs index a28885d..c4a0c26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,20 +25,20 @@ struct Args { images: Vec<PathBuf>, } -//#[derive(Debug)] -//struct Config { -// base_dir: xdg::BaseDirectories, -//} -// -//fn get_config() -> Result<Config, String> { -// let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort") -// .map_err(|_| "oh no")?; -// -// Ok(Config{base_dir: dirs}) -//} +#[derive(Debug)] +struct Config { + base_dirs: xdg::BaseDirectories, +} -fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> - where E: EmbedderT +fn get_config() -> Result<Config, String> { + let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort") + .map_err(|_| "oh no")?; + + Ok(Config{base_dirs: dirs}) +} + +fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> + where M: MetricElem { // wrapper struct to // - reverse the ordering @@ -58,31 +58,23 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> } } - let num_ims = args.images.len(); - if num_ims == 0 { - return Ok(Vec::new()); - } - - let embeds: Vec<_> = e - .embed(&args.images)? - .into_iter() - .collect(); + let num_embeds = embeds.len(); let mut possible_edges = - PriorityQueue::with_capacity((num_ims * num_ims - num_ims) / 2); - let mut mst = HashMap::with_capacity(num_ims); + PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); + let mut mst = HashMap::with_capacity(num_embeds); // here, we start at 0. // we might get a better result in the end if we started with a vertex next // to the lowest-cost edge, but we don't know which one that is (though we // could compute that without changing our asymptotic complexity) mst.insert(0, Vec::new()); - for i in 1..num_ims { + for i in 1..num_embeds { possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); } // prims algorithm or something like that - while mst.len() < num_ims { + while mst.len() < num_embeds { // find the edge with the least cost that connects us to a new vertex let (new, old) = loop { let ((a, b), _) = possible_edges.pop().unwrap(); @@ -97,7 +89,7 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> // insert all the new edges we could take mst.entry(old).and_modify(|v|v.push(new)); - for i in 0..num_ims { + for i in 0..num_embeds { // don't consider edges taking us to nodes we already visited if mst.contains_key(&i) { continue; @@ -107,25 +99,71 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String> } } - // find TSP approximation via DFS through the MST + mst +} + +fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); } - let mut tsp_path = Vec::with_capacity(num_ims); + let mut tsp_path = Vec::with_capacity(mst.len()); dfs(0, &mst, &mut tsp_path); + tsp_path +} + +fn process_embedder<E>(mut e: E, args: Args, cfg: Config) -> Result<Vec<PathBuf>, String> + where E: EmbedderT +{ + if args.images.len() == 0 { + return Ok(Vec::new()); + } + + let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db") + .map_err(|e| e.to_string())?).map_err(|e| e.to_string())?; + let tree = typed_sled::Tree::<PathBuf, E::Embedding>::open(&db, E::NAME); + + // TODO nicht pfad, sondern hash vom bild als key nehmen + let mut embeds: Vec<Option<_>> = args.images + .iter() + .map(|p| tree.get(p).map_err(|e| e.to_string())) + .try_collect()?; + + let missing_embeds_indices: Vec<_> = embeds + .iter() + .enumerate() + .filter_map(|(i, v)| match v { + None => Some(i), + Some(_) => None, + }).collect(); + let missing_embeds = e.embeds(&missing_embeds_indices + .iter() + .map(|i| args.images[*i].clone()) + .collect::<Vec<_>>())?; + + for (idx, emb) in missing_embeds_indices + .into_iter().zip(missing_embeds.into_iter()) + { + // TODO hier auch hash statt pfad + tree.insert(&args.images[idx], &emb).map_err(|e| e.to_string())?; + embeds[idx] = Some(emb); + } + + let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); + let tsp_path = tsp_from_mst(get_mst(&embeds)); + Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect()) } fn main() -> Result<(), String> { - //let cfg = get_config()?; + let cfg = get_config()?; let args = Args::parse(); let tsp_path = match args.embedder { - Embedder::Brightness => process_embedder(BrightnessEmbedder, args), - Embedder::Hue => process_embedder(HueEmbedder, args), - Embedder::Color => process_embedder(ColorEmbedder, args), + Embedder::Brightness => process_embedder(BrightnessEmbedder, args, cfg), + Embedder::Hue => process_embedder(HueEmbedder, args, cfg), + Embedder::Color => process_embedder(ColorEmbedder, args, cfg), }?; for p in tsp_path { |