// SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de> // SPDX-License-Identifier: AGPL-3.0-only use core::str; use std::collections::HashMap; use std::fs::File; use std::io::{self, BufRead, BufReader}; use std::error::Error; use std::str::FromStr; use std::fs; use std::time::SystemTime; use importers::Importer; use importers::file::FileImporter; use predictors::basic_set::SetPredictor; use predictors::composite::CompositePredictor; use toml::Table; use csv::ReaderBuilder; use tiny_http::{Server, Response}; mod importers; mod predictors; use predictors::{basic_markov, basic_set, composite, Predictor}; use predictors::basic_markov::MarkovChain; use postgres::{Client, NoTls}; fn main() -> Result<(), io::Error> { let mut last_update = std::time::SystemTime::now(); let config_toml = read_config("data/config.toml").or(read_config("config.toml")).unwrap().parse::<Table>().unwrap(); let mut config = HashMap::new(); for (key, value) in config_toml { match value { toml::Value::String(s) => { config.insert(key, s); } toml::Value::Integer(i) => { config.insert(key, i.to_string()); } _ => { // Ignore other types } } } let mut markov_chain = CompositePredictor::new_from_config(config.clone()); for path in vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"].into_iter() { match &mut FileImporter::new_from_path(path) { Ok(imp) => { match markov_chain.update_from_importer(imp, 10000) { Ok(count) => { println!("Read {} queries from file {}.",count, path); } Err(e) => { println!("Error while reading from file: {}", e) } }; }, _ => {} } } // data locations: vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"] read_from_db(config.clone(), &mut markov_chain); // let term_frequency_threshold = match config.get("term_frequency_threshold") { // Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, // _ => 2 // }; //let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold); // Print the Markov Chain for verification // for (key, values) in &markov_chain { // //println!("{}: {:?}", key, values); // } // Test the function let word = "wie"; // replace with the word you want to check let top_words = &markov_chain.predict(&word, 3); println!("Example prediction for '{}': {:?}", word, top_words); let server = Server::http("0.0.0.0:8000").unwrap(); for request in server.incoming_requests() { process_request(request, config.clone(), &mut markov_chain, last_update) } Ok(()) } fn process_request(request: tiny_http::Request, config: HashMap<String, String>, markov_chain: &mut CompositePredictor, mut last_update: SystemTime) { // println!("received request! method: {:?}, url: {:?}, headers: {:?}", // request.method(), // request.url(), // request.headers() // ); let query = get_query(request.url()); //println!("got query:{}", query.clone().unwrap()); match query { Ok(query) => { match config.get("auth") { Some(server_auth) => { match get_authkey(request.url().clone()) { Ok(client_auth) => { if client_auth != server_auth.clone() { println!("invalid auth:{}, server auth: {}", client_auth, server_auth); request.respond(Response::from_string("")); return; } }, _ => {} } }, _ => {} } let predict_count = match config.get("max_predict_count") { Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5), _ => 5 }; let predictions = &markov_chain.predict(&query, predict_count); //println!("Query: {}, Prediction:{}", query, prediction); let response = Response::from_string( format!("[\"{}\",[{}]]", query, predictions.join(","))); request.respond(response); let now = std::time::SystemTime::now(); let elapsed = now.duration_since(last_update).expect("Time went backwards"); if elapsed >= std::time::Duration::from_secs(24 * 60 * 60) { markov_chain.decay(); read_from_db(config.clone(), markov_chain); last_update = now; } }, Err(e) => { //println!("Error: {}",e); } } } fn read_from_db(config: HashMap<String, String>, predictor: &mut CompositePredictor) { let mut raw_list = String::new(); let blocklist: Vec<&str> = match config.get("blocklist") { Some(list) => { match fs::read_to_string(list) { Ok(l) => { raw_list = l.clone(); raw_list.lines().collect() }, _ => Vec::new() } }, _ => Vec::new() }; let default_password = &String::from(""); match (config.get("db_host"), config.get("db_password")) { (Some(db_host), Some(db_password)) => { let mut client = Client::connect(format!("host={} user=metager password={} dbname=postgres",db_host, db_password).as_str(), NoTls); match client { Ok(mut c) => { let mut count = 0; match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 100000", &[]) { Ok(rows) => { for row in rows { let query: &str = row.get(0); predictor.update_from_query(query, &blocklist); count += 1; } println!("{} queries read from DB", count); }, Err(e) => { println!("Error while reading from DB: {}", e); } } }, Err(e) => { println!("Error while connecting to DB: {}", e); } } }, _ => { println!("No DB config found. Skipping..."); } } } fn read_config(file_path: &str) -> Result<String, io::Error> { fs::read_to_string(file_path) } //TODO: unify get_X() functions fn get_query(request_url: &str) -> Result<String, url::ParseError> { let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after); //println!("parsed_url:{}", parsed_url); let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes()); for (key, value) in query_pairs { //println!("key:{}, value: {}", key, value); if key == "q" { return Ok(value.into_owned()); } } Ok(String::from_str("").unwrap()) } fn get_authkey(request_url: &str) -> Result<String, url::ParseError> { let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after); //println!("parsed_url:{}", parsed_url); let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes()); for (key, value) in query_pairs { //println!("key:{}, value: {}", key, value); if key == "auth" { return Ok(value.into_owned()); } } Ok(String::from_str("").unwrap()) } fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain { let mut filtered_chain: MarkovChain = HashMap::new(); for (key, followers) in markov_chain { let filtered_followers: HashMap<String, usize> = followers .iter() .filter(|&(_, &count)| count >= min_count) .map(|(word, &count)| (word.clone(), count)) .collect(); if !filtered_followers.is_empty() { filtered_chain.insert(key.clone(), filtered_followers); } } filtered_chain }