// SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de> // SPDX-License-Identifier: AGPL-3.0-only use std::collections::HashMap; use std::fs::File; use std::io::{self, BufRead, BufReader}; use std::error::Error; use std::str::FromStr; use std::fs; use toml::Table; use csv::ReaderBuilder; type MarkovChain = HashMap<String, HashMap<String, usize>>; use tiny_http::{Server, Response}; fn main() -> Result<(), io::Error> { let config = read_config("config.toml").unwrap().parse::<Table>().unwrap(); let markov_chain = build_markov_chain("../../data/data.csv") .unwrap_or(build_markov_chain("data.csv") .unwrap_or_default()); let filtered_markov_chain = filter_markov_chain(&markov_chain,2); // Print the Markov Chain for verification for (key, values) in &markov_chain { //println!("{}: {:?}", key, values); } // Test the function let word = "wie"; // replace with the word you want to check if let Some(top_words) = get_top_following_words(&filtered_markov_chain, word, 3) { println!("Top following words for '{}': {:?}", word, top_words); } else { println!("No following words found for '{}'", word); } let server = Server::http("0.0.0.0:8000").unwrap(); for request in server.incoming_requests() { // println!("received request! method: {:?}, url: {:?}, headers: {:?}", // request.method(), // request.url(), // request.headers() // ); let query = get_query(request.url()); //println!("got query:{}", query.clone().unwrap()); match query { Ok(query) => { let prediction = predictn(&filtered_markov_chain, &query,5); //println!("Query: {}, Prediction:{}", query, prediction); let response = Response::from_string(prediction); request.respond(response); }, Err(e) => { //println!("Error: {}",e); } } } Ok(()) } fn read_config(file_path: &str) -> Result<String, io::Error> { fs::read_to_string(file_path) } fn get_query(request_url: &str) -> Result<String, url::ParseError> { let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after); //println!("parsed_url:{}", parsed_url); let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes()); for (key, value) in query_pairs { //println!("key:{}, value: {}", key, value); if key == "q" { return Ok(value.into_owned()); } } Ok(String::from_str("").unwrap()) } fn build_markov_chain(file_path: &str) -> Result<MarkovChain, io::Error> { let file = File::open(file_path)?; let mut reader = ReaderBuilder::new().from_reader(file); let mut markov_chain: MarkovChain = HashMap::new(); for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { let lowercase_query = query.to_lowercase(); let words: Vec<&str> = lowercase_query.split_whitespace().collect(); for window in words.windows(2) { if let [first, second] = window { markov_chain .entry(first.to_string().to_lowercase()) .or_default() .entry(second.to_string()) .and_modify(|count| *count += 1) .or_insert(1); } } } } Ok(markov_chain) } fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain { let mut filtered_chain: MarkovChain = HashMap::new(); for (key, followers) in markov_chain { let filtered_followers: HashMap<String, usize> = followers .iter() .filter(|&(_, &count)| count >= min_count) .map(|(word, &count)| (word.clone(), count)) .collect(); if !filtered_followers.is_empty() { filtered_chain.insert(key.clone(), filtered_followers); } } filtered_chain } fn get_top_following_words( markov_chain: &MarkovChain, word: &str, top_n: usize, ) -> Option<Vec<(String, usize)>> { let following_words = markov_chain.get(word)?; // Collect and sort the words by their counts in descending order let mut sorted_words: Vec<(String, usize)> = following_words.iter() .map(|(word, &count)| (word.clone(), count)) .collect(); sorted_words.sort_by(|a, b| b.1.cmp(&a.1)); // Return the top N words Some(sorted_words.into_iter().take(top_n).collect()) } fn predict(markov_chain: &MarkovChain, query: &str) -> String { if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), 1) { if let Some((predicted_word, _)) = top_words.first() { return format!("{} {}", query, predicted_word); } } String::new() } fn predictn(markov_chain: &MarkovChain, query: &str, n: usize) -> String { if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), n) { let predictions: Vec<String> = top_words.into_iter() .map(|(word, _)| format!("\"{} {}\"",query, word)) .collect(); return format!("[\"{}\",[{}]]",query, predictions.join(",")); } String::new() }