Skip to content
Snippets Groups Projects
basic_markov.rs 2.37 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::{collections::HashMap, fs::File};
    
    use csv::ReaderBuilder;
    
    use super::Predictor;
    
    pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
    
    impl Predictor for MarkovChain {
        fn new() -> Self
        {
            HashMap::new()
        }
    
        fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
            HashMap::new()
        }
    
        fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
            let lowercase_query = query.to_lowercase();
            let words: Vec<&str> = lowercase_query.split_whitespace().collect();
            for window in words.windows(2) {
                if let [first, second] = window {
                    self
                        .entry(first.to_string().to_lowercase())
                        .or_default()
                        .entry(second.to_string())
                        .and_modify(|count| *count += 1)
                        .or_insert(1);
                }
            }
    
            Ok(())
        }
    
        fn predict(&self, query: &str, n: usize) -> String {
            if let Some(top_words) =
                get_top_following_words(self, query.split_whitespace().last().unwrap_or(""), n)
            {
                let predictions: Vec<String> = top_words
                    .into_iter()
                    .map(|(word, _)| format!("\"{} {}\"", query, word))
                    .collect();
                return format!("[\"{}\",[{}]]", query, predictions.join(","));
            }
            String::new()
        }
    }
    
    fn get_top_following_words(
        markov_chain: &MarkovChain,
        word: &str,
        top_n: usize,
    ) -> Option<Vec<(String, usize)>> {
        let following_words = markov_chain.get(word)?;
    
        // Collect and sort the words by their counts in descending order
        let mut sorted_words: Vec<(String, usize)> = following_words
            .iter()
            .map(|(word, &count)| (word.clone(), count))
            .collect();
        sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
    
        // Return the top N words
        Some(sorted_words.into_iter().take(top_n).collect())
    }
    
    
    pub fn from_file_path(file_path: &str) -> Result<MarkovChain, std::io::Error> {
        let file = File::open(file_path)?;
        let mut reader = ReaderBuilder::new().from_reader(file);
        let mut markov_chain: MarkovChain = MarkovChain::new();
    
        for result in reader.records() {
            let record = result?;
            if let Some(query) = record.get(5) {
                markov_chain.update(query);
            }
        }
    
        Ok(markov_chain)
    }