diff options
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r-- | src/embedders/ai.rs | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs new file mode 100644 index 0000000..3848674 --- /dev/null +++ b/src/embedders/ai.rs @@ -0,0 +1,100 @@ +use anyhow::Result; +use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; +use serde::{Deserialize, Serialize}; +use std::{ + fs::{remove_file, File}, + io::{copy, BufRead, BufReader, Cursor}, + path::PathBuf, + process::{Command, Stdio}, +}; + +use crate::{BatchEmbedder, Config, MetricElem}; + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Imgbedding(Vec<f32>); +impl MetricElem for Imgbedding { + fn dist(&self, other: &Self) -> f64 { + self.0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} + +pub(crate) struct ContentEmbedder<'a> { + cfg: &'a Config, +} +impl<'a> ContentEmbedder<'a> { + pub(crate) fn new(cfg: &'a Config) -> Self { + ContentEmbedder { cfg } + } +} + +impl<'a> Drop for ContentEmbedder<'a> { + fn drop(&mut self) { + self.cfg + .base_dirs + .place_runtime_file("imgbeddings-api.py") + .iter() + .for_each(|p| { + let _ = remove_file(p); + }); + } +} + +impl BatchEmbedder for ContentEmbedder<'_> { + type Embedding = Imgbedding; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { + let venv_dir = self + .cfg + .base_dirs + .create_data_directory("imgbeddings-venv")?; + let script_file = self + .cfg + .base_dirs + .place_runtime_file("imgbeddings-api.py")?; + + let api_prog = include_bytes!("imgbeddings-api.py"); + copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; + + let bar = ProgressBar::new_spinner(); + bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); + bar.enable_steady_tick(std::time::Duration::from_millis(100)); + + bar.set_message("Creating venv..."); + Command::new("python3") + .args(["-m", "venv", venv_dir.to_str().unwrap()]) + .stdout(Stdio::null()) + .spawn()? + .wait()?; + + bar.set_message("Installing/checking packages..."); + Command::new(venv_dir.join("bin/pip3")) + .args(["install", "imgbeddings"]) + .stdout(Stdio::null()) + .spawn()? + .wait()?; + bar.finish(); + + let child = Command::new(venv_dir.join("bin/python3")) + .arg(script_file) + .args(paths) + .stderr(Stdio::null()) + .stdout(Stdio::piped()) + .spawn()?; + + let st = + ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; + let bar = ProgressBar::new(paths.len() as u64).with_style(st); + BufReader::new(child.stdout.unwrap()) + .lines() + .progress_with(bar) + .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) + .try_collect() + } +} |