aboutsummaryrefslogtreecommitdiff
path: root/src/embedders/ai.rs
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2023-09-20 16:55:50 +0200
committermetamuffin <metamuffin@disroot.org>2023-09-20 16:55:50 +0200
commitfbfee0a2bb436a6205d67f561dbd6284621504d6 (patch)
tree8bc973c19f1e3ee12eaa382fa63a20ddf642fdfe /src/embedders/ai.rs
parentf62b4e356a1deecc550a2eba6d7d0caaad1303c1 (diff)
downloadembeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar
embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar.bz2
embeddings-sort-fbfee0a2bb436a6205d67f561dbd6284621504d6.tar.zst
move embedder to module
Diffstat (limited to 'src/embedders/ai.rs')
-rw-r--r--src/embedders/ai.rs100
1 files changed, 100 insertions, 0 deletions
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
new file mode 100644
index 0000000..3848674
--- /dev/null
+++ b/src/embedders/ai.rs
@@ -0,0 +1,100 @@
+use anyhow::Result;
+use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
+use serde::{Deserialize, Serialize};
+use std::{
+ fs::{remove_file, File},
+ io::{copy, BufRead, BufReader, Cursor},
+ path::PathBuf,
+ process::{Command, Stdio},
+};
+
+use crate::{BatchEmbedder, Config, MetricElem};
+
+#[repr(transparent)]
+#[derive(Serialize, Deserialize)]
+pub(crate) struct Imgbedding(Vec<f32>);
+impl MetricElem for Imgbedding {
+ fn dist(&self, other: &Self) -> f64 {
+ self.0
+ .iter()
+ .zip(other.0.iter())
+ .map(|(a, b)| (a - b).powf(2.))
+ .sum::<f32>()
+ .sqrt() as f64
+ }
+}
+
+pub(crate) struct ContentEmbedder<'a> {
+ cfg: &'a Config,
+}
+impl<'a> ContentEmbedder<'a> {
+ pub(crate) fn new(cfg: &'a Config) -> Self {
+ ContentEmbedder { cfg }
+ }
+}
+
+impl<'a> Drop for ContentEmbedder<'a> {
+ fn drop(&mut self) {
+ self.cfg
+ .base_dirs
+ .place_runtime_file("imgbeddings-api.py")
+ .iter()
+ .for_each(|p| {
+ let _ = remove_file(p);
+ });
+ }
+}
+
+impl BatchEmbedder for ContentEmbedder<'_> {
+ type Embedding = Imgbedding;
+ const NAME: &'static str = "imgbeddings";
+
+ fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+ let venv_dir = self
+ .cfg
+ .base_dirs
+ .create_data_directory("imgbeddings-venv")?;
+ let script_file = self
+ .cfg
+ .base_dirs
+ .place_runtime_file("imgbeddings-api.py")?;
+
+ let api_prog = include_bytes!("imgbeddings-api.py");
+ copy(&mut Cursor::new(api_prog), &mut File::create(&script_file)?)?;
+
+ let bar = ProgressBar::new_spinner();
+ bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
+ bar.enable_steady_tick(std::time::Duration::from_millis(100));
+
+ bar.set_message("Creating venv...");
+ Command::new("python3")
+ .args(["-m", "venv", venv_dir.to_str().unwrap()])
+ .stdout(Stdio::null())
+ .spawn()?
+ .wait()?;
+
+ bar.set_message("Installing/checking packages...");
+ Command::new(venv_dir.join("bin/pip3"))
+ .args(["install", "imgbeddings"])
+ .stdout(Stdio::null())
+ .spawn()?
+ .wait()?;
+ bar.finish();
+
+ let child = Command::new(venv_dir.join("bin/python3"))
+ .arg(script_file)
+ .args(paths)
+ .stderr(Stdio::null())
+ .stdout(Stdio::piped())
+ .spawn()?;
+
+ let st =
+ ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
+ let bar = ProgressBar::new(paths.len() as u64).with_style(st);
+ BufReader::new(child.stdout.unwrap())
+ .lines()
+ .progress_with(bar)
+ .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
+ .try_collect()
+ }
+}