diff options
author | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-07 01:37:34 +0200 |
---|---|---|
committer | Lia Lenckowski <lialenck@protonmail.com> | 2023-09-07 01:37:34 +0200 |
commit | c4b03717914e5c907f7f47dc2a85df6b57763c58 (patch) | |
tree | 926acc5e5497a0539ca18a936fc38029561e7417 | |
parent | 51007f5b8ff6d5960ac034854ceae1ab15237b6a (diff) | |
download | embeddings-sort-c4b03717914e5c907f7f47dc2a85df6b57763c58.tar embeddings-sort-c4b03717914e5c907f7f47dc2a85df6b57763c58.tar.bz2 embeddings-sort-c4b03717914e5c907f7f47dc2a85df6b57763c58.tar.zst |
add content embedder
-rw-r--r-- | Cargo.lock | 24 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/ai_embedders.rs | 65 | ||||
-rw-r--r-- | src/embedders.rs | 96 | ||||
-rw-r--r-- | src/imgbeddings-api.py | 15 | ||||
-rw-r--r-- | src/main.rs | 21 | ||||
-rw-r--r-- | src/pure_embedders.rs | 98 | ||||
-rw-r--r-- | test.py | 11 |
8 files changed, 218 insertions, 113 deletions
@@ -291,6 +291,7 @@ dependencies = [ "priority-queue", "rayon", "serde", + "serde_json", "sha2", "sled", "typed-sled", @@ -494,6 +495,12 @@ dependencies = [ ] [[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] name = "jpeg-decoder" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -758,6 +765,12 @@ dependencies = [ ] [[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -784,6 +797,17 @@ dependencies = [ ] [[package]] +name = "serde_json" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] name = "sha2" version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -15,5 +15,6 @@ indicatif = "0" sled = "0" typed-sled = "0" serde = "1" +serde_json = "1" sha2 = "0" anyhow = "1" diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs new file mode 100644 index 0000000..14e1d8c --- /dev/null +++ b/src/ai_embedders.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::{io::{copy, Cursor}, path::PathBuf, process::Command}; + +use crate::{Config, BatchEmbedder, MetricElem}; + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Imgbedding (Vec<f32>); // TODO das hier zu einem const size slice machen +impl MetricElem for Imgbedding { + fn dist(&self, other: &Self) -> f64 { + self.0.iter().zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} + +pub(crate) struct ContentEmbedder<'a> { + cfg: &'a Config, +} +impl<'a> ContentEmbedder<'a> { + pub(crate) fn new(cfg: &'a Config) -> Self { + ContentEmbedder { cfg: cfg } + } +} + +impl<'a> Drop for ContentEmbedder<'a> { + fn drop(&mut self) { + self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py") + .iter() + .for_each(|p| { let _ = std::fs::remove_file(&p); }); + } +} + +impl BatchEmbedder for ContentEmbedder<'_> { + type Embedding = Imgbedding; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { + let venv_dir = self.cfg.base_dirs + .create_data_directory("imgbeddings-venv")?; + let script_file = self.cfg.base_dirs + .place_runtime_file("imgbeddings-api.py")?; + + let api_prog = include_bytes!("imgbeddings-api.py"); + copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?; + + Command::new("python3") + .args(["-m", "venv", venv_dir.to_str().unwrap()]) + .spawn()? + .wait()?; + Command::new(venv_dir.join("bin/pip3")) + .args(["install", "imgbeddings"]) + .spawn()? + .wait()?; + + let output = Command::new(venv_dir.join("bin/python3")) + .arg(script_file) + .args(paths) + .output()?; + + Ok(serde_json::from_slice(&output.stdout)?) + } +} diff --git a/src/embedders.rs b/src/embedders.rs index 8911e95..7257a27 100644 --- a/src/embedders.rs +++ b/src/embedders.rs @@ -1,4 +1,4 @@ -use anyhow::{bail, Result}; +use anyhow::Result; use rayon::prelude::*; use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; @@ -40,97 +40,3 @@ impl<T: EmbedderT> BatchEmbedder for T { } } -pub struct BrightnessEmbedder; -impl EmbedderT for BrightnessEmbedder { - type Embedding = f64; - const NAME: &'static str = "Brightness"; - - fn embed(&self, path: &Path) -> Result<f64> { - let im = image::open(path)?; - let num_bytes = 3 * (im.height() * im.width()); - - if num_bytes == 0 { - bail!("Encountered NaN brightness, due to an empty image"); - } - - Ok(im.to_rgb8() - .iter() - .map(|e| *e as u64) - .sum::<u64>() as f64 / num_bytes as f64) - } -} - -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub struct Hue (f64); -impl MetricElem for Hue { - fn dist(&self, b: &Hue) -> f64 { - let d = self.0.dist(&b.0); - d.min(6.-d) - } -} - -pub struct HueEmbedder; -impl EmbedderT for HueEmbedder { - type Embedding = Hue; - const NAME: &'static str = "Hue"; - - fn embed(&self, path: &Path) -> Result<Hue> { - let im = image::open(path)?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / 255. / num_pixels as f64); - - let hue = - if sr >= sg && sr >= sb { - (sg - sb) / (sr - sg.min(sb)) - } - else if sg >= sb { - 2. + (sb - sr) / (sg - sr.min(sb)) - } - else { - 4. + (sr - sg) / (sb - sr.min(sg)) - }; - - if hue.is_nan() { - bail!("Encountered NaN hue, possibly because of a colorless or empty image"); - } - - Ok(Hue(hue)) - } -} - -impl MetricElem for (f64, f64, f64) { - fn dist(&self, o: &(f64, f64, f64)) -> f64 { - let (dr, dg, db) = - ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2)); - (dr*dr + dg*dg + db*db).sqrt() - } -} - -pub struct ColorEmbedder; -impl EmbedderT for ColorEmbedder { - type Embedding = (f64, f64, f64); - const NAME: &'static str = "Color"; - - fn embed(&self, path: &Path) -> Result<(f64, f64, f64)> { - let im = image::open(path)?; - let num_pixels = im.height() * im.width(); - let [sr, sg, sb] = im - .to_rgb8() - .pixels() - .fold([0, 0, 0], |[or, og, ob], n| { - let [nr, ng, nb] = n.0; - [or + nr as u64, og + ng as u64, ob + nb as u64] - }) - .map(|e| e as f64 / num_pixels as f64); - - Ok((sr, sg, sb)) - } -} diff --git a/src/imgbeddings-api.py b/src/imgbeddings-api.py new file mode 100644 index 0000000..795c625 --- /dev/null +++ b/src/imgbeddings-api.py @@ -0,0 +1,15 @@ +from PIL import Image +from imgbeddings import imgbeddings +import sys +import json as j +#from itertools import batched # TODO das hier ab python 3.12 + +b = imgbeddings() + +ems = [] + +for f in sys.argv[1:]: # TODO this should be batched for faster ai stuff + im = Image.open(open(f, "rb")) + ems += [b.to_embeddings(im)] + +print(j.dumps([em[0].tolist() for em in ems])) diff --git a/src/main.rs b/src/main.rs index 5caa4ce..6524de7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,14 +7,18 @@ use sha2::{Sha512_256, Digest}; use std::{cmp::Ordering, collections::HashMap, fs, io, path::PathBuf}; use embedders::*; - +use pure_embedders::*; +use ai_embedders::*; mod embedders; +mod pure_embedders; +mod ai_embedders; #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum Embedder { Brightness, Hue, Color, + Content, } #[derive(Debug, Parser)] @@ -33,7 +37,7 @@ struct Config { fn get_config() -> Result<Config> { let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")?; - Ok(Config{base_dirs: dirs}) + Ok(Config { base_dirs: dirs }) } fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>> @@ -119,8 +123,8 @@ fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { Ok(hasher.finalize().into_iter().collect::<Vec<u8>>().try_into().unwrap()) } -fn process_embedder<E>(mut e: E, args: Args, cfg: Config) -> Result<Vec<PathBuf>> - where E: EmbedderT +fn process_embedder<E>(mut e: E, args: Args, cfg: &Config) -> Result<Vec<PathBuf>> + where E: BatchEmbedder { if args.images.is_empty() { return Ok(Vec::new()); @@ -145,6 +149,8 @@ fn process_embedder<E>(mut e: E, args: Args, cfg: Config) -> Result<Vec<PathBuf> None => Some(i), Some(_) => None, }).collect(); + // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows + // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) let missing_embeds = e.embeds(&missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) @@ -168,9 +174,10 @@ fn main() -> Result<()> { let args = Args::parse(); let tsp_path = match args.embedder { - Embedder::Brightness => process_embedder(BrightnessEmbedder, args, cfg), - Embedder::Hue => process_embedder(HueEmbedder, args, cfg), - Embedder::Color => process_embedder(ColorEmbedder, args, cfg), + Embedder::Brightness => process_embedder(BrightnessEmbedder, args, &cfg), + Embedder::Hue => process_embedder(HueEmbedder, args, &cfg), + Embedder::Color => process_embedder(ColorEmbedder, args, &cfg), + Embedder::Content => process_embedder(ContentEmbedder::new(&cfg), args, &cfg), }?; for p in tsp_path { diff --git a/src/pure_embedders.rs b/src/pure_embedders.rs new file mode 100644 index 0000000..0f0c3ab --- /dev/null +++ b/src/pure_embedders.rs @@ -0,0 +1,98 @@ +use anyhow::{bail, Result}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +use crate::{MetricElem, EmbedderT}; + +pub(crate) struct BrightnessEmbedder; +impl EmbedderT for BrightnessEmbedder { + type Embedding = f64; + const NAME: &'static str = "Brightness"; + + fn embed(&self, path: &Path) -> Result<f64> { + let im = image::open(path)?; + let num_bytes = 3 * (im.height() * im.width()); + + if num_bytes == 0 { + bail!("Encountered NaN brightness, due to an empty image"); + } + + Ok(im.to_rgb8() + .iter() + .map(|e| *e as u64) + .sum::<u64>() as f64 / num_bytes as f64) + } +} + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Hue (f64); +impl MetricElem for Hue { + fn dist(&self, b: &Hue) -> f64 { + let d = self.0.dist(&b.0); + d.min(6.-d) + } +} +pub(crate) struct HueEmbedder; +impl EmbedderT for HueEmbedder { + type Embedding = Hue; + const NAME: &'static str = "Hue"; + + fn embed(&self, path: &Path) -> Result<Hue> { + let im = image::open(path)?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / 255. / num_pixels as f64); + + let hue = + if sr >= sg && sr >= sb { + (sg - sb) / (sr - sg.min(sb)) + } + else if sg >= sb { + 2. + (sb - sr) / (sg - sr.min(sb)) + } + else { + 4. + (sr - sg) / (sb - sr.min(sg)) + }; + + if hue.is_nan() { + bail!("Encountered NaN hue, possibly because of a colorless or empty image"); + } + + Ok(Hue(hue)) + } +} + +impl MetricElem for (f64, f64, f64) { + fn dist(&self, o: &(f64, f64, f64)) -> f64 { + let (dr, dg, db) = + ((self.0 - o.0), (self.1 - o.1), (self.2 - o.2)); + (dr*dr + dg*dg + db*db).sqrt() + } +} +pub(crate) struct ColorEmbedder; +impl EmbedderT for ColorEmbedder { + type Embedding = (f64, f64, f64); + const NAME: &'static str = "Color"; + + fn embed(&self, path: &Path) -> Result<(f64, f64, f64)> { + let im = image::open(path)?; + let num_pixels = im.height() * im.width(); + let [sr, sg, sb] = im + .to_rgb8() + .pixels() + .fold([0, 0, 0], |[or, og, ob], n| { + let [nr, ng, nb] = n.0; + [or + nr as u64, og + ng as u64, ob + nb as u64] + }) + .map(|e| e as f64 / num_pixels as f64); + + Ok((sr, sg, sb)) + } +} diff --git a/test.py b/test.py deleted file mode 100644 index 317eaf0..0000000 --- a/test.py +++ /dev/null @@ -1,11 +0,0 @@ -from PIL import Image -from imgbeddings import imgbeddings -import sys - -b = imgbeddings() -for p in sys.argv[1:]: - f = open(p, "rb") - im = Image.open(f) - em = b.to_embeddings(im) - print(em) -print(f"embedded {len(sys.argv[1:])} images.") |