From fbfee0a2bb436a6205d67f561dbd6284621504d6 Mon Sep 17 00:00:00 2001 From: metamuffin Date: Wed, 20 Sep 2023 16:55:50 +0200 Subject: move embedder to module --- src/ai_embedders.rs | 100 ---------------------------------------------------- 1 file changed, 100 deletions(-) delete mode 100644 src/ai_embedders.rs (limited to 'src/ai_embedders.rs') diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs deleted file mode 100644 index 3848674..0000000 --- a/src/ai_embedders.rs +++ /dev/null @@ -1,100 +0,0 @@ -use anyhow::Result; -use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use serde::{Deserialize, Serialize}; -use std::{ - fs::{remove_file, File}, - io::{copy, BufRead, BufReader, Cursor}, - path::PathBuf, - process::{Command, Stdio}, -}; - -use crate::{BatchEmbedder, Config, MetricElem}; - -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding(Vec); -impl MetricElem for Imgbedding { - fn dist(&self, other: &Self) -> f64 { - self.0 - .iter() - .zip(other.0.iter()) - .map(|(a, b)| (a - b).powf(2.)) - .sum::() - .sqrt() as f64 - } -} - -pub(crate) struct ContentEmbedder<'a> { - cfg: &'a Config, -} -impl<'a> ContentEmbedder<'a> { - pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg } - } -} - -impl<'a> Drop for ContentEmbedder<'a> { - fn drop(&mut self) { - self.cfg - .base_dirs - .place_runtime_file("imgbeddings-api.py") - .iter() - .for_each(|p| { - let _ = remove_file(p); - }); - } -} - -impl BatchEmbedder for ContentEmbedder<'_> { - type Embedding = Imgbedding; - const NAME: &'static str = "imgbeddings"; - - fn embeds(&mut self, paths: &[PathBuf]) -> Result> { - let venv_dir = self - .cfg - .base_dirs - .create_data_directory("imgbeddings-venv")?; - let script_file = self - .cfg - .base_dirs - .place_runtime_file("imgbeddings-api.py")?; - - let api_prog = include_bytes!("imgbeddings-api.py"); - copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; - - let bar = ProgressBar::new_spinner(); - bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); - bar.enable_steady_tick(std::time::Duration::from_millis(100)); - - bar.set_message("Creating venv..."); - Command::new("python3") - .args(["-m", "venv", venv_dir.to_str().unwrap()]) - .stdout(Stdio::null()) - .spawn()? - .wait()?; - - bar.set_message("Installing/checking packages..."); - Command::new(venv_dir.join("bin/pip3")) - .args(["install", "imgbeddings"]) - .stdout(Stdio::null()) - .spawn()? - .wait()?; - bar.finish(); - - let child = Command::new(venv_dir.join("bin/python3")) - .arg(script_file) - .args(paths) - .stderr(Stdio::null()) - .stdout(Stdio::piped()) - .spawn()?; - - let st = - ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; - let bar = ProgressBar::new(paths.len() as u64).with_style(st); - BufReader::new(child.stdout.unwrap()) - .lines() - .progress_with(bar) - .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) - .try_collect() - } -} -- cgit v1.2.3-70-g09d2