aboutsummaryrefslogtreecommitdiff
path: root/src/tsp_approx.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/tsp_approx.rs')
-rw-r--r--src/tsp_approx.rs97
1 files changed, 97 insertions, 0 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
new file mode 100644
index 0000000..c697a25
--- /dev/null
+++ b/src/tsp_approx.rs
@@ -0,0 +1,97 @@
+use indicatif::{ProgressBar, ProgressStyle};
+use priority_queue::PriorityQueue;
+use std::{cmp::Ordering, collections::HashMap};
+
+use crate::MetricElem;
+
+fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
+where
+ M: MetricElem,
+{
+ // wrapper struct to
+ // - reverse the ordering
+ // - implement Ord, even though the type is backed by an f64
+ #[repr(transparent)]
+ #[derive(Debug, PartialEq)]
+ struct DownOrd(f64);
+ impl Eq for DownOrd {}
+ impl PartialOrd for DownOrd {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+ }
+ impl Ord for DownOrd {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.0.partial_cmp(&other.0).unwrap().reverse()
+ }
+ }
+
+ let num_embeds = embeds.len();
+
+ let mut possible_edges =
+ PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
+ let mut mst = HashMap::with_capacity(num_embeds);
+
+ // here, we start at 0.
+ // we might get a better result in the end if we started with a vertex next
+ // to the lowest-cost edge, but we don't know which one that is (though we
+ // could compute that without changing our asymptotic complexity)
+ mst.insert(0, Vec::new());
+ for i in 1..num_embeds {
+ possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
+ }
+
+ // prims algorithm or something like that
+ while mst.len() < num_embeds {
+ // find the edge with the least cost that connects us to a new vertex
+ let (new, old) = loop {
+ let ((a, b), _) = possible_edges.pop().unwrap();
+ if !mst.contains_key(&a) {
+ break (a, b);
+ } else if !mst.contains_key(&b) {
+ break (b, a);
+ }
+ };
+ mst.insert(new, Vec::new());
+
+ // insert all the new edges we could take
+ mst.entry(old).and_modify(|v| v.push(new));
+ for i in 0..num_embeds {
+ // don't consider edges taking us to nodes we already visited
+ if mst.contains_key(&i) {
+ continue;
+ }
+
+ possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i])));
+ }
+ }
+
+ mst
+}
+
+fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
+ fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
+ into.push(cur);
+ t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
+ }
+ let mut tsp_path = Vec::with_capacity(mst.len());
+ dfs(0, &mst, &mut tsp_path);
+
+ tsp_path
+}
+
+pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize>
+where
+ M: MetricElem,
+{
+ let bar = ProgressBar::new_spinner();
+ bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
+ bar.enable_steady_tick(std::time::Duration::from_millis(100));
+ bar.set_message("Finding path...");
+
+ let r = tsp_from_mst(get_mst(embeds));
+
+ bar.finish();
+
+ r
+}