From 4a17c06f22d3236da6f30c397695ef3771a9d393 Mon Sep 17 00:00:00 2001 From: metamuffin Date: Wed, 20 Sep 2023 17:23:45 +0200 Subject: support for different vector metrics --- src/embedders/ai.rs | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) (limited to 'src/embedders/ai.rs') diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 3848674..e772e4a 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -1,39 +1,30 @@ +use crate::{BatchEmbedder, Config}; use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use serde::{Deserialize, Serialize}; use std::{ fs::{remove_file, File}, io::{copy, BufRead, BufReader, Cursor}, + marker::PhantomData, path::PathBuf, process::{Command, Stdio}, }; -use crate::{BatchEmbedder, Config, MetricElem}; +use super::vecmetric::VecMetric; -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding(Vec); -impl MetricElem for Imgbedding { - fn dist(&self, other: &Self) -> f64 { - self.0 - .iter() - .zip(other.0.iter()) - .map(|(a, b)| (a - b).powf(2.)) - .sum::() - .sqrt() as f64 - } -} - -pub(crate) struct ContentEmbedder<'a> { +pub(crate) struct ContentEmbedder<'a, Metric> { cfg: &'a Config, + _sim: PhantomData, } -impl<'a> ContentEmbedder<'a> { +impl<'a, Metric> ContentEmbedder<'a, Metric> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg } + ContentEmbedder { + cfg, + _sim: PhantomData::default(), + } } } -impl<'a> Drop for ContentEmbedder<'a> { +impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> { } } -impl BatchEmbedder for ContentEmbedder<'_> { - type Embedding = Imgbedding; +impl BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result> { -- cgit v1.2.3-70-g09d2