aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/ai.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r--src/embedders/ai.rs35
1 files changed, 13 insertions, 22 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 3848674..120714c 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -1,39 +1,30 @@
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
-use serde::{Deserialize, Serialize};
use std::{
fs::{remove_file, File},
io::{copy, BufRead, BufReader, Cursor},
+ marker::PhantomData,
path::PathBuf,
process::{Command, Stdio},
};
-use crate::{BatchEmbedder, Config, MetricElem};
+use super::vecmetric::VecMetric;
+use crate::{BatchEmbedder, Config};
-#[repr(transparent)]
-#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding(Vec<f32>);
-impl MetricElem for Imgbedding {
- fn dist(&self, other: &Self) -> f64 {
- self.0
- .iter()
- .zip(other.0.iter())
- .map(|(a, b)| (a - b).powf(2.))
- .sum::<f32>()
- .sqrt() as f64
- }
-}
-
-pub(crate) struct ContentEmbedder<'a> {
+pub(crate) struct ContentEmbedder<'a, Metric> {
cfg: &'a Config,
+ _sim: PhantomData<Metric>,
}
-impl<'a> ContentEmbedder<'a> {
+impl<'a, Metric> ContentEmbedder<'a, Metric> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg }
+ ContentEmbedder {
+ cfg,
+ _sim: PhantomData,
+ }
}
}
-impl<'a> Drop for ContentEmbedder<'a> {
+impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
fn drop(&mut self) {
self.cfg
.base_dirs
@@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> {
}
}
-impl BatchEmbedder for ContentEmbedder<'_> {
- type Embedding = Imgbedding;
+impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
+ type Embedding = Metric;
const NAME: &'static str = "imgbeddings";
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {