aboutsummaryrefslogtreecommitdiff
path: root/src/ai_embedders.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ai_embedders.rs')
-rw-r--r--src/ai_embedders.rs40
1 files changed, 29 insertions, 11 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs
index e8e6c8b..3848674 100644
--- a/src/ai_embedders.rs
+++ b/src/ai_embedders.rs
@@ -1,16 +1,23 @@
use anyhow::Result;
use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
use serde::{Deserialize, Serialize};
-use std::{io::{BufRead, BufReader, copy, Cursor}, path::PathBuf, process::{Command, Stdio}};
+use std::{
+ fs::{remove_file, File},
+ io::{copy, BufRead, BufReader, Cursor},
+ path::PathBuf,
+ process::{Command, Stdio},
+};
-use crate::{Config, BatchEmbedder, MetricElem};
+use crate::{BatchEmbedder, Config, MetricElem};
#[repr(transparent)]
#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding (Vec<f32>);
+pub(crate) struct Imgbedding(Vec<f32>);
impl MetricElem for Imgbedding {
fn dist(&self, other: &Self) -> f64 {
- self.0.iter().zip(other.0.iter())
+ self.0
+ .iter()
+ .zip(other.0.iter())
.map(|(a, b)| (a - b).powf(2.))
.sum::<f32>()
.sqrt() as f64
@@ -22,15 +29,19 @@ pub(crate) struct ContentEmbedder<'a> {
}
impl<'a> ContentEmbedder<'a> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg: cfg }
+ ContentEmbedder { cfg }
}
}
impl<'a> Drop for ContentEmbedder<'a> {
fn drop(&mut self) {
- self.cfg.base_dirs.place_runtime_file("imgbeddings-api.py")
+ self.cfg
+ .base_dirs
+ .place_runtime_file("imgbeddings-api.py")
.iter()
- .for_each(|p| { let _ = std::fs::remove_file(&p); });
+ .for_each(|p| {
+ let _ = remove_file(p);
+ });
}
}
@@ -39,13 +50,17 @@ impl BatchEmbedder for ContentEmbedder<'_> {
const NAME: &'static str = "imgbeddings";
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
- let venv_dir = self.cfg.base_dirs
+ let venv_dir = self
+ .cfg
+ .base_dirs
.create_data_directory("imgbeddings-venv")?;
- let script_file = self.cfg.base_dirs
+ let script_file = self
+ .cfg
+ .base_dirs
.place_runtime_file("imgbeddings-api.py")?;
let api_prog = include_bytes!("imgbeddings-api.py");
- copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?;
+ copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;
let bar = ProgressBar::new_spinner();
bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
@@ -73,9 +88,12 @@ impl BatchEmbedder for ContentEmbedder<'_> {
.stdout(Stdio::piped())
.spawn()?;
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ let bar = ProgressBar::new(paths.len() as u64).with_style(st);
BufReader::new(child.stdout.unwrap())
.lines()
- .progress_count(paths.len().try_into().unwrap())
+ .progress_with(bar)
.map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
.try_collect()
}