diff options
Diffstat (limited to 'src/main.rs')
-rw-r--r-- | src/main.rs | 63 |
1 files changed, 45 insertions, 18 deletions
diff --git a/src/main.rs b/src/main.rs index 1a520bd..92f51f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -#![feature(iterator_try_collect, absolute_path)] +#![feature(iterator_try_collect)] use anyhow::Result; use clap::Parser; @@ -60,10 +60,14 @@ struct Args { #[arg(long, default_value = "christofides")] tsp_approx: TspBaseAlg, - /// number of 2-Opt refinement steps. Has quickly diminishing returns + /// Number of 2-Opt refinement steps. Has quickly diminishing returns #[arg(short = 'r', default_value = "3")] refine: usize, + /// Ignore failed embeddings + #[arg(short = 'i')] + ignore_errors: bool, + /// Seed for hashing. Random by default. #[arg(long)] hash_seed: Option<u64>, @@ -98,57 +102,80 @@ fn process_embedder<E>(mut e: E, args: &Args, cfg: &Config) -> Result<(Vec<PathB where E: BatchEmbedder, { - if args.images.is_empty() { - return Ok((Vec::new(), 0.)); - } - let db = sled::open(cfg.base_dirs.place_cache_file("embeddings.db")?)?; let tree = typed_sled::Tree::<[u8; 32], E::Embedding>::open(&db, E::NAME); - let mut embeds: Vec<Option<_>> = args + // find cached embeddings + let mut embeds: Vec<_> = args .images .iter() - .map(|p| { - let h = hash_file(p)?; + .map(|path| { + let h = hash_file(path)?; let r: Result<Option<E::Embedding>> = tree.get(&h).map_err(|e| e.into()); r }) .try_collect()?; + // find indices of missing embeddings let missing_embeds_indices: Vec<_> = embeds .iter() .enumerate() - .filter_map(|(i, v)| match v { + .filter_map(|(i, cached_embedding)| match cached_embedding { None => Some(i), Some(_) => None, }) .collect(); - // TODO only run e.embeds if !missing_embeds_indices.is_empty(); this allows - // for optimizations in the ai embedde (move pip to ::embeds() instead of ::new()) + + // calculate missing embeddings let missing_embeds = if missing_embeds_indices.is_empty() { - Vec::new() + vec![] } else { e.embeds( &missing_embeds_indices .iter() .map(|i| args.images[*i].clone()) .collect::<Vec<_>>(), - )? + ) }; + // insert successfully changed for (idx, emb) in missing_embeds_indices .into_iter() .zip(missing_embeds.into_iter()) { - tree.insert(&hash_file(&args.images[idx])?, &emb)?; - embeds[idx] = Some(emb); + match emb { + Ok(emb) => { + tree.insert(&hash_file(&args.images[idx])?, &emb)?; + embeds[idx] = Some(emb); + } + Err(e) => { + if !args.ignore_errors { + return Err(e); + } + } + } } - let embeds: Vec<_> = embeds.into_iter().map(|e| e.unwrap()).collect(); + // filter out images with failed embeddings + let (embeds, images): (Vec<_>, Vec<_>) = embeds + .into_iter() + .zip(args.images.iter()) + .filter_map(|(emb, path)| match emb { + Some(embedding) => Some((embedding, path)), + None => { + if args.ignore_errors { + None + } else { + panic!("Embedding failed for {}", path.display()) + } + } + }) + .unzip(); + let (tsp_path, total_dist) = tsp(&embeds, &args.tsp_approx, args.refine, &args.hash_seed); Ok(( - tsp_path.iter().map(|i| args.images[*i].clone()).collect(), + tsp_path.iter().map(|i| images[*i].clone()).collect(), total_dist, )) } |