diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2024-03-04 20:34:51 +0100 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2024-03-04 20:34:51 +0100 |
commit | 4195f3e742bec7ed0ef1b193a725ea7335a547ef (patch) | |
tree | 646a7b0bb9b8f9a576b7bf045780748adae2eabe /src/main.rs | |
parent | 6ef681468c12618dc395a59882b27a5ccef27e30 (diff) | |
download | embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar.bz2 embeddings-sort-4195f3e742bec7ed0ef1b193a725ea7335a547ef.tar.zst |
variable number of improvement steps
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/main.rs b/src/main.rs index 71e5d00..1a520bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,16 +25,15 @@ enum Embedder { } #[derive(Debug, Clone, Copy, clap::ValueEnum)] -enum TspAlg { +enum TspBaseAlg { MstDfs, Christofides, - ChristofidesRefined, } #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by - #[arg(short, long, default_value = "content-euclidean")] + #[arg(short, long, default_value = "content-angular-distance")] embedder: Embedder, /// Symlink the sorted images into this directory @@ -57,9 +56,13 @@ struct Args { #[arg(short = 'b', long)] benchmark: bool, - /// Algorithm for TSP approximation. Leave as default if unsure. + /// Algorithm for TSP approximation. Leave as default if unsure #[arg(long, default_value = "christofides")] - tsp_approx: TspAlg, + tsp_approx: TspBaseAlg, + + /// number of 2-Opt refinement steps. Has quickly diminishing returns + #[arg(short = 'r', default_value = "3")] + refine: usize, /// Seed for hashing. Random by default. #[arg(long)] @@ -142,7 +145,7 @@ where } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, &args.hash_seed); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed); Ok(( tsp_path.iter().map(|i| args.images[*i].clone()).collect(), |