diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2024-03-03 21:28:19 +0100 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2024-03-03 21:28:19 +0100 |
commit | 919c286140d5c5efd4b57d507461090c5c7fb31a (patch) | |
tree | b227094ec771d35cce4798e2b4b4a517c138dde8 | |
parent | 3b670fe70de4b52b9a3f8614b42fb94fe49b3548 (diff) | |
download | embeddings-sort-919c286140d5c5efd4b57d507461090c5c7fb31a.tar embeddings-sort-919c286140d5c5efd4b57d507461090c5c7fb31a.tar.bz2 embeddings-sort-919c286140d5c5efd4b57d507461090c5c7fb31a.tar.zst |
abstract away the distance cache, calculate tour length at the end to simplify code & improve accuracy, further work on 2-opt
-rw-r--r-- | src/tsp_approx.rs | 92 |
1 files changed, 59 insertions, 33 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 914f414..f79d0a4 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -6,6 +6,13 @@ use std::{hash::Hash, cmp::Eq}; use crate::{MetricElem, TspAlg}; +struct DistCache(HashMap<(usize, usize), f64>); +impl DistCache { + fn dist(&self, a: usize, b: usize) -> f64 { + self.0[&(a.min(b), a.max(b))] + } +} + fn get_hashmap<K, V>(hash_seed: &Option<u64>, capacity: Option<usize>) -> HashMap<K, V> { let hasher = match hash_seed { Some(s) => RandomState::with_seeds(*s, 0, 0, 0), @@ -43,7 +50,7 @@ fn get_multiset<V: Hash + Eq>(hash_seed: &Option<u64>, capacity: Option<usize>) } fn get_mst( - dist_cache: &HashMap<(usize, usize), f64>, + dist_cache: &DistCache, num_embeds: usize, hash_seed: &Option<u64>, ) -> HashMap<usize, Vec<usize>> { @@ -54,7 +61,7 @@ fn get_mst( mst.insert(0, Vec::new()); for a in 0..num_embeds { for b in a + 1..num_embeds { - possible_edges.push((dist_cache[&(a.min(b), a.max(b))], a, b)); + possible_edges.push((dist_cache.dist(a, b), a, b)); } } @@ -77,9 +84,8 @@ fn get_mst( } fn tsp_from_mst( - dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>, -) -> (Vec<usize>, f64) { +) -> Vec<usize> { fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|&c| { @@ -91,20 +97,14 @@ fn tsp_from_mst( let mut tsp_path = Vec::with_capacity(mst.len()); dfs(0, usize::MAX, &mst, &mut tsp_path); - let mut total_dist = 0.; - for i in 0..tsp_path.len() - 1 { - let (a, b) = (tsp_path[i], tsp_path[i + 1]); - total_dist += dist_cache[&(a.min(b), a.max(b))]; - } - - (tsp_path, total_dist) + tsp_path } // this is a greedy, non-exact implementation. The polynomial time solution would be // in O(n^3), which is too large for my taste, while this is O(n^2). /// 'verts' must be an even number of vertices with odd degree fn min_weight_matching( - dist_cache: &HashMap<(usize, usize), f64>, + dist_cache: &DistCache, verts: &[usize], hash_seed: &Option<u64>, ) -> HashMap<usize, usize> { @@ -115,7 +115,7 @@ fn min_weight_matching( for &x in verts { for &y in verts { if x != y { - possible_edges.push((dist_cache[&(x.min(y), x.max(y))], x, y)); + possible_edges.push((dist_cache.dist(x, y), x, y)); } } } @@ -229,10 +229,10 @@ fn euler_tour( } fn christofides( - dist_cache: &HashMap<(usize, usize), f64>, + dist_cache: &DistCache, mst: HashMap<usize, Vec<usize>>, hash_seed: &Option<u64>, -) -> (Vec<usize>, f64) { +) -> Vec<usize> { let mut ext_mst: HashMap<usize, MultiSet<usize, RandomState>> = get_hashmap(hash_seed, Some(mst.len())); for (k, v) in mst.into_iter() { let mut as_mset = get_multiset(hash_seed, Some(v.len())); @@ -254,7 +254,6 @@ fn christofides( } let mut r = Vec::new(); - let mut total_dist = 0.; let mut visited_verts = get_hashset(hash_seed, None); visited_verts.insert(usize::MAX); @@ -273,7 +272,6 @@ fn christofides( if visited_verts.insert(next) { // haven't visited 'next' yet r.push(next); - total_dist += dist_cache[&(cur.min(next), cur.max(next))]; } pprev = prev; prev = cur; @@ -283,16 +281,15 @@ fn christofides( } - (r, total_dist) + r } fn refine_2_opt( - dist_cache: &HashMap<(usize, usize), f64>, + dist_cache: &DistCache, tour: Vec<usize>, - tour_len: f64, -) -> (Vec<usize>, f64) { +) -> Vec<usize> { if tour.len() < 4 { - return (tour, tour_len); + return tour; } // convert the tour into a doubly linked-list. instead of pointers, we use @@ -324,18 +321,41 @@ fn refine_2_opt( } }; - while le.0 != n - 1 { - // TODO so nehmen wir die boundary nicht mit. + // TODO so nehmen wir die boundary nicht mit. NOTE: das flattening geht davon aus, + // das wir das nicht tuen + while le.0 != n - 2 { let mut re = adv(le); - println!("{:?} <-> {:?}", le, re); + while re.0 != n - 1 { + let le_elems = (tour[le.0].1, tour[le.1].1); + let re_elems = (tour[re.0].1, tour[re.1].1); + let cur_cost = dist_cache.dist(le_elems.0, le_elems.1) + dist_cache.dist(re_elems.0, re_elems.1); + let swap_cost = dist_cache.dist(le_elems.0, re_elems.0) + dist_cache.dist(le_elems.1, re_elems.1); + + if cur_cost <= swap_cost { + re = adv(re); + continue; + } + + re = adv(re); + } le = adv(le); } - todo!() + // calculate tour as vector from linked-list + let mut tour_flat = Vec::with_capacity(n); + let mut cur = (n - 1, 0); + + tour_flat.push(tour[0].1); + while cur.1 != n - 1 { + cur = adv(cur); + tour_flat.push(tour[cur.1].1); + } + + tour_flat } -fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> HashMap<(usize, usize), f64> +fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> DistCache where M: MetricElem, { @@ -347,7 +367,7 @@ where } } - r + DistCache(r) } pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg, hash_seed: &Option<u64>) -> (Vec<usize>, f64) @@ -365,17 +385,23 @@ where let mst = get_mst(&dc, embeds.len(), hash_seed); bar.set_message("Finding path..."); - let r = match alg { - TspAlg::MstDfs => tsp_from_mst(&dc, mst), + let tour = match alg { + TspAlg::MstDfs => tsp_from_mst(mst), TspAlg::Christofides => christofides(&dc, mst, hash_seed), TspAlg::ChristofidesRefined => { - let (p, l) = christofides(&dc, mst, hash_seed); + let tour = christofides(&dc, mst, hash_seed); bar.set_message("Refining path..."); - refine_2_opt(&dc, p, l) + refine_2_opt(&dc, tour) } }; + let mut total_dist = 0.; + for i in 0..tour.len() - 1 { + let (a, b) = (tour[i], tour[i + 1]); + total_dist += dc.dist(a, b); + } + bar.finish(); - r + (tour, total_dist) } |