diff options
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r-- | src/embedders/ai.rs | 32 |
1 files changed, 23 insertions, 9 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 120714c..7d5ae90 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -24,7 +24,7 @@ impl<'a, Metric> ContentEmbedder<'a, Metric> { } } -impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { +impl<Metric> Drop for ContentEmbedder<'_, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -36,11 +36,11 @@ impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { } } -impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { - type Embedding = Metric; - const NAME: &'static str = "imgbeddings"; - - fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { +impl<Metric: VecMetric> ContentEmbedder<'_, Metric> { + fn embeds_or_err( + &mut self, + paths: &[PathBuf], + ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> { let venv_dir = self .cfg .base_dirs @@ -75,17 +75,31 @@ impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { let child = Command::new(venv_dir.join("bin/python3")) .arg(script_file) .args(paths) - .stderr(Stdio::null()) + .stderr(Stdio::inherit()) .stdout(Stdio::piped()) .spawn()?; + // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück. + // python-code muss dafür auch geändert werden xD let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; let bar = ProgressBar::new(paths.len() as u64).with_style(st); - BufReader::new(child.stdout.unwrap()) + Ok(BufReader::new(child.stdout.unwrap()) .lines() .progress_with(bar) .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) - .try_collect() + .collect()) + } +} + +impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> { + match self.embeds_or_err(paths) { + Ok(embeddings) => embeddings, + Err(e) => vec![Err(e)], + } } } |