diff options
author | lialenck <lialenck@noreply.codeberg.org> | 2023-09-20 20:29:11 +0000 |
---|---|---|
committer | lialenck <lialenck@noreply.codeberg.org> | 2023-09-20 20:29:11 +0000 |
commit | c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332 (patch) | |
tree | 6698bf8fb3e3b6c77207533b970badfe6c9b2eb4 /src/embedders/ai.rs | |
parent | fbfee0a2bb436a6205d67f561dbd6284621504d6 (diff) | |
parent | bcfc1b328f10188567ec99720a04418dec728868 (diff) | |
download | embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.bz2 embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.zst |
Merge pull request 'support for different vector metrics' (#1) from metamuffin/embeddings-sort:main into main
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r-- | src/embedders/ai.rs | 35 |
1 files changed, 13 insertions, 22 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 3848674..120714c 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -1,39 +1,30 @@ use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use serde::{Deserialize, Serialize}; use std::{ fs::{remove_file, File}, io::{copy, BufRead, BufReader, Cursor}, + marker::PhantomData, path::PathBuf, process::{Command, Stdio}, }; -use crate::{BatchEmbedder, Config, MetricElem}; +use super::vecmetric::VecMetric; +use crate::{BatchEmbedder, Config}; -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding(Vec<f32>); -impl MetricElem for Imgbedding { - fn dist(&self, other: &Self) -> f64 { - self.0 - .iter() - .zip(other.0.iter()) - .map(|(a, b)| (a - b).powf(2.)) - .sum::<f32>() - .sqrt() as f64 - } -} - -pub(crate) struct ContentEmbedder<'a> { +pub(crate) struct ContentEmbedder<'a, Metric> { cfg: &'a Config, + _sim: PhantomData<Metric>, } -impl<'a> ContentEmbedder<'a> { +impl<'a, Metric> ContentEmbedder<'a, Metric> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg } + ContentEmbedder { + cfg, + _sim: PhantomData, + } } } -impl<'a> Drop for ContentEmbedder<'a> { +impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> { } } -impl BatchEmbedder for ContentEmbedder<'_> { - type Embedding = Imgbedding; +impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { |