diff options
-rw-r--r-- | src/main.rs | 29 | ||||
-rw-r--r-- | src/tsp_approx.rs | 54 |
2 files changed, 64 insertions, 19 deletions
diff --git a/src/main.rs b/src/main.rs index b9df0a9..9a3078a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,6 +26,13 @@ enum Embedder { Content, } +#[derive(Debug, Clone, Copy, clap::ValueEnum)] +enum TspAlg { + MstDfs, + Christofides, + ChristofidesRefined, +} + #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by @@ -52,6 +59,10 @@ struct Args { #[arg(short = 'b', long)] benchmark: bool, + /// Algorithm for TSP approximation. Leave as default if unsure. + #[arg(long, default_value = "christofides")] + tsp_approx: TspAlg, + images: Vec<PathBuf>, } @@ -109,12 +120,16 @@ where .collect(); // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) - let missing_embeds = e.embeds( - &missing_embeds_indices - .iter() - .map(|i| args.images[*i].clone()) - .collect::<Vec<_>>(), - )?; + let missing_embeds = if missing_embeds_indices.is_empty() { + Vec::new() + } else { + e.embeds( + &missing_embeds_indices + .iter() + .map(|i| args.images[*i].clone()) + .collect::<Vec<_>>(), + )? + }; for (idx, emb) in missing_embeds_indices .into_iter() @@ -125,7 +140,7 @@ where } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let (tsp_path, total_dist) = tsp(&embeds); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx); Ok(( tsp_path.iter().map(|i| args.images[*i].clone()).collect(), diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 1da8212..6a19b22 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -6,7 +6,7 @@ use std::{ collections::{HashMap, HashSet}, }; -use crate::MetricElem; +use crate::{MetricElem, TspAlg}; // wrapper struct to // - reverse the ordering @@ -26,7 +26,7 @@ impl Ord for DownOrd { } } -fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> +fn get_mst<M>(embeds: &[M]) -> HashMap<usize, Vec<usize>> where M: MetricElem, { @@ -106,19 +106,23 @@ where let num_odd = verts.len(); assert!(num_odd % 2 == 0); - let mut possible_edges = PriorityQueue::with_capacity((num_odd * num_odd - num_odd) / 2); + let mut possible_edges = Vec::with_capacity((num_odd * num_odd - num_odd) / 2); for &x in verts { for &y in verts { if x != y { - possible_edges.push((x, y), DownOrd(embeds[x].dist(&embeds[y]))); + //possible_edges.push((x, y), DownOrd(embeds[x].dist(&embeds[y]))); + possible_edges.push((embeds[x].dist(&embeds[y]), x, y)); } } } + possible_edges.sort_unstable_by(|(da, _, _), (db, _, _)| da.partial_cmp(db).unwrap()); let mut res = HashMap::new(); - while res.len() != num_odd { - let ((a, b), _) = possible_edges.pop().unwrap(); + for (_, a, b) in possible_edges.into_iter() { + if res.len() >= num_odd { + break; + } if !res.contains_key(&a) && !res.contains_key(&b) { res.insert(a, b); @@ -186,7 +190,10 @@ fn euler_tour( // following node in that case, as there's no such node None => (), Some(rerouted_vec) => { - let p = rerouted_vec.iter().position(|&(_, other_p, _)| other_p == cur).unwrap(); + let p = rerouted_vec + .iter() + .position(|&(_, other_p, _)| other_p == cur) + .unwrap(); rerouted_vec[p] = (prev, cur, rerouted_vec[p].2); } } @@ -248,7 +255,10 @@ where match euler_tour.get(&cur) { None => break, // tour complete: this should happen iff 'cur == usize::MAX' Some(v) => { - let &(_, _, next) = v.iter().find(|(pp, p, _)| *pp == pprev && *p == prev).unwrap(); + let &(_, _, next) = v + .iter() + .find(|(pp, p, _)| *pp == pprev && *p == prev) + .unwrap(); if visited_verts.insert(next) { // haven't visited 'next' yet @@ -265,17 +275,37 @@ where (r, total_dist) } -pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64) +fn refine<M>(_: &[M], _: Vec<usize>, _: f64) -> (Vec<usize>, f64) +where + M: MetricElem, +{ + // convert the tour into a linked-list representation. instead of pointers, we use + // an array of indices + //let tour_ll = Vec::new(); + todo!() +} + +pub(crate) fn tsp<M>(embeds: &[M], alg: &TspAlg) -> (Vec<usize>, f64) where M: MetricElem, { let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); bar.enable_steady_tick(std::time::Duration::from_millis(100)); - bar.set_message("Finding path..."); + bar.set_message("Finding mst..."); + + let mst = get_mst(embeds); - //let r = tsp_from_mst(embeds, get_mst(embeds)); - let r = christofides(embeds, get_mst(embeds)); + bar.set_message("Finding path..."); + let r = match alg { + TspAlg::MstDfs => tsp_from_mst(embeds, mst), + TspAlg::Christofides => christofides(embeds, mst), + TspAlg::ChristofidesRefined => { + let (p, l) = christofides(embeds, mst); + bar.set_message("Refining path..."); + refine(embeds, p, l) + } + }; bar.finish(); |