aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/vecmetric.rs
diff options
context:
space:
mode:
authorlialenck <lialenck@noreply.codeberg.org>2023-09-20 20:29:11 +0000
committerlialenck <lialenck@noreply.codeberg.org>2023-09-20 20:29:11 +0000
commitc65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332 (patch)
tree6698bf8fb3e3b6c77207533b970badfe6c9b2eb4 /src/embedders/vecmetric.rs
parentfbfee0a2bb436a6205d67f561dbd6284621504d6 (diff)
parentbcfc1b328f10188567ec99720a04418dec728868 (diff)
downloadembeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar
embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.bz2
embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.zst
Merge pull request 'support for different vector metrics' (#1) from metamuffin/embeddings-sort:main into main
Diffstat (limited to 'src/embedders/vecmetric.rs')
-rw-r--r--src/embedders/vecmetric.rs52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
new file mode 100644
index 0000000..9f2f143
--- /dev/null
+++ b/src/embedders/vecmetric.rs
@@ -0,0 +1,52 @@
+use super::MetricElem;
+use serde::{Deserialize, Serialize};
+
+pub trait VecMetric: MetricElem + From<Vec<f32>> {}
+
+#[derive(Deserialize, Serialize)]
+pub struct AngularDistance(pub Vec<f32>);
+#[derive(Deserialize, Serialize)]
+pub struct EuclidianDistance(pub Vec<f32>);
+#[derive(Deserialize, Serialize)]
+pub struct ManhattenDistance(pub Vec<f32>);
+
+impl VecMetric for AngularDistance {}
+impl VecMetric for EuclidianDistance {}
+impl VecMetric for ManhattenDistance {}
+#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+
+impl MetricElem for AngularDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ let x = self
+ .0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| *a * *b)
+ .sum::<f32>();
+ let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>();
+ let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>();
+ let cossim = x / (mag_a * mag_b).sqrt();
+ cossim.acos() as f64
+ }
+}
+impl MetricElem for EuclidianDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).powf(2.))
+ .sum::<f32>()
+ .sqrt() as f64
+ }
+}
+impl MetricElem for ManhattenDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).abs())
+ .sum::<f32>() as f64
+ }
+}