diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index 832b8e32e8e41ca223023ab878d70d819fee697d..be0825772abf37613bbcea7b4d0fa49f134816bf 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -77,17 +77,19 @@ impl Predictor for MarkovChainPredictor { } fn predict(&self, query: &str, n: usize) -> Vec<String> { - if let Some(top_words) = + let split_query: Vec<&str> = query.split_whitespace().collect(); + + if let Some(top_words) = get_top_following_words( &self.chain, - query.split_whitespace().last().unwrap_or(""), + split_query.last().unwrap_or(&""), n, self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap() ) { let predictions: Vec<String> = top_words .into_iter() - .map(|(word, _)| format!("\"{} {}\"", query, word)) + .map(|(word, _)| format!("\"{} {}\"", query.trim(), word.trim())) .collect(); return predictions; }