From 674f93f03d7851cbd8e6efdc68f9cf0bd1fd9c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de> Date: Fri, 12 Jul 2024 14:46:34 +0200 Subject: [PATCH] Add Basic Composite Predictor Implementation --- src/main.rs | 10 +-- src/predictors/basic_markov.rs | 2 +- src/predictors/composite.rs | 108 +++++++++++++++++++++++++++++++++ src/predictors/mod.rs | 1 + 4 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 src/predictors/composite.rs diff --git a/src/main.rs b/src/main.rs index 2d0a018..ec69113 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use std::str::FromStr; use std::fs; use predictors::basic_set::SetPredictor; +use predictors::composite::CompositePredictor; use toml::Table; use csv::ReaderBuilder; @@ -18,7 +19,7 @@ use csv::ReaderBuilder; use tiny_http::{Server, Response}; mod predictors; -use predictors::{basic_markov,basic_set, Predictor}; +use predictors::{basic_markov, basic_set, composite, Predictor}; use predictors::basic_markov::MarkovChain; use postgres::{Client, NoTls}; @@ -44,12 +45,11 @@ fn main() -> Result<(), io::Error> { } - let mut markov_chain = basic_set::from_file_path_and_config( + let mut markov_chain = composite::from_file_path_and_config( vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"],config.clone()) - .unwrap_or(basic_set::SetPredictor::new()); + .unwrap_or(CompositePredictor::new()); markov_chain = read_from_db(config.clone(), markov_chain); - markov_chain.decay(); // let term_frequency_threshold = match config.get("term_frequency_threshold") { // Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, @@ -122,7 +122,7 @@ fn main() -> Result<(), io::Error> { } -fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) -> SetPredictor { +fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredictor) -> CompositePredictor { let default_password = &String::from(""); match (config.get("db_host"), config.get("db_password")) { (Some(db_host), Some(db_password)) => { diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index fc8afcb..9e51fed 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -61,7 +61,7 @@ impl Predictor for MarkovChainPredictor { } fn decay(&mut self) -> () { - //TODO + self.chain = HashMap::new(); } fn predict(&self, query: &str, n: usize) -> String { diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs new file mode 100644 index 0000000..7e02464 --- /dev/null +++ b/src/predictors/composite.rs @@ -0,0 +1,108 @@ +use std::{collections::HashMap, f32::consts::E, fs::File}; + +use csv::ReaderBuilder; + +use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor}; + +pub struct CompositePredictor { + set_predictor: SetPredictor, + markov_predictor: MarkovChainPredictor +} + +impl Predictor for CompositePredictor { + fn new() -> Self + { + CompositePredictor { + set_predictor: SetPredictor::new(), + markov_predictor: MarkovChainPredictor::new() + } + } + + fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized { + let mut configuration = HashMap::new(); + for (key, value) in config { + configuration.insert(key, value.into()); + } + CompositePredictor { + set_predictor: SetPredictor::new_from_config(configuration.clone()), + markov_predictor: MarkovChainPredictor::new_from_config(configuration) + } + } + + fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + self.set_predictor.update(query); + self.markov_predictor.update(query); + Ok(()) + } + + fn decay(&mut self) -> () { + self.set_predictor.decay(); + self.markov_predictor.decay(); + + } + + fn predict(&self, query: &str, n: usize) -> String { + let markov_prediction = self.markov_predictor.predict(query, n); + if markov_prediction.len() > 7+query.len() { + return markov_prediction; + } + let set_prediction = self.set_predictor.predict(query, n); + set_prediction + } +} + + + +pub fn from_file_path(file_path: &str) -> Result<CompositePredictor, std::io::Error> { + let file = File::open(file_path)?; + let mut reader = ReaderBuilder::new().from_reader(file); + let mut markov_chain: CompositePredictor = CompositePredictor::new(); + + let mut count = 0; + for result in reader.records() { + let record = result?; + if let Some(query) = record.get(5) { + markov_chain.update(query); + count += 1; + } + } + println!("{} queries read from file", count); + + Ok(markov_chain) +} + + +pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, impl Into<String>>) -> Result<CompositePredictor, std::io::Error> { + let mut configuration = HashMap::new(); + for (key, value) in config { + configuration.insert(key, value.into()); + } + + let mut markov_chain: CompositePredictor = CompositePredictor::new_from_config(configuration); + + for path in file_paths { + println!("Trying to open data file at {}",path); + match File::open(path) { + Ok(file) => { + println!("Reading data file..."); + let mut count = 0; + let mut reader = ReaderBuilder::new().from_reader(file); + + for result in reader.records() { + let record = result?; + if let Some(query) = record.get(5) { + markov_chain.update(query); + count += 1; + } + } + println!("{} queries read from file", count); + }, + Err(e) => { + println!("Error while reading: {}",e); + } + } + } + + + Ok(markov_chain) +} \ No newline at end of file diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index c37124d..471e263 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; pub mod basic_set; pub mod basic_markov; +pub mod composite; pub trait Predictor { -- GitLab