Newer
Older
// SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de>
// SPDX-License-Identifier: AGPL-3.0-only
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead, BufReader};
use std::error::Error;
use std::str::FromStr;
use csv::ReaderBuilder;
use tiny_http::{Server, Response};
mod predictors;
use predictors::{basic_markov, Predictor};
use predictors::basic_markov::MarkovChain;
let config = read_config("config.toml").unwrap().parse::<Table>().unwrap();
let markov_chain = basic_markov::from_file_path("../../data/data.csv")
.unwrap_or(basic_markov::from_file_path("data.csv")
let term_frequency_threshold = match config.get("term_frequency_threshold") {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
_ => 2
};
let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
// Print the Markov Chain for verification
for (key, values) in &markov_chain {
}
// Test the function
let word = "wie"; // replace with the word you want to check
let top_words = &filtered_markov_chain.predict(&word, 3);
println!("Example prediction for '{}': {:?}", word, top_words);
// println!("received request! method: {:?}, url: {:?}, headers: {:?}",
// request.method(),
// request.url(),
// request.headers()
// );
match config.get("auth") {
Some(toml::Value::String(server_auth)) => {
match get_authkey(request.url().clone()) {
Ok(client_auth) => {
if client_auth != server_auth.clone() {
println!("invalid auth:{}, server auth: {}", client_auth, server_auth);
request.respond(Response::from_string(""));
continue;
}
},
_ => {}
}
},
_ => {}
}
let predict_count = match config.get("max_predict_count") {
Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
_ => 5
};
let prediction = &filtered_markov_chain.predict(&query, predict_count);
let response = Response::from_string(prediction);
request.respond(response);
},
Err(e) => {
fn read_config(file_path: &str) -> Result<String, io::Error> {
fs::read_to_string(file_path)
}
fn get_query(request_url: &str) -> Result<String, url::ParseError> {
let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
//println!("parsed_url:{}", parsed_url);
let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
for (key, value) in query_pairs {
if key == "q" {
return Ok(value.into_owned());
}
}
Ok(String::from_str("").unwrap())
}
fn get_authkey(request_url: &str) -> Result<String, url::ParseError> {
let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
//println!("parsed_url:{}", parsed_url);
let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
for (key, value) in query_pairs {
//println!("key:{}, value: {}", key, value);
if key == "auth" {
return Ok(value.into_owned());
}
}
Ok(String::from_str("").unwrap())
}
fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
let mut filtered_chain: MarkovChain = HashMap::new();
for (key, followers) in markov_chain {
let filtered_followers: HashMap<String, usize> = followers
.iter()
.filter(|&(_, &count)| count >= min_count)
.map(|(word, &count)| (word.clone(), count))
.collect();
if !filtered_followers.is_empty() {
filtered_chain.insert(key.clone(), filtered_followers);
}
}
filtered_chain
}