Skip to content
Snippets Groups Projects
Commit 6bb15dda authored by Phil Höfer's avatar Phil Höfer
Browse files

Implement Basic Set Predictor

parent 42bde9ef
No related branches found
No related tags found
No related merge requests found
Pipeline #9867 passed
......@@ -16,7 +16,7 @@ use csv::ReaderBuilder;
use tiny_http::{Server, Response};
mod predictors;
use predictors::{basic_markov, Predictor};
use predictors::{basic_markov,basic_set, Predictor};
use predictors::basic_markov::MarkovChain;
......@@ -38,8 +38,8 @@ fn main() -> Result<(), io::Error> {
}
}
let markov_chain = basic_markov::from_file_path_and_config("../../data/data.csv",config.clone())
.unwrap_or(basic_markov::from_file_path_and_config("data.csv",config.clone())
let markov_chain = basic_set::from_file_path_and_config("../../data/data.csv",config.clone())
.unwrap_or(basic_set::from_file_path_and_config("data.csv",config.clone())
.unwrap());
......
use std::{collections::HashMap, fs::File};
use csv::ReaderBuilder;
use super::Predictor;
pub struct SetPredictor {
set: HashMap<String,usize>,
configuration: HashMap<String, String>,
}
impl Predictor for SetPredictor {
fn new() -> Self
{
SetPredictor {
set: HashMap::new(),
configuration: HashMap::new()
}
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
SetPredictor {
set: HashMap::new(),
configuration,
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let blocklist: Vec<&str> = match self.configuration.get("blocked_words") {
Some(list) => {
list.split_whitespace().collect()
},
_ => Vec::new()
};
//println!("blocklist:{:?}",self.configuration);
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().filter(|&x| !blocklist.contains(&x)).collect();
for &word in &words {
let counter = self.set.entry(word.to_string()).or_insert(0);
*counter += 1;
}
Ok(())
}
fn predict(&self, query: &str, n: usize) -> String {
if let Some(top_words) =
get_top_completions(
&self,
query.split_whitespace().last().unwrap_or(""),
n,
self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap()
)
{
let query_prefix = query.rsplit_once(' ').map_or("", |(head, _)| head).to_string();
let predictions: Vec<String> = top_words
.into_iter()
.map(|(word, _)| format!("\"{} {}\"", query_prefix, word))
.collect();
return format!("[\"{}\",[{}]]", query, predictions.join(","));
}
String::new()
}
}
fn get_top_completions(
predictor: &SetPredictor,
word: &str,
top_n: usize,
min_freq: usize
) -> Option<Vec<(String, usize)>> {
Some(predictor.set.iter()
.filter(|(key, &value)| key.starts_with(word) && value >= min_freq)
.map(|(key, &value)| (key.clone(), value))
.take(top_n)
.collect())
}
pub fn from_file_path(file_path: &str) -> Result<SetPredictor, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: SetPredictor = SetPredictor::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
}
}
Ok(markov_chain)
}
pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<SetPredictor, std::io::Error> {
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: SetPredictor = SetPredictor::new();
markov_chain.configuration = configuration;
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
}
}
Ok(markov_chain)
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment