diff options
Diffstat (limited to 'src/embedders')
-rw-r--r-- | src/embedders/ai.rs | 100 | ||||
-rw-r--r-- | src/embedders/imgbeddings-api.py | 18 | ||||
-rw-r--r-- | src/embedders/mod.rs | 51 | ||||
-rw-r--r-- | src/embedders/pure.rs | 91 |
4 files changed, 260 insertions, 0 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs new file mode 100644 index 0000000..3848674 --- /dev/null +++ b/src/embedders/ai.rs @@ -0,0 +1,100 @@ +use anyhow::Result; +use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; +use serde::{Deserialize, Serialize}; +use std::{ + fs::{remove_file, File}, + io::{copy, BufRead, BufReader, Cursor}, + path::PathBuf, + process::{Command, Stdio}, +}; + +use crate::{BatchEmbedder, Config, MetricElem}; + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Imgbedding(Vec<f32>); +impl MetricElem for Imgbedding { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} + +pub(crate) struct ContentEmbedder<'a> { + cfg: &'a Config, +} +impl<'a> ContentEmbedder<'a> { + pub(crate) fn new(cfg: &'a Config) -> Self { + ContentEmbedder { cfg } + } +} + +impl<'a> Drop for ContentEmbedder<'a> { + fn drop(&mut self) { + self.cfg + .base_dirs + .place_runtime_file("imgbeddings-api.py") + .iter() + .for_each(|p| { + let _ = remove_file(p); + }); + } +} + +impl BatchEmbedder for ContentEmbedder<'_> { + type Embedding = Imgbedding; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { + let venv_dir = self + .cfg + .base_dirs + .create_data_directory("imgbeddings-venv")?; + let script_file = self + .cfg + .base_dirs + .place_runtime_file("imgbeddings-api.py")?; + + let api_prog = include_bytes!("imgbeddings-api.py"); + copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; + + let bar = ProgressBar::new_spinner(); + bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + + bar.set_message("Creating venv..."); + Command::new("python3") + .args(["-m", "venv", venv_dir.to_str().unwrap()]) + .stdout(Stdio::null()) + .spawn()? + .wait()?; + + bar.set_message("Installing/checking packages..."); + Command::new(venv_dir.join("bin/pip3")) + .args(["install", "imgbeddings"]) + .stdout(Stdio::null()) + .spawn()? + .wait()?; + bar.finish(); + + let child = Command::new(venv_dir.join("bin/python3")) + .arg(script_file) + .args(paths) + .stderr(Stdio::null()) + .stdout(Stdio::piped()) + .spawn()?; + + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + let bar = ProgressBar::new(paths.len() as u64).with_style(st); + BufReader::new(child.stdout.unwrap()) + .lines() + .progress_with(bar) + .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) + .try_collect() + } +} diff --git a/src/embedders/imgbeddings-api.py b/src/embedders/imgbeddings-api.py new file mode 100644 index 0000000..0c890b5 --- /dev/null +++ b/src/embedders/imgbeddings-api.py @@ -0,0 +1,18 @@ +from PIL import Image +from imgbeddings import imgbeddings +import sys, os +import json as j + +b = imgbeddings() + +paths = sys.argv[1:] +batch_size = 8 + +for i in range(0, len(paths), batch_size): + fs = paths[i:i+batch_size] + + ims = [Image.open(open(f, "rb")) for f in fs] + for emb in b.to_embeddings(ims).tolist(): + print(j.dumps(emb)) + +sys.stderr.write("\n") diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs new file mode 100644 index 0000000..353222b --- /dev/null +++ b/src/embedders/mod.rs @@ -0,0 +1,51 @@ +pub mod ai; +pub mod pure; +pub(crate) use ai::*; +pub(crate) use pure::*; + +use anyhow::Result; +use indicatif::{ParallelProgressIterator, ProgressStyle}; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; + +pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { + fn dist(&self, _: &Self) -> f64; +} + +impl MetricElem for f64 { + fn dist(&self, b: &f64) -> f64 { + (self - b).abs() + } +} + +pub trait EmbedderT: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embed(&self, _: &Path) -> Result<Self::Embedding>; +} + +pub trait BatchEmbedder: Send + Sync { + type Embedding: MetricElem; + const NAME: &'static str; + + fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>; +} + +impl<T: EmbedderT> BatchEmbedder for T { + type Embedding = T::Embedding; + const NAME: &'static str = T::NAME; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + paths + .par_iter() + .progress_with_style(st) + .map(|p| self.embed(p)) + .collect::<Vec<_>>() + .into_iter() + .try_collect() + } +} diff --git a/src/embedders/pure.rs b/src/embedders/pure.rs new file mode 100644 index 0000000..09c8321 --- /dev/null +++ b/src/embedders/pure.rs @@ -0,0 +1,91 @@ +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +use crate::{EmbedderT, MetricElem}; + +pub(crate) struct BrightnessEmbedder; +impl EmbedderT for BrightnessEmbedder { + type Embedding = f64; + const NAME: &'static str = "Brightness"; + + fn embed(&self, path: &Path) -> Result<f64> { + let im = image::open(path)?; + let num_bytes = 3 * (im.height() * im.width()); + + if num_bytes == 0 { + bail!("Encountered NaN brightness, due to an empty image"); + } + + Ok(im.to_rgb8().iter().map(|e| *e as u64).sum::<u64>() as f64 / num_bytes as f64) + } +} + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Hue(f64); +impl MetricElem for Hue { + fn dist(&self, b: &Hue) -> f64 { + let d = self.0.dist(&b.0); + d.min(6. - d) + } +} +pub(crate) struct HueEmbedder; +impl EmbedderT for HueEmbedder { + type Embedding = Hue; + const NAME: &'static str = "Hue"; + + fn embed(&self, path: &Path) -> Result<Hue> { + let im = image::open(path)?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / 255. / num_pixels as f64); + + let hue = if sr >= sg && sr >= sb { + (sg - sb) / (sr - sg.min(sb)) + } else if sg >= sb { + 2. + (sb - sr) / (sg - sr.min(sb)) + } else { + 4. + (sr - sg) / (sb - sr.min(sg)) + }; + + if hue.is_nan() { + bail!("Encountered NaN hue, possibly because of a colorless or empty image"); + } + + Ok(Hue(hue)) + } +} + +impl MetricElem for (f64, f64, f64) { + fn dist(&self, o: &(f64, f64, f64)) -> f64 { + let (dr, dg, db) = ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2)); + (dr * dr + dg * dg + db * db).sqrt() + } +} +pub(crate) struct ColorEmbedder; +impl EmbedderT for ColorEmbedder { + type Embedding = (f64, f64, f64); + const NAME: &'static str = "Color"; + + fn embed(&self, path: &Path) -> Result<(f64, f64, f64)> { + let im = image::open(path)?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / num_pixels as f64); + + Ok((sr, sg, sb)) + } +} |