diff --git a/src/main.rs b/src/main.rs index 4db7bd5afdf88113d68230ea3fcc8051166906cf..6f8d32146a34ba3f1a0e1f1b23b410c3ddb4c299 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use csv::ReaderBuilder; use tiny_http::{Server, Response}; mod predictors; -use predictors::{basic_markov, Predictor}; +use predictors::{basic_markov,basic_set, Predictor}; use predictors::basic_markov::MarkovChain; @@ -38,8 +38,8 @@ fn main() -> Result<(), io::Error> { } } - let markov_chain = basic_markov::from_file_path_and_config("../../data/data.csv",config.clone()) - .unwrap_or(basic_markov::from_file_path_and_config("data.csv",config.clone()) + let markov_chain = basic_set::from_file_path_and_config("../../data/data.csv",config.clone()) + .unwrap_or(basic_set::from_file_path_and_config("data.csv",config.clone()) .unwrap()); diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..56bc0bee94764f1ba07df1992681919b36e51f61 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -0,0 +1,126 @@ +use std::{collections::HashMap, fs::File}; + +use csv::ReaderBuilder; + +use super::Predictor; + +pub struct SetPredictor { + set: HashMap<String,usize>, + configuration: HashMap<String, String>, +} + +impl Predictor for SetPredictor { + fn new() -> Self + { + SetPredictor { + set: HashMap::new(), + configuration: HashMap::new() + } + } + + fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized { + let mut configuration = HashMap::new(); + for (key, value) in config { + configuration.insert(key, value.into()); + } + + + SetPredictor { + set: HashMap::new(), + configuration, + } + } + + fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { + Some(list) => { + list.split_whitespace().collect() + }, + _ => Vec::new() + + }; + + //println!("blocklist:{:?}",self.configuration); + + let lowercase_query = query.to_lowercase(); + let words: Vec<&str> = lowercase_query.split_whitespace().filter(|&x| !blocklist.contains(&x)).collect(); + + for &word in &words { + let counter = self.set.entry(word.to_string()).or_insert(0); + *counter += 1; + } + + Ok(()) + } + + fn predict(&self, query: &str, n: usize) -> String { + if let Some(top_words) = + get_top_completions( + &self, + query.split_whitespace().last().unwrap_or(""), + n, + self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap() + ) + { + let query_prefix = query.rsplit_once(' ').map_or("", |(head, _)| head).to_string(); + let predictions: Vec<String> = top_words + .into_iter() + .map(|(word, _)| format!("\"{} {}\"", query_prefix, word)) + .collect(); + return format!("[\"{}\",[{}]]", query, predictions.join(",")); + } + String::new() + } +} + +fn get_top_completions( + predictor: &SetPredictor, + word: &str, + top_n: usize, + min_freq: usize +) -> Option<Vec<(String, usize)>> { + Some(predictor.set.iter() + .filter(|(key, &value)| key.starts_with(word) && value >= min_freq) + .map(|(key, &value)| (key.clone(), value)) + .take(top_n) + .collect()) + +} + + +pub fn from_file_path(file_path: &str) -> Result<SetPredictor, std::io::Error> { + let file = File::open(file_path)?; + let mut reader = ReaderBuilder::new().from_reader(file); + let mut markov_chain: SetPredictor = SetPredictor::new(); + + for result in reader.records() { + let record = result?; + if let Some(query) = record.get(5) { + markov_chain.update(query); + } + } + + Ok(markov_chain) +} + + +pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<SetPredictor, std::io::Error> { + let mut configuration = HashMap::new(); + for (key, value) in config { + configuration.insert(key, value.into()); + } + let file = File::open(file_path)?; + let mut reader = ReaderBuilder::new().from_reader(file); + let mut markov_chain: SetPredictor = SetPredictor::new(); + markov_chain.configuration = configuration; + + for result in reader.records() { + let record = result?; + if let Some(query) = record.get(5) { + markov_chain.update(query); + } + } + + + Ok(markov_chain) +} \ No newline at end of file