diff options
author | metamuffin <metamuffin@disroot.org> | 2024-12-09 18:19:38 +0100 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2024-12-09 18:19:38 +0100 |
commit | e8ab616a5dabe2ae8d77772466f92410ebee1048 (patch) | |
tree | 80c06556d4c040d34b79824f45026fc855cee482 /src/embedders/vecmetric.rs | |
parent | e71c5d901beec2b15052785893f7250f958f7719 (diff) | |
download | embeddings-sort-e8ab616a5dabe2ae8d77772466f92410ebee1048.tar embeddings-sort-e8ab616a5dabe2ae8d77772466f92410ebee1048.tar.bz2 embeddings-sort-e8ab616a5dabe2ae8d77772466f92410ebee1048.tar.zst |
fix bug with angular distance and floating point error; added cosine distance
Diffstat (limited to 'src/embedders/vecmetric.rs')
-rw-r--r-- | src/embedders/vecmetric.rs | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs index 65d71df..2ebd170 100644 --- a/src/embedders/vecmetric.rs +++ b/src/embedders/vecmetric.rs @@ -6,14 +6,18 @@ pub trait VecMetric: MetricElem + From<Vec<f32>> {} #[derive(Decode, Encode)] pub struct AngularDistance(pub Vec<f32>); #[derive(Decode, Encode)] +pub struct CosineDistance(pub Vec<f32>); +#[derive(Decode, Encode)] pub struct EuclidianDistance(pub Vec<f32>); #[derive(Decode, Encode)] pub struct ManhattenDistance(pub Vec<f32>); impl VecMetric for AngularDistance {} +impl VecMetric for CosineDistance {} impl VecMetric for EuclidianDistance {} impl VecMetric for ManhattenDistance {} #[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for CosineDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } #[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } #[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } @@ -28,7 +32,18 @@ impl MetricElem for AngularDistance { let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>(); let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>(); let cossim = x / (mag_a * mag_b).sqrt(); - cossim.acos() as f64 + // clamp is require because floating point errors + cossim.clamp(-1., 1.).acos() as f64 + } +} +impl MetricElem for CosineDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (*a - *b) * (*b - *a)) + .sum::<f32>() + .sqrt() as f64 } } impl MetricElem for EuclidianDistance { |