diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-19 15:24:45 +0200 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-19 15:24:45 +0200 |
commit | 5f926dd6e4f884f6d29c88207480a6bd0b97aa2a (patch) | |
tree | d5e1f0d316a325b23739c37b59689262e003aec3 /src/tsp_approx.rs | |
parent | fe82169e0a84692cf161e6210c4c522912e70e72 (diff) | |
download | embeddings-sort-5f926dd6e4f884f6d29c88207480a6bd0b97aa2a.tar embeddings-sort-5f926dd6e4f884f6d29c88207480a6bd0b97aa2a.tar.bz2 embeddings-sort-5f926dd6e4f884f6d29c88207480a6bd0b97aa2a.tar.zst |
christofides algorithm (~8% improvement)
Diffstat (limited to 'src/tsp_approx.rs')
-rw-r--r-- | src/tsp_approx.rs | 228 |
1 files changed, 203 insertions, 25 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 648d9ea..1da8212 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -1,31 +1,35 @@ use indicatif::{ProgressBar, ProgressStyle}; +use multiset::HashMultiSet; use priority_queue::PriorityQueue; -use std::{cmp::Ordering, collections::HashMap}; +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, +}; use crate::MetricElem; +// wrapper struct to +// - reverse the ordering +// - implement Ord, even though the type is backed by an f64 +#[repr(transparent)] +#[derive(Debug, PartialEq)] +struct DownOrd(f64); +impl Eq for DownOrd {} +impl PartialOrd for DownOrd { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } +} +impl Ord for DownOrd { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap().reverse() + } +} + fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> where M: MetricElem, { - // wrapper struct to - // - reverse the ordering - // - implement Ord, even though the type is backed by an f64 - #[repr(transparent)] - #[derive(Debug, PartialEq)] - struct DownOrd(f64); - impl Eq for DownOrd {} - impl PartialOrd for DownOrd { - fn partial_cmp(&self, other: &Self) -> Option<Ordering> { - Some(self.cmp(other)) - } - } - impl Ord for DownOrd { - fn cmp(&self, other: &Self) -> Ordering { - self.0.partial_cmp(&other.0).unwrap().reverse() - } - } - let num_embeds = embeds.len(); let mut possible_edges = @@ -52,7 +56,7 @@ where break (b, a); } }; - mst.insert(new, Vec::new()); + mst.insert(new, vec![old]); // insert all the new edges we could take mst.entry(old).and_modify(|v| v.push(new)); @@ -69,16 +73,20 @@ where mst } -fn tsp_from_mst<M>(embeds: &Vec<M>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) +fn tsp_from_mst<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) where M: MetricElem, { - fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { + fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); - t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); + t.get(&cur).unwrap().iter().for_each(|&c| { + if c != prev { + dfs(c, cur, t, into) + } + }); } let mut tsp_path = Vec::with_capacity(mst.len()); - dfs(0, &mst, &mut tsp_path); + dfs(0, usize::MAX, &mst, &mut tsp_path); let mut total_dist = 0.; for i in 0..tsp_path.len() - 1 { @@ -88,6 +96,175 @@ where (tsp_path, total_dist) } +// TODO this is a non-ideal, greedy algorithm. there are better algorithms for this, +// and i should probably implement one. +/// 'verts' must be an even number of vertices with odd degree +fn min_weight_matching<M>(embeds: &[M], verts: &[usize]) -> HashMap<usize, usize> +where + M: MetricElem, +{ + let num_odd = verts.len(); + assert!(num_odd % 2 == 0); + + let mut possible_edges = PriorityQueue::with_capacity((num_odd * num_odd - num_odd) / 2); + for &x in verts { + for &y in verts { + if x != y { + possible_edges.push((x, y), DownOrd(embeds[x].dist(&embeds[y]))); + } + } + } + + let mut res = HashMap::new(); + + while res.len() != num_odd { + let ((a, b), _) = possible_edges.pop().unwrap(); + + if !res.contains_key(&a) && !res.contains_key(&b) { + res.insert(a, b); + res.insert(b, a); + } + } + + res +} + +fn euler_tour( + mut graph: HashMap<usize, HashMultiSet<usize>>, +) -> (usize, HashMap<usize, Vec<(usize, usize, usize)>>) { + let mut r: HashMap<_, Vec<_>> = HashMap::new(); + let mut partially_explored = HashSet::new(); + + // initial setup: pretend that we have only the path 'INF -> root -> INF' + // for some arbitrary root, and set cur to some node next to root. + // This mimicks the state we're in just after a new phase (because it is) + let &root = graph.keys().next().unwrap(); + r.insert(root, vec![(usize::MAX, usize::MAX, usize::MAX)]); + let e = graph.get_mut(&root).unwrap(); + let &next = e.iter().next().unwrap(); + e.remove(&next); + graph.get_mut(&next).unwrap().remove(&root); + + let mut cur = next; + let mut prev = root; + let mut pprev = usize::MAX; + let mut circ_start_edge = cur; + + loop { + let e = graph.get_mut(&cur).unwrap(); + if e.len() <= 1 { + partially_explored.remove(&cur); + } else { + partially_explored.insert(cur); + } + + match e.iter().next() { + Some(&next) => { + e.remove(&next); + // TODO das hier lässt vllt deduplizieren + graph.get_mut(&next).unwrap().remove(&cur); + + r.entry(cur).or_default().push((pprev, prev, next)); + pprev = prev; + prev = cur; + cur = next; + } + None => { + // we got stuck, which means we returned to the starting vertex of + // the current phase. now, we need join the 2 formed trips + + // pick an arbitrary existing edge-pair going through cur + let cur_vec = r.get_mut(&cur).unwrap(); + let (pp, p, n) = cur_vec[0]; + // reroute + cur_vec[0] = (pp, p, circ_start_edge); + cur_vec.push((pprev, prev, n)); + + // after rerouting, the pprev value of the next node will be wrong + match r.get_mut(&n) { + // should only happen with n == usize::MAX. no need to reroute the + // following node in that case, as there's no such node + None => (), + Some(rerouted_vec) => { + let p = rerouted_vec.iter().position(|&(_, other_p, _)| other_p == cur).unwrap(); + rerouted_vec[p] = (prev, cur, rerouted_vec[p].2); + } + } + + // are there any partially explored vertices left? + match partially_explored.iter().next() { + None => break, // graph fully explored :) + Some(&new_cur) => { + // reset our active point (note that we don't have to delete + // new_cur from partially_explored; that's done the next time we + // get there) + let e = graph.get_mut(&new_cur).unwrap(); + let &next = e.iter().next().unwrap(); + e.remove(&next); + graph.get_mut(&next).unwrap().remove(&new_cur); + + circ_start_edge = next; + pprev = r[&new_cur][0].1; + prev = new_cur; + cur = next; + } + } + } + } + } + + (root, r) +} + +fn christofides<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) +where + M: MetricElem, +{ + let mut mst: HashMap<_, HashMultiSet<_>> = mst + .into_iter() + .map(|(k, v)| (k, v.into_iter().collect())) + .collect(); + + let odd_verts: Vec<_> = mst + .iter() + .filter_map(|(&i, s)| if s.len() % 2 == 0 { None } else { Some(i) }) + .collect(); + + // from here on, 'mst' is a bit of a misnomer, as we're adding more edges such + // that all vertices have even degree + for (a, b) in min_weight_matching(embeds, &odd_verts) { + mst.get_mut(&a).unwrap().insert(b); + } + + let mut r = Vec::new(); + let mut total_dist = 0.; + let mut visited_verts = HashSet::new(); + visited_verts.insert(usize::MAX); + + let mut pprev = usize::MAX; + let mut prev = usize::MAX; + let (mut cur, euler_tour) = euler_tour(mst); + loop { + match euler_tour.get(&cur) { + None => break, // tour complete: this should happen iff 'cur == usize::MAX' + Some(v) => { + let &(_, _, next) = v.iter().find(|(pp, p, _)| *pp == pprev && *p == prev).unwrap(); + + if visited_verts.insert(next) { + // haven't visited 'next' yet + r.push(next); + total_dist += embeds[cur].dist(&embeds[next]); + } + pprev = prev; + prev = cur; + cur = next; + } + } + } + + (r, total_dist) +} + pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64) where M: MetricElem, @@ -97,7 +274,8 @@ where bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Finding path..."); - let r = tsp_from_mst(embeds, get_mst(embeds)); + //let r = tsp_from_mst(embeds, get_mst(embeds)); + let r = christofides(embeds, get_mst(embeds)); bar.finish(); |