aboutsummaryrefslogtreecommitdiff
path: root/src/embedders
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders')
-rw-r--r--src/embedders/ai.rs32
-rw-r--r--src/embedders/mod.rs10
2 files changed, 27 insertions, 15 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 120714c..7d5ae90 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -24,7 +24,7 @@ impl<'a, Metric> ContentEmbedder<'a, Metric> {
}
}
-impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
+impl<Metric> Drop for ContentEmbedder<'_, Metric> {
fn drop(&mut self) {
self.cfg
.base_dirs
@@ -36,11 +36,11 @@ impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
}
}
-impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
- type Embedding = Metric;
- const NAME: &'static str = "imgbeddings";
-
- fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+impl<Metric: VecMetric> ContentEmbedder<'_, Metric> {
+ fn embeds_or_err(
+ &mut self,
+ paths: &[PathBuf],
+ ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> {
let venv_dir = self
.cfg
.base_dirs
@@ -75,17 +75,31 @@ impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
let child = Command::new(venv_dir.join("bin/python3"))
.arg(script_file)
.args(paths)
- .stderr(Stdio::null())
+ .stderr(Stdio::inherit())
.stdout(Stdio::piped())
.spawn()?;
+ // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück.
+ // python-code muss dafür auch geändert werden xD
let st =
ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
let bar = ProgressBar::new(paths.len() as u64).with_style(st);
- BufReader::new(child.stdout.unwrap())
+ Ok(BufReader::new(child.stdout.unwrap())
.lines()
.progress_with(bar)
.map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
- .try_collect()
+ .collect())
+ }
+}
+
+impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
+ type Embedding = Metric;
+ const NAME: &'static str = "imgbeddings";
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> {
+ match self.embeds_or_err(paths) {
+ Ok(embeddings) => embeddings,
+ Err(e) => vec![Err(e)],
+ }
}
}
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
index 5ade40d..1a1721d 100644
--- a/src/embedders/mod.rs
+++ b/src/embedders/mod.rs
@@ -32,22 +32,20 @@ pub trait BatchEmbedder: Send + Sync {
type Embedding: MetricElem;
const NAME: &'static str;
- fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
+ fn embeds(&mut self, _: &[PathBuf]) -> Vec<Result<Self::Embedding>>;
}
impl<T: EmbedderT> BatchEmbedder for T {
type Embedding = T::Embedding;
const NAME: &'static str = T::NAME;
- fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
- let st =
- ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> {
+ let st = ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")
+ .unwrap();
paths
.par_iter()
.progress_with_style(st)
.map(|p| self.embed(p))
.collect::<Vec<_>>()
- .into_iter()
- .try_collect()
}
}