aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--Cargo.lock4
-rw-r--r--Cargo.toml2
-rw-r--r--src/embedders/ai.rs32
-rw-r--r--src/embedders/mod.rs10
-rw-r--r--src/main.rs63
-rw-r--r--src/tsp_approx.rs6
7 files changed, 82 insertions, 36 deletions
diff --git a/.gitignore b/.gitignore
index ea8c4bf..10c31e4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
/target
+/native
diff --git a/Cargo.lock b/Cargo.lock
index e20d86e..73eb595 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
-version = 3
+version = 4
[[package]]
name = "adler"
@@ -301,7 +301,7 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "embeddings-sort"
-version = "0.2.0"
+version = "0.3.0"
dependencies = [
"ahash",
"anyhow",
diff --git a/Cargo.toml b/Cargo.toml
index 800529c..641b5a3 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "embeddings-sort"
-version = "0.2.0"
+version = "0.3.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 120714c..7d5ae90 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -24,7 +24,7 @@ impl<'a, Metric> ContentEmbedder<'a, Metric> {
}
}
-impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
+impl<Metric> Drop for ContentEmbedder<'_, Metric> {
fn drop(&mut self) {
self.cfg
.base_dirs
@@ -36,11 +36,11 @@ impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
}
}
-impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
- type Embedding = Metric;
- const NAME: &'static str = "imgbeddings";
-
- fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+impl<Metric: VecMetric> ContentEmbedder<'_, Metric> {
+ fn embeds_or_err(
+ &mut self,
+ paths: &[PathBuf],
+ ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> {
let venv_dir = self
.cfg
.base_dirs
@@ -75,17 +75,31 @@ impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
let child = Command::new(venv_dir.join("bin/python3"))
.arg(script_file)
.args(paths)
- .stderr(Stdio::null())
+ .stderr(Stdio::inherit())
.stdout(Stdio::piped())
.spawn()?;
+ // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück.
+ // python-code muss dafür auch geändert werden xD
let st =
ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
let bar = ProgressBar::new(paths.len() as u64).with_style(st);
- BufReader::new(child.stdout.unwrap())
+ Ok(BufReader::new(child.stdout.unwrap())
.lines()
.progress_with(bar)
.map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
- .try_collect()
+ .collect())
+ }
+}
+
+impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
+ type Embedding = Metric;
+ const NAME: &'static str = "imgbeddings";
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> {
+ match self.embeds_or_err(paths) {
+ Ok(embeddings) => embeddings,
+ Err(e) => vec![Err(e)],
+ }
}
}
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
index 5ade40d..1a1721d 100644
--- a/src/embedders/mod.rs
+++ b/src/embedders/mod.rs
@@ -32,22 +32,20 @@ pub trait BatchEmbedder: Send + Sync {
type Embedding: MetricElem;
const NAME: &'static str;
- fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
+ fn embeds(&mut self, _: &[PathBuf]) -> Vec<Result<Self::Embedding>>;
}
impl<T: EmbedderT> BatchEmbedder for T {
type Embedding = T::Embedding;
const NAME: &'static str = T::NAME;
- fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
- let st =
- ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> {
+ let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")
+ .unwrap();
paths
.par_iter()
.progress_with_style(st)
.map(|p| self.embed(p))
.collect::<Vec<_>>()
- .into_iter()
- .try_collect()
}
}
diff --git a/src/main.rs b/src/main.rs
index 1a520bd..92f51f1 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,4 @@
-#![feature(iterator_try_collect, absolute_path)]
+#![feature(iterator_try_collect)]
use anyhow::Result;
use clap::Parser;
@@ -60,10 +60,14 @@ struct Args {
#[arg(long, default_value = "christofides")]
tsp_approx: TspBaseAlg,
- /// number of 2-Opt refinement steps. Has quickly diminishing returns
+ /// Number of 2-Opt refinement steps. Has quickly diminishing returns
#[arg(short = 'r', default_value = "3")]
refine: usize,
+ /// Ignore failed embeddings
+ #[arg(short = 'i')]
+ ignore_errors: bool,
+
/// Seed for hashing. Random by default.
#[arg(long)]
hash_seed: Option<u64>,
@@ -98,57 +102,80 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathB
where
E: BatchEmbedder,
{
- if args.images.is_empty() {
- return Ok((Vec::new(), 0.));
- }
-
let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?;
let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME);
- let mut embeds: Vec<Option<_>> = args
+ // find cached embeddings
+ let mut embeds: Vec<_> = args
.images
.iter()
- .map(|p| {
- let h = hash_file(p)?;
+ .map(|path| {
+ let h = hash_file(path)?;
let r: Result<Option<E::Embedding>> = tree.get(&h).map_err(|e| e.into());
r
})
.try_collect()?;
+ // find indices of missing embeddings
let missing_embeds_indices: Vec<_> = embeds
.iter()
.enumerate()
- .filter_map(|(i, v)| match v {
+ .filter_map(|(i, cached_embedding)| match cached_embedding {
None => Some(i),
Some(_) => None,
})
.collect();
- // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows
- // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new())
+
+ // calculate missing embeddings
let missing_embeds = if missing_embeds_indices.is_empty() {
- Vec::new()
+ vec![]
} else {
e.embeds(
&missing_embeds_indices
.iter()
.map(|i| args.images[*i].clone())
.collect::<Vec<_>>(),
- )?
+ )
};
+ // insert successfully changed
for (idx, emb) in missing_embeds_indices
.into_iter()
.zip(missing_embeds.into_iter())
{
- tree.insert(&hash_file(&args.images[idx])?, &emb)?;
- embeds[idx] = Some(emb);
+ match emb {
+ Ok(emb) => {
+ tree.insert(&hash_file(&args.images[idx])?, &emb)?;
+ embeds[idx] = Some(emb);
+ }
+ Err(e) => {
+ if !args.ignore_errors {
+ return Err(e);
+ }
+ }
+ }
}
- let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
+ // filter out images with failed embeddings
+ let (embeds, images): (Vec<_>, Vec<_>) = embeds
+ .into_iter()
+ .zip(args.images.iter())
+ .filter_map(|(emb, path)| match emb {
+ Some(embedding) => Some((embedding, path)),
+ None => {
+ if args.ignore_errors {
+ None
+ } else {
+ panic!("Embedding failed for {}", path.display())
+ }
+ }
+ })
+ .unzip();
+
let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed);
Ok((
- tsp_path.iter().map(|i| args.images[*i].clone()).collect(),
+ tsp_path.iter().map(|i| images[*i].clone()).collect(),
total_dist,
))
}
diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs
index 097fee7..4484b91 100644
--- a/src/tsp_approx.rs
+++ b/src/tsp_approx.rs
@@ -403,6 +403,12 @@ pub(crate) fn tsp<M>(
where
M: MetricElem,
{
+ match embeds.len() {
+ 0 => return (vec![], 0.0),
+ 1 => return (vec![0], 0.0),
+ _ => (),
+ }
+
let bar = ProgressBar::new_spinner();
bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap());
bar.enable_steady_tick(std::time::Duration::from_millis(100));