aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2024-03-04 20:34:51 +0100
committerLia Lenckowski <lialenck@protonmail.com>2024-03-04 20:34:51 +0100
commit4195f3e742bec7ed0ef1b193a725ea7335a547ef (patch)
tree646a7b0bb9b8f9a576b7bf045780748adae2eabe /src/main.rs
parent6ef681468c12618dc395a59882b27a5ccef27e30 (diff)
downloadembeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar
embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar.bz2
embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar.zst
variable number of improvement steps
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/main.rs b/src/main.rs
index 71e5d00..1a520bd 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,16 +25,15 @@ enum Embedder {
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
-enum TspAlg {
+enum TspBaseAlg {
MstDfs,
Christofides,
- ChristofidesRefined,
}
#[derive(Debug, Parser)]
struct Args {
/// Characteristic to sort by
- #[arg(short, long, default_value = "content-euclidean")]
+ #[arg(short, long, default_value = "content-angular-distance")]
embedder: Embedder,
/// Symlink the sorted images into this directory
@@ -57,9 +56,13 @@ struct Args {
#[arg(short = 'b', long)]
benchmark: bool,
- /// Algorithm for TSP approximation. Leave as default if unsure.
+ /// Algorithm for TSP approximation. Leave as default if unsure
#[arg(long, default_value = "christofides")]
- tsp_approx: TspAlg,
+ tsp_approx: TspBaseAlg,
+
+ /// number of 2-Opt refinement steps. Has quickly diminishing returns
+ #[arg(short = 'r', default_value = "3")]
+ refine: usize,
/// Seed for hashing. Random by default.
#[arg(long)]
@@ -142,7 +145,7 @@ where
}
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
- let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, &args.hash_seed);
+ let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed);
Ok((
tsp_path.iter().map(|i| args.images[*i].clone()).collect(),