aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormetamuffin <metamuffin@disroot.org>2024-11-27 22:34:03 +0000
committerlialenck <lialenck@noreply.codeberg.org>2024-11-27 22:34:03 +0000
commit2a17ceac1ab5cdee98d20a928795a1aba06c8be7 (patch)
tree1c5c31e238870776dc324a91ddface2ae15e050a /src
parent467674743fb638ea56713aecc719a80505b82a17 (diff)
downloadembeddings-sort-2a17ceac1ab5cdee98d20a928795a1aba06c8be7.tar
embeddings-sort-2a17ceac1ab5cdee98d20a928795a1aba06c8be7.tar.bz2
embeddings-sort-2a17ceac1ab5cdee98d20a928795a1aba06c8be7.tar.zst
Replace sled with redb (Also replaces serde to bincode.) (#2)
Reviewed-on: https://codeberg.org/lialenck/embeddings-sort/pulls/2 Co-authored-by: metamuffin <metamuffin@disroot.org> Co-committed-by: metamuffin <metamuffin@disroot.org>
Diffstat (limited to 'src')
-rw-r--r--src/cache.rs45
-rw-r--r--src/embedders/mod.rs4
-rw-r--r--src/embedders/pure.rs4
-rw-r--r--src/embedders/vecmetric.rs8
-rw-r--r--src/main.rs23
5 files changed, 64 insertions, 20 deletions
diff --git a/src/cache.rs b/src/cache.rs
new file mode 100644
index 0000000..608adb5
--- /dev/null
+++ b/src/cache.rs
@@ -0,0 +1,45 @@
+use crate::{FileHash, MetricElem};
+use anyhow::Result;
+use bincode::config::standard;
+use redb::{Database, TableDefinition};
+use std::path::Path;
+
+const T_ENTRIES: TableDefinition<(&str, FileHash), &[u8]> = TableDefinition::new("entries");
+
+pub struct Cache {
+ db: Database,
+}
+impl Cache {
+ pub fn open(path: &Path) -> Result<Self> {
+ let db = Database::create(path)?;
+ let txn = db.begin_write()?;
+ txn.open_table(T_ENTRIES)?;
+ txn.commit()?;
+ Ok(Self { db })
+ }
+ pub fn get<E: MetricElem>(&self, type_name: &'static str, hash: FileHash) -> Result<Option<E>> {
+ let txn = self.db.begin_read()?;
+ let table = txn.open_table(T_ENTRIES)?;
+ if let Some(e) = table.get((type_name, hash))? {
+ Ok(Some(bincode::decode_from_slice(e.value(), standard())?.0))
+ } else {
+ Ok(None)
+ }
+ }
+ pub fn insert<E: MetricElem>(
+ &self,
+ type_name: &'static str,
+ hash: FileHash,
+ value: &E,
+ ) -> Result<()> {
+ let txn = self.db.begin_write()?;
+ let mut table = txn.open_table(T_ENTRIES)?;
+ table.insert(
+ (type_name, hash),
+ bincode::encode_to_vec(value, standard())?.as_slice(),
+ )?;
+ drop(table);
+ txn.commit()?;
+ Ok(())
+ }
+}
diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs
index 1a1721d..83484a1 100644
--- a/src/embedders/mod.rs
+++ b/src/embedders/mod.rs
@@ -6,12 +6,12 @@ pub(crate) use pure::*;
pub(crate) use vecmetric::*;
use anyhow::Result;
+use bincode::{Decode, Encode};
use indicatif::{ParallelProgressIterator, ProgressStyle};
use rayon::prelude::*;
-use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
-pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> {
+pub trait MetricElem: Send + Sync + 'static + Encode + Decode {
fn dist(&self, _: &Self) -> f64;
}
diff --git a/src/embedders/pure.rs b/src/embedders/pure.rs
index 09c8321..531368c 100644
--- a/src/embedders/pure.rs
+++ b/src/embedders/pure.rs
@@ -1,5 +1,5 @@
use anyhow::{bail, Result};
-use serde::{Deserialize, Serialize};
+use bincode::{Decode, Encode};
use std::path::Path;
use crate::{EmbedderT, MetricElem};
@@ -22,7 +22,7 @@ impl EmbedderT for BrightnessEmbedder {
}
#[repr(transparent)]
-#[derive(Serialize, Deserialize)]
+#[derive(Encode, Decode)]
pub(crate) struct Hue(f64);
impl MetricElem for Hue {
fn dist(&self, b: &Hue) -> f64 {
diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs
index 9f2f143..65d71df 100644
--- a/src/embedders/vecmetric.rs
+++ b/src/embedders/vecmetric.rs
@@ -1,13 +1,13 @@
use super::MetricElem;
-use serde::{Deserialize, Serialize};
+use bincode::{Decode, Encode};
pub trait VecMetric: MetricElem + From<Vec<f32>> {}
-#[derive(Deserialize, Serialize)]
+#[derive(Decode, Encode)]
pub struct AngularDistance(pub Vec<f32>);
-#[derive(Deserialize, Serialize)]
+#[derive(Decode, Encode)]
pub struct EuclidianDistance(pub Vec<f32>);
-#[derive(Deserialize, Serialize)]
+#[derive(Decode, Encode)]
pub struct ManhattenDistance(pub Vec<f32>);
impl VecMetric for AngularDistance {}
diff --git a/src/main.rs b/src/main.rs
index 2e63dd8..45621cf 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,5 @@
use anyhow::{anyhow, Result};
+use cache::Cache;
use clap::Parser;
use sha2::{Digest, Sha512_256};
use std::{
@@ -12,9 +13,12 @@ use std::path::absolute;
use embedders::*;
use tsp_approx::*;
+pub mod cache;
mod embedders;
mod tsp_approx;
+pub type FileHash = [u8; 32];
+
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
enum Embedder {
Brightness,
@@ -84,12 +88,12 @@ struct Config {
fn get_config() -> Result<Config> {
let glob_cache_dir = dirs::cache_dir().ok_or(anyhow!("Could not get cache directory"))?;
- Ok(Config {
- cache_dir: glob_cache_dir.join("embeddings-sort"),
- })
+ let cache_dir = glob_cache_dir.join("embeddings-sort");
+ std::fs::create_dir_all(&cache_dir)?;
+ Ok(Config { cache_dir })
}
-fn hash_file(p: &PathBuf) -> Result<[u8; 32]> {
+fn hash_file(p: &PathBuf) -> Result<FileHash> {
let mut f = fs::File::open(p)?;
let mut hasher = Sha512_256::new();
io::copy(&mut f, &mut hasher)?;
@@ -105,18 +109,13 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathB
where
E: BatchEmbedder,
{
- let db = sled::open(cfg.cache_dir.join("embeddings.db"))?;
- let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME);
+ let cache = Cache::open(&cfg.cache_dir.join("embeddings.db-v2"))?;
// find cached embeddings
let mut embeds = args
.images
.iter()
- .map(|path| {
- let h = hash_file(path)?;
- let r: Result<Option<E::Embedding>> = tree.get(&h).map_err(|e| e.into());
- r
- })
+ .map(|path| cache.get(E::NAME, hash_file(path)?))
.collect::<Result<Vec<_>>>()?;
// find indices of missing embeddings
@@ -148,7 +147,7 @@ where
{
match emb {
Ok(emb) => {
- tree.insert(&hash_file(&args.images[idx])?, &emb)?;
+ cache.insert(E::NAME, hash_file(&args.images[idx])?, &emb)?;
embeds[idx] = Some(emb);
}
Err(e) => {