Skip to content
Snippets Groups Projects
Commit 1f44ba21 authored by Phil Höfer's avatar Phil Höfer
Browse files

Implement Term Threshold for Traid Based Markov Predictor

parent cf8c7190
No related branches found
No related tags found
1 merge request!7Resolve "Implement Filters for Trait-based Predictors"
......@@ -22,27 +22,41 @@ use predictors::basic_markov::MarkovChain;
fn main() -> Result<(), io::Error> {
let config = read_config("config.toml").unwrap().parse::<Table>().unwrap();
let config_toml = read_config("config.toml").unwrap().parse::<Table>().unwrap();
let mut config = HashMap::new();
for (key, value) in config_toml {
match value {
toml::Value::String(s) => {
config.insert(key, s);
}
toml::Value::Integer(i) => {
config.insert(key, i.to_string());
}
_ => {
// Ignore other types
}
}
}
let markov_chain = basic_markov::from_file_path("../../data/data.csv")
.unwrap_or(basic_markov::from_file_path("data.csv")
.unwrap_or_default());
let markov_chain = basic_markov::from_file_path_and_config("../../data/data.csv",config.clone())
.unwrap_or(basic_markov::from_file_path_and_config("data.csv",config.clone())
.unwrap());
let term_frequency_threshold = match config.get("term_frequency_threshold") {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
_ => 2
};
let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
// let term_frequency_threshold = match config.get("term_frequency_threshold") {
// Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
// _ => 2
// };
//let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
// Print the Markov Chain for verification
for (key, values) in &markov_chain {
//println!("{}: {:?}", key, values);
}
// for (key, values) in &markov_chain {
// //println!("{}: {:?}", key, values);
// }
// Test the function
let word = "wie"; // replace with the word you want to check
let top_words = &filtered_markov_chain.predict(&word, 3);
let top_words = &markov_chain.predict(&word, 3);
println!("Example prediction for '{}': {:?}", word, top_words);
......@@ -56,7 +70,7 @@ fn main() -> Result<(), io::Error> {
// request.headers()
// );
match config.get("auth") {
Some(toml::Value::String(server_auth)) => {
Some(server_auth) => {
match get_authkey(request.url().clone()) {
Ok(client_auth) => {
if client_auth != server_auth.clone() {
......@@ -75,10 +89,10 @@ fn main() -> Result<(), io::Error> {
match query {
Ok(query) => {
let predict_count = match config.get("max_predict_count") {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5),
_ => 5
};
let prediction = &filtered_markov_chain.predict(&query, predict_count);
let prediction = &markov_chain.predict(&query, predict_count);
//println!("Query: {}, Prediction:{}", query, prediction);
let response = Response::from_string(prediction);
request.respond(response);
......
......@@ -6,14 +6,31 @@ use super::Predictor;
pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
impl Predictor for MarkovChain {
pub struct MarkovChainPredictor {
chain: MarkovChain,
configuration: HashMap<String, String>,
}
impl Predictor for MarkovChainPredictor {
fn new() -> Self
{
HashMap::new()
MarkovChainPredictor {
chain: HashMap::new(),
configuration: HashMap::new()
}
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
HashMap::new()
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
MarkovChainPredictor {
chain: HashMap::new(),
configuration,
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
......@@ -21,7 +38,7 @@ impl Predictor for MarkovChain {
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
for window in words.windows(2) {
if let [first, second] = window {
self
self.chain
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
......@@ -35,7 +52,12 @@ impl Predictor for MarkovChain {
fn predict(&self, query: &str, n: usize) -> String {
if let Some(top_words) =
get_top_following_words(self, query.split_whitespace().last().unwrap_or(""), n)
get_top_following_words(
&self.chain,
query.split_whitespace().last().unwrap_or(""),
n,
self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap()
)
{
let predictions: Vec<String> = top_words
.into_iter()
......@@ -51,6 +73,7 @@ fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
min_freq: usize
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
......@@ -58,6 +81,7 @@ fn get_top_following_words(
let mut sorted_words: Vec<(String, usize)> = following_words
.iter()
.map(|(word, &count)| (word.clone(), count))
.filter(|&(_, count)| count >= min_freq)
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
......@@ -66,10 +90,10 @@ fn get_top_following_words(
}
pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChain = MarkovChain::new();
let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new();
for result in reader.records() {
let record = result?;
......@@ -78,5 +102,19 @@ pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
}
}
Ok(markov_chain)
}
pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<MarkovChainPredictor, std::io::Error> {
let mut markov_chain: MarkovChainPredictor = from_file_path(file_path)?;
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
markov_chain.configuration = configuration;
println!("{}",markov_chain.configuration.get("term_frequency_threshold").unwrap());
Ok(markov_chain)
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment