From f5fb5275f1f78a517d5f9b07903b71f3ecefb382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de> Date: Fri, 12 Jul 2024 13:23:25 +0200 Subject: [PATCH] Add Basic Decay --- src/main.rs | 4 +++- src/predictors/basic_markov.rs | 4 ++++ src/predictors/basic_set.rs | 22 ++++++++++++++++++++++ src/predictors/mod.rs | 1 + 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 7e8c5cd..d148b96 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,6 +49,7 @@ fn main() -> Result<(), io::Error> { .unwrap_or(basic_set::SetPredictor::new()); markov_chain = read_from_db(config.clone(), markov_chain); + markov_chain.decay(); // let term_frequency_threshold = match config.get("term_frequency_threshold") { // Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, @@ -106,6 +107,7 @@ fn main() -> Result<(), io::Error> { let now = std::time::SystemTime::now(); let elapsed = now.duration_since(last_update).expect("Time went backwards"); if elapsed >= std::time::Duration::from_secs(24 * 60 * 60) { + markov_chain.decay(); markov_chain = read_from_db(config.clone(), markov_chain); last_update = now; } @@ -128,7 +130,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) -> match client { Ok(mut c) => { let mut count = 0; - match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 5000", &[]) { + match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 100000", &[]) { Ok(rows) => { for row in rows { let query: &str = row.get(0); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index eebc5aa..fc8afcb 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -60,6 +60,10 @@ impl Predictor for MarkovChainPredictor { Ok(()) } + fn decay(&mut self) -> () { + //TODO + } + fn predict(&self, query: &str, n: usize) -> String { if let Some(top_words) = get_top_following_words( diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index bf1a101..331e67f 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -61,6 +61,28 @@ impl Predictor for SetPredictor { Ok(()) } + fn decay(&mut self) -> () { + let keys_to_remove: Vec<String> = self.set.iter_mut() + .filter_map(|(key, value)| { + if *value > 0 { + *value -= 1; + if *value < 1 { + Some(key.clone()) + } else { + None + } + } else { + Some(key.clone()) + } + }) + .collect(); + + for key in keys_to_remove { + self.set.remove(&key); + } + + } + fn predict(&self, query: &str, n: usize) -> String { if let Some(top_words) = get_top_completions( diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index f4736d9..c37124d 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -7,6 +7,7 @@ pub trait Predictor { fn predict(&self, query: &str, n: usize) -> String; fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn decay(&mut self) -> (); fn new() -> Self where Self: Sized; fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized; -- GitLab