-
Phil Höfer authoredPhil Höfer authored
main.rs 5.18 KiB
// SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de>
// SPDX-License-Identifier: AGPL-3.0-only
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead, BufReader};
use std::error::Error;
use std::str::FromStr;
use csv::ReaderBuilder;
type MarkovChain = HashMap<String, HashMap<String, usize>>;
use tiny_http::{Server, Response};
use url::Url;
fn main() -> Result<(), io::Error> {
let markov_chain = build_markov_chain("../../data/data.csv")
.unwrap_or(build_markov_chain("data.csv")
.unwrap_or_default());
let filtered_markov_chain = filter_markov_chain(&markov_chain,10);
// Print the Markov Chain for verification
for (key, values) in &markov_chain {
//println!("{}: {:?}", key, values);
}
// Test the function
let word = "wie"; // replace with the word you want to check
if let Some(top_words) = get_top_following_words(&filtered_markov_chain, word, 3) {
println!("Top following words for '{}': {:?}", word, top_words);
} else {
println!("No following words found for '{}'", word);
}
let server = Server::http("0.0.0.0:8000").unwrap();
for request in server.incoming_requests() {
// println!("received request! method: {:?}, url: {:?}, headers: {:?}",
// request.method(),
// request.url(),
// request.headers()
// );
let query = get_query(request.url());
//println!("got query:{}", query.clone().unwrap());
match query {
Ok(query) => {
let prediction = predictn(&filtered_markov_chain, &query,5);
//println!("Query: {}, Prediction:{}", query, prediction);
let response = Response::from_string(prediction);
request.respond(response);
},
Err(e) => {
//println!("Error: {}",e);
}
}
}
Ok(())
}
fn get_query(request_url: &str) -> Result<String, url::ParseError> {
let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
//println!("parsed_url:{}", parsed_url);
let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
for (key, value) in query_pairs {
//println!("key:{}, value: {}", key, value);
if key == "q" {
return Ok(value.into_owned());
}
}
Ok(String::from_str("").unwrap())
}
fn build_markov_chain(file_path: &str) -> Result<MarkovChain, io::Error> {
let file = File::open(file_path)?;
let mut reader = ReaderBuilder::new().from_reader(file);
let mut markov_chain: MarkovChain = HashMap::new();
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
let lowercase_query = query.to_lowercase();
let words: Vec<&str> = lowercase_query.split_whitespace().collect();
for window in words.windows(2) {
if let [first, second] = window {
markov_chain
.entry(first.to_string().to_lowercase())
.or_default()
.entry(second.to_string())
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
}
}
Ok(markov_chain)
}
fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
let mut filtered_chain: MarkovChain = HashMap::new();
for (key, followers) in markov_chain {
let filtered_followers: HashMap<String, usize> = followers
.iter()
.filter(|&(_, &count)| count >= min_count)
.map(|(word, &count)| (word.clone(), count))
.collect();
if !filtered_followers.is_empty() {
filtered_chain.insert(key.clone(), filtered_followers);
}
}
filtered_chain
}
fn get_top_following_words(
markov_chain: &MarkovChain,
word: &str,
top_n: usize,
) -> Option<Vec<(String, usize)>> {
let following_words = markov_chain.get(word)?;
// Collect and sort the words by their counts in descending order
let mut sorted_words: Vec<(String, usize)> = following_words.iter()
.map(|(word, &count)| (word.clone(), count))
.collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
// Return the top N words
Some(sorted_words.into_iter().take(top_n).collect())
}
fn predict(markov_chain: &MarkovChain, query: &str) -> String {
if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), 1) {
if let Some((predicted_word, _)) = top_words.first() {
return format!("{} {}", query, predicted_word);
}
}
String::new()
}
fn predictn(markov_chain: &MarkovChain, query: &str, n: usize) -> String {
if let Some(top_words) = get_top_following_words(markov_chain, query.split_whitespace().last().unwrap_or(""), n) {
let predictions: Vec<String> = top_words.into_iter()
.map(|(word, _)| format!("\"{} {}\"",query, word))
.collect();
return format!("[\"{}\",[{}]]",query, predictions.join(","));
}
String::new()
}