aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/ai.rs
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-09-20 17:23:45 +0200
committermetamuffin <metamuffin@disroot.org>2023-09-20 17:23:45 +0200
commit4a17c06f22d3236da6f30c397695ef3771a9d393 (patch)
tree5d8477d40ad4119406d93e01541bf97bdcc85c0f /src/embedders/ai.rs
parentfbfee0a2bb436a6205d67f561dbd6284621504d6 (diff)
downloadembeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar
embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar.bz2
embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar.zst
support for different vector metrics
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r--src/embedders/ai.rs35
1 files changed, 13 insertions, 22 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 3848674..e772e4a 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -1,39 +1,30 @@
+use crate::{BatchEmbedder, Config};
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
-use serde::{Deserialize, Serialize};
use std::{
fs::{remove_file, File},
io::{copy, BufRead, BufReader, Cursor},
+ marker::PhantomData,
path::PathBuf,
process::{Command, Stdio},
};
-use crate::{BatchEmbedder, Config, MetricElem};
+use super::vecmetric::VecMetric;
-#[repr(transparent)]
-#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding(Vec<f32>);
-impl MetricElem for Imgbedding {
- fn dist(&self, other: &Self) -> f64 {
- self.0
- .iter()
- .zip(other.0.iter())
- .map(|(a, b)| (a - b).powf(2.))
- .sum::<f32>()
- .sqrt() as f64
- }
-}
-
-pub(crate) struct ContentEmbedder<'a> {
+pub(crate) struct ContentEmbedder<'a, Metric> {
cfg: &'a Config,
+ _sim: PhantomData<Metric>,
}
-impl<'a> ContentEmbedder<'a> {
+impl<'a, Metric> ContentEmbedder<'a, Metric> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg }
+ ContentEmbedder {
+ cfg,
+ _sim: PhantomData::default(),
+ }
}
}
-impl<'a> Drop for ContentEmbedder<'a> {
+impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
fn drop(&mut self) {
self.cfg
.base_dirs
@@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> {
}
}
-impl BatchEmbedder for ContentEmbedder<'_> {
- type Embedding = Imgbedding;
+impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
+ type Embedding = Metric;
const NAME: &'static str = "imgbeddings";
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {