diff options
author | lialenck <lialenck@noreply.codeberg.org> | 2023-09-20 20:29:11 +0000 |
---|---|---|
committer | lialenck <lialenck@noreply.codeberg.org> | 2023-09-20 20:29:11 +0000 |
commit | c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332 (patch) | |
tree | 6698bf8fb3e3b6c77207533b970badfe6c9b2eb4 | |
parent | fbfee0a2bb436a6205d67f561dbd6284621504d6 (diff) | |
parent | bcfc1b328f10188567ec99720a04418dec728868 (diff) | |
download | embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.bz2 embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.zst |
Merge pull request 'support for different vector metrics' (#1) from metamuffin/embeddings-sort:main into main
-rw-r--r-- | src/embedders/ai.rs | 35 | ||||
-rw-r--r-- | src/embedders/mod.rs | 2 | ||||
-rw-r--r-- | src/embedders/vecmetric.rs | 52 | ||||
-rw-r--r-- | src/main.rs | 16 |
4 files changed, 80 insertions, 25 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index 3848674..120714c 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -1,39 +1,30 @@ use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use serde::{Deserialize, Serialize}; use std::{ fs::{remove_file, File}, io::{copy, BufRead, BufReader, Cursor}, + marker::PhantomData, path::PathBuf, process::{Command, Stdio}, }; -use crate::{BatchEmbedder, Config, MetricElem}; +use super::vecmetric::VecMetric; +use crate::{BatchEmbedder, Config}; -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding(Vec<f32>); -impl MetricElem for Imgbedding { - fn dist(&self, other: &Self) -> f64 { - self.0 - .iter() - .zip(other.0.iter()) - .map(|(a, b)| (a - b).powf(2.)) - .sum::<f32>() - .sqrt() as f64 - } -} - -pub(crate) struct ContentEmbedder<'a> { +pub(crate) struct ContentEmbedder<'a, Metric> { cfg: &'a Config, + _sim: PhantomData<Metric>, } -impl<'a> ContentEmbedder<'a> { +impl<'a, Metric> ContentEmbedder<'a, Metric> { pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg } + ContentEmbedder { + cfg, + _sim: PhantomData, + } } } -impl<'a> Drop for ContentEmbedder<'a> { +impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> { fn drop(&mut self) { self.cfg .base_dirs @@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> { } } -impl BatchEmbedder for ContentEmbedder<'_> { - type Embedding = Imgbedding; +impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> { + type Embedding = Metric; const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs index 353222b..5ade40d 100644 --- a/src/embedders/mod.rs +++ b/src/embedders/mod.rs @@ -1,7 +1,9 @@ pub mod ai; pub mod pure; +pub mod vecmetric; pub(crate) use ai::*; pub(crate) use pure::*; +pub(crate) use vecmetric::*; use anyhow::Result; use indicatif::{ParallelProgressIterator, ProgressStyle}; diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs new file mode 100644 index 0000000..9f2f143 --- /dev/null +++ b/src/embedders/vecmetric.rs @@ -0,0 +1,52 @@ +use super::MetricElem; +use serde::{Deserialize, Serialize}; + +pub trait VecMetric: MetricElem + From<Vec<f32>> {} + +#[derive(Deserialize, Serialize)] +pub struct AngularDistance(pub Vec<f32>); +#[derive(Deserialize, Serialize)] +pub struct EuclidianDistance(pub Vec<f32>); +#[derive(Deserialize, Serialize)] +pub struct ManhattenDistance(pub Vec<f32>); + +impl VecMetric for AngularDistance {} +impl VecMetric for EuclidianDistance {} +impl VecMetric for ManhattenDistance {} +#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } + +impl MetricElem for AngularDistance { + fn dist(&self, other: &Self) -> f64 { + let x = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| *a * *b) + .sum::<f32>(); + let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>(); + let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>(); + let cossim = x / (mag_a * mag_b).sqrt(); + cossim.acos() as f64 + } +} +impl MetricElem for EuclidianDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} +impl MetricElem for ManhattenDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).abs()) + .sum::<f32>() as f64 + } +} diff --git a/src/main.rs b/src/main.rs index 474532c..f9bf4fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,7 +19,9 @@ enum Embedder { Brightness, Hue, Color, - Content, + ContentEuclidean, + ContentAngularDistance, + ContentManhatten, } #[derive(Debug, Clone, Copy, clap::ValueEnum)] @@ -32,7 +34,7 @@ enum TspAlg { #[derive(Debug, Parser)] struct Args { /// Characteristic to sort by - #[arg(short, long, default_value = "content")] + #[arg(short, long, default_value = "content-euclidean")] embedder: Embedder, /// Symlink the sorted images into this directory @@ -175,7 +177,15 @@ fn main() -> Result<()> { Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg), Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg), Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg), - Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg), + Embedder::ContentAngularDistance => { + process_embedder(ContentEmbedder::<AngularDistance>::new(&cfg), &args, &cfg) + } + Embedder::ContentEuclidean => { + process_embedder(ContentEmbedder::<EuclidianDistance>::new(&cfg), &args, &cfg) + } + Embedder::ContentManhatten => { + process_embedder(ContentEmbedder::<ManhattenDistance>::new(&cfg), &args, &cfg) + } }?; if args.benchmark { |