aboutsummaryrefslogtreecommitdiff
path: root/src/embedders
diff options
context:
space:
mode:
Diffstat (limited to 'src/embedders')
-rw-r--r--src/embedders/ai.rs100
-rw-r--r--src/embedders/imgbeddings-api.py18
-rw-r--r--src/embedders/mod.rs51
-rw-r--r--src/embedders/pure.rs91
4 files changed, 260 insertions, 0 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
new file mode 100644
index 0000000..3848674
--- /dev/null
+++ b/src/embedders/ai.rs
@@ -0,0 +1,100 @@
+use anyhow::Result;
+use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
+use serde::{Deserialize, Serialize};
+use std::{
+ fs::{remove_file, File},
+ io::{copy, BufRead, BufReader, Cursor},
+ path::PathBuf,
+ process::{Command, Stdio},
+};
+
+use crate::{BatchEmbedder, Config, MetricElem};
+
+#[repr(transparent)]
+#[derive(Serialize, Deserialize)]
+pub(crate) struct Imgbedding(Vec<f32>);
+impl MetricElem for Imgbedding {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).powf(2.))
+ .sum::<f32>()
+ .sqrt() as f64
+ }
+}
+
+pub(crate) struct ContentEmbedder<'a> {
+ cfg: &'a Config,
+}
+impl<'a> ContentEmbedder<'a> {
+ pub(crate) fn new(cfg: &'a Config) -> Self {
+ ContentEmbedder { cfg }
+ }
+}
+
+impl<'a> Drop for ContentEmbedder<'a> {
+ fn drop(&mut self) {
+ self.cfg
+ .base_dirs
+ .place_runtime_file("imgbeddings-api.py")
+ .iter()
+ .for_each(|p| {
+ let _ = remove_file(p);
+ });
+ }
+}
+
+impl BatchEmbedder for ContentEmbedder<'_> {
+ type Embedding = Imgbedding;
+ const NAME: &'static str = "imgbeddings";
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+ let venv_dir = self
+ .cfg
+ .base_dirs
+ .create_data_directory("imgbeddings-venv")?;
+ let script_file = self
+ .cfg
+ .base_dirs
+ .place_runtime_file("imgbeddings-api.py")?;
+
+ let api_prog = include_bytes!("imgbeddings-api.py");
+ copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;
+
+ let bar = ProgressBar::new_spinner();
+ bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
+ bar.enable_steady_tick(std::time::Duration::from_millis(100));
+
+ bar.set_message("Creating venv...");
+ Command::new("python3")
+ .args(["-m", "venv", venv_dir.to_str().unwrap()])
+ .stdout(Stdio::null())
+ .spawn()?
+ .wait()?;
+
+ bar.set_message("Installing/checking packages...");
+ Command::new(venv_dir.join("bin/pip3"))
+ .args(["install", "imgbeddings"])
+ .stdout(Stdio::null())
+ .spawn()?
+ .wait()?;
+ bar.finish();
+
+ let child = Command::new(venv_dir.join("bin/python3"))
+ .arg(script_file)
+ .args(paths)
+ .stderr(Stdio::null())
+ .stdout(Stdio::piped())
+ .spawn()?;
+
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ let bar = ProgressBar::new(paths.len() as u64).with_style(st);
+ BufReader::new(child.stdout.unwrap())
+ .lines()
+ .progress_with(bar)
+ .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
+ .try_collect()
+ }
+}
diff --git a/src/embedders/imgbeddings-api.py b/src/embedders/imgbeddings-api.py
new file mode 100644
index 0000000..0c890b5
--- /dev/null
+++ b/src/embedders/imgbeddings-api.py
@@ -0,0 +1,18 @@
+from PIL import Image
+from imgbeddings import imgbeddings
+import sys, os
+import json as j
+
+b = imgbeddings()
+
+paths = sys.argv[1:]
+batch_size = 8
+
+for i in range(0, len(paths), batch_size):
+ fs = paths[i:i+batch_size]
+
+ ims = [Image.open(open(f, "rb")) for f in fs]
+ for emb in b.to_embeddings(ims).tolist():
+ print(j.dumps(emb))
+
+sys.stderr.write("\n")
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
new file mode 100644
index 0000000..353222b
--- /dev/null
+++ b/src/embedders/mod.rs
@@ -0,0 +1,51 @@
+pub mod ai;
+pub mod pure;
+pub(crate) use ai::*;
+pub(crate) use pure::*;
+
+use anyhow::Result;
+use indicatif::{ParallelProgressIterator, ProgressStyle};
+use rayon::prelude::*;
+use serde::{Deserialize, Serialize};
+use std::path::{Path, PathBuf};
+
+pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
+ fn dist(&self, _: &Self) -> f64;
+}
+
+impl MetricElem for f64 {
+ fn dist(&self, b: &f64) -> f64 {
+ (self - b).abs()
+ }
+}
+
+pub trait EmbedderT: Send + Sync {
+ type Embedding: MetricElem;
+ const NAME: &'static str;
+
+ fn embed(&self, _: &Path) -> Result<Self::Embedding>;
+}
+
+pub trait BatchEmbedder: Send + Sync {
+ type Embedding: MetricElem;
+ const NAME: &'static str;
+
+ fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
+}
+
+impl<T: EmbedderT> BatchEmbedder for T {
+ type Embedding = T::Embedding;
+ const NAME: &'static str = T::NAME;
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ paths
+ .par_iter()
+ .progress_with_style(st)
+ .map(|p| self.embed(p))
+ .collect::<Vec<_>>()
+ .into_iter()
+ .try_collect()
+ }
+}
diff --git a/src/embedders/pure.rs b/src/embedders/pure.rs
new file mode 100644
index 0000000..09c8321
--- /dev/null
+++ b/src/embedders/pure.rs
@@ -0,0 +1,91 @@
+use anyhow::{bail, Result};
+use serde::{Deserialize, Serialize};
+use std::path::Path;
+
+use crate::{EmbedderT, MetricElem};
+
+pub(crate) struct BrightnessEmbedder;
+impl EmbedderT for BrightnessEmbedder {
+ type Embedding = f64;
+ const NAME: &'static str = "Brightness";
+
+ fn embed(&self, path: &Path) -> Result<f64> {
+ let im = image::open(path)?;
+ let num_bytes = 3 * (im.height() * im.width());
+
+ if num_bytes == 0 {
+ bail!("Encountered NaN brightness, due to an empty image");
+ }
+
+ Ok(im.to_rgb8().iter().map(|e| *e as u64).sum::<u64>() as f64 / num_bytes as f64)
+ }
+}
+
+#[repr(transparent)]
+#[derive(Serialize, Deserialize)]
+pub(crate) struct Hue(f64);
+impl MetricElem for Hue {
+ fn dist(&self, b: &Hue) -> f64 {
+ let d = self.0.dist(&b.0);
+ d.min(6. - d)
+ }
+}
+pub(crate) struct HueEmbedder;
+impl EmbedderT for HueEmbedder {
+ type Embedding = Hue;
+ const NAME: &'static str = "Hue";
+
+ fn embed(&self, path: &Path) -> Result<Hue> {
+ let im = image::open(path)?;
+ let num_pixels = im.height() * im.width();
+ let [sr, sg, sb] = im
+ .to_rgb8()
+ .pixels()
+ .fold([0, 0, 0], |[or, og, ob], n| {
+ let [nr, ng, nb] = n.0;
+ [or + nr as u64, og + ng as u64, ob + nb as u64]
+ })
+ .map(|e| e as f64 / 255. / num_pixels as f64);
+
+ let hue = if sr >= sg && sr >= sb {
+ (sg - sb) / (sr - sg.min(sb))
+ } else if sg >= sb {
+ 2. + (sb - sr) / (sg - sr.min(sb))
+ } else {
+ 4. + (sr - sg) / (sb - sr.min(sg))
+ };
+
+ if hue.is_nan() {
+ bail!("Encountered NaN hue, possibly because of a colorless or empty image");
+ }
+
+ Ok(Hue(hue))
+ }
+}
+
+impl MetricElem for (f64, f64, f64) {
+ fn dist(&self, o: &(f64, f64, f64)) -> f64 {
+ let (dr, dg, db) = ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2));
+ (dr * dr + dg * dg + db * db).sqrt()
+ }
+}
+pub(crate) struct ColorEmbedder;
+impl EmbedderT for ColorEmbedder {
+ type Embedding = (f64, f64, f64);
+ const NAME: &'static str = "Color";
+
+ fn embed(&self, path: &Path) -> Result<(f64, f64, f64)> {
+ let im = image::open(path)?;
+ let num_pixels = im.height() * im.width();
+ let [sr, sg, sb] = im
+ .to_rgb8()
+ .pixels()
+ .fold([0, 0, 0], |[or, og, ob], n| {
+ let [nr, ng, nb] = n.0;
+ [or + nr as u64, og + ng as u64, ob + nb as u64]
+ })
+ .map(|e| e as f64 / num_pixels as f64);
+
+ Ok((sr, sg, sb))
+ }
+}