Initial commit
commit
8b43a38229
@ -0,0 +1,3 @@
|
||||
/target
|
||||
.envrc
|
||||
shitty_wizard.db
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "shitty_wizard"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread"] }
|
||||
reqwest = { version = "0.11.11", features = [] }
|
||||
rusqlite = { version = "0.28.0", features = ["bundled"] }
|
||||
serde = { version = "1.0.143", features = ["derive"] }
|
||||
serde_json = "1.0.83"
|
||||
futures = "0.3.23"
|
||||
regex = "1.6.0"
|
||||
once_cell = "1.13.0"
|
||||
num_cpus = "1.13.1"
|
||||
serenity = { version = "0.11.5", default-features = false, features = ["client", "gateway", "rustls_backend", "model"] }
|
||||
rand = "0.8.5"
|
||||
easy_from = { path = "../easy_from" }
|
@ -0,0 +1,11 @@
|
||||
edition = "2021"
|
||||
unstable_features = true
|
||||
struct_field_align_threshold = 30
|
||||
hard_tabs = true
|
||||
max_width = 80
|
||||
imports_granularity = 'Crate'
|
||||
group_imports = 'StdExternalCrate'
|
||||
format_strings = true
|
||||
wrap_comments = true
|
||||
blank_lines_lower_bound = 0
|
||||
blank_lines_upper_bound = 2
|
@ -0,0 +1,472 @@
|
||||
use std::{
|
||||
cmp::min,
|
||||
env,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use easy_from::EasyFrom;
|
||||
use rusqlite::{Connection, Row};
|
||||
use serde::Serialize;
|
||||
use serenity::{
|
||||
async_trait,
|
||||
model::prelude::{
|
||||
command::{Command, CommandOptionType, CommandType},
|
||||
component::InputTextStyle,
|
||||
interaction::{
|
||||
application_command::ApplicationCommandInteraction,
|
||||
autocomplete::AutocompleteInteraction, Interaction,
|
||||
InteractionResponseType,
|
||||
},
|
||||
CommandId, Ready,
|
||||
},
|
||||
prelude::{Context, EventHandler, GatewayIntents},
|
||||
Client,
|
||||
};
|
||||
|
||||
type SharedConnection = Arc<Mutex<Connection>>;
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
enum NameOrId {
|
||||
Name(String),
|
||||
Id(usize),
|
||||
}
|
||||
#[derive(Debug)]
|
||||
struct ResponseCommandInput {
|
||||
hero: Option<NameOrId>,
|
||||
processed_text: String,
|
||||
}
|
||||
impl ResponseCommandInput {
|
||||
fn parse(input: &str) -> Result<Self> {
|
||||
let mut hero = None;
|
||||
let mut pieces = Vec::new();
|
||||
|
||||
let words = input.split_whitespace();
|
||||
|
||||
for word in words {
|
||||
if let Some(name_or_id_str) = word.strip_prefix("@hero:") {
|
||||
match name_or_id_str.parse::<usize>() {
|
||||
Ok(u) => hero = Some(NameOrId::Id(u)),
|
||||
Err(_) => {
|
||||
hero = Some(NameOrId::Name(
|
||||
name_or_id_str
|
||||
.replace(|c: char| !c.is_alphanumeric(), ""),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
let word = word.replace(|c: char| !c.is_alphanumeric(), "");
|
||||
pieces.push(word);
|
||||
}
|
||||
|
||||
let processed_text = pieces.join(" ");
|
||||
|
||||
Ok(Self {
|
||||
hero,
|
||||
processed_text,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_response_command(
|
||||
conn: &SharedConnection,
|
||||
cx: Context,
|
||||
ac: ApplicationCommandInteraction,
|
||||
) -> Result<()> {
|
||||
let input = match ac.data.options.iter().find(|o| o.name == "voice_line") {
|
||||
Some(input) => input,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let input = input.value.as_ref().unwrap();
|
||||
let input = input.as_str().unwrap();
|
||||
let input = ResponseCommandInput::parse(input)?;
|
||||
|
||||
let chosen = {
|
||||
type ResponseData = (HeroName, IconUrl, OriginalText, AudioUrl);
|
||||
|
||||
let results: Vec<ResponseData> = {
|
||||
let conn = conn.lock().unwrap();
|
||||
let map_row = |r: &Row| -> rusqlite::Result<ResponseData> {
|
||||
let hero_name = r.get(0)?;
|
||||
let icon_url = r.get(1)?;
|
||||
let original_text = r.get(2)?;
|
||||
let audio_url = r.get(3)?;
|
||||
|
||||
Ok((hero_name, icon_url, original_text, audio_url))
|
||||
};
|
||||
|
||||
let map_err =
|
||||
|r: rusqlite::Result<ResponseData>| -> Result<ResponseData> {
|
||||
r.map_err(Error::from)
|
||||
};
|
||||
|
||||
match &input.hero {
|
||||
None => {
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT name, icon_url, original_text, audio_url
|
||||
FROM responses AS r
|
||||
INNER JOIN heroes AS h ON h.id = r.hero_id
|
||||
WHERE processed_text LIKE ?
|
||||
",
|
||||
)?;
|
||||
|
||||
let results = stmt
|
||||
.query_map((input.processed_text.as_str(),), map_row)?
|
||||
.map(map_err)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
results
|
||||
}
|
||||
Some(NameOrId::Id(hero_id)) => {
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT name, icon_url, original_text, audio_url
|
||||
FROM responses AS r
|
||||
INNER JOIN heroes AS h ON h.id = r.hero_id
|
||||
WHERE
|
||||
processed_text LIKE ?
|
||||
AND h.id = ?
|
||||
",
|
||||
)?;
|
||||
|
||||
let results = stmt
|
||||
.query_map(
|
||||
(input.processed_text.as_str(), hero_id),
|
||||
map_row,
|
||||
)?
|
||||
.map(map_err)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
results
|
||||
}
|
||||
Some(NameOrId::Name(hero_name)) => {
|
||||
let hero_name = format!("%{}%", hero_name);
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT name, icon_url, original_text, audio_url
|
||||
FROM responses AS r
|
||||
INNER JOIN heroes AS h ON h.id = r.hero_id
|
||||
WHERE
|
||||
processed_text LIKE ?
|
||||
AND h.name LIKE ?
|
||||
",
|
||||
)?;
|
||||
|
||||
let results = stmt
|
||||
.query_map(
|
||||
(input.processed_text.as_str(), hero_name),
|
||||
map_row,
|
||||
)?
|
||||
.map(map_err)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if results.is_empty() {
|
||||
ac.create_interaction_response(&cx.http, |ir| {
|
||||
ir.kind(InteractionResponseType::ChannelMessageWithSource)
|
||||
.interaction_response_data(|rd| {
|
||||
rd.content("No responses matched").ephemeral(true)
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
|
||||
let res = results.choose(&mut rand::thread_rng()).unwrap().clone();
|
||||
|
||||
res
|
||||
};
|
||||
|
||||
ac.create_interaction_response(&cx.http, |ir| {
|
||||
ir.kind(InteractionResponseType::ChannelMessageWithSource)
|
||||
.interaction_response_data(|rd| {
|
||||
rd.embed(|e| {
|
||||
e.author(|a| {
|
||||
a.name(&chosen.0).icon_url(chosen.1).url(format!(
|
||||
"https://dota2.fandom.com/wiki/{}",
|
||||
chosen.0.replace(' ', "_"),
|
||||
))
|
||||
})
|
||||
.description(format!("[{}]({})", chosen.2, chosen.3))
|
||||
})
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_reply_with_response_command(
|
||||
conn: &SharedConnection,
|
||||
cx: Context,
|
||||
ac: ApplicationCommandInteraction,
|
||||
) -> Result<()> {
|
||||
dbg!(&ac);
|
||||
ac.create_interaction_response(&cx.http, |ir| {
|
||||
ir.kind(InteractionResponseType::Modal)
|
||||
.interaction_response_data(|ird| {
|
||||
ird.content("test")
|
||||
.title("Test")
|
||||
.components(|c| {
|
||||
c.create_action_row(|r| {
|
||||
r.create_input_text(|it| {
|
||||
it.placeholder("test")
|
||||
.label("Test")
|
||||
.style(InputTextStyle::Short)
|
||||
.custom_id("test_text")
|
||||
})
|
||||
})
|
||||
})
|
||||
.custom_id("test")
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_application_command(
|
||||
conn: &SharedConnection,
|
||||
cx: Context,
|
||||
ac: ApplicationCommandInteraction,
|
||||
) -> Result<()> {
|
||||
match ac.data.name.as_str() {
|
||||
"response" => handle_response_command(conn, cx, ac).await,
|
||||
"Reply with response" => {
|
||||
handle_reply_with_response_command(conn, cx, ac).await
|
||||
}
|
||||
_ => {
|
||||
eprintln!("Unknown command: {}", ac.data.name);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
type OriginalText = String;
|
||||
type ProcessedText = String;
|
||||
type HeroName = String;
|
||||
type HeroId = usize;
|
||||
type AudioUrl = String;
|
||||
type IconUrl = String;
|
||||
|
||||
async fn handle_response_autocomplete(
|
||||
conn: &SharedConnection,
|
||||
cx: Context,
|
||||
ac: AutocompleteInteraction,
|
||||
) -> Result<()> {
|
||||
let input = match ac.data.options.iter().find(|o| o.name == "voice_line") {
|
||||
Some(input) => input,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let input = input.value.as_ref().unwrap();
|
||||
let input = input.as_str().unwrap();
|
||||
let input = ResponseCommandInput::parse(input)?;
|
||||
|
||||
let fuzzy_input = format!("%{}%", input.processed_text.replace(' ', "%"));
|
||||
|
||||
let suggestions: Vec<Suggestion> = {
|
||||
let map_row = |r: &Row| -> rusqlite::Result<Suggestion> {
|
||||
let original_text: OriginalText = r.get(0)?;
|
||||
let processed_text: ProcessedText = r.get(1)?;
|
||||
let hero_name: HeroName = r.get(2)?;
|
||||
let hero_id: HeroId = r.get(3)?;
|
||||
|
||||
let name = format!("{}: {}", hero_name, original_text);
|
||||
let value = format!("@hero:{} {}", hero_id, processed_text);
|
||||
|
||||
let name = name[0..min(100, name.len())].to_string();
|
||||
let value = value[0..min(100, value.len())].to_string();
|
||||
|
||||
Ok(Suggestion { name, value })
|
||||
};
|
||||
let map_err = |r: rusqlite::Result<Suggestion>| -> Result<Suggestion> {
|
||||
r.map_err(Error::from)
|
||||
};
|
||||
|
||||
let conn = conn.lock().unwrap();
|
||||
match &input.hero {
|
||||
None => {
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT original_text, processed_text, name, h.id AS hero_id
|
||||
FROM responses
|
||||
INNER JOIN heroes AS h ON h.id = responses.hero_id
|
||||
WHERE responses.processed_text LIKE ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 25
|
||||
",
|
||||
)?;
|
||||
|
||||
let results = stmt
|
||||
.query_map((fuzzy_input.as_str(),), map_row)?
|
||||
.map(map_err)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
results
|
||||
}
|
||||
Some(NameOrId::Id(hero_id)) => {
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT original_text, processed_text, name, h.id AS hero_id
|
||||
FROM responses
|
||||
INNER JOIN heroes AS h ON h.id = responses.hero_id
|
||||
WHERE
|
||||
responses.processed_text LIKE ?
|
||||
AND h.id = ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 25
|
||||
",
|
||||
)?;
|
||||
|
||||
let results = stmt
|
||||
.query_map((fuzzy_input.as_str(), hero_id), map_row)?
|
||||
.map(map_err)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
results
|
||||
}
|
||||
Some(NameOrId::Name(hero_name)) => {
|
||||
let hero_name = format!("%{}%", hero_name);
|
||||
let mut stmt = conn.prepare(
|
||||
r"
|
||||
SELECT original_text, processed_text, name, h.id AS hero_id
|
||||
FROM responses
|
||||
INNER JOIN heroes AS h ON h.id = responses.hero_id
|
||||
WHERE
|
||||
responses.processed_text LIKE ?
|
||||
AND h.name LIKE ?
|
||||
ORDER BY RANDOM()
|
||||
LIMIT 25
|
||||
",
|
||||
)?;
|
||||
|
||||
let results = stmt
|
||||
.query_map((fuzzy_input.as_str(), hero_name), map_row)?
|
||||
.map(map_err)
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let suggestions = serde_json::to_value(suggestions)?;
|
||||
|
||||
ac.create_autocomplete_response(cx.http, |r| r.set_choices(suggestions))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_autocomplete(
|
||||
conn: &SharedConnection,
|
||||
cx: Context,
|
||||
ac: AutocompleteInteraction,
|
||||
) -> Result<()> {
|
||||
match ac.data.name.as_str() {
|
||||
"response" => handle_response_autocomplete(conn, cx, ac).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
struct Suggestion {
|
||||
name: String,
|
||||
value: String,
|
||||
}
|
||||
|
||||
struct Handler {
|
||||
conn: SharedConnection,
|
||||
}
|
||||
#[async_trait]
|
||||
impl EventHandler for Handler {
|
||||
async fn interaction_create(&self, cx: Context, interaction: Interaction) {
|
||||
let res: Result<()> = async {
|
||||
match interaction {
|
||||
Interaction::ApplicationCommand(ac) => {
|
||||
handle_application_command(&self.conn, cx, ac).await
|
||||
}
|
||||
Interaction::Autocomplete(ac) => {
|
||||
handle_autocomplete(&self.conn, cx, ac).await
|
||||
}
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = res {
|
||||
eprintln!("{:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
async fn ready(&self, cx: Context, ready: Ready) {
|
||||
println!("{} connected", ready.user.name);
|
||||
|
||||
match Command::create_global_application_command(&cx.http, |c| {
|
||||
c.name("response")
|
||||
.description("Send a response voice line")
|
||||
.kind(CommandType::ChatInput)
|
||||
.create_option(|c| {
|
||||
c.kind(CommandOptionType::String)
|
||||
.name("voice_line")
|
||||
.description("the voice line to send")
|
||||
.set_autocomplete(true)
|
||||
.required(true)
|
||||
})
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => println!("Command registered"),
|
||||
Err(e) => eprintln!("Failed to register command: {:?}", e),
|
||||
};
|
||||
|
||||
match Command::create_global_application_command(&cx.http, |c| {
|
||||
c.name("Reply with response").kind(CommandType::Message)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(_) => println!("Command registered"),
|
||||
Err(e) => eprintln!("Failed to register command: {:?}", e),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(EasyFrom, Debug)]
|
||||
pub enum Error {
|
||||
Serenity(#[from] serenity::Error),
|
||||
Rusqlite(#[from] rusqlite::Error),
|
||||
SerdeJson(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = core::result::Result<T, Error>;
|
||||
|
||||
pub async fn run_bot(conn: Connection) -> Result<()> {
|
||||
let token = env::var("DISCORD_TOKEN").expect("missing DISCORD_TOKEN");
|
||||
let intents = GatewayIntents::empty();
|
||||
let mut client = Client::builder(&token, intents)
|
||||
.event_handler(Handler {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
})
|
||||
.await?;
|
||||
|
||||
|
||||
client.start().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
pub mod heroes {
|
||||
#[derive(Debug)]
|
||||
pub struct Id(pub usize);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NameOwned(pub String);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct IconUrlOwned(pub String);
|
||||
}
|
||||
|
||||
pub mod responses {
|
||||
#[derive(Debug)]
|
||||
pub struct Id(pub usize);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HeroId(pub usize);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProcessedTextOwned(pub String);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OriginalTextOwned(pub String);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AudioUrlOwned(pub String);
|
||||
}
|
@ -0,0 +1,501 @@
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use easy_from::EasyFrom;
|
||||
use futures::{
|
||||
future::{join_all, try_join_all},
|
||||
try_join,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use rusqlite::Connection;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
use crate::db_types::{heroes, responses};
|
||||
|
||||
fn check_status(res: reqwest::Response) -> Result<reqwest::Response> {
|
||||
if !res.status().is_success() {
|
||||
return Err(Error::HttpUnsuccessful(res.status()));
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn parse_response_json<R>(res: reqwest::Response) -> Result<R>
|
||||
where
|
||||
R: for<'a> Deserialize<'a>,
|
||||
{
|
||||
let res = check_status(res)?;
|
||||
|
||||
Ok(res.json::<R>().await?)
|
||||
}
|
||||
|
||||
#[derive(Debug, EasyFrom)]
|
||||
pub enum Error {
|
||||
Rusqlite(#[from] rusqlite::Error),
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
HttpUnsuccessful(reqwest::StatusCode),
|
||||
NoIconRegexMatch,
|
||||
}
|
||||
pub type Result<T> = core::result::Result<T, Error>;
|
||||
|
||||
pub fn initialize_db(conn: &mut Connection) -> Result<()> {
|
||||
conn.execute_batch(
|
||||
r"
|
||||
DROP TABLE IF EXISTS responses;
|
||||
DROP TABLE IF EXISTS heroes;
|
||||
|
||||
CREATE TABLE heroes (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
icon_url TEXT NOT NULL
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE responses (
|
||||
id INTEGER PRIMARY KEY,
|
||||
hero_id INTEGER NOT NULL,
|
||||
processed_text TEXT NOT NULL,
|
||||
original_text TEXT NOT NULL,
|
||||
audio_url TEXT NOT NULL,
|
||||
FOREIGN KEY(hero_id) REFERENCES heroes(id)
|
||||
) STRICT;
|
||||
",
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct RequestLimiter {
|
||||
queue: Arc<Mutex<VecDeque<Arc<Barrier>>>>,
|
||||
in_flight: AtomicU64,
|
||||
max_in_flight: usize,
|
||||
}
|
||||
impl RequestLimiter {
|
||||
fn new(max_in_flight: usize) -> Self {
|
||||
RequestLimiter {
|
||||
queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
in_flight: AtomicU64::new(0),
|
||||
max_in_flight,
|
||||
}
|
||||
}
|
||||
|
||||
async fn unblock_one(&self) {
|
||||
let next = match self.queue.lock().unwrap().pop_front() {
|
||||
Some(b) => b,
|
||||
None => return,
|
||||
};
|
||||
|
||||
next.wait().await;
|
||||
}
|
||||
|
||||
async fn get(&self, url: String) -> reqwest::Result<reqwest::Response> {
|
||||
loop {
|
||||
let in_flight = self.in_flight.load(Ordering::Relaxed);
|
||||
let next_in_flight = in_flight + 1;
|
||||
if next_in_flight < self.max_in_flight as u64 {
|
||||
if self
|
||||
.in_flight
|
||||
.compare_exchange(
|
||||
in_flight,
|
||||
next_in_flight,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
)
|
||||
.is_err()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let res = reqwest::get(url).await;
|
||||
self.in_flight.fetch_sub(1, Ordering::Relaxed);
|
||||
self.unblock_one().await;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
let barrier = Arc::new(Barrier::new(2));
|
||||
self.queue.lock().unwrap().push_back(barrier.clone());
|
||||
|
||||
barrier.wait().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_and_insert_data(conn: &mut Connection) -> Result<()> {
|
||||
let limiter = RequestLimiter::new(128);
|
||||
|
||||
let hero_pages = fetch_response_page_urls(&limiter).await?;
|
||||
let futures = hero_pages
|
||||
.into_iter()
|
||||
.map(|p| get_heroes(p, &limiter))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let heroes =
|
||||
join_all(futures)
|
||||
.await
|
||||
.into_iter()
|
||||
.flat_map(|(hn, r)| match r {
|
||||
Err(_) => {
|
||||
eprintln!("Failed to fetch {}: {:?}", hn, r);
|
||||
|
||||
None
|
||||
}
|
||||
Ok(r) => Some(r),
|
||||
});
|
||||
|
||||
println!("All fetched, inserting into DB");
|
||||
for h in heroes {
|
||||
let tx = conn.transaction()?;
|
||||
{
|
||||
let mut insert_response = tx.prepare(
|
||||
"INSERT INTO responses (hero_id, processed_text, \
|
||||
original_text, audio_url) VALUES (?, ?, ?, ?)",
|
||||
)?;
|
||||
|
||||
let mut insert_hero_icon = tx.prepare(
|
||||
"INSERT INTO heroes (name, icon_url) VALUES (?, ?) RETURNING \
|
||||
id",
|
||||
)?;
|
||||
|
||||
|
||||
let hero_id: heroes::Id = insert_hero_icon
|
||||
.query_row((h.name.0, h.icon_url.0), |r| {
|
||||
Ok(heroes::Id(r.get(0)?))
|
||||
})?;
|
||||
|
||||
|
||||
for response in h.responses {
|
||||
insert_response.execute((
|
||||
hero_id.0,
|
||||
response.processed_text.0,
|
||||
response.original_text.0,
|
||||
response.audio_url.0,
|
||||
))?;
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct CategoryMember {
|
||||
title: String,
|
||||
}
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct CategoryMembers {
|
||||
categorymembers: Vec<CategoryMember>,
|
||||
}
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct HeroesResponse {
|
||||
query: CategoryMembers,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct HeroUrls {
|
||||
hero_name: String,
|
||||
responses_url: String,
|
||||
hero_page_url: String,
|
||||
}
|
||||
const API_BASE: &str = "https://dota2.fandom.com";
|
||||
async fn fetch_response_page_urls(
|
||||
limiter: &RequestLimiter,
|
||||
) -> Result<Vec<HeroUrls>> {
|
||||
let hero_endpoint = format!(
|
||||
"{}/api.php?action=query&list=categorymembers&cmtitle=Category:Heroes&\
|
||||
cmlimit=max&cmtype=page&format=json",
|
||||
API_BASE
|
||||
);
|
||||
let resp = limiter.get(hero_endpoint).await?;
|
||||
let resp: HeroesResponse = parse_response_json(resp).await?;
|
||||
|
||||
Ok(resp
|
||||
.query
|
||||
.categorymembers
|
||||
.into_iter()
|
||||
.map(|cm| HeroUrls {
|
||||
hero_name: cm.title.clone(),
|
||||
responses_url: format!(
|
||||
"{}/wiki/{}/Responses?action=raw",
|
||||
API_BASE, cm.title
|
||||
),
|
||||
hero_page_url: format!("{}/wiki/{}?action=raw", API_BASE, cm.title),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Response {
|
||||
processed_text: responses::ProcessedTextOwned,
|
||||
original_text: responses::OriginalTextOwned,
|
||||
audio_url: responses::AudioUrlOwned,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Hero {
|
||||
name: heroes::NameOwned,
|
||||
icon_url: heroes::IconUrlOwned,
|
||||
responses: Vec<Response>,
|
||||
}
|
||||
|
||||
async fn get_heroes(
|
||||
urls: HeroUrls,
|
||||
limiter: &RequestLimiter,
|
||||
) -> (String, Result<Hero>) {
|
||||
let res = (|| async {
|
||||
static RESPONSE: Lazy<Regex> = Lazy::new(|| {
|
||||
Regex::new(
|
||||
r"\*.*<sm2>(?P<file>.*)</sm2>(.*\{\{.*\}\})* (?P<text>[^\{\n]+)",
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
println!("Fetching {}", urls.hero_name);
|
||||
|
||||
let get_responses = (|| async {
|
||||
let resp = limiter.get(urls.responses_url).await?;
|
||||
if !resp.status().is_success() {
|
||||
panic!(
|
||||
"Failed to fetch hero responses: {} {}",
|
||||
urls.hero_name,
|
||||
resp.status()
|
||||
);
|
||||
}
|
||||
let body = resp.text().await?;
|
||||
|
||||
let responses: Vec<_> = RESPONSE
|
||||
.captures_iter(body.as_str())
|
||||
.map(|c| {
|
||||
let file = c.name("file").unwrap();
|
||||
let response = c.name("text").unwrap();
|
||||
|
||||
Ok((response.as_str(), OriginalFilename(file.as_str())))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
|
||||
let filenames: Vec<_> =
|
||||
responses.iter().map(|(_, file)| *file).collect();
|
||||
let mut filenames_to_urls =
|
||||
fetch_all_file_urls(&filenames, limiter).await?;
|
||||
let responses = responses
|
||||
.into_iter()
|
||||
.flat_map(|(response, filename)| {
|
||||
let url = match filenames_to_urls.remove(filename.0) {
|
||||
None => return None,
|
||||
Some(url) => url,
|
||||
};
|
||||
|
||||
Some(Response {
|
||||
processed_text: process_response(response),
|
||||
original_text: responses::OriginalTextOwned(
|
||||
response.to_string(),
|
||||
),
|
||||
audio_url: responses::AudioUrlOwned(url.0),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!("Fetched responses for {}", urls.hero_name);
|
||||
|
||||
Result::Ok(responses)
|
||||
})();
|
||||
|
||||
let get_icon_url = (|| async {
|
||||
let res = limiter.get(urls.hero_page_url).await?;
|
||||
let res = check_status(res)?;
|
||||
let body = res.text().await?;
|
||||
|
||||
static ICON: Lazy<Regex> = Lazy::new(|| {
|
||||
Regex::new(r"(?m)^\s*\|\s*icon\s*=\s*(?P<icon>.*)$").unwrap()
|
||||
});
|
||||
|
||||
let icon_filename = ICON
|
||||
.captures_iter(body.as_str())
|
||||
.next()
|
||||
.ok_or(Error::NoIconRegexMatch)?
|
||||
.name("icon")
|
||||
.unwrap()
|
||||
.as_str();
|
||||
let res = fetch_hero_icon_url(icon_filename, limiter).await?;
|
||||
|
||||
println!("Got icon for {}", urls.hero_name);
|
||||
|
||||
Result::Ok(res)
|
||||
})();
|
||||
|
||||
let (responses, icon_url) = try_join!(get_responses, get_icon_url)?;
|
||||
|
||||
Ok(Hero {
|
||||
name: heroes::NameOwned(urls.hero_name.clone()),
|
||||
icon_url: heroes::IconUrlOwned(icon_url),
|
||||
responses,
|
||||
})
|
||||
})()
|
||||
.await;
|
||||
|
||||
(urls.hero_name, res)
|
||||
}
|
||||
|
||||
fn process_response(s: &str) -> responses::ProcessedTextOwned {
|
||||
responses::ProcessedTextOwned(
|
||||
s.to_lowercase()
|
||||
.replace(|c: char| !c.is_whitespace() && !c.is_alphanumeric(), ""),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageInfo {
|
||||
url: String,
|
||||
}
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageInfoPage {
|
||||
imageinfo: Option<Vec<ImageInfo>>,
|
||||
title: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NormalizedFilename {
|
||||
from: String,
|
||||
to: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageInfoQueryWithNormalized {
|
||||
normalized: Vec<NormalizedFilename>,
|
||||
pages: HashMap<String, ImageInfoPage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageInfoQuery {
|
||||
pages: HashMap<String, ImageInfoPage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ImageInfoResponseBase<Body> {
|
||||
query: Body,
|
||||
}
|
||||
|
||||
type ImageInfoWithNormalizedResponse =
|
||||
ImageInfoResponseBase<ImageInfoQueryWithNormalized>;
|
||||
type ImageInfoResponse = ImageInfoResponseBase<ImageInfoQuery>;
|
||||
|
||||
#[derive(Clone, Hash, Eq, PartialEq)]
|
||||
struct OriginalFilenameOwned(String);
|
||||
|
||||
#[derive(Clone, Hash, Eq, PartialEq, Copy)]
|
||||
struct OriginalFilename<'a>(&'a str);
|
||||
impl core::borrow::Borrow<str> for OriginalFilenameOwned {
|
||||
fn borrow(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct FileUrlOwned(String);
|
||||
|
||||
async fn fetch_all_file_urls(
|
||||
filenames: &[OriginalFilename<'_>],
|
||||
limiter: &RequestLimiter,
|
||||
) -> Result<HashMap<OriginalFilenameOwned, FileUrlOwned>> {
|
||||
let futures = filenames
|
||||
.chunks(50)
|
||||
.map(|filenames| async {
|
||||
let fnames = filenames
|
||||
.iter()
|
||||
.map(|fname| format!("File:{}", fname.0))
|
||||
.collect::<Vec<_>>()
|
||||
.join("|");
|
||||
let url = format!(
|
||||
"{}/api.php?action=query&prop=imageinfo&iiprop=url&\
|
||||
format=json&titles={}",
|
||||
API_BASE, fnames,
|
||||
);
|
||||
|
||||
let resp = limiter.get(url).await?;
|
||||
let res: ImageInfoWithNormalizedResponse =
|
||||
parse_response_json(resp).await?;
|
||||
|
||||
let normalized_to_original: HashMap<&str, OriginalFilenameOwned> =
|
||||
res.query
|
||||
.normalized
|
||||
.iter()
|
||||
.map(|n| {
|
||||
(
|
||||
n.to.as_str(),
|
||||
OriginalFilenameOwned(
|
||||
n.from.as_str().replace("File:", ""),
|
||||
),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let original_to_url: HashMap<OriginalFilenameOwned, FileUrlOwned> =
|
||||
res.query
|
||||
.pages
|
||||
.into_values()
|
||||
.flat_map(|p| {
|
||||
let title = p.title;
|
||||
let url = match p.imageinfo {
|
||||
Some(ii) => ii.into_iter().next().unwrap().url,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
Some((
|
||||
normalized_to_original
|
||||
.get(title.as_str())
|
||||
.unwrap()
|
||||
.clone(),
|
||||
FileUrlOwned(url),
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Result::Ok(original_to_url)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let maps = try_join_all(futures).await?;
|
||||
let map = maps.into_iter().fold(HashMap::new(), |mut acc, map| {
|
||||
acc.extend(map.into_iter());
|
||||
|
||||
acc
|
||||
});
|
||||
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
async fn fetch_hero_icon_url(
|
||||
filename: &str,
|
||||
limiter: &RequestLimiter,
|
||||
) -> Result<String> {
|
||||
let url = format!(
|
||||
"{}/api.php?action=query&prop=imageinfo&iiprop=url&format=json&\
|
||||
titles={}",
|
||||
API_BASE, filename
|
||||
);
|
||||
let res = limiter.get(url).await?;
|
||||
let res: ImageInfoResponse = parse_response_json(res).await?;
|
||||
|
||||
let url = res
|
||||
.query
|
||||
.pages
|
||||
.into_values()
|
||||
.next()
|
||||
.unwrap()
|
||||
.imageinfo
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.url;
|
||||
|
||||
Ok(url)
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
#![allow(clippy::never_loop)]
|
||||
#![allow(clippy::let_and_return)]
|
||||
|
||||
mod bot;
|
||||
mod db_types;
|
||||
mod init_db;
|
||||
|
||||
use std::env;
|
||||
|
||||
use bot::run_bot;
|
||||
use init_db::{fetch_and_insert_data, initialize_db};
|
||||
use rusqlite::Connection;
|
||||
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let mut conn = Connection::open("./shitty_wizard.db").unwrap();
|
||||
|
||||
let should_refresh_db = loop {
|
||||
if env::args().any(|a| a == "--refresh-db") {
|
||||
break true;
|
||||
}
|
||||
|
||||
let count: usize =
|
||||
match conn
|
||||
.query_row("SELECT COUNT(*) FROM responses", (), |r| r.get(0))
|
||||
{
|
||||
Ok(count) => count,
|
||||
Err(_) => break true,
|
||||
};
|
||||
|
||||
break count == 0;
|
||||
};
|
||||
|
||||
if should_refresh_db {
|
||||
initialize_db(&mut conn).unwrap();
|
||||
fetch_and_insert_data(&mut conn).await.unwrap();
|
||||
};
|
||||
|
||||
run_bot(conn).await.unwrap();
|
||||
}
|
Loading…
Reference in New Issue