Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use std::{collections::HashMap, fs::File};
use csv::ReaderBuilder;
use super::Predictor;
pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
impl Predictor for MarkovChain {
fn new() -> Self
{
HashMap::new()
}
fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
HashMap::new()
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
for window in words.windows(2) {
if let [first, second] = window {
self
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
Ok(())
}
fn predict(&self, query: &str, n: usize) -> String {
if let Some(top_words) =
get_top_following_words(self, query.split_whitespace().last().unwrap_or(""), n)
{
let predictions: Vec<String> = top_words
.into_iter()
.map(|(word, _)| format!("\"{} {}\"", query, word))
.collect();
return format!("[\"{}\",[{}]]", query, predictions.join(","));
}
String::new()
}
}
fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
// Collect and sort the words by their counts in descending order
let mut sorted_words: Vec<(String, usize)> = following_words
.iter()
.map(|(word, &count)| (word.clone(), count))
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
// Return the top N words
Some(sorted_words.into_iter().take(top_n).collect())
}
pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChain = MarkovChain::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
}
}
Ok(markov_chain)
}