diff options
author | metamuffin <metamuffin@disroot.org> | 2023-09-20 22:15:48 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-09-20 22:15:48 +0200 |
commit | 4c7c58f487d0ccb70162421e5a871a4020454022 (patch) | |
tree | c50397cf990c99e61e6a45a7703737faea32f6ca | |
parent | d0ce1d9134968a15d37135622138f6b8b7667454 (diff) | |
download | embeddings-sort-4c7c58f487d0ccb70162421e5a871a4020454022.tar embeddings-sort-4c7c58f487d0ccb70162421e5a871a4020454022.tar.bz2 embeddings-sort-4c7c58f487d0ccb70162421e5a871a4020454022.tar.zst |
use angular distance instead of cossim replacement
-rw-r--r-- | src/embedders/ai.rs | 2 | ||||
-rw-r--r-- | src/embedders/vecmetric.rs | 22 | ||||
-rw-r--r-- | src/main.rs | 6 |
3 files changed, 16 insertions, 14 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs index e772e4a..ac27708 100644 --- a/src/embedders/ai.rs +++ b/src/embedders/ai.rs @@ -1,4 +1,3 @@ -use crate::{BatchEmbedder, Config}; use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; use std::{ @@ -10,6 +9,7 @@ use std::{ }; use super::vecmetric::VecMetric; +use crate::{BatchEmbedder, Config}; pub(crate) struct ContentEmbedder<'a, Metric> { cfg: &'a Config, diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs index 1bda3a8..9f2f143 100644 --- a/src/embedders/vecmetric.rs +++ b/src/embedders/vecmetric.rs @@ -4,29 +4,31 @@ use serde::{Deserialize, Serialize}; pub trait VecMetric: MetricElem + From<Vec<f32>> {} #[derive(Deserialize, Serialize)] -pub struct CosineSimilarity(pub Vec<f32>); +pub struct AngularDistance(pub Vec<f32>); #[derive(Deserialize, Serialize)] pub struct EuclidianDistance(pub Vec<f32>); #[derive(Deserialize, Serialize)] pub struct ManhattenDistance(pub Vec<f32>); -impl VecMetric for CosineSimilarity {} +impl VecMetric for AngularDistance {} impl VecMetric for EuclidianDistance {} impl VecMetric for ManhattenDistance {} -#[rustfmt::skip] impl From<Vec<f32>> for CosineSimilarity { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } #[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } #[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } -impl MetricElem for CosineSimilarity { +impl MetricElem for AngularDistance { fn dist(&self, other: &Self) -> f64 { - let len_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>().sqrt(); - let len_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>().sqrt(); - self.0 + let x = self + .0 .iter() .zip(other.0.iter()) - .map(|(a, b)| (*a / len_a - *b / len_b).powi(2)) - .sum::<f32>() - .sqrt() as f64 + .map(|(a, b)| *a * *b) + .sum::<f32>(); + let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>(); + let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>(); + let cossim = x / (mag_a * mag_b).sqrt(); + cossim.acos() as f64 } } impl MetricElem for EuclidianDistance { diff --git a/src/main.rs b/src/main.rs index 2dda2da..f9bf4fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,7 @@ enum Embedder { Hue, Color, ContentEuclidean, - ContentCosineSim, + ContentAngularDistance, ContentManhatten, } @@ -177,8 +177,8 @@ fn main() -> Result<()> { Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg), Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg), Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg), - Embedder::ContentCosineSim => { - process_embedder(ContentEmbedder::<CosineSimilarity>::new(&cfg), &args, &cfg) + Embedder::ContentAngularDistance => { + process_embedder(ContentEmbedder::<AngularDistance>::new(&cfg), &args, &cfg) } Embedder::ContentEuclidean => { process_embedder(ContentEmbedder::<EuclidianDistance>::new(&cfg), &args, &cfg) |