diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2024-11-25 01:33:42 +0100 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2024-11-25 01:33:42 +0100 |
commit | 6a2d5e2241b73039ae09b627c30841248d887a79 (patch) | |
tree | 8242692c4b8b087f4fc2323ad09c7b7a63075c9b | |
parent | 4195f3e742bec7ed0ef1b193a725ea7335a547ef (diff) | |
download | embeddings-sort-6a2d5e2241b73039ae09b627c30841248d887a79.tar embeddings-sort-6a2d5e2241b73039ae09b627c30841248d887a79.tar.bz2 embeddings-sort-6a2d5e2241b73039ae09b627c30841248d887a79.tar.zst |
implement rust side of flag that ignores failing embeddings
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | Cargo.lock | 4 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | src/embedders/ai.rs | 32 | ||||
-rw-r--r-- | src/embedders/mod.rs | 10 | ||||
-rw-r--r-- | src/main.rs | 63 | ||||
-rw-r--r-- | src/tsp_approx.rs | 6 |
7 files changed, 82 insertions, 36 deletions
@@ -1 +1,2 @@ /target +/native @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "adler" @@ -301,7 +301,7 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "embeddings-sort" -version = "0.2.0" +version = "0.3.0" dependencies = [ "ahash", "anyhow", @@ -1,6 +1,6 @@ [package] name = "embeddings-sort" -version = "0.2.0" +version = "0.3.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 120714c..7d5ae90 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -24,7 +24,7 @@ impl<'a, Metric> ContentEmbedder<'a, Metric> { } } -impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { +impl<Metric> Drop for ContentEmbedder<'_, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -36,11 +36,11 @@ impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { } } -impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { - type Embedding = Metric; - const NAME: &'static str = "imgbeddings"; - - fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { +impl<Metric: VecMetric> ContentEmbedder<'_, Metric> { + fn embeds_or_err( + &mut self, + paths: &[PathBuf], + ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> { let venv_dir = self .cfg .base_dirs @@ -75,17 +75,31 @@ impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { let child = Command::new(venv_dir.join("bin/python3")) .arg(script_file) .args(paths) - .stderr(Stdio::null()) + .stderr(Stdio::inherit()) .stdout(Stdio::piped()) .spawn()?; + // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück. + // python-code muss dafür auch geändert werden xD let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; let bar = ProgressBar::new(paths.len() as u64).with_style(st); - BufReader::new(child.stdout.unwrap()) + Ok(BufReader::new(child.stdout.unwrap()) .lines() .progress_with(bar) .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) - .try_collect() + .collect()) + } +} + +impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> { + match self.embeds_or_err(paths) { + Ok(embeddings) => embeddings, + Err(e) => vec![Err(e)], + } } } diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs index 5ade40d..1a1721d 100644 --- a/src/embedders/mod.rs +++ b/src/embedders/mod.rs @@ -32,22 +32,20 @@ pub trait BatchEmbedder: Send + Sync { type Embedding: MetricElem; const NAME: &'static str; - fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>; + fn embeds(&mut self, _: &[PathBuf]) -> Vec<Result<Self::Embedding>>; } impl<T: EmbedderT> BatchEmbedder for T { type Embedding = T::Embedding; const NAME: &'static str = T::NAME; - fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { - let st = - ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> { + let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...") + .unwrap(); paths .par_iter() .progress_with_style(st) .map(|p| self.embed(p)) .collect::<Vec<_>>() - .into_iter() - .try_collect() } } diff --git a/src/main.rs b/src/main.rs index 1a520bd..92f51f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -#![feature(iterator_try_collect, absolute_path)] +#![feature(iterator_try_collect)] use anyhow::Result; use clap::Parser; @@ -60,10 +60,14 @@ struct Args { #[arg(long, default_value = "christofides")] tsp_approx: TspBaseAlg, - /// number of 2-Opt refinement steps. Has quickly diminishing returns + /// Number of 2-Opt refinement steps. Has quickly diminishing returns #[arg(short = 'r', default_value = "3")] refine: usize, + /// Ignore failed embeddings + #[arg(short = 'i')] + ignore_errors: bool, + /// Seed for hashing. Random by default. #[arg(long)] hash_seed: Option<u64>, @@ -98,57 +102,80 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathB where E: BatchEmbedder, { - if args.images.is_empty() { - return Ok((Vec::new(), 0.)); - } - let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?; let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); - let mut embeds: Vec<Option<_>> = args + // find cached embeddings + let mut embeds: Vec<_> = args .images .iter() - .map(|p| { - let h = hash_file(p)?; + .map(|path| { + let h = hash_file(path)?; let r: Result<Option<E::Embedding>> = tree.get(&h).map_err(|e| e.into()); r }) .try_collect()?; + // find indices of missing embeddings let missing_embeds_indices: Vec<_> = embeds .iter() .enumerate() - .filter_map(|(i, v)| match v { + .filter_map(|(i, cached_embedding)| match cached_embedding { None => Some(i), Some(_) => None, }) .collect(); - // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows - // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) + + // calculate missing embeddings let missing_embeds = if missing_embeds_indices.is_empty() { - Vec::new() + vec![] } else { e.embeds( &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) .collect::<Vec<_>>(), - )? + ) }; + // insert successfully changed for (idx, emb) in missing_embeds_indices .into_iter() .zip(missing_embeds.into_iter()) { - tree.insert(&hash_file(&args.images[idx])?, &emb)?; - embeds[idx] = Some(emb); + match emb { + Ok(emb) => { + tree.insert(&hash_file(&args.images[idx])?, &emb)?; + embeds[idx] = Some(emb); + } + Err(e) => { + if !args.ignore_errors { + return Err(e); + } + } + } } - let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); + // filter out images with failed embeddings + let (embeds, images): (Vec<_>, Vec<_>) = embeds + .into_iter() + .zip(args.images.iter()) + .filter_map(|(emb, path)| match emb { + Some(embedding) => Some((embedding, path)), + None => { + if args.ignore_errors { + None + } else { + panic!("Embedding failed for {}", path.display()) + } + } + }) + .unzip(); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed); Ok(( - tsp_path.iter().map(|i| args.images[*i].clone()).collect(), + tsp_path.iter().map(|i| images[*i].clone()).collect(), total_dist, )) } diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 097fee7..4484b91 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -403,6 +403,12 @@ pub(crate) fn tsp<M>( where M: MetricElem, { + match embeds.len() { + 0 => return (vec![], 0.0), + 1 => return (vec![0], 0.0), + _ => (), + } + let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); bar.enable_steady_tick(std::time::Duration::from_millis(100)); |