From 6a2d5e2241b73039ae09b627c30841248d887a79 Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Mon, 25 Nov 2024 01:33:42 +0100 Subject: implement rust side of flag that ignores failing embeddings --- src/embedders/ai.rs | 32 ++++++++++++++++++-------- src/embedders/mod.rs | 10 ++++----- src/main.rs | 63 +++++++++++++++++++++++++++++++++++++--------------- src/tsp_approx.rs | 6 +++++ 4 files changed, 78 insertions(+), 33 deletions(-) (limited to 'src') diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 120714c..7d5ae90 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -24,7 +24,7 @@ impl<'a, Metric> ContentEmbedder<'a, Metric> { } } -impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { +impl Drop for ContentEmbedder<'_, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -36,11 +36,11 @@ impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { } } -impl BatchEmbedder for ContentEmbedder<'_, Metric> { - type Embedding = Metric; - const NAME: &'static str = "imgbeddings"; - - fn embeds(&mut self, paths: &[PathBuf]) -> Result> { +impl ContentEmbedder<'_, Metric> { + fn embeds_or_err( + &mut self, + paths: &[PathBuf], + ) -> Result::Embedding>>> { let venv_dir = self .cfg .base_dirs @@ -75,17 +75,31 @@ impl BatchEmbedder for ContentEmbedder<'_, Metric> { let child = Command::new(venv_dir.join("bin/python3")) .arg(script_file) .args(paths) - .stderr(Stdio::null()) + .stderr(Stdio::inherit()) .stdout(Stdio::piped()) .spawn()?; + // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück. + // python-code muss dafür auch geändert werden xD let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; let bar = ProgressBar::new(paths.len() as u64).with_style(st); - BufReader::new(child.stdout.unwrap()) + Ok(BufReader::new(child.stdout.unwrap()) .lines() .progress_with(bar) .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) - .try_collect() + .collect()) + } +} + +impl BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Vec> { + match self.embeds_or_err(paths) { + Ok(embeddings) => embeddings, + Err(e) => vec![Err(e)], + } } } diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs index 5ade40d..1a1721d 100644 --- a/src/embedders/mod.rs +++ b/src/embedders/mod.rs @@ -32,22 +32,20 @@ pub trait BatchEmbedder: Send + Sync { type Embedding: MetricElem; const NAME: &'static str; - fn embeds(&mut self, _: &[PathBuf]) -> Result>; + fn embeds(&mut self, _: &[PathBuf]) -> Vec>; } impl BatchEmbedder for T { type Embedding = T::Embedding; const NAME: &'static str = T::NAME; - fn embeds(&mut self, paths: &[PathBuf]) -> Result> { - let st = - ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + fn embeds(&mut self, paths: &[PathBuf]) -> Vec> { + let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...") + .unwrap(); paths .par_iter() .progress_with_style(st) .map(|p| self.embed(p)) .collect::>() - .into_iter() - .try_collect() } } diff --git a/src/main.rs b/src/main.rs index 1a520bd..92f51f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -#![feature(iterator_try_collect, absolute_path)] +#![feature(iterator_try_collect)] use anyhow::Result; use clap::Parser; @@ -60,10 +60,14 @@ struct Args { #[arg(long, default_value = "christofides")] tsp_approx: TspBaseAlg, - /// number of 2-Opt refinement steps. Has quickly diminishing returns + /// Number of 2-Opt refinement steps. Has quickly diminishing returns #[arg(short = 'r', default_value = "3")] refine: usize, + /// Ignore failed embeddings + #[arg(short = 'i')] + ignore_errors: bool, + /// Seed for hashing. Random by default. #[arg(long)] hash_seed: Option, @@ -98,57 +102,80 @@ fn process_embedder(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec::open(&db, E::NAME); - let mut embeds: Vec> = args + // find cached embeddings + let mut embeds: Vec<_> = args .images .iter() - .map(|p| { - let h = hash_file(p)?; + .map(|path| { + let h = hash_file(path)?; let r: Result> = tree.get(&h).map_err(|e| e.into()); r }) .try_collect()?; + // find indices of missing embeddings let missing_embeds_indices: Vec<_> = embeds .iter() .enumerate() - .filter_map(|(i, v)| match v { + .filter_map(|(i, cached_embedding)| match cached_embedding { None => Some(i), Some(_) => None, }) .collect(); - // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows - // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) + + // calculate missing embeddings let missing_embeds = if missing_embeds_indices.is_empty() { - Vec::new() + vec![] } else { e.embeds( &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) .collect::>(), - )? + ) }; + // insert successfully changed for (idx, emb) in missing_embeds_indices .into_iter() .zip(missing_embeds.into_iter()) { - tree.insert(&hash_file(&args.images[idx])?, &emb)?; - embeds[idx] = Some(emb); + match emb { + Ok(emb) => { + tree.insert(&hash_file(&args.images[idx])?, &emb)?; + embeds[idx] = Some(emb); + } + Err(e) => { + if !args.ignore_errors { + return Err(e); + } + } + } } - let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); + // filter out images with failed embeddings + let (embeds, images): (Vec<_>, Vec<_>) = embeds + .into_iter() + .zip(args.images.iter()) + .filter_map(|(emb, path)| match emb { + Some(embedding) => Some((embedding, path)), + None => { + if args.ignore_errors { + None + } else { + panic!("Embedding failed for {}", path.display()) + } + } + }) + .unzip(); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed); Ok(( - tsp_path.iter().map(|i| args.images[*i].clone()).collect(), + tsp_path.iter().map(|i| images[*i].clone()).collect(), total_dist, )) } diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index 097fee7..4484b91 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -403,6 +403,12 @@ pub(crate) fn tsp( where M: MetricElem, { + match embeds.len() { + 0 => return (vec![], 0.0), + 1 => return (vec![0], 0.0), + _ => (), + } + let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}").unwrap()); bar.enable_steady_tick(std::time::Duration::from_millis(100)); -- cgit v1.2.3-70-g09d2