aboutsummaryrefslogtreecommitdiff
path: root/src/embedders.rs
blob: 7257a27cced8a9815bad9c99c20a3117885a80e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
use anyhow::Result;
use rayon::prelude::*;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};

pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
    fn dist(&self, _: &Self) -> f64;
}

impl MetricElem for f64 {
    fn dist(&self, b: &f64) -> f64 {
        (self - b).abs()
    }
}

pub trait EmbedderT: Send + Sync {
    type Embedding: MetricElem;
    const NAME: &'static str;

    fn embed(&self, _: &Path) -> Result<Self::Embedding>;
}

pub trait BatchEmbedder: Send + Sync {
    type Embedding: MetricElem;
    const NAME: &'static str;

    fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
}

impl<T: EmbedderT> BatchEmbedder for T {
    type Embedding = T::Embedding;
    const NAME: &'static str = T::NAME;

    fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
        paths.par_iter()
        .map(|p| self.embed(p))
        .collect::<Vec<_>>()
        .into_iter()
        .try_collect()
    }
}