diff options
-rw-r--r-- | src/ai_embedders.rs | 40 | ||||
-rw-r--r-- | src/embedders.rs | 8 | ||||
-rw-r--r-- | src/main.rs | 168 | ||||
-rw-r--r-- | src/pure_embedders.rs | 33 | ||||
-rw-r--r-- | src/tsp_approx.rs | 97 |
5 files changed, 197 insertions, 149 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs index e8e6c8b..3848674 100644 --- a/src/ai_embedders.rs +++ b/src/ai_embedders.rs @@ -1,16 +1,23 @@ use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; use serde::{Deserialize, Serialize}; -use std::{io::{BufRead, BufReader, copy, Cursor}, path::PathBuf, process::{Command, Stdio}}; +use std::{ + fs::{remove_file, File}, + io::{copy, BufRead, BufReader, Cursor}, + path::PathBuf, + process::{Command, Stdio}, +}; -use crate::{Config, BatchEmbedder, MetricElem}; +use crate::{BatchEmbedder, Config, MetricElem}; #[repr(transparent)] #[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding (Vec<f32>); +pub(crate) struct Imgbedding(Vec<f32>); impl MetricElem for Imgbedding { fn dist(&self, other: &Self) -> f64 { - self.0.iter().zip(other.0.iter()) + self.0 + .iter() + .zip(other.0.iter()) .map(|(a, b)| (a - b).powf(2.)) .sum::<f32>() .sqrt() as f64 @@ -22,15 +29,19 @@ pub(crate) struct ContentEmbedder<'a> { } impl<'a> ContentEmbedder<'a> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg: cfg } + ContentEmbedder { cfg } } } impl<'a> Drop for ContentEmbedder<'a> { fn drop(&mut self) { - self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py") + self.cfg + .base_dirs + .place_runtime_file("imgbeddings-api.py") .iter() - .for_each(|p| { let _ = std::fs::remove_file(&p); }); + .for_each(|p| { + let _ = remove_file(p); + }); } } @@ -39,13 +50,17 @@ impl BatchEmbedder for ContentEmbedder<'_> { const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { - let venv_dir = self.cfg.base_dirs + let venv_dir = self + .cfg + .base_dirs .create_data_directory("imgbeddings-venv")?; - let script_file = self.cfg.base_dirs + let script_file = self + .cfg + .base_dirs .place_runtime_file("imgbeddings-api.py")?; let api_prog = include_bytes!("imgbeddings-api.py"); - copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?; + copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); @@ -73,9 +88,12 @@ impl BatchEmbedder for ContentEmbedder<'_> { .stdout(Stdio::piped()) .spawn()?; + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + let bar = ProgressBar::new(paths.len() as u64).with_style(st); BufReader::new(child.stdout.unwrap()) .lines() - .progress_count(paths.len().try_into().unwrap()) + .progress_with(bar) .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) .try_collect() } diff --git a/src/embedders.rs b/src/embedders.rs index 97cc5ac..a14a8cf 100644 --- a/src/embedders.rs +++ b/src/embedders.rs @@ -33,9 +33,10 @@ impl<T: EmbedderT> BatchEmbedder for T { const NAME: &'static str = T::NAME; fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { - let st = ProgressStyle:: - with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; - paths.par_iter() + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + paths + .par_iter() .progress_with_style(st) .map(|p| self.embed(p)) .collect::<Vec<_>>() @@ -43,4 +44,3 @@ impl<T: EmbedderT> BatchEmbedder for T { .try_collect() } } - diff --git a/src/main.rs b/src/main.rs index 6337b14..c8f5bbc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,16 +2,21 @@ use anyhow::Result; use clap::Parser; -use priority_queue::PriorityQueue; -use sha2::{Sha512_256, Digest}; -use std::{cmp::Ordering, collections::HashMap, fs, io, io::Write, path, path::PathBuf}; +use sha2::{Digest, Sha512_256}; +use std::{ + fs, + io::{self, Write}, + path::{self, PathBuf}, +}; +use ai_embedders::*; use embedders::*; use pure_embedders::*; -use ai_embedders::*; +use tsp_approx::*; +mod ai_embedders; mod embedders; mod pure_embedders; -mod ai_embedders; +mod tsp_approx; #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum Embedder { @@ -57,91 +62,21 @@ fn get_config() -> Result<Config> { Ok(Config { base_dirs: dirs }) } -fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> - where M: MetricElem -{ - // wrapper struct to - // - reverse the ordering - // - implement Ord, even though the type is backed by an f64 - #[repr(transparent)] - #[derive(Debug, PartialEq)] - struct DownOrd (f64); - impl Eq for DownOrd {} - impl PartialOrd for DownOrd { - fn partial_cmp(&self, other: &Self) -> Option<Ordering> { - Some(self.cmp(other)) - } - } - impl Ord for DownOrd { - fn cmp(&self, other: &Self) -> Ordering { - self.0.partial_cmp(&other.0).unwrap().reverse() - } - } - - let num_embeds = embeds.len(); - - let mut possible_edges = - PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); - let mut mst = HashMap::with_capacity(num_embeds); - - // here, we start at 0. - // we might get a better result in the end if we started with a vertex next - // to the lowest-cost edge, but we don't know which one that is (though we - // could compute that without changing our asymptotic complexity) - mst.insert(0, Vec::new()); - for i in 1..num_embeds { - possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); - } - - // prims algorithm or something like that - while mst.len() < num_embeds { - // find the edge with the least cost that connects us to a new vertex - let (new, old) = loop { - let ((a, b), _) = possible_edges.pop().unwrap(); - if !mst.contains_key(&a) { - break (a, b); - } - else if !mst.contains_key(&b) { - break (b, a); - } - }; - mst.insert(new, Vec::new()); - - // insert all the new edges we could take - mst.entry(old).and_modify(|v|v.push(new)); - for i in 0..num_embeds { - // don't consider edges taking us to nodes we already visited - if mst.contains_key(&i) { - continue; - } - - possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i]))); - } - } - - mst -} - -fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { - fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { - into.push(cur); - t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); - } - let mut tsp_path = Vec::with_capacity(mst.len()); - dfs(0, &mst, &mut tsp_path); - - tsp_path -} - fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { let mut f = fs::File::open(p)?; let mut hasher = Sha512_256::new(); io::copy(&mut f, &mut hasher)?; - Ok(hasher.finalize().into_iter().collect::<Vec<u8>>().try_into().unwrap()) + Ok(hasher + .finalize() + .into_iter() + .collect::<Vec<u8>>() + .try_into() + .unwrap()) } fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBuf>> - where E: BatchEmbedder +where + E: BatchEmbedder, { if args.images.is_empty() { return Ok(Vec::new()); @@ -150,7 +85,8 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?; let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); - let mut embeds: Vec<Option<_>> = args.images + let mut embeds: Vec<Option<_>> = args + .images .iter() .map(|p| { let h = hash_file(p)?; @@ -165,28 +101,32 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu .filter_map(|(i, v)| match v { None => Some(i), Some(_) => None, - }).collect(); + }) + .collect(); // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) - let missing_embeds = e.embeds(&missing_embeds_indices + let missing_embeds = e.embeds( + &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) - .collect::<Vec<_>>())?; + .collect::<Vec<_>>(), + )?; for (idx, emb) in missing_embeds_indices - .into_iter().zip(missing_embeds.into_iter()) + .into_iter() + .zip(missing_embeds.into_iter()) { tree.insert(&hash_file(&args.images[idx])?, &emb)?; embeds[idx] = Some(emb); } let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); - let tsp_path = tsp_from_mst(get_mst(&embeds)); + let tsp_path = tsp(&embeds); Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect()) } -fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> { +fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> { fs::create_dir_all(target)?; let pad_len = (tsp.len() as f64).log10().ceil() as usize; @@ -195,24 +135,16 @@ fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> { None => "".to_string(), Some(e) => format!(".{}", e.to_str().unwrap()), }; - let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext)); - let rel_path = pathdiff::diff_paths(path::absolute(p)?, - path::absolute(target)?).unwrap(); - std::os::unix::fs::symlink(rel_path, tp); - } - Ok(()) -} -fn copy_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> { - fs::create_dir_all(target)?; + let tp = target.join(format!("{i:0pad_len$}{ext}")); - let pad_len = (tsp.len() as f64).log10().ceil() as usize; - for (i, p) in tsp.iter().enumerate() { - let ext: String = match p.extension() { - None => "".to_string(), - Some(e) => format!(".{}", e.to_str().unwrap()), - }; - let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext)); - reflink_copy::reflink_or_copy(p, tp)?; + if use_symlinks { + let rel_path = + pathdiff::diff_paths(path::absolute(p)?, path::absolute(target)?).unwrap(); + let _ = fs::remove_file(&tp); + std::os::unix::fs::symlink(rel_path, tp)?; + } else { + reflink_copy::reflink_or_copy(p, tp)?; + } } Ok(()) } @@ -228,18 +160,26 @@ fn main() -> Result<()> { Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg), }?; - if let Some(p) = args.symlink_dir { symlink_into(&tsp_path, &p)? } - if let Some(p) = args.copy_dir { copy_into(&tsp_path, &p)? } + if let Some(p) = args.symlink_dir { + copy_into(&tsp_path, &p, true)? + } + if let Some(p) = args.copy_dir { + copy_into(&tsp_path, &p, false)? + } - let path_delim = if args.stdout0 { Some(0) } - else if args.stdout { Some(b'\n') } - else { None }; + let path_delim = if args.stdout0 { + Some(0) + } else if args.stdout { + Some(b'\n') + } else { + None + }; path_delim.into_iter().try_for_each(|delim| { let mut o = io::BufWriter::new(io::stdout().lock()); for p in &tsp_path { - o.write(p.as_os_str().to_str().unwrap().as_bytes())?; - o.write(&[delim])?; + o.write_all(p.as_os_str().to_str().unwrap().as_bytes())?; + o.write_all(&[delim])?; } o.flush() })?; diff --git a/src/pure_embedders.rs b/src/pure_embedders.rs index 0f0c3ab..09c8321 100644 --- a/src/pure_embedders.rs +++ b/src/pure_embedders.rs @@ -2,7 +2,7 @@ use anyhow::{bail, Result}; use serde::{Deserialize, Serialize}; use std::path::Path; -use crate::{MetricElem, EmbedderT}; +use crate::{EmbedderT, MetricElem}; pub(crate) struct BrightnessEmbedder; impl EmbedderT for BrightnessEmbedder { @@ -17,20 +17,17 @@ impl EmbedderT for BrightnessEmbedder { bail!("Encountered NaN brightness, due to an empty image"); } - Ok(im.to_rgb8() - .iter() - .map(|e| *e as u64) - .sum::<u64>() as f64 / num_bytes as f64) + Ok(im.to_rgb8().iter().map(|e| *e as u64).sum::<u64>() as f64 / num_bytes as f64) } } #[repr(transparent)] #[derive(Serialize, Deserialize)] -pub(crate) struct Hue (f64); +pub(crate) struct Hue(f64); impl MetricElem for Hue { fn dist(&self, b: &Hue) -> f64 { let d = self.0.dist(&b.0); - d.min(6.-d) + d.min(6. - d) } } pub(crate) struct HueEmbedder; @@ -50,16 +47,13 @@ impl EmbedderT for HueEmbedder { }) .map(|e| e as f64 / 255. / num_pixels as f64); - let hue = - if sr >= sg && sr >= sb { - (sg - sb) / (sr - sg.min(sb)) - } - else if sg >= sb { - 2. + (sb - sr) / (sg - sr.min(sb)) - } - else { - 4. + (sr - sg) / (sb - sr.min(sg)) - }; + let hue = if sr >= sg && sr >= sb { + (sg - sb) / (sr - sg.min(sb)) + } else if sg >= sb { + 2. + (sb - sr) / (sg - sr.min(sb)) + } else { + 4. + (sr - sg) / (sb - sr.min(sg)) + }; if hue.is_nan() { bail!("Encountered NaN hue, possibly because of a colorless or empty image"); @@ -71,9 +65,8 @@ impl EmbedderT for HueEmbedder { impl MetricElem for (f64, f64, f64) { fn dist(&self, o: &(f64, f64, f64)) -> f64 { - let (dr, dg, db) = - ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2)); - (dr*dr + dg*dg + db*db).sqrt() + let (dr, dg, db) = ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2)); + (dr * dr + dg * dg + db * db).sqrt() } } pub(crate) struct ColorEmbedder; diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs new file mode 100644 index 0000000..c697a25 --- /dev/null +++ b/src/tsp_approx.rs @@ -0,0 +1,97 @@ +use indicatif::{ProgressBar, ProgressStyle}; +use priority_queue::PriorityQueue; +use std::{cmp::Ordering, collections::HashMap}; + +use crate::MetricElem; + +fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> +where + M: MetricElem, +{ + // wrapper struct to + // - reverse the ordering + // - implement Ord, even though the type is backed by an f64 + #[repr(transparent)] + #[derive(Debug, PartialEq)] + struct DownOrd(f64); + impl Eq for DownOrd {} + impl PartialOrd for DownOrd { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + Some(self.cmp(other)) + } + } + impl Ord for DownOrd { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap().reverse() + } + } + + let num_embeds = embeds.len(); + + let mut possible_edges = + PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2); + let mut mst = HashMap::with_capacity(num_embeds); + + // here, we start at 0. + // we might get a better result in the end if we started with a vertex next + // to the lowest-cost edge, but we don't know which one that is (though we + // could compute that without changing our asymptotic complexity) + mst.insert(0, Vec::new()); + for i in 1..num_embeds { + possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i]))); + } + + // prims algorithm or something like that + while mst.len() < num_embeds { + // find the edge with the least cost that connects us to a new vertex + let (new, old) = loop { + let ((a, b), _) = possible_edges.pop().unwrap(); + if !mst.contains_key(&a) { + break (a, b); + } else if !mst.contains_key(&b) { + break (b, a); + } + }; + mst.insert(new, Vec::new()); + + // insert all the new edges we could take + mst.entry(old).and_modify(|v| v.push(new)); + for i in 0..num_embeds { + // don't consider edges taking us to nodes we already visited + if mst.contains_key(&i) { + continue; + } + + possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i]))); + } + } + + mst +} + +fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> { + fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) { + into.push(cur); + t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); + } + let mut tsp_path = Vec::with_capacity(mst.len()); + dfs(0, &mst, &mut tsp_path); + + tsp_path +} + +pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize> +where + M: MetricElem, +{ + let bar = ProgressBar::new_spinner(); + bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + bar.set_message("Finding path..."); + + let r = tsp_from_mst(get_mst(embeds)); + + bar.finish(); + + r +} |