Skip to content
Snippets Groups Projects
Commit 9a41cdaf authored by Phil Höfer's avatar Phil Höfer
Browse files

Refactor Predictors into Trait

parent de9a97a7
No related branches found
No related tags found
1 merge request!6Resolve "Refactor Predictors to Allow For Choice and Configuration"
......@@ -13,18 +13,19 @@ use toml::Table;
use csv::ReaderBuilder;
type MarkovChain = HashMap<String, HashMap<String, usize>>;
use tiny_http::{Server, Response};
mod predictors;
use predictors::{basic_markov, Predictor};
use predictors::basic_markov::MarkovChain;
fn main() -> Result<(), io::Error> {
let config = read_config("config.toml").unwrap().parse::<Table>().unwrap();
let markov_chain = build_markov_chain("../../data/data.csv")
.unwrap_or(build_markov_chain("data.csv")
let markov_chain = basic_markov::from_file_path("../../data/data.csv")
.unwrap_or(basic_markov::from_file_path("data.csv")
.unwrap_or_default());
......@@ -41,11 +42,8 @@ fn main() -> Result<(), io::Error> {
// Test the function
let word = "wie"; // replace with the word you want to check
if let Some(top_words) = get_top_following_words(&filtered_markov_chain, word, 3) {
println!("Top following words for '{}': {:?}", word, top_words);
} else {
println!("No following words found for '{}'", word);
}
let top_words = &filtered_markov_chain.predict(&word, 3);
println!("Example prediction for '{}': {:?}", word, top_words);
......@@ -80,7 +78,7 @@ fn main() -> Result<(), io::Error> {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
_ => 5
};
let prediction = predictn(&filtered_markov_chain, &query, predict_count);
let prediction = &filtered_markov_chain.predict(&query, predict_count);
//println!("Query: {}, Prediction:{}", query, prediction);
let response = Response::from_string(prediction);
request.respond(response);
......@@ -98,6 +96,7 @@ fn read_config(file_path: &str) -> Result<String, io::Error> {
fs::read_to_string(file_path)
}
//TODO: unify get_X() functions
fn get_query(request_url: &str) -> Result<String, url::ParseError> {
let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
//println!("parsed_url:{}", parsed_url);
......@@ -123,31 +122,6 @@ fn get_authkey(request_url: &str) -> Result<String, url::ParseError> {
Ok(String::from_str("").unwrap())
}
fn build_markov_chain(file_path: &str) -> Result<MarkovChain, io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChain = HashMap::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
for window in words.windows(2) {
if let [first, second] = window {
markov_chain
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
}
}
Ok(markov_chain)
}
fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
let mut filtered_chain: MarkovChain = HashMap::new();
......@@ -166,41 +140,3 @@ fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovCh
filtered_chain
}
fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
// Collect and sort the words by their counts in descending order
let mut sorted_words: Vec<(String, usize)> = following_words.iter()
.map(|(word, &count)| (word.clone(), count))
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
// Return the top N words
Some(sorted_words.into_iter().take(top_n).collect())
}
fn predict(markov_chain: &MarkovChain, query: &str) -> String {
if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), 1) {
if let Some((predicted_word, _)) = top_words.first() {
return format!("{} {}", query, predicted_word);
}
}
String::new()
}
fn predictn(markov_chain: &MarkovChain, query: &str, n: usize) -> String {
if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), n) {
let predictions: Vec<String> = top_words.into_iter()
.map(|(word, _)| format!("\"{} {}\"",query, word))
.collect();
return format!("[\"{}\",[{}]]",query, predictions.join(","));
}
String::new()
}
use std::{collections::HashMap, fs::File};
use csv::ReaderBuilder;
use super::Predictor;
pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
impl Predictor for MarkovChain {
fn new() -> Self
{
HashMap::new()
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
HashMap::new()
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
for window in words.windows(2) {
if let [first, second] = window {
self
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
Ok(())
}
fn predict(&self, query: &str, n: usize) -> String {
if let Some(top_words) =
get_top_following_words(self, query.split_whitespace().last().unwrap_or(""), n)
{
let predictions: Vec<String> = top_words
.into_iter()
.map(|(word, _)| format!("\"{} {}\"", query, word))
.collect();
return format!("[\"{}\",[{}]]", query, predictions.join(","));
}
String::new()
}
}
fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
// Collect and sort the words by their counts in descending order
let mut sorted_words: Vec<(String, usize)> = following_words
.iter()
.map(|(word, &count)| (word.clone(), count))
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
// Return the top N words
Some(sorted_words.into_iter().take(top_n).collect())
}
pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChain = MarkovChain::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
}
}
Ok(markov_chain)
}
\ No newline at end of file
use std::collections::HashMap;
pub mod basic_set;
pub mod basic_markov;
pub trait Predictor {
fn predict(&self, query: &str, n: usize) -> String;
fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>;
fn new() -> Self where Self: Sized;
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized;
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment