aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-20 12:59:18 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-20 12:59:18 +0200
commitae445f481e9ae96b41d15ff22592f40ef5432302 (patch)
tree64f81a5be563a79c50d70fb5fa374e155dd7946d
parent5f926dd6e4f884f6d29c88207480a6bd0b97aa2a (diff)
downloadembeddings-sort-ae445f481e9ae96b41d15ff22592f40ef5432302.tar
embeddings-sort-ae445f481e9ae96b41d15ff22592f40ef5432302.tar.bz2
embeddings-sort-ae445f481e9ae96b41d15ff22592f40ef5432302.tar.zst
algorithm selection, big perfect matching performance improvement
-rw-r--r--src/main.rs29
-rw-r--r--src/tsp_approx.rs54
2 files changed, 64 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs
index b9df0a9..9a3078a 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -26,6 +26,13 @@ enum Embedder {
Content,
}
+#[derive(Debug, Clone, Copy, clap::ValueEnum)]
+enum TspAlg {
+ MstDfs,
+ Christofides,
+ ChristofidesRefined,
+}
+
#[derive(Debug, Parser)]
struct Args {
/// Characteristic to sort by
@@ -52,6 +59,10 @@ struct Args {
#[arg(short = 'b', long)]
benchmark: bool,
+ /// Algorithm for TSP approximation. Leave as default if unsure.
+ #[arg(long, default_value = "christofides")]
+ tsp_approx: TspAlg,
+
images: Vec<PathBuf>,
}
@@ -109,12 +120,16 @@ where
.collect();
// TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows
// for optimizations in the ai embedde (move pip to ::embeds() instead of ::new())
- let missing_embeds = e.embeds(
- &missing_embeds_indices
- .iter()
- .map(|i| args.images[*i].clone())
- .collect::<Vec<_>>(),
- )?;
+ let missing_embeds = if missing_embeds_indices.is_empty() {
+ Vec::new()
+ } else {
+ e.embeds(
+ &missing_embeds_indices
+ .iter()
+ .map(|i| args.images[*i].clone())
+ .collect::<Vec<_>>(),
+ )?
+ };
for (idx, emb) in missing_embeds_indices
.into_iter()
@@ -125,7 +140,7 @@ where
}
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
- let (tsp_path, total_dist) = tsp(&embeds);
+ let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx);
Ok((
tsp_path.iter().map(|i| args.images[*i].clone()).collect(),
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 1da8212..6a19b22 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -6,7 +6,7 @@ use std::{
collections::{HashMap, HashSet},
};
-use crate::MetricElem;
+use crate::{MetricElem, TspAlg};
// wrapper struct to
// - reverse the ordering
@@ -26,7 +26,7 @@ impl Ord for DownOrd {
}
}
-fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
+fn get_mst<M>(embeds: &[M]) -> HashMap<usize, Vec<usize>>
where
M: MetricElem,
{
@@ -106,19 +106,23 @@ where
let num_odd = verts.len();
assert!(num_odd % 2 == 0);
- let mut possible_edges = PriorityQueue::with_capacity((num_odd * num_odd - num_odd) / 2);
+ let mut possible_edges = Vec::with_capacity((num_odd * num_odd - num_odd) / 2);
for &x in verts {
for &y in verts {
if x != y {
- possible_edges.push((x, y), DownOrd(embeds[x].dist(&embeds[y])));
+ //possible_edges.push((x, y), DownOrd(embeds[x].dist(&embeds[y])));
+ possible_edges.push((embeds[x].dist(&embeds[y]), x, y));
}
}
}
+ possible_edges.sort_unstable_by(|(da, _, _), (db, _, _)| da.partial_cmp(db).unwrap());
let mut res = HashMap::new();
- while res.len() != num_odd {
- let ((a, b), _) = possible_edges.pop().unwrap();
+ for (_, a, b) in possible_edges.into_iter() {
+ if res.len() >= num_odd {
+ break;
+ }
if !res.contains_key(&a) && !res.contains_key(&b) {
res.insert(a, b);
@@ -186,7 +190,10 @@ fn euler_tour(
// following node in that case, as there's no such node
None => (),
Some(rerouted_vec) => {
- let p = rerouted_vec.iter().position(|&(_, other_p, _)| other_p == cur).unwrap();
+ let p = rerouted_vec
+ .iter()
+ .position(|&(_, other_p, _)| other_p == cur)
+ .unwrap();
rerouted_vec[p] = (prev, cur, rerouted_vec[p].2);
}
}
@@ -248,7 +255,10 @@ where
match euler_tour.get(&cur) {
None => break, // tour complete: this should happen iff 'cur == usize::MAX'
Some(v) => {
- let &(_, _, next) = v.iter().find(|(pp, p, _)| *pp == pprev && *p == prev).unwrap();
+ let &(_, _, next) = v
+ .iter()
+ .find(|(pp, p, _)| *pp == pprev && *p == prev)
+ .unwrap();
if visited_verts.insert(next) {
// haven't visited 'next' yet
@@ -265,17 +275,37 @@ where
(r, total_dist)
}
-pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64)
+fn refine<M>(_: &[M], _: Vec<usize>, _: f64) -> (Vec<usize>, f64)
+where
+ M: MetricElem,
+{
+ // convert the tour into a linked-list representation. instead of pointers, we use
+ // an array of indices
+ //let tour_ll = Vec::new();
+ todo!()
+}
+
+pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg) -> (Vec<usize>, f64)
where
M: MetricElem,
{
let bar = ProgressBar::new_spinner();
bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
bar.enable_steady_tick(std::time::Duration::from_millis(100));
- bar.set_message("Finding path...");
+ bar.set_message("Finding mst...");
+
+ let mst = get_mst(embeds);
- //let r = tsp_from_mst(embeds, get_mst(embeds));
- let r = christofides(embeds, get_mst(embeds));
+ bar.set_message("Finding path...");
+ let r = match alg {
+ TspAlg::MstDfs => tsp_from_mst(embeds, mst),
+ TspAlg::Christofides => christofides(embeds, mst),
+ TspAlg::ChristofidesRefined => {
+ let (p, l) = christofides(embeds, mst);
+ bar.set_message("Refining path...");
+ refine(embeds, p, l)
+ }
+ };
bar.finish();