aboutsummaryrefslogtreecommitdiff
path: root/src/tsp_approx.rs
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2024-03-03 19:40:16 +0100
committerLia Lenckowski <lialenck@protonmail.com>2024-03-03 19:40:16 +0100
commit3b670fe70de4b52b9a3f8614b42fb94fe49b3548 (patch)
tree1119b13ab7119d8dc8bbaaebd77a9ec164eddac0 /src/tsp_approx.rs
parent94185d8b39af5a96a61ffb4df7a2d76dcd7afa49 (diff)
downloadembeddings-sort-3b670fe70de4b52b9a3f8614b42fb94fe49b3548.tar
embeddings-sort-3b670fe70de4b52b9a3f8614b42fb94fe49b3548.tar.bz2
embeddings-sort-3b670fe70de4b52b9a3f8614b42fb94fe49b3548.tar.zst
add hash seed option for reproducible behaviour
Diffstat (limited to 'src/tsp_approx.rs')
-rw-r--r--src/tsp_approx.rs180
1 files changed, 139 insertions, 41 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 79aca3d..914f414 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -1,13 +1,54 @@
+use ahash::{random_state::RandomState, HashMap, HashSet};
use indicatif::{ProgressBar, ProgressStyle};
-use multiset::HashMultiSet;
use partitions::partition_vec::PartitionVec;
-use std::collections::{HashMap, HashSet};
+use mset::MultiSet;
+use std::{hash::Hash, cmp::Eq};
use crate::{MetricElem, TspAlg};
-fn get_mst(dist_cache: &HashMap<(usize, usize), f64>, num_embeds: usize) -> HashMap<usize, Vec<usize>> {
+fn get_hashmap<K, V>(hash_seed: &Option<u64>, capacity: Option<usize>) -> HashMap<K, V> {
+ let hasher = match hash_seed {
+ Some(s) => RandomState::with_seeds(*s, 0, 0, 0),
+ None => RandomState::new(),
+ };
+
+ match capacity {
+ Some(sz) => HashMap::with_capacity_and_hasher(sz, hasher),
+ None => HashMap::with_hasher(hasher),
+ }
+}
+
+fn get_hashset<V>(hash_seed: &Option<u64>, capacity: Option<usize>) -> HashSet<V> {
+ let hasher = match hash_seed {
+ Some(s) => RandomState::with_seeds(*s, 0, 0, 0),
+ None => RandomState::new(),
+ };
+
+ match capacity {
+ Some(sz) => HashSet::with_capacity_and_hasher(sz, hasher),
+ None => HashSet::with_hasher(hasher),
+ }
+}
+
+fn get_multiset<V: Hash + Eq>(hash_seed: &Option<u64>, capacity: Option<usize>) -> MultiSet<V, RandomState> {
+ let hasher = match hash_seed {
+ Some(s) => RandomState::with_seeds(*s, 0, 0, 0),
+ None => RandomState::new(),
+ };
+
+ match capacity {
+ Some(sz) => MultiSet::with_capacity_and_hasher(sz, hasher),
+ None => MultiSet::with_hasher(hasher),
+ }
+}
+
+fn get_mst(
+ dist_cache: &HashMap<(usize, usize), f64>,
+ num_embeds: usize,
+ hash_seed: &Option<u64>,
+) -> HashMap<usize, Vec<usize>> {
let mut possible_edges = Vec::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
- let mut mst = HashMap::with_capacity(num_embeds);
+ let mut mst = get_hashmap(hash_seed, Some(num_embeds));
// insert all edges we could ever use
mst.insert(0, Vec::new());
@@ -35,7 +76,10 @@ fn get_mst(dist_cache: &HashMap<(usize, usize), f64>, num_embeds: usize) -> Hash
mst
}
-fn tsp_from_mst(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) {
+fn tsp_from_mst(
+ dist_cache: &HashMap<(usize, usize), f64>,
+ mst: HashMap<usize, Vec<usize>>,
+) -> (Vec<usize>, f64) {
fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|&c| {
@@ -56,10 +100,14 @@ fn tsp_from_mst(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, V
(tsp_path, total_dist)
}
-// TODO this is a non-ideal, greedy algorithm. there are better algorithms for this,
-// and i should probably implement one.
+// this is a greedy, non-exact implementation. The polynomial time solution would be
+// in O(n^3), which is too large for my taste, while this is O(n^2).
/// 'verts' must be an even number of vertices with odd degree
-fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize]) -> HashMap<usize, usize> {
+fn min_weight_matching(
+ dist_cache: &HashMap<(usize, usize), f64>,
+ verts: &[usize],
+ hash_seed: &Option<u64>,
+) -> HashMap<usize, usize> {
let num_odd = verts.len();
assert!(num_odd % 2 == 0);
@@ -73,7 +121,7 @@ fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize
}
possible_edges.sort_unstable_by(|(da, _, _), (db, _, _)| da.partial_cmp(db).unwrap());
- let mut res = HashMap::new();
+ let mut res = get_hashmap(hash_seed, None);
for (_, a, b) in possible_edges.into_iter() {
if res.len() >= num_odd {
@@ -90,18 +138,20 @@ fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize
}
fn euler_tour(
- mut graph: HashMap<usize, HashMultiSet<usize>>,
+ mut graph: HashMap<usize, MultiSet<usize, RandomState>>,
+ hash_seed: &Option<u64>,
) -> (usize, HashMap<usize, Vec<(usize, usize, usize)>>) {
- let mut r: HashMap<_, Vec<_>> = HashMap::new();
- let mut partially_explored = HashSet::new();
+ let mut r: HashMap<_, Vec<_>> = get_hashmap(hash_seed, None);
+ let mut partially_explored = get_hashset(hash_seed, None);
+ // TODO das hier brauch fixing. bitte nochmal algorithmus verstehen vorher.
// initial setup: pretend that we have only the path 'INF -> root -> INF'
// for some arbitrary root, and set cur to some node next to root.
// This mimicks the state we're in just after a new phase (because it is)
let &root = graph.keys().next().unwrap();
r.insert(root, vec![(usize::MAX, usize::MAX, usize::MAX)]);
let e = graph.get_mut(&root).unwrap();
- let &next = e.iter().next().unwrap();
+ let (&next, _) = e.iter().next().unwrap();
e.remove(&next);
graph.get_mut(&next).unwrap().remove(&root);
@@ -119,7 +169,7 @@ fn euler_tour(
}
match e.iter().next() {
- Some(&next) => {
+ Some((&next, _)) => {
e.remove(&next);
graph.get_mut(&next).unwrap().remove(&cur);
@@ -161,7 +211,7 @@ fn euler_tour(
// new_cur from partially_explored; that's done the next time we
// get there)
let e = graph.get_mut(&new_cur).unwrap();
- let &next = e.iter().next().unwrap();
+ let (&next, _) = e.iter().next().unwrap();
e.remove(&next);
graph.get_mut(&next).unwrap().remove(&new_cur);
@@ -178,31 +228,39 @@ fn euler_tour(
(root, r)
}
-fn christofides(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) {
- let mut mst: HashMap<_, HashMultiSet<_>> = mst
- .into_iter()
- .map(|(k, v)| (k, v.into_iter().collect()))
- .collect();
+fn christofides(
+ dist_cache: &HashMap<(usize, usize), f64>,
+ mst: HashMap<usize, Vec<usize>>,
+ hash_seed: &Option<u64>,
+) -> (Vec<usize>, f64) {
+ let mut ext_mst: HashMap<usize, MultiSet<usize, RandomState>> = get_hashmap(hash_seed, Some(mst.len()));
+ for (k, v) in mst.into_iter() {
+ let mut as_mset = get_multiset(hash_seed, Some(v.len()));
+ for l in v.into_iter() {
+ as_mset.insert(l);
+ }
+ ext_mst.insert(k, as_mset);
+ }
- let odd_verts: Vec<_> = mst
+ let odd_verts: Vec<_> = ext_mst
.iter()
.filter_map(|(&i, s)| if s.len() % 2 == 0 { None } else { Some(i) })
.collect();
- // from here on, 'mst' is a bit of a misnomer, as we're adding more edges such
+ // from here on, 'mst' would be a bit of a misnomer, as we're adding more edges such
// that all vertices have even degree
- for (a, b) in min_weight_matching(dist_cache, &odd_verts) {
- mst.get_mut(&a).unwrap().insert(b);
+ for (a, b) in min_weight_matching(dist_cache, &odd_verts, hash_seed) {
+ ext_mst.get_mut(&a).unwrap().insert(b);
}
let mut r = Vec::new();
let mut total_dist = 0.;
- let mut visited_verts = HashSet::new();
+ let mut visited_verts = get_hashset(hash_seed, None);
visited_verts.insert(usize::MAX);
let mut pprev = usize::MAX;
let mut prev = usize::MAX;
- let (mut cur, euler_tour) = euler_tour(mst);
+ let (mut cur, euler_tour) = euler_tour(ext_mst, hash_seed);
loop {
match euler_tour.get(&cur) {
None => break, // tour complete: this should happen iff 'cur == usize::MAX'
@@ -224,25 +282,65 @@ fn christofides(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, V
}
}
+
(r, total_dist)
}
-fn refine<M>(_: &[M], _: Vec<usize>, _: f64) -> (Vec<usize>, f64)
-where
- M: MetricElem,
-{
- // convert the tour into a linked-list representation. instead of pointers, we use
- // an array of indices
- //let tour_ll = Vec::new();
+fn refine_2_opt(
+ dist_cache: &HashMap<(usize, usize), f64>,
+ tour: Vec<usize>,
+ tour_len: f64,
+) -> (Vec<usize>, f64) {
+ if tour.len() < 4 {
+ return (tour, tour_len);
+ }
+
+ // convert the tour into a doubly linked-list. instead of pointers, we use
+ // an array of index pairs.
+ let mut tour: Vec<_> = tour
+ .into_iter()
+ .enumerate()
+ .map(|(i, v)| (i - 1, v, i + 1))
+ .collect();
+ let n = tour.len();
+ // fix boundary
+ tour[0].0 = n - 1;
+ tour[n - 1].2 = 0;
+
+ // this will be an implementation of 2-opt swapping where the swapping takes O(1).
+ // to do this, the links of a node no are no longer "backward"/"forward"; instead,
+ // the forward direction will be in the direction of the link you didn't come from.
+ // This implicit direction indicator allows for reversing a sublist in O(1)
+ let mut /* "left edge" */ le = (
+ /* previous tour index */ n - 1,
+ /* tour index */ 0,
+ );
+
+ let adv = |(pi, i): (usize, usize)| {
+ if tour[i].0 == pi {
+ (i, tour[i].2) // forward
+ } else {
+ (i, tour[i].0) // reverse
+ }
+ };
+
+ while le.0 != n - 1 {
+ // TODO so nehmen wir die boundary nicht mit.
+ let mut re = adv(le);
+ println!("{:?} <-> {:?}", le, re);
+
+ le = adv(le);
+ }
+
todo!()
}
-fn get_dist_cache<M>(embeds: &[M]) -> HashMap<(usize, usize), f64>
+fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> HashMap<(usize, usize), f64>
where
M: MetricElem,
{
let n = embeds.len();
- let mut r = HashMap::with_capacity((n * n - n) / 2);
+ let mut r = get_hashmap(hash_seed, Some((n * n - n) / 2));
for a in 0..n {
for b in a + 1..n {
r.insert((a, b), embeds[a].dist(&embeds[b]));
@@ -252,7 +350,7 @@ where
r
}
-pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg) -> (Vec<usize>, f64)
+pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg, hash_seed: &Option<u64>) -> (Vec<usize>, f64)
where
M: MetricElem,
{
@@ -261,19 +359,19 @@ where
bar.enable_steady_tick(std::time::Duration::from_millis(100));
bar.set_message("Calculating distances...");
- let dc = get_dist_cache(embeds);
+ let dc = get_dist_cache(embeds, hash_seed);
bar.set_message("Finding mst...");
- let mst = get_mst(&dc, embeds.len());
+ let mst = get_mst(&dc, embeds.len(), hash_seed);
bar.set_message("Finding path...");
let r = match alg {
TspAlg::MstDfs => tsp_from_mst(&dc, mst),
- TspAlg::Christofides => christofides(&dc, mst),
+ TspAlg::Christofides => christofides(&dc, mst, hash_seed),
TspAlg::ChristofidesRefined => {
- let (p, l) = christofides(&dc, mst);
+ let (p, l) = christofides(&dc, mst, hash_seed);
bar.set_message("Refining path...");
- refine(embeds, p, l)
+ refine_2_opt(&dc, p, l)
}
};