diff options
author | lialenck <lialenck@noreply.codeberg.org> | 2023-09-20 20:29:11 +0000 |
---|---|---|
committer | lialenck <lialenck@noreply.codeberg.org> | 2023-09-20 20:29:11 +0000 |
commit | c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332 (patch) | |
tree | 6698bf8fb3e3b6c77207533b970badfe6c9b2eb4 /src/embedders/vecmetric.rs | |
parent | fbfee0a2bb436a6205d67f561dbd6284621504d6 (diff) | |
parent | bcfc1b328f10188567ec99720a04418dec728868 (diff) | |
download | embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.bz2 embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.zst |
Merge pull request 'support for different vector metrics' (#1) from metamuffin/embeddings-sort:main into main
Diffstat (limited to 'src/embedders/vecmetric.rs')
-rw-r--r-- | src/embedders/vecmetric.rs | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs new file mode 100644 index 0000000..9f2f143 --- /dev/null +++ b/src/embedders/vecmetric.rs @@ -0,0 +1,52 @@ +use super::MetricElem; +use serde::{Deserialize, Serialize}; + +pub trait VecMetric: MetricElem + From<Vec<f32>> {} + +#[derive(Deserialize, Serialize)] +pub struct AngularDistance(pub Vec<f32>); +#[derive(Deserialize, Serialize)] +pub struct EuclidianDistance(pub Vec<f32>); +#[derive(Deserialize, Serialize)] +pub struct ManhattenDistance(pub Vec<f32>); + +impl VecMetric for AngularDistance {} +impl VecMetric for EuclidianDistance {} +impl VecMetric for ManhattenDistance {} +#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } +#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } } + +impl MetricElem for AngularDistance { + fn dist(&self, other: &Self) -> f64 { + let x = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| *a * *b) + .sum::<f32>(); + let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>(); + let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>(); + let cossim = x / (mag_a * mag_b).sqrt(); + cossim.acos() as f64 + } +} +impl MetricElem for EuclidianDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} +impl MetricElem for ManhattenDistance { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).abs()) + .sum::<f32>() as f64 + } +} |