diff options
author | metamuffin <metamuffin@disroot.org> | 2023-09-20 16:55:50 +0200 |
---|---|---|
committer | metamuffin <metamuffin@disroot.org> | 2023-09-20 16:55:50 +0200 |
commit | fbfee0a2bb436a6205d67f561dbd6284621504d6 (patch) | |
tree | 8bc973c19f1e3ee12eaa382fa63a20ddf642fdfe /src/ai_embedders.rs | |
parent | f62b4e356a1deecc550a2eba6d7d0caaad1303c1 (diff) | |
download | embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar.bz2 embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar.zst |
move embedder to module
Diffstat (limited to 'src/ai_embedders.rs')
-rw-r--r-- | src/ai_embedders.rs | 100 |
1 files changed, 0 insertions, 100 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs deleted file mode 100644 index 3848674..0000000 --- a/src/ai_embedders.rs +++ /dev/null @@ -1,100 +0,0 @@ -use anyhow::Result; -use indicatif::{ProgressBar, ProgressIterator, ProgressStyle}; -use serde::{Deserialize, Serialize}; -use std::{ - fs::{remove_file, File}, - io::{copy, BufRead, BufReader, Cursor}, - path::PathBuf, - process::{Command, Stdio}, -}; - -use crate::{BatchEmbedder, Config, MetricElem}; - -#[repr(transparent)] -#[derive(Serialize, Deserialize)] -pub(crate) struct Imgbedding(Vec<f32>); -impl MetricElem for Imgbedding { - fn dist(&self, other: &Self) -> f64 { - self.0 - .iter() - .zip(other.0.iter()) - .map(|(a, b)| (a - b).powf(2.)) - .sum::<f32>() - .sqrt() as f64 - } -} - -pub(crate) struct ContentEmbedder<'a> { - cfg: &'a Config, -} -impl<'a> ContentEmbedder<'a> { - pub(crate) fn new(cfg: &'a Config) -> Self { - ContentEmbedder { cfg } - } -} - -impl<'a> Drop for ContentEmbedder<'a> { - fn drop(&mut self) { - self.cfg - .base_dirs - .place_runtime_file("imgbeddings-api.py") - .iter() - .for_each(|p| { - let _ = remove_file(p); - }); - } -} - -impl BatchEmbedder for ContentEmbedder<'_> { - type Embedding = Imgbedding; - const NAME: &'static str = "imgbeddings"; - - fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> { - let venv_dir = self - .cfg - .base_dirs - .create_data_directory("imgbeddings-venv")?; - let script_file = self - .cfg - .base_dirs - .place_runtime_file("imgbeddings-api.py")?; - - let api_prog = include_bytes!("imgbeddings-api.py"); - copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?; - - let bar = ProgressBar::new_spinner(); - bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?); - bar.enable_steady_tick(std::time::Duration::from_millis(100)); - - bar.set_message("Creating venv..."); - Command::new("python3") - .args(["-m", "venv", venv_dir.to_str().unwrap()]) - .stdout(Stdio::null()) - .spawn()? - .wait()?; - - bar.set_message("Installing/checking packages..."); - Command::new(venv_dir.join("bin/pip3")) - .args(["install", "imgbeddings"]) - .stdout(Stdio::null()) - .spawn()? - .wait()?; - bar.finish(); - - let child = Command::new(venv_dir.join("bin/python3")) - .arg(script_file) - .args(paths) - .stderr(Stdio::null()) - .stdout(Stdio::piped()) - .spawn()?; - - let st = - ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?; - let bar = ProgressBar::new(paths.len() as u64).with_style(st); - BufReader::new(child.stdout.unwrap()) - .lines() - .progress_with(bar) - .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?)) - .try_collect() - } -} |