diff --git a/src/config.rs b/src/config.rs index 891a9b2..96e32bd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,7 +14,6 @@ mod redis_cfg; mod redis_cfg_types; pub fn merge_dotenv() -> Result<(), err::FatalErr> { - // TODO -- should this allow the user to run in a dir without a `.env` file? dotenv::from_filename(match env::var("ENV").ok().as_deref() { Some("production") => ".env.production", Some("development") | None => ".env", diff --git a/src/err.rs b/src/err.rs index 3a61f4b..5386d92 100644 --- a/src/err.rs +++ b/src/err.rs @@ -1,3 +1,4 @@ +use crate::request::RequestErr; use crate::response::ManagerErr; use std::fmt; @@ -6,6 +7,7 @@ pub enum FatalErr { ReceiverErr(ManagerErr), DotEnv(dotenv::Error), Logger(log::SetLoggerError), + Postgres(RequestErr), } impl FatalErr { @@ -33,11 +35,18 @@ impl fmt::Display for FatalErr { ReceiverErr(e) => format!("{}", e), Logger(e) => format!("{}", e), DotEnv(e) => format!("Could not load specified environmental file: {}", e), + Postgres(e) => format!("Could not connect to Postgres: {}", e), } ) } } +impl From for FatalErr { + fn from(e: RequestErr) -> Self { + Self::Postgres(e) + } +} + impl From for FatalErr { fn from(e: dotenv::Error) -> Self { Self::DotEnv(e) diff --git a/src/event.rs b/src/event.rs index 8583b98..8b37d0a 100644 --- a/src/event.rs +++ b/src/event.rs @@ -2,14 +2,14 @@ mod checked_event; mod dynamic_event; mod err; -pub use { - checked_event::{CheckedEvent, Id}, - dynamic_event::{DynEvent, DynStatus, EventKind}, - err::EventErr, -}; +pub use checked_event::{CheckedEvent, Id}; +pub use dynamic_event::{DynEvent, DynStatus, EventKind}; +pub use err::EventErr; use serde::Serialize; -use std::{convert::TryFrom, string::String}; +use std::convert::TryFrom; +use std::string::String; +use warp::sse::ServerSentEvent; #[derive(Debug, Clone)] pub enum Event { @@ -20,15 +20,30 @@ pub enum Event { impl Event { pub fn to_json_string(&self) -> String { - let event = &self.event_name(); - let sendable_event = match self.payload() { - Some(payload) => SendableEvent::WithPayload { event, payload }, - None => SendableEvent::NoPayload { event }, - }; - serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize") + if let Event::Ping = self { + "{}".to_string() + } else { + let event = &self.event_name(); + let sendable_event = match self.payload() { + Some(payload) => SendableEvent::WithPayload { event, payload }, + None => SendableEvent::NoPayload { event }, + }; + serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize") + } } - pub fn event_name(&self) -> String { + pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> { + if let Event::Ping = self { + None + } else { + Some(( + warp::sse::event(self.event_name()), + warp::sse::data(self.payload().unwrap_or_else(String::new)), + )) + } + } + + fn event_name(&self) -> String { String::from(match self { Self::TypeSafe(checked) => match checked { CheckedEvent::Update { .. } => "update", @@ -45,11 +60,11 @@ impl Event { .. }) => "update", Self::Dynamic(DynEvent { event, .. }) => event, - Self::Ping => panic!("event_name() called on Ping"), + Self::Ping => unreachable!(), // private method only called above }) } - pub fn payload(&self) -> Option { + fn payload(&self) -> Option { use CheckedEvent::*; match self { Self::TypeSafe(checked) => match checked { @@ -63,7 +78,7 @@ impl Event { FiltersChanged => None, }, Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()), - Self::Ping => panic!("payload() called on Ping"), + Self::Ping => unreachable!(), // private method only called above } } } diff --git a/src/main.rs b/src/main.rs index 7e385ff..77c192c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ fn main() -> Result<(), FatalErr> { let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); - let request = Handler::new(postgres_cfg, *cfg.whitelist_mode); + let request = Handler::new(postgres_cfg, *cfg.whitelist_mode)?; let poll_freq = *redis_cfg.polling_interval; let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc(); diff --git a/src/request.rs b/src/request.rs index abeea43..7d1eab3 100644 --- a/src/request.rs +++ b/src/request.rs @@ -3,8 +3,10 @@ mod postgres; mod query; pub mod timeline; +mod err; mod subscription; +pub use self::err::RequestErr; pub use self::postgres::PgPool; // TODO consider whether we can remove `Stream` from public API pub use subscription::{Blocks, Subscription}; @@ -22,6 +24,8 @@ mod sse_test; #[cfg(test)] mod ws_test; +type Result = std::result::Result; + /// Helper macro to match on the first of any of the provided filters macro_rules! any_of { ($filter:expr, $($other_filter:expr),*) => { @@ -56,10 +60,10 @@ pub struct Handler { } impl Handler { - pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self { - Self { - pg_conn: PgPool::new(postgres_cfg, whitelist_mode), - } + pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Result { + Ok(Self { + pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?, + }) } pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> { @@ -113,7 +117,7 @@ impl Handler { warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed() } - pub fn err(r: Rejection) -> Result { + pub fn err(r: Rejection) -> std::result::Result { let json_err = match r.cause() { Some(text) if text.to_string() == "Missing request header 'authorization'" => { warp::reply::json(&"Error: Missing access token".to_string()) diff --git a/src/request/err.rs b/src/request/err.rs new file mode 100644 index 0000000..d1d62ca --- /dev/null +++ b/src/request/err.rs @@ -0,0 +1,25 @@ +use std::fmt; +#[derive(Debug)] +pub enum RequestErr { + Unknown, + PgPool(r2d2::Error), +} + +impl std::error::Error for RequestErr {} + +impl fmt::Display for RequestErr { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use RequestErr::*; + let msg = match self { + Unknown => "Encountered an unrecoverable error related to handling a request".into(), + PgPool(e) => format!("{}", e), + }; + write!(f, "{}", msg) + } +} + +impl From for RequestErr { + fn from(e: r2d2::Error) -> Self { + Self::PgPool(e) + } +} diff --git a/src/request/postgres.rs b/src/request/postgres.rs index de17ab0..c427c93 100644 --- a/src/request/postgres.rs +++ b/src/request/postgres.rs @@ -1,13 +1,13 @@ //! Postgres queries +use super::err; +use super::timeline::{Scope, UserData}; use crate::config; use crate::event::Id; -use crate::request::timeline::{Scope, UserData}; use ::postgres; use hashbrown::HashSet; use r2d2_postgres::PostgresConnectionManager; use std::convert::TryFrom; -use warp::reject::Rejection; #[derive(Clone, Debug)] pub struct PgPool { @@ -15,8 +15,11 @@ pub struct PgPool { whitelist_mode: bool, } +type Result = std::result::Result; +type Rejectable = std::result::Result; + impl PgPool { - pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self { + pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Result { let mut cfg = postgres::Config::new(); cfg.user(&pg_cfg.user) .host(&*pg_cfg.host.to_string()) @@ -27,18 +30,17 @@ impl PgPool { }; let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); - let pool = r2d2::Pool::builder() - .max_size(10) - .build(manager) - .expect("Can connect to local postgres"); - Self { + let pool = r2d2::Pool::builder().max_size(10).build(manager)?; + + Ok(Self { conn: pool, whitelist_mode, - } + }) } - pub fn select_user(self, token: &Option) -> Result { - let mut conn = self.conn.get().unwrap(); + pub fn select_user(self, token: &Option) -> Rejectable { + let mut conn = self.conn.get().map_err(warp::reject::custom)?; + if let Some(token) = token { let query_rows = conn .query(" @@ -47,9 +49,9 @@ SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_lan INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1", - &[&token.to_owned()], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); + &[&token.to_owned()], + ).map_err(warp::reject::custom)?; + if let Some(result_columns) = query_rows.get(0) { let id = Id(result_columns.get(1)); let allowed_langs = result_columns @@ -85,10 +87,10 @@ LIMIT 1", } } - pub fn select_hashtag_id(self, tag_name: &str) -> Result { - let mut conn = self.conn.get().expect("TODO"); + pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable { + let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name]) - .expect("Hard-coded query will return Some([0 or more rows])") + .map_err(warp::reject::custom)? .get(0) .map(|row| row.get(0)) .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) @@ -98,31 +100,31 @@ LIMIT 1", /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocked_users(self, user_id: Id) -> HashSet { - let mut conn = self.conn.get().expect("TODO"); + pub fn select_blocked_users(self, user_id: Id) -> Rejectable> { + let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query( "SELECT target_account_id FROM blocks WHERE account_id = $1 UNION SELECT target_account_id FROM mutes WHERE account_id = $1", &[&*user_id], ) - .expect("Hard-coded query will return Some([0 or more rows])") + .map_err(warp::reject::custom)? .iter() - .map(|row| Id(row.get(0))) + .map(|row| Ok(Id(row.get(0)))) .collect() } /// Query Postgres for everyone who has blocked the user /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocking_users(self, user_id: Id) -> HashSet { - let mut conn = self.conn.get().expect("TODO"); + pub fn select_blocking_users(self, user_id: Id) -> Rejectable> { + let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query( "SELECT account_id FROM blocks WHERE target_account_id = $1", &[&*user_id], ) - .expect("Hard-coded query will return Some([0 or more rows])") + .map_err(warp::reject::custom)? .iter() - .map(|row| Id(row.get(0))) + .map(|row| Ok(Id(row.get(0)))) .collect() } @@ -130,28 +132,28 @@ LIMIT 1", /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocked_domains(self, user_id: Id) -> HashSet { - let mut conn = self.conn.get().expect("TODO"); + pub fn select_blocked_domains(self, user_id: Id) -> Rejectable> { + let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query( "SELECT domain FROM account_domain_blocks WHERE account_id = $1", &[&*user_id], ) - .expect("Hard-coded query will return Some([0 or more rows])") + .map_err(warp::reject::custom)? .iter() - .map(|row| row.get(0)) + .map(|row| Ok(row.get(0))) .collect() } /// Test whether a user owns a list - pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool { - let mut conn = self.conn.get().expect("TODO"); + pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable { + let mut conn = self.conn.get().map_err(warp::reject::custom)?; // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn .query( "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", &[&list_id], ) - .expect("Hard-coded query will return Some([0 or more rows])"); - rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id) + .map_err(warp::reject::custom)?; + Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)) } } diff --git a/src/request/subscription.rs b/src/request/subscription.rs index 6e9a84b..79fd046 100644 --- a/src/request/subscription.rs +++ b/src/request/subscription.rs @@ -54,7 +54,7 @@ impl Subscription { let tag = pool.select_hashtag_id(&q.hashtag)?; Timeline(Hashtag(tag), reach, stream) } - Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => { + Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => { Err(warp::reject::custom("Error: Missing access token"))? } other_tl => other_tl, @@ -70,9 +70,9 @@ impl Subscription { timeline, allowed_langs: user.allowed_langs, blocks: Blocks { - blocking_users: pool.clone().select_blocking_users(user.id), - blocked_users: pool.clone().select_blocked_users(user.id), - blocked_domains: pool.select_blocked_domains(user.id), + blocking_users: pool.clone().select_blocking_users(user.id)?, + blocked_users: pool.clone().select_blocked_users(user.id)?, + blocked_domains: pool.select_blocked_domains(user.id)?, }, hashtag_name, access_token: q.access_token, diff --git a/src/response/stream/sse.rs b/src/response/stream/sse.rs index cd6fd12..9bcb058 100644 --- a/src/response/stream/sse.rs +++ b/src/response/stream/sse.rs @@ -6,18 +6,11 @@ use log; use std::time::Duration; use tokio::sync::{mpsc, watch}; use warp::reply::Reply; -use warp::sse::{ServerSentEvent, Sse as WarpSse}; +use warp::sse::Sse as WarpSse; pub struct Sse; impl Sse { - fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> { - Some(( - warp::sse::event(event.event_name()), - warp::sse::data(event.payload().unwrap_or_else(String::new)), - )) - } - pub fn send_events( sse: WarpSse, mut unsubscribe_tx: mpsc::UnboundedSender, @@ -40,21 +33,20 @@ impl Sse { TypeSafe(Update { payload, queued_at }) => match timeline { Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, _ if payload.involves_any(&blocks) => None, - _ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { - payload, - queued_at, - })), + _ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at }) + .to_warp_reply(), }, - TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), + TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(), Dynamic(dyn_event) => { if let EventKind::Update(s) = dyn_event.kind { match timeline { Timeline(Public, _, _) if s.language_not(&allowed_langs) => None, _ if s.involves_any(&blocks) => None, - _ => Self::reply_with(Dynamic(DynEvent { + _ => Dynamic(DynEvent { kind: EventKind::Update(s), ..dyn_event - })), + }) + .to_warp_reply(), } } else { None diff --git a/src/response/stream/ws.rs b/src/response/stream/ws.rs index 3faef99..224cee7 100644 --- a/src/response/stream/ws.rs +++ b/src/response/stream/ws.rs @@ -50,7 +50,7 @@ impl Ws { incoming_events.for_each(move |(tl, event)| { if matches!(event, Event::Ping) { - self.send_ping() + self.send_msg(event) } else if target_timeline == tl { use crate::event::{CheckedEvent::Update, Event::*, EventKind}; use crate::request::Stream::Public; @@ -83,15 +83,8 @@ impl Ws { }) } - fn send_ping(&mut self) -> Result<(), ()> { - self.send_txt("{}") - } - fn send_msg(&mut self, event: Event) -> Result<(), ()> { - self.send_txt(&event.to_json_string()) - } - - fn send_txt(&mut self, txt: &str) -> Result<(), ()> { + let txt = &event.to_json_string(); let tl = self.subscription.timeline; match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) { Ok(_) => Ok(()),