diff --git a/src/main.rs b/src/main.rs index 658e545289c1ef144b5aea981700cd980abfd6b8..6fc069cd7d525bdb199ef1b453f059629bfc5ec9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,10 @@ +mod markov; + use std::{fs, path::Path}; use anyhow::Result; use clap::{arg, Parser}; +use markov::Model; use poem::{listener::TcpListener, Route, Server}; use poem_openapi::{payload::PlainText, OpenApi, OpenApiService}; use rand::{thread_rng, Rng}; @@ -9,15 +12,23 @@ use rand::{thread_rng, Rng}; #[derive(Clone)] struct Api { quotes: Vec<String>, + markov_model: Model } impl Api { pub fn from_file(path: impl AsRef<Path>) -> Result<Self> { + + let quotes: Vec<_> = fs::read_to_string(path)? + .lines() + .map(|v| v.to_string()) + .collect(); + + let mut markov_model = Model::new(); + markov_model.add_messages("es); + Ok(Api { - quotes: fs::read_to_string(path)? - .lines() - .map(|v| v.to_string()) - .collect(), + quotes, + markov_model }) } } @@ -33,6 +44,11 @@ impl Api { async fn index(&self) -> PlainText<String> { PlainText(self.quotes[thread_rng().gen_range(0..self.quotes.len())].clone()) } + + #[oai(path = "/markov", method = "get")] + async fn markov(&self) -> PlainText<String> { + PlainText(self.markov_model.generage_message()) + } } #[derive(Parser)] diff --git a/src/markov.rs b/src/markov.rs new file mode 100644 index 0000000000000000000000000000000000000000..2bdb24ffb531d90cc7e4973916872ffcef7c222f --- /dev/null +++ b/src/markov.rs @@ -0,0 +1,89 @@ +use std::{collections::HashMap, hash::Hash}; +use rand::Rng; + +#[derive(Hash, Clone, PartialEq, Eq, Debug)] +enum Token { + Start, + Word(String), + End +} + +#[derive(Clone)] +struct TokenSampler { + token_count: usize, + token_to_token_count: HashMap<Token, usize> +} + +impl TokenSampler { + fn new() -> TokenSampler { + TokenSampler { + token_count: 0, + token_to_token_count: HashMap::new() + } + } + + fn add_token(&mut self, token: Token) { + self.token_count += 1; + *self.token_to_token_count.entry(token).or_insert(0) += 1; + } + + fn sample(&self) -> Token { + let mut random_index = rand::thread_rng().gen_range(0..self.token_count); + for (token, count) in &self.token_to_token_count { + if &random_index >= count { + random_index -= count; + } else { + return token.clone(); + } + } + unreachable!("The token count does not match the tokens in the sampler :("); + } +} + +#[derive(Clone)] +pub struct Model { + token_to_token_sampler: HashMap<Token, TokenSampler> +} + +impl Model { + pub fn new() -> Model { + Model { + token_to_token_sampler: HashMap::new() + } + } + + pub fn add_messages(&mut self, messages: &Vec<String>) { + for message in messages { + self.add_tokenized_message(&Model::tokenize(message)); + } + } + + fn tokenize(message: &str) -> Vec<Token> { + let mut tokens = vec![Token::Start]; + tokens.extend(message.split_whitespace().map(|word| Token::Word(word.to_owned()))); + tokens.push(Token::End); + tokens + } + + fn add_tokenized_message(&mut self, tokenized_message: &Vec<Token>) { + for window in tokenized_message.windows(2) { + let token = &window[0]; + let next_token = &window[1]; + self.token_to_token_sampler.entry(token.clone()).or_insert(TokenSampler::new()).add_token(next_token.clone()); + } + } + + pub fn generage_message(&self) -> String { + let mut current_token = Token::Start; + let mut current_message = "".to_owned(); + loop { + let next_token = self.token_to_token_sampler[¤t_token].sample(); + match &next_token { + Token::Start => unreachable!("Start token should not be reachable inside a message :("), + Token::End => return current_message, + Token::Word(current_word) => current_message = current_message + " " + ¤t_word + } + current_token = next_token; + } + } +}