#![feature(iterator_try_collect, absolute_path)] use anyhow::Result; use clap::Parser; use sha2::{Digest, Sha512_256}; use std::{ fs, io::{self, Write}, path::{self, PathBuf}, }; use embedders::*; use tsp_approx::*; mod embedders; mod tsp_approx; #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum Embedder { Brightness, Hue, Color, ContentEuclidean, ContentCosineSim, ContentManhatten, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum TspAlg { MstDfs, Christofides, ChristofidesRefined, } #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by #[arg(short, long, default_value = "content-euclidean")] embedder: Embedder, /// Symlink the sorted images into this directory #[arg(short = 's', long)] symlink_dir: Option, /// Copy the sorted images into this directory. Uses COW when available #[arg(short = 'o', long)] copy_dir: Option, /// Write sorted paths into stdout, one per line #[arg(short = 'c', long)] stdout: bool, /// Write sorted paths into stdout, null-separated. Overrides -c #[arg(short = '0', long)] stdout0: bool, /// Output total tour length to stderr #[arg(short = 'b', long)] benchmark: bool, /// Algorithm for TSP approximation. Leave as default if unsure. #[arg(long, default_value = "christofides")] tsp_approx: TspAlg, images: Vec, } #[derive(Debug)] struct Config { base_dirs: xdg::BaseDirectories, } fn get_config() -> Result { let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")?; Ok(Config { base_dirs: dirs }) } fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { let mut f = fs::File::open(p)?; let mut hasher = Sha512_256::new(); io::copy(&mut f, &mut hasher)?; Ok(hasher .finalize() .into_iter() .collect::>() .try_into() .unwrap()) } fn process_embedder(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec, f64)> where E: BatchEmbedder, { if args.images.is_empty() { return Ok((Vec::new(), 0.)); } let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?; let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); let mut embeds: Vec> = args .images .iter() .map(|p| { let h = hash_file(p)?; let r: Result> = tree.get(&h).map_err(|e| e.into()); r }) .try_collect()?; let missing_embeds_indices: Vec<_> = embeds .iter() .enumerate() .filter_map(|(i, v)| match v { None => Some(i), Some(_) => None, }) .collect(); // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) let missing_embeds = if missing_embeds_indices.is_empty() { Vec::new() } else { e.embeds( &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) .collect::>(), )? }; for (idx, emb) in missing_embeds_indices .into_iter() .zip(missing_embeds.into_iter()) { tree.insert(&hash_file(&args.images[idx])?, &emb)?; embeds[idx] = Some(emb); } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx); Ok(( tsp_path.iter().map(|i| args.images[*i].clone()).collect(), total_dist, )) } fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> { fs::create_dir_all(target)?; let pad_len = (tsp.len() as f64).log10().ceil() as usize; for (i, p) in tsp.iter().enumerate() { let ext: String = match p.extension() { None => "".to_string(), Some(e) => format!(".{}", e.to_str().unwrap()), }; let tp = target.join(format!("{i:0pad_len$}{ext}")); if use_symlinks { let rel_path = pathdiff::diff_paths(path::absolute(p)?, path::absolute(target)?).unwrap(); let _ = fs::remove_file(&tp); std::os::unix::fs::symlink(rel_path, tp)?; } else { reflink_copy::reflink_or_copy(p, tp)?; } } Ok(()) } fn main() -> Result<()> { let cfg = get_config()?; let args = Args::parse(); let (tsp_path, total_dist) = match args.embedder { Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg), Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg), Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg), Embedder::ContentCosineSim => { process_embedder(ContentEmbedder::::new(&cfg), &args, &cfg) } Embedder::ContentEuclidean => { process_embedder(ContentEmbedder::::new(&cfg), &args, &cfg) } Embedder::ContentManhatten => { process_embedder(ContentEmbedder::::new(&cfg), &args, &cfg) } }?; if args.benchmark { eprintln!("Found tour with length: {}", total_dist); } if let Some(p) = args.symlink_dir { copy_into(&tsp_path, &p, true)? } if let Some(p) = args.copy_dir { copy_into(&tsp_path, &p, false)? } let path_delim = if args.stdout0 { Some(0) } else if args.stdout { Some(b'\n') } else { None }; path_delim.into_iter().try_for_each(|delim| { let mut o = io::BufWriter::new(io::stdout().lock()); for p in &tsp_path { o.write_all(p.as_os_str().to_str().unwrap().as_bytes())?; o.write_all(&[delim])?; } o.flush() })?; Ok(()) }