Improve handling of Postgres errors

This commit is contained in:
Daniel Sockwell 2020-04-13 22:17:53 -04:00
parent 45f9d4b9fb
commit 3280c12fe1
10 changed files with 123 additions and 84 deletions

View File

@ -14,7 +14,6 @@ mod redis_cfg;
mod redis_cfg_types; mod redis_cfg_types;
pub fn merge_dotenv() -> Result<(), err::FatalErr> { pub fn merge_dotenv() -> Result<(), err::FatalErr> {
// TODO -- should this allow the user to run in a dir without a `.env` file?
dotenv::from_filename(match env::var("ENV").ok().as_deref() { dotenv::from_filename(match env::var("ENV").ok().as_deref() {
Some("production") => ".env.production", Some("production") => ".env.production",
Some("development") | None => ".env", Some("development") | None => ".env",

View File

@ -1,3 +1,4 @@
use crate::request::RequestErr;
use crate::response::ManagerErr; use crate::response::ManagerErr;
use std::fmt; use std::fmt;
@ -6,6 +7,7 @@ pub enum FatalErr {
ReceiverErr(ManagerErr), ReceiverErr(ManagerErr),
DotEnv(dotenv::Error), DotEnv(dotenv::Error),
Logger(log::SetLoggerError), Logger(log::SetLoggerError),
Postgres(RequestErr),
} }
impl FatalErr { impl FatalErr {
@ -33,11 +35,18 @@ impl fmt::Display for FatalErr {
ReceiverErr(e) => format!("{}", e), ReceiverErr(e) => format!("{}", e),
Logger(e) => format!("{}", e), Logger(e) => format!("{}", e),
DotEnv(e) => format!("Could not load specified environmental file: {}", e), DotEnv(e) => format!("Could not load specified environmental file: {}", e),
Postgres(e) => format!("Could not connect to Postgres: {}", e),
} }
) )
} }
} }
impl From<RequestErr> for FatalErr {
fn from(e: RequestErr) -> Self {
Self::Postgres(e)
}
}
impl From<dotenv::Error> for FatalErr { impl From<dotenv::Error> for FatalErr {
fn from(e: dotenv::Error) -> Self { fn from(e: dotenv::Error) -> Self {
Self::DotEnv(e) Self::DotEnv(e)

View File

@ -2,14 +2,14 @@ mod checked_event;
mod dynamic_event; mod dynamic_event;
mod err; mod err;
pub use { pub use checked_event::{CheckedEvent, Id};
checked_event::{CheckedEvent, Id}, pub use dynamic_event::{DynEvent, DynStatus, EventKind};
dynamic_event::{DynEvent, DynStatus, EventKind}, pub use err::EventErr;
err::EventErr,
};
use serde::Serialize; use serde::Serialize;
use std::{convert::TryFrom, string::String}; use std::convert::TryFrom;
use std::string::String;
use warp::sse::ServerSentEvent;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Event { pub enum Event {
@ -20,15 +20,30 @@ pub enum Event {
impl Event { impl Event {
pub fn to_json_string(&self) -> String { pub fn to_json_string(&self) -> String {
let event = &self.event_name(); if let Event::Ping = self {
let sendable_event = match self.payload() { "{}".to_string()
Some(payload) => SendableEvent::WithPayload { event, payload }, } else {
None => SendableEvent::NoPayload { event }, let event = &self.event_name();
}; let sendable_event = match self.payload() {
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize") Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
}
} }
pub fn event_name(&self) -> String { pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
if let Event::Ping = self {
None
} else {
Some((
warp::sse::event(self.event_name()),
warp::sse::data(self.payload().unwrap_or_else(String::new)),
))
}
}
fn event_name(&self) -> String {
String::from(match self { String::from(match self {
Self::TypeSafe(checked) => match checked { Self::TypeSafe(checked) => match checked {
CheckedEvent::Update { .. } => "update", CheckedEvent::Update { .. } => "update",
@ -45,11 +60,11 @@ impl Event {
.. ..
}) => "update", }) => "update",
Self::Dynamic(DynEvent { event, .. }) => event, Self::Dynamic(DynEvent { event, .. }) => event,
Self::Ping => panic!("event_name() called on Ping"), Self::Ping => unreachable!(), // private method only called above
}) })
} }
pub fn payload(&self) -> Option<String> { fn payload(&self) -> Option<String> {
use CheckedEvent::*; use CheckedEvent::*;
match self { match self {
Self::TypeSafe(checked) => match checked { Self::TypeSafe(checked) => match checked {
@ -63,7 +78,7 @@ impl Event {
FiltersChanged => None, FiltersChanged => None,
}, },
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()), Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
Self::Ping => panic!("payload() called on Ping"), Self::Ping => unreachable!(), // private method only called above
} }
} }
} }

View File

@ -25,7 +25,7 @@ fn main() -> Result<(), FatalErr> {
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode); let request = Handler::new(postgres_cfg, *cfg.whitelist_mode)?;
let poll_freq = *redis_cfg.polling_interval; let poll_freq = *redis_cfg.polling_interval;
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc(); let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();

View File

@ -3,8 +3,10 @@ mod postgres;
mod query; mod query;
pub mod timeline; pub mod timeline;
mod err;
mod subscription; mod subscription;
pub use self::err::RequestErr;
pub use self::postgres::PgPool; pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API // TODO consider whether we can remove `Stream` from public API
pub use subscription::{Blocks, Subscription}; pub use subscription::{Blocks, Subscription};
@ -22,6 +24,8 @@ mod sse_test;
#[cfg(test)] #[cfg(test)]
mod ws_test; mod ws_test;
type Result<T> = std::result::Result<T, err::RequestErr>;
/// Helper macro to match on the first of any of the provided filters /// Helper macro to match on the first of any of the provided filters
macro_rules! any_of { macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => { ($filter:expr, $($other_filter:expr),*) => {
@ -56,10 +60,10 @@ pub struct Handler {
} }
impl Handler { impl Handler {
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self { pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Result<Self> {
Self { Ok(Self {
pg_conn: PgPool::new(postgres_cfg, whitelist_mode), pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?,
} })
} }
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> { pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
@ -113,7 +117,7 @@ impl Handler {
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed() warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
} }
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> { pub fn err(r: Rejection) -> std::result::Result<impl warp::Reply, warp::Rejection> {
let json_err = match r.cause() { let json_err = match r.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => { Some(text) if text.to_string() == "Missing request header 'authorization'" => {
warp::reply::json(&"Error: Missing access token".to_string()) warp::reply::json(&"Error: Missing access token".to_string())

25
src/request/err.rs Normal file
View File

@ -0,0 +1,25 @@
use std::fmt;
#[derive(Debug)]
pub enum RequestErr {
Unknown,
PgPool(r2d2::Error),
}
impl std::error::Error for RequestErr {}
impl fmt::Display for RequestErr {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use RequestErr::*;
let msg = match self {
Unknown => "Encountered an unrecoverable error related to handling a request".into(),
PgPool(e) => format!("{}", e),
};
write!(f, "{}", msg)
}
}
impl From<r2d2::Error> for RequestErr {
fn from(e: r2d2::Error) -> Self {
Self::PgPool(e)
}
}

View File

@ -1,13 +1,13 @@
//! Postgres queries //! Postgres queries
use super::err;
use super::timeline::{Scope, UserData};
use crate::config; use crate::config;
use crate::event::Id; use crate::event::Id;
use crate::request::timeline::{Scope, UserData};
use ::postgres; use ::postgres;
use hashbrown::HashSet; use hashbrown::HashSet;
use r2d2_postgres::PostgresConnectionManager; use r2d2_postgres::PostgresConnectionManager;
use std::convert::TryFrom; use std::convert::TryFrom;
use warp::reject::Rejection;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PgPool { pub struct PgPool {
@ -15,8 +15,11 @@ pub struct PgPool {
whitelist_mode: bool, whitelist_mode: bool,
} }
type Result<T> = std::result::Result<T, err::RequestErr>;
type Rejectable<T> = std::result::Result<T, warp::Rejection>;
impl PgPool { impl PgPool {
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self { pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Result<Self> {
let mut cfg = postgres::Config::new(); let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user) cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string()) .host(&*pg_cfg.host.to_string())
@ -27,18 +30,17 @@ impl PgPool {
}; };
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder() let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
.max_size(10)
.build(manager) Ok(Self {
.expect("Can connect to local postgres");
Self {
conn: pool, conn: pool,
whitelist_mode, whitelist_mode,
} })
} }
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> { pub fn select_user(self, token: &Option<String>) -> Rejectable<UserData> {
let mut conn = self.conn.get().unwrap(); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
if let Some(token) = token { if let Some(token) = token {
let query_rows = conn let query_rows = conn
.query(" .query("
@ -47,9 +49,9 @@ SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_lan
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1", LIMIT 1",
&[&token.to_owned()], &[&token.to_owned()],
) ).map_err(warp::reject::custom)?;
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) { if let Some(result_columns) = query_rows.get(0) {
let id = Id(result_columns.get(1)); let id = Id(result_columns.get(1));
let allowed_langs = result_columns let allowed_langs = result_columns
@ -85,10 +87,10 @@ LIMIT 1",
} }
} }
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> { pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable<i64> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name]) conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.get(0) .get(0)
.map(|row| row.get(0)) .map(|row| row.get(0))
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
@ -98,31 +100,31 @@ LIMIT 1",
/// ///
/// **NOTE**: because we check this when the user connects, it will not include any blocks /// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect. /// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> { pub fn select_blocked_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query( conn.query(
"SELECT target_account_id FROM blocks WHERE account_id = $1 "SELECT target_account_id FROM blocks WHERE account_id = $1
UNION SELECT target_account_id FROM mutes WHERE account_id = $1", UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
&[&*user_id], &[&*user_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.iter() .iter()
.map(|row| Id(row.get(0))) .map(|row| Ok(Id(row.get(0))))
.collect() .collect()
} }
/// Query Postgres for everyone who has blocked the user /// Query Postgres for everyone who has blocked the user
/// ///
/// **NOTE**: because we check this when the user connects, it will not include any blocks /// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect. /// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> { pub fn select_blocking_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query( conn.query(
"SELECT account_id FROM blocks WHERE target_account_id = $1", "SELECT account_id FROM blocks WHERE target_account_id = $1",
&[&*user_id], &[&*user_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.iter() .iter()
.map(|row| Id(row.get(0))) .map(|row| Ok(Id(row.get(0))))
.collect() .collect()
} }
@ -130,28 +132,28 @@ LIMIT 1",
/// ///
/// **NOTE**: because we check this when the user connects, it will not include any blocks /// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect. /// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> { pub fn select_blocked_domains(self, user_id: Id) -> Rejectable<HashSet<String>> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query( conn.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1", "SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&*user_id], &[&*user_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.iter() .iter()
.map(|row| row.get(0)) .map(|row| Ok(row.get(0)))
.collect() .collect()
} }
/// Test whether a user owns a list /// Test whether a user owns a list
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool { pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable<bool> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
// For the Postgres query, `id` = list number; `account_id` = user.id // For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn let rows = &conn
.query( .query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_id], &[&list_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])"); .map_err(warp::reject::custom)?;
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id) Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id))
} }
} }

View File

@ -54,7 +54,7 @@ impl Subscription {
let tag = pool.select_hashtag_id(&q.hashtag)?; let tag = pool.select_hashtag_id(&q.hashtag)?;
Timeline(Hashtag(tag), reach, stream) Timeline(Hashtag(tag), reach, stream)
} }
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => { Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => {
Err(warp::reject::custom("Error: Missing access token"))? Err(warp::reject::custom("Error: Missing access token"))?
} }
other_tl => other_tl, other_tl => other_tl,
@ -70,9 +70,9 @@ impl Subscription {
timeline, timeline,
allowed_langs: user.allowed_langs, allowed_langs: user.allowed_langs,
blocks: Blocks { blocks: Blocks {
blocking_users: pool.clone().select_blocking_users(user.id), blocking_users: pool.clone().select_blocking_users(user.id)?,
blocked_users: pool.clone().select_blocked_users(user.id), blocked_users: pool.clone().select_blocked_users(user.id)?,
blocked_domains: pool.select_blocked_domains(user.id), blocked_domains: pool.select_blocked_domains(user.id)?,
}, },
hashtag_name, hashtag_name,
access_token: q.access_token, access_token: q.access_token,

View File

@ -6,18 +6,11 @@ use log;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use warp::reply::Reply; use warp::reply::Reply;
use warp::sse::{ServerSentEvent, Sse as WarpSse}; use warp::sse::Sse as WarpSse;
pub struct Sse; pub struct Sse;
impl Sse { impl Sse {
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
))
}
pub fn send_events( pub fn send_events(
sse: WarpSse, sse: WarpSse,
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>, mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
@ -40,21 +33,20 @@ impl Sse {
TypeSafe(Update { payload, queued_at }) => match timeline { TypeSafe(Update { payload, queued_at }) => match timeline {
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
_ if payload.involves_any(&blocks) => None, _ if payload.involves_any(&blocks) => None,
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { _ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at })
payload, .to_warp_reply(),
queued_at,
})),
}, },
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(),
Dynamic(dyn_event) => { Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind { if let EventKind::Update(s) = dyn_event.kind {
match timeline { match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None, Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
_ if s.involves_any(&blocks) => None, _ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(DynEvent { _ => Dynamic(DynEvent {
kind: EventKind::Update(s), kind: EventKind::Update(s),
..dyn_event ..dyn_event
})), })
.to_warp_reply(),
} }
} else { } else {
None None

View File

@ -50,7 +50,7 @@ impl Ws {
incoming_events.for_each(move |(tl, event)| { incoming_events.for_each(move |(tl, event)| {
if matches!(event, Event::Ping) { if matches!(event, Event::Ping) {
self.send_ping() self.send_msg(event)
} else if target_timeline == tl { } else if target_timeline == tl {
use crate::event::{CheckedEvent::Update, Event::*, EventKind}; use crate::event::{CheckedEvent::Update, Event::*, EventKind};
use crate::request::Stream::Public; use crate::request::Stream::Public;
@ -83,15 +83,8 @@ impl Ws {
}) })
} }
fn send_ping(&mut self) -> Result<(), ()> {
self.send_txt("{}")
}
fn send_msg(&mut self, event: Event) -> Result<(), ()> { fn send_msg(&mut self, event: Event) -> Result<(), ()> {
self.send_txt(&event.to_json_string()) let txt = &event.to_json_string();
}
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
let tl = self.subscription.timeline; let tl = self.subscription.timeline;
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) { match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
Ok(_) => Ok(()), Ok(_) => Ok(()),