diff options
Diffstat (limited to 'src/ai_embedders.rs')
-rw-r--r-- | src/ai_embedders.rs | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs new file mode 100644 index 0000000..14e1d8c --- /dev/null +++ b/src/ai_embedders.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::{io::{copy, Cursor}, path::PathBuf, process::Command}; + +use crate::{Config, BatchEmbedder, MetricElem}; + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Imgbedding (Vec<f32>); // TODO das hier zu einem const size slice machen +impl MetricElem for Imgbedding { + fn dist(&self, other: &Self) -> f64 { + self.0.iter().zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::<f32>() + .sqrt() as f64 + } +} + +pub(crate) struct ContentEmbedder<'a> { + cfg: &'a Config, +} +impl<'a> ContentEmbedder<'a> { + pub(crate) fn new(cfg: &'a Config) -> Self { + ContentEmbedder { cfg: cfg } + } +} + +impl<'a> Drop for ContentEmbedder<'a> { + fn drop(&mut self) { + self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py") + .iter() + .for_each(|p| { let _ = std::fs::remove_file(&p); }); + } +} + +impl BatchEmbedder for ContentEmbedder<'_> { + type Embedding = Imgbedding; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { + let venv_dir = self.cfg.base_dirs + .create_data_directory("imgbeddings-venv")?; + let script_file = self.cfg.base_dirs + .place_runtime_file("imgbeddings-api.py")?; + + let api_prog = include_bytes!("imgbeddings-api.py"); + copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?; + + Command::new("python3") + .args(["-m", "venv", venv_dir.to_str().unwrap()]) + .spawn()? + .wait()?; + Command::new(venv_dir.join("bin/pip3")) + .args(["install", "imgbeddings"]) + .spawn()? + .wait()?; + + let output = Command::new(venv_dir.join("bin/python3")) + .arg(script_file) + .args(paths) + .output()?; + + Ok(serde_json::from_slice(&output.stdout)?) + } +} |