diff options
author | metamuffin <metamuffin@disroot.org> | 2023-09-20 17:23:45 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-09-20 17:23:45 +0200 |
commit | 4a17c06f22d3236da6f30c397695ef3771a9d393 (patch) | |
tree | 5d8477d40ad4119406d93e01541bf97bdcc85c0f /src | |
parent | fbfee0a2bb436a6205d67f561dbd6284621504d6 (diff) | |
download | embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar.bz2 embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar.zst |
support for different vector metrics
Diffstat (limited to 'src')
-rw-r--r-- | src/embedders/ai.rs | 35 | ||||
-rw-r--r-- | src/embedders/mod.rs | 2 | ||||
-rw-r--r-- | src/embedders/vecmetric.rs | 43 | ||||
-rw-r--r-- | src/main.rs | 16 |
4 files changed, 71 insertions, 25 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 3848674..e772e4a 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -1,39 +1,30 @@ +use crate::{BatchEmbedder, Config}; use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use serde::{Deserialize, Serialize}; use std::{ fs::{remove_file, File}, io::{copy, BufRead, BufReader, Cursor}, + marker::PhantomData, path::PathBuf, process::{Command, Stdio}, }; -use crate::{BatchEmbedder, Config, MetricElem}; +use super::vecmetric::VecMetric; -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding(Vec<f32>); -impl MetricElem for Imgbedding { - fn dist(&self, other: &Self) -> f64 { - self.0 - .iter() - .zip(other.0.iter()) - .map(|(a, b)| (a - b).powf(2.)) - .sum::<f32>() - .sqrt() as f64 - } -} - -pub(crate) struct ContentEmbedder<'a> { +pub(crate) struct ContentEmbedder<'a, Metric> { cfg: &'a Config, + _sim: PhantomData<Metric>, } -impl<'a> ContentEmbedder<'a> { +impl<'a, Metric> ContentEmbedder<'a, Metric> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg } + ContentEmbedder { + cfg, + _sim: PhantomData::default(), + } } } -impl<'a> Drop for ContentEmbedder<'a> { +impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> { } } -impl BatchEmbedder for ContentEmbedder<'_> { - type Embedding = Imgbedding; +impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs index 353222b..5ade40d 100644 --- a/src/embedders/mod.rs +++ b/src/embedders/mod.rs @@ -1,7 +1,9 @@ pub mod ai; pub mod pure; +pub mod vecmetric; pub(crate) use ai::*; pub(crate) use pure::*; +pub(crate) use vecmetric::*; use anyhow::Result; use indicatif::{ParallelProgressIterator, ProgressStyle}; diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs new file mode 100644 index 0000000..474a6d0 --- /dev/null +++ b/src/embedders/vecmetric.rs @@ -0,0 +1,43 @@ +use super::MetricElem; +use serde::{Deserialize, Serialize}; + +pub trait VecMetric: MetricElem + From<Vec<f32>> {} + +#[derive(Deserialize, Serialize)] +pub struct CosineSimilarity(pub Vec<f32>); +#[derive(Deserialize, Serialize)] +pub struct EuclidianDistance(pub Vec<f32>); +#[derive(Deserialize, Serialize)] +pub struct ManhattenDistance(pub Vec<f32>); + +impl VecMetric for CosineSimilarity {} +impl VecMetric for EuclidianDistance {} +impl VecMetric for ManhattenDistance {} +#[rustfmt::skip] impl From<Vec<f32>> for CosineSimilarity { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } + +impl MetricElem for CosineSimilarity { + fn dist(&self, _other: &Self) -> f64 { + todo!() + } +} +impl MetricElem for EuclidianDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} +impl MetricElem for ManhattenDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).abs()) + .sum::<f32>() as f64 + } +} diff --git a/src/main.rs b/src/main.rs index 474532c..2dda2da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,7 +19,9 @@ enum Embedder { Brightness, Hue, Color, - Content, + ContentEuclidean, + ContentCosineSim, + ContentManhatten, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] @@ -32,7 +34,7 @@ enum TspAlg { #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by - #[arg(short, long, default_value = "content")] + #[arg(short, long, default_value = "content-euclidean")] embedder: Embedder, /// Symlink the sorted images into this directory @@ -175,7 +177,15 @@ fn main() -> Result<()> { Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg), Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg), Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg), - Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg), + Embedder::ContentCosineSim => { + process_embedder(ContentEmbedder::<CosineSimilarity>::new(&cfg), &args, &cfg) + } + Embedder::ContentEuclidean => { + process_embedder(ContentEmbedder::<EuclidianDistance>::new(&cfg), &args, &cfg) + } + Embedder::ContentManhatten => { + process_embedder(ContentEmbedder::<ManhattenDistance>::new(&cfg), &args, &cfg) + } }?; if args.benchmark { |