aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-09-20 17:23:45 +0200
committermetamuffin <metamuffin@disroot.org>2023-09-20 17:23:45 +0200
commit4a17c06f22d3236da6f30c397695ef3771a9d393 (patch)
tree5d8477d40ad4119406d93e01541bf97bdcc85c0f
parentfbfee0a2bb436a6205d67f561dbd6284621504d6 (diff)
downloadembeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar
embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar.bz2
embeddings-sort-4a17c06f22d3236da6f30c397695ef3771a9d393.tar.zst
support for different vector metrics
-rw-r--r--src/embedders/ai.rs35
-rw-r--r--src/embedders/mod.rs2
-rw-r--r--src/embedders/vecmetric.rs43
-rw-r--r--src/main.rs16
4 files changed, 71 insertions, 25 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 3848674..e772e4a 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -1,39 +1,30 @@
+use crate::{BatchEmbedder, Config};
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
-use serde::{Deserialize, Serialize};
use std::{
fs::{remove_file, File},
io::{copy, BufRead, BufReader, Cursor},
+ marker::PhantomData,
path::PathBuf,
process::{Command, Stdio},
};
-use crate::{BatchEmbedder, Config, MetricElem};
+use super::vecmetric::VecMetric;
-#[repr(transparent)]
-#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding(Vec<f32>);
-impl MetricElem for Imgbedding {
- fn dist(&self, other: &Self) -> f64 {
- self.0
- .iter()
- .zip(other.0.iter())
- .map(|(a, b)| (a - b).powf(2.))
- .sum::<f32>()
- .sqrt() as f64
- }
-}
-
-pub(crate) struct ContentEmbedder<'a> {
+pub(crate) struct ContentEmbedder<'a, Metric> {
cfg: &'a Config,
+ _sim: PhantomData<Metric>,
}
-impl<'a> ContentEmbedder<'a> {
+impl<'a, Metric> ContentEmbedder<'a, Metric> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg }
+ ContentEmbedder {
+ cfg,
+ _sim: PhantomData::default(),
+ }
}
}
-impl<'a> Drop for ContentEmbedder<'a> {
+impl<'a, Metric> Drop for ContentEmbedder<'a, Metric> {
fn drop(&mut self) {
self.cfg
.base_dirs
@@ -45,8 +36,8 @@ impl<'a> Drop for ContentEmbedder<'a> {
}
}
-impl BatchEmbedder for ContentEmbedder<'_> {
- type Embedding = Imgbedding;
+impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
+ type Embedding = Metric;
const NAME: &'static str = "imgbeddings";
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
index 353222b..5ade40d 100644
--- a/src/embedders/mod.rs
+++ b/src/embedders/mod.rs
@@ -1,7 +1,9 @@
pub mod ai;
pub mod pure;
+pub mod vecmetric;
pub(crate) use ai::*;
pub(crate) use pure::*;
+pub(crate) use vecmetric::*;
use anyhow::Result;
use indicatif::{ParallelProgressIterator, ProgressStyle};
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
new file mode 100644
index 0000000..474a6d0
--- /dev/null
+++ b/src/embedders/vecmetric.rs
@@ -0,0 +1,43 @@
+use super::MetricElem;
+use serde::{Deserialize, Serialize};
+
+pub trait VecMetric: MetricElem + From<Vec<f32>> {}
+
+#[derive(Deserialize, Serialize)]
+pub struct CosineSimilarity(pub Vec<f32>);
+#[derive(Deserialize, Serialize)]
+pub struct EuclidianDistance(pub Vec<f32>);
+#[derive(Deserialize, Serialize)]
+pub struct ManhattenDistance(pub Vec<f32>);
+
+impl VecMetric for CosineSimilarity {}
+impl VecMetric for EuclidianDistance {}
+impl VecMetric for ManhattenDistance {}
+#[rustfmt::skip] impl From<Vec<f32>> for CosineSimilarity { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for EuclidianDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+#[rustfmt::skip] impl From<Vec<f32>> for ManhattenDistance { fn from(value: Vec<f32>) -> Self { Self(value) } }
+
+impl MetricElem for CosineSimilarity {
+ fn dist(&self, _other: &Self) -> f64 {
+ todo!()
+ }
+}
+impl MetricElem for EuclidianDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).powf(2.))
+ .sum::<f32>()
+ .sqrt() as f64
+ }
+}
+impl MetricElem for ManhattenDistance {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).abs())
+ .sum::<f32>() as f64
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 474532c..2dda2da 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -19,7 +19,9 @@ enum Embedder {
Brightness,
Hue,
Color,
- Content,
+ ContentEuclidean,
+ ContentCosineSim,
+ ContentManhatten,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@@ -32,7 +34,7 @@ enum TspAlg {
#[derive(Debug, Parser)]
struct Args {
/// Characteristic to sort by
- #[arg(short, long, default_value = "content")]
+ #[arg(short, long, default_value = "content-euclidean")]
embedder: Embedder,
/// Symlink the sorted images into this directory
@@ -175,7 +177,15 @@ fn main() -> Result<()> {
Embedder::Brightness => process_embedder(BrightnessEmbedder, &args, &cfg),
Embedder::Hue => process_embedder(HueEmbedder, &args, &cfg),
Embedder::Color => process_embedder(ColorEmbedder, &args, &cfg),
- Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), &args, &cfg),
+ Embedder::ContentCosineSim => {
+ process_embedder(ContentEmbedder::<CosineSimilarity>::new(&cfg), &args, &cfg)
+ }
+ Embedder::ContentEuclidean => {
+ process_embedder(ContentEmbedder::<EuclidianDistance>::new(&cfg), &args, &cfg)
+ }
+ Embedder::ContentManhatten => {
+ process_embedder(ContentEmbedder::<ManhattenDistance>::new(&cfg), &args, &cfg)
+ }
}?;
if args.benchmark {