aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-07 16:41:02 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-07 16:41:02 +0200
commitad8e831a5cbe5bb65d0567d4bdde95bd3ef3de75 (patch)
tree355c1ed5244495090b914e9136e977ae22956585
parentc4b03717914e5c907f7f47dc2a85df6b57763c58 (diff)
downloadembeddings-sort-ad8e831a5cbe5bb65d0567d4bdde95bd3ef3de75.tar
embeddings-sort-ad8e831a5cbe5bb65d0567d4bdde95bd3ef3de75.tar.bz2
embeddings-sort-ad8e831a5cbe5bb65d0567d4bdde95bd3ef3de75.tar.zst
add progress bars/spinners
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml4
-rw-r--r--src/ai_embedders.rs25
-rw-r--r--src/embedders.rs14
-rw-r--r--src/imgbeddings-api.py17
5 files changed, 43 insertions, 18 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 5a44418..571f48d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -482,6 +482,7 @@ dependencies = [
"instant",
"number_prefix",
"portable-atomic",
+ "rayon",
"unicode-width",
]
diff --git a/Cargo.toml b/Cargo.toml
index 7ae9c2f..7335a70 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -8,10 +8,10 @@ edition = "2021"
[dependencies]
image = "0"
xdg = "2"
-clap = {version = "4", features = ["derive"]}
+clap = { version = "4", features = ["derive"] }
priority-queue = "1"
rayon = "1"
-indicatif = "0"
+indicatif = { version = "0", features = ["rayon"] }
sled = "0"
typed-sled = "0"
serde = "1"
diff --git a/src/ai_embedders.rs b/src/ai_embedders.rs
index 14e1d8c..b30d890 100644
--- a/src/ai_embedders.rs
+++ b/src/ai_embedders.rs
@@ -1,6 +1,7 @@
use anyhow::Result;
+use indicatif::{ProgressBar, ProgressIterator, ProgressStyle};
use serde::{Deserialize, Serialize};
-use std::{io::{copy, Cursor}, path::PathBuf, process::Command};
+use std::{io::{BufRead, BufReader, copy, Cursor}, path::PathBuf, process::{Command, Stdio}};
use crate::{Config, BatchEmbedder, MetricElem};
@@ -46,20 +47,36 @@ impl BatchEmbedder for ContentEmbedder<'_> {
let api_prog = include_bytes!("imgbeddings-api.py");
copy(&mut Cursor::new(api_prog), &mut std::fs::File::create(&script_file)?)?;
+ let bar = ProgressBar::new_spinner();
+ bar.set_style(ProgressStyle::with_template("{spinner} {msg}")?);
+ bar.enable_steady_tick(std::time::Duration::from_millis(100));
+
+ bar.set_message("Creating venv...");
Command::new("python3")
.args(["-m", "venv", venv_dir.to_str().unwrap()])
+ .stdout(Stdio::null())
.spawn()?
.wait()?;
+
+ bar.set_message("Installing/checking packages...");
Command::new(venv_dir.join("bin/pip3"))
.args(["install", "imgbeddings"])
+ .stdout(Stdio::null())
.spawn()?
.wait()?;
+ bar.finish();
- let output = Command::new(venv_dir.join("bin/python3"))
+ let child = Command::new(venv_dir.join("bin/python3"))
.arg(script_file)
.args(paths)
- .output()?;
+ .stderr(Stdio::null())
+ .stdout(Stdio::piped())
+ .spawn()?;
- Ok(serde_json::from_slice(&output.stdout)?)
+ Ok(BufReader::new(child.stdout.unwrap())
+ .lines()
+ .progress_count(paths.len().try_into().unwrap())
+ .map(|l| Ok::<_, anyhow::Error>(serde_json::from_str(&l?)?))
+ .try_collect()?)
}
}
diff --git a/src/embedders.rs b/src/embedders.rs
index 7257a27..97cc5ac 100644
--- a/src/embedders.rs
+++ b/src/embedders.rs
@@ -1,7 +1,8 @@
use anyhow::Result;
+use indicatif::{ParallelProgressIterator, ProgressStyle};
use rayon::prelude::*;
-use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
+use std::path::{Path, PathBuf};
pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
fn dist(&self, _: &Self) -> f64;
@@ -32,11 +33,14 @@ impl<T: EmbedderT> BatchEmbedder for T {
const NAME: &'static str = T::NAME;
fn embeds(&mut self, paths: &[PathBuf]) -> Result<Vec<Self::Embedding>> {
+ let st = ProgressStyle::
+ with_template("{bar:20.cyan/blue} {pos}/{len} Embedding images...")?;
paths.par_iter()
- .map(|p| self.embed(p))
- .collect::<Vec<_>>()
- .into_iter()
- .try_collect()
+ .progress_with_style(st)
+ .map(|p| self.embed(p))
+ .collect::<Vec<_>>()
+ .into_iter()
+ .try_collect()
}
}
diff --git a/src/imgbeddings-api.py b/src/imgbeddings-api.py
index 795c625..0c890b5 100644
--- a/src/imgbeddings-api.py
+++ b/src/imgbeddings-api.py
@@ -1,15 +1,18 @@
from PIL import Image
from imgbeddings import imgbeddings
-import sys
+import sys, os
import json as j
-#from itertools import batched # TODO das hier ab python 3.12
b = imgbeddings()
-ems = []
+paths = sys.argv[1:]
+batch_size = 8
-for f in sys.argv[1:]: # TODO this should be batched for faster ai stuff
- im = Image.open(open(f, "rb"))
- ems += [b.to_embeddings(im)]
+for i in range(0, len(paths), batch_size):
+ fs = paths[i:i+batch_size]
-print(j.dumps([em[0].tolist() for em in ems]))
+ ims = [Image.open(open(f, "rb")) for f in fs]
+ for emb in b.to_embeddings(ims).tolist():
+ print(j.dumps(emb))
+
+sys.stderr.write("\n")