Skip to content
Snippets Groups Projects
main.rs 4.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • // SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de>
    // SPDX-License-Identifier: AGPL-3.0-only
    
    
    Phil Höfer's avatar
    Phil Höfer committed
    use std::collections::HashMap;
    use std::fs::File;
    use std::io::{self, BufRead, BufReader};
    use std::error::Error;
    use std::str::FromStr;
    
    use csv::ReaderBuilder;
    
    
    type MarkovChain = HashMap<String, HashMap<String, usize>>;
    
    use tiny_http::{Server, Response};
    use url::Url;
    
    
    
    fn main() -> Result<(), io::Error> {
    
    
    Phil Höfer's avatar
    Phil Höfer committed
        let markov_chain = build_markov_chain("data.csv")?;
    
    Phil Höfer's avatar
    Phil Höfer committed
        let filtered_markov_chain = filter_markov_chain(&markov_chain,1);
     
        // Print the Markov Chain for verification
        for (key, values) in &markov_chain {
            // println!("{}: {:?}", key, values);
        }
    
        // Test the function
        let word = "wie"; // replace with the word you want to check
        if let Some(top_words) = get_top_following_words(&filtered_markov_chain, word, 3) {
            println!("Top following words for '{}': {:?}", word, top_words);
        } else {
            println!("No following words found for '{}'", word);
        }
    
    
    
        let server = Server::http("0.0.0.0:80").unwrap();
    
        for request in server.incoming_requests() {
            //  println!("received request! method: {:?}, url: {:?}, headers: {:?}",
            //      request.method(),
            //      request.url(),
            //      request.headers()
            //  );
            let query = get_query(request.url());
            match query {
                Ok(query) => {
                    let prediction = predictn(&filtered_markov_chain, &query,5);
                    println!("Query: {}, Prediction:{}", query, prediction);
                    let response = Response::from_string(prediction);
                    request.respond(response);
                },
                Err(e) => {
                    println!("Error: {}",e);
                }
            }
        }
    
        Ok(())
    }
    
    fn get_query(request_url: &str) -> Result<String, url::ParseError> {
        let parsed_url = request_url.strip_prefix("/?").unwrap_or(request_url);
        let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
        for (key, value) in query_pairs {
            if key == "q" {
                return Ok(value.into_owned());
            }
        }
        Ok(String::from_str("").unwrap())
    }
    
    fn build_markov_chain(file_path: &str) -> Result<MarkovChain, io::Error> {
        let file = File::open(file_path)?;
        let mut reader = ReaderBuilder::new().from_reader(file);
        let mut markov_chain: MarkovChain = HashMap::new();
    
        for result in reader.records() {
            let record = result?;
            if let Some(query) = record.get(5) {
                let lowercase_query = query.to_lowercase();
                let words: Vec<&str> = lowercase_query.split_whitespace().collect();
                for window in words.windows(2) {
                    if let [first, second] = window {
                        markov_chain
                            .entry(first.to_string().to_lowercase())
                            .or_default()
                            .entry(second.to_string())
                            .and_modify(|count| *count += 1)
                            .or_insert(1);
                    }
                }
            }
        }
    
        Ok(markov_chain)
    }
    
    fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
        let mut filtered_chain: MarkovChain = HashMap::new();
    
        for (key, followers) in markov_chain {
            let filtered_followers: HashMap<String, usize> = followers
                .iter()
                .filter(|&(_, &count)| count >= min_count)
                .map(|(word, &count)| (word.clone(), count))
                .collect();
    
            if !filtered_followers.is_empty() {
                filtered_chain.insert(key.clone(), filtered_followers);
            }
        }
    
        filtered_chain
    }
    
    fn get_top_following_words(
        markov_chain: &MarkovChain,
        word: &str,
        top_n: usize,
    ) -> Option<Vec<(String, usize)>> {
        let following_words = markov_chain.get(word)?;
    
        // Collect and sort the words by their counts in descending order
        let mut sorted_words: Vec<(String, usize)> = following_words.iter()
            .map(|(word, &count)| (word.clone(), count))
            .collect();
        sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
    
        // Return the top N words
        Some(sorted_words.into_iter().take(top_n).collect())
    }
    
    fn predict(markov_chain: &MarkovChain, query: &str) -> String {
        if let Some(top_words) = get_top_following_words(markov_chain, query, 1) {
            if let Some((predicted_word, _)) = top_words.first() {
                return format!("{} {}", query, predicted_word);
            }
        }
        String::new()
    }
    
    fn predictn(markov_chain: &MarkovChain, query: &str, n: usize) -> String {
        if let Some(top_words) = get_top_following_words(markov_chain, query, n) {
            let predictions: Vec<String> = top_words.into_iter()
                .map(|(word, _)| format!("\"{} {}\"",query, word))
                .collect();
            return format!("[\"{}\",[{}]]",query, predictions.join(","));
        }
        String::new()
    }