aboutsummaryrefslogtreecommitdiff
path: root/src/tsp_approx.rs
blob: c697a255e3077a8a30adfb6e6603ad1a9bb97c83 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use indicatif::{ProgressBar, ProgressStyle};
use priority_queue::PriorityQueue;
use std::{cmp::Ordering, collections::HashMap};

use crate::MetricElem;

fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
where
    M: MetricElem,
{
    // wrapper struct to
    // - reverse the ordering
    // - implement Ord, even though the type is backed by an f64
    #[repr(transparent)]
    #[derive(Debug, PartialEq)]
    struct DownOrd(f64);
    impl Eq for DownOrd {}
    impl PartialOrd for DownOrd {
        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
            Some(self.cmp(other))
        }
    }
    impl Ord for DownOrd {
        fn cmp(&self, other: &Self) -> Ordering {
            self.0.partial_cmp(&other.0).unwrap().reverse()
        }
    }

    let num_embeds = embeds.len();

    let mut possible_edges =
        PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
    let mut mst = HashMap::with_capacity(num_embeds);

    // here, we start at 0.
    // we might get a better result in the end if we started with a vertex next
    // to the lowest-cost edge, but we don't know which one that is (though we
    // could compute that without changing our asymptotic complexity)
    mst.insert(0, Vec::new());
    for i in 1..num_embeds {
        possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
    }

    // prims algorithm or something like that
    while mst.len() < num_embeds {
        // find the edge with the least cost that connects us to a new vertex
        let (new, old) = loop {
            let ((a, b), _) = possible_edges.pop().unwrap();
            if !mst.contains_key(&a) {
                break (a, b);
            } else if !mst.contains_key(&b) {
                break (b, a);
            }
        };
        mst.insert(new, Vec::new());

        // insert all the new edges we could take
        mst.entry(old).and_modify(|v| v.push(new));
        for i in 0..num_embeds {
            // don't consider edges taking us to nodes we already visited
            if mst.contains_key(&i) {
                continue;
            }

            possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i])));
        }
    }

    mst
}

fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
    fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
        into.push(cur);
        t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
    }
    let mut tsp_path = Vec::with_capacity(mst.len());
    dfs(0, &mst, &mut tsp_path);

    tsp_path
}

pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize>
where
    M: MetricElem,
{
    let bar = ProgressBar::new_spinner();
    bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
    bar.enable_steady_tick(std::time::Duration::from_millis(100));
    bar.set_message("Finding path...");

    let r = tsp_from_mst(get_mst(embeds));

    bar.finish();

    r
}