From a00cb50f92be8d76369eb65015641a6696cf2dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Phil=20H=C3=B6fer?= <phil.hoefer@suma-ev.de> Date: Fri, 19 Jul 2024 15:51:08 +0200 Subject: [PATCH] Basic Importer Structure --- src/importers/file.rs | 20 ++++++++++++++++++++ src/importers/mod.rs | 11 +++++++++++ src/main.rs | 5 ++++- src/predictors/basic_markov.rs | 16 +++++++++++++--- src/predictors/basic_set.rs | 16 +++++++++++++--- src/predictors/composite.rs | 21 ++++++++++++++++----- src/predictors/mod.rs | 4 +++- 7 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 src/importers/file.rs create mode 100644 src/importers/mod.rs diff --git a/src/importers/file.rs b/src/importers/file.rs new file mode 100644 index 0000000..8736901 --- /dev/null +++ b/src/importers/file.rs @@ -0,0 +1,20 @@ + + +use super::{Importer,SearchQuery}; + +struct FileImporter { + file: std::fs::File, +} + +impl Importer for FileImporter { + fn fetch_queries(&mut self, n: usize) -> Result<Vec<SearchQuery>, String> { + + let entries: Vec<SearchQuery> = Vec::new(); + + if entries.is_empty() { + Err(String::from("Requested number of entries exceeds available data")) + } else { + Ok(entries) + } + } +} \ No newline at end of file diff --git a/src/importers/mod.rs b/src/importers/mod.rs new file mode 100644 index 0000000..a0ec7d6 --- /dev/null +++ b/src/importers/mod.rs @@ -0,0 +1,11 @@ + +pub mod file; + +#[derive(Debug)] +pub struct SearchQuery { + pub query: String, +} + +pub trait Importer { + fn fetch_queries(&mut self, n: usize) -> Result<Vec<SearchQuery>, String>; +} diff --git a/src/main.rs b/src/main.rs index 2804b01..5c081a8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,9 @@ use csv::ReaderBuilder; use tiny_http::{Server, Response}; +mod importers; + + mod predictors; use predictors::{basic_markov, basic_set, composite, Predictor}; use predictors::basic_markov::MarkovChain; @@ -139,7 +142,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredict Ok(rows) => { for row in rows { let query: &str = row.get(0); - predictor.update(query); + predictor.update_from_query(query); count += 1; } println!("{} queries read from DB", count); diff --git a/src/predictors/basic_markov.rs b/src/predictors/basic_markov.rs index 1c2f908..57a673c 100644 --- a/src/predictors/basic_markov.rs +++ b/src/predictors/basic_markov.rs @@ -2,6 +2,8 @@ use std::{collections::HashMap, fs::File}; use csv::ReaderBuilder; +use crate::importers::Importer; + use super::Predictor; pub type MarkovChain = HashMap<String, HashMap<String, usize>>; @@ -33,7 +35,15 @@ impl Predictor for MarkovChainPredictor { } } - fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> { + let data = importer.fetch_queries(count)?; + for q in data.iter() { + self.update_from_query(&q.query.as_str()); + } + Ok(()) + } + + fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { Some(list) => { list.split_whitespace().collect() @@ -112,7 +122,7 @@ pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io:: for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); } } @@ -139,7 +149,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); } } }, diff --git a/src/predictors/basic_set.rs b/src/predictors/basic_set.rs index 3b66c2c..d01527b 100644 --- a/src/predictors/basic_set.rs +++ b/src/predictors/basic_set.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, f32::consts::E, fs::File}; +use crate::importers::Importer; + use csv::ReaderBuilder; use super::Predictor; @@ -31,7 +33,15 @@ impl Predictor for SetPredictor { } } - fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> { + let data = importer.fetch_queries(count)?; + for q in data.iter() { + self.update_from_query(&q.query.as_str()); + } + Ok(()) + } + + fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { let blocklist: Vec<&str> = match self.configuration.get("blocked_words") { Some(list) => { list.split_whitespace().collect() @@ -130,7 +140,7 @@ pub fn from_file_path(file_path: &str) -> Result<SetPredictor, std::io::Error> { for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } @@ -160,7 +170,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } diff --git a/src/predictors/composite.rs b/src/predictors/composite.rs index 97b08a7..ecb20e4 100644 --- a/src/predictors/composite.rs +++ b/src/predictors/composite.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, f32::consts::E, fs::File}; +use crate::importers::Importer; + use csv::ReaderBuilder; use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor}; @@ -29,9 +31,18 @@ impl Predictor for CompositePredictor { } } - fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { - self.set_predictor.update(query); - self.markov_predictor.update(query); + + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> { + let data = importer.fetch_queries(count)?; + for q in data.iter() { + self.update_from_query(&q.query.as_str()); + } + Ok(()) + } + + fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> { + self.set_predictor.update_from_query(query); + self.markov_predictor.update_from_query(query); Ok(()) } @@ -62,7 +73,7 @@ pub fn from_file_path(file_path: &str) -> Result<CompositePredictor, std::io::Er for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } @@ -91,7 +102,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String, for result in reader.records() { let record = result?; if let Some(query) = record.get(5) { - markov_chain.update(query); + markov_chain.update_from_query(query); count += 1; } } diff --git a/src/predictors/mod.rs b/src/predictors/mod.rs index d111600..eba5fef 100644 --- a/src/predictors/mod.rs +++ b/src/predictors/mod.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use super::importers::Importer; pub mod basic_set; pub mod basic_markov; @@ -7,7 +8,8 @@ pub mod composite; pub trait Predictor { fn predict(&self, query: &str, n: usize) -> Vec<String>; - fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn update_from_query(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>; + fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>>; fn decay(&mut self) -> (); fn new() -> Self where Self: Sized; -- GitLab