#![feature(iterator_try_collect)] use anyhow::{anyhow, Result}; use clap::Parser; use sha2::{Digest, Sha512_256}; use std::{ fs, io::{self, Write}, path::PathBuf, }; #[cfg(unix)] use std::path::absolute; use embedders::*; use tsp_approx::*; mod embedders; mod tsp_approx; #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum Embedder { Brightness, Hue, Color, ContentEuclidean, ContentAngularDistance, ContentManhatten, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum TspBaseAlg { MstDfs, Christofides, } #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by #[arg(short, long, default_value = "content-angular-distance")] embedder: Embedder, /// Symlink the sorted images into this directory #[cfg(unix)] #[arg(short = 's', long)] symlink_dir: Option, /// Copy the sorted images into this directory. Uses COW when available #[arg(short = 'o', long)] copy_dir: Option, /// Write sorted paths into stdout, one per line #[arg(short = 'c', long)] stdout: bool, /// Write sorted paths into stdout, null-separated. Overrides -c #[arg(short = '0', long)] stdout0: bool, /// Output total tour length to stderr #[arg(short = 'b', long)] benchmark: bool, /// Algorithm for TSP approximation. Leave as default if unsure #[arg(long, default_value = "christofides")] tsp_approx: TspBaseAlg, /// Number of 2-Opt refinement steps. Has quickly diminishing returns #[arg(short = 'r', default_value = "3")] refine: usize, /// Ignore failed embeddings #[arg(short = 'i')] ignore_errors: bool, /// Seed for hashing. Random by default. #[arg(long)] hash_seed: Option, images: Vec, } #[derive(Debug)] struct Config { cache_dir: PathBuf, } fn get_config() -> Result { let glob_cache_dir = dirs::cache_dir().ok_or(anyhow!("Could not get cache directory"))?; Ok(Config { cache_dir: glob_cache_dir.join("embeddings-sort"), }) } fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { let mut f = fs::File::open(p)?; let mut hasher = Sha512_256::new(); io::copy(&mut f, &mut hasher)?; Ok(hasher .finalize() .into_iter() .collect::>() .try_into() .unwrap()) } fn process_embedder(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec, f64)> where E: BatchEmbedder, { let db = sled::open(cfg.cache_dir.join("embeddings.db"))?; let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); // find cached embeddings let mut embeds: Vec<_> = args .images .iter() .map(|path| { let h = hash_file(path)?; let r: Result> = tree.get(&h).map_err(|e| e.into()); r }) .try_collect()?; // find indices of missing embeddings let missing_embeds_indices: Vec<_> = embeds .iter() .enumerate() .filter_map(|(i, cached_embedding)| match cached_embedding { None => Some(i), Some(_) => None, }) .collect(); // calculate missing embeddings let missing_embeds = if missing_embeds_indices.is_empty() { vec![] } else { e.embeds( &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) .collect::>(), ) }; // insert successfully changed for (idx, emb) in missing_embeds_indices .into_iter() .zip(missing_embeds.into_iter()) { match emb { Ok(emb) => { tree.insert(&hash_file(&args.images[idx])?, &emb)?; embeds[idx] = Some(emb); } Err(e) => { if !args.ignore_errors { return Err(e); } } } } // filter out images with failed embeddings let (embeds, images): (Vec<_>, Vec<_>) = embeds .into_iter() .zip(args.images.iter()) .filter_map(|(emb, path)| match emb { Some(embedding) => Some((embedding, path)), None => { if args.ignore_errors { None } else { panic!("Embedding failed for {}", path.display()) } } }) .unzip(); let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed); Ok(( tsp_path.iter().map(|i| images[*i].clone()).collect(), total_dist, )) } #[allow(unused_variables)] // use_symlinks on windows fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> { fs::create_dir_all(target)?; let pad_len = (tsp.len() as f64).log10().ceil() as usize; for (i, p) in tsp.iter().enumerate() { let ext: String = match p.extension() { None => "".to_string(), Some(e) => format!(".{}", e.to_str().unwrap()), }; let tp = target.join(format!("{i:0pad_len$}{ext}")); #[cfg(unix)] if use_symlinks { let rel_path = pathdiff::diff_paths(absolute(p)?, absolute(target)?).unwrap(); let _ = fs::remove_file(&tp); std::os::unix::fs::symlink(rel_path, tp)?; } else { reflink_copy::reflink_or_copy(p, tp)?; } #[cfg(not(unix))] reflink_copy::reflink_or_copy(p, tp)?; } Ok(()) } fn main() -> Result<()> { let cfg = get_config()?; let args = Args::parse(); let (tsp_path, total_dist) = match args.embedder { Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg), Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg), Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg), Embedder::ContentAngularDistance => { process_embedder(ContentEmbedder::::new(&cfg), &args, &cfg) } Embedder::ContentEuclidean => { process_embedder(ContentEmbedder::::new(&cfg), &args, &cfg) } Embedder::ContentManhatten => { process_embedder(ContentEmbedder::::new(&cfg), &args, &cfg) } }?; if args.benchmark { eprintln!("Found tour with length: {}", total_dist); } #[cfg(unix)] if let Some(p) = args.symlink_dir { copy_into(&tsp_path, &p, true)? } if let Some(p) = args.copy_dir { copy_into(&tsp_path, &p, false)? } let path_delim = if args.stdout0 { Some(0) } else if args.stdout { Some(b'\n') } else { None }; path_delim.into_iter().try_for_each(|delim| { let mut o = io::BufWriter::new(io::stdout().lock()); for p in &tsp_path { o.write_all(p.as_os_str().to_str().unwrap().as_bytes())?; o.write_all(&[delim])?; } o.flush() })?; Ok(()) }