diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index be0825772abf37613bbcea7b4d0fa49f134816bf..963a7fb42d0c81c6c3fc2128e6bc0f890e2c66a5 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, fs::File}; +use std::{collections::HashMap, fs::{self, File}}; use csv::ReaderBuilder; @@ -46,13 +46,26 @@ impl Predictor for MarkovChainPredictor { } fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { - let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { + let mut raw_list = String::new(); + let blocklist: Vec<&str> = match self.configuration.get("blocklist") { Some(list) => { - list.split_whitespace().collect() + match fs::read_to_string(list) { + Ok(l) => { + raw_list = l.clone(); + raw_list.lines().collect() + }, + _ => Vec::new() + } }, _ => Vec::new() - }; + }; + + for word in blocklist.clone() { + if query.contains(word) { + return Ok(()); + } + } //println!("blocklist:{:?}",self.configuration); diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index a5eb354833a64a79f6bc1155e52b835e200f60e1..d208039ae489d97c7bda77f7fcb60a2b7fe3ffef 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, f32::consts::E, fs::File}; +use std::{collections::HashMap, f32::consts::E, fs::{self, File}}; use crate::importers::Importer; @@ -45,13 +45,26 @@ impl Predictor for SetPredictor { } fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { - let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { + let mut raw_list = String::new(); + let blocklist: Vec<&str> = match self.configuration.get("blocklist") { Some(list) => { - list.split_whitespace().collect() + match fs::read_to_string(list) { + Ok(l) => { + raw_list = l.clone(); + raw_list.lines().collect() + }, + _ => Vec::new() + } }, _ => Vec::new() - }; + }; + + for word in blocklist.clone() { + if query.contains(word) { + return Ok(()); + } + } //println!("blocklist:{:?}",self.configuration);