From f5fb5275f1f78a517d5f9b07903b71f3ecefb382 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de>
Date: Fri, 12 Jul 2024 13:23:25 +0200
Subject: [PATCH] Add Basic Decay

---
 src/main.rs                    |  4 +++-
 src/predictors/basic_markov.rs |  4 ++++
 src/predictors/basic_set.rs    | 22 ++++++++++++++++++++++
 src/predictors/mod.rs          |  1 +
 4 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/src/main.rs b/src/main.rs
index 7e8c5cd..d148b96 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -49,6 +49,7 @@ fn main() -> Result<(), io::Error> {
                 .unwrap_or(basic_set::SetPredictor::new());
 
     markov_chain = read_from_db(config.clone(), markov_chain);
+    markov_chain.decay();
 
     // let term_frequency_threshold = match config.get("term_frequency_threshold") {
     //     Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize,
@@ -106,6 +107,7 @@ fn main() -> Result<(), io::Error> {
                 let now = std::time::SystemTime::now();
                 let elapsed = now.duration_since(last_update).expect("Time went backwards");
                 if elapsed >= std::time::Duration::from_secs(24 * 60 * 60) {
+                    markov_chain.decay();
                     markov_chain = read_from_db(config.clone(), markov_chain);
                     last_update = now;
                 }
@@ -128,7 +130,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) ->
             match client {
                 Ok(mut c) => {
                     let mut count = 0;
-                    match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 5000", &[]) {
+                    match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 100000", &[]) {
                         Ok(rows) => {
                             for row in rows {
                                 let query: &str = row.get(0);
diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs
index eebc5aa..fc8afcb 100644
--- a/src/predictors/basic_markov.rs
+++ b/src/predictors/basic_markov.rs
@@ -60,6 +60,10 @@ impl Predictor for MarkovChainPredictor {
         Ok(())
     }
 
+    fn decay(&mut self) -> () {
+        //TODO
+    }
+
     fn predict(&self, query: &str, n: usize) -> String {
         if let Some(top_words) =
             get_top_following_words(
diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs
index bf1a101..331e67f 100644
--- a/src/predictors/basic_set.rs
+++ b/src/predictors/basic_set.rs
@@ -61,6 +61,28 @@ impl Predictor for SetPredictor {
         Ok(())
     }
 
+    fn decay(&mut self) -> () {
+        let keys_to_remove: Vec<String> = self.set.iter_mut()
+        .filter_map(|(key, value)| {
+            if *value > 0 {
+                *value -= 1;
+                if *value < 1 {
+                    Some(key.clone())
+                } else {
+                    None
+                }
+            } else {
+                Some(key.clone())
+            }
+        })
+        .collect();
+    
+        for key in keys_to_remove {
+            self.set.remove(&key);
+        }
+        
+    }
+
     fn predict(&self, query: &str, n: usize) -> String {
         if let Some(top_words) =
             get_top_completions(
diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs
index f4736d9..c37124d 100644
--- a/src/predictors/mod.rs
+++ b/src/predictors/mod.rs
@@ -7,6 +7,7 @@ pub trait Predictor {
 
     fn predict(&self, query: &str, n: usize) -> String;
     fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>;
+    fn decay(&mut self) -> ();
 
     fn new() -> Self where Self: Sized;
     fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized;
-- 
GitLab