Skip to content
Snippets Groups Projects
basic_markov.rs 3.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • use std::{collections::HashMap, fs::File};
    
    use csv::ReaderBuilder;
    
    use super::Predictor;
    
    pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
    
    
    pub struct MarkovChainPredictor {
        chain: MarkovChain,
        configuration: HashMap<String, String>,
    }
    
    impl Predictor for MarkovChainPredictor {
    
        fn new() -> Self
        {
    
            MarkovChainPredictor {
                chain: HashMap::new(),
                configuration: HashMap::new()
            }
    
        }
    
        fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized {
    
            let mut configuration = HashMap::new();
            for (key, value) in config {
                configuration.insert(key, value.into());
            }
            
    
            MarkovChainPredictor {
                chain: HashMap::new(),
                configuration,
            }
    
        }
    
        fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
            let lowercase_query = query.to_lowercase();
            let words: Vec<&str> = lowercase_query.split_whitespace().collect();
            for window in words.windows(2) {
                if let [first, second] = window {
    
                        .entry(first.to_string().to_lowercase())
                        .or_default()
                        .entry(second.to_string())
                        .and_modify(|count| *count += 1)
                        .or_insert(1);
                }
            }
    
            Ok(())
        }
    
        fn predict(&self, query: &str, n: usize) -> String {
            if let Some(top_words) =
    
                get_top_following_words(
                    &self.chain,
                    query.split_whitespace().last().unwrap_or(""),
                    n,
                    self.configuration.get("term_frequency_threshold").unwrap_or(&String::from("2")).parse::<usize>().unwrap()
                )
    
            {
                let predictions: Vec<String> = top_words
                    .into_iter()
                    .map(|(word, _)| format!("\"{} {}\"", query, word))
                    .collect();
                return format!("[\"{}\",[{}]]", query, predictions.join(","));
            }
            String::new()
        }
    }
    
    fn get_top_following_words(
        markov_chain: &MarkovChain,
        word: &str,
        top_n: usize,
    
    ) -> Option<Vec<(String, usize)>> {
        let following_words = markov_chain.get(word)?;
    
        // Collect and sort the words by their counts in descending order
        let mut sorted_words: Vec<(String, usize)> = following_words
            .iter()
            .map(|(word, &count)| (word.clone(), count))
    
            .filter(|&(_, count)| count >= min_freq)
    
            .collect();
        sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
    
        // Return the top N words
        Some(sorted_words.into_iter().take(top_n).collect())
    }
    
    
    
    pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io::Error> {
    
        let file = File::open(file_path)?;
        let mut reader = ReaderBuilder::new().from_reader(file);
    
        let mut markov_chain: MarkovChainPredictor = MarkovChainPredictor::new();
    
    
        for result in reader.records() {
            let record = result?;
            if let Some(query) = record.get(5) {
                markov_chain.update(query);
            }
        }
    
    
        Ok(markov_chain)
    }
    
    
    pub fn from_file_path_and_config(file_path: &str, config: HashMap<String, impl Into<String>>) -> Result<MarkovChainPredictor, std::io::Error> {
        let mut markov_chain: MarkovChainPredictor = from_file_path(file_path)?;
        let mut configuration = HashMap::new();
        for (key, value) in config {
            configuration.insert(key, value.into());
        }
        markov_chain.configuration = configuration;
    
        println!("{}",markov_chain.configuration.get("term_frequency_threshold").unwrap());
    
    
        Ok(markov_chain)
    }