aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-08 00:51:27 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-08 00:51:27 +0200
commit2590cda29da4c4a0dd05a6f91271c08c0335497d (patch)
tree84520c3925a74474eae1a4e6dff2ea222e0e30bf
parent7c0b7f51023c3608825774f7a4921fd6a249273e (diff)
downloadembeddings-sort-2590cda29da4c4a0dd05a6f91271c08c0335497d.tar
embeddings-sort-2590cda29da4c4a0dd05a6f91271c08c0335497d.tar.bz2
embeddings-sort-2590cda29da4c4a0dd05a6f91271c08c0335497d.tar.zst
giant refactor
-rw-r--r--src/ai_embedders.rs40
-rw-r--r--src/embedders.rs8
-rw-r--r--src/main.rs168
-rw-r--r--src/pure_embedders.rs33
-rw-r--r--src/tsp_approx.rs97
5 files changed, 197 insertions, 149 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs
index e8e6c8b..3848674 100644
--- a/src/ai_embedders.rs
+++ b/src/ai_embedders.rs
@@ -1,16 +1,23 @@
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
use serde::{Deserialize, Serialize};
-use std::{io::{BufRead, BufReader, copy, Cursor}, path::PathBuf, process::{Command, Stdio}};
+use std::{
+ fs::{remove_file, File},
+ io::{copy, BufRead, BufReader, Cursor},
+ path::PathBuf,
+ process::{Command, Stdio},
+};
-use crate::{Config, BatchEmbedder, MetricElem};
+use crate::{BatchEmbedder, Config, MetricElem};
#[repr(transparent)]
#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding (Vec<f32>);
+pub(crate) struct Imgbedding(Vec<f32>);
impl MetricElem for Imgbedding {
fn dist(&self, other: &Self) -> f64 {
- self.0.iter().zip(other.0.iter())
+ self.0
+ .iter()
+ .zip(other.0.iter())
.map(|(a, b)| (a - b).powf(2.))
.sum::<f32>()
.sqrt() as f64
@@ -22,15 +29,19 @@ pub(crate) struct ContentEmbedder<'a> {
}
impl<'a> ContentEmbedder<'a> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg: cfg }
+ ContentEmbedder { cfg }
}
}
impl<'a> Drop for ContentEmbedder<'a> {
fn drop(&mut self) {
- self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py")
+ self.cfg
+ .base_dirs
+ .place_runtime_file("imgbeddings-api.py")
.iter()
- .for_each(|p| { let _ = std::fs::remove_file(&p); });
+ .for_each(|p| {
+ let _ = remove_file(p);
+ });
}
}
@@ -39,13 +50,17 @@ impl BatchEmbedder for ContentEmbedder<'_> {
const NAME: &'static str = "imgbeddings";
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
- let venv_dir = self.cfg.base_dirs
+ let venv_dir = self
+ .cfg
+ .base_dirs
.create_data_directory("imgbeddings-venv")?;
- let script_file = self.cfg.base_dirs
+ let script_file = self
+ .cfg
+ .base_dirs
.place_runtime_file("imgbeddings-api.py")?;
let api_prog = include_bytes!("imgbeddings-api.py");
- copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?;
+ copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;
let bar = ProgressBar::new_spinner();
bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
@@ -73,9 +88,12 @@ impl BatchEmbedder for ContentEmbedder<'_> {
.stdout(Stdio::piped())
.spawn()?;
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ let bar = ProgressBar::new(paths.len() as u64).with_style(st);
BufReader::new(child.stdout.unwrap())
.lines()
- .progress_count(paths.len().try_into().unwrap())
+ .progress_with(bar)
.map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
.try_collect()
}
diff --git a/src/embedders.rs b/src/embedders.rs
index 97cc5ac..a14a8cf 100644
--- a/src/embedders.rs
+++ b/src/embedders.rs
@@ -33,9 +33,10 @@ impl<T: EmbedderT> BatchEmbedder for T {
const NAME: &'static str = T::NAME;
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
- let st = ProgressStyle::
- with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
- paths.par_iter()
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ paths
+ .par_iter()
.progress_with_style(st)
.map(|p| self.embed(p))
.collect::<Vec<_>>()
@@ -43,4 +44,3 @@ impl<T: EmbedderT> BatchEmbedder for T {
.try_collect()
}
}
-
diff --git a/src/main.rs b/src/main.rs
index 6337b14..c8f5bbc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -2,16 +2,21 @@
use anyhow::Result;
use clap::Parser;
-use priority_queue::PriorityQueue;
-use sha2::{Sha512_256, Digest};
-use std::{cmp::Ordering, collections::HashMap, fs, io, io::Write, path, path::PathBuf};
+use sha2::{Digest, Sha512_256};
+use std::{
+ fs,
+ io::{self, Write},
+ path::{self, PathBuf},
+};
+use ai_embedders::*;
use embedders::*;
use pure_embedders::*;
-use ai_embedders::*;
+use tsp_approx::*;
+mod ai_embedders;
mod embedders;
mod pure_embedders;
-mod ai_embedders;
+mod tsp_approx;
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum Embedder {
@@ -57,91 +62,21 @@ fn get_config() -> Result<Config> {
Ok(Config { base_dirs: dirs })
}
-fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
- where M: MetricElem
-{
- // wrapper struct to
- // - reverse the ordering
- // - implement Ord, even though the type is backed by an f64
- #[repr(transparent)]
- #[derive(Debug, PartialEq)]
- struct DownOrd (f64);
- impl Eq for DownOrd {}
- impl PartialOrd for DownOrd {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl Ord for DownOrd {
- fn cmp(&self, other: &Self) -> Ordering {
- self.0.partial_cmp(&other.0).unwrap().reverse()
- }
- }
-
- let num_embeds = embeds.len();
-
- let mut possible_edges =
- PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
- let mut mst = HashMap::with_capacity(num_embeds);
-
- // here, we start at 0.
- // we might get a better result in the end if we started with a vertex next
- // to the lowest-cost edge, but we don't know which one that is (though we
- // could compute that without changing our asymptotic complexity)
- mst.insert(0, Vec::new());
- for i in 1..num_embeds {
- possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
- }
-
- // prims algorithm or something like that
- while mst.len() < num_embeds {
- // find the edge with the least cost that connects us to a new vertex
- let (new, old) = loop {
- let ((a, b), _) = possible_edges.pop().unwrap();
- if !mst.contains_key(&a) {
- break (a, b);
- }
- else if !mst.contains_key(&b) {
- break (b, a);
- }
- };
- mst.insert(new, Vec::new());
-
- // insert all the new edges we could take
- mst.entry(old).and_modify(|v|v.push(new));
- for i in 0..num_embeds {
- // don't consider edges taking us to nodes we already visited
- if mst.contains_key(&i) {
- continue;
- }
-
- possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i])));
- }
- }
-
- mst
-}
-
-fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
- fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
- into.push(cur);
- t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
- }
- let mut tsp_path = Vec::with_capacity(mst.len());
- dfs(0, &mst, &mut tsp_path);
-
- tsp_path
-}
-
fn hash_file(p: &PathBuf) -> Result<[u8; 32]> {
let mut f = fs::File::open(p)?;
let mut hasher = Sha512_256::new();
io::copy(&mut f, &mut hasher)?;
- Ok(hasher.finalize().into_iter().collect::<Vec<u8>>().try_into().unwrap())
+ Ok(hasher
+ .finalize()
+ .into_iter()
+ .collect::<Vec<u8>>()
+ .try_into()
+ .unwrap())
}
fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBuf>>
- where E: BatchEmbedder
+where
+ E: BatchEmbedder,
{
if args.images.is_empty() {
return Ok(Vec::new());
@@ -150,7 +85,8 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu
let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?;
let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME);
- let mut embeds: Vec<Option<_>> = args.images
+ let mut embeds: Vec<Option<_>> = args
+ .images
.iter()
.map(|p| {
let h = hash_file(p)?;
@@ -165,28 +101,32 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<Vec<PathBu
.filter_map(|(i, v)| match v {
None => Some(i),
Some(_) => None,
- }).collect();
+ })
+ .collect();
// TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows
// for optimizations in the ai embedde (move pip to ::embeds() instead of ::new())
- let missing_embeds = e.embeds(&missing_embeds_indices
+ let missing_embeds = e.embeds(
+ &missing_embeds_indices
.iter()
.map(|i| args.images[*i].clone())
- .collect::<Vec<_>>())?;
+ .collect::<Vec<_>>(),
+ )?;
for (idx, emb) in missing_embeds_indices
- .into_iter().zip(missing_embeds.into_iter())
+ .into_iter()
+ .zip(missing_embeds.into_iter())
{
tree.insert(&hash_file(&args.images[idx])?, &emb)?;
embeds[idx] = Some(emb);
}
let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
- let tsp_path = tsp_from_mst(get_mst(&embeds));
+ let tsp_path = tsp(&embeds);
Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect())
}
-fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> {
+fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> {
fs::create_dir_all(target)?;
let pad_len = (tsp.len() as f64).log10().ceil() as usize;
@@ -195,24 +135,16 @@ fn symlink_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> {
None => "".to_string(),
Some(e) => format!(".{}", e.to_str().unwrap()),
};
- let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext));
- let rel_path = pathdiff::diff_paths(path::absolute(p)?,
- path::absolute(target)?).unwrap();
- std::os::unix::fs::symlink(rel_path, tp);
- }
- Ok(())
-}
-fn copy_into(tsp: &[PathBuf], target: &PathBuf) -> Result<()> {
- fs::create_dir_all(target)?;
+ let tp = target.join(format!("{i:0pad_len$}{ext}"));
- let pad_len = (tsp.len() as f64).log10().ceil() as usize;
- for (i, p) in tsp.iter().enumerate() {
- let ext: String = match p.extension() {
- None => "".to_string(),
- Some(e) => format!(".{}", e.to_str().unwrap()),
- };
- let tp = target.join(format!("{:0pl$}{ext}", i, pl = pad_len, ext = ext));
- reflink_copy::reflink_or_copy(p, tp)?;
+ if use_symlinks {
+ let rel_path =
+ pathdiff::diff_paths(path::absolute(p)?, path::absolute(target)?).unwrap();
+ let _ = fs::remove_file(&tp);
+ std::os::unix::fs::symlink(rel_path, tp)?;
+ } else {
+ reflink_copy::reflink_or_copy(p, tp)?;
+ }
}
Ok(())
}
@@ -228,18 +160,26 @@ fn main() -> Result<()> {
Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg),
}?;
- if let Some(p) = args.symlink_dir { symlink_into(&tsp_path, &p)? }
- if let Some(p) = args.copy_dir { copy_into(&tsp_path, &p)? }
+ if let Some(p) = args.symlink_dir {
+ copy_into(&tsp_path, &p, true)?
+ }
+ if let Some(p) = args.copy_dir {
+ copy_into(&tsp_path, &p, false)?
+ }
- let path_delim = if args.stdout0 { Some(0) }
- else if args.stdout { Some(b'\n') }
- else { None };
+ let path_delim = if args.stdout0 {
+ Some(0)
+ } else if args.stdout {
+ Some(b'\n')
+ } else {
+ None
+ };
path_delim.into_iter().try_for_each(|delim| {
let mut o = io::BufWriter::new(io::stdout().lock());
for p in &tsp_path {
- o.write(p.as_os_str().to_str().unwrap().as_bytes())?;
- o.write(&[delim])?;
+ o.write_all(p.as_os_str().to_str().unwrap().as_bytes())?;
+ o.write_all(&[delim])?;
}
o.flush()
})?;
diff --git a/src/pure_embedders.rs b/src/pure_embedders.rs
index 0f0c3ab..09c8321 100644
--- a/src/pure_embedders.rs
+++ b/src/pure_embedders.rs
@@ -2,7 +2,7 @@ use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
-use crate::{MetricElem, EmbedderT};
+use crate::{EmbedderT, MetricElem};
pub(crate) struct BrightnessEmbedder;
impl EmbedderT for BrightnessEmbedder {
@@ -17,20 +17,17 @@ impl EmbedderT for BrightnessEmbedder {
bail!("Encountered NaN brightness, due to an empty image");
}
- Ok(im.to_rgb8()
- .iter()
- .map(|e| *e as u64)
- .sum::<u64>() as f64 / num_bytes as f64)
+ Ok(im.to_rgb8().iter().map(|e| *e as u64).sum::<u64>() as f64 / num_bytes as f64)
}
}
#[repr(transparent)]
#[derive(Serialize, Deserialize)]
-pub(crate) struct Hue (f64);
+pub(crate) struct Hue(f64);
impl MetricElem for Hue {
fn dist(&self, b: &Hue) -> f64 {
let d = self.0.dist(&b.0);
- d.min(6.-d)
+ d.min(6. - d)
}
}
pub(crate) struct HueEmbedder;
@@ -50,16 +47,13 @@ impl EmbedderT for HueEmbedder {
})
.map(|e| e as f64 / 255. / num_pixels as f64);
- let hue =
- if sr >= sg && sr >= sb {
- (sg - sb) / (sr - sg.min(sb))
- }
- else if sg >= sb {
- 2. + (sb - sr) / (sg - sr.min(sb))
- }
- else {
- 4. + (sr - sg) / (sb - sr.min(sg))
- };
+ let hue = if sr >= sg && sr >= sb {
+ (sg - sb) / (sr - sg.min(sb))
+ } else if sg >= sb {
+ 2. + (sb - sr) / (sg - sr.min(sb))
+ } else {
+ 4. + (sr - sg) / (sb - sr.min(sg))
+ };
if hue.is_nan() {
bail!("Encountered NaN hue, possibly because of a colorless or empty image");
@@ -71,9 +65,8 @@ impl EmbedderT for HueEmbedder {
impl MetricElem for (f64, f64, f64) {
fn dist(&self, o: &(f64, f64, f64)) -> f64 {
- let (dr, dg, db) =
- ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2));
- (dr*dr + dg*dg + db*db).sqrt()
+ let (dr, dg, db) = ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2));
+ (dr * dr + dg * dg + db * db).sqrt()
}
}
pub(crate) struct ColorEmbedder;
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
new file mode 100644
index 0000000..c697a25
--- /dev/null
+++ b/src/tsp_approx.rs
@@ -0,0 +1,97 @@
+use indicatif::{ProgressBar, ProgressStyle};
+use priority_queue::PriorityQueue;
+use std::{cmp::Ordering, collections::HashMap};
+
+use crate::MetricElem;
+
+fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
+where
+ M: MetricElem,
+{
+ // wrapper struct to
+ // - reverse the ordering
+ // - implement Ord, even though the type is backed by an f64
+ #[repr(transparent)]
+ #[derive(Debug, PartialEq)]
+ struct DownOrd(f64);
+ impl Eq for DownOrd {}
+ impl PartialOrd for DownOrd {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+ }
+ impl Ord for DownOrd {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.0.partial_cmp(&other.0).unwrap().reverse()
+ }
+ }
+
+ let num_embeds = embeds.len();
+
+ let mut possible_edges =
+ PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
+ let mut mst = HashMap::with_capacity(num_embeds);
+
+ // here, we start at 0.
+ // we might get a better result in the end if we started with a vertex next
+ // to the lowest-cost edge, but we don't know which one that is (though we
+ // could compute that without changing our asymptotic complexity)
+ mst.insert(0, Vec::new());
+ for i in 1..num_embeds {
+ possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
+ }
+
+ // prims algorithm or something like that
+ while mst.len() < num_embeds {
+ // find the edge with the least cost that connects us to a new vertex
+ let (new, old) = loop {
+ let ((a, b), _) = possible_edges.pop().unwrap();
+ if !mst.contains_key(&a) {
+ break (a, b);
+ } else if !mst.contains_key(&b) {
+ break (b, a);
+ }
+ };
+ mst.insert(new, Vec::new());
+
+ // insert all the new edges we could take
+ mst.entry(old).and_modify(|v| v.push(new));
+ for i in 0..num_embeds {
+ // don't consider edges taking us to nodes we already visited
+ if mst.contains_key(&i) {
+ continue;
+ }
+
+ possible_edges.push((new, i), DownOrd(embeds[new].dist(&embeds[i])));
+ }
+ }
+
+ mst
+}
+
+fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
+ fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
+ into.push(cur);
+ t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
+ }
+ let mut tsp_path = Vec::with_capacity(mst.len());
+ dfs(0, &mst, &mut tsp_path);
+
+ tsp_path
+}
+
+pub fn tsp<M>(embeds: &Vec<M>) -> Vec<usize>
+where
+ M: MetricElem,
+{
+ let bar = ProgressBar::new_spinner();
+ bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
+ bar.enable_steady_tick(std::time::Duration::from_millis(100));
+ bar.set_message("Finding path...");
+
+ let r = tsp_from_mst(get_mst(embeds));
+
+ bar.finish();
+
+ r
+}