diff --git a/src/main.rs b/src/main.rs index 7e8c5cdcf03bd6d4ef53bd5417abf0d95feedbd1..d148b965f2412daee7f108043742dddcd9d193f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -49,6 +49,7 @@ fn main() -> Result<(), io::Error> { .unwrap_or(basic_set::SetPredictor::new()); markov_chain = read_from_db(config.clone(), markov_chain); + markov_chain.decay(); // let term_frequency_threshold = match config.get("term_frequency_threshold") { // Some(toml::Value::Integer(n)) if *n >= 0 => *n as usize, @@ -106,6 +107,7 @@ fn main() -> Result<(), io::Error> { let now = std::time::SystemTime::now(); let elapsed = now.duration_since(last_update).expect("Time went backwards"); if elapsed >= std::time::Duration::from_secs(24 * 60 * 60) { + markov_chain.decay(); markov_chain = read_from_db(config.clone(), markov_chain); last_update = now; } @@ -128,7 +130,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: SetPredictor) -> match client { Ok(mut c) => { let mut count = 0; - match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 5000", &[]) { + match c.query("SELECT query FROM public.logs_partitioned ORDER BY time DESC LIMIT 100000", &[]) { Ok(rows) => { for row in rows { let query: &str = row.get(0); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index eebc5aae2b30ed6fa4eec79551e6bbcbd09e65e5..fc8afcb6778304d8789866b7e1073075cab8dc03 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -60,6 +60,10 @@ impl Predictor for MarkovChainPredictor { Ok(()) } + fn decay(&mut self) -> () { + //TODO + } + fn predict(&self, query: &str, n: usize) -> String { if let Some(top_words) = get_top_following_words( diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index bf1a101ce7cb0f7d349a8d67f23fcf39f3949eb5..331e67fd869f80a02a23203c9cb93248b0af5cd2 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -61,6 +61,28 @@ impl Predictor for SetPredictor { Ok(()) } + fn decay(&mut self) -> () { + let keys_to_remove: Vec<String> = self.set.iter_mut() + .filter_map(|(key, value)| { + if *value > 0 { + *value -= 1; + if *value < 1 { + Some(key.clone()) + } else { + None + } + } else { + Some(key.clone()) + } + }) + .collect(); + + for key in keys_to_remove { + self.set.remove(&key); + } + + } + fn predict(&self, query: &str, n: usize) -> String { if let Some(top_words) = get_top_completions( diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index f4736d993b1c7c7e426f922d0249bd6b36f0a8ef..c37124d028ab4b8391a642df8e4c4d987185a19e 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -7,6 +7,7 @@ pub trait Predictor { fn predict(&self, query: &str, n: usize) -> String; fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn decay(&mut self) -> (); fn new() -> Self where Self: Sized; fn new_from_config(config: HashMap<String, impl Into<String>>) -> Self where Self:Sized;