aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-06 13:37:22 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-06 13:37:22 +0200
commitcb67cd7389b68750518b67c7a58beb7659298352 (patch)
treeaf6cb44cb47dacb82127f5bcf586c5b60aa4603a
parent3fe8109102083ef32da9e11eef5dc45dba530333 (diff)
downloadembeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar
embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.bz2
embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.zst
cache embeddings by path (should be by hash, but thats for later)
-rw-r--r--Cargo.lock156
-rw-r--r--Cargo.toml3
-rw-r--r--src/embedders.rs147
-rw-r--r--src/main.rs104
4 files changed, 306 insertions, 104 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 339e4f7..1fab5d7 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -63,6 +63,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
+name = "bincode"
+version = "1.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
+dependencies = [
+ "serde",
+]
+
+[[package]]
name = "bit_field"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -236,6 +245,9 @@ dependencies = [
"indicatif",
"priority-queue",
"rayon",
+ "serde",
+ "sled",
+ "typed-sled",
"xdg",
]
@@ -294,6 +306,16 @@ dependencies = [
]
[[package]]
+name = "fs2"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213"
+dependencies = [
+ "libc",
+ "winapi",
+]
+
+[[package]]
name = "futures-core"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -306,6 +328,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
[[package]]
+name = "fxhash"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
+dependencies = [
+ "byteorder",
+]
+
+[[package]]
name = "getrandom"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -539,6 +570,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
[[package]]
+name = "parking_lot"
+version = "0.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
+dependencies = [
+ "instant",
+ "lock_api",
+ "parking_lot_core",
+]
+
+[[package]]
+name = "parking_lot_core"
+version = "0.8.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc"
+dependencies = [
+ "cfg-if",
+ "instant",
+ "libc",
+ "redox_syscall",
+ "smallvec",
+ "winapi",
+]
+
+[[package]]
name = "pin-project"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -637,18 +693,63 @@ dependencies = [
]
[[package]]
+name = "redox_syscall"
+version = "0.2.16"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
+dependencies = [
+ "bitflags",
+]
+
+[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
+name = "serde"
+version = "1.0.188"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e"
+dependencies = [
+ "serde_derive",
+]
+
+[[package]]
+name = "serde_derive"
+version = "1.0.188"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "simd-adler32"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
+name = "sled"
+version = "0.34.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f96b4737c2ce5987354855aed3797279def4ebf734436c6aa4552cf8e169935"
+dependencies = [
+ "crc32fast",
+ "crossbeam-epoch",
+ "crossbeam-utils",
+ "fs2",
+ "fxhash",
+ "libc",
+ "log",
+ "parking_lot",
+]
+
+[[package]]
name = "smallvec"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -681,6 +782,26 @@ dependencies = [
]
[[package]]
+name = "thiserror"
+version = "1.0.48"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7"
+dependencies = [
+ "thiserror-impl",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "1.0.48"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
name = "tiff"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -692,6 +813,19 @@ dependencies = [
]
[[package]]
+name = "typed-sled"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1060f05a4450ec5b758da60951b04f225a93a62079316630e76cf25c4034500d"
+dependencies = [
+ "bincode",
+ "pin-project",
+ "serde",
+ "sled",
+ "thiserror",
+]
+
+[[package]]
name = "unicode-ident"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -776,6 +910,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9193164d4de03a926d909d3bc7c30543cecb35400c02114792c2cae20d5e2dbb"
[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
+[[package]]
name = "windows-sys"
version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 3e3ce75..fc80898 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,3 +12,6 @@ clap = {version = "4", features = ["derive"]}
priority-queue = "1"
rayon = "1"
indicatif = "0"
+sled = "0"
+typed-sled = "0"
+serde = "1"
diff --git a/src/embedders.rs b/src/embedders.rs
index 94fc122..0693b5e 100644
--- a/src/embedders.rs
+++ b/src/embedders.rs
@@ -1,7 +1,8 @@
use rayon::prelude::*;
use std::path::PathBuf;
+use serde::{Deserialize, Serialize};
-pub trait MetricElem {
+pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
fn dist(&self, _: &Self) -> f64;
}
@@ -11,39 +12,55 @@ impl MetricElem for f64 {
}
}
-pub trait EmbedderT {
+pub trait EmbedderT: Send + Sync {
type Embedding: MetricElem;
+ const NAME: &'static str;
- fn embed(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>;
+ fn embed(&self, _: &PathBuf) -> Result<Self::Embedding, String>;
+}
+
+pub trait BatchEmbedder: Send + Sync {
+ type Embedding: MetricElem;
+ const NAME: &'static str;
+
+ fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>, String>;
+}
+
+impl<T: EmbedderT> BatchEmbedder for T {
+ type Embedding = T::Embedding;
+ const NAME: &'static str = T::NAME;
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>, String> {
+ paths.par_iter()
+ .map(|p| self.embed(p))
+ .collect::<Vec<_>>()
+ .into_iter()
+ .try_collect()
+ }
}
pub struct BrightnessEmbedder;
impl EmbedderT for BrightnessEmbedder {
type Embedding = f64;
+ const NAME: &'static str = "Brightness";
- fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<f64>, String> {
- paths
- .par_iter()
- .map(|p| {
- let im = image::open(p).map_err(|e| e.to_string())?;
- let num_bytes = 3 * (im.height() * im.width());
+ fn embed(&self, path: &PathBuf) -> Result<f64, String> {
+ let im = image::open(path).map_err(|e| e.to_string())?;
+ let num_bytes = 3 * (im.height() * im.width());
- if num_bytes == 0 {
- Err("Encountered NaN brightness, due to an empty image")?;
- }
+ if num_bytes == 0 {
+ return Err("Encountered NaN brightness, due to an empty image".to_string());
+ }
- Ok(im.to_rgb8()
- .iter()
- .map(|e| *e as u64)
- .sum::<u64>() as f64 / num_bytes as f64)
- })
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ Ok(im.to_rgb8()
+ .iter()
+ .map(|e| *e as u64)
+ .sum::<u64>() as f64 / num_bytes as f64)
}
}
#[repr(transparent)]
+#[derive(Serialize, Deserialize)]
pub struct Hue (f64);
impl MetricElem for Hue {
fn dist(&self, b: &Hue) -> f64 {
@@ -55,42 +72,36 @@ impl MetricElem for Hue {
pub struct HueEmbedder;
impl EmbedderT for HueEmbedder {
type Embedding = Hue;
+ const NAME: &'static str = "Hue";
- fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<Hue>, String> {
- paths
- .par_iter()
- .map(|p| {
- let im = image::open(p).map_err(|e| e.to_string())?;
- let num_pixels = im.height() * im.width();
- let [sr, sg, sb] = im
- .to_rgb8()
- .pixels()
- .fold([0, 0, 0], |[or, og, ob], n| {
- let [nr, ng, nb] = n.0;
- [or + nr as u64, og + ng as u64, ob + nb as u64]
- })
- .map(|e| e as f64 / 255. / num_pixels as f64);
+ fn embed(&self, path: &PathBuf) -> Result<Hue, String> {
+ let im = image::open(path).map_err(|e| e.to_string())?;
+ let num_pixels = im.height() * im.width();
+ let [sr, sg, sb] = im
+ .to_rgb8()
+ .pixels()
+ .fold([0, 0, 0], |[or, og, ob], n| {
+ let [nr, ng, nb] = n.0;
+ [or + nr as u64, og + ng as u64, ob + nb as u64]
+ })
+ .map(|e| e as f64 / 255. / num_pixels as f64);
- let hue =
- if sr >= sg && sr >= sb {
- (sg - sb) / (sr - sg.min(sb))
- }
- else if sg >= sb {
- 2. + (sb - sr) / (sg - sr.min(sb))
- }
- else {
- 4. + (sr - sg) / (sb - sr.min(sg))
- };
+ let hue =
+ if sr >= sg && sr >= sb {
+ (sg - sb) / (sr - sg.min(sb))
+ }
+ else if sg >= sb {
+ 2. + (sb - sr) / (sg - sr.min(sb))
+ }
+ else {
+ 4. + (sr - sg) / (sb - sr.min(sg))
+ };
- if hue.is_nan() {
- Err("Encountered NaN hue, possibly because of a colorless or empty image")?;
- }
+ if hue.is_nan() {
+ return Err("Encountered NaN hue, possibly because of a colorless or empty image".to_string());
+ }
- Ok(Hue(hue))
- })
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ Ok(Hue(hue))
}
}
@@ -105,26 +116,20 @@ impl MetricElem for (f64, f64, f64) {
pub struct ColorEmbedder;
impl EmbedderT for ColorEmbedder {
type Embedding = (f64, f64, f64);
+ const NAME: &'static str = "Color";
- fn embed(&mut self, paths: &[PathBuf]) -> Result<Vec<(f64, f64, f64)>, String> {
- paths
- .par_iter()
- .map(|p| {
- let im = image::open(p).map_err(|e| e.to_string())?;
- let num_pixels = im.height() * im.width();
- let [sr, sg, sb] = im
- .to_rgb8()
- .pixels()
- .fold([0, 0, 0], |[or, og, ob], n| {
- let [nr, ng, nb] = n.0;
- [or + nr as u64, og + ng as u64, ob + nb as u64]
- })
- .map(|e| e as f64 / num_pixels as f64);
-
- Ok((sr, sg, sb))
+ fn embed(&self, path: &PathBuf) -> Result<(f64, f64, f64), String> {
+ let im = image::open(path).map_err(|e| e.to_string())?;
+ let num_pixels = im.height() * im.width();
+ let [sr, sg, sb] = im
+ .to_rgb8()
+ .pixels()
+ .fold([0, 0, 0], |[or, og, ob], n| {
+ let [nr, ng, nb] = n.0;
+ [or + nr as u64, og + ng as u64, ob + nb as u64]
})
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ .map(|e| e as f64 / num_pixels as f64);
+
+ Ok((sr, sg, sb))
}
}
diff --git a/src/main.rs b/src/main.rs
index a28885d..c4a0c26 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,20 +25,20 @@ struct Args {
images: Vec<PathBuf>,
}
-//#[derive(Debug)]
-//struct Config {
-// base_dir: xdg::BaseDirectories,
-//}
-//
-//fn get_config() -> Result<Config, String> {
-// let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")
-// .map_err(|_| "oh no")?;
-//
-// Ok(Config{base_dir: dirs})
-//}
+#[derive(Debug)]
+struct Config {
+ base_dirs: xdg::BaseDirectories,
+}
-fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
- where E: EmbedderT
+fn get_config() -> Result<Config, String> {
+ let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")
+ .map_err(|_| "oh no")?;
+
+ Ok(Config{base_dirs: dirs})
+}
+
+fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
+ where M: MetricElem
{
// wrapper struct to
// - reverse the ordering
@@ -58,31 +58,23 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
}
}
- let num_ims = args.images.len();
- if num_ims == 0 {
- return Ok(Vec::new());
- }
-
- let embeds: Vec<_> = e
- .embed(&args.images)?
- .into_iter()
- .collect();
+ let num_embeds = embeds.len();
let mut possible_edges =
- PriorityQueue::with_capacity((num_ims * num_ims - num_ims) / 2);
- let mut mst = HashMap::with_capacity(num_ims);
+ PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
+ let mut mst = HashMap::with_capacity(num_embeds);
// here, we start at 0.
// we might get a better result in the end if we started with a vertex next
// to the lowest-cost edge, but we don't know which one that is (though we
// could compute that without changing our asymptotic complexity)
mst.insert(0, Vec::new());
- for i in 1..num_ims {
+ for i in 1..num_embeds {
possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
}
// prims algorithm or something like that
- while mst.len() < num_ims {
+ while mst.len() < num_embeds {
// find the edge with the least cost that connects us to a new vertex
let (new, old) = loop {
let ((a, b), _) = possible_edges.pop().unwrap();
@@ -97,7 +89,7 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
// insert all the new edges we could take
mst.entry(old).and_modify(|v|v.push(new));
- for i in 0..num_ims {
+ for i in 0..num_embeds {
// don't consider edges taking us to nodes we already visited
if mst.contains_key(&i) {
continue;
@@ -107,25 +99,71 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
}
}
- // find TSP approximation via DFS through the MST
+ mst
+}
+
+fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
}
- let mut tsp_path = Vec::with_capacity(num_ims);
+ let mut tsp_path = Vec::with_capacity(mst.len());
dfs(0, &mst, &mut tsp_path);
+ tsp_path
+}
+
+fn process_embedder<E>(mut e: E, args: Args, cfg: Config) -> Result<Vec<PathBuf>, String>
+ where E: EmbedderT
+{
+ if args.images.len() == 0 {
+ return Ok(Vec::new());
+ }
+
+ let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")
+ .map_err(|e| e.to_string())?).map_err(|e| e.to_string())?;
+ let tree = typed_sled::Tree::<PathBuf, E::Embedding>::open(&db, E::NAME);
+
+ // TODO nicht pfad, sondern hash vom bild als key nehmen
+ let mut embeds: Vec<Option<_>> = args.images
+ .iter()
+ .map(|p| tree.get(p).map_err(|e| e.to_string()))
+ .try_collect()?;
+
+ let missing_embeds_indices: Vec<_> = embeds
+ .iter()
+ .enumerate()
+ .filter_map(|(i, v)| match v {
+ None => Some(i),
+ Some(_) => None,
+ }).collect();
+ let missing_embeds = e.embeds(&missing_embeds_indices
+ .iter()
+ .map(|i| args.images[*i].clone())
+ .collect::<Vec<_>>())?;
+
+ for (idx, emb) in missing_embeds_indices
+ .into_iter().zip(missing_embeds.into_iter())
+ {
+ // TODO hier auch hash statt pfad
+ tree.insert(&args.images[idx], &emb).map_err(|e| e.to_string())?;
+ embeds[idx] = Some(emb);
+ }
+
+ let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
+ let tsp_path = tsp_from_mst(get_mst(&embeds));
+
Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect())
}
fn main() -> Result<(), String> {
- //let cfg = get_config()?;
+ let cfg = get_config()?;
let args = Args::parse();
let tsp_path = match args.embedder {
- Embedder::Brightness => process_embedder(BrightnessEmbedder, args),
- Embedder::Hue => process_embedder(HueEmbedder, args),
- Embedder::Color => process_embedder(ColorEmbedder, args),
+ Embedder::Brightness => process_embedder(BrightnessEmbedder, args, cfg),
+ Embedder::Hue => process_embedder(HueEmbedder, args, cfg),
+ Embedder::Color => process_embedder(ColorEmbedder, args, cfg),
}?;
for p in tsp_path {