Skip to content
Snippets Groups Projects
Commit a00cb50f authored by Phil Höfer's avatar Phil Höfer
Browse files

Basic Importer Structure

parent 1e7a4827
Branches
No related tags found
No related merge requests found
Pipeline #10041 passed
use super::{Importer,SearchQuery};
struct FileImporter {
file: std::fs::File,
}
impl Importer for FileImporter {
fn fetch_queries(&mut self, n: usize) -> Result<Vec<SearchQuery>, String> {
let entries: Vec<SearchQuery> = Vec::new();
if entries.is_empty() {
Err(String::from("Requested number of entries exceeds available data"))
} else {
Ok(entries)
}
}
}
\ No newline at end of file
pub mod file;
#[derive(Debug)]
pub struct SearchQuery {
pub query: String,
}
pub trait Importer {
fn fetch_queries(&mut self, n: usize) -> Result<Vec<SearchQuery>, String>;
}
......@@ -18,6 +18,9 @@ use csv::ReaderBuilder;
use tiny_http::{Server, Response};
mod importers;
mod predictors;
use predictors::{basic_markov, basic_set, composite, Predictor};
use predictors::basic_markov::MarkovChain;
......@@ -139,7 +142,7 @@ fn read_from_db(config: HashMap<String, String>, mut predictor: CompositePredict
Ok(rows) => {
for row in rows {
let query: &str = row.get(0);
predictor.update(query);
predictor.update_from_query(query);
count += 1;
}
println!("{} queries read from DB", count);
......
......@@ -2,6 +2,8 @@ use std::{collections::HashMap, fs::File};
use csv::ReaderBuilder;
use crate::importers::Importer;
use super::Predictor;
pub type MarkovChain = HashMap<String, HashMap<String, usize>>;
......@@ -33,7 +35,15 @@ impl Predictor for MarkovChainPredictor {
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> {
let data = importer.fetch_queries(count)?;
for q in data.iter() {
self.update_from_query(&q.query.as_str());
}
Ok(())
}
fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let blocklist: Vec<&str> = match self.configuration.get("blocked_words") {
Some(list) => {
list.split_whitespace().collect()
......@@ -112,7 +122,7 @@ pub fn from_file_path(file_path: &str) -> Result<MarkovChainPredictor, std::io::
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
markov_chain.update_from_query(query);
}
}
......@@ -139,7 +149,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String,
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
markov_chain.update_from_query(query);
}
}
},
......
use std::{collections::HashMap, f32::consts::E, fs::File};
use crate::importers::Importer;
use csv::ReaderBuilder;
use super::Predictor;
......@@ -31,7 +33,15 @@ impl Predictor for SetPredictor {
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> {
let data = importer.fetch_queries(count)?;
for q in data.iter() {
self.update_from_query(&q.query.as_str());
}
Ok(())
}
fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
let blocklist: Vec<&str> = match self.configuration.get("blocked_words") {
Some(list) => {
list.split_whitespace().collect()
......@@ -130,7 +140,7 @@ pub fn from_file_path(file_path: &str) -> Result<SetPredictor, std::io::Error> {
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
markov_chain.update_from_query(query);
count += 1;
}
}
......@@ -160,7 +170,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String,
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
markov_chain.update_from_query(query);
count += 1;
}
}
......
use std::{collections::HashMap, f32::consts::E, fs::File};
use crate::importers::Importer;
use csv::ReaderBuilder;
use super::{basic_markov::MarkovChainPredictor, basic_set::SetPredictor, Predictor};
......@@ -29,9 +31,18 @@ impl Predictor for CompositePredictor {
}
}
fn update(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
self.set_predictor.update(query);
self.markov_predictor.update(query);
fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>> {
let data = importer.fetch_queries(count)?;
for q in data.iter() {
self.update_from_query(&q.query.as_str());
}
Ok(())
}
fn update_from_query(&mut self, query: &str) -> Result<(), Box<dyn std::error::Error>> {
self.set_predictor.update_from_query(query);
self.markov_predictor.update_from_query(query);
Ok(())
}
......@@ -62,7 +73,7 @@ pub fn from_file_path(file_path: &str) -> Result<CompositePredictor, std::io::Er
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
markov_chain.update_from_query(query);
count += 1;
}
}
......@@ -91,7 +102,7 @@ pub fn from_file_path_and_config(file_paths: Vec<&str>, config: HashMap<String,
for result in reader.records() {
let record = result?;
if let Some(query) = record.get(5) {
markov_chain.update(query);
markov_chain.update_from_query(query);
count += 1;
}
}
......
use std::collections::HashMap;
use super::importers::Importer;
pub mod basic_set;
pub mod basic_markov;
......@@ -7,7 +8,8 @@ pub mod composite;
pub trait Predictor {
fn predict(&self, query: &str, n: usize) -> Vec<String>;
fn update(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>;
fn update_from_query(&mut self, query: &str) -> Result<(),Box<dyn std::error::Error>>;
fn update_from_importer<I: Importer>(&mut self, importer: &mut I, count: usize) -> Result<(),Box<dyn std::error::Error>>;
fn decay(&mut self) -> ();
fn new() -> Self where Self: Sized;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment