From 65dab3b631706f8481bdc8d12e338f76a1b1cb6e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de>
Date: Fri, 12 Jul 2024 15:10:40 +0200
Subject: [PATCH] Refactor Internal Response API

---
 src/main.rs                    |  5 +++--
 src/predictors/basic_markov.rs |  6 +++---
 src/predictors/basic_set.rs    |  6 +++---
 src/predictors/composite.rs    | 12 ++++++------
 src/predictors/mod.rs          |  2 +-
 5 files changed, 16 insertions(+), 15 deletions(-)

diff --git a/src/main.rs b/src/main.rs
index ec69113..00ea01e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -100,9 +100,10 @@ fn main() -> Result<(), io::Error> {
                     Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5),
                     _ => 5
                 };
-                let prediction = &markov_chain.predict(&query, predict_count);
+                let predictions = &markov_chain.predict(&query, predict_count);
                 //println!("Query: {}, Prediction:{}", query, prediction);
-                let response = Response::from_string(prediction);
+                let response = Response::from_string(
+                    format!("[\"{}\",[{}]]", query, predictions.join(",")));
                 request.respond(response);
                 let now = std::time::SystemTime::now();
                 let elapsed = now.duration_since(last_update).expect("Time went backwards");
diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs
index 9e51fed..1c2f908 100644
--- a/src/predictors/basic_markov.rs
+++ b/src/predictors/basic_markov.rs
@@ -64,7 +64,7 @@ impl Predictor for MarkovChainPredictor {
         self.chain = HashMap::new();
     }
 
-    fn predict(&self, query: &str, n: usize) -> String {
+    fn predict(&self, query: &str, n: usize) -> Vec<String> {
         if let Some(top_words) =
             get_top_following_words(
                 &self.chain,
@@ -77,9 +77,9 @@ impl Predictor for MarkovChainPredictor {
                 .into_iter()
                 .map(|(word, _)| format!("\"{} {}\"", query, word))
                 .collect();
-            return format!("[\"{}\",[{}]]", query, predictions.join(","));
+            return predictions;
         }
-        String::new()
+        Vec::<String>::new()
     }
 }
 
diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs
index c490b6a..3b66c2c 100644
--- a/src/predictors/basic_set.rs
+++ b/src/predictors/basic_set.rs
@@ -85,7 +85,7 @@ impl Predictor for SetPredictor {
         
     }
 
-    fn predict(&self, query: &str, n: usize) -> String {
+    fn predict(&self, query: &str, n: usize) -> Vec<String> {
         if let Some(top_words) =
             get_top_completions(
                 &self,
@@ -99,9 +99,9 @@ impl Predictor for SetPredictor {
                 .into_iter()
                 .map(|(word, _)| format!("\"{}\"", format!("{} {}",query_prefix, word).trim()))
                 .collect();
-            return format!("[\"{}\",[{}]]", query, predictions.join(","));
+            return predictions;
         }
-        String::new()
+        Vec::<String>::new()
     }
 }
 
diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs
index 7e02464..97b08a7 100644
--- a/src/predictors/composite.rs
+++ b/src/predictors/composite.rs
@@ -41,13 +41,13 @@ impl Predictor for CompositePredictor {
         
     }
 
-    fn predict(&self, query: &str, n: usize) -> String {
-        let markov_prediction = self.markov_predictor.predict(query, n);
-        if markov_prediction.len() > 7+query.len() {
-            return markov_prediction;
+    fn predict(&self, query: &str, n: usize) -> Vec<String> {
+        let mut prediction = self.markov_predictor.predict(query, n);
+        if prediction.len() < n {
+            prediction.append(&mut self.set_predictor.predict(query, n));
+            prediction.truncate(n);
         }
-        let set_prediction = self.set_predictor.predict(query, n);
-        set_prediction
+        prediction
     }
 }
 
diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs
index 471e263..d111600 100644
--- a/src/predictors/mod.rs
+++ b/src/predictors/mod.rs
@@ -6,7 +6,7 @@ pub mod composite;
 
 pub trait Predictor {
 
-    fn predict(&self, query: &str, n: usize) -> String;
+    fn predict(&self, query: &str, n: usize) -> Vec<String>;
     fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>;
     fn decay(&mut self) -> ();
 
-- 
GitLab