use indicatif::{ProgressBar, ProgressStyle}; use priority_queue::PriorityQueue; use std::{cmp::Ordering, collections::HashMap}; use crate::MetricElem; fn get_mst(embeds: &Vec) -> HashMap> where M: MetricElem, { // wrapper struct to // - reverse the ordering // - implement Ord, even though the type is backed by an f64 #[repr(transparent)] #[derive(Debug, PartialEq)] struct DownOrd(f64); impl Eq for DownOrd {} impl PartialOrd for DownOrd { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for DownOrd { fn cmp(&self, other: &Self) -> Ordering { self.0.partial_cmp(&other.0).unwrap().reverse() } } let num_embeds = embeds.len(); let mut possible_edges = PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); let mut mst = HashMap::with_capacity(num_embeds); // here, we start at 0. // we might get a better result in the end if we started with a vertex next // to the lowest-cost edge, but we don't know which one that is (though we // could compute that without changing our asymptotic complexity) mst.insert(0, Vec::new()); for i in 1..num_embeds { possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); } // prims algorithm or something like that while mst.len() < num_embeds { // find the edge with the least cost that connects us to a new vertex let (new, old) = loop { let ((a, b), _) = possible_edges.pop().unwrap(); if !mst.contains_key(&a) { break (a, b); } else if !mst.contains_key(&b) { break (b, a); } }; mst.insert(new, Vec::new()); // insert all the new edges we could take mst.entry(old).and_modify(|v| v.push(new)); for i in 0..num_embeds { // don't consider edges taking us to nodes we already visited if mst.contains_key(&i) { continue; } possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i]))); } } mst } fn tsp_from_mst(embeds: &Vec, mst: HashMap>) -> (Vec, f64) where M: MetricElem, { fn dfs(cur: usize, t: &HashMap>, into: &mut Vec) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); } let mut tsp_path = Vec::with_capacity(mst.len()); dfs(0, &mst, &mut tsp_path); let mut total_dist = 0.; for i in 0..tsp_path.len() - 1 { total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]); } (tsp_path, total_dist) } pub fn tsp(embeds: &Vec) -> (Vec, f64) where M: MetricElem, { let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Finding path..."); let r = tsp_from_mst(embeds, get_mst(embeds)); bar.finish(); r }