-
Phil Höfer authoredPhil Höfer authored
main.rs 7.20 KiB
// SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de>
// SPDX-License-Identifier: AGPL-3.0-only
use core::str;
use std::collections::HashMap;
use std::fs::File;
use std::io::{self, BufRead, BufReader};
use std::error::Error;
use std::str::FromStr;
use std::fs;
use predictors::basic_set::SetPredictor;
use toml::Table;
use csv::ReaderBuilder;
use tiny_http::{Server, Response};
mod predictors;
use predictors::{basic_markov,basic_set, Predictor};
use predictors::basic_markov::MarkovChain;
use postgres::{Client, NoTls};
fn main() -> Result<(), io::Error> {
let mut last_update = std::time::SystemTime::now();
let config_toml = read_config("data/config.toml").or(read_config("config.toml")).unwrap().parse::<Table>().unwrap();
let mut config = HashMap::new();
for (key, value) in config_toml {
match value {
toml::Value::String(s) => {
config.insert(key, s);
}
toml::Value::Integer(i) => {
config.insert(key, i.to_string());
}
_ => {
// Ignore other types
}
}
}
let mut markov_chain = basic_set::from_file_path_and_config(
vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"],config.clone())
.unwrap_or(basic_set::SetPredictor::new());
markov_chain = read_from_db(config.clone(), markov_chain);
markov_chain.decay();
// let term_frequency_threshold = match config.get("term_frequency_threshold") {
// Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
// _ => 2
// };
//let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
// Print the Markov Chain for verification
// for (key, values) in &markov_chain {
// //println!("{}: {:?}", key, values);
// }
// Test the function
let word = "wie"; // replace with the word you want to check
let top_words = &markov_chain.predict(&word, 3);
println!("Example prediction for '{}': {:?}", word, top_words);
let server = Server::http("0.0.0.0:8000").unwrap();
for request in server.incoming_requests() {
// println!("received request! method: {:?}, url: {:?}, headers: {:?}",
// request.method(),
// request.url(),
// request.headers()
// );
match config.get("auth") {
Some(server_auth) => {
match get_authkey(request.url().clone()) {
Ok(client_auth) => {
if client_auth != server_auth.clone() {
println!("invalid auth:{}, server auth: {}", client_auth, server_auth);
request.respond(Response::from_string(""));
continue;
}
},
_ => {}
}
},
_ => {}
}
let query = get_query(request.url());
//println!("got query:{}", query.clone().unwrap());
match query {
Ok(query) => {
let predict_count = match config.get("max_predict_count") {
Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5),
_ => 5
};
let prediction = &markov_chain.predict(&query, predict_count);
//println!("Query: {}, Prediction:{}", query, prediction);
let response = Response::from_string(prediction);
request.respond(response);
let now = std::time::SystemTime::now();
let elapsed = now.duration_since(last_update).expect("Time went backwards");
if elapsed >= std::time::Duration::from_secs(24 * 60 * 60) {
markov_chain.decay();
markov_chain = read_from_db(config.clone(), markov_chain);
last_update = now;
}
},
Err(e) => {
//println!("Error: {}",e);
}
}
}
Ok(())
}
fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) -> SetPredictor {
let default_password = &String::from("");
match (config.get("db_host"), config.get("db_password")) {
(Some(db_host), Some(db_password)) => {
let mut client = Client::connect(format!("host={} user=metager password={} dbname=postgres",db_host, db_password).as_str(), NoTls);
match client {
Ok(mut c) => {
let mut count = 0;
match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 100000", &[]) {
Ok(rows) => {
for row in rows {
let query: &str = row.get(0);
predictor.update(query);
count += 1;
}
println!("{} queries read from DB", count);
predictor
},
Err(e) => {
println!("Error while reading from DB: {}", e);
predictor
}
}
},
Err(e) => {
println!("Error while connecting to DB: {}", e);
predictor
}
}
},
_ => {
println!("No DB config found. Skipping...");
predictor
}
}
}
fn read_config(file_path: &str) -> Result<String, io::Error> {
fs::read_to_string(file_path)
}
//TODO: unify get_X() functions
fn get_query(request_url: &str) -> Result<String, url::ParseError> {
let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
//println!("parsed_url:{}", parsed_url);
let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
for (key, value) in query_pairs {
//println!("key:{}, value: {}", key, value);
if key == "q" {
return Ok(value.into_owned());
}
}
Ok(String::from_str("").unwrap())
}
fn get_authkey(request_url: &str) -> Result<String, url::ParseError> {
let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
//println!("parsed_url:{}", parsed_url);
let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
for (key, value) in query_pairs {
//println!("key:{}, value: {}", key, value);
if key == "auth" {
return Ok(value.into_owned());
}
}
Ok(String::from_str("").unwrap())
}
fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
let mut filtered_chain: MarkovChain = HashMap::new();
for (key, followers) in markov_chain {
let filtered_followers: HashMap<String, usize> = followers
.iter()
.filter(|&(_, &count)| count >= min_count)
.map(|(word, &count)| (word.clone(), count))
.collect();
if !filtered_followers.is_empty() {
filtered_chain.insert(key.clone(), filtered_followers);
}
}
filtered_chain
}