aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock8
-rw-r--r--Cargo.toml2
-rw-r--r--src/embedders/ai.rs11
-rw-r--r--src/main.rs19
4 files changed, 23 insertions, 17 deletions
diff --git a/Cargo.lock b/Cargo.lock
index fcb165e..827a9aa 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -521,6 +521,7 @@ dependencies = [
"ahash",
"anyhow",
"clap",
+ "dirs",
"fastembed",
"image",
"indicatif",
@@ -533,7 +534,6 @@ dependencies = [
"sha2",
"sled",
"typed-sled",
- "xdg",
]
[[package]]
@@ -2712,12 +2712,6 @@ dependencies = [
]
[[package]]
-name = "xdg"
-version = "2.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546"
-
-[[package]]
name = "yoke"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index a8cc6f6..b3a2d6d 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,6 +9,7 @@ edition = "2021"
ahash = "0"
anyhow = "1"
clap = { version = "4", features = ["derive"] }
+dirs = "5"
fastembed = "4"
image = "0"
indicatif = { version = "0", features = ["rayon"] }
@@ -21,4 +22,3 @@ serde = "1"
sha2 = "0"
sled = "0"
typed-sled = "0"
-xdg = "2"
diff --git a/src/embedders/ai.rs b/src/embedders/ai.rs
index 8c9de11..7d31a6b 100644
--- a/src/embedders/ai.rs
+++ b/src/embedders/ai.rs
@@ -13,7 +13,10 @@ pub(crate) struct ContentEmbedder<'a, Metric> {
}
impl<'a, Metric> ContentEmbedder<'a, Metric> {
pub(crate) fn new(cfg: &'a Config) -> Self {
- ContentEmbedder { cfg, _sim: PhantomData }
+ ContentEmbedder {
+ cfg,
+ _sim: PhantomData,
+ }
}
}
@@ -23,11 +26,13 @@ impl<Metric: VecMetric> ContentEmbedder<'_, Metric> {
paths: &[PathBuf],
) -> Result<Vec<Result<<Self as BatchEmbedder>::Embedding>>> {
let mut options = ImageInitOptions::default();
- options.cache_dir = self.cfg.base_dirs.get_cache_home();
+ options.cache_dir = self.cfg.cache_dir.join("models");
let embedder = ImageEmbedding::try_new(options)?;
let bar = ProgressBar::new(paths.len() as u64);
- bar.set_style(ProgressStyle::with_template("{bar:20.cyan/blue} {pos}/{len} {msg}")?);
+ bar.set_style(ProgressStyle::with_template(
+ "{bar:20.cyan/blue} {pos}/{len} {msg}",
+ )?);
bar.enable_steady_tick(std::time::Duration::from_millis(100));
bar.set_message("Embedding images...");
diff --git a/src/main.rs b/src/main.rs
index 92f51f1..7652166 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,6 +1,6 @@
#![feature(iterator_try_collect)]
-use anyhow::Result;
+use anyhow::{anyhow, Result};
use clap::Parser;
use sha2::{Digest, Sha512_256};
use std::{
@@ -37,6 +37,7 @@ struct Args {
embedder: Embedder,
/// Symlink the sorted images into this directory
+ #[cfg(unix)]
#[arg(short = 's', long)]
symlink_dir: Option<PathBuf>,
@@ -77,13 +78,14 @@ struct Args {
#[derive(Debug)]
struct Config {
- base_dirs: xdg::BaseDirectories,
+ cache_dir: PathBuf,
}
fn get_config() -> Result<Config> {
- let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")?;
-
- Ok(Config { base_dirs: dirs })
+ let glob_cache_dir = dirs::cache_dir().ok_or(anyhow!("Could not get cache directory"))?;
+ Ok(Config {
+ cache_dir: glob_cache_dir.join("embeddings-sort"),
+ })
}
fn hash_file(p: &PathBuf) -> Result<[u8; 32]> {
@@ -102,7 +104,7 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathB
where
E: BatchEmbedder,
{
- let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?;
+ let db = sled::open(cfg.cache_dir.join("embeddings.db"))?;
let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME);
// find cached embeddings
@@ -180,6 +182,7 @@ where
))
}
+#[allow(unused_variables)] // use_symlinks on windows
fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()> {
fs::create_dir_all(target)?;
@@ -191,6 +194,7 @@ fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()
};
let tp = target.join(format!("{i:0pad_len$}{ext}"));
+ #[cfg(unix)]
if use_symlinks {
let rel_path =
pathdiff::diff_paths(path::absolute(p)?, path::absolute(target)?).unwrap();
@@ -199,6 +203,8 @@ fn copy_into(tsp: &[PathBuf], target: &PathBuf, use_symlinks: bool) -> Result<()
} else {
reflink_copy::reflink_or_copy(p, tp)?;
}
+ #[cfg(not(unix))]
+ reflink_copy::reflink_or_copy(p, tp)?;
}
Ok(())
}
@@ -226,6 +232,7 @@ fn main() -> Result<()> {
eprintln!("Found tour with length: {}", total_dist);
}
+ #[cfg(unix)]
if let Some(p) = args.symlink_dir {
copy_into(&tsp_path, &p, true)?
}