aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-19 15:24:45 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-19 15:24:45 +0200
commit5f926dd6e4f884f6d29c88207480a6bd0b97aa2a (patch)
treed5e1f0d316a325b23739c37b59689262e003aec3
parentfe82169e0a84692cf161e6210c4c522912e70e72 (diff)
downloadembeddings-sort-5f926dd6e4f884f6d29c88207480a6bd0b97aa2a.tar
embeddings-sort-5f926dd6e4f884f6d29c88207480a6bd0b97aa2a.tar.bz2
embeddings-sort-5f926dd6e4f884f6d29c88207480a6bd0b97aa2a.tar.zst
christofides algorithm (~8% improvement)
-rw-r--r--Cargo.lock7
-rw-r--r--Cargo.toml1
-rw-r--r--src/main.rs5
-rw-r--r--src/tsp_approx.rs228
4 files changed, 215 insertions, 26 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 486a600..3871c28 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -288,6 +288,7 @@ dependencies = [
"clap",
"image",
"indicatif",
+ "multiset",
"pathdiff",
"priority-queue",
"rayon",
@@ -581,6 +582,12 @@ dependencies = [
]
[[package]]
+name = "multiset"
+version = "0.0.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ce8738c9ddd350996cb8b8b718192851df960803764bcdaa3afb44a63b1ddb5c"
+
+[[package]]
name = "nanorand"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index b6df99d..20d6077 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -20,3 +20,4 @@ sha2 = "0"
anyhow = "1"
pathdiff = "0"
reflink-copy = "0"
+multiset = "0"
diff --git a/src/main.rs b/src/main.rs
index 993e490..b9df0a9 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -127,7 +127,10 @@ where
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
let (tsp_path, total_dist) = tsp(&embeds);
- Ok((tsp_path.iter().map(|i| args.images[*i].clone()).collect(), total_dist))
+ Ok((
+ tsp_path.iter().map(|i| args.images[*i].clone()).collect(),
+ total_dist,
+ ))
}
fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> {
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 648d9ea..1da8212 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -1,31 +1,35 @@
use indicatif::{ProgressBar, ProgressStyle};
+use multiset::HashMultiSet;
use priority_queue::PriorityQueue;
-use std::{cmp::Ordering, collections::HashMap};
+use std::{
+ cmp::Ordering,
+ collections::{HashMap, HashSet},
+};
use crate::MetricElem;
+// wrapper struct to
+// - reverse the ordering
+// - implement Ord, even though the type is backed by an f64
+#[repr(transparent)]
+#[derive(Debug, PartialEq)]
+struct DownOrd(f64);
+impl Eq for DownOrd {}
+impl PartialOrd for DownOrd {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+impl Ord for DownOrd {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.0.partial_cmp(&other.0).unwrap().reverse()
+ }
+}
+
fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
where
M: MetricElem,
{
- // wrapper struct to
- // - reverse the ordering
- // - implement Ord, even though the type is backed by an f64
- #[repr(transparent)]
- #[derive(Debug, PartialEq)]
- struct DownOrd(f64);
- impl Eq for DownOrd {}
- impl PartialOrd for DownOrd {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl Ord for DownOrd {
- fn cmp(&self, other: &Self) -> Ordering {
- self.0.partial_cmp(&other.0).unwrap().reverse()
- }
- }
-
let num_embeds = embeds.len();
let mut possible_edges =
@@ -52,7 +56,7 @@ where
break (b, a);
}
};
- mst.insert(new, Vec::new());
+ mst.insert(new, vec![old]);
// insert all the new edges we could take
mst.entry(old).and_modify(|v| v.push(new));
@@ -69,16 +73,20 @@ where
mst
}
-fn tsp_from_mst<M>(embeds: &Vec<M>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
+fn tsp_from_mst<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
where
M: MetricElem,
{
- fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
+ fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
- t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
+ t.get(&cur).unwrap().iter().for_each(|&c| {
+ if c != prev {
+ dfs(c, cur, t, into)
+ }
+ });
}
let mut tsp_path = Vec::with_capacity(mst.len());
- dfs(0, &mst, &mut tsp_path);
+ dfs(0, usize::MAX, &mst, &mut tsp_path);
let mut total_dist = 0.;
for i in 0..tsp_path.len() - 1 {
@@ -88,6 +96,175 @@ where
(tsp_path, total_dist)
}
+// TODO this is a non-ideal, greedy algorithm. there are better algorithms for this,
+// and i should probably implement one.
+/// 'verts' must be an even number of vertices with odd degree
+fn min_weight_matching<M>(embeds: &[M], verts: &[usize]) -> HashMap<usize, usize>
+where
+ M: MetricElem,
+{
+ let num_odd = verts.len();
+ assert!(num_odd % 2 == 0);
+
+ let mut possible_edges = PriorityQueue::with_capacity((num_odd * num_odd - num_odd) / 2);
+ for &x in verts {
+ for &y in verts {
+ if x != y {
+ possible_edges.push((x, y), DownOrd(embeds[x].dist(&embeds[y])));
+ }
+ }
+ }
+
+ let mut res = HashMap::new();
+
+ while res.len() != num_odd {
+ let ((a, b), _) = possible_edges.pop().unwrap();
+
+ if !res.contains_key(&a) && !res.contains_key(&b) {
+ res.insert(a, b);
+ res.insert(b, a);
+ }
+ }
+
+ res
+}
+
+fn euler_tour(
+ mut graph: HashMap<usize, HashMultiSet<usize>>,
+) -> (usize, HashMap<usize, Vec<(usize, usize, usize)>>) {
+ let mut r: HashMap<_, Vec<_>> = HashMap::new();
+ let mut partially_explored = HashSet::new();
+
+ // initial setup: pretend that we have only the path 'INF -> root -> INF'
+ // for some arbitrary root, and set cur to some node next to root.
+ // This mimicks the state we're in just after a new phase (because it is)
+ let &root = graph.keys().next().unwrap();
+ r.insert(root, vec![(usize::MAX, usize::MAX, usize::MAX)]);
+ let e = graph.get_mut(&root).unwrap();
+ let &next = e.iter().next().unwrap();
+ e.remove(&next);
+ graph.get_mut(&next).unwrap().remove(&root);
+
+ let mut cur = next;
+ let mut prev = root;
+ let mut pprev = usize::MAX;
+ let mut circ_start_edge = cur;
+
+ loop {
+ let e = graph.get_mut(&cur).unwrap();
+ if e.len() <= 1 {
+ partially_explored.remove(&cur);
+ } else {
+ partially_explored.insert(cur);
+ }
+
+ match e.iter().next() {
+ Some(&next) => {
+ e.remove(&next);
+ // TODO das hier lässt vllt deduplizieren
+ graph.get_mut(&next).unwrap().remove(&cur);
+
+ r.entry(cur).or_default().push((pprev, prev, next));
+ pprev = prev;
+ prev = cur;
+ cur = next;
+ }
+ None => {
+ // we got stuck, which means we returned to the starting vertex of
+ // the current phase. now, we need join the 2 formed trips
+
+ // pick an arbitrary existing edge-pair going through cur
+ let cur_vec = r.get_mut(&cur).unwrap();
+ let (pp, p, n) = cur_vec[0];
+ // reroute
+ cur_vec[0] = (pp, p, circ_start_edge);
+ cur_vec.push((pprev, prev, n));
+
+ // after rerouting, the pprev value of the next node will be wrong
+ match r.get_mut(&n) {
+ // should only happen with n == usize::MAX. no need to reroute the
+ // following node in that case, as there's no such node
+ None => (),
+ Some(rerouted_vec) => {
+ let p = rerouted_vec.iter().position(|&(_, other_p, _)| other_p == cur).unwrap();
+ rerouted_vec[p] = (prev, cur, rerouted_vec[p].2);
+ }
+ }
+
+ // are there any partially explored vertices left?
+ match partially_explored.iter().next() {
+ None => break, // graph fully explored :)
+ Some(&new_cur) => {
+ // reset our active point (note that we don't have to delete
+ // new_cur from partially_explored; that's done the next time we
+ // get there)
+ let e = graph.get_mut(&new_cur).unwrap();
+ let &next = e.iter().next().unwrap();
+ e.remove(&next);
+ graph.get_mut(&next).unwrap().remove(&new_cur);
+
+ circ_start_edge = next;
+ pprev = r[&new_cur][0].1;
+ prev = new_cur;
+ cur = next;
+ }
+ }
+ }
+ }
+ }
+
+ (root, r)
+}
+
+fn christofides<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
+where
+ M: MetricElem,
+{
+ let mut mst: HashMap<_, HashMultiSet<_>> = mst
+ .into_iter()
+ .map(|(k, v)| (k, v.into_iter().collect()))
+ .collect();
+
+ let odd_verts: Vec<_> = mst
+ .iter()
+ .filter_map(|(&i, s)| if s.len() % 2 == 0 { None } else { Some(i) })
+ .collect();
+
+ // from here on, 'mst' is a bit of a misnomer, as we're adding more edges such
+ // that all vertices have even degree
+ for (a, b) in min_weight_matching(embeds, &odd_verts) {
+ mst.get_mut(&a).unwrap().insert(b);
+ }
+
+ let mut r = Vec::new();
+ let mut total_dist = 0.;
+ let mut visited_verts = HashSet::new();
+ visited_verts.insert(usize::MAX);
+
+ let mut pprev = usize::MAX;
+ let mut prev = usize::MAX;
+ let (mut cur, euler_tour) = euler_tour(mst);
+ loop {
+ match euler_tour.get(&cur) {
+ None => break, // tour complete: this should happen iff 'cur == usize::MAX'
+ Some(v) => {
+ let &(_, _, next) = v.iter().find(|(pp, p, _)| *pp == pprev && *p == prev).unwrap();
+
+ if visited_verts.insert(next) {
+ // haven't visited 'next' yet
+ r.push(next);
+ total_dist += embeds[cur].dist(&embeds[next]);
+ }
+ pprev = prev;
+ prev = cur;
+ cur = next;
+ }
+ }
+ }
+
+ (r, total_dist)
+}
+
pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64)
where
M: MetricElem,
@@ -97,7 +274,8 @@ where
bar.enable_steady_tick(std::time::Duration::from_millis(100));
bar.set_message("Finding path...");
- let r = tsp_from_mst(embeds, get_mst(embeds));
+ //let r = tsp_from_mst(embeds, get_mst(embeds));
+ let r = christofides(embeds, get_mst(embeds));
bar.finish();