aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/ai.rs
blob: 7d5ae9039043753f32208083e7d34eab5781ec6a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
use std::{
    fs::{remove_file, File},
    io::{copy, BufRead, BufReader, Cursor},
    marker::PhantomData,
    path::PathBuf,
    process::{Command, Stdio},
};

use super::vecmetric::VecMetric;
use crate::{BatchEmbedder, Config};

pub(crate) struct ContentEmbedder<'a, Metric> {
    cfg: &'a Config,
    _sim: PhantomData<Metric>,
}
impl<'a, Metric> ContentEmbedder<'a, Metric> {
    pub(crate) fn new(cfg: &'a Config) -> Self {
        ContentEmbedder {
            cfg,
            _sim: PhantomData,
        }
    }
}

impl<Metric> Drop for ContentEmbedder<'_, Metric> {
    fn drop(&mut self) {
        self.cfg
            .base_dirs
            .place_runtime_file("imgbeddings-api.py")
            .iter()
            .for_each(|p| {
                let _ = remove_file(p);
            });
    }
}

impl<Metric: VecMetric> ContentEmbedder<'_, Metric> {
    fn embeds_or_err(
        &mut self,
        paths: &[PathBuf],
    ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> {
        let venv_dir = self
            .cfg
            .base_dirs
            .create_data_directory("imgbeddings-venv")?;
        let script_file = self
            .cfg
            .base_dirs
            .place_runtime_file("imgbeddings-api.py")?;

        let api_prog = include_bytes!("imgbeddings-api.py");
        copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;

        let bar = ProgressBar::new_spinner();
        bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
        bar.enable_steady_tick(std::time::Duration::from_millis(100));

        bar.set_message("Creating venv...");
        Command::new("python3")
            .args(["-m", "venv", venv_dir.to_str().unwrap()])
            .stdout(Stdio::null())
            .spawn()?
            .wait()?;

        bar.set_message("Installing/checking packages...");
        Command::new(venv_dir.join("bin/pip3"))
            .args(["install", "imgbeddings"])
            .stdout(Stdio::null())
            .spawn()?
            .wait()?;
        bar.finish();

        let child = Command::new(venv_dir.join("bin/python3"))
            .arg(script_file)
            .args(paths)
            .stderr(Stdio::inherit())
            .stdout(Stdio::piped())
            .spawn()?;

        // TODO das ist noch nicht ok... wir geben zb potentiell zu wenig dings zurück.
        // python-code muss dafür auch geändert werden xD
        let st =
            ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
        let bar = ProgressBar::new(paths.len() as u64).with_style(st);
        Ok(BufReader::new(child.stdout.unwrap())
            .lines()
            .progress_with(bar)
            .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
            .collect())
    }
}

impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
    type Embedding = Metric;
    const NAME: &'static str = "imgbeddings";

    fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> {
        match self.embeds_or_err(paths) {
            Ok(embeddings) => embeddings,
            Err(e) => vec![Err(e)],
        }
    }
}