Skip to content
Snippets Groups Projects
Commit 76b910de authored by JulianBohne's avatar JulianBohne
Browse files

Added Markov model for generating Quotes

parent a1836e60
Branches main
No related tags found
No related merge requests found
Pipeline #252393 passed
mod markov;
use std::{fs, path::Path};
use anyhow::Result;
use clap::{arg, Parser};
use markov::Model;
use poem::{listener::TcpListener, Route, Server};
use poem_openapi::{payload::PlainText, OpenApi, OpenApiService};
use rand::{thread_rng, Rng};
......@@ -9,15 +12,23 @@ use rand::{thread_rng, Rng};
#[derive(Clone)]
struct Api {
quotes: Vec<String>,
markov_model: Model
}
impl Api {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let quotes: Vec<_> = fs::read_to_string(path)?
.lines()
.map(|v| v.to_string())
.collect();
let mut markov_model = Model::new();
markov_model.add_messages(&quotes);
Ok(Api {
quotes: fs::read_to_string(path)?
.lines()
.map(|v| v.to_string())
.collect(),
quotes,
markov_model
})
}
}
......@@ -33,6 +44,11 @@ impl Api {
async fn index(&self) -> PlainText<String> {
PlainText(self.quotes[thread_rng().gen_range(0..self.quotes.len())].clone())
}
#[oai(path = "/markov", method = "get")]
async fn markov(&self) -> PlainText<String> {
PlainText(self.markov_model.generage_message())
}
}
#[derive(Parser)]
......
use std::{collections::HashMap, hash::Hash};
use rand::Rng;
#[derive(Hash, Clone, PartialEq, Eq, Debug)]
enum Token {
Start,
Word(String),
End
}
#[derive(Clone)]
struct TokenSampler {
token_count: usize,
token_to_token_count: HashMap<Token, usize>
}
impl TokenSampler {
fn new() -> TokenSampler {
TokenSampler {
token_count: 0,
token_to_token_count: HashMap::new()
}
}
fn add_token(&mut self, token: Token) {
self.token_count += 1;
*self.token_to_token_count.entry(token).or_insert(0) += 1;
}
fn sample(&self) -> Token {
let mut random_index = rand::thread_rng().gen_range(0..self.token_count);
for (token, count) in &self.token_to_token_count {
if &random_index >= count {
random_index -= count;
} else {
return token.clone();
}
}
unreachable!("The token count does not match the tokens in the sampler :(");
}
}
#[derive(Clone)]
pub struct Model {
token_to_token_sampler: HashMap<Token, TokenSampler>
}
impl Model {
pub fn new() -> Model {
Model {
token_to_token_sampler: HashMap::new()
}
}
pub fn add_messages(&mut self, messages: &Vec<String>) {
for message in messages {
self.add_tokenized_message(&Model::tokenize(message));
}
}
fn tokenize(message: &str) -> Vec<Token> {
let mut tokens = vec![Token::Start];
tokens.extend(message.split_whitespace().map(|word| Token::Word(word.to_owned())));
tokens.push(Token::End);
tokens
}
fn add_tokenized_message(&mut self, tokenized_message: &Vec<Token>) {
for window in tokenized_message.windows(2) {
let token = &window[0];
let next_token = &window[1];
self.token_to_token_sampler.entry(token.clone()).or_insert(TokenSampler::new()).add_token(next_token.clone());
}
}
pub fn generage_message(&self) -> String {
let mut current_token = Token::Start;
let mut current_message = "".to_owned();
loop {
let next_token = self.token_to_token_sampler[&current_token].sample();
match &next_token {
Token::Start => unreachable!("Start token should not be reachable inside a message :("),
Token::End => return current_message,
Token::Word(current_word) => current_message = current_message + " " + &current_word
}
current_token = next_token;
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment