Skip to content
Snippets Groups Projects
Commit 42bde9ef authored by Phil Höfer's avatar Phil Höfer
Browse files

Merge branch '2-implement-filters-for-trait-based-predictors' into 'main'

Resolve "Implement Filters for Trait-based Predictors"

Closes #2

See merge request !7
parents cf8c7190 5ebce499
No related branches found
No related tags found
1 merge request!7Resolve "Implement Filters for Trait-based Predictors"
Pipeline #9863 passed
auth = "12345"
term_frequency_threshold = 2
max_predict_count = 5
\ No newline at end of file
max_predict_count = 5
blocked_words = ""
\ No newline at end of file
......@@ -22,27 +22,41 @@ use predictors::basic_markov::MarkovChain;
fn main() -> Result<(), io::Error> {
let config = read_config("config.toml").unwrap().parse::<Table>().unwrap();
let config_toml = read_config("config.toml").unwrap().parse::<Table>().unwrap();
let mut config = HashMap::new();
for (key, value) in config_toml {
match value {
toml::Value::String(s) => {
config.insert(key, s);
}
toml::Value::Integer(i) => {
config.insert(key, i.to_string());
}
_ => {
// Ignore other types
}
}
}
let markov_chain = basic_markov::from_file_path("../../data/data.csv")
.unwrap_or(basic_markov::from_file_path("data.csv")
.unwrap_or_default());
let markov_chain = basic_markov::from_file_path_and_config("../../data/data.csv",config.clone())
.unwrap_or(basic_markov::from_file_path_and_config("data.csv",config.clone())
.unwrap());
let term_frequency_threshold = match config.get("term_frequency_threshold") {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
_ => 2
};
let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
// let term_frequency_threshold = match config.get("term_frequency_threshold") {
// Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
// _ => 2
// };
//let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
// Print the Markov Chain for verification
for (key, values) in &markov_chain {
//println!("{}: {:?}", key, values);
}
// for (key, values) in &markov_chain {
// //println!("{}: {:?}", key, values);
// }
// Test the function
let word = "wie"; // replace with the word you want to check
let top_words = &filtered_markov_chain.predict(&word, 3);
let top_words = &markov_chain.predict(&word, 3);
println!("Example prediction for '{}': {:?}", word, top_words);
......@@ -56,7 +70,7 @@ fn main() -> Result<(), io::Error> {
// request.headers()
// );
match config.get("auth") {
Some(toml::Value::String(server_auth)) => {
Some(server_auth) => {
match get_authkey(request.url().clone()) {
Ok(client_auth) => {
if client_auth != server_auth.clone() {
......@@ -75,10 +89,10 @@ fn main() -> Result<(), io::Error> {
match query {
Ok(query) => {
let predict_count = match config.get("max_predict_count") {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5),
_ => 5
};
let prediction = &filtered_markov_chain.predict(&query, predict_count);
let prediction = &markov_chain.predict(&query, predict_count);
//println!("Query: {}, Prediction:{}", query, prediction);
let response = Response::from_string(prediction);
request.respond(response);
......
......@@ -6,22 +6,49 @@ use super::Predictor;
pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
impl Predictor for MarkovChain {
pub struct MarkovChainPredictor {
chain: MarkovChain,
configuration: HashMap<String, String>,
}
impl Predictor for MarkovChainPredictor {
fn new() -> Self
{
HashMap::new()
MarkovChainPredictor {
chain: HashMap::new(),
configuration: HashMap::new()
}
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
HashMap::new()
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
MarkovChainPredictor {
chain: HashMap::new(),
configuration,
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let blocklist: Vec<&str> = match self.configuration.get("blocked_words") {
Some(list) => {
list.split_whitespace().collect()
},
_ => Vec::new()
};
//println!("blocklist:{:?}",self.configuration);
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
let words: Vec<&str> = lowercase_query.split_whitespace().filter(|&x| !blocklist.contains(&x)).collect();
for window in words.windows(2) {
if let [first, second] = window {
self
self.chain
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
......@@ -35,7 +62,12 @@ impl Predictor for MarkovChain {
fn predict(&self, query: &str, n: usize) -> String {
if let Some(top_words) =
get_top_following_words(self, query.split_whitespace().last().unwrap_or(""), n)
get_top_following_words(
&self.chain,
query.split_whitespace().last().unwrap_or(""),
n,
self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap()
)
{
let predictions: Vec<String> = top_words
.into_iter()
......@@ -51,6 +83,7 @@ fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
min_freq: usize
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
......@@ -58,6 +91,7 @@ fn get_top_following_words(
let mut sorted_words: Vec<(String, usize)> = following_words
.iter()
.map(|(word, &count)| (word.clone(), count))
.filter(|&(_, count)| count >= min_freq)
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
......@@ -66,10 +100,31 @@ fn get_top_following_words(
}
pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
}
}
Ok(markov_chain)
}
pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<MarkovChainPredictor, std::io::Error> {
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChain = MarkovChain::new();
let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new();
markov_chain.configuration = configuration;
for result in reader.records() {
let record = result?;
......@@ -78,5 +133,6 @@ pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
}
}
Ok(markov_chain)
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment