diff --git a/src/config.rs b/src/config.rs index 76edcad..79e3041 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,12 @@ -pub(crate) use postgres_cfg::Postgres; -pub(crate) use redis_cfg::Redis; - -use deployment_cfg::Deployment; +pub use self::deployment_cfg::Deployment; +pub use self::postgres_cfg::Postgres; +pub use self::redis_cfg::Redis; use self::environmental_variables::EnvVar; -use super::err::Error; + use hashbrown::HashMap; use std::env; - +use std::fmt; mod deployment_cfg; mod deployment_cfg_types; mod environmental_variables; @@ -58,3 +57,47 @@ pub fn from_env<'a>( Ok((pg_cfg, redis_cfg, deployment_cfg)) } + +#[derive(Debug)] +pub enum Error { + Config(String), + UrlEncoding(urlencoding::FromUrlEncodingError), + UrlParse(url::ParseError), +} + +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> { + write!( + f, + "{}", + match self { + Self::Config(e) => e.to_string(), + Self::UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e), + Self::UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e), + } + ) + } +} + +impl Error { + pub fn config(var: T, value: T, allowed_vals: T) -> Self { + Self::Config(format!( + "{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.", + var, value, allowed_vals, "" + )) + } +} + +impl From for Error { + fn from(e: urlencoding::FromUrlEncodingError) -> Self { + Self::UrlEncoding(e) + } +} + +impl From for Error { + fn from(e: url::ParseError) -> Self { + Self::UrlParse(e) + } +} diff --git a/src/config/deployment_cfg.rs b/src/config/deployment_cfg.rs index efee4a6..db0c502 100644 --- a/src/config/deployment_cfg.rs +++ b/src/config/deployment_cfg.rs @@ -1,5 +1,5 @@ -use super::{deployment_cfg_types::*, EnvVar}; -use crate::err::Error; +use super::deployment_cfg_types::*; +use super::{EnvVar, Error}; #[derive(Debug, Default)] pub struct Deployment<'a> { diff --git a/src/config/environmental_variables.rs b/src/config/environmental_variables.rs index 497fa98..fbac507 100644 --- a/src/config/environmental_variables.rs +++ b/src/config/environmental_variables.rs @@ -61,6 +61,7 @@ impl fmt::Display for EnvVar { } } #[macro_export] +#[doc(hidden)] macro_rules! maybe_update { ($name:ident; $item: tt:$type:ty) => ( pub(crate) fn $name(self, item: Option<$type>) -> Self { @@ -76,7 +77,9 @@ macro_rules! maybe_update { None => Self { ..self } } })} + #[macro_export] +#[doc(hidden)] macro_rules! from_env_var { ($(#[$outer:meta])* let name = $name:ident; @@ -106,15 +109,14 @@ macro_rules! from_env_var { fn inner_from_str($arg: &str) -> Option<$type> { $body } - pub(crate) fn maybe_update( - self, - var: Option<&String>, - ) -> Result { + pub(crate) fn maybe_update(self, var: Option<&String>) -> Result { Ok(match var { Some(empty_string) if empty_string.is_empty() => Self::default(), - Some(value) => Self(Self::inner_from_str(value).ok_or_else(|| { - crate::err::Error::config($env_var, value, $allowed_values) - })?), + Some(value) => { + Self(Self::inner_from_str(value).ok_or_else(|| { + super::Error::config($env_var, value, $allowed_values) + })?) + } None => self, }) } diff --git a/src/config/postgres_cfg.rs b/src/config/postgres_cfg.rs index d94fe21..2422f7d 100644 --- a/src/config/postgres_cfg.rs +++ b/src/config/postgres_cfg.rs @@ -1,17 +1,19 @@ -use super::{postgres_cfg_types::*, EnvVar}; -use crate::err::Error; +use super::postgres_cfg_types::*; +use super::{EnvVar, Error}; use url::Url; use urlencoding; type Result = std::result::Result; +/// Configuration values for Postgres #[derive(Debug, Clone)] pub struct Postgres { pub(crate) user: PgUser, pub(crate) host: PgHost, pub(crate) password: PgPass, - pub(crate) database: PgDatabase, + /// The name of the postgres database to connect to + pub database: PgDatabase, pub(crate) port: PgPort, pub(crate) ssl_mode: PgSslMode, } diff --git a/src/config/redis_cfg.rs b/src/config/redis_cfg.rs index fe79364..2cc10f7 100644 --- a/src/config/redis_cfg.rs +++ b/src/config/redis_cfg.rs @@ -1,6 +1,5 @@ use super::redis_cfg_types::*; -use super::EnvVar; -use crate::err::Error; +use super::{EnvVar, Error}; use url::Url; diff --git a/src/err.rs b/src/err.rs index e756524..bb4cf61 100644 --- a/src/err.rs +++ b/src/err.rs @@ -1,33 +1,26 @@ +use crate::config; use crate::request; use crate::response; + use std::fmt; pub enum Error { - ReceiverErr(response::Error), + Response(response::Error), Logger(log::SetLoggerError), Postgres(request::Error), Unrecoverable, StdIo(std::io::Error), - // config errs - UrlParse(url::ParseError), - UrlEncoding(urlencoding::FromUrlEncodingError), - ConfigErr(String), + Config(config::Error), } impl Error { pub fn log(msg: impl fmt::Display) { eprintln!("{}", msg); } - - pub fn config(var: T, value: T, allowed_vals: T) -> Self { - Self::ConfigErr(format!( - "{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.", - var, value, allowed_vals, "" - )) - } } impl std::error::Error for Error {} + impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { write!(f, "{}", self) @@ -41,45 +34,46 @@ impl fmt::Display for Error { f, "{}", match self { - ReceiverErr(e) => format!("{}", e), + Response(e) => format!("{}", e), Logger(e) => format!("{}", e), StdIo(e) => format!("{}", e), Postgres(e) => format!("could not connect to Postgres.\n{:7}{}", "", e), - ConfigErr(e) => e.to_string(), - UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e), - UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e), + Config(e) => format!("{}", e), Unrecoverable => "Flodgatt will now shut down.".into(), } ) } } +#[doc(hidden)] impl From for Error { fn from(e: request::Error) -> Self { Self::Postgres(e) } } +#[doc(hidden)] impl From for Error { fn from(e: response::Error) -> Self { - Self::ReceiverErr(e) + Self::Response(e) } } -impl From for Error { - fn from(e: urlencoding::FromUrlEncodingError) -> Self { - Self::UrlEncoding(e) - } -} -impl From for Error { - fn from(e: url::ParseError) -> Self { - Self::UrlParse(e) + +#[doc(hidden)] +impl From for Error { + fn from(e: config::Error) -> Self { + Self::Config(e) } } + +#[doc(hidden)] impl From for Error { fn from(e: std::io::Error) -> Self { Self::StdIo(e) } } + +#[doc(hidden)] impl From for Error { fn from(e: log::SetLoggerError) -> Self { Self::Logger(e) diff --git a/src/event.rs b/src/event.rs index ce10fb9..b28a1c7 100644 --- a/src/event.rs +++ b/src/event.rs @@ -2,10 +2,12 @@ mod checked_event; mod dynamic_event; mod err; -pub(crate) use checked_event::{CheckedEvent, Id}; -pub(crate) use dynamic_event::{DynEvent, EventKind}; +pub(crate) use checked_event::Id; pub(crate) use err::EventErr; +use self::checked_event::CheckedEvent; +use self::dynamic_event::{DynEvent, EventKind}; + use hashbrown::HashSet; use serde::Serialize; use std::convert::TryFrom; @@ -32,17 +34,6 @@ pub(crate) trait Payload { } impl Event { - pub(crate) fn get_update_payload(&self) -> Option> { - match self { - Event::TypeSafe(CheckedEvent::Update { payload, .. }) => Some(Box::new(payload)), - Event::Dynamic(DynEvent { - kind: EventKind::Update(s), - .. - }) => Some(Box::new(s)), - _ => None, - } - } - pub(crate) fn to_json_string(&self) -> String { if let Event::Ping = self { "{}".to_string() diff --git a/src/event/checked_event.rs b/src/event/checked_event.rs index 9a22db7..0c2d80c 100644 --- a/src/event/checked_event.rs +++ b/src/event/checked_event.rs @@ -11,13 +11,14 @@ mod status; mod tag; mod visibility; -use announcement::Announcement; -pub(in crate::event) use announcement_reaction::AnnouncementReaction; -use conversation::Conversation; +pub(self) use super::Payload; +pub(super) use announcement_reaction::AnnouncementReaction; pub(crate) use id::Id; -use notification::Notification; pub(crate) use status::Status; +use announcement::Announcement; +use conversation::Conversation; +use notification::Notification; use serde::Deserialize; #[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] diff --git a/src/event/checked_event/status.rs b/src/event/checked_event/status.rs index b8348c4..b14ddb1 100644 --- a/src/event/checked_event/status.rs +++ b/src/event/checked_event/status.rs @@ -3,21 +3,18 @@ mod attachment; mod card; mod poll; -use super::super::Payload; use super::account::Account; use super::emoji::Emoji; use super::id::Id; use super::mention::Mention; use super::tag::Tag; use super::visibility::Visibility; +use super::Payload; use application::Application; use attachment::Attachment; use card::Card; -use poll::Poll; - -//use crate::request::Blocks; - use hashbrown::HashSet; +use poll::Poll; use serde::{Deserialize, Serialize}; use std::boxed::Box; use std::string::String; diff --git a/src/event/dynamic_event.rs b/src/event/dynamic_event.rs index 67e8677..cfd8abf 100644 --- a/src/event/dynamic_event.rs +++ b/src/event/dynamic_event.rs @@ -49,13 +49,6 @@ impl DynEvent { Ok(self) } } - pub(crate) fn update(&self) -> Option { - if let EventKind::Update(status) = self.kind.clone() { - Some(status) - } else { - None - } - } } impl DynStatus { pub(crate) fn new(payload: &Value) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 70234fa..b173980 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,6 @@ pub use err::Error; pub mod config; mod err; -pub mod event; +mod event; pub mod request; pub mod response; diff --git a/src/main.rs b/src/main.rs index 8479523..19d47dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,6 @@ use flodgatt::config; -use flodgatt::event::Event; use flodgatt::request::{Handler, Subscription, Timeline}; -use flodgatt::response::redis::Manager; -use flodgatt::response::stream; +use flodgatt::response::{Event, RedisManager, SseStream, WsStream}; use flodgatt::Error; use futures::{future::lazy, stream::Stream as _}; @@ -27,7 +25,7 @@ fn main() -> Result<(), Error> { let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?; - let shared_manager = Manager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc(); + let shared_manager = RedisManager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc(); // Server Sent Events let sse_manager = shared_manager.clone(); @@ -37,10 +35,10 @@ fn main() -> Result<(), Error> { .and(warp::sse()) .map(move |subscription: Subscription, sse: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); - let mut manager = sse_manager.lock().unwrap_or_else(Manager::recover); + let mut manager = sse_manager.lock().unwrap_or_else(RedisManager::recover); manager.subscribe(&subscription); - stream::Sse::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) + SseStream::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) }) .with(warp::reply::with::header("Connection", "keep-alive")); @@ -51,10 +49,10 @@ fn main() -> Result<(), Error> { .and(warp::ws::ws2()) .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); - let mut manager = ws_manager.lock().unwrap_or_else(Manager::recover); + let mut manager = ws_manager.lock().unwrap_or_else(RedisManager::recover); manager.subscribe(&subscription); let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security - let ws_stream = stream::Ws::new(cmd_tx.clone(), event_rx.clone(), subscription); + let ws_stream = WsStream::new(cmd_tx.clone(), event_rx.clone(), subscription); (ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token) }) @@ -66,9 +64,9 @@ fn main() -> Result<(), Error> { let (r1, r3) = (shared_manager.clone(), shared_manager.clone()); request.health().map(|| "OK") .or(request.status() - .map(move || r1.lock().unwrap_or_else(redis::Manager::recover).count())) + .map(move || r1.lock().unwrap_or_else(RedisManager::recover).count())) .or(request.status_per_timeline() - .map(move || r3.lock().unwrap_or_else(redis::Manager::recover).list())) + .map(move || r3.lock().unwrap_or_else(RedisManager::recover).list())) }; #[cfg(not(feature = "stub_status"))] let status = request.health().map(|| "OK"); @@ -83,7 +81,7 @@ fn main() -> Result<(), Error> { let stream = Interval::new(Instant::now(), poll_freq) .map_err(|e| log::error!("{}", e)) .for_each(move |_| { - let mut manager = manager.lock().unwrap_or_else(Manager::recover); + let mut manager = manager.lock().unwrap_or_else(RedisManager::recover); manager.poll_broadcast().map_err(Error::log) }); diff --git a/src/request.rs b/src/request.rs index cf674c0..89e5b95 100644 --- a/src/request.rs +++ b/src/request.rs @@ -7,7 +7,8 @@ pub mod err; mod subscription; pub(crate) use err::Error; -pub use subscription::Subscription; +pub use subscription::{Blocks, Subscription}; +#[doc(hidden)] pub use timeline::Timeline; use timeline::{Content, Reach, Stream}; diff --git a/src/request/sse_test.rs b/src/request/sse_test.rs index 287cd9b..1a585f1 100644 --- a/src/request/sse_test.rs +++ b/src/request/sse_test.rs @@ -3,7 +3,6 @@ // #[cfg(test)] // mod test { // use super::*; -// use crate::parse_client_request::user::{Blocks, Filter, OauthScope, PgPool}; // macro_rules! test_public_endpoint { // ($name:ident { diff --git a/src/request/subscription.rs b/src/request/subscription.rs index 8b105dc..3688166 100644 --- a/src/request/subscription.rs +++ b/src/request/subscription.rs @@ -17,17 +17,19 @@ use warp::reject::Rejection; #[derive(Clone, Debug, PartialEq)] pub struct Subscription { pub timeline: Timeline, - pub(crate) allowed_langs: HashSet, - pub(crate) blocks: Blocks, - pub(crate) hashtag_name: Option, + pub allowed_langs: HashSet, + /// [Blocks](./request/struct.Blocks.html) + pub blocks: Blocks, + pub hashtag_name: Option, pub access_token: Option, } +/// Blocked and muted users and domains #[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct Blocks { - pub(crate) blocked_domains: HashSet, - pub(crate) blocked_users: HashSet, - pub(crate) blocking_users: HashSet, +pub struct Blocks { + pub blocked_domains: HashSet, + pub blocked_users: HashSet, + pub blocking_users: HashSet, } impl Default for Subscription { diff --git a/src/request/ws_test.rs b/src/request/ws_test.rs index e101bed..8b547a4 100644 --- a/src/request/ws_test.rs +++ b/src/request/ws_test.rs @@ -1,14 +1,8 @@ //! Filters for the WebSocket endpoint - - - - - // #[cfg(test)] // mod test { // use super::*; -// use crate::parse_client_request::user::{Blocks, Filter, OauthScope}; // macro_rules! test_public_endpoint { // ($name:ident { diff --git a/src/response.rs b/src/response.rs index 4ce8d1f..10d8fab 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,9 +1,13 @@ //! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. -pub mod redis; -pub mod stream; +pub use crate::event::Event; +pub use redis::Manager as RedisManager; +pub use stream::{Sse as SseStream, Ws as WsStream}; -pub(crate) use redis::Error; +mod redis; +mod stream; + +pub use redis::Error; #[cfg(feature = "bench")] pub use redis::msg::{RedisMsg, RedisParseOutput}; diff --git a/src/response/redis.rs b/src/response/redis.rs index 0d69504..196189c 100644 --- a/src/response/redis.rs +++ b/src/response/redis.rs @@ -3,7 +3,7 @@ mod manager; mod msg; pub(self) use connection::RedisConn; -pub(crate) use manager::Error; +pub use manager::Error; pub use manager::Manager; use connection::RedisConnErr; diff --git a/src/response/redis/manager.rs b/src/response/redis/manager.rs index 9a68f32..1ecdcc4 100644 --- a/src/response/redis/manager.rs +++ b/src/response/redis/manager.rs @@ -2,7 +2,7 @@ //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. mod err; -pub(crate) use err::Error; +pub use err::Error; use super::{RedisCmd, RedisConn}; use crate::config; diff --git a/src/response/stream/sse.rs b/src/response/stream/sse.rs index 7319bc0..6522003 100644 --- a/src/response/stream/sse.rs +++ b/src/response/stream/sse.rs @@ -1,4 +1,4 @@ -use crate::event::Event; +use crate::event::{Event, Payload}; use crate::request::{Subscription, Timeline}; use futures::stream::Stream; @@ -18,59 +18,19 @@ impl Sse { sse_rx: watch::Receiver<(Timeline, Event)>, ) -> impl Reply { let target_timeline = subscription.timeline; - let allowed_langs = subscription.allowed_langs; - let blocks = subscription.blocks; let event_stream = sse_rx .filter(move |(timeline, _)| target_timeline == *timeline) - .filter_map(move |(timeline, event)| { - use crate::event::Payload; - use crate::event::{ - CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind, - }; // TODO -- move up - - match event { - TypeSafe(Update { payload, queued_at }) => match timeline { - tl if tl.is_public() - && !payload.language_unset() - && !allowed_langs.is_empty() - && !allowed_langs.contains(&payload.language()) => - { - None - } - _ if blocks.blocked_users.is_disjoint(&payload.involved_users()) => None, - _ if blocks.blocking_users.contains(payload.author()) => None, - _ if blocks.blocked_domains.contains(payload.sent_from()) => None, - - _ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at }) - .to_warp_reply(), - }, - TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(), - Dynamic(dyn_event) => { - if let EventKind::Update(s) = dyn_event.kind { - match timeline { - tl if tl.is_public() - && !s.language_unset() - && !allowed_langs.is_empty() - && !allowed_langs.contains(&s.language()) => - { - None - } - _ if blocks.blocked_users.is_disjoint(&s.involved_users()) => None, - _ if blocks.blocking_users.contains(s.author()) => None, - _ if blocks.blocked_domains.contains(s.sent_from()) => None, - - _ => Dynamic(DynEvent { - kind: EventKind::Update(s), - ..dyn_event - }) - .to_warp_reply(), - } - } else { - None - } + .filter_map(move |(_timeline, event)| { + match (event.update_payload(), event.dyn_update_payload()) { + (Some(update), _) if Sse::update_not_filtered(subscription.clone(), update) => { + event.to_warp_reply() } - Ping => None, // pings handled automatically + (None, None) => event.to_warp_reply(), // send all non-updates + (_, Some(update)) if Sse::update_not_filtered(subscription.clone(), update) => { + event.to_warp_reply() + } + (_, _) => None, } }) .then(move |res| { @@ -87,4 +47,23 @@ impl Sse { .stream(event_stream), ) } + + fn update_not_filtered(subscription: Subscription, update: &impl Payload) -> bool { + let blocks = &subscription.blocks; + let allowed_langs = &subscription.allowed_langs; + + match subscription.timeline { + tl if tl.is_public() + && !update.language_unset() + && !allowed_langs.is_empty() + && !allowed_langs.contains(&update.language()) => + { + false + } + _ if !blocks.blocked_users.is_disjoint(&update.involved_users()) => false, + _ if blocks.blocking_users.contains(update.author()) => false, + _ if blocks.blocked_domains.contains(update.sent_from()) => false, + _ => true, + } + } } diff --git a/src/response/stream/ws.rs b/src/response/stream/ws.rs index e6541e4..139541e 100644 --- a/src/response/stream/ws.rs +++ b/src/response/stream/ws.rs @@ -1,10 +1,12 @@ -use crate::event::Event; +use crate::event::{Event, Payload}; use crate::request::{Subscription, Timeline}; use futures::{future::Future, stream::Stream}; use tokio::sync::{mpsc, watch}; use warp::ws::{Message, WebSocket}; +type Result = std::result::Result; + pub struct Ws { unsubscribe_tx: mpsc::UnboundedSender, subscription: Subscription, @@ -54,30 +56,36 @@ impl Ws { if matches!(event, Event::Ping) { self.send_msg(&event)? } else if target_timeline == tl { - let blocks = &self.subscription.blocks; - let allowed_langs = &self.subscription.allowed_langs; - - if let Some(update) = event.get_update_payload() { - match tl { - tl if tl.is_public() - && !update.language_unset() - && !allowed_langs.is_empty() - && !allowed_langs.contains(&update.language()) => {} // skip - _ if !blocks.blocked_users.is_disjoint(&update.involved_users()) => {} // skip - _ if blocks.blocking_users.contains(update.author()) => {} // skip - _ if blocks.blocked_domains.contains(update.sent_from()) => {} // skip - _ => self.send_msg(&event)?, - } - } else { - // send all non-updates - self.send_msg(&event)?; + match (event.update_payload(), event.dyn_update_payload()) { + (Some(update), _) => self.send_or_filter(tl, &event, update)?, + (None, None) => self.send_msg(&event)?, // send all non-updates + (_, Some(dyn_update)) => self.send_or_filter(tl, &event, dyn_update)?, } } Ok(()) }) } - fn send_msg(&mut self, event: &Event) -> Result<(), ()> { + fn send_or_filter(&mut self, tl: Timeline, event: &Event, update: &impl Payload) -> Result<()> { + let blocks = &self.subscription.blocks; + let allowed_langs = &self.subscription.allowed_langs; + const SKIP: Result<()> = Ok(()); + match tl { + tl if tl.is_public() + && !update.language_unset() + && !allowed_langs.is_empty() + && !allowed_langs.contains(&update.language()) => + { + SKIP + } + _ if !blocks.blocked_users.is_disjoint(&update.involved_users()) => SKIP, + _ if blocks.blocking_users.contains(update.author()) => SKIP, + _ if blocks.blocked_domains.contains(update.sent_from()) => SKIP, + _ => Ok(self.send_msg(&event)?), + } + } + + fn send_msg(&mut self, event: &Event) -> Result<()> { let txt = &event.to_json_string(); let tl = self.subscription.timeline; let mut channel = self.ws_tx.clone().ok_or(())?;