use std::{collections::HashMap, fs::File}; use csv::ReaderBuilder; use super::Predictor; pub struct SetPredictor { set: HashMap<String,usize>, configuration: HashMap<String, String>, } impl Predictor for SetPredictor { fn new() -> Self { SetPredictor { set: HashMap::new(), configuration: HashMap::new() } } fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized { let mut configuration = HashMap::new(); for (key, value) in config { configuration.insert(key, value.into()); } SetPredictor { set: HashMap::new(), configuration, } } fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { Some(list) => { list.split_whitespace().collect() }, _ => Vec::new() }; //println!("blocklist:{:?}",self.configuration); let lowercase_query = query.to_lowercase(); let words: Vec<&str> = lowercase_query.split_whitespace().filter(|&x| !blocklist.contains(&x)).collect(); for &word in &words { let counter = self.set.entry(word.to_string()).or_insert(0); *counter += 1; } Ok(()) } fn predict(&self, query: &str, n: usize) -> String { if let Some(top_words) = get_top_completions( &self, query.split_whitespace().last().unwrap_or(""), n, self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap() ) { let query_prefix = query.rsplit_once(' ').map_or("", |(head, _)| head).to_string(); let predictions: Vec<String> = top_words .into_iter() .map(|(word, _)| format!("\"{} {}\"", query_prefix, word)) .collect(); return format!("[\"{}\",[{}]]", query, predictions.join(",")); } String::new() } } fn get_top_completions( predictor: &SetPredictor, word: &str, top_n: usize, min_freq: usize ) -> Option<Vec<(String, usize)>> { Some(predictor.set.iter() .filter(|(key, &value)| key.starts_with(word) && value >= min_freq) .map(|(key, &value)| (key.clone(), value)) .take(top_n) .collect()) } pub fn from_file_path(file_path: &str) -> Result<SetPredictor, std::io::Error> { let file = File::open(file_path)?; let mut reader = ReaderBuilder::new().from_reader(file); let mut markov_chain: SetPredictor = SetPredictor::new(); for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { markov_chain.update(query); } } Ok(markov_chain) } pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<SetPredictor, std::io::Error> { let mut configuration = HashMap::new(); for (key, value) in config { configuration.insert(key, value.into()); } println!("Trying to open data file at {}",file_path); let file = File::open(file_path)?; println!("Reading data file..."); let mut reader = ReaderBuilder::new().from_reader(file); let mut markov_chain: SetPredictor = SetPredictor::new(); markov_chain.configuration = configuration; for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { markov_chain.update(query); } } Ok(markov_chain) }