aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-18 11:34:32 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-18 11:34:32 +0200
commitfe82169e0a84692cf161e6210c4c522912e70e72 (patch)
tree9769ba6a70f0849181e028a22c95835ca4e7503b /src
parent970e44a28a57c6299ddde7a107db7a024f901fd7 (diff)
downloadembeddings-sort-fe82169e0a84692cf161e6210c4c522912e70e72.tar
embeddings-sort-fe82169e0a84692cf161e6210c4c522912e70e72.tar.bz2
embeddings-sort-fe82169e0a84692cf161e6210c4c522912e70e72.tar.zst
add tsp benchmark flag
Diffstat (limited to 'src')
-rw-r--r--src/main.rs18
-rw-r--r--src/tsp_approx.rs16
2 files changed, 25 insertions, 9 deletions
diff --git a/src/main.rs b/src/main.rs
index c8f5bbc..993e490 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -48,6 +48,10 @@ struct Args {
#[arg(short = '0', long)]
stdout0: bool,
+ /// Output total tour length to stderr
+ #[arg(short = 'b', long)]
+ benchmark: bool,
+
images: Vec<PathBuf>,
}
@@ -74,12 +78,12 @@ fn hash_file(p: &PathBuf) -> Result<[u8; 32]> {
.unwrap())
}
-fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBuf>>
+fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathBuf>, f64)>
where
E: BatchEmbedder,
{
if args.images.is_empty() {
- return Ok(Vec::new());
+ return Ok((Vec::new(), 0.));
}
let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?;
@@ -121,9 +125,9 @@ where
}
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
- let tsp_path = tsp(&embeds);
+ let (tsp_path, total_dist) = tsp(&embeds);
- Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect())
+ Ok((tsp_path.iter().map(|i| args.images[*i].clone()).collect(), total_dist))
}
fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> {
@@ -153,13 +157,17 @@ fn main() -> Result<()> {
let cfg = get_config()?;
let args = Args::parse();
- let tsp_path = match args.embedder {
+ let (tsp_path, total_dist) = match args.embedder {
Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg),
Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg),
Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg),
Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg),
}?;
+ if args.benchmark {
+ eprintln!("Found tour with length: {}", total_dist);
+ }
+
if let Some(p) = args.symlink_dir {
copy_into(&tsp_path, &p, true)?
}
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index c697a25..648d9ea 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -69,7 +69,10 @@ where
mst
}
-fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
+fn tsp_from_mst<M>(embeds: &Vec<M>, mst: HashMap<usize, Vec<usize>>) -> (Vec<usize>, f64)
+where
+ M: MetricElem,
+{
fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
@@ -77,10 +80,15 @@ fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
let mut tsp_path = Vec::with_capacity(mst.len());
dfs(0, &mst, &mut tsp_path);
- tsp_path
+ let mut total_dist = 0.;
+ for i in 0..tsp_path.len() - 1 {
+ total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]);
+ }
+
+ (tsp_path, total_dist)
}
-pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize>
+pub fn tsp<M>(embeds: &Vec<M>) -> (Vec<usize>, f64)
where
M: MetricElem,
{
@@ -89,7 +97,7 @@ where
bar.enable_steady_tick(std::time::Duration::from_millis(100));
bar.set_message("Finding path...");
- let r = tsp_from_mst(get_mst(embeds));
+ let r = tsp_from_mst(embeds, get_mst(embeds));
bar.finish();