diff --git a/src/importers/file.rs b/src/importers/file.rs new file mode 100644 index 0000000000000000000000000000000000000000..8736901091faeaa612dd1b40ddfc8fd471a3a079 --- /dev/null +++ b/src/importers/file.rs @@ -0,0 +1,20 @@ + + +use super::{Importer,SearchQuery}; + +struct FileImporter { + file: std::fs::File, +} + +impl Importer for FileImporter { + fn fetch_queries(&mut self, n: usize) -> Result<Vec<SearchQuery>, String> { + + let entries: Vec<SearchQuery> = Vec::new(); + + if entries.is_empty() { + Err(String::from("Requested number of entries exceeds available data")) + } else { + Ok(entries) + } + } +} \ No newline at end of file diff --git a/src/importers/mod.rs b/src/importers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..a0ec7d6cadac46f41377c7de805960ed1bc9a0c0 --- /dev/null +++ b/src/importers/mod.rs @@ -0,0 +1,11 @@ + +pub mod file; + +#[derive(Debug)] +pub struct SearchQuery { + pub query: String, +} + +pub trait Importer { + fn fetch_queries(&mut self, n: usize) -> Result<Vec<SearchQuery>, String>; +} diff --git a/src/main.rs b/src/main.rs index 2804b01ccc30969a919a6d82b5dbb64842e4dd93..5c081a87781d47301c514b870cfbd0409bdf3ce2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,9 @@ use csv::ReaderBuilder; use tiny_http::{Server, Response}; +mod importers; + + mod predictors; use predictors::{basic_markov, basic_set, composite, Predictor}; use predictors::basic_markov::MarkovChain; @@ -139,7 +142,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredict Ok(rows) => { for row in rows { let query: &str = row.get(0); - predictor.update(query); + predictor.update_from_query(query); count += 1; } println!("{} queries read from DB", count); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index 1c2f908281e5b8f01acc913ff15b4c8b7ea2cf89..57a673cfbb37dd676120486835df400d5c4459d9 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -2,6 +2,8 @@ use std::{collections::HashMap, fs::File}; use csv::ReaderBuilder; +use crate::importers::Importer; + use super::Predictor; pub type MarkovChain = HashMap<String, HashMap<String, usize>>; @@ -33,7 +35,15 @@ impl Predictor for MarkovChainPredictor { } } - fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> { + let data = importer.fetch_queries(count)?; + for q in data.iter() { + self.update_from_query(&q.query.as_str()); + } + Ok(()) + } + + fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { Some(list) => { list.split_whitespace().collect() @@ -112,7 +122,7 @@ pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io:: for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); } } @@ -139,7 +149,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); } } }, diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index 3b66c2ca02f178657269c22c36a1c3a28d5e51d6..d01527b8f3cf5af91b7e6b37ed37e1773e19c1f4 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, f32::consts::E, fs::File}; +use crate::importers::Importer; + use csv::ReaderBuilder; use super::Predictor; @@ -31,7 +33,15 @@ impl Predictor for SetPredictor { } } - fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> { + let data = importer.fetch_queries(count)?; + for q in data.iter() { + self.update_from_query(&q.query.as_str()); + } + Ok(()) + } + + fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { Some(list) => { list.split_whitespace().collect() @@ -130,7 +140,7 @@ pub fn from_file_path(file_path: &str) -> Result<SetPredictor, std::io::Error> { for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } @@ -160,7 +170,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs index 97b08a7915ca39e11ad35ca3d8a8830c67712f11..ecb20e4b3546a9a0e48c5ad530ce8387d4640719 100644 --- a/src/predictors/composite.rs +++ b/src/predictors/composite.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, f32::consts::E, fs::File}; +use crate::importers::Importer; + use csv::ReaderBuilder; use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor}; @@ -29,9 +31,18 @@ impl Predictor for CompositePredictor { } } - fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { - self.set_predictor.update(query); - self.markov_predictor.update(query); + + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> { + let data = importer.fetch_queries(count)?; + for q in data.iter() { + self.update_from_query(&q.query.as_str()); + } + Ok(()) + } + + fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + self.set_predictor.update_from_query(query); + self.markov_predictor.update_from_query(query); Ok(()) } @@ -62,7 +73,7 @@ pub fn from_file_path(file_path: &str) -> Result<CompositePredictor, std::io::Er for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } @@ -91,7 +102,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index d111600a8ea24d422c3b7407738f71687bbfbd53..eba5fefd6055b8cc37c566b772b9e780d65aec41 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use super::importers::Importer; pub mod basic_set; pub mod basic_markov; @@ -7,7 +8,8 @@ pub mod composite; pub trait Predictor { fn predict(&self, query: &str, n: usize) -> Vec<String>; - fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn update_from_query(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>>; fn decay(&mut self) -> (); fn new() -> Self where Self: Sized;