diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cache.rs | 45 | ||||
-rw-r--r-- | src/embedders/mod.rs | 4 | ||||
-rw-r--r-- | src/embedders/pure.rs | 4 | ||||
-rw-r--r-- | src/embedders/vecmetric.rs | 8 | ||||
-rw-r--r-- | src/main.rs | 23 |
5 files changed, 64 insertions, 20 deletions
diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..608adb5 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,45 @@ +use crate::{FileHash, MetricElem}; +use anyhow::Result; +use bincode::config::standard; +use redb::{Database, TableDefinition}; +use std::path::Path; + +const T_ENTRIES: TableDefinition<(&str, FileHash), &[u8]> = TableDefinition::new("entries"); + +pub struct Cache { + db: Database, +} +impl Cache { + pub fn open(path: &Path) -> Result<Self> { + let db = Database::create(path)?; + let txn = db.begin_write()?; + txn.open_table(T_ENTRIES)?; + txn.commit()?; + Ok(Self { db }) + } + pub fn get<E: MetricElem>(&self, type_name: &'static str, hash: FileHash) -> Result<Option<E>> { + let txn = self.db.begin_read()?; + let table = txn.open_table(T_ENTRIES)?; + if let Some(e) = table.get((type_name, hash))? { + Ok(Some(bincode::decode_from_slice(e.value(), standard())?.0)) + } else { + Ok(None) + } + } + pub fn insert<E: MetricElem>( + &self, + type_name: &'static str, + hash: FileHash, + value: &E, + ) -> Result<()> { + let txn = self.db.begin_write()?; + let mut table = txn.open_table(T_ENTRIES)?; + table.insert( + (type_name, hash), + bincode::encode_to_vec(value, standard())?.as_slice(), + )?; + drop(table); + txn.commit()?; + Ok(()) + } +} diff --git a/src/embedders/mod.rs b/src/embedders/mod.rs index 1a1721d..83484a1 100644 --- a/src/embedders/mod.rs +++ b/src/embedders/mod.rs @@ -6,12 +6,12 @@ pub(crate) use pure::*; pub(crate) use vecmetric::*; use anyhow::Result; +use bincode::{Decode, Encode}; use indicatif::{ParallelProgressIterator, ProgressStyle}; use rayon::prelude::*; -use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; -pub trait MetricElem: Send + Sync + 'static + Serialize + for<'a> Deserialize<'a> { +pub trait MetricElem: Send + Sync + 'static + Encode + Decode { fn dist(&self, _: &Self) -> f64; } diff --git a/src/embedders/pure.rs b/src/embedders/pure.rs index 09c8321..531368c 100644 --- a/src/embedders/pure.rs +++ b/src/embedders/pure.rs @@ -1,5 +1,5 @@ use anyhow::{bail, Result}; -use serde::{Deserialize, Serialize}; +use bincode::{Decode, Encode}; use std::path::Path; use crate::{EmbedderT, MetricElem}; @@ -22,7 +22,7 @@ impl EmbedderT for BrightnessEmbedder { } #[repr(transparent)] -#[derive(Serialize, Deserialize)] +#[derive(Encode, Decode)] pub(crate) struct Hue(f64); impl MetricElem for Hue { fn dist(&self, b: &Hue) -> f64 { diff --git a/src/embedders/vecmetric.rs b/src/embedders/vecmetric.rs index 9f2f143..65d71df 100644 --- a/src/embedders/vecmetric.rs +++ b/src/embedders/vecmetric.rs @@ -1,13 +1,13 @@ use super::MetricElem; -use serde::{Deserialize, Serialize}; +use bincode::{Decode, Encode}; pub trait VecMetric: MetricElem + From<Vec<f32>> {} -#[derive(Deserialize, Serialize)] +#[derive(Decode, Encode)] pub struct AngularDistance(pub Vec<f32>); -#[derive(Deserialize, Serialize)] +#[derive(Decode, Encode)] pub struct EuclidianDistance(pub Vec<f32>); -#[derive(Deserialize, Serialize)] +#[derive(Decode, Encode)] pub struct ManhattenDistance(pub Vec<f32>); impl VecMetric for AngularDistance {} diff --git a/src/main.rs b/src/main.rs index 2e63dd8..45621cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use anyhow::{anyhow, Result}; +use cache::Cache; use clap::Parser; use sha2::{Digest, Sha512_256}; use std::{ @@ -12,9 +13,12 @@ use std::path::absolute; use embedders::*; use tsp_approx::*; +pub mod cache; mod embedders; mod tsp_approx; +pub type FileHash = [u8; 32]; + #[derive(Debug, Clone, Copy, clap::ValueEnum)] enum Embedder { Brightness, @@ -84,12 +88,12 @@ struct Config { fn get_config() -> Result<Config> { let glob_cache_dir = dirs::cache_dir().ok_or(anyhow!("Could not get cache directory"))?; - Ok(Config { - cache_dir: glob_cache_dir.join("embeddings-sort"), - }) + let cache_dir = glob_cache_dir.join("embeddings-sort"); + std::fs::create_dir_all(&cache_dir)?; + Ok(Config { cache_dir }) } -fn hash_file(p: &PathBuf) -> Result<[u8; 32]> { +fn hash_file(p: &PathBuf) -> Result<FileHash> { let mut f = fs::File::open(p)?; let mut hasher = Sha512_256::new(); io::copy(&mut f, &mut hasher)?; @@ -105,18 +109,13 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathB where E: BatchEmbedder, { - let db = sled::open(cfg.cache_dir.join("embeddings.db"))?; - let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); + let cache = Cache::open(&cfg.cache_dir.join("embeddings.db-v2"))?; // find cached embeddings let mut embeds = args .images .iter() - .map(|path| { - let h = hash_file(path)?; - let r: Result<Option<E::Embedding>> = tree.get(&h).map_err(|e| e.into()); - r - }) + .map(|path| cache.get(E::NAME, hash_file(path)?)) .collect::<Result<Vec<_>>>()?; // find indices of missing embeddings @@ -148,7 +147,7 @@ where { match emb { Ok(emb) => { - tree.insert(&hash_file(&args.images[idx])?, &emb)?; + cache.insert(E::NAME, hash_file(&args.images[idx])?, &emb)?; embeds[idx] = Some(emb); } Err(e) => { |