diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 168 |
1 files changed, 54 insertions, 114 deletions
diff --git a/src/main.rs b/src/main.rs index 6337b14..c8f5bbc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,16 +2,21 @@ use anyhow::Result; use clap::Parser; -use priority_queue::PriorityQueue; -use sha2::{Sha512_256, Digest}; -use std::{cmp::Ordering, collections::HashMap, fs, io, io::Write, path, path::PathBuf}; +use sha2::{Digest, Sha512_256}; +use std::{ + fs, + io::{self, Write}, + path::{self, PathBuf}, +}; +use ai_embedders::*; use embedders::*; use pure_embedders::*; -use ai_embedders::*; +use tsp_approx::*; +mod ai_embedders; mod embedders; mod pure_embedders; -mod ai_embedders; +mod tsp_approx; #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum Embedder { @@ -57,91 +62,21 @@ fn get_config() -> Result<Config> { Ok(Config { base_dirs: dirs }) } -fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> - where M: MetricElem -{ - // wrapper struct to - // - reverse the ordering - // - implement Ord, even though the type is backed by an f64 - #[repr(transparent)] - #[derive(Debug, PartialEq)] - struct DownOrd (f64); - impl Eq for DownOrd {} - impl PartialOrd for DownOrd { - fn partial_cmp(&self, other: &Self) -> Option<Ordering> { - Some(self.cmp(other)) - } - } - impl Ord for DownOrd { - fn cmp(&self, other: &Self) -> Ordering { - self.0.partial_cmp(&other.0).unwrap().reverse() - } - } - - let num_embeds = embeds.len(); - - let mut possible_edges = - PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); - let mut mst = HashMap::with_capacity(num_embeds); - - // here, we start at 0. - // we might get a better result in the end if we started with a vertex next - // to the lowest-cost edge, but we don't know which one that is (though we - // could compute that without changing our asymptotic complexity) - mst.insert(0, Vec::new()); - for i in 1..num_embeds { - possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); - } - - // prims algorithm or something like that - while mst.len() < num_embeds { - // find the edge with the least cost that connects us to a new vertex - let (new, old) = loop { - let ((a, b), _) = possible_edges.pop().unwrap(); - if !mst.contains_key(&a) { - break (a, b); - } - else if !mst.contains_key(&b) { - break (b, a); - } - }; - mst.insert(new, Vec::new()); - - // insert all the new edges we could take - mst.entry(old).and_modify(|v|v.push(new)); - for i in 0..num_embeds { - // don't consider edges taking us to nodes we already visited - if mst.contains_key(&i) { - continue; - } - - possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i]))); - } - } - - mst -} - -fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { - fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { - into.push(cur); - t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); - } - let mut tsp_path = Vec::with_capacity(mst.len()); - dfs(0, &mst, &mut tsp_path); - - tsp_path -} - fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { let mut f = fs::File::open(p)?; let mut hasher = Sha512_256::new(); io::copy(&mut f, &mut hasher)?; - Ok(hasher.finalize().into_iter().collect::<Vec<u8>>().try_into().unwrap()) + Ok(hasher + .finalize() + .into_iter() + .collect::<Vec<u8>>() + .try_into() + .unwrap()) } fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBuf>> - where E: BatchEmbedder +where + E: BatchEmbedder, { if args.images.is_empty() { return Ok(Vec::new()); @@ -150,7 +85,8 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?; let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); - let mut embeds: Vec<Option<_>> = args.images + let mut embeds: Vec<Option<_>> = args + .images .iter() .map(|p| { let h = hash_file(p)?; @@ -165,28 +101,32 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu .filter_map(|(i, v)| match v { None => Some(i), Some(_) => None, - }).collect(); + }) + .collect(); // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) - let missing_embeds = e.embeds(&missing_embeds_indices + let missing_embeds = e.embeds( + &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) - .collect::<Vec<_>>())?; + .collect::<Vec<_>>(), + )?; for (idx, emb) in missing_embeds_indices - .into_iter().zip(missing_embeds.into_iter()) + .into_iter() + .zip(missing_embeds.into_iter()) { tree.insert(&hash_file(&args.images[idx])?, &emb)?; embeds[idx] = Some(emb); } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let tsp_path = tsp_from_mst(get_mst(&embeds)); + let tsp_path = tsp(&embeds); Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect()) } -fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> { +fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> { fs::create_dir_all(target)?; let pad_len = (tsp.len() as f64).log10().ceil() as usize; @@ -195,24 +135,16 @@ fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> { None => "".to_string(), Some(e) => format!(".{}", e.to_str().unwrap()), }; - let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext)); - let rel_path = pathdiff::diff_paths(path::absolute(p)?, - path::absolute(target)?).unwrap(); - std::os::unix::fs::symlink(rel_path, tp); - } - Ok(()) -} -fn copy_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> { - fs::create_dir_all(target)?; + let tp = target.join(format!("{i:0pad_len$}{ext}")); - let pad_len = (tsp.len() as f64).log10().ceil() as usize; - for (i, p) in tsp.iter().enumerate() { - let ext: String = match p.extension() { - None => "".to_string(), - Some(e) => format!(".{}", e.to_str().unwrap()), - }; - let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext)); - reflink_copy::reflink_or_copy(p, tp)?; + if use_symlinks { + let rel_path = + pathdiff::diff_paths(path::absolute(p)?, path::absolute(target)?).unwrap(); + let _ = fs::remove_file(&tp); + std::os::unix::fs::symlink(rel_path, tp)?; + } else { + reflink_copy::reflink_or_copy(p, tp)?; + } } Ok(()) } @@ -228,18 +160,26 @@ fn main() -> Result<()> { Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg), }?; - if let Some(p) = args.symlink_dir { symlink_into(&tsp_path, &p)? } - if let Some(p) = args.copy_dir { copy_into(&tsp_path, &p)? } + if let Some(p) = args.symlink_dir { + copy_into(&tsp_path, &p, true)? + } + if let Some(p) = args.copy_dir { + copy_into(&tsp_path, &p, false)? + } - let path_delim = if args.stdout0 { Some(0) } - else if args.stdout { Some(b'\n') } - else { None }; + let path_delim = if args.stdout0 { + Some(0) + } else if args.stdout { + Some(b'\n') + } else { + None + }; path_delim.into_iter().try_for_each(|delim| { let mut o = io::BufWriter::new(io::stdout().lock()); for p in &tsp_path { - o.write(p.as_os_str().to_str().unwrap().as_bytes())?; - o.write(&[delim])?; + o.write_all(p.as_os_str().to_str().unwrap().as_bytes())?; + o.write_all(&[delim])?; } o.flush() })?; |