aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs168
1 files changed, 54 insertions, 114 deletions
diff --git a/src/main.rs b/src/main.rs
index 6337b14..c8f5bbc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,16 +2,21 @@
use anyhow::Result;
use clap::Parser;
-use priority_queue::PriorityQueue;
-use sha2::{Sha512_256, Digest};
-use std::{cmp::Ordering, collections::HashMap, fs, io, io::Write, path, path::PathBuf};
+use sha2::{Digest, Sha512_256};
+use std::{
+ fs,
+ io::{self, Write},
+ path::{self, PathBuf},
+};
+use ai_embedders::*;
use embedders::*;
use pure_embedders::*;
-use ai_embedders::*;
+use tsp_approx::*;
+mod ai_embedders;
mod embedders;
mod pure_embedders;
-mod ai_embedders;
+mod tsp_approx;
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum Embedder {
@@ -57,91 +62,21 @@ fn get_config() -> Result<Config> {
Ok(Config { base_dirs: dirs })
}
-fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
- where M: MetricElem
-{
- // wrapper struct to
- // - reverse the ordering
- // - implement Ord, even though the type is backed by an f64
- #[repr(transparent)]
- #[derive(Debug, PartialEq)]
- struct DownOrd (f64);
- impl Eq for DownOrd {}
- impl PartialOrd for DownOrd {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl Ord for DownOrd {
- fn cmp(&self, other: &Self) -> Ordering {
- self.0.partial_cmp(&other.0).unwrap().reverse()
- }
- }
-
- let num_embeds = embeds.len();
-
- let mut possible_edges =
- PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
- let mut mst = HashMap::with_capacity(num_embeds);
-
- // here, we start at 0.
- // we might get a better result in the end if we started with a vertex next
- // to the lowest-cost edge, but we don't know which one that is (though we
- // could compute that without changing our asymptotic complexity)
- mst.insert(0, Vec::new());
- for i in 1..num_embeds {
- possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
- }
-
- // prims algorithm or something like that
- while mst.len() < num_embeds {
- // find the edge with the least cost that connects us to a new vertex
- let (new, old) = loop {
- let ((a, b), _) = possible_edges.pop().unwrap();
- if !mst.contains_key(&a) {
- break (a, b);
- }
- else if !mst.contains_key(&b) {
- break (b, a);
- }
- };
- mst.insert(new, Vec::new());
-
- // insert all the new edges we could take
- mst.entry(old).and_modify(|v|v.push(new));
- for i in 0..num_embeds {
- // don't consider edges taking us to nodes we already visited
- if mst.contains_key(&i) {
- continue;
- }
-
- possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i])));
- }
- }
-
- mst
-}
-
-fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
- fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
- into.push(cur);
- t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
- }
- let mut tsp_path = Vec::with_capacity(mst.len());
- dfs(0, &mst, &mut tsp_path);
-
- tsp_path
-}
-
fn hash_file(p: &PathBuf) -> Result<[u8; 32]> {
let mut f = fs::File::open(p)?;
let mut hasher = Sha512_256::new();
io::copy(&mut f, &mut hasher)?;
- Ok(hasher.finalize().into_iter().collect::<Vec<u8>>().try_into().unwrap())
+ Ok(hasher
+ .finalize()
+ .into_iter()
+ .collect::<Vec<u8>>()
+ .try_into()
+ .unwrap())
}
fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBuf>>
- where E: BatchEmbedder
+where
+ E: BatchEmbedder,
{
if args.images.is_empty() {
return Ok(Vec::new());
@@ -150,7 +85,8 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu
let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?;
let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME);
- let mut embeds: Vec<Option<_>> = args.images
+ let mut embeds: Vec<Option<_>> = args
+ .images
.iter()
.map(|p| {
let h = hash_file(p)?;
@@ -165,28 +101,32 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu
.filter_map(|(i, v)| match v {
None => Some(i),
Some(_) => None,
- }).collect();
+ })
+ .collect();
// TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows
// for optimizations in the ai embedde (move pip to ::embeds() instead of ::new())
- let missing_embeds = e.embeds(&missing_embeds_indices
+ let missing_embeds = e.embeds(
+ &missing_embeds_indices
.iter()
.map(|i| args.images[*i].clone())
- .collect::<Vec<_>>())?;
+ .collect::<Vec<_>>(),
+ )?;
for (idx, emb) in missing_embeds_indices
- .into_iter().zip(missing_embeds.into_iter())
+ .into_iter()
+ .zip(missing_embeds.into_iter())
{
tree.insert(&hash_file(&args.images[idx])?, &emb)?;
embeds[idx] = Some(emb);
}
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
- let tsp_path = tsp_from_mst(get_mst(&embeds));
+ let tsp_path = tsp(&embeds);
Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect())
}
-fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> {
+fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> {
fs::create_dir_all(target)?;
let pad_len = (tsp.len() as f64).log10().ceil() as usize;
@@ -195,24 +135,16 @@ fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> {
None => "".to_string(),
Some(e) => format!(".{}", e.to_str().unwrap()),
};
- let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext));
- let rel_path = pathdiff::diff_paths(path::absolute(p)?,
- path::absolute(target)?).unwrap();
- std::os::unix::fs::symlink(rel_path, tp);
- }
- Ok(())
-}
-fn copy_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> {
- fs::create_dir_all(target)?;
+ let tp = target.join(format!("{i:0pad_len$}{ext}"));
- let pad_len = (tsp.len() as f64).log10().ceil() as usize;
- for (i, p) in tsp.iter().enumerate() {
- let ext: String = match p.extension() {
- None => "".to_string(),
- Some(e) => format!(".{}", e.to_str().unwrap()),
- };
- let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext));
- reflink_copy::reflink_or_copy(p, tp)?;
+ if use_symlinks {
+ let rel_path =
+ pathdiff::diff_paths(path::absolute(p)?, path::absolute(target)?).unwrap();
+ let _ = fs::remove_file(&tp);
+ std::os::unix::fs::symlink(rel_path, tp)?;
+ } else {
+ reflink_copy::reflink_or_copy(p, tp)?;
+ }
}
Ok(())
}
@@ -228,18 +160,26 @@ fn main() -> Result<()> {
Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg),
}?;
- if let Some(p) = args.symlink_dir { symlink_into(&tsp_path, &p)? }
- if let Some(p) = args.copy_dir { copy_into(&tsp_path, &p)? }
+ if let Some(p) = args.symlink_dir {
+ copy_into(&tsp_path, &p, true)?
+ }
+ if let Some(p) = args.copy_dir {
+ copy_into(&tsp_path, &p, false)?
+ }
- let path_delim = if args.stdout0 { Some(0) }
- else if args.stdout { Some(b'\n') }
- else { None };
+ let path_delim = if args.stdout0 {
+ Some(0)
+ } else if args.stdout {
+ Some(b'\n')
+ } else {
+ None
+ };
path_delim.into_iter().try_for_each(|delim| {
let mut o = io::BufWriter::new(io::stdout().lock());
for p in &tsp_path {
- o.write(p.as_os_str().to_str().unwrap().as_bytes())?;
- o.write(&[delim])?;
+ o.write_all(p.as_os_str().to_str().unwrap().as_bytes())?;
+ o.write_all(&[delim])?;
}
o.flush()
})?;