aboutsummaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
authorLia Lenckowski <lialenck@protonmail.com>2023-09-06 13:37:22 +0200
committerLia Lenckowski <lialenck@protonmail.com>2023-09-06 13:37:22 +0200
commitcb67cd7389b68750518b67c7a58beb7659298352 (patch)
treeaf6cb44cb47dacb82127f5bcf586c5b60aa4603a /src/main.rs
parent3fe8109102083ef32da9e11eef5dc45dba530333 (diff)
downloadembeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar
embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.bz2
embeddings-sort-cb67cd7389b68750518b67c7a58beb7659298352.tar.zst
cache embeddings by path (should be by hash, but thats for later)
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs104
1 files changed, 71 insertions, 33 deletions
diff --git a/src/main.rs b/src/main.rs
index a28885d..c4a0c26 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,20 +25,20 @@ struct Args {
images: Vec<PathBuf>,
}
-//#[derive(Debug)]
-//struct Config {
-// base_dir: xdg::BaseDirectories,
-//}
-//
-//fn get_config() -> Result<Config, String> {
-// let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")
-// .map_err(|_| "oh no")?;
-//
-// Ok(Config{base_dir: dirs})
-//}
+#[derive(Debug)]
+struct Config {
+ base_dirs: xdg::BaseDirectories,
+}
-fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
- where E: EmbedderT
+fn get_config() -> Result<Config, String> {
+ let dirs = xdg::BaseDirectories::with_prefix("embeddings-sort")
+ .map_err(|_| "oh no")?;
+
+ Ok(Config{base_dirs: dirs})
+}
+
+fn get_mst<M>(embeds: &Vec<M>) -> HashMap<usize, Vec<usize>>
+ where M: MetricElem
{
// wrapper struct to
// - reverse the ordering
@@ -58,31 +58,23 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
}
}
- let num_ims = args.images.len();
- if num_ims == 0 {
- return Ok(Vec::new());
- }
-
- let embeds: Vec<_> = e
- .embed(&args.images)?
- .into_iter()
- .collect();
+ let num_embeds = embeds.len();
let mut possible_edges =
- PriorityQueue::with_capacity((num_ims * num_ims - num_ims) / 2);
- let mut mst = HashMap::with_capacity(num_ims);
+ PriorityQueue::with_capacity((num_embeds * num_embeds - num_embeds) / 2);
+ let mut mst = HashMap::with_capacity(num_embeds);
// here, we start at 0.
// we might get a better result in the end if we started with a vertex next
// to the lowest-cost edge, but we don't know which one that is (though we
// could compute that without changing our asymptotic complexity)
mst.insert(0, Vec::new());
- for i in 1..num_ims {
+ for i in 1..num_embeds {
possible_edges.push((0, i), DownOrd(embeds[0].dist(&embeds[i])));
}
// prims algorithm or something like that
- while mst.len() < num_ims {
+ while mst.len() < num_embeds {
// find the edge with the least cost that connects us to a new vertex
let (new, old) = loop {
let ((a, b), _) = possible_edges.pop().unwrap();
@@ -97,7 +89,7 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
// insert all the new edges we could take
mst.entry(old).and_modify(|v|v.push(new));
- for i in 0..num_ims {
+ for i in 0..num_embeds {
// don't consider edges taking us to nodes we already visited
if mst.contains_key(&i) {
continue;
@@ -107,25 +99,71 @@ fn process_embedder<E>(mut e: E, args: Args) -> Result<Vec<PathBuf>, String>
}
}
- // find TSP approximation via DFS through the MST
+ mst
+}
+
+fn tsp_from_mst(mst: HashMap<usize, Vec<usize>>) -> Vec<usize> {
fn dfs(cur: usize, t: &HashMap<usize, Vec<usize>>, into: &mut Vec<usize>) {
into.push(cur);
t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into));
}
- let mut tsp_path = Vec::with_capacity(num_ims);
+ let mut tsp_path = Vec::with_capacity(mst.len());
dfs(0, &mst, &mut tsp_path);
+ tsp_path
+}
+
+fn process_embedder<E>(mut e: E, args: Args, cfg: Config) -> Result<Vec<PathBuf>, String>
+ where E: EmbedderT
+{
+ if args.images.len() == 0 {
+ return Ok(Vec::new());
+ }
+
+ let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")
+ .map_err(|e| e.to_string())?).map_err(|e| e.to_string())?;
+ let tree = typed_sled::Tree::<PathBuf, E::Embedding>::open(&db, E::NAME);
+
+ // TODO nicht pfad, sondern hash vom bild als key nehmen
+ let mut embeds: Vec<Option<_>> = args.images
+ .iter()
+ .map(|p| tree.get(p).map_err(|e| e.to_string()))
+ .try_collect()?;
+
+ let missing_embeds_indices: Vec<_> = embeds
+ .iter()
+ .enumerate()
+ .filter_map(|(i, v)| match v {
+ None => Some(i),
+ Some(_) => None,
+ }).collect();
+ let missing_embeds = e.embeds(&missing_embeds_indices
+ .iter()
+ .map(|i| args.images[*i].clone())
+ .collect::<Vec<_>>())?;
+
+ for (idx, emb) in missing_embeds_indices
+ .into_iter().zip(missing_embeds.into_iter())
+ {
+ // TODO hier auch hash statt pfad
+ tree.insert(&args.images[idx], &emb).map_err(|e| e.to_string())?;
+ embeds[idx] = Some(emb);
+ }
+
+ let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect();
+ let tsp_path = tsp_from_mst(get_mst(&embeds));
+
Ok(tsp_path.iter().map(|i| args.images[*i].clone()).collect())
}
fn main() -> Result<(), String> {
- //let cfg = get_config()?;
+ let cfg = get_config()?;
let args = Args::parse();
let tsp_path = match args.embedder {
- Embedder::Brightness => process_embedder(BrightnessEmbedder, args),
- Embedder::Hue => process_embedder(HueEmbedder, args),
- Embedder::Color => process_embedder(ColorEmbedder, args),
+ Embedder::Brightness => process_embedder(BrightnessEmbedder, args, cfg),
+ Embedder::Hue => process_embedder(HueEmbedder, args, cfg),
+ Embedder::Color => process_embedder(ColorEmbedder, args, cfg),
}?;
for p in tsp_path {