mirror of https://github.com/mastodon/flodgatt
Improve handling of Postgres errors
This commit is contained in:
parent
45f9d4b9fb
commit
3280c12fe1
|
@ -14,7 +14,6 @@ mod redis_cfg;
|
|||
mod redis_cfg_types;
|
||||
|
||||
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
|
||||
// TODO -- should this allow the user to run in a dir without a `.env` file?
|
||||
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
|
||||
Some("production") => ".env.production",
|
||||
Some("development") | None => ".env",
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use crate::request::RequestErr;
|
||||
use crate::response::ManagerErr;
|
||||
use std::fmt;
|
||||
|
||||
|
@ -6,6 +7,7 @@ pub enum FatalErr {
|
|||
ReceiverErr(ManagerErr),
|
||||
DotEnv(dotenv::Error),
|
||||
Logger(log::SetLoggerError),
|
||||
Postgres(RequestErr),
|
||||
}
|
||||
|
||||
impl FatalErr {
|
||||
|
@ -33,11 +35,18 @@ impl fmt::Display for FatalErr {
|
|||
ReceiverErr(e) => format!("{}", e),
|
||||
Logger(e) => format!("{}", e),
|
||||
DotEnv(e) => format!("Could not load specified environmental file: {}", e),
|
||||
Postgres(e) => format!("Could not connect to Postgres: {}", e),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RequestErr> for FatalErr {
|
||||
fn from(e: RequestErr) -> Self {
|
||||
Self::Postgres(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<dotenv::Error> for FatalErr {
|
||||
fn from(e: dotenv::Error) -> Self {
|
||||
Self::DotEnv(e)
|
||||
|
|
47
src/event.rs
47
src/event.rs
|
@ -2,14 +2,14 @@ mod checked_event;
|
|||
mod dynamic_event;
|
||||
mod err;
|
||||
|
||||
pub use {
|
||||
checked_event::{CheckedEvent, Id},
|
||||
dynamic_event::{DynEvent, DynStatus, EventKind},
|
||||
err::EventErr,
|
||||
};
|
||||
pub use checked_event::{CheckedEvent, Id};
|
||||
pub use dynamic_event::{DynEvent, DynStatus, EventKind};
|
||||
pub use err::EventErr;
|
||||
|
||||
use serde::Serialize;
|
||||
use std::{convert::TryFrom, string::String};
|
||||
use std::convert::TryFrom;
|
||||
use std::string::String;
|
||||
use warp::sse::ServerSentEvent;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
|
@ -20,15 +20,30 @@ pub enum Event {
|
|||
|
||||
impl Event {
|
||||
pub fn to_json_string(&self) -> String {
|
||||
let event = &self.event_name();
|
||||
let sendable_event = match self.payload() {
|
||||
Some(payload) => SendableEvent::WithPayload { event, payload },
|
||||
None => SendableEvent::NoPayload { event },
|
||||
};
|
||||
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
|
||||
if let Event::Ping = self {
|
||||
"{}".to_string()
|
||||
} else {
|
||||
let event = &self.event_name();
|
||||
let sendable_event = match self.payload() {
|
||||
Some(payload) => SendableEvent::WithPayload { event, payload },
|
||||
None => SendableEvent::NoPayload { event },
|
||||
};
|
||||
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn event_name(&self) -> String {
|
||||
pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
|
||||
if let Event::Ping = self {
|
||||
None
|
||||
} else {
|
||||
Some((
|
||||
warp::sse::event(self.event_name()),
|
||||
warp::sse::data(self.payload().unwrap_or_else(String::new)),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn event_name(&self) -> String {
|
||||
String::from(match self {
|
||||
Self::TypeSafe(checked) => match checked {
|
||||
CheckedEvent::Update { .. } => "update",
|
||||
|
@ -45,11 +60,11 @@ impl Event {
|
|||
..
|
||||
}) => "update",
|
||||
Self::Dynamic(DynEvent { event, .. }) => event,
|
||||
Self::Ping => panic!("event_name() called on Ping"),
|
||||
Self::Ping => unreachable!(), // private method only called above
|
||||
})
|
||||
}
|
||||
|
||||
pub fn payload(&self) -> Option<String> {
|
||||
fn payload(&self) -> Option<String> {
|
||||
use CheckedEvent::*;
|
||||
match self {
|
||||
Self::TypeSafe(checked) => match checked {
|
||||
|
@ -63,7 +78,7 @@ impl Event {
|
|||
FiltersChanged => None,
|
||||
},
|
||||
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
|
||||
Self::Ping => panic!("payload() called on Ping"),
|
||||
Self::Ping => unreachable!(), // private method only called above
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ fn main() -> Result<(), FatalErr> {
|
|||
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
||||
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
||||
|
||||
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode);
|
||||
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode)?;
|
||||
let poll_freq = *redis_cfg.polling_interval;
|
||||
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
||||
|
||||
|
|
|
@ -3,8 +3,10 @@ mod postgres;
|
|||
mod query;
|
||||
pub mod timeline;
|
||||
|
||||
mod err;
|
||||
mod subscription;
|
||||
|
||||
pub use self::err::RequestErr;
|
||||
pub use self::postgres::PgPool;
|
||||
// TODO consider whether we can remove `Stream` from public API
|
||||
pub use subscription::{Blocks, Subscription};
|
||||
|
@ -22,6 +24,8 @@ mod sse_test;
|
|||
#[cfg(test)]
|
||||
mod ws_test;
|
||||
|
||||
type Result<T> = std::result::Result<T, err::RequestErr>;
|
||||
|
||||
/// Helper macro to match on the first of any of the provided filters
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
|
@ -56,10 +60,10 @@ pub struct Handler {
|
|||
}
|
||||
|
||||
impl Handler {
|
||||
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
||||
Self {
|
||||
pg_conn: PgPool::new(postgres_cfg, whitelist_mode),
|
||||
}
|
||||
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Result<Self> {
|
||||
Ok(Self {
|
||||
pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
|
||||
|
@ -113,7 +117,7 @@ impl Handler {
|
|||
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
|
||||
}
|
||||
|
||||
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
pub fn err(r: Rejection) -> std::result::Result<impl warp::Reply, warp::Rejection> {
|
||||
let json_err = match r.cause() {
|
||||
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
||||
warp::reply::json(&"Error: Missing access token".to_string())
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
use std::fmt;
|
||||
#[derive(Debug)]
|
||||
pub enum RequestErr {
|
||||
Unknown,
|
||||
PgPool(r2d2::Error),
|
||||
}
|
||||
|
||||
impl std::error::Error for RequestErr {}
|
||||
|
||||
impl fmt::Display for RequestErr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
use RequestErr::*;
|
||||
let msg = match self {
|
||||
Unknown => "Encountered an unrecoverable error related to handling a request".into(),
|
||||
PgPool(e) => format!("{}", e),
|
||||
};
|
||||
write!(f, "{}", msg)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<r2d2::Error> for RequestErr {
|
||||
fn from(e: r2d2::Error) -> Self {
|
||||
Self::PgPool(e)
|
||||
}
|
||||
}
|
|
@ -1,13 +1,13 @@
|
|||
//! Postgres queries
|
||||
use super::err;
|
||||
use super::timeline::{Scope, UserData};
|
||||
use crate::config;
|
||||
use crate::event::Id;
|
||||
use crate::request::timeline::{Scope, UserData};
|
||||
|
||||
use ::postgres;
|
||||
use hashbrown::HashSet;
|
||||
use r2d2_postgres::PostgresConnectionManager;
|
||||
use std::convert::TryFrom;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PgPool {
|
||||
|
@ -15,8 +15,11 @@ pub struct PgPool {
|
|||
whitelist_mode: bool,
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, err::RequestErr>;
|
||||
type Rejectable<T> = std::result::Result<T, warp::Rejection>;
|
||||
|
||||
impl PgPool {
|
||||
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
||||
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Result<Self> {
|
||||
let mut cfg = postgres::Config::new();
|
||||
cfg.user(&pg_cfg.user)
|
||||
.host(&*pg_cfg.host.to_string())
|
||||
|
@ -27,18 +30,17 @@ impl PgPool {
|
|||
};
|
||||
|
||||
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
||||
let pool = r2d2::Pool::builder()
|
||||
.max_size(10)
|
||||
.build(manager)
|
||||
.expect("Can connect to local postgres");
|
||||
Self {
|
||||
let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: pool,
|
||||
whitelist_mode,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> {
|
||||
let mut conn = self.conn.get().unwrap();
|
||||
pub fn select_user(self, token: &Option<String>) -> Rejectable<UserData> {
|
||||
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||
|
||||
if let Some(token) = token {
|
||||
let query_rows = conn
|
||||
.query("
|
||||
|
@ -47,9 +49,9 @@ SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_lan
|
|||
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&token.to_owned()],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
&[&token.to_owned()],
|
||||
).map_err(warp::reject::custom)?;
|
||||
|
||||
if let Some(result_columns) = query_rows.get(0) {
|
||||
let id = Id(result_columns.get(1));
|
||||
let allowed_langs = result_columns
|
||||
|
@ -85,10 +87,10 @@ LIMIT 1",
|
|||
}
|
||||
}
|
||||
|
||||
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable<i64> {
|
||||
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.map_err(warp::reject::custom)?
|
||||
.get(0)
|
||||
.map(|row| row.get(0))
|
||||
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
||||
|
@ -98,31 +100,31 @@ LIMIT 1",
|
|||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
pub fn select_blocked_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
|
||||
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||
conn.query(
|
||||
"SELECT target_account_id FROM blocks WHERE account_id = $1
|
||||
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.map_err(warp::reject::custom)?
|
||||
.iter()
|
||||
.map(|row| Id(row.get(0)))
|
||||
.map(|row| Ok(Id(row.get(0))))
|
||||
.collect()
|
||||
}
|
||||
/// Query Postgres for everyone who has blocked the user
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
pub fn select_blocking_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
|
||||
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||
conn.query(
|
||||
"SELECT account_id FROM blocks WHERE target_account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.map_err(warp::reject::custom)?
|
||||
.iter()
|
||||
.map(|row| Id(row.get(0)))
|
||||
.map(|row| Ok(Id(row.get(0))))
|
||||
.collect()
|
||||
}
|
||||
|
||||
|
@ -130,28 +132,28 @@ LIMIT 1",
|
|||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
pub fn select_blocked_domains(self, user_id: Id) -> Rejectable<HashSet<String>> {
|
||||
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||
conn.query(
|
||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.map_err(warp::reject::custom)?
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.map(|row| Ok(row.get(0)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Test whether a user owns a list
|
||||
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable<bool> {
|
||||
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||
&[&list_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
|
||||
.map_err(warp::reject::custom)?;
|
||||
Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ impl Subscription {
|
|||
let tag = pool.select_hashtag_id(&q.hashtag)?;
|
||||
Timeline(Hashtag(tag), reach, stream)
|
||||
}
|
||||
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
|
||||
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => {
|
||||
Err(warp::reject::custom("Error: Missing access token"))?
|
||||
}
|
||||
other_tl => other_tl,
|
||||
|
@ -70,9 +70,9 @@ impl Subscription {
|
|||
timeline,
|
||||
allowed_langs: user.allowed_langs,
|
||||
blocks: Blocks {
|
||||
blocking_users: pool.clone().select_blocking_users(user.id),
|
||||
blocked_users: pool.clone().select_blocked_users(user.id),
|
||||
blocked_domains: pool.select_blocked_domains(user.id),
|
||||
blocking_users: pool.clone().select_blocking_users(user.id)?,
|
||||
blocked_users: pool.clone().select_blocked_users(user.id)?,
|
||||
blocked_domains: pool.select_blocked_domains(user.id)?,
|
||||
},
|
||||
hashtag_name,
|
||||
access_token: q.access_token,
|
||||
|
|
|
@ -6,18 +6,11 @@ use log;
|
|||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use warp::reply::Reply;
|
||||
use warp::sse::{ServerSentEvent, Sse as WarpSse};
|
||||
use warp::sse::Sse as WarpSse;
|
||||
|
||||
pub struct Sse;
|
||||
|
||||
impl Sse {
|
||||
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
|
||||
Some((
|
||||
warp::sse::event(event.event_name()),
|
||||
warp::sse::data(event.payload().unwrap_or_else(String::new)),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn send_events(
|
||||
sse: WarpSse,
|
||||
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
|
@ -40,21 +33,20 @@ impl Sse {
|
|||
TypeSafe(Update { payload, queued_at }) => match timeline {
|
||||
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
||||
_ if payload.involves_any(&blocks) => None,
|
||||
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
|
||||
payload,
|
||||
queued_at,
|
||||
})),
|
||||
_ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at })
|
||||
.to_warp_reply(),
|
||||
},
|
||||
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
|
||||
TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(),
|
||||
Dynamic(dyn_event) => {
|
||||
if let EventKind::Update(s) = dyn_event.kind {
|
||||
match timeline {
|
||||
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
|
||||
_ if s.involves_any(&blocks) => None,
|
||||
_ => Self::reply_with(Dynamic(DynEvent {
|
||||
_ => Dynamic(DynEvent {
|
||||
kind: EventKind::Update(s),
|
||||
..dyn_event
|
||||
})),
|
||||
})
|
||||
.to_warp_reply(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
|
|
|
@ -50,7 +50,7 @@ impl Ws {
|
|||
|
||||
incoming_events.for_each(move |(tl, event)| {
|
||||
if matches!(event, Event::Ping) {
|
||||
self.send_ping()
|
||||
self.send_msg(event)
|
||||
} else if target_timeline == tl {
|
||||
use crate::event::{CheckedEvent::Update, Event::*, EventKind};
|
||||
use crate::request::Stream::Public;
|
||||
|
@ -83,15 +83,8 @@ impl Ws {
|
|||
})
|
||||
}
|
||||
|
||||
fn send_ping(&mut self) -> Result<(), ()> {
|
||||
self.send_txt("{}")
|
||||
}
|
||||
|
||||
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
|
||||
self.send_txt(&event.to_json_string())
|
||||
}
|
||||
|
||||
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
|
||||
let txt = &event.to_json_string();
|
||||
let tl = self.subscription.timeline;
|
||||
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
|
||||
Ok(_) => Ok(()),
|
||||
|
|
Loading…
Reference in New Issue