diff --git a/config.toml b/config.toml index 7bbf4bef306f1e4db181c4f7805daf8cce51bf38..9b3d8cf5dc4abeac8982d0454f700d5a88ecda1f 100644 --- a/config.toml +++ b/config.toml @@ -1,3 +1,4 @@ auth = "12345" term_frequency_threshold = 2 -max_predict_count = 5 \ No newline at end of file +max_predict_count = 5 +blocked_words = "" \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 2af7ebd9f79eb222af1b8799a4716197f306571d..4db7bd5afdf88113d68230ea3fcc8051166906cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,27 +22,41 @@ use predictors::basic_markov::MarkovChain; fn main() -> Result<(), io::Error> { - let config = read_config("config.toml").unwrap().parse::<Table>().unwrap(); + let config_toml = read_config("config.toml").unwrap().parse::<Table>().unwrap(); + let mut config = HashMap::new(); + for (key, value) in config_toml { + match value { + toml::Value::String(s) => { + config.insert(key, s); + } + toml::Value::Integer(i) => { + config.insert(key, i.to_string()); + } + _ => { + // Ignore other types + } + } + } - let markov_chain = basic_markov::from_file_path("../../data/data.csv") - .unwrap_or(basic_markov::from_file_path("data.csv") - .unwrap_or_default()); + let markov_chain = basic_markov::from_file_path_and_config("../../data/data.csv",config.clone()) + .unwrap_or(basic_markov::from_file_path_and_config("data.csv",config.clone()) + .unwrap()); - let term_frequency_threshold = match config.get("term_frequency_threshold") { - Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, - _ => 2 - }; - let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold); + // let term_frequency_threshold = match config.get("term_frequency_threshold") { + // Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, + // _ => 2 + // }; + //let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold); // Print the Markov Chain for verification - for (key, values) in &markov_chain { - //println!("{}: {:?}", key, values); - } + // for (key, values) in &markov_chain { + // //println!("{}: {:?}", key, values); + // } // Test the function let word = "wie"; // replace with the word you want to check - let top_words = &filtered_markov_chain.predict(&word, 3); + let top_words = &markov_chain.predict(&word, 3); println!("Example prediction for '{}': {:?}", word, top_words); @@ -56,7 +70,7 @@ fn main() -> Result<(), io::Error> { // request.headers() // ); match config.get("auth") { - Some(toml::Value::String(server_auth)) => { + Some(server_auth) => { match get_authkey(request.url().clone()) { Ok(client_auth) => { if client_auth != server_auth.clone() { @@ -75,10 +89,10 @@ fn main() -> Result<(), io::Error> { match query { Ok(query) => { let predict_count = match config.get("max_predict_count") { - Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, + Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5), _ => 5 }; - let prediction = &filtered_markov_chain.predict(&query, predict_count); + let prediction = &markov_chain.predict(&query, predict_count); //println!("Query: {}, Prediction:{}", query, prediction); let response = Response::from_string(prediction); request.respond(response); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index a3033d07cbc2eaf5a91c9af863b9b94083e49661..7839581a15c4f17504f6cef217695567967dee0b 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -6,22 +6,49 @@ use super::Predictor; pub type MarkovChain = HashMap<String, HashMap<String, usize>>; -impl Predictor for MarkovChain { +pub struct MarkovChainPredictor { + chain: MarkovChain, + configuration: HashMap<String, String>, +} + +impl Predictor for MarkovChainPredictor { fn new() -> Self { - HashMap::new() + MarkovChainPredictor { + chain: HashMap::new(), + configuration: HashMap::new() + } } fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized { - HashMap::new() + let mut configuration = HashMap::new(); + for (key, value) in config { + configuration.insert(key, value.into()); + } + + + MarkovChainPredictor { + chain: HashMap::new(), + configuration, + } } fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { + Some(list) => { + list.split_whitespace().collect() + }, + _ => Vec::new() + + }; + + //println!("blocklist:{:?}",self.configuration); + let lowercase_query = query.to_lowercase(); - let words: Vec<&str> = lowercase_query.split_whitespace().collect(); + let words: Vec<&str> = lowercase_query.split_whitespace().filter(|&x| !blocklist.contains(&x)).collect(); for window in words.windows(2) { if let [first, second] = window { - self + self.chain .entry(first.to_string().to_lowercase()) .or_default() .entry(second.to_string()) @@ -35,7 +62,12 @@ impl Predictor for MarkovChain { fn predict(&self, query: &str, n: usize) -> String { if let Some(top_words) = - get_top_following_words(self, query.split_whitespace().last().unwrap_or(""), n) + get_top_following_words( + &self.chain, + query.split_whitespace().last().unwrap_or(""), + n, + self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap() + ) { let predictions: Vec<String> = top_words .into_iter() @@ -51,6 +83,7 @@ fn get_top_following_words( markov_chain: &MarkovChain, word: &str, top_n: usize, + min_freq: usize ) -> Option<Vec<(String, usize)>> { let following_words = markov_chain.get(word)?; @@ -58,6 +91,7 @@ fn get_top_following_words( let mut sorted_words: Vec<(String, usize)> = following_words .iter() .map(|(word, &count)| (word.clone(), count)) + .filter(|&(_, count)| count >= min_freq) .collect(); sorted_words.sort_by(|a, b| b.1.cmp(&a.1)); @@ -66,10 +100,31 @@ fn get_top_following_words( } -pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> { +pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io::Error> { + let file = File::open(file_path)?; + let mut reader = ReaderBuilder::new().from_reader(file); + let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new(); + + for result in reader.records() { + let record = result?; + if let Some(query) = record.get(5) { + markov_chain.update(query); + } + } + + Ok(markov_chain) +} + + +pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<MarkovChainPredictor, std::io::Error> { + let mut configuration = HashMap::new(); + for (key, value) in config { + configuration.insert(key, value.into()); + } let file = File::open(file_path)?; let mut reader = ReaderBuilder::new().from_reader(file); - let mut markov_chain: MarkovChain = MarkovChain::new(); + let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new(); + markov_chain.configuration = configuration; for result in reader.records() { let record = result?; @@ -78,5 +133,6 @@ pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> { } } + Ok(markov_chain) } \ No newline at end of file