diff options
-rw-r--r-- | src/tsp_approx.rs | 61 |
1 files changed, 33 insertions, 28 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 4c886cd..79aca3d 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -5,12 +5,7 @@ use std::collections::{HashMap, HashSet}; use crate::{MetricElem, TspAlg}; -fn get_mst<M>(embeds: &[M]) -> HashMap<usize, Vec<usize>> -where - M: MetricElem, -{ - let num_embeds = embeds.len(); - +fn get_mst(dist_cache: &HashMap<(usize, usize), f64>, num_embeds: usize) -> HashMap<usize, Vec<usize>> { let mut possible_edges = Vec::with_capacity((num_embeds * num_embeds - num_embeds) / 2); let mut mst = HashMap::with_capacity(num_embeds); @@ -18,7 +13,7 @@ where mst.insert(0, Vec::new()); for a in 0..num_embeds { for b in a + 1..num_embeds { - possible_edges.push((embeds[a].dist(&embeds[b]), a, b)); + possible_edges.push((dist_cache[&(a.min(b), a.max(b))], a, b)); } } @@ -40,10 +35,7 @@ where mst } -fn tsp_from_mst<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) -where - M: MetricElem, -{ +fn tsp_from_mst(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) { fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|&c| { @@ -57,7 +49,8 @@ where let mut total_dist = 0.; for i in 0..tsp_path.len() - 1 { - total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]); + let (a, b) = (tsp_path[i], tsp_path[i + 1]); + total_dist += dist_cache[&(a.min(b), a.max(b))]; } (tsp_path, total_dist) @@ -66,10 +59,7 @@ where // TODO this is a non-ideal, greedy algorithm. there are better algorithms for this, // and i should probably implement one. /// 'verts' must be an even number of vertices with odd degree -fn min_weight_matching<M>(embeds: &[M], verts: &[usize]) -> HashMap<usize, usize> -where - M: MetricElem, -{ +fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize]) -> HashMap<usize, usize> { let num_odd = verts.len(); assert!(num_odd % 2 == 0); @@ -77,7 +67,7 @@ where for &x in verts { for &y in verts { if x != y { - possible_edges.push((embeds[x].dist(&embeds[y]), x, y)); + possible_edges.push((dist_cache[&(x.min(y), x.max(y))], x, y)); } } } @@ -188,10 +178,7 @@ fn euler_tour( (root, r) } -fn christofides<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) -where - M: MetricElem, -{ +fn christofides(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) { let mut mst: HashMap<_, HashMultiSet<_>> = mst .into_iter() .map(|(k, v)| (k, v.into_iter().collect())) @@ -204,7 +191,7 @@ where // from here on, 'mst' is a bit of a misnomer, as we're adding more edges such // that all vertices have even degree - for (a, b) in min_weight_matching(embeds, &odd_verts) { + for (a, b) in min_weight_matching(dist_cache, &odd_verts) { mst.get_mut(&a).unwrap().insert(b); } @@ -228,7 +215,7 @@ where if visited_verts.insert(next) { // haven't visited 'next' yet r.push(next); - total_dist += embeds[cur].dist(&embeds[next]); + total_dist += dist_cache[&(cur.min(next), cur.max(next))]; } pprev = prev; prev = cur; @@ -250,6 +237,21 @@ where todo!() } +fn get_dist_cache<M>(embeds: &[M]) -> HashMap<(usize, usize), f64> +where + M: MetricElem, +{ + let n = embeds.len(); + let mut r = HashMap::with_capacity((n * n - n) / 2); + for a in 0..n { + for b in a + 1..n { + r.insert((a, b), embeds[a].dist(&embeds[b])); + } + } + + r +} + pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg) -> (Vec<usize>, f64) where M: MetricElem, @@ -257,16 +259,19 @@ where let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); bar.enable_steady_tick(std::time::Duration::from_millis(100)); - bar.set_message("Finding mst..."); - let mst = get_mst(embeds); + bar.set_message("Calculating distances..."); + let dc = get_dist_cache(embeds); + + bar.set_message("Finding mst..."); + let mst = get_mst(&dc, embeds.len()); bar.set_message("Finding path..."); let r = match alg { - TspAlg::MstDfs => tsp_from_mst(embeds, mst), - TspAlg::Christofides => christofides(embeds, mst), + TspAlg::MstDfs => tsp_from_mst(&dc, mst), + TspAlg::Christofides => christofides(&dc, mst), TspAlg::ChristofidesRefined => { - let (p, l) = christofides(embeds, mst); + let (p, l) = christofides(&dc, mst); bar.set_message("Refining path..."); refine(embeds, p, l) } |