aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-09-20 22:15:48 +0200
committermetamuffin <metamuffin@disroot.org>2023-09-20 22:15:48 +0200
commit4c7c58f487d0ccb70162421e5a871a4020454022 (patch)
treec50397cf990c99e61e6a45a7703737faea32f6ca
parentd0ce1d9134968a15d37135622138f6b8b7667454 (diff)
downloadembeddings-sort-4c7c58f487d0ccb70162421e5a871a4020454022.tar
embeddings-sort-4c7c58f487d0ccb70162421e5a871a4020454022.tar.bz2
embeddings-sort-4c7c58f487d0ccb70162421e5a871a4020454022.tar.zst
use angular distance instead of cossim replacement
-rw-r--r--src/embedders/ai.rs2
-rw-r--r--src/embedders/vecmetric.rs22
-rw-r--r--src/main.rs6
3 files changed, 16 insertions, 14 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index e772e4a..ac27708 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -1,4 +1,3 @@
-use crate::{BatchEmbedder, Config};
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
use std::{
@@ -10,6 +9,7 @@ use std::{
};
use super::vecmetric::VecMetric;
+use crate::{BatchEmbedder, Config};
pub(crate) struct ContentEmbedder<'a, Metric> {
cfg: &'a Config,
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
index 1bda3a8..9f2f143 100644
--- a/src/embedders/vecmetric.rs
+++ b/src/embedders/vecmetric.rs
@@ -4,29 +4,31 @@ use serde::{Deserialize, Serialize};
pub trait VecMetric: MetricElem + From<Vec<f32>> {}
#[derive(Deserialize, Serialize)]
-pub struct CosineSimilarity(pub Vec<f32>);
+pub struct AngularDistance(pub Vec<f32>);
#[derive(Deserialize, Serialize)]
pub struct EuclidianDistance(pub Vec<f32>);
#[derive(Deserialize, Serialize)]
pub struct ManhattenDistance(pub Vec<f32>);
-impl VecMetric for CosineSimilarity {}
+impl VecMetric for AngularDistance {}
impl VecMetric for EuclidianDistance {}
impl VecMetric for ManhattenDistance {}
-#[rustfmt::skip] impl From<Vec<f32>> for CosineSimilarity { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
-impl MetricElem for CosineSimilarity {
+impl MetricElem for AngularDistance {
fn dist(&self, other: &Self) -> f64 {
- let len_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
- let len_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
- self.0
+ let x = self
+ .0
.iter()
.zip(other.0.iter())
- .map(|(a, b)| (*a / len_a - *b / len_b).powi(2))
- .sum::<f32>()
- .sqrt() as f64
+ .map(|(a, b)| *a * *b)
+ .sum::<f32>();
+ let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>();
+ let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>();
+ let cossim = x / (mag_a * mag_b).sqrt();
+ cossim.acos() as f64
}
}
impl MetricElem for EuclidianDistance {
diff --git a/src/main.rs b/src/main.rs
index 2dda2da..f9bf4fc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -20,7 +20,7 @@ enum Embedder {
Hue,
Color,
ContentEuclidean,
- ContentCosineSim,
+ ContentAngularDistance,
ContentManhatten,
}
@@ -177,8 +177,8 @@ fn main() -> Result<()> {
Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg),
Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg),
Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg),
- Embedder::ContentCosineSim => {
- process_embedder(ContentEmbedder::<CosineSimilarity>::new(&cfg), &args, &cfg)
+ Embedder::ContentAngularDistance => {
+ process_embedder(ContentEmbedder::<AngularDistance>::new(&cfg), &args, &cfg)
}
Embedder::ContentEuclidean => {
process_embedder(ContentEmbedder::<EuclidianDistance>::new(&cfg), &args, &cfg)