aboutsummaryrefslogtreecommitdiff
path: root/src/ai_embedders.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ai_embedders.rs')
-rw-r--r--src/ai_embedders.rs65
1 files changed, 65 insertions, 0 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs
new file mode 100644
index 0000000..14e1d8c
--- /dev/null
+++ b/src/ai_embedders.rs
@@ -0,0 +1,65 @@
+use anyhow::Result;
+use serde::{Deserialize, Serialize};
+use std::{io::{copy, Cursor}, path::PathBuf, process::Command};
+
+use crate::{Config, BatchEmbedder, MetricElem};
+
+#[repr(transparent)]
+#[derive(Serialize, Deserialize)]
+pub(crate) struct Imgbedding (Vec<f32>); // TODO das hier zu einem const size slice machen
+impl MetricElem for Imgbedding {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0.iter().zip(other.0.iter())
+ .map(|(a, b)| (a - b).powf(2.))
+ .sum::<f32>()
+ .sqrt() as f64
+ }
+}
+
+pub(crate) struct ContentEmbedder<'a> {
+ cfg: &'a Config,
+}
+impl<'a> ContentEmbedder<'a> {
+ pub(crate) fn new(cfg: &'a Config) -> Self {
+ ContentEmbedder { cfg: cfg }
+ }
+}
+
+impl<'a> Drop for ContentEmbedder<'a> {
+ fn drop(&mut self) {
+ self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py")
+ .iter()
+ .for_each(|p| { let _ = std::fs::remove_file(&p); });
+ }
+}
+
+impl BatchEmbedder for ContentEmbedder<'_> {
+ type Embedding = Imgbedding;
+ const NAME: &'static str = "imgbeddings";
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+ let venv_dir = self.cfg.base_dirs
+ .create_data_directory("imgbeddings-venv")?;
+ let script_file = self.cfg.base_dirs
+ .place_runtime_file("imgbeddings-api.py")?;
+
+ let api_prog = include_bytes!("imgbeddings-api.py");
+ copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?;
+
+ Command::new("python3")
+ .args(["-m", "venv", venv_dir.to_str().unwrap()])
+ .spawn()?
+ .wait()?;
+ Command::new(venv_dir.join("bin/pip3"))
+ .args(["install", "imgbeddings"])
+ .spawn()?
+ .wait()?;
+
+ let output = Command::new(venv_dir.join("bin/python3"))
+ .arg(script_file)
+ .args(paths)
+ .output()?;
+
+ Ok(serde_json::from_slice(&output.stdout)?)
+ }
+}