Newer
Older
use std::{collections::HashMap, fs::File};
use csv::ReaderBuilder;
use super::Predictor;
pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
pub struct MarkovChainPredictor {
chain: MarkovChain,
configuration: HashMap<String, String>,
}
impl Predictor for MarkovChainPredictor {
MarkovChainPredictor {
chain: HashMap::new(),
configuration: HashMap::new()
}
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
MarkovChainPredictor {
chain: HashMap::new(),
configuration,
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
for window in words.windows(2) {
if let [first, second] = window {
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
Ok(())
}
fn predict(&self, query: &str, n: usize) -> String {
if let Some(top_words) =
get_top_following_words(
&self.chain,
query.split_whitespace().last().unwrap_or(""),
n,
self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap()
)
{
let predictions: Vec<String> = top_words
.into_iter()
.map(|(word, _)| format!("\"{} {}\"", query, word))
.collect();
return format!("[\"{}\",[{}]]", query, predictions.join(","));
}
String::new()
}
}
fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
// Collect and sort the words by their counts in descending order
let mut sorted_words: Vec<(String, usize)> = following_words
.iter()
.map(|(word, &count)| (word.clone(), count))
.filter(|&(_, count)| count >= min_freq)
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
// Return the top N words
Some(sorted_words.into_iter().take(top_n).collect())
}
pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
}
}
Ok(markov_chain)
}
pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<MarkovChainPredictor, std::io::Error> {
let mut markov_chain: MarkovChainPredictor = from_file_path(file_path)?;
let mut configuration = HashMap::new();
for (key, value) in config {
configuration.insert(key, value.into());
}
markov_chain.configuration = configuration;
println!("{}",markov_chain.configuration.get("term_frequency_threshold").unwrap());