diff options
Diffstat (limited to 'src/tsp_approx.rs')
-rw-r--r-- | src/tsp_approx.rs | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs new file mode 100644 index 0000000..c697a25 --- /dev/null +++ b/src/tsp_approx.rs @@ -0,0 +1,97 @@ +use indicatif::{ProgressBar, ProgressStyle}; +use priority_queue::PriorityQueue; +use std::{cmp::Ordering, collections::HashMap}; + +use crate::MetricElem; + +fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> +where + M: MetricElem, +{ + // wrapper struct to + // - reverse the ordering + // - implement Ord, even though the type is backed by an f64 + #[repr(transparent)] + #[derive(Debug, PartialEq)] + struct DownOrd(f64); + impl Eq for DownOrd {} + impl PartialOrd for DownOrd { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } + } + impl Ord for DownOrd { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap().reverse() + } + } + + let num_embeds = embeds.len(); + + let mut possible_edges = + PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); + let mut mst = HashMap::with_capacity(num_embeds); + + // here, we start at 0. + // we might get a better result in the end if we started with a vertex next + // to the lowest-cost edge, but we don't know which one that is (though we + // could compute that without changing our asymptotic complexity) + mst.insert(0, Vec::new()); + for i in 1..num_embeds { + possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); + } + + // prims algorithm or something like that + while mst.len() < num_embeds { + // find the edge with the least cost that connects us to a new vertex + let (new, old) = loop { + let ((a, b), _) = possible_edges.pop().unwrap(); + if !mst.contains_key(&a) { + break (a, b); + } else if !mst.contains_key(&b) { + break (b, a); + } + }; + mst.insert(new, Vec::new()); + + // insert all the new edges we could take + mst.entry(old).and_modify(|v| v.push(new)); + for i in 0..num_embeds { + // don't consider edges taking us to nodes we already visited + if mst.contains_key(&i) { + continue; + } + + possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i]))); + } + } + + mst +} + +fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { + fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { + into.push(cur); + t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); + } + let mut tsp_path = Vec::with_capacity(mst.len()); + dfs(0, &mst, &mut tsp_path); + + tsp_path +} + +pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize> +where + M: MetricElem, +{ + let bar = ProgressBar::new_spinner(); + bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + bar.set_message("Finding path..."); + + let r = tsp_from_mst(get_mst(embeds)); + + bar.finish(); + + r +} |