aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-20 16:12:25 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-20 16:12:25 +0200
commitf62b4e356a1deecc550a2eba6d7d0caaad1303c1 (patch)
tree62d08653b904fa48ae55b73345119d07e0bf3eec
parent2d127740cf30cfbd3875a406ecc42ef6ebde60e4 (diff)
downloadembeddings-sort-f62b4e356a1deecc550a2eba6d7d0caaad1303c1.tar
embeddings-sort-f62b4e356a1deecc550a2eba6d7d0caaad1303c1.tar.bz2
embeddings-sort-f62b4e356a1deecc550a2eba6d7d0caaad1303c1.tar.zst
small performance improvement for AI embedder: cache distances. makes other embedders slower, though.
-rw-r--r--src/tsp_approx.rs61
1 files changed, 33 insertions, 28 deletions
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 4c886cd..79aca3d 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -5,12 +5,7 @@ use std::collections::{HashMap, HashSet};
use crate::{MetricElem, TspAlg};
-fn get_mst<M>(embeds: &[M]) -> HashMap<usize, Vec<usize>>
-where
- M: MetricElem,
-{
- let num_embeds = embeds.len();
-
+fn get_mst(dist_cache: &HashMap<(usize, usize), f64>, num_embeds: usize) -> HashMap<usize, Vec<usize>> {
let mut possible_edges = Vec::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
let mut mst = HashMap::with_capacity(num_embeds);
@@ -18,7 +13,7 @@ where
mst.insert(0, Vec::new());
for a in 0..num_embeds {
for b in a + 1..num_embeds {
- possible_edges.push((embeds[a].dist(&embeds[b]), a, b));
+ possible_edges.push((dist_cache[&(a.min(b), a.max(b))], a, b));
}
}
@@ -40,10 +35,7 @@ where
mst
}
-fn tsp_from_mst<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
-where
- M: MetricElem,
-{
+fn tsp_from_mst(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) {
fn dfs(cur: usize, prev: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|&c| {
@@ -57,7 +49,8 @@ where
let mut total_dist = 0.;
for i in 0..tsp_path.len() - 1 {
- total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]);
+ let (a, b) = (tsp_path[i], tsp_path[i + 1]);
+ total_dist += dist_cache[&(a.min(b), a.max(b))];
}
(tsp_path, total_dist)
@@ -66,10 +59,7 @@ where
// TODO this is a non-ideal, greedy algorithm. there are better algorithms for this,
// and i should probably implement one.
/// 'verts' must be an even number of vertices with odd degree
-fn min_weight_matching<M>(embeds: &[M], verts: &[usize]) -> HashMap<usize, usize>
-where
- M: MetricElem,
-{
+fn min_weight_matching(dist_cache: &HashMap<(usize, usize), f64>, verts: &[usize]) -> HashMap<usize, usize> {
let num_odd = verts.len();
assert!(num_odd % 2 == 0);
@@ -77,7 +67,7 @@ where
for &x in verts {
for &y in verts {
if x != y {
- possible_edges.push((embeds[x].dist(&embeds[y]), x, y));
+ possible_edges.push((dist_cache[&(x.min(y), x.max(y))], x, y));
}
}
}
@@ -188,10 +178,7 @@ fn euler_tour(
(root, r)
}
-fn christofides<M>(embeds: &[M], mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
-where
- M: MetricElem,
-{
+fn christofides(dist_cache: &HashMap<(usize, usize), f64>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) {
let mut mst: HashMap<_, HashMultiSet<_>> = mst
.into_iter()
.map(|(k, v)| (k, v.into_iter().collect()))
@@ -204,7 +191,7 @@ where
// from here on, 'mst' is a bit of a misnomer, as we're adding more edges such
// that all vertices have even degree
- for (a, b) in min_weight_matching(embeds, &odd_verts) {
+ for (a, b) in min_weight_matching(dist_cache, &odd_verts) {
mst.get_mut(&a).unwrap().insert(b);
}
@@ -228,7 +215,7 @@ where
if visited_verts.insert(next) {
// haven't visited 'next' yet
r.push(next);
- total_dist += embeds[cur].dist(&embeds[next]);
+ total_dist += dist_cache[&(cur.min(next), cur.max(next))];
}
pprev = prev;
prev = cur;
@@ -250,6 +237,21 @@ where
todo!()
}
+fn get_dist_cache<M>(embeds: &[M]) -> HashMap<(usize, usize), f64>
+where
+ M: MetricElem,
+{
+ let n = embeds.len();
+ let mut r = HashMap::with_capacity((n * n - n) / 2);
+ for a in 0..n {
+ for b in a + 1..n {
+ r.insert((a, b), embeds[a].dist(&embeds[b]));
+ }
+ }
+
+ r
+}
+
pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg) -> (Vec<usize>, f64)
where
M: MetricElem,
@@ -257,16 +259,19 @@ where
let bar = ProgressBar::new_spinner();
bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
bar.enable_steady_tick(std::time::Duration::from_millis(100));
- bar.set_message("Finding mst...");
- let mst = get_mst(embeds);
+ bar.set_message("Calculating distances...");
+ let dc = get_dist_cache(embeds);
+
+ bar.set_message("Finding mst...");
+ let mst = get_mst(&dc, embeds.len());
bar.set_message("Finding path...");
let r = match alg {
- TspAlg::MstDfs => tsp_from_mst(embeds, mst),
- TspAlg::Christofides => christofides(embeds, mst),
+ TspAlg::MstDfs => tsp_from_mst(&dc, mst),
+ TspAlg::Christofides => christofides(&dc, mst),
TspAlg::ChristofidesRefined => {
- let (p, l) = christofides(embeds, mst);
+ let (p, l) = christofides(&dc, mst);
bar.set_message("Refining path...");
refine(embeds, p, l)
}