aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/ai.rs
blob: 7d31a6b91a76cabcc4051ec4a10cbc54a15c128f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use anyhow::{anyhow, Result};
use fastembed::{ImageEmbedding, ImageInitOptions};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use std::{marker::PhantomData, path::PathBuf};

use super::vecmetric::VecMetric;
use crate::{BatchEmbedder, Config};

pub(crate) struct ContentEmbedder<'a, Metric> {
    cfg: &'a Config,
    _sim: PhantomData<Metric>,
}
impl<'a, Metric> ContentEmbedder<'a, Metric> {
    pub(crate) fn new(cfg: &'a Config) -> Self {
        ContentEmbedder {
            cfg,
            _sim: PhantomData,
        }
    }
}

impl<Metric: VecMetric> ContentEmbedder<'_, Metric> {
    fn embeds_or_err(
        &mut self,
        paths: &[PathBuf],
    ) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> {
        let mut options = ImageInitOptions::default();
        options.cache_dir = self.cfg.cache_dir.join("models");
        let embedder = ImageEmbedding::try_new(options)?;

        let bar = ProgressBar::new(paths.len() as u64);
        bar.set_style(ProgressStyle::with_template(
            "{bar:20.cyan/blue} {pos}/{len} {msg}",
        )?);
        bar.enable_steady_tick(std::time::Duration::from_millis(100));
        bar.set_message("Embedding images...");

        let mut res = Vec::with_capacity(paths.len());

        // fastembeds supports batched processing, but does not support error reporting on a
        // per-image basis. Thus, we first try embedding 64 images at once, and if that fails, fall
        // back to passing them to fastembeds one-by-one, so that we can get all the non-failure
        // results.
        for chunk in paths.chunks(64) {
            match embedder.embed(chunk.iter().collect(), Some(8)) {
                Ok(embeds) => res.extend(embeds.into_iter().map(|e| Ok(e.into()))),
                Err(_) => {
                    // embed one by one
                    let mut embeds = chunk
                        .par_iter()
                        .map(|path| match embedder.embed(vec![path], Some(1)) {
                            Err(e) => Err(e),
                            Ok(mut embed) if embed.len() == 1 => Ok(embed.pop().unwrap().into()),
                            Ok(embed) => {
                                Err(anyhow!("Embedder did not return a single value: {embed:?}"))
                            }
                        })
                        .collect();

                    res.append(&mut embeds);
                }
            }
            bar.inc(64);
        }

        Ok(res)
    }
}

impl<Metric: VecMetric> BatchEmbedder for ContentEmbedder<'_, Metric> {
    type Embedding = Metric;
    const NAME: &'static str = "imgbeddings";

    fn embeds(&mut self, paths: &[PathBuf]) -> Vec<Result<Self::Embedding>> {
        match self.embeds_or_err(paths) {
            Ok(embeddings) => embeddings,
            Err(e) => vec![Err(e)],
        }
    }
}