From 674f93f03d7851cbd8e6efdc68f9cf0bd1fd9c58 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de>
Date: Fri, 12 Jul 2024 14:46:34 +0200
Subject: [PATCH] Add Basic Composite Predictor Implementation

---
 src/main.rs                    |  10 +--
 src/predictors/basic_markov.rs |   2 +-
 src/predictors/composite.rs    | 108 +++++++++++++++++++++++++++++++++
 src/predictors/mod.rs          |   1 +
 4 files changed, 115 insertions(+), 6 deletions(-)
 create mode 100644 src/predictors/composite.rs

diff --git a/src/main.rs b/src/main.rs
index 2d0a018..ec69113 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,6 +10,7 @@ use std::str::FromStr;
 use std::fs;
 
 use predictors::basic_set::SetPredictor;
+use predictors::composite::CompositePredictor;
 use toml::Table;
 
 use csv::ReaderBuilder;
@@ -18,7 +19,7 @@ use csv::ReaderBuilder;
 use tiny_http::{Server, Response};
 
 mod predictors;
-use predictors::{basic_markov,basic_set, Predictor};
+use predictors::{basic_markov, basic_set, composite, Predictor};
 use predictors::basic_markov::MarkovChain;
 
 use postgres::{Client, NoTls};
@@ -44,12 +45,11 @@ fn main() -> Result<(), io::Error> {
     }
 
 
-    let mut markov_chain = basic_set::from_file_path_and_config(
+    let mut markov_chain = composite::from_file_path_and_config(
         vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"],config.clone())
-                .unwrap_or(basic_set::SetPredictor::new());
+                .unwrap_or(CompositePredictor::new());
 
     markov_chain = read_from_db(config.clone(), markov_chain);
-    markov_chain.decay();
 
     // let term_frequency_threshold = match config.get("term_frequency_threshold") {
     //     Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
@@ -122,7 +122,7 @@ fn main() -> Result<(), io::Error> {
 }
 
 
-fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) -> SetPredictor {
+fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredictor) -> CompositePredictor {
     let default_password = &String::from("");
     match (config.get("db_host"), config.get("db_password")) {
         (Some(db_host), Some(db_password)) => {
diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs
index fc8afcb..9e51fed 100644
--- a/src/predictors/basic_markov.rs
+++ b/src/predictors/basic_markov.rs
@@ -61,7 +61,7 @@ impl Predictor for MarkovChainPredictor {
     }
 
     fn decay(&mut self) -> () {
-        //TODO
+        self.chain = HashMap::new();
     }
 
     fn predict(&self, query: &str, n: usize) -> String {
diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs
new file mode 100644
index 0000000..7e02464
--- /dev/null
+++ b/src/predictors/composite.rs
@@ -0,0 +1,108 @@
+use std::{collections::HashMap, f32::consts::E, fs::File};
+
+use csv::ReaderBuilder;
+
+use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor};
+
+pub struct CompositePredictor {
+    set_predictor: SetPredictor,
+    markov_predictor: MarkovChainPredictor
+}
+
+impl Predictor for CompositePredictor {
+    fn new() -> Self
+    {
+        CompositePredictor {
+            set_predictor: SetPredictor::new(),
+            markov_predictor: MarkovChainPredictor::new()
+        }
+    }
+
+    fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
+        let mut configuration = HashMap::new();
+        for (key, value) in config {
+            configuration.insert(key, value.into());
+        }
+        CompositePredictor {
+            set_predictor: SetPredictor::new_from_config(configuration.clone()),
+            markov_predictor: MarkovChainPredictor::new_from_config(configuration)
+        }
+    }
+
+    fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
+        self.set_predictor.update(query);
+        self.markov_predictor.update(query);
+        Ok(())
+    }
+
+    fn decay(&mut self) -> () {
+        self.set_predictor.decay();
+        self.markov_predictor.decay();
+        
+    }
+
+    fn predict(&self, query: &str, n: usize) -> String {
+        let markov_prediction = self.markov_predictor.predict(query, n);
+        if markov_prediction.len() > 7+query.len() {
+            return markov_prediction;
+        }
+        let set_prediction = self.set_predictor.predict(query, n);
+        set_prediction
+    }
+}
+
+
+
+pub fn from_file_path(file_path: &str) -> Result<CompositePredictor, std::io::Error> {
+    let file = File::open(file_path)?;
+    let mut reader = ReaderBuilder::new().from_reader(file);
+    let mut markov_chain: CompositePredictor = CompositePredictor::new();
+
+    let mut count = 0;
+    for result in reader.records() {
+        let record = result?;
+        if let Some(query) = record.get(5) {
+            markov_chain.update(query);
+            count += 1;
+        }
+    }
+    println!("{} queries read from file", count);
+
+    Ok(markov_chain)
+}
+
+
+pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, impl Into<String>>) -> Result<CompositePredictor, std::io::Error> {
+    let mut configuration = HashMap::new();
+    for (key, value) in config {
+        configuration.insert(key, value.into());
+    }
+
+    let mut markov_chain: CompositePredictor = CompositePredictor::new_from_config(configuration);
+
+    for path in file_paths {
+        println!("Trying to open data file at {}",path);
+        match File::open(path) {
+            Ok(file) => {
+                println!("Reading data file...");
+                let mut count = 0;
+                let mut reader = ReaderBuilder::new().from_reader(file);
+
+                for result in reader.records() {
+                    let record = result?;
+                    if let Some(query) = record.get(5) {
+                        markov_chain.update(query);
+                        count += 1;
+                    }
+                }
+                println!("{} queries read from file", count);
+            },
+            Err(e) => {
+                println!("Error while reading: {}",e);
+            }
+        }
+    }
+
+
+    Ok(markov_chain)
+}
\ No newline at end of file
diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs
index c37124d..471e263 100644
--- a/src/predictors/mod.rs
+++ b/src/predictors/mod.rs
@@ -2,6 +2,7 @@ use std::collections::HashMap;
 
 pub mod basic_set;
 pub mod basic_markov;
+pub mod composite;
 
 pub trait Predictor {
 
-- 
GitLab