Improve handling of Postgres errors

This commit is contained in:
Daniel Sockwell 2020-04-13 22:17:53 -04:00
parent 45f9d4b9fb
commit 3280c12fe1
10 changed files with 123 additions and 84 deletions

View File

@ -14,7 +14,6 @@ mod redis_cfg;
mod redis_cfg_types;
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
// TODO -- should this allow the user to run in a dir without a `.env` file?
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
Some("production") => ".env.production",
Some("development") | None => ".env",

View File

@ -1,3 +1,4 @@
use crate::request::RequestErr;
use crate::response::ManagerErr;
use std::fmt;
@ -6,6 +7,7 @@ pub enum FatalErr {
ReceiverErr(ManagerErr),
DotEnv(dotenv::Error),
Logger(log::SetLoggerError),
Postgres(RequestErr),
}
impl FatalErr {
@ -33,11 +35,18 @@ impl fmt::Display for FatalErr {
ReceiverErr(e) => format!("{}", e),
Logger(e) => format!("{}", e),
DotEnv(e) => format!("Could not load specified environmental file: {}", e),
Postgres(e) => format!("Could not connect to Postgres: {}", e),
}
)
}
}
impl From<RequestErr> for FatalErr {
fn from(e: RequestErr) -> Self {
Self::Postgres(e)
}
}
impl From<dotenv::Error> for FatalErr {
fn from(e: dotenv::Error) -> Self {
Self::DotEnv(e)

View File

@ -2,14 +2,14 @@ mod checked_event;
mod dynamic_event;
mod err;
pub use {
checked_event::{CheckedEvent, Id},
dynamic_event::{DynEvent, DynStatus, EventKind},
err::EventErr,
};
pub use checked_event::{CheckedEvent, Id};
pub use dynamic_event::{DynEvent, DynStatus, EventKind};
pub use err::EventErr;
use serde::Serialize;
use std::{convert::TryFrom, string::String};
use std::convert::TryFrom;
use std::string::String;
use warp::sse::ServerSentEvent;
#[derive(Debug, Clone)]
pub enum Event {
@ -20,15 +20,30 @@ pub enum Event {
impl Event {
pub fn to_json_string(&self) -> String {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
if let Event::Ping = self {
"{}".to_string()
} else {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
}
}
pub fn event_name(&self) -> String {
pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
if let Event::Ping = self {
None
} else {
Some((
warp::sse::event(self.event_name()),
warp::sse::data(self.payload().unwrap_or_else(String::new)),
))
}
}
fn event_name(&self) -> String {
String::from(match self {
Self::TypeSafe(checked) => match checked {
CheckedEvent::Update { .. } => "update",
@ -45,11 +60,11 @@ impl Event {
..
}) => "update",
Self::Dynamic(DynEvent { event, .. }) => event,
Self::Ping => panic!("event_name() called on Ping"),
Self::Ping => unreachable!(), // private method only called above
})
}
pub fn payload(&self) -> Option<String> {
fn payload(&self) -> Option<String> {
use CheckedEvent::*;
match self {
Self::TypeSafe(checked) => match checked {
@ -63,7 +78,7 @@ impl Event {
FiltersChanged => None,
},
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
Self::Ping => panic!("payload() called on Ping"),
Self::Ping => unreachable!(), // private method only called above
}
}
}

View File

@ -25,7 +25,7 @@ fn main() -> Result<(), FatalErr> {
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode);
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode)?;
let poll_freq = *redis_cfg.polling_interval;
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();

View File

@ -3,8 +3,10 @@ mod postgres;
mod query;
pub mod timeline;
mod err;
mod subscription;
pub use self::err::RequestErr;
pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API
pub use subscription::{Blocks, Subscription};
@ -22,6 +24,8 @@ mod sse_test;
#[cfg(test)]
mod ws_test;
type Result<T> = std::result::Result<T, err::RequestErr>;
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
@ -56,10 +60,10 @@ pub struct Handler {
}
impl Handler {
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self {
Self {
pg_conn: PgPool::new(postgres_cfg, whitelist_mode),
}
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Result<Self> {
Ok(Self {
pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?,
})
}
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
@ -113,7 +117,7 @@ impl Handler {
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
}
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> {
pub fn err(r: Rejection) -> std::result::Result<impl warp::Reply, warp::Rejection> {
let json_err = match r.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
warp::reply::json(&"Error: Missing access token".to_string())

25
src/request/err.rs Normal file
View File

@ -0,0 +1,25 @@
use std::fmt;
#[derive(Debug)]
pub enum RequestErr {
Unknown,
PgPool(r2d2::Error),
}
impl std::error::Error for RequestErr {}
impl fmt::Display for RequestErr {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use RequestErr::*;
let msg = match self {
Unknown => "Encountered an unrecoverable error related to handling a request".into(),
PgPool(e) => format!("{}", e),
};
write!(f, "{}", msg)
}
}
impl From<r2d2::Error> for RequestErr {
fn from(e: r2d2::Error) -> Self {
Self::PgPool(e)
}
}

View File

@ -1,13 +1,13 @@
//! Postgres queries
use super::err;
use super::timeline::{Scope, UserData};
use crate::config;
use crate::event::Id;
use crate::request::timeline::{Scope, UserData};
use ::postgres;
use hashbrown::HashSet;
use r2d2_postgres::PostgresConnectionManager;
use std::convert::TryFrom;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
pub struct PgPool {
@ -15,8 +15,11 @@ pub struct PgPool {
whitelist_mode: bool,
}
type Result<T> = std::result::Result<T, err::RequestErr>;
type Rejectable<T> = std::result::Result<T, warp::Rejection>;
impl PgPool {
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Result<Self> {
let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string())
@ -27,18 +30,17 @@ impl PgPool {
};
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder()
.max_size(10)
.build(manager)
.expect("Can connect to local postgres");
Self {
let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
Ok(Self {
conn: pool,
whitelist_mode,
}
})
}
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> {
let mut conn = self.conn.get().unwrap();
pub fn select_user(self, token: &Option<String>) -> Rejectable<UserData> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
if let Some(token) = token {
let query_rows = conn
.query("
@ -47,9 +49,9 @@ SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_lan
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
&[&token.to_owned()],
).map_err(warp::reject::custom)?;
if let Some(result_columns) = query_rows.get(0) {
let id = Id(result_columns.get(1));
let allowed_langs = result_columns
@ -85,10 +87,10 @@ LIMIT 1",
}
}
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable<i64> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.get(0)
.map(|row| row.get(0))
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
@ -98,31 +100,31 @@ LIMIT 1",
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_blocked_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query(
"SELECT target_account_id FROM blocks WHERE account_id = $1
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.iter()
.map(|row| Id(row.get(0)))
.map(|row| Ok(Id(row.get(0))))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_blocking_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query(
"SELECT account_id FROM blocks WHERE target_account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.iter()
.map(|row| Id(row.get(0)))
.map(|row| Ok(Id(row.get(0))))
.collect()
}
@ -130,28 +132,28 @@ LIMIT 1",
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_blocked_domains(self, user_id: Id) -> Rejectable<HashSet<String>> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.iter()
.map(|row| row.get(0))
.map(|row| Ok(row.get(0)))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
let mut conn = self.conn.get().expect("TODO");
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable<bool> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
.map_err(warp::reject::custom)?;
Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id))
}
}

View File

@ -54,7 +54,7 @@ impl Subscription {
let tag = pool.select_hashtag_id(&q.hashtag)?;
Timeline(Hashtag(tag), reach, stream)
}
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => {
Err(warp::reject::custom("Error: Missing access token"))?
}
other_tl => other_tl,
@ -70,9 +70,9 @@ impl Subscription {
timeline,
allowed_langs: user.allowed_langs,
blocks: Blocks {
blocking_users: pool.clone().select_blocking_users(user.id),
blocked_users: pool.clone().select_blocked_users(user.id),
blocked_domains: pool.select_blocked_domains(user.id),
blocking_users: pool.clone().select_blocking_users(user.id)?,
blocked_users: pool.clone().select_blocked_users(user.id)?,
blocked_domains: pool.select_blocked_domains(user.id)?,
},
hashtag_name,
access_token: q.access_token,

View File

@ -6,18 +6,11 @@ use log;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use warp::reply::Reply;
use warp::sse::{ServerSentEvent, Sse as WarpSse};
use warp::sse::Sse as WarpSse;
pub struct Sse;
impl Sse {
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
))
}
pub fn send_events(
sse: WarpSse,
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
@ -40,21 +33,20 @@ impl Sse {
TypeSafe(Update { payload, queued_at }) => match timeline {
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
_ if payload.involves_any(&blocks) => None,
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
payload,
queued_at,
})),
_ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at })
.to_warp_reply(),
},
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind {
match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
_ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(DynEvent {
_ => Dynamic(DynEvent {
kind: EventKind::Update(s),
..dyn_event
})),
})
.to_warp_reply(),
}
} else {
None

View File

@ -50,7 +50,7 @@ impl Ws {
incoming_events.for_each(move |(tl, event)| {
if matches!(event, Event::Ping) {
self.send_ping()
self.send_msg(event)
} else if target_timeline == tl {
use crate::event::{CheckedEvent::Update, Event::*, EventKind};
use crate::request::Stream::Public;
@ -83,15 +83,8 @@ impl Ws {
})
}
fn send_ping(&mut self) -> Result<(), ()> {
self.send_txt("{}")
}
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
self.send_txt(&event.to_json_string())
}
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
let txt = &event.to_json_string();
let tl = self.subscription.timeline;
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
Ok(_) => Ok(()),