diff --git a/src/main.rs b/src/main.rs index ec69113d3d14f93ff8c1f6cd5245010134cf2a28..00ea01e5066f2cea3e7a85230f99090cb6e72697 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,9 +100,10 @@ fn main() -> Result<(), io::Error> { Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5), _ => 5 }; - let prediction = &markov_chain.predict(&query, predict_count); + let predictions = &markov_chain.predict(&query, predict_count); //println!("Query: {}, Prediction:{}", query, prediction); - let response = Response::from_string(prediction); + let response = Response::from_string( + format!("[\"{}\",[{}]]", query, predictions.join(","))); request.respond(response); let now = std::time::SystemTime::now(); let elapsed = now.duration_since(last_update).expect("Time went backwards"); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index 9e51fed22301bc2522c2bb0966c6bf9545b74174..1c2f908281e5b8f01acc913ff15b4c8b7ea2cf89 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -64,7 +64,7 @@ impl Predictor for MarkovChainPredictor { self.chain = HashMap::new(); } - fn predict(&self, query: &str, n: usize) -> String { + fn predict(&self, query: &str, n: usize) -> Vec<String> { if let Some(top_words) = get_top_following_words( &self.chain, @@ -77,9 +77,9 @@ impl Predictor for MarkovChainPredictor { .into_iter() .map(|(word, _)| format!("\"{} {}\"", query, word)) .collect(); - return format!("[\"{}\",[{}]]", query, predictions.join(",")); + return predictions; } - String::new() + Vec::<String>::new() } } diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index c490b6ace42a770812fc53864ef5d2dbdc9205f7..3b66c2ca02f178657269c22c36a1c3a28d5e51d6 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -85,7 +85,7 @@ impl Predictor for SetPredictor { } - fn predict(&self, query: &str, n: usize) -> String { + fn predict(&self, query: &str, n: usize) -> Vec<String> { if let Some(top_words) = get_top_completions( &self, @@ -99,9 +99,9 @@ impl Predictor for SetPredictor { .into_iter() .map(|(word, _)| format!("\"{}\"", format!("{} {}",query_prefix, word).trim())) .collect(); - return format!("[\"{}\",[{}]]", query, predictions.join(",")); + return predictions; } - String::new() + Vec::<String>::new() } } diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs index 7e02464878e308c74f3e6d48a2c8c884e4e1cc80..97b08a7915ca39e11ad35ca3d8a8830c67712f11 100644 --- a/src/predictors/composite.rs +++ b/src/predictors/composite.rs @@ -41,13 +41,13 @@ impl Predictor for CompositePredictor { } - fn predict(&self, query: &str, n: usize) -> String { - let markov_prediction = self.markov_predictor.predict(query, n); - if markov_prediction.len() > 7+query.len() { - return markov_prediction; + fn predict(&self, query: &str, n: usize) -> Vec<String> { + let mut prediction = self.markov_predictor.predict(query, n); + if prediction.len() < n { + prediction.append(&mut self.set_predictor.predict(query, n)); + prediction.truncate(n); } - let set_prediction = self.set_predictor.predict(query, n); - set_prediction + prediction } } diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index 471e2634ec0981b5aab605fd3922af38d25d9b44..d111600a8ea24d422c3b7407738f71687bbfbd53 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -6,7 +6,7 @@ pub mod composite; pub trait Predictor { - fn predict(&self, query: &str, n: usize) -> String; + fn predict(&self, query: &str, n: usize) -> Vec<String>; fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; fn decay(&mut self) -> ();