1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
|
use indicatif::{ProgressBar, ProgressStyle};
use priority_queue::PriorityQueue;
use std::{cmp::Ordering, collections::HashMap};
use crate::MetricElem;
fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
where
M: MetricElem,
{
// wrapper struct to
// - reverse the ordering
// - implement Ord, even though the type is backed by an f64
#[repr(transparent)]
#[derive(Debug, PartialEq)]
struct DownOrd(f64);
impl Eq for DownOrd {}
impl PartialOrd for DownOrd {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DownOrd {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap().reverse()
}
}
let num_embeds = embeds.len();
let mut possible_edges =
PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
let mut mst = HashMap::with_capacity(num_embeds);
// here, we start at 0.
// we might get a better result in the end if we started with a vertex next
// to the lowest-cost edge, but we don't know which one that is (though we
// could compute that without changing our asymptotic complexity)
mst.insert(0, Vec::new());
for i in 1..num_embeds {
possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
}
// prims algorithm or something like that
while mst.len() < num_embeds {
// find the edge with the least cost that connects us to a new vertex
let (new, old) = loop {
let ((a, b), _) = possible_edges.pop().unwrap();
if !mst.contains_key(&a) {
break (a, b);
} else if !mst.contains_key(&b) {
break (b, a);
}
};
mst.insert(new, Vec::new());
// insert all the new edges we could take
mst.entry(old).and_modify(|v| v.push(new));
for i in 0..num_embeds {
// don't consider edges taking us to nodes we already visited
if mst.contains_key(&i) {
continue;
}
possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i])));
}
}
mst
}
fn tsp_from_mst<M>(embeds: &Vec<M>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
where
M: MetricElem,
{
fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
}
let mut tsp_path = Vec::with_capacity(mst.len());
dfs(0, &mst, &mut tsp_path);
let mut total_dist = 0.;
for i in 0..tsp_path.len() - 1 {
total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]);
}
(tsp_path, total_dist)
}
pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64)
where
M: MetricElem,
{
let bar = ProgressBar::new_spinner();
bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
bar.enable_steady_tick(std::time::Duration::from_millis(100));
bar.set_message("Finding path...");
let r = tsp_from_mst(embeds, get_mst(embeds));
bar.finish();
r
}
|