diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2024-03-03 19:40:16 +0100 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2024-03-03 19:40:16 +0100 |
commit | 3b670fe70de4b52b9a3f8614b42fb94fe49b3548 (patch) | |
tree | 1119b13ab7119d8dc8bbaaebd77a9ec164eddac0 | |
parent | 94185d8b39af5a96a61ffb4df7a2d76dcd7afa49 (diff) | |
download | embeddings-sort-3b670fe70de4b52b9a3f8614b42fb94fe49b3548.tar embeddings-sort-3b670fe70de4b52b9a3f8614b42fb94fe49b3548.tar.bz2 embeddings-sort-3b670fe70de4b52b9a3f8614b42fb94fe49b3548.tar.zst |
add hash seed option for reproducible behaviour
-rw-r--r-- | Cargo.lock | 42 | ||||
-rw-r--r-- | Cargo.toml | 21 | ||||
-rw-r--r-- | src/main.rs | 6 | ||||
-rw-r--r-- | src/tsp_approx.rs | 180 |
4 files changed, 193 insertions, 56 deletions
@@ -9,6 +9,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] +name = "ahash" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] name = "anstream" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -290,11 +303,12 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" name = "embeddings-sort" version = "0.2.0" dependencies = [ + "ahash", "anyhow", "clap", "image", "indicatif", - "multiset", + "mset", "partitions", "pathdiff", "rayon", @@ -572,10 +586,10 @@ dependencies = [ ] [[package]] -name = "multiset" -version = "0.0.5" +name = "mset" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce8738c9ddd350996cb8b8b718192851df960803764bcdaa3afb44a63b1ddb5c" +checksum = "26c4d16a3d2b0e89ec6e7d509cf791545fcb48cbc8fc2fb2e96a492defda9140" [[package]] name = "nanorand" @@ -1216,6 +1230,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546" [[package]] +name = "zerocopy" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] name = "zune-inflate" version = "0.2.54" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -6,18 +6,19 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -image = "0" -xdg = "2" +ahash = "0" +anyhow = "1" clap = { version = "4", features = ["derive"] } -rayon = "1" +image = "0" indicatif = { version = "0", features = ["rayon"] } -sled = "0" -typed-sled = "0" +mset = "0" +partitions = "0" +pathdiff = "0" +rayon = "1" +reflink-copy = "0" serde = "1" serde_json = "1" sha2 = "0" -anyhow = "1" -pathdiff = "0" -reflink-copy = "0" -multiset = "0" -partitions = "0" +sled = "0" +typed-sled = "0" +xdg = "2" diff --git a/src/main.rs b/src/main.rs index f9bf4fc..71e5d00 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,6 +61,10 @@ struct Args { #[arg(long, default_value = "christofides")] tsp_approx: TspAlg, + /// Seed for hashing. Random by default. + #[arg(long)] + hash_seed: Option<u64>, + images: Vec<PathBuf>, } @@ -138,7 +142,7 @@ where } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, &args.hash_seed); Ok(( tsp_path.iter().map(|i| args.images[*i].clone()).collect(), diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 79aca3d..914f414 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -1,13 +1,54 @@ +use ahash::{random_state::RandomState, HashMap, HashSet}; use indicatif::{ProgressBar, ProgressStyle}; -use multiset::HashMultiSet; use partitions::partition_vec::PartitionVec; -use std::collections::{HashMap, HashSet}; +use mset::MultiSet; +use std::{hash::Hash, cmp::Eq}; use crate::{MetricElem, TspAlg}; -fn get_mst(dist_cache: &HashMap<(usize, usize), f64>, num_embeds: usize) -> HashMap<usize, Vec<usize>> { +fn get_hashmap<K, V>(hash_seed: &Option<u64>, capacity: Option<usize>) -> HashMap<K, V> { + let hasher = match hash_seed { + Some(s) => RandomState::with_seeds(*s, 0, 0, 0), + None => RandomState::new(), + }; + + match capacity { + Some(sz) => HashMap::with_capacity_and_hasher(sz, hasher), + None => HashMap::with_hasher(hasher), + } +} + +fn get_hashset<V>(hash_seed: &Option<u64>, capacity: Option<usize>) -> HashSet<V> { + let hasher = match hash_seed { + Some(s) => RandomState::with_seeds(*s, 0, 0, 0), + None => RandomState::new(), + }; + + match capacity { + Some(sz) => HashSet::with_capacity_and_hasher(sz, hasher), + None => HashSet::with_hasher(hasher), + } +} + +fn get_multiset<V: Hash + Eq>(hash_seed: &Option<u64>, capacity: Option<usize>) -> MultiSet<V, RandomState> { + let hasher = match hash_seed { + Some(s) => RandomState::with_seeds(*s, 0, 0, 0), + None => RandomState::new(), + }; + + match capacity { + Some(sz) => MultiSet::with_capacity_and_hasher(sz, hasher), + None => MultiSet::with_hasher(hasher), + } +} + +fn get_mst( + dist_cache: &HashMap<(usize, usize), f64>, + num_embeds: usize, + hash_seed: &Option<u64>, +) -> HashMap<usize, Vec<usize>> { let mut possible_edges = Vec::with_capacity((num_embeds * num_embeds - num_embeds) / 2); - let mut mst = HashMap::with_capacity(num_embeds); + let mut mst = get_hashmap(hash_seed, Some(num_embeds)); // insert all edges we could ever use mst.insert(0, Vec::new()); @@ -35,7 +76,10 @@ fn get_mst(dist_cache: &HashMap<(usize, usize), f64>, num_embeds: usize) -> Hash mst } -fn tsp_from_mst(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) { +fn tsp_from_mst( + dist_cache: &HashMap<(usize, usize), f64>, + mst: HashMap<usize, Vec<usize>>, +) -> (Vec<usize>, f64) { fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|&c| { @@ -56,10 +100,14 @@ fn tsp_from_mst(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, V (tsp_path, total_dist) } -// TODO this is a non-ideal, greedy algorithm. there are better algorithms for this, -// and i should probably implement one. +// this is a greedy, non-exact implementation. The polynomial time solution would be +// in O(n^3), which is too large for my taste, while this is O(n^2). /// 'verts' must be an even number of vertices with odd degree -fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize]) -> HashMap<usize, usize> { +fn min_weight_matching( + dist_cache: &HashMap<(usize, usize), f64>, + verts: &[usize], + hash_seed: &Option<u64>, +) -> HashMap<usize, usize> { let num_odd = verts.len(); assert!(num_odd % 2 == 0); @@ -73,7 +121,7 @@ fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize } possible_edges.sort_unstable_by(|(da, _, _), (db, _, _)| da.partial_cmp(db).unwrap()); - let mut res = HashMap::new(); + let mut res = get_hashmap(hash_seed, None); for (_, a, b) in possible_edges.into_iter() { if res.len() >= num_odd { @@ -90,18 +138,20 @@ fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize } fn euler_tour( - mut graph: HashMap<usize, HashMultiSet<usize>>, + mut graph: HashMap<usize, MultiSet<usize, RandomState>>, + hash_seed: &Option<u64>, ) -> (usize, HashMap<usize, Vec<(usize, usize, usize)>>) { - let mut r: HashMap<_, Vec<_>> = HashMap::new(); - let mut partially_explored = HashSet::new(); + let mut r: HashMap<_, Vec<_>> = get_hashmap(hash_seed, None); + let mut partially_explored = get_hashset(hash_seed, None); + // TODO das hier brauch fixing. bitte nochmal algorithmus verstehen vorher. // initial setup: pretend that we have only the path 'INF -> root -> INF' // for some arbitrary root, and set cur to some node next to root. // This mimicks the state we're in just after a new phase (because it is) let &root = graph.keys().next().unwrap(); r.insert(root, vec![(usize::MAX, usize::MAX, usize::MAX)]); let e = graph.get_mut(&root).unwrap(); - let &next = e.iter().next().unwrap(); + let (&next, _) = e.iter().next().unwrap(); e.remove(&next); graph.get_mut(&next).unwrap().remove(&root); @@ -119,7 +169,7 @@ fn euler_tour( } match e.iter().next() { - Some(&next) => { + Some((&next, _)) => { e.remove(&next); graph.get_mut(&next).unwrap().remove(&cur); @@ -161,7 +211,7 @@ fn euler_tour( // new_cur from partially_explored; that's done the next time we // get there) let e = graph.get_mut(&new_cur).unwrap(); - let &next = e.iter().next().unwrap(); + let (&next, _) = e.iter().next().unwrap(); e.remove(&next); graph.get_mut(&next).unwrap().remove(&new_cur); @@ -178,31 +228,39 @@ fn euler_tour( (root, r) } -fn christofides(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) { - let mut mst: HashMap<_, HashMultiSet<_>> = mst - .into_iter() - .map(|(k, v)| (k, v.into_iter().collect())) - .collect(); +fn christofides( + dist_cache: &HashMap<(usize, usize), f64>, + mst: HashMap<usize, Vec<usize>>, + hash_seed: &Option<u64>, +) -> (Vec<usize>, f64) { + let mut ext_mst: HashMap<usize, MultiSet<usize, RandomState>> = get_hashmap(hash_seed, Some(mst.len())); + for (k, v) in mst.into_iter() { + let mut as_mset = get_multiset(hash_seed, Some(v.len())); + for l in v.into_iter() { + as_mset.insert(l); + } + ext_mst.insert(k, as_mset); + } - let odd_verts: Vec<_> = mst + let odd_verts: Vec<_> = ext_mst .iter() .filter_map(|(&i, s)| if s.len() % 2 == 0 { None } else { Some(i) }) .collect(); - // from here on, 'mst' is a bit of a misnomer, as we're adding more edges such + // from here on, 'mst' would be a bit of a misnomer, as we're adding more edges such // that all vertices have even degree - for (a, b) in min_weight_matching(dist_cache, &odd_verts) { - mst.get_mut(&a).unwrap().insert(b); + for (a, b) in min_weight_matching(dist_cache, &odd_verts, hash_seed) { + ext_mst.get_mut(&a).unwrap().insert(b); } let mut r = Vec::new(); let mut total_dist = 0.; - let mut visited_verts = HashSet::new(); + let mut visited_verts = get_hashset(hash_seed, None); visited_verts.insert(usize::MAX); let mut pprev = usize::MAX; let mut prev = usize::MAX; - let (mut cur, euler_tour) = euler_tour(mst); + let (mut cur, euler_tour) = euler_tour(ext_mst, hash_seed); loop { match euler_tour.get(&cur) { None => break, // tour complete: this should happen iff 'cur == usize::MAX' @@ -224,25 +282,65 @@ fn christofides(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, V } } + (r, total_dist) } -fn refine<M>(_: &[M], _: Vec<usize>, _: f64) -> (Vec<usize>, f64) -where - M: MetricElem, -{ - // convert the tour into a linked-list representation. instead of pointers, we use - // an array of indices - //let tour_ll = Vec::new(); +fn refine_2_opt( + dist_cache: &HashMap<(usize, usize), f64>, + tour: Vec<usize>, + tour_len: f64, +) -> (Vec<usize>, f64) { + if tour.len() < 4 { + return (tour, tour_len); + } + + // convert the tour into a doubly linked-list. instead of pointers, we use + // an array of index pairs. + let mut tour: Vec<_> = tour + .into_iter() + .enumerate() + .map(|(i, v)| (i - 1, v, i + 1)) + .collect(); + let n = tour.len(); + // fix boundary + tour[0].0 = n - 1; + tour[n - 1].2 = 0; + + // this will be an implementation of 2-opt swapping where the swapping takes O(1). + // to do this, the links of a node no are no longer "backward"/"forward"; instead, + // the forward direction will be in the direction of the link you didn't come from. + // This implicit direction indicator allows for reversing a sublist in O(1) + let mut /* "left edge" */ le = ( + /* previous tour index */ n - 1, + /* tour index */ 0, + ); + + let adv = |(pi, i): (usize, usize)| { + if tour[i].0 == pi { + (i, tour[i].2) // forward + } else { + (i, tour[i].0) // reverse + } + }; + + while le.0 != n - 1 { + // TODO so nehmen wir die boundary nicht mit. + let mut re = adv(le); + println!("{:?} <-> {:?}", le, re); + + le = adv(le); + } + todo!() } -fn get_dist_cache<M>(embeds: &[M]) -> HashMap<(usize, usize), f64> +fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> HashMap<(usize, usize), f64> where M: MetricElem, { let n = embeds.len(); - let mut r = HashMap::with_capacity((n * n - n) / 2); + let mut r = get_hashmap(hash_seed, Some((n * n - n) / 2)); for a in 0..n { for b in a + 1..n { r.insert((a, b), embeds[a].dist(&embeds[b])); @@ -252,7 +350,7 @@ where r } -pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg) -> (Vec<usize>, f64) +pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg, hash_seed: &Option<u64>) -> (Vec<usize>, f64) where M: MetricElem, { @@ -261,19 +359,19 @@ where bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Calculating distances..."); - let dc = get_dist_cache(embeds); + let dc = get_dist_cache(embeds, hash_seed); bar.set_message("Finding mst..."); - let mst = get_mst(&dc, embeds.len()); + let mst = get_mst(&dc, embeds.len(), hash_seed); bar.set_message("Finding path..."); let r = match alg { TspAlg::MstDfs => tsp_from_mst(&dc, mst), - TspAlg::Christofides => christofides(&dc, mst), + TspAlg::Christofides => christofides(&dc, mst, hash_seed), TspAlg::ChristofidesRefined => { - let (p, l) = christofides(&dc, mst); + let (p, l) = christofides(&dc, mst, hash_seed); bar.set_message("Refining path..."); - refine(embeds, p, l) + refine_2_opt(&dc, p, l) } }; |