use anyhow::Result; use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; use serde::{Deserialize, Serialize}; use std::{io::{BufRead, BufReader, copy, Cursor}, path::PathBuf, process::{Command, Stdio}}; use crate::{Config, BatchEmbedder, MetricElem}; #[repr(transparent)] #[derive(Serialize, Deserialize)] pub(crate) struct Imgbedding (Vec); // TODO das hier zu einem const size slice machen impl MetricElem for Imgbedding { fn dist(&self, other: &Self) -> f64 { self.0.iter().zip(other.0.iter()) .map(|(a, b)| (a - b).powf(2.)) .sum::() .sqrt() as f64 } } pub(crate) struct ContentEmbedder<'a> { cfg: &'a Config, } impl<'a> ContentEmbedder<'a> { pub(crate) fn new(cfg: &'a Config) -> Self { ContentEmbedder { cfg: cfg } } } impl<'a> Drop for ContentEmbedder<'a> { fn drop(&mut self) { self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py") .iter() .for_each(|p| { let _ = std::fs::remove_file(&p); }); } } impl BatchEmbedder for ContentEmbedder<'_> { type Embedding = Imgbedding; const NAME: &'static str = "imgbeddings"; fn embeds(&mut self, paths: &[PathBuf]) -> Result> { let venv_dir = self.cfg.base_dirs .create_data_directory("imgbeddings-venv")?; let script_file = self.cfg.base_dirs .place_runtime_file("imgbeddings-api.py")?; let api_prog = include_bytes!("imgbeddings-api.py"); copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?; let bar = ProgressBar::new_spinner(); bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Creating venv..."); Command::new("python3") .args(["-m", "venv", venv_dir.to_str().unwrap()]) .stdout(Stdio::null()) .spawn()? .wait()?; bar.set_message("Installing/checking packages..."); Command::new(venv_dir.join("bin/pip3")) .args(["install", "imgbeddings"]) .stdout(Stdio::null()) .spawn()? .wait()?; bar.finish(); let child = Command::new(venv_dir.join("bin/python3")) .arg(script_file) .args(paths) .stderr(Stdio::null()) .stdout(Stdio::piped()) .spawn()?; Ok(BufReader::new(child.stdout.unwrap()) .lines() .progress_count(paths.len().try_into().unwrap()) .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) .try_collect()?) } }