From 3d63553a56393c8a0466073c5c0387be4d23c867 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de>
Date: Mon, 22 Jul 2024 13:12:21 +0200
Subject: [PATCH] Cache Blocklist When Importing

---
 src/main.rs                    | 19 ++++++++++++++++++-
 src/predictors/basic_markov.rs | 23 +++++++++++------------
 src/predictors/basic_set.rs    | 22 ++++++++++++----------
 src/predictors/composite.rs    | 30 +++++++++++++++++++++++++-----
 src/predictors/mod.rs          |  2 +-
 5 files changed, 67 insertions(+), 29 deletions(-)

diff --git a/src/main.rs b/src/main.rs
index b5eea6e..007f480 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -148,6 +148,23 @@ fn main() -> Result<(), io::Error> {
 
 
 fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredictor) -> CompositePredictor {
+
+    let mut raw_list = String::new();
+    let blocklist: Vec<&str> = match config.get("blocklist") {
+        Some(list) => {
+            match fs::read_to_string(list) {
+                Ok(l) => {
+                    raw_list = l.clone();
+                    raw_list.lines().collect()
+                },
+                _ => Vec::new()
+            }
+        },
+        _ => Vec::new()
+        
+    };
+
+
     let default_password = &String::from("");
     match (config.get("db_host"), config.get("db_password")) {
         (Some(db_host), Some(db_password)) => {
@@ -159,7 +176,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredict
                         Ok(rows) => {
                             for row in rows {
                                 let query: &str = row.get(0);
-                                predictor.update_from_query(query);
+                                predictor.update_from_query(query, &blocklist);
                                 count += 1;
                             }
                             println!("{} queries read from DB", count);
diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs
index 3f0e400..ed4bcc3 100644
--- a/src/predictors/basic_markov.rs
+++ b/src/predictors/basic_markov.rs
@@ -35,18 +35,7 @@ impl Predictor for MarkovChainPredictor {
         }
     }
 
-    fn update_from_importer<I: Importer>(&mut self, importer:  &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> {
-        let data = importer.fetch_queries(count)?;
-        let mut count = 0;
-        for q in data.iter() {
-            self.update_from_query(&q.query.as_str());
-            count += 1;
-        }
-        Ok(count)
-    }
-
-    fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
-        let mut raw_list = String::new();
+    fn update_from_importer<I: Importer>(&mut self, importer:  &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> {let mut raw_list = String::new();
         let blocklist: Vec<&str> = match self.configuration.get("blocklist") {
             Some(list) => {
                 match fs::read_to_string(list) {
@@ -61,6 +50,16 @@ impl Predictor for MarkovChainPredictor {
             
         };
 
+        let data = importer.fetch_queries(count)?;
+        let mut count = 0;
+        for q in data.iter() {
+            self.update_from_query(&q.query.as_str(), &blocklist);
+            count += 1;
+        }
+        Ok(count)
+    }
+
+    fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(), Box<dyn std::error::Error>> {
         for word in blocklist.clone() {
             if word.trim().len() > 1 && query.to_lowercase().contains(String::from(word).to_lowercase().as_str()) {
                 return Ok(());
diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs
index a97196b..05edb13 100644
--- a/src/predictors/basic_set.rs
+++ b/src/predictors/basic_set.rs
@@ -34,17 +34,7 @@ impl Predictor for SetPredictor {
     }
 
     fn update_from_importer<I: Importer>(&mut self, importer:  &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> {
-        let data = importer.fetch_queries(count)?;
-        let mut count = 0;
 
-        for q in data.iter() {
-            self.update_from_query(&q.query.as_str());
-            count += 1;
-        }
-        Ok(count)
-    }
-
-    fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
         let mut raw_list = String::new();
         let blocklist: Vec<&str> = match self.configuration.get("blocklist") {
             Some(list) => {
@@ -60,6 +50,18 @@ impl Predictor for SetPredictor {
             
         };
 
+        let data = importer.fetch_queries(count)?;
+        let mut count = 0;
+
+        for q in data.iter() {
+            self.update_from_query(&q.query.as_str(), &blocklist);
+            count += 1;
+        }
+        Ok(count)
+    }
+
+    fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(), Box<dyn std::error::Error>> {
+
         for word in blocklist.clone() {
             if word.trim().len() > 1 && query.to_lowercase().contains(String::from(word).to_lowercase().as_str()) {
                 return Ok(());
diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs
index eef3052..10494ac 100644
--- a/src/predictors/composite.rs
+++ b/src/predictors/composite.rs
@@ -1,4 +1,4 @@
-use std::{collections::HashMap, f32::consts::E, fs::File};
+use std::{collections::HashMap, f32::consts::E, fs::{self, File}};
 
 use crate::importers::Importer;
 
@@ -7,6 +7,7 @@ use csv::ReaderBuilder;
 use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor};
 
 pub struct CompositePredictor {
+    configuration: HashMap<String, String>,
     set_predictor: SetPredictor,
     markov_predictor: MarkovChainPredictor
 }
@@ -15,6 +16,7 @@ impl Predictor for CompositePredictor {
     fn new() -> Self
     {
         CompositePredictor {
+            configuration: HashMap::new(),
             set_predictor: SetPredictor::new(),
             markov_predictor: MarkovChainPredictor::new()
         }
@@ -26,6 +28,7 @@ impl Predictor for CompositePredictor {
             configuration.insert(key, value.into());
         }
         CompositePredictor {
+            configuration: configuration.clone(),
             set_predictor: SetPredictor::new_from_config(configuration.clone()),
             markov_predictor: MarkovChainPredictor::new_from_config(configuration)
         }
@@ -33,18 +36,35 @@ impl Predictor for CompositePredictor {
 
 
     fn update_from_importer<I: Importer>(&mut self, importer:  &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>> {
+
+
+        let mut raw_list = String::new();
+        let blocklist: Vec<&str> = match self.configuration.get("blocklist") {
+            Some(list) => {
+                match fs::read_to_string(list) {
+                    Ok(l) => {
+                        raw_list = l.clone();
+                        raw_list.lines().collect()
+                    },
+                    _ => Vec::new()
+                }
+            },
+            _ => Vec::new()
+            
+        };
+
         let data = importer.fetch_queries(count)?;
         let mut count = 0;
         for q in data.iter() {
-            self.update_from_query(&q.query.as_str());
+            self.update_from_query(&q.query.as_str(), &blocklist);
             count += 1;
         }
         Ok(count)
     }
 
-    fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
-        self.set_predictor.update_from_query(query);
-        self.markov_predictor.update_from_query(query);
+    fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(), Box<dyn std::error::Error>> {
+        self.set_predictor.update_from_query(query, blocklist);
+        self.markov_predictor.update_from_query(query, blocklist);
         Ok(())
     }
 
diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs
index 143b969..e8d1def 100644
--- a/src/predictors/mod.rs
+++ b/src/predictors/mod.rs
@@ -8,7 +8,7 @@ pub mod composite;
 pub trait Predictor {
 
     fn predict(&self, query: &str, n: usize) -> Vec<String>;
-    fn update_from_query(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>;
+    fn update_from_query(&mut self, query: &str, blocklist: &Vec<&str>) -> Result<(),Box<dyn std::error::Error>>;
     fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<usize,Box<dyn std::error::Error>>;
     fn decay(&mut self) -> ();
 
-- 
GitLab