aboutsummaryrefslogtreecommitdiff
path: root/src/embedders.rs
blob: a14a8cfe834fbb1f6e9f4ef265480733bf6a883a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
use anyhow::Result;
use indicatif::{ParallelProgressIterator, ProgressStyle};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};

pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
    fn dist(&self, _: &Self) -> f64;
}

impl MetricElem for f64 {
    fn dist(&self, b: &f64) -> f64 {
        (self - b).abs()
    }
}

pub trait EmbedderT: Send + Sync {
    type Embedding: MetricElem;
    const NAME: &'static str;

    fn embed(&self, _: &Path) -> Result<Self::Embedding>;
}

pub trait BatchEmbedder: Send + Sync {
    type Embedding: MetricElem;
    const NAME: &'static str;

    fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
}

impl<T: EmbedderT> BatchEmbedder for T {
    type Embedding = T::Embedding;
    const NAME: &'static str = T::NAME;

    fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
        let st =
            ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
        paths
            .par_iter()
            .progress_with_style(st)
            .map(|p| self.embed(p))
            .collect::<Vec<_>>()
            .into_iter()
            .try_collect()
    }
}