Skip to content
Snippets Groups Projects
Commit 674f93f0 authored by Phil Höfer's avatar Phil Höfer
Browse files

Add Basic Composite Predictor Implementation

parent ee50ef2f
No related branches found
No related tags found
No related merge requests found
Pipeline #9976 passed
......@@ -10,6 +10,7 @@ use std::str::FromStr;
use std::fs;
use predictors::basic_set::SetPredictor;
use predictors::composite::CompositePredictor;
use toml::Table;
use csv::ReaderBuilder;
......@@ -18,7 +19,7 @@ use csv::ReaderBuilder;
use tiny_http::{Server, Response};
mod predictors;
use predictors::{basic_markov,basic_set, Predictor};
use predictors::{basic_markov, basic_set, composite, Predictor};
use predictors::basic_markov::MarkovChain;
use postgres::{Client, NoTls};
......@@ -44,12 +45,11 @@ fn main() -> Result<(), io::Error> {
}
let mut markov_chain = basic_set::from_file_path_and_config(
let mut markov_chain = composite::from_file_path_and_config(
vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"],config.clone())
.unwrap_or(basic_set::SetPredictor::new());
.unwrap_or(CompositePredictor::new());
markov_chain = read_from_db(config.clone(), markov_chain);
markov_chain.decay();
// let term_frequency_threshold = match config.get("term_frequency_threshold") {
// Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
......@@ -122,7 +122,7 @@ fn main() -> Result<(), io::Error> {
}
fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) -> SetPredictor {
fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredictor) -> CompositePredictor {
let default_password = &String::from("");
match (config.get("db_host"), config.get("db_password")) {
(Some(db_host), Some(db_password)) => {
......
......@@ -61,7 +61,7 @@ impl Predictor for MarkovChainPredictor {
}
fn decay(&mut self) -> () {
//TODO
self.chain = HashMap::new();
}
fn predict(&self, query: &str, n: usize) -> String {
......
use std::{collections::HashMap, f32::consts::E, fs::File};
use csv::ReaderBuilder;
use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor};
pub struct CompositePredictor {
set_predictor: SetPredictor,
markov_predictor: MarkovChainPredictor
}
impl Predictor for CompositePredictor {
fn new() -> Self
{
CompositePredictor {
set_predictor: SetPredictor::new(),
markov_predictor: MarkovChainPredictor::new()
}
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
CompositePredictor {
set_predictor: SetPredictor::new_from_config(configuration.clone()),
markov_predictor: MarkovChainPredictor::new_from_config(configuration)
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
self.set_predictor.update(query);
self.markov_predictor.update(query);
Ok(())
}
fn decay(&mut self) -> () {
self.set_predictor.decay();
self.markov_predictor.decay();
}
fn predict(&self, query: &str, n: usize) -> String {
let markov_prediction = self.markov_predictor.predict(query, n);
if markov_prediction.len() > 7+query.len() {
return markov_prediction;
}
let set_prediction = self.set_predictor.predict(query, n);
set_prediction
}
}
pub fn from_file_path(file_path: &str) -> Result<CompositePredictor, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: CompositePredictor = CompositePredictor::new();
let mut count = 0;
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
count += 1;
}
}
println!("{} queries read from file", count);
Ok(markov_chain)
}
pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, impl Into<String>>) -> Result<CompositePredictor, std::io::Error> {
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
let mut markov_chain: CompositePredictor = CompositePredictor::new_from_config(configuration);
for path in file_paths {
println!("Trying to open data file at {}",path);
match File::open(path) {
Ok(file) => {
println!("Reading data file...");
let mut count = 0;
let mut reader = ReaderBuilder::new().from_reader(file);
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
count += 1;
}
}
println!("{} queries read from file", count);
},
Err(e) => {
println!("Error while reading: {}",e);
}
}
}
Ok(markov_chain)
}
\ No newline at end of file
......@@ -2,6 +2,7 @@ use std::collections::HashMap;
pub mod basic_set;
pub mod basic_markov;
pub mod composite;
pub trait Predictor {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment