aboutsummaryrefslogtreecommitdiff
path: root/src/embedders
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2024-12-09 18:19:38 +0100
committermetamuffin <metamuffin@disroot.org>2024-12-09 18:19:38 +0100
commite8ab616a5dabe2ae8d77772466f92410ebee1048 (patch)
tree80c06556d4c040d34b79824f45026fc855cee482 /src/embedders
parente71c5d901beec2b15052785893f7250f958f7719 (diff)
downloadembeddings-sort-e8ab616a5dabe2ae8d77772466f92410ebee1048.tar
embeddings-sort-e8ab616a5dabe2ae8d77772466f92410ebee1048.tar.bz2
embeddings-sort-e8ab616a5dabe2ae8d77772466f92410ebee1048.tar.zst
fix bug with angular distance and floating point error; added cosine distance
Diffstat (limited to 'src/embedders')
-rw-r--r--src/embedders/vecmetric.rs17
1 files changed, 16 insertions, 1 deletions
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
index 65d71df..2ebd170 100644
--- a/src/embedders/vecmetric.rs
+++ b/src/embedders/vecmetric.rs
@@ -6,14 +6,18 @@ pub trait VecMetric: MetricElem + From<Vec<f32>> {}
#[derive(Decode, Encode)]
pub struct AngularDistance(pub Vec<f32>);
#[derive(Decode, Encode)]
+pub struct CosineDistance(pub Vec<f32>);
+#[derive(Decode, Encode)]
pub struct EuclidianDistance(pub Vec<f32>);
#[derive(Decode, Encode)]
pub struct ManhattenDistance(pub Vec<f32>);
impl VecMetric for AngularDistance {}
+impl VecMetric for CosineDistance {}
impl VecMetric for EuclidianDistance {}
impl VecMetric for ManhattenDistance {}
#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for CosineDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
@@ -28,7 +32,18 @@ impl MetricElem for AngularDistance {
let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>();
let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>();
let cossim = x / (mag_a * mag_b).sqrt();
- cossim.acos() as f64
+ // clamp is require because floating point errors
+ cossim.clamp(-1., 1.).acos() as f64
+ }
+}
+impl MetricElem for CosineDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (*a - *b) * (*b - *a))
+ .sum::<f32>()
+ .sqrt() as f64
}
}
impl MetricElem for EuclidianDistance {