diff --git a/src/main.rs b/src/main.rs index b5eea6ebab6a4eb91f1dff7f02e2b3f72a83f6af..007f480f9f160959fc804ed99d914e3f44a7aa7d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -148,6 +148,23 @@ fn main() -> Result<(), io::Error> { fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredictor) -> CompositePredictor { + + let mut raw_list = String::new(); + let blocklist: Vec<&str> = match config.get("blocklist") { + Some(list) => { + match fs::read_to_string(list) { + Ok(l) => { + raw_list = l.clone(); + raw_list.lines().collect() + }, + _ => Vec::new() + } + }, + _ => Vec::new() + + }; + + let default_password = &String::from(""); match (config.get("db_host"), config.get("db_password")) { (Some(db_host), Some(db_password)) => { @@ -159,7 +176,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredict Ok(rows) => { for row in rows { let query: &str = row.get(0); - predictor.update_from_query(query); + predictor.update_from_query(query, &blocklist); count += 1; } println!("{} queries read from DB", count); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index 3f0e400dcaf225abfe6996eda4d4ca8c03cc75fe..ed4bcc3eb8620e58529ab911922eb80dff648910 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -35,18 +35,7 @@ impl Predictor for MarkovChainPredictor { } } - fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> { - let data = importer.fetch_queries(count)?; - let mut count = 0; - for q in data.iter() { - self.update_from_query(&q.query.as_str()); - count += 1; - } - Ok(count) - } - - fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { - let mut raw_list = String::new(); + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> {let mut raw_list = String::new(); let blocklist: Vec<&str> = match self.configuration.get("blocklist") { Some(list) => { match fs::read_to_string(list) { @@ -61,6 +50,16 @@ impl Predictor for MarkovChainPredictor { }; + let data = importer.fetch_queries(count)?; + let mut count = 0; + for q in data.iter() { + self.update_from_query(&q.query.as_str(), &blocklist); + count += 1; + } + Ok(count) + } + + fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(), Box<dyn std::error::Error>> { for word in blocklist.clone() { if word.trim().len() > 1 && query.to_lowercase().contains(String::from(word).to_lowercase().as_str()) { return Ok(()); diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index a97196b38f53e8502d2332cb70cad4f7aa2d329a..05edb13078b99749d49b5057d2906fd32a888761 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -34,17 +34,7 @@ impl Predictor for SetPredictor { } fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> { - let data = importer.fetch_queries(count)?; - let mut count = 0; - for q in data.iter() { - self.update_from_query(&q.query.as_str()); - count += 1; - } - Ok(count) - } - - fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { let mut raw_list = String::new(); let blocklist: Vec<&str> = match self.configuration.get("blocklist") { Some(list) => { @@ -60,6 +50,18 @@ impl Predictor for SetPredictor { }; + let data = importer.fetch_queries(count)?; + let mut count = 0; + + for q in data.iter() { + self.update_from_query(&q.query.as_str(), &blocklist); + count += 1; + } + Ok(count) + } + + fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(), Box<dyn std::error::Error>> { + for word in blocklist.clone() { if word.trim().len() > 1 && query.to_lowercase().contains(String::from(word).to_lowercase().as_str()) { return Ok(()); diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs index eef3052950d00fd6794b5c335ca8455b497cf0d0..10494acb66c565ae3c884278694abbdfb242fa06 100644 --- a/src/predictors/composite.rs +++ b/src/predictors/composite.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, f32::consts::E, fs::File}; +use std::{collections::HashMap, f32::consts::E, fs::{self, File}}; use crate::importers::Importer; @@ -7,6 +7,7 @@ use csv::ReaderBuilder; use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor}; pub struct CompositePredictor { + configuration: HashMap<String, String>, set_predictor: SetPredictor, markov_predictor: MarkovChainPredictor } @@ -15,6 +16,7 @@ impl Predictor for CompositePredictor { fn new() -> Self { CompositePredictor { + configuration: HashMap::new(), set_predictor: SetPredictor::new(), markov_predictor: MarkovChainPredictor::new() } @@ -26,6 +28,7 @@ impl Predictor for CompositePredictor { configuration.insert(key, value.into()); } CompositePredictor { + configuration: configuration.clone(), set_predictor: SetPredictor::new_from_config(configuration.clone()), markov_predictor: MarkovChainPredictor::new_from_config(configuration) } @@ -33,18 +36,35 @@ impl Predictor for CompositePredictor { fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> { + + + let mut raw_list = String::new(); + let blocklist: Vec<&str> = match self.configuration.get("blocklist") { + Some(list) => { + match fs::read_to_string(list) { + Ok(l) => { + raw_list = l.clone(); + raw_list.lines().collect() + }, + _ => Vec::new() + } + }, + _ => Vec::new() + + }; + let data = importer.fetch_queries(count)?; let mut count = 0; for q in data.iter() { - self.update_from_query(&q.query.as_str()); + self.update_from_query(&q.query.as_str(), &blocklist); count += 1; } Ok(count) } - fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { - self.set_predictor.update_from_query(query); - self.markov_predictor.update_from_query(query); + fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(), Box<dyn std::error::Error>> { + self.set_predictor.update_from_query(query, blocklist); + self.markov_predictor.update_from_query(query, blocklist); Ok(()) } diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index 143b9694521d69848f17462665222e5b93dbf297..e8d1def0721d3cefb7625996b7f5c684936d4d3b 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -8,7 +8,7 @@ pub mod composite; pub trait Predictor { fn predict(&self, query: &str, n: usize) -> Vec<String>; - fn update_from_query(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(),Box<dyn std::error::Error>>; fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>>; fn decay(&mut self) -> ();