diff options
-rw-r--r-- | src/main.rs | 15 | ||||
-rw-r--r-- | src/tsp_approx.rs | 38 |
2 files changed, 34 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs index 71e5d00..1a520bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,16 +25,15 @@ enum Embedder { } #[derive(Debug, Clone, Copy, clap::ValueEnum)] -enum TspAlg { +enum TspBaseAlg { MstDfs, Christofides, - ChristofidesRefined, } #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by - #[arg(short, long, default_value = "content-euclidean")] + #[arg(short, long, default_value = "content-angular-distance")] embedder: Embedder, /// Symlink the sorted images into this directory @@ -57,9 +56,13 @@ struct Args { #[arg(short = 'b', long)] benchmark: bool, - /// Algorithm for TSP approximation. Leave as default if unsure. + /// Algorithm for TSP approximation. Leave as default if unsure #[arg(long, default_value = "christofides")] - tsp_approx: TspAlg, + tsp_approx: TspBaseAlg, + + /// number of 2-Opt refinement steps. Has quickly diminishing returns + #[arg(short = 'r', default_value = "3")] + refine: usize, /// Seed for hashing. Random by default. #[arg(long)] @@ -142,7 +145,7 @@ where } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, &args.hash_seed); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed); Ok(( tsp_path.iter().map(|i| args.images[*i].clone()).collect(), diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 7c9568d..097fee7 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -4,7 +4,7 @@ use mset::MultiSet; use partitions::partition_vec::PartitionVec; use std::{cmp::Eq, hash::Hash, mem::swap}; -use crate::{MetricElem, TspAlg}; +use crate::{MetricElem, TspBaseAlg}; struct DistCache(HashMap<(usize, usize), f64>); impl DistCache { @@ -285,9 +285,9 @@ fn christofides( r } -fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> { +fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> (bool, Vec<usize>) { if tour.len() < 4 { - return tour; + return (false, tour); } // convert the tour into a doubly linked-list. instead of pointers, we use @@ -327,6 +327,8 @@ fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> { } }; + let mut improved = false; + // for each combination of edges ... while le.0 != n - 2 { let mut re = adv(le, &tour); @@ -355,6 +357,8 @@ fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> { swap(&mut le.1, &mut re.0); + improved = true; + re = adv(re, &tour); } @@ -372,7 +376,7 @@ fn refine_2_opt(dist_cache: &DistCache, tour: Vec<usize>) -> Vec<usize> { cur = adv(cur, &tour); } - tour_flat + (improved, tour_flat) } fn get_dist_cache<M>(embeds: &[M], hash_seed: &Option<u64>) -> DistCache @@ -390,7 +394,12 @@ where DistCache(r) } -pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg, hash_seed: &Option<u64>) -> (Vec<usize>, f64) +pub(crate) fn tsp<M>( + embeds: &[M], + alg: &TspBaseAlg, + refinements: usize, + hash_seed: &Option<u64>, +) -> (Vec<usize>, f64) where M: MetricElem, { @@ -405,16 +414,19 @@ where let mst = get_mst(&dc, embeds.len(), hash_seed); bar.set_message("Finding path..."); - let tour = match alg { - TspAlg::MstDfs => tsp_from_mst(mst), - TspAlg::Christofides => christofides(&dc, mst, hash_seed), - TspAlg::ChristofidesRefined => { - let tour = christofides(&dc, mst, hash_seed); - bar.set_message("Refining path..."); - refine_2_opt(&dc, tour) - } + let mut tour = match alg { + TspBaseAlg::MstDfs => tsp_from_mst(mst), + TspBaseAlg::Christofides => christofides(&dc, mst, hash_seed), }; + for _ in 0..refinements { + let res = refine_2_opt(&dc, tour); + tour = res.1; + if !res.0 { + break; // stop early at convergence + } + } + let mut total_dist = 0.; for i in 0..tour.len() - 1 { let (a, b) = (tour[i], tour[i + 1]); |