Skip to content
Snippets Groups Projects
main.rs 5.77 KiB
Newer Older
// SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de>
// SPDX-License-Identifier: AGPL-3.0-only

Phil Höfer's avatar
Phil Höfer committed
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead, BufReader};
use std::error::Error;
use std::str::FromStr;
Phil Höfer's avatar
Phil Höfer committed
use std::fs;

use toml::Table;
Phil Höfer's avatar
Phil Höfer committed

use csv::ReaderBuilder;


type MarkovChain = HashMap<String, HashMap<String, usize>>;

use tiny_http::{Server, Response};



fn main() -> Result<(), io::Error> {

Phil Höfer's avatar
Phil Höfer committed
    let config = read_config("config.toml").unwrap().parse::<Table>().unwrap();

Phil Höfer's avatar
Phil Höfer committed
    let markov_chain = build_markov_chain("../../data/data.csv")
                        .unwrap_or(build_markov_chain("data.csv")
                            .unwrap_or_default());


    let term_frequency_threshold = match config.get("term_frequency_threshold") {
        Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
        _ => 2
    };
    let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
Phil Höfer's avatar
Phil Höfer committed
    // Print the Markov Chain for verification
    for (key, values) in &markov_chain {
Phil Höfer's avatar
Phil Höfer committed
         //println!("{}: {:?}", key, values);
Phil Höfer's avatar
Phil Höfer committed
    }

    // Test the function
    let word = "wie"; // replace with the word you want to check
    if let Some(top_words) = get_top_following_words(&filtered_markov_chain, word, 3) {
        println!("Top following words for '{}': {:?}", word, top_words);
    } else {
        println!("No following words found for '{}'", word);
    }



Phil Höfer's avatar
Phil Höfer committed
    let server = Server::http("0.0.0.0:8000").unwrap();
Phil Höfer's avatar
Phil Höfer committed

    for request in server.incoming_requests() {
Phil Höfer's avatar
Phil Höfer committed
        //   println!("received request! method: {:?}, url: {:?}, headers: {:?}",
        //       request.method(),
        //       request.url(),
        //       request.headers()
        //   );
Phil Höfer's avatar
Phil Höfer committed
        let query = get_query(request.url());
Phil Höfer's avatar
Phil Höfer committed
        //println!("got query:{}", query.clone().unwrap());
Phil Höfer's avatar
Phil Höfer committed
        match query {
            Ok(query) => {
                let predict_count = match config.get("max_predict_count") {
                    Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
                    _ => 5
                };
                let prediction = predictn(&filtered_markov_chain, &query, predict_count);
Phil Höfer's avatar
Phil Höfer committed
                //println!("Query: {}, Prediction:{}", query, prediction);
Phil Höfer's avatar
Phil Höfer committed
                let response = Response::from_string(prediction);
                request.respond(response);
            },
            Err(e) => {
Phil Höfer's avatar
Phil Höfer committed
                //println!("Error: {}",e);
Phil Höfer's avatar
Phil Höfer committed
fn read_config(file_path: &str) -> Result<String, io::Error> {
    fs::read_to_string(file_path)
}

Phil Höfer's avatar
Phil Höfer committed
fn get_query(request_url: &str) -> Result<String, url::ParseError> {
Phil Höfer's avatar
Phil Höfer committed
    let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
    //println!("parsed_url:{}", parsed_url);
Phil Höfer's avatar
Phil Höfer committed
    let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
    for (key, value) in query_pairs {
Phil Höfer's avatar
Phil Höfer committed
        //println!("key:{}, value: {}", key, value);
Phil Höfer's avatar
Phil Höfer committed
        if key == "q" {
            return Ok(value.into_owned());
        }
    }
    Ok(String::from_str("").unwrap())
}

fn build_markov_chain(file_path: &str) -> Result<MarkovChain, io::Error> {
    let file = File::open(file_path)?;
    let mut reader = ReaderBuilder::new().from_reader(file);
    let mut markov_chain: MarkovChain = HashMap::new();

    for result in reader.records() {
        let record = result?;
        if let Some(query) = record.get(5) {
            let lowercase_query = query.to_lowercase();
            let words: Vec<&str> = lowercase_query.split_whitespace().collect();
            for window in words.windows(2) {
                if let [first, second] = window {
                    markov_chain
                        .entry(first.to_string().to_lowercase())
                        .or_default()
                        .entry(second.to_string())
                        .and_modify(|count| *count += 1)
                        .or_insert(1);
                }
            }
        }
    }

    Ok(markov_chain)
}

fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
    let mut filtered_chain: MarkovChain = HashMap::new();

    for (key, followers) in markov_chain {
        let filtered_followers: HashMap<String, usize> = followers
            .iter()
            .filter(|&(_, &count)| count >= min_count)
            .map(|(word, &count)| (word.clone(), count))
            .collect();

        if !filtered_followers.is_empty() {
            filtered_chain.insert(key.clone(), filtered_followers);
        }
    }

    filtered_chain
}

fn get_top_following_words(
    markov_chain: &MarkovChain,
    word: &str,
    top_n: usize,
) -> Option<Vec<(String, usize)>> {
    let following_words = markov_chain.get(word)?;

    // Collect and sort the words by their counts in descending order
    let mut sorted_words: Vec<(String, usize)> = following_words.iter()
        .map(|(word, &count)| (word.clone(), count))
        .collect();
    sorted_words.sort_by(|a, b| b.1.cmp(&a.1));

    // Return the top N words
    Some(sorted_words.into_iter().take(top_n).collect())
}

fn predict(markov_chain: &MarkovChain, query: &str) -> String {
Phil Höfer's avatar
Phil Höfer committed
    if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), 1) {
Phil Höfer's avatar
Phil Höfer committed
        if let Some((predicted_word, _)) = top_words.first() {
            return format!("{} {}", query, predicted_word);
        }
    }
    String::new()
}

fn predictn(markov_chain: &MarkovChain, query: &str, n: usize) -> String {
Phil Höfer's avatar
Phil Höfer committed
    if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), n) {
Phil Höfer's avatar
Phil Höfer committed
        let predictions: Vec<String> = top_words.into_iter()
            .map(|(word, _)| format!("\"{} {}\"",query, word))
            .collect();
        return format!("[\"{}\",[{}]]",query, predictions.join(","));
    }
    String::new()
}