diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-18 11:34:32 +0200 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-18 11:34:32 +0200 |
commit | fe82169e0a84692cf161e6210c4c522912e70e72 (patch) | |
tree | 9769ba6a70f0849181e028a22c95835ca4e7503b /src | |
parent | 970e44a28a57c6299ddde7a107db7a024f901fd7 (diff) | |
download | embeddings-sort-fe82169e0a84692cf161e6210c4c522912e70e72.tar embeddings-sort-fe82169e0a84692cf161e6210c4c522912e70e72.tar.bz2 embeddings-sort-fe82169e0a84692cf161e6210c4c522912e70e72.tar.zst |
add tsp benchmark flag
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 18 | ||||
-rw-r--r-- | src/tsp_approx.rs | 16 |
2 files changed, 25 insertions, 9 deletions
diff --git a/src/main.rs b/src/main.rs index c8f5bbc..993e490 100644 --- a/src/main.rs +++ b/src/main.rs @@ -48,6 +48,10 @@ struct Args { #[arg(short = '0', long)] stdout0: bool, + /// Output total tour length to stderr + #[arg(short = 'b', long)] + benchmark: bool, + images: Vec<PathBuf>, } @@ -74,12 +78,12 @@ fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { .unwrap()) } -fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBuf>> +fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathBuf>, f64)> where E: BatchEmbedder, { if args.images.is_empty() { - return Ok(Vec::new()); + return Ok((Vec::new(), 0.)); } let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?; @@ -121,9 +125,9 @@ where } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let tsp_path = tsp(&embeds); + let (tsp_path, total_dist) = tsp(&embeds); - Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect()) + Ok((tsp_path.iter().map(|i| args.images[*i].clone()).collect(), total_dist)) } fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> { @@ -153,13 +157,17 @@ fn main() -> Result<()> { let cfg = get_config()?; let args = Args::parse(); - let tsp_path = match args.embedder { + let (tsp_path, total_dist) = match args.embedder { Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg), Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg), Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg), Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg), }?; + if args.benchmark { + eprintln!("Found tour with length: {}", total_dist); + } + if let Some(p) = args.symlink_dir { copy_into(&tsp_path, &p, true)? } diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index c697a25..648d9ea 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -69,7 +69,10 @@ where mst } -fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { +fn tsp_from_mst<M>(embeds: &Vec<M>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64) +where + M: MetricElem, +{ fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); @@ -77,10 +80,15 @@ fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { let mut tsp_path = Vec::with_capacity(mst.len()); dfs(0, &mst, &mut tsp_path); - tsp_path + let mut total_dist = 0.; + for i in 0..tsp_path.len() - 1 { + total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]); + } + + (tsp_path, total_dist) } -pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize> +pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64) where M: MetricElem, { @@ -89,7 +97,7 @@ where bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Finding path..."); - let r = tsp_from_mst(get_mst(embeds)); + let r = tsp_from_mst(embeds, get_mst(embeds)); bar.finish(); |