aboutsummaryrefslogtreecommitdiff
path: root/src/ai_embedders.rs
blob: 14e1d8c93eccf114d1623e6f35a37dbe014028ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::{io::{copy, Cursor}, path::PathBuf, process::Command};

use crate::{Config, BatchEmbedder, MetricElem};

#[repr(transparent)]
#[derive(Serialize, Deserialize)]
pub(crate) struct Imgbedding (Vec<f32>); // TODO das hier zu einem const size slice machen
impl MetricElem for Imgbedding {
    fn dist(&self, other: &Self) -> f64 {
        self.0.iter().zip(other.0.iter())
            .map(|(a, b)| (a - b).powf(2.))
            .sum::<f32>()
            .sqrt() as f64
    }
}

pub(crate) struct ContentEmbedder<'a> {
    cfg: &'a Config,
}
impl<'a> ContentEmbedder<'a> {
    pub(crate) fn new(cfg: &'a Config) -> Self {
        ContentEmbedder { cfg: cfg }
    }
}

impl<'a> Drop for ContentEmbedder<'a> {
    fn drop(&mut self) {
        self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py")
            .iter()
            .for_each(|p| { let _ = std::fs::remove_file(&p); });
    }
}

impl BatchEmbedder for ContentEmbedder<'_> {
    type Embedding = Imgbedding;
    const NAME: &'static str = "imgbeddings";

    fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
        let venv_dir = self.cfg.base_dirs
            .create_data_directory("imgbeddings-venv")?;
        let script_file = self.cfg.base_dirs
            .place_runtime_file("imgbeddings-api.py")?;

        let api_prog = include_bytes!("imgbeddings-api.py");
        copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?;

        Command::new("python3")
            .args(["-m", "venv", venv_dir.to_str().unwrap()])
            .spawn()?
            .wait()?;
        Command::new(venv_dir.join("bin/pip3"))
            .args(["install", "imgbeddings"])
            .spawn()?
            .wait()?;

        let output = Command::new(venv_dir.join("bin/python3"))
            .arg(script_file)
            .args(paths)
            .output()?;

        Ok(serde_json::from_slice(&output.stdout)?)
    }
}