From 65dab3b631706f8481bdc8d12e338f76a1b1cb6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de> Date: Fri, 12 Jul 2024 15:10:40 +0200 Subject: [PATCH] Refactor Internal Response API --- src/main.rs | 5 +++-- src/predictors/basic_markov.rs | 6 +++--- src/predictors/basic_set.rs | 6 +++--- src/predictors/composite.rs | 12 ++++++------ src/predictors/mod.rs | 2 +- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/main.rs b/src/main.rs index ec69113..00ea01e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,9 +100,10 @@ fn main() -> Result<(), io::Error> { Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5), _ => 5 }; - let prediction = &markov_chain.predict(&query, predict_count); + let predictions = &markov_chain.predict(&query, predict_count); //println!("Query: {}, Prediction:{}", query, prediction); - let response = Response::from_string(prediction); + let response = Response::from_string( + format!("[\"{}\",[{}]]", query, predictions.join(","))); request.respond(response); let now = std::time::SystemTime::now(); let elapsed = now.duration_since(last_update).expect("Time went backwards"); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index 9e51fed..1c2f908 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -64,7 +64,7 @@ impl Predictor for MarkovChainPredictor { self.chain = HashMap::new(); } - fn predict(&self, query: &str, n: usize) -> String { + fn predict(&self, query: &str, n: usize) -> Vec<String> { if let Some(top_words) = get_top_following_words( &self.chain, @@ -77,9 +77,9 @@ impl Predictor for MarkovChainPredictor { .into_iter() .map(|(word, _)| format!("\"{} {}\"", query, word)) .collect(); - return format!("[\"{}\",[{}]]", query, predictions.join(",")); + return predictions; } - String::new() + Vec::<String>::new() } } diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index c490b6a..3b66c2c 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -85,7 +85,7 @@ impl Predictor for SetPredictor { } - fn predict(&self, query: &str, n: usize) -> String { + fn predict(&self, query: &str, n: usize) -> Vec<String> { if let Some(top_words) = get_top_completions( &self, @@ -99,9 +99,9 @@ impl Predictor for SetPredictor { .into_iter() .map(|(word, _)| format!("\"{}\"", format!("{} {}",query_prefix, word).trim())) .collect(); - return format!("[\"{}\",[{}]]", query, predictions.join(",")); + return predictions; } - String::new() + Vec::<String>::new() } } diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs index 7e02464..97b08a7 100644 --- a/src/predictors/composite.rs +++ b/src/predictors/composite.rs @@ -41,13 +41,13 @@ impl Predictor for CompositePredictor { } - fn predict(&self, query: &str, n: usize) -> String { - let markov_prediction = self.markov_predictor.predict(query, n); - if markov_prediction.len() > 7+query.len() { - return markov_prediction; + fn predict(&self, query: &str, n: usize) -> Vec<String> { + let mut prediction = self.markov_predictor.predict(query, n); + if prediction.len() < n { + prediction.append(&mut self.set_predictor.predict(query, n)); + prediction.truncate(n); } - let set_prediction = self.set_predictor.predict(query, n); - set_prediction + prediction } } diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index 471e263..d111600 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -6,7 +6,7 @@ pub mod composite; pub trait Predictor { - fn predict(&self, query: &str, n: usize) -> String; + fn predict(&self, query: &str, n: usize) -> Vec<String>; fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; fn decay(&mut self) -> (); -- GitLab