From fe82169e0a84692cf161e6210c4c522912e70e72 Mon Sep 17 00:00:00 2001 From: Lia Lenckowski Date: Mon, 18 Sep 2023 11:34:32 +0200 Subject: add tsp benchmark flag --- src/tsp_approx.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'src/tsp_approx.rs') diff --git a/src/tsp_approx.rs b/src/tsp_approx.rs index c697a25..648d9ea 100644 --- a/src/tsp_approx.rs +++ b/src/tsp_approx.rs @@ -69,7 +69,10 @@ where mst } -fn tsp_from_mst(mst: HashMap>) -> Vec { +fn tsp_from_mst(embeds: &Vec, mst: HashMap>) -> (Vec, f64) +where + M: MetricElem, +{ fn dfs(cur: usize, t: &HashMap>, into: &mut Vec) { into.push(cur); t.get(&cur).unwrap().iter().for_each(|c| dfs(*c, t, into)); @@ -77,10 +80,15 @@ fn tsp_from_mst(mst: HashMap>) -> Vec { let mut tsp_path = Vec::with_capacity(mst.len()); dfs(0, &mst, &mut tsp_path); - tsp_path + let mut total_dist = 0.; + for i in 0..tsp_path.len() - 1 { + total_dist += embeds[tsp_path[i]].dist(&embeds[tsp_path[i + 1]]); + } + + (tsp_path, total_dist) } -pub fn tsp(embeds: &Vec) -> Vec +pub fn tsp(embeds: &Vec) -> (Vec, f64) where M: MetricElem, { @@ -89,7 +97,7 @@ where bar.enable_steady_tick(std::time::Duration::from_millis(100)); bar.set_message("Finding path..."); - let r = tsp_from_mst(get_mst(embeds)); + let r = tsp_from_mst(embeds, get_mst(embeds)); bar.finish(); -- cgit v1.2.3-70-g09d2