aboutsummaryrefslogtreecommitdiff
path: root/src/ai_embedders.rs
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-09-20 16:55:50 +0200
committermetamuffin <metamuffin@disroot.org>2023-09-20 16:55:50 +0200
commitfbfee0a2bb436a6205d67f561dbd6284621504d6 (patch)
tree8bc973c19f1e3ee12eaa382fa63a20ddf642fdfe /src/ai_embedders.rs
parentf62b4e356a1deecc550a2eba6d7d0caaad1303c1 (diff)
downloadembeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar
embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar.bz2
embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar.zst
move embedder to module
Diffstat (limited to 'src/ai_embedders.rs')
-rw-r--r--src/ai_embedders.rs100
1 files changed, 0 insertions, 100 deletions
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs
deleted file mode 100644
index 3848674..0000000
--- a/src/ai_embedders.rs
+++ /dev/null
@@ -1,100 +0,0 @@
-use anyhow::Result;
-use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
-use serde::{Deserialize, Serialize};
-use std::{
- fs::{remove_file, File},
- io::{copy, BufRead, BufReader, Cursor},
- path::PathBuf,
- process::{Command, Stdio},
-};
-
-use crate::{BatchEmbedder, Config, MetricElem};
-
-#[repr(transparent)]
-#[derive(Serialize, Deserialize)]
-pub(crate) struct Imgbedding(Vec<f32>);
-impl MetricElem for Imgbedding {
- fn dist(&self, other: &Self) -> f64 {
- self.0
- .iter()
- .zip(other.0.iter())
- .map(|(a, b)| (a - b).powf(2.))
- .sum::<f32>()
- .sqrt() as f64
- }
-}
-
-pub(crate) struct ContentEmbedder<'a> {
- cfg: &'a Config,
-}
-impl<'a> ContentEmbedder<'a> {
- pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg }
- }
-}
-
-impl<'a> Drop for ContentEmbedder<'a> {
- fn drop(&mut self) {
- self.cfg
- .base_dirs
- .place_runtime_file("imgbeddings-api.py")
- .iter()
- .for_each(|p| {
- let _ = remove_file(p);
- });
- }
-}
-
-impl BatchEmbedder for ContentEmbedder<'_> {
- type Embedding = Imgbedding;
- const NAME: &'static str = "imgbeddings";
-
- fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
- let venv_dir = self
- .cfg
- .base_dirs
- .create_data_directory("imgbeddings-venv")?;
- let script_file = self
- .cfg
- .base_dirs
- .place_runtime_file("imgbeddings-api.py")?;
-
- let api_prog = include_bytes!("imgbeddings-api.py");
- copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;
-
- let bar = ProgressBar::new_spinner();
- bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
- bar.enable_steady_tick(std::time::Duration::from_millis(100));
-
- bar.set_message("Creating venv...");
- Command::new("python3")
- .args(["-m", "venv", venv_dir.to_str().unwrap()])
- .stdout(Stdio::null())
- .spawn()?
- .wait()?;
-
- bar.set_message("Installing/checking packages...");
- Command::new(venv_dir.join("bin/pip3"))
- .args(["install", "imgbeddings"])
- .stdout(Stdio::null())
- .spawn()?
- .wait()?;
- bar.finish();
-
- let child = Command::new(venv_dir.join("bin/python3"))
- .arg(script_file)
- .args(paths)
- .stderr(Stdio::null())
- .stdout(Stdio::piped())
- .spawn()?;
-
- let st =
- ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
- let bar = ProgressBar::new(paths.len() as u64).with_style(st);
- BufReader::new(child.stdout.unwrap())
- .lines()
- .progress_with(bar)
- .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
- .try_collect()
- }
-}