aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlialenck <lialenck@noreply.codeberg.org>2023-09-20 20:29:11 +0000
committerlialenck <lialenck@noreply.codeberg.org>2023-09-20 20:29:11 +0000
commitc65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332 (patch)
tree6698bf8fb3e3b6c77207533b970badfe6c9b2eb4
parentfbfee0a2bb436a6205d67f561dbd6284621504d6 (diff)
parentbcfc1b328f10188567ec99720a04418dec728868 (diff)
downloadembeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar
embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.bz2
embeddings-sort-c65db2b7bf16f1e1b4f35cf796d3d1d78c1a9332.tar.zst
Merge pull request 'support for different vector metrics' (#1) from metamuffin/embeddings-sort:main into main
-rw-r--r--src/embedders/ai.rs35
-rw-r--r--src/embedders/mod.rs2
-rw-r--r--src/embedders/vecmetric.rs52
-rw-r--r--src/main.rs16
4 files changed, 80 insertions, 25 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 3848674..120714c 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -1,39 +1,30 @@
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
-use serde::{Deserialize, Serialize};
use std::{
fs::{remove_file, File},
io::{copy, BufRead, BufReader, Cursor},
+ marker::PhantomData,
path::PathBuf,
process::{Command, Stdio},
};
-use crate::{BatchEmbedder, Config, MetricElem};
+use super::vecmetric::VecMetric;
+use crate::{BatchEmbedder, Config};
-#[repr(transparent)]
-#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding(Vec<f32>);
-impl MetricElem for Imgbedding {
- fn dist(&self, other: &Self) -> f64 {
- self.0
- .iter()
- .zip(other.0.iter())
- .map(|(a, b)| (a - b).powf(2.))
- .sum::<f32>()
- .sqrt() as f64
- }
-}
-
-pub(crate) struct ContentEmbedder<'a> {
+pub(crate) struct ContentEmbedder<'a, Metric> {
cfg: &'a Config,
+ _sim: PhantomData<Metric>,
}
-impl<'a> ContentEmbedder<'a> {
+impl<'a, Metric> ContentEmbedder<'a, Metric> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg }
+ ContentEmbedder {
+ cfg,
+ _sim: PhantomData,
+ }
}
}
-impl<'a> Drop for ContentEmbedder<'a> {
+impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
fn drop(&mut self) {
self.cfg
.base_dirs
@@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> {
}
}
-impl BatchEmbedder for ContentEmbedder<'_> {
- type Embedding = Imgbedding;
+impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
+ type Embedding = Metric;
const NAME: &'static str = "imgbeddings";
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
index 353222b..5ade40d 100644
--- a/src/embedders/mod.rs
+++ b/src/embedders/mod.rs
@@ -1,7 +1,9 @@
pub mod ai;
pub mod pure;
+pub mod vecmetric;
pub(crate) use ai::*;
pub(crate) use pure::*;
+pub(crate) use vecmetric::*;
use anyhow::Result;
use indicatif::{ParallelProgressIterator, ProgressStyle};
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
new file mode 100644
index 0000000..9f2f143
--- /dev/null
+++ b/src/embedders/vecmetric.rs
@@ -0,0 +1,52 @@
+use super::MetricElem;
+use serde::{Deserialize, Serialize};
+
+pub trait VecMetric: MetricElem + From<Vec<f32>> {}
+
+#[derive(Deserialize, Serialize)]
+pub struct AngularDistance(pub Vec<f32>);
+#[derive(Deserialize, Serialize)]
+pub struct EuclidianDistance(pub Vec<f32>);
+#[derive(Deserialize, Serialize)]
+pub struct ManhattenDistance(pub Vec<f32>);
+
+impl VecMetric for AngularDistance {}
+impl VecMetric for EuclidianDistance {}
+impl VecMetric for ManhattenDistance {}
+#[rustfmt::skip] impl From<Vec<f32>> for AngularDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+
+impl MetricElem for AngularDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ let x = self
+ .0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| *a * *b)
+ .sum::<f32>();
+ let mag_a = self.0.iter().map(|x| x.powi(2)).sum::<f32>();
+ let mag_b = other.0.iter().map(|x| x.powi(2)).sum::<f32>();
+ let cossim = x / (mag_a * mag_b).sqrt();
+ cossim.acos() as f64
+ }
+}
+impl MetricElem for EuclidianDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).powf(2.))
+ .sum::<f32>()
+ .sqrt() as f64
+ }
+}
+impl MetricElem for ManhattenDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).abs())
+ .sum::<f32>() as f64
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 474532c..f9bf4fc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -19,7 +19,9 @@ enum Embedder {
Brightness,
Hue,
Color,
- Content,
+ ContentEuclidean,
+ ContentAngularDistance,
+ ContentManhatten,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@@ -32,7 +34,7 @@ enum TspAlg {
#[derive(Debug, Parser)]
struct Args {
/// Characteristic to sort by
- #[arg(short, long, default_value = "content")]
+ #[arg(short, long, default_value = "content-euclidean")]
embedder: Embedder,
/// Symlink the sorted images into this directory
@@ -175,7 +177,15 @@ fn main() -> Result<()> {
Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg),
Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg),
Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg),
- Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg),
+ Embedder::ContentAngularDistance => {
+ process_embedder(ContentEmbedder::<AngularDistance>::new(&cfg), &args, &cfg)
+ }
+ Embedder::ContentEuclidean => {
+ process_embedder(ContentEmbedder::<EuclidianDistance>::new(&cfg), &args, &cfg)
+ }
+ Embedder::ContentManhatten => {
+ process_embedder(ContentEmbedder::<ManhattenDistance>::new(&cfg), &args, &cfg)
+ }
}?;
if args.benchmark {