aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2024-03-03 21:28:19 +0100
committerLia Lenckowski <lialenck@protonmail.com>2024-03-03 21:28:19 +0100
commit919c286140d5c5efd4b57d507461090c5c7fb31a (patch)
treeb227094ec771d35cce4798e2b4b4a517c138dde8
parent3b670fe70de4b52b9a3f8614b42fb94fe49b3548 (diff)
downloadembeddings-sort-919c286140d5c5efd4b57d507461090c5c7fb31a.tar
embeddings-sort-919c286140d5c5efd4b57d507461090c5c7fb31a.tar.bz2
embeddings-sort-919c286140d5c5efd4b57d507461090c5c7fb31a.tar.zst
abstract away the distance cache, calculate tour length at the end to simplify code & improve accuracy, further work on 2-opt
-rw-r--r--src/tsp_approx.rs92
1 files changed, 59 insertions, 33 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 914f414..f79d0a4 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -6,6 +6,13 @@ use std::{hash::Hash, cmp::Eq};
use crate::{MetricElem, TspAlg};
+struct DistCache(HashMap<(usize, usize), f64>);
+impl DistCache {
+ fn dist(&self, a: usize, b: usize) -> f64 {
+ self.0[&(a.min(b), a.max(b))]
+ }
+}
+
fn get_hashmap<K, V>(hash_seed: &Option<u64>, capacity: Option<usize>) -> HashMap<K, V> {
let hasher = match hash_seed {
Some(s) => RandomState::with_seeds(*s, 0, 0, 0),
@@ -43,7 +50,7 @@ fn get_multiset<V: Hash + Eq>(hash_seed: &Option<u64>, capacity: Option<usize>)
}
fn get_mst(
- dist_cache: &HashMap<(usize, usize), f64>,
+ dist_cache: &DistCache,
num_embeds: usize,
hash_seed: &Option<u64>,
) -> HashMap<usize, Vec<usize>> {
@@ -54,7 +61,7 @@ fn get_mst(
mst.insert(0, Vec::new());
for a in 0..num_embeds {
for b in a + 1..num_embeds {
- possible_edges.push((dist_cache[&(a.min(b), a.max(b))], a, b));
+ possible_edges.push((dist_cache.dist(a, b), a, b));
}
}
@@ -77,9 +84,8 @@ fn get_mst(
}
fn tsp_from_mst(
- dist_cache: &HashMap<(usize, usize), f64>,
mst: HashMap<usize, Vec<usize>>,
-) -> (Vec<usize>, f64) {
+) -> Vec<usize> {
fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|&c| {
@@ -91,20 +97,14 @@ fn tsp_from_mst(
let mut tsp_path = Vec::with_capacity(mst.len());
dfs(0, usize::MAX, &mst, &mut tsp_path);
- let mut total_dist = 0.;
- for i in 0..tsp_path.len() - 1 {
- let (a, b) = (tsp_path[i], tsp_path[i + 1]);
- total_dist += dist_cache[&(a.min(b), a.max(b))];
- }
-
- (tsp_path, total_dist)
+ tsp_path
}
// this is a greedy, non-exact implementation. The polynomial time solution would be
// in O(n^3), which is too large for my taste, while this is O(n^2).
/// 'verts' must be an even number of vertices with odd degree
fn min_weight_matching(
- dist_cache: &HashMap<(usize, usize), f64>,
+ dist_cache: &DistCache,
verts: &[usize],
hash_seed: &Option<u64>,
) -> HashMap<usize, usize> {
@@ -115,7 +115,7 @@ fn min_weight_matching(
for &x in verts {
for &y in verts {
if x != y {
- possible_edges.push((dist_cache[&(x.min(y), x.max(y))], x, y));
+ possible_edges.push((dist_cache.dist(x, y), x, y));
}
}
}
@@ -229,10 +229,10 @@ fn euler_tour(
}
fn christofides(
- dist_cache: &HashMap<(usize, usize), f64>,
+ dist_cache: &DistCache,
mst: HashMap<usize, Vec<usize>>,
hash_seed: &Option<u64>,
-) -> (Vec<usize>, f64) {
+) -> Vec<usize> {
let mut ext_mst: HashMap<usize, MultiSet<usize, RandomState>> = get_hashmap(hash_seed, Some(mst.len()));
for (k, v) in mst.into_iter() {
let mut as_mset = get_multiset(hash_seed, Some(v.len()));
@@ -254,7 +254,6 @@ fn christofides(
}
let mut r = Vec::new();
- let mut total_dist = 0.;
let mut visited_verts = get_hashset(hash_seed, None);
visited_verts.insert(usize::MAX);
@@ -273,7 +272,6 @@ fn christofides(
if visited_verts.insert(next) {
// haven't visited 'next' yet
r.push(next);
- total_dist += dist_cache[&(cur.min(next), cur.max(next))];
}
pprev = prev;
prev = cur;
@@ -283,16 +281,15 @@ fn christofides(
}
- (r, total_dist)
+ r
}
fn refine_2_opt(
- dist_cache: &HashMap<(usize, usize), f64>,
+ dist_cache: &DistCache,
tour: Vec<usize>,
- tour_len: f64,
-) -> (Vec<usize>, f64) {
+) -> Vec<usize> {
if tour.len() < 4 {
- return (tour, tour_len);
+ return tour;
}
// convert the tour into a doubly linked-list. instead of pointers, we use
@@ -324,18 +321,41 @@ fn refine_2_opt(
}
};
- while le.0 != n - 1 {
- // TODO so nehmen wir die boundary nicht mit.
+ // TODO so nehmen wir die boundary nicht mit. NOTE: das flattening geht davon aus,
+ // das wir das nicht tuen
+ while le.0 != n - 2 {
let mut re = adv(le);
- println!("{:?} <-> {:?}", le, re);
+ while re.0 != n - 1 {
+ let le_elems = (tour[le.0].1, tour[le.1].1);
+ let re_elems = (tour[re.0].1, tour[re.1].1);
+ let cur_cost = dist_cache.dist(le_elems.0, le_elems.1) + dist_cache.dist(re_elems.0, re_elems.1);
+ let swap_cost = dist_cache.dist(le_elems.0, re_elems.0) + dist_cache.dist(le_elems.1, re_elems.1);
+
+ if cur_cost <= swap_cost {
+ re = adv(re);
+ continue;
+ }
+
+ re = adv(re);
+ }
le = adv(le);
}
- todo!()
+ // calculate tour as vector from linked-list
+ let mut tour_flat = Vec::with_capacity(n);
+ let mut cur = (n - 1, 0);
+
+ tour_flat.push(tour[0].1);
+ while cur.1 != n - 1 {
+ cur = adv(cur);
+ tour_flat.push(tour[cur.1].1);
+ }
+
+ tour_flat
}
-fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> HashMap<(usize, usize), f64>
+fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> DistCache
where
M: MetricElem,
{
@@ -347,7 +367,7 @@ where
}
}
- r
+ DistCache(r)
}
pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg, hash_seed: &Option<u64>) -> (Vec<usize>, f64)
@@ -365,17 +385,23 @@ where
let mst = get_mst(&dc, embeds.len(), hash_seed);
bar.set_message("Finding path...");
- let r = match alg {
- TspAlg::MstDfs => tsp_from_mst(&dc, mst),
+ let tour = match alg {
+ TspAlg::MstDfs => tsp_from_mst(mst),
TspAlg::Christofides => christofides(&dc, mst, hash_seed),
TspAlg::ChristofidesRefined => {
- let (p, l) = christofides(&dc, mst, hash_seed);
+ let tour = christofides(&dc, mst, hash_seed);
bar.set_message("Refining path...");
- refine_2_opt(&dc, p, l)
+ refine_2_opt(&dc, tour)
}
};
+ let mut total_dist = 0.;
+ for i in 0..tour.len() - 1 {
+ let (a, b) = (tour[i], tour[i + 1]);
+ total_dist += dc.dist(a, b);
+ }
+
bar.finish();
- r
+ (tour, total_dist)
}