aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/vecmetric.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders/vecmetric.rs')
-rw-r--r--src/embedders/vecmetric.rs13
1 files changed, 6 insertions, 7 deletions
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
index 0c63911..1bda3a8 100644
--- a/src/embedders/vecmetric.rs
+++ b/src/embedders/vecmetric.rs
@@ -19,15 +19,14 @@ impl VecMetric for ManhattenDistance {}
impl MetricElem for CosineSimilarity {
fn dist(&self, other: &Self) -> f64 {
- let x = self
- .0
+ let len_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
+ let len_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
+ self.0
.iter()
.zip(other.0.iter())
- .map(|(a, b)| *a * *b)
- .sum::<f32>();
- let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>();
- let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>();
- (x / (mag_a * mag_b).sqrt()) as f64
+ .map(|(a, b)| (*a / len_a - *b / len_b).powi(2))
+ .sum::<f32>()
+ .sqrt() as f64
}
}
impl MetricElem for EuclidianDistance {