aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2024-03-04 20:34:51 +0100
committerLia Lenckowski <lialenck@protonmail.com>2024-03-04 20:34:51 +0100
commit4195f3e742bec7ed0ef1b193a725ea7335a547ef (patch)
tree646a7b0bb9b8f9a576b7bf045780748adae2eabe
parent6ef681468c12618dc395a59882b27a5ccef27e30 (diff)
downloadembeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar
embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar.bz2
embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar.zst
variable number of improvement steps
-rw-r--r--src/main.rs15
-rw-r--r--src/tsp_approx.rs38
2 files changed, 34 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs
index 71e5d00..1a520bd 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,16 +25,15 @@ enum Embedder {
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
-enum TspAlg {
+enum TspBaseAlg {
MstDfs,
Christofides,
- ChristofidesRefined,
}
#[derive(Debug, Parser)]
struct Args {
/// Characteristic to sort by
- #[arg(short, long, default_value = "content-euclidean")]
+ #[arg(short, long, default_value = "content-angular-distance")]
embedder: Embedder,
/// Symlink the sorted images into this directory
@@ -57,9 +56,13 @@ struct Args {
#[arg(short = 'b', long)]
benchmark: bool,
- /// Algorithm for TSP approximation. Leave as default if unsure.
+ /// Algorithm for TSP approximation. Leave as default if unsure
#[arg(long, default_value = "christofides")]
- tsp_approx: TspAlg,
+ tsp_approx: TspBaseAlg,
+
+ /// number of 2-Opt refinement steps. Has quickly diminishing returns
+ #[arg(short = 'r', default_value = "3")]
+ refine: usize,
/// Seed for hashing. Random by default.
#[arg(long)]
@@ -142,7 +145,7 @@ where
}
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
- let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, &args.hash_seed);
+ let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed);
Ok((
tsp_path.iter().map(|i| args.images[*i].clone()).collect(),
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 7c9568d..097fee7 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -4,7 +4,7 @@ use mset::MultiSet;
use partitions::partition_vec::PartitionVec;
use std::{cmp::Eq, hash::Hash, mem::swap};
-use crate::{MetricElem, TspAlg};
+use crate::{MetricElem, TspBaseAlg};
struct DistCache(HashMap<(usize, usize), f64>);
impl DistCache {
@@ -285,9 +285,9 @@ fn christofides(
r
}
-fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> {
+fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> (bool, Vec<usize>) {
if tour.len() < 4 {
- return tour;
+ return (false, tour);
}
// convert the tour into a doubly linked-list. instead of pointers, we use
@@ -327,6 +327,8 @@ fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> {
}
};
+ let mut improved = false;
+
// for each combination of edges ...
while le.0 != n - 2 {
let mut re = adv(le, &tour);
@@ -355,6 +357,8 @@ fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> {
swap(&mut le.1, &mut re.0);
+ improved = true;
+
re = adv(re, &tour);
}
@@ -372,7 +376,7 @@ fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> {
cur = adv(cur, &tour);
}
- tour_flat
+ (improved, tour_flat)
}
fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> DistCache
@@ -390,7 +394,12 @@ where
DistCache(r)
}
-pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg, hash_seed: &Option<u64>) -> (Vec<usize>, f64)
+pub(crate) fn tsp<M>(
+ embeds: &[M],
+ alg: &TspBaseAlg,
+ refinements: usize,
+ hash_seed: &Option<u64>,
+) -> (Vec<usize>, f64)
where
M: MetricElem,
{
@@ -405,16 +414,19 @@ where
let mst = get_mst(&dc, embeds.len(), hash_seed);
bar.set_message("Finding path...");
- let tour = match alg {
- TspAlg::MstDfs => tsp_from_mst(mst),
- TspAlg::Christofides => christofides(&dc, mst, hash_seed),
- TspAlg::ChristofidesRefined => {
- let tour = christofides(&dc, mst, hash_seed);
- bar.set_message("Refining path...");
- refine_2_opt(&dc, tour)
- }
+ let mut tour = match alg {
+ TspBaseAlg::MstDfs => tsp_from_mst(mst),
+ TspBaseAlg::Christofides => christofides(&dc, mst, hash_seed),
};
+ for _ in 0..refinements {
+ let res = refine_2_opt(&dc, tour);
+ tour = res.1;
+ if !res.0 {
+ break; // stop early at convergence
+ }
+ }
+
let mut total_dist = 0.;
for i in 0..tour.len() - 1 {
let (a, b) = (tour[i], tour[i + 1]);