From c4b03717914e5c907f7f47dc2a85df6b57763c58 Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Thu, 7 Sep 2023 01:37:34 +0200 Subject: add content embedder --- src/ai_embedders.rs | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 src/ai_embedders.rs (limited to 'src/ai_embedders.rs') diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs new file mode 100644 index 0000000..14e1d8c --- /dev/null +++ b/src/ai_embedders.rs @@ -0,0 +1,65 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::{io::{copy, Cursor}, path::PathBuf, process::Command}; + +use crate::{Config, BatchEmbedder, MetricElem}; + +#[repr(transparent)] +#[derive(Serialize, Deserialize)] +pub(crate) struct Imgbedding (Vec); // TODO das hier zu einem const size slice machen +impl MetricElem for Imgbedding { + fn dist(&self, other: &Self) -> f64 { + self.0.iter().zip(other.0.iter()) + .map(|(a, b)| (a - b).powf(2.)) + .sum::() + .sqrt() as f64 + } +} + +pub(crate) struct ContentEmbedder<'a> { + cfg: &'a Config, +} +impl<'a> ContentEmbedder<'a> { + pub(crate) fn new(cfg: &'a Config) -> Self { + ContentEmbedder { cfg: cfg } + } +} + +impl<'a> Drop for ContentEmbedder<'a> { + fn drop(&mut self) { + self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py") + .iter() + .for_each(|p| { let _ = std::fs::remove_file(&p); }); + } +} + +impl BatchEmbedder for ContentEmbedder<'_> { + type Embedding = Imgbedding; + const NAME: &'static str = "imgbeddings"; + + fn embeds(&mut self, paths: &[PathBuf]) -> Result> { + let venv_dir = self.cfg.base_dirs + .create_data_directory("imgbeddings-venv")?; + let script_file = self.cfg.base_dirs + .place_runtime_file("imgbeddings-api.py")?; + + let api_prog = include_bytes!("imgbeddings-api.py"); + copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?; + + Command::new("python3") + .args(["-m", "venv", venv_dir.to_str().unwrap()]) + .spawn()? + .wait()?; + Command::new(venv_dir.join("bin/pip3")) + .args(["install", "imgbeddings"]) + .spawn()? + .wait()?; + + let output = Command::new(venv_dir.join("bin/python3")) + .arg(script_file) + .args(paths) + .output()?; + + Ok(serde_json::from_slice(&output.stdout)?) + } +} -- cgit v1.2.3-70-g09d2