blob: 353222b477e2b3e3a0ca3b1cc6bc4312b2c829b7 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
pub mod ai;
pub mod pure;
pub(crate) use ai::*;
pub(crate) use pure::*;
use anyhow::Result;
use indicatif::{ParallelProgressIterator, ProgressStyle};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
fn dist(&self, _: &Self) -> f64;
}
impl MetricElem for f64 {
fn dist(&self, b: &f64) -> f64 {
(self - b).abs()
}
}
pub trait EmbedderT: Send + Sync {
type Embedding: MetricElem;
const NAME: &'static str;
fn embed(&self, _: &Path) -> Result<Self::Embedding>;
}
pub trait BatchEmbedder: Send + Sync {
type Embedding: MetricElem;
const NAME: &'static str;
fn embeds(&mut self, _: &[PathBuf]) -> Result<Vec<Self::Embedding>>;
}
impl<T: EmbedderT> BatchEmbedder for T {
type Embedding = T::Embedding;
const NAME: &'static str = T::NAME;
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
let st =
ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
paths
.par_iter()
.progress_with_style(st)
.map(|p| self.embed(p))
.collect::<Vec<_>>()
.into_iter()
.try_collect()
}
}
|