Skip to content
Snippets Groups Projects
main.rs 8.63 KiB
Newer Older
  • Learn to ignore specific revisions
  • // SPDX-FileCopyrightText: 2024 Phil Höfer <phil@suma-ev.de>
    // SPDX-License-Identifier: AGPL-3.0-only
    
    
    Phil Höfer's avatar
    Phil Höfer committed
    use core::str;
    
    Phil Höfer's avatar
    Phil Höfer committed
    use std::collections::HashMap;
    use std::fs::File;
    use std::io::{self, BufRead, BufReader};
    use std::error::Error;
    use std::str::FromStr;
    
    Phil Höfer's avatar
    Phil Höfer committed
    use std::fs;
    
    Phil Höfer's avatar
    Phil Höfer committed
    use std::time::SystemTime;
    
    use importers::Importer;
    use importers::file::FileImporter;
    
    Phil Höfer's avatar
    Phil Höfer committed
    use predictors::basic_set::SetPredictor;
    
    use predictors::composite::CompositePredictor;
    
    Phil Höfer's avatar
    Phil Höfer committed
    use toml::Table;
    
    Phil Höfer's avatar
    Phil Höfer committed
    
    use csv::ReaderBuilder;
    
    
    use tiny_http::{Server, Response};
    
    
    Phil Höfer's avatar
    Phil Höfer committed
    mod importers;
    
    
    
    mod predictors;
    
    use predictors::{basic_markov, basic_set, composite, Predictor};
    
    use predictors::basic_markov::MarkovChain;
    
    Phil Höfer's avatar
    Phil Höfer committed
    
    
    Phil Höfer's avatar
    Phil Höfer committed
    use postgres::{Client, NoTls};
    
    
    Phil Höfer's avatar
    Phil Höfer committed
    
    fn main() -> Result<(), io::Error> {
    
    
        let mut last_update = std::time::SystemTime::now();
    
    Phil Höfer's avatar
    Phil Höfer committed
        let config_toml = read_config("data/config.toml").or(read_config("config.toml")).unwrap().parse::<Table>().unwrap();
    
        let mut config = HashMap::new();
        for (key, value) in config_toml {
            match value {
                toml::Value::String(s) => {
                    config.insert(key, s);
                }
                toml::Value::Integer(i) => {
                    config.insert(key, i.to_string());
                }
                _ => {
                    // Ignore other types
                }
            }
        }
    
        let mut markov_chain = CompositePredictor::new_from_config(config.clone());
    
        for path in vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"].into_iter() {
            match &mut FileImporter::new_from_path(path) {
                Ok(imp) => {
    
    Phil Höfer's avatar
    Phil Höfer committed
                    match markov_chain.update_from_importer(imp, 10000) {
                        Ok(count) => {
                            println!("Read {} queries from file {}.",count, path);
                        }
                        Err(e) => {
                            println!("Error while reading from file: {}", e)
                        }
                    };
                },
    
                _ => {}
            }
        }
        // data locations: vec!["../../data/data.csv","data/data.csv","data.csv","data_full.csv"]
    
    Phil Höfer's avatar
    Phil Höfer committed
        read_from_db(config.clone(), &mut markov_chain);
    
        // let term_frequency_threshold = match config.get("term_frequency_threshold") {
        //     Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
        //     _ => 2
        // };
        //let filtered_markov_chain = filter_markov_chain(&markov_chain,term_frequency_threshold);
    
    Phil Höfer's avatar
    Phil Höfer committed
        // Print the Markov Chain for verification
    
        // for (key, values) in &markov_chain {
        //      //println!("{}: {:?}", key, values);
        // }
    
    Phil Höfer's avatar
    Phil Höfer committed
    
        // Test the function
        let word = "wie"; // replace with the word you want to check
    
        let top_words = &markov_chain.predict(&word, 3);
    
        println!("Example prediction for '{}': {:?}", word, top_words);
    
    Phil Höfer's avatar
    Phil Höfer committed
        let server = Server::http("0.0.0.0:8000").unwrap();
    
    Phil Höfer's avatar
    Phil Höfer committed
    
        for request in server.incoming_requests() {
    
    Phil Höfer's avatar
    Phil Höfer committed
            process_request(request, config.clone(), &mut markov_chain, last_update)
        }
    
        Ok(())
    }
    
    fn process_request(request: tiny_http::Request, config: HashMap<String, String>, markov_chain: &mut CompositePredictor, mut last_update: SystemTime) {
        //   println!("received request! method: {:?}, url: {:?}, headers: {:?}",
    
    Phil Höfer's avatar
    Phil Höfer committed
            //       request.method(),
            //       request.url(),
            //       request.headers()
            //   );
    
    Phil Höfer's avatar
    Phil Höfer committed
            let query = get_query(request.url());
    
    Phil Höfer's avatar
    Phil Höfer committed
            //println!("got query:{}", query.clone().unwrap());
    
    Phil Höfer's avatar
    Phil Höfer committed
            match query {
                Ok(query) => {
    
    
                    match config.get("auth") { 
                        Some(server_auth) => {
                            match get_authkey(request.url().clone()) {
                                Ok(client_auth) => {
                                    if client_auth != server_auth.clone() {
                                        println!("invalid auth:{}, server auth: {}", client_auth, server_auth);
                                        request.respond(Response::from_string(""));
    
    Phil Höfer's avatar
    Phil Höfer committed
                                        return;
    
                    let predict_count = match config.get("max_predict_count") {
    
                        Some(n) if n.parse::<usize>().unwrap_or_default() > 0 => n.parse::<usize>().unwrap_or(5),
    
                    let predictions = &markov_chain.predict(&query, predict_count);
    
    Phil Höfer's avatar
    Phil Höfer committed
                    //println!("Query: {}, Prediction:{}", query, prediction);
    
                    let response = Response::from_string(
                        format!("[\"{}\",[{}]]", query, predictions.join(",")));
    
    Phil Höfer's avatar
    Phil Höfer committed
                    request.respond(response);
    
                    let now = std::time::SystemTime::now();
                    let elapsed = now.duration_since(last_update).expect("Time went backwards");
                    if elapsed >= std::time::Duration::from_secs(24 * 60 * 60) {
    
    Phil Höfer's avatar
    Phil Höfer committed
                        markov_chain.decay();
    
    Phil Höfer's avatar
    Phil Höfer committed
                        read_from_db(config.clone(), markov_chain);
    
                        last_update = now;
                    }
    
    Phil Höfer's avatar
    Phil Höfer committed
                },
                Err(e) => {
    
    Phil Höfer's avatar
    Phil Höfer committed
                    //println!("Error: {}",e);
    
    Phil Höfer's avatar
    Phil Höfer committed
    fn read_from_db(config: HashMap<String, String>, predictor: &mut CompositePredictor) {
    
    
        let mut raw_list = String::new();
        let blocklist: Vec<&str> = match config.get("blocklist") {
            Some(list) => {
                match fs::read_to_string(list) {
                    Ok(l) => {
                        raw_list = l.clone();
                        raw_list.lines().collect()
                    },
                    _ => Vec::new()
                }
            },
            _ => Vec::new()
            
        };
    
    
    
    Phil Höfer's avatar
    Phil Höfer committed
        let default_password = &String::from("");
    
        match (config.get("db_host"), config.get("db_password")) {
            (Some(db_host), Some(db_password)) => {
                let mut client = Client::connect(format!("host={} user=metager password={} dbname=postgres",db_host, db_password).as_str(), NoTls);
    
    Phil Höfer's avatar
    Phil Höfer committed
                match client {
                    Ok(mut c) => {
                        let mut count = 0;
    
    Phil Höfer's avatar
    Phil Höfer committed
                        match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 100000", &[]) {
    
    Phil Höfer's avatar
    Phil Höfer committed
                            Ok(rows) => {
                                for row in rows {
                                    let query: &str = row.get(0);
    
                                    predictor.update_from_query(query, &blocklist);
    
    Phil Höfer's avatar
    Phil Höfer committed
                                    count += 1;
                                }
                                println!("{} queries read from DB", count);
    
    Phil Höfer's avatar
    Phil Höfer committed
                            },
                            Err(e) => {
                                println!("Error while reading from DB: {}", e);
    
    Phil Höfer's avatar
    Phil Höfer committed
                            }
    
    Phil Höfer's avatar
    Phil Höfer committed
                        }
                    },
                    Err(e) => {
    
    Phil Höfer's avatar
    Phil Höfer committed
                        println!("Error while connecting to DB: {}", e);
    
            _ => {
                println!("No DB config found. Skipping...");
    
    Phil Höfer's avatar
    Phil Höfer committed
        }  
    
    Phil Höfer's avatar
    Phil Höfer committed
    fn read_config(file_path: &str) -> Result<String, io::Error> {
        fs::read_to_string(file_path)
    }
    
    
    //TODO: unify get_X() functions
    
    Phil Höfer's avatar
    Phil Höfer committed
    fn get_query(request_url: &str) -> Result<String, url::ParseError> {
    
    Phil Höfer's avatar
    Phil Höfer committed
        let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
        //println!("parsed_url:{}", parsed_url);
    
    Phil Höfer's avatar
    Phil Höfer committed
        let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
        for (key, value) in query_pairs {
    
    Phil Höfer's avatar
    Phil Höfer committed
            //println!("key:{}, value: {}", key, value);
    
    Phil Höfer's avatar
    Phil Höfer committed
            if key == "q" {
                return Ok(value.into_owned());
            }
        }
        Ok(String::from_str("").unwrap())
    }
    
    Phil Höfer's avatar
    Phil Höfer committed
    fn get_authkey(request_url: &str) -> Result<String, url::ParseError> {
        let parsed_url = request_url.split_once('?').map_or(request_url, |(_, after)| after);
        //println!("parsed_url:{}", parsed_url);
        let query_pairs = url::form_urlencoded::parse(parsed_url.as_bytes());
        for (key, value) in query_pairs {
            //println!("key:{}, value: {}", key, value);
            if key == "auth" {
                return Ok(value.into_owned());
            }
        }
        Ok(String::from_str("").unwrap())
    }
    
    Phil Höfer's avatar
    Phil Höfer committed
    
    
    fn filter_markov_chain(markov_chain: &MarkovChain, min_count: usize) -> MarkovChain {
        let mut filtered_chain: MarkovChain = HashMap::new();
    
        for (key, followers) in markov_chain {
            let filtered_followers: HashMap<String, usize> = followers
                .iter()
                .filter(|&(_, &count)| count >= min_count)
                .map(|(word, &count)| (word.clone(), count))
                .collect();
    
            if !filtered_followers.is_empty() {
                filtered_chain.insert(key.clone(), filtered_followers);
            }
        }
    
        filtered_chain
    }