diff --git a/benches/parse_redis.rs b/benches/parse_redis.rs index df53fb9..f5f1980 100644 --- a/benches/parse_redis.rs +++ b/benches/parse_redis.rs @@ -22,20 +22,20 @@ fn parse_to_timeline(msg: RedisMsg) -> Timeline { assert_eq!(tl, Timeline(User(Id(1)), Federated, All)); tl } -fn parse_to_checked_event(msg: RedisMsg) -> Event { - Event::TypeSafe(serde_json::from_str(msg.event_txt).unwrap()) +fn parse_to_checked_event(msg: RedisMsg) -> EventKind { + EventKind::TypeSafe(serde_json::from_str(msg.event_txt).unwrap()) } -fn parse_to_dyn_event(msg: RedisMsg) -> Event { - Event::Dynamic(serde_json::from_str(msg.event_txt).unwrap()) +fn parse_to_dyn_event(msg: RedisMsg) -> EventKind { + EventKind::Dynamic(serde_json::from_str(msg.event_txt).unwrap()) } fn redis_msg_to_event_string(msg: RedisMsg) -> String { msg.event_txt.to_string() } -fn string_to_checked_event(event_txt: &String) -> Event { - Event::TypeSafe(serde_json::from_str(event_txt).unwrap()) +fn string_to_checked_event(event_txt: &String) -> EventKind { + EventKind::TypeSafe(serde_json::from_str(event_txt).unwrap()) } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/config.rs b/src/config.rs index f7c3089..79e3041 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,12 @@ -pub(crate) use postgres_cfg::Postgres; -pub(crate) use redis_cfg::Redis; - -use deployment_cfg::Deployment; +pub use self::deployment_cfg::Deployment; +pub use self::postgres_cfg::Postgres; +pub use self::redis_cfg::Redis; use self::environmental_variables::EnvVar; -use super::err::FatalErr; + use hashbrown::HashMap; use std::env; - +use std::fmt; mod deployment_cfg; mod deployment_cfg_types; mod environmental_variables; @@ -16,13 +15,13 @@ mod postgres_cfg_types; mod redis_cfg; mod redis_cfg_types; -type Result = std::result::Result; +type Result = std::result::Result; pub fn merge_dotenv() -> Result<()> { let env_file = match env::var("ENV").ok().as_deref() { Some("production") => ".env.production", Some("development") | None => ".env", - Some(v) => Err(FatalErr::config("ENV", v, "`production` or `development`"))?, + Some(v) => Err(Error::config("ENV", v, "`production` or `development`"))?, }; let res = dotenv::from_filename(env_file); @@ -58,3 +57,47 @@ pub fn from_env<'a>( Ok((pg_cfg, redis_cfg, deployment_cfg)) } + +#[derive(Debug)] +pub enum Error { + Config(String), + UrlEncoding(urlencoding::FromUrlEncodingError), + UrlParse(url::ParseError), +} + +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> std::result::Result<(), fmt::Error> { + write!( + f, + "{}", + match self { + Self::Config(e) => e.to_string(), + Self::UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e), + Self::UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e), + } + ) + } +} + +impl Error { + pub fn config(var: T, value: T, allowed_vals: T) -> Self { + Self::Config(format!( + "{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.", + var, value, allowed_vals, "" + )) + } +} + +impl From for Error { + fn from(e: urlencoding::FromUrlEncodingError) -> Self { + Self::UrlEncoding(e) + } +} + +impl From for Error { + fn from(e: url::ParseError) -> Self { + Self::UrlParse(e) + } +} diff --git a/src/config/deployment_cfg.rs b/src/config/deployment_cfg.rs index 379691e..db0c502 100644 --- a/src/config/deployment_cfg.rs +++ b/src/config/deployment_cfg.rs @@ -1,5 +1,5 @@ -use super::{deployment_cfg_types::*, EnvVar}; -use crate::err::FatalErr; +use super::deployment_cfg_types::*; +use super::{EnvVar, Error}; #[derive(Debug, Default)] pub struct Deployment<'a> { @@ -13,7 +13,7 @@ pub struct Deployment<'a> { } impl Deployment<'_> { - pub(crate) fn from_env(env: &EnvVar) -> Result { + pub(crate) fn from_env(env: &EnvVar) -> Result { let mut cfg = Self { env: Env::default().maybe_update(env.get("NODE_ENV"))?, log_level: LogLevel::default().maybe_update(env.get("RUST_LOG"))?, diff --git a/src/config/environmental_variables.rs b/src/config/environmental_variables.rs index a1e08cf..fbac507 100644 --- a/src/config/environmental_variables.rs +++ b/src/config/environmental_variables.rs @@ -61,6 +61,7 @@ impl fmt::Display for EnvVar { } } #[macro_export] +#[doc(hidden)] macro_rules! maybe_update { ($name:ident; $item: tt:$type:ty) => ( pub(crate) fn $name(self, item: Option<$type>) -> Self { @@ -76,7 +77,9 @@ macro_rules! maybe_update { None => Self { ..self } } })} + #[macro_export] +#[doc(hidden)] macro_rules! from_env_var { ($(#[$outer:meta])* let name = $name:ident; @@ -106,15 +109,14 @@ macro_rules! from_env_var { fn inner_from_str($arg: &str) -> Option<$type> { $body } - pub(crate) fn maybe_update( - self, - var: Option<&String>, - ) -> Result { + pub(crate) fn maybe_update(self, var: Option<&String>) -> Result { Ok(match var { Some(empty_string) if empty_string.is_empty() => Self::default(), - Some(value) => Self(Self::inner_from_str(value).ok_or_else(|| { - crate::err::FatalErr::config($env_var, value, $allowed_values) - })?), + Some(value) => { + Self(Self::inner_from_str(value).ok_or_else(|| { + super::Error::config($env_var, value, $allowed_values) + })?) + } None => self, }) } diff --git a/src/config/postgres_cfg.rs b/src/config/postgres_cfg.rs index 2e62ff2..2422f7d 100644 --- a/src/config/postgres_cfg.rs +++ b/src/config/postgres_cfg.rs @@ -1,17 +1,19 @@ -use super::{postgres_cfg_types::*, EnvVar}; -use crate::err::FatalErr; +use super::postgres_cfg_types::*; +use super::{EnvVar, Error}; use url::Url; use urlencoding; -type Result = std::result::Result; +type Result = std::result::Result; +/// Configuration values for Postgres #[derive(Debug, Clone)] pub struct Postgres { pub(crate) user: PgUser, pub(crate) host: PgHost, pub(crate) password: PgPass, - pub(crate) database: PgDatabase, + /// The name of the postgres database to connect to + pub database: PgDatabase, pub(crate) port: PgPort, pub(crate) ssl_mode: PgSslMode, } @@ -27,7 +29,7 @@ impl EnvVar { "password" => self.maybe_add_env_var("DB_PASS", Some(v.to_string())), "host" => self.maybe_add_env_var("DB_HOST", Some(v.to_string())), "sslmode" => self.maybe_add_env_var("DB_SSLMODE", Some(v.to_string())), - _ => Err(FatalErr::config( + _ => Err(Error::config( "POSTGRES_URL", &k, "a URL with parameters `password`, `user`, `host`, and `sslmode` only", diff --git a/src/config/redis_cfg.rs b/src/config/redis_cfg.rs index 5f85362..2cc10f7 100644 --- a/src/config/redis_cfg.rs +++ b/src/config/redis_cfg.rs @@ -1,10 +1,9 @@ use super::redis_cfg_types::*; -use super::EnvVar; -use crate::err::FatalErr; +use super::{EnvVar, Error}; use url::Url; -type Result = std::result::Result; +type Result = std::result::Result; #[derive(Debug, Default)] pub struct Redis { @@ -33,7 +32,7 @@ impl EnvVar { match k.to_string().as_str() { "password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())), "db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())), - _ => Err(FatalErr::config( + _ => Err(Error::config( "REDIS_URL", &k, "a URL with parameters `password`, `db`, only", diff --git a/src/config/redis_cfg_types.rs b/src/config/redis_cfg_types.rs index 94e7d1e..6e07063 100644 --- a/src/config/redis_cfg_types.rs +++ b/src/config/redis_cfg_types.rs @@ -1,4 +1,4 @@ -use crate::from_env_var; +use crate::from_env_var; //macro use std::time::Duration; //use std::{fmt, net::IpAddr, os::unix::net::UnixListener, str::FromStr, time::Duration}; //use strum_macros::{EnumString, EnumVariantNames}; diff --git a/src/err.rs b/src/err.rs index 557a2bf..bb4cf61 100644 --- a/src/err.rs +++ b/src/err.rs @@ -1,86 +1,80 @@ -use crate::request::RequestErr; -use crate::response::ManagerErr; +use crate::config; +use crate::request; +use crate::response; + use std::fmt; -pub enum FatalErr { - ReceiverErr(ManagerErr), +pub enum Error { + Response(response::Error), Logger(log::SetLoggerError), - Postgres(RequestErr), + Postgres(request::Error), Unrecoverable, StdIo(std::io::Error), - // config errs - UrlParse(url::ParseError), - UrlEncoding(urlencoding::FromUrlEncodingError), - ConfigErr(String), + Config(config::Error), } -impl FatalErr { +impl Error { pub fn log(msg: impl fmt::Display) { eprintln!("{}", msg); } - - pub fn config(var: T, value: T, allowed_vals: T) -> Self { - Self::ConfigErr(format!( - "{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.", - var, value, allowed_vals, "" - )) - } } -impl std::error::Error for FatalErr {} -impl fmt::Debug for FatalErr { +impl std::error::Error for Error {} + +impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { write!(f, "{}", self) } } -impl fmt::Display for FatalErr { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - use FatalErr::*; + use Error::*; write!( f, "{}", match self { - ReceiverErr(e) => format!("{}", e), + Response(e) => format!("{}", e), Logger(e) => format!("{}", e), StdIo(e) => format!("{}", e), Postgres(e) => format!("could not connect to Postgres.\n{:7}{}", "", e), - ConfigErr(e) => e.to_string(), - UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e), - UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e), + Config(e) => format!("{}", e), Unrecoverable => "Flodgatt will now shut down.".into(), } ) } } -impl From for FatalErr { - fn from(e: RequestErr) -> Self { +#[doc(hidden)] +impl From for Error { + fn from(e: request::Error) -> Self { Self::Postgres(e) } } -impl From for FatalErr { - fn from(e: ManagerErr) -> Self { - Self::ReceiverErr(e) +#[doc(hidden)] +impl From for Error { + fn from(e: response::Error) -> Self { + Self::Response(e) } } -impl From for FatalErr { - fn from(e: urlencoding::FromUrlEncodingError) -> Self { - Self::UrlEncoding(e) + +#[doc(hidden)] +impl From for Error { + fn from(e: config::Error) -> Self { + Self::Config(e) } } -impl From for FatalErr { - fn from(e: url::ParseError) -> Self { - Self::UrlParse(e) - } -} -impl From for FatalErr { + +#[doc(hidden)] +impl From for Error { fn from(e: std::io::Error) -> Self { Self::StdIo(e) } } -impl From for FatalErr { + +#[doc(hidden)] +impl From for Error { fn from(e: log::SetLoggerError) -> Self { Self::Logger(e) } diff --git a/src/event/checked_event/status.rs b/src/event/checked_event/status.rs deleted file mode 100644 index 0976106..0000000 --- a/src/event/checked_event/status.rs +++ /dev/null @@ -1,134 +0,0 @@ -mod application; -mod attachment; -mod card; -mod poll; - -use super::account::Account; -use super::emoji::Emoji; -use super::id::Id; -use super::mention::Mention; -use super::tag::Tag; -use super::visibility::Visibility; -use application::Application; -use attachment::Attachment; -use card::Card; -use poll::Poll; - -use crate::request::Blocks; - -use hashbrown::HashSet; -use serde::{Deserialize, Serialize}; -use std::boxed::Box; -use std::string::String; - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Status { - id: Id, - uri: String, - created_at: String, - account: Account, - content: String, - visibility: Visibility, - sensitive: bool, - spoiler_text: String, - media_attachments: Vec, - application: Option, // Should be non-optional? - mentions: Vec, - tags: Vec, - emojis: Vec, - reblogs_count: i64, - favourites_count: i64, - replies_count: i64, - url: Option, - in_reply_to_id: Option, - in_reply_to_account_id: Option, - reblog: Option>, - poll: Option, - card: Option, - language: Option, - - text: Option, - // ↓↓↓ Only for authorized users - favourited: Option, - reblogged: Option, - muted: Option, - bookmarked: Option, - pinned: Option, -} - -impl Status { - /// Returns `true` if the status is filtered out based on its language - pub(crate) fn language_not(&self, allowed_langs: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - - let reject_and_maybe_log = |toot_language| { - log::info!("Filtering out toot from `{}`", &self.account.acct); - log::info!("Toot language: `{}`", toot_language); - log::info!("Recipient's allowed languages: `{:?}`", allowed_langs); - REJECT - }; - if allowed_langs.is_empty() { - return ALLOW; // listing no allowed_langs results in allowing all languages - } - - match self.language.as_ref() { - Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, - None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == &String::new() => ALLOW, - Some(toot_language) => reject_and_maybe_log(toot_language), - } - } - - /// Returns `true` if the Status originated from a blocked domain, is from an account - /// that has blocked the current user, or if the User's list of blocked/muted users - /// includes a user involved in the Status. - /// - /// A user is involved in the Status/toot if they: - /// * Are mentioned in this toot - /// * Wrote this toot - /// * Wrote a toot that this toot is replying to (if any) - /// * Wrote the toot that this toot is boosting (if any) - pub(crate) fn involves_any(&self, blocks: &Blocks) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let Blocks { - blocked_users, - blocking_users, - blocked_domains, - } = blocks; - let user_id = &Id(self.account.id.0); - - if blocking_users.contains(user_id) || self.involves(blocked_users) { - REJECT - } else { - let full_username = &self.account.acct; - match full_username.split('@').nth(1) { - Some(originating_domain) if blocked_domains.contains(originating_domain) => REJECT, - Some(_) | None => ALLOW, // None means the local instance, which can't be blocked - } - } - } - - fn involves(&self, blocked_users: &HashSet) -> bool { - // involved_users = mentioned_users + author + replied-to user + boosted user - let mut involved_users: HashSet = self - .mentions - .iter() - .map(|mention| Id(mention.id.0)) - .collect(); - - // author - involved_users.insert(Id(self.account.id.0)); - // replied-to user - if let Some(user_id) = self.in_reply_to_account_id { - involved_users.insert(Id(user_id.0)); - } - // boosted user - if let Some(boosted_status) = self.reblog.clone() { - involved_users.insert(Id(boosted_status.account.id.0)); - } - !involved_users.is_disjoint(blocked_users) - } -} diff --git a/src/lib.rs b/src/lib.rs index f0731ca..3eec3e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,12 +35,22 @@ //! polls the `Receiver` and the frequency with which the `Receiver` polls Redis. //! -//#![warn(clippy::pedantic)] +#![warn(clippy::pedantic)] #![allow(clippy::try_err, clippy::match_bool)] -//#![allow(clippy::large_enum_variant)] +#![allow(clippy::large_enum_variant)] + +pub use err::Error; pub mod config; -pub mod err; -pub mod event; +mod err; pub mod request; pub mod response; + +/// A user ID. +/// +/// Internally, Mastodon IDs are i64s, but are sent to clients as string because +/// JavaScript numbers don't support i64s. This newtype serializes to/from a string, but +/// keeps the i64 as the "true" value for internal use. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[doc(hidden)] +pub struct Id(pub i64); diff --git a/src/main.rs b/src/main.rs index f1f3e3c..19d47dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,7 @@ use flodgatt::config; -use flodgatt::err::FatalErr; -use flodgatt::event::Event; use flodgatt::request::{Handler, Subscription, Timeline}; -use flodgatt::response::redis; -use flodgatt::response::stream; +use flodgatt::response::{Event, RedisManager, SseStream, WsStream}; +use flodgatt::Error; use futures::{future::lazy, stream::Stream as _}; use std::fs; @@ -16,7 +14,7 @@ use tokio::timer::Interval; use warp::ws::Ws2; use warp::Filter; -fn main() -> Result<(), FatalErr> { +fn main() -> Result<(), Error> { config::merge_dotenv()?; pretty_env_logger::try_init()?; let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?; @@ -27,7 +25,7 @@ fn main() -> Result<(), FatalErr> { let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?; - let shared_manager = redis::Manager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc(); + let shared_manager = RedisManager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc(); // Server Sent Events let sse_manager = shared_manager.clone(); @@ -37,10 +35,10 @@ fn main() -> Result<(), FatalErr> { .and(warp::sse()) .map(move |subscription: Subscription, sse: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); - let mut manager = sse_manager.lock().unwrap_or_else(redis::Manager::recover); + let mut manager = sse_manager.lock().unwrap_or_else(RedisManager::recover); manager.subscribe(&subscription); - stream::Sse::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) + SseStream::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) }) .with(warp::reply::with::header("Connection", "keep-alive")); @@ -51,10 +49,10 @@ fn main() -> Result<(), FatalErr> { .and(warp::ws::ws2()) .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); - let mut manager = ws_manager.lock().unwrap_or_else(redis::Manager::recover); + let mut manager = ws_manager.lock().unwrap_or_else(RedisManager::recover); manager.subscribe(&subscription); let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security - let ws_stream = stream::Ws::new(cmd_tx.clone(), event_rx.clone(), subscription); + let ws_stream = WsStream::new(cmd_tx.clone(), event_rx.clone(), subscription); (ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token) }) @@ -66,9 +64,9 @@ fn main() -> Result<(), FatalErr> { let (r1, r3) = (shared_manager.clone(), shared_manager.clone()); request.health().map(|| "OK") .or(request.status() - .map(move || r1.lock().unwrap_or_else(redis::Manager::recover).count())) + .map(move || r1.lock().unwrap_or_else(RedisManager::recover).count())) .or(request.status_per_timeline() - .map(move || r3.lock().unwrap_or_else(redis::Manager::recover).list())) + .map(move || r3.lock().unwrap_or_else(RedisManager::recover).list())) }; #[cfg(not(feature = "stub_status"))] let status = request.health().map(|| "OK"); @@ -78,22 +76,14 @@ fn main() -> Result<(), FatalErr> { .allow_methods(cfg.cors.allowed_methods) .allow_headers(cfg.cors.allowed_headers); - // use futures::future::Future; let streaming_server = move || { let manager = shared_manager.clone(); let stream = Interval::new(Instant::now(), poll_freq) - // .take(1200) .map_err(|e| log::error!("{}", e)) - .for_each( - move |_| { - let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover); - manager.poll_broadcast().map_err(FatalErr::log) - }, // ).and_then(|_| { - // log::info!("shutting down!"); - // std::process::exit(0); - // futures::future::ok(()) - // } - ); + .for_each(move |_| { + let mut manager = manager.lock().unwrap_or_else(RedisManager::recover); + manager.poll_broadcast().map_err(Error::log) + }); warp::spawn(lazy(move || stream)); warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err)) @@ -109,5 +99,5 @@ fn main() -> Result<(), FatalErr> { let server_addr = SocketAddr::new(*cfg.address, *cfg.port); tokio::run(lazy(move || streaming_server().bind(server_addr))); } - Err(FatalErr::Unrecoverable) // on get here if there's an unrecoverable error in poll_broadcast. + Err(Error::Unrecoverable) // only get here if there's an unrecoverable error in poll_broadcast. } diff --git a/src/request.rs b/src/request.rs index e265d77..fdae7db 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,21 +1,19 @@ //! Parse the client request and return a Subscription mod postgres; mod query; -pub mod timeline; +mod timeline; mod err; mod subscription; -pub(crate) use self::err::RequestErr; -pub(crate) use self::postgres::PgPool; - -pub(crate) use subscription::Blocks; -pub use subscription::Subscription; +pub use err::{Error, Timeline as TimelineErr}; +pub use subscription::{Blocks, Subscription}; pub use timeline::Timeline; -pub(crate) use timeline::{Content, Reach, Stream, TimelineErr}; +use timeline::{Content, Reach, Stream}; +pub use self::postgres::PgPool; use self::query::Query; -use crate::config; +use crate::config::Postgres; use warp::filters::BoxedFilter; use warp::http::StatusCode; use warp::path; @@ -26,7 +24,7 @@ mod sse_test; #[cfg(test)] mod ws_test; -type Result = std::result::Result; +type Result = std::result::Result; /// Helper macro to match on the first of any of the provided filters macro_rules! any_of { @@ -62,7 +60,7 @@ pub struct Handler { } impl Handler { - pub fn new(postgres_cfg: &config::Postgres, whitelist_mode: bool) -> Result { + pub fn new(postgres_cfg: &Postgres, whitelist_mode: bool) -> Result { Ok(Self { pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?, }) diff --git a/src/request/err.rs b/src/request/err.rs index 0637e30..9050af9 100644 --- a/src/request/err.rs +++ b/src/request/err.rs @@ -1,15 +1,15 @@ use std::fmt; #[derive(Debug)] -pub enum RequestErr { +pub enum Error { PgPool(r2d2::Error), Pg(postgres::Error), } -impl std::error::Error for RequestErr {} +impl std::error::Error for Error {} -impl fmt::Display for RequestErr { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - use RequestErr::*; + use Error::*; let msg = match self { PgPool(e) => format!("{}", e), Pg(e) => format!("{}", e), @@ -18,13 +18,40 @@ impl fmt::Display for RequestErr { } } -impl From for RequestErr { +impl From for Error { fn from(e: r2d2::Error) -> Self { Self::PgPool(e) } } -impl From for RequestErr { +impl From for Error { fn from(e: postgres::Error) -> Self { Self::Pg(e) } } +// TODO make Timeline & TimelineErr their own top-level module +#[derive(Debug)] +pub enum Timeline { + MissingHashtag, + InvalidInput, + BadTag, +} + +impl std::error::Error for Timeline {} + +impl From for Timeline { + fn from(_error: std::num::ParseIntError) -> Self { + Self::InvalidInput + } +} + +impl fmt::Display for Timeline { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use Timeline::*; + let msg = match self { + InvalidInput => "The timeline text from Redis could not be parsed into a supported timeline. TODO: add incoming timeline text", + MissingHashtag => "Attempted to send a hashtag timeline without supplying a tag name", + BadTag => "No hashtag exists with the specified hashtag ID" + }; + write!(f, "{}", msg) + } +} diff --git a/src/request/postgres.rs b/src/request/postgres.rs index a5ea50c..b6c621b 100644 --- a/src/request/postgres.rs +++ b/src/request/postgres.rs @@ -2,7 +2,7 @@ use super::err; use super::timeline::{Scope, UserData}; use crate::config; -use crate::event::Id; +use crate::Id; use ::postgres; use hashbrown::HashSet; @@ -15,7 +15,7 @@ pub struct PgPool { whitelist_mode: bool, } -type Result = std::result::Result; +type Result = std::result::Result; type Rejectable = std::result::Result; impl PgPool { diff --git a/src/request/sse_test.rs b/src/request/sse_test.rs index 287cd9b..1a585f1 100644 --- a/src/request/sse_test.rs +++ b/src/request/sse_test.rs @@ -3,7 +3,6 @@ // #[cfg(test)] // mod test { // use super::*; -// use crate::parse_client_request::user::{Blocks, Filter, OauthScope, PgPool}; // macro_rules! test_public_endpoint { // ($name:ident { diff --git a/src/request/subscription.rs b/src/request/subscription.rs index 8b105dc..3ee773e 100644 --- a/src/request/subscription.rs +++ b/src/request/subscription.rs @@ -8,7 +8,7 @@ use super::postgres::PgPool; use super::query::Query; use super::{Content, Reach, Stream, Timeline}; -use crate::event::Id; +use crate::Id; use hashbrown::HashSet; @@ -17,17 +17,19 @@ use warp::reject::Rejection; #[derive(Clone, Debug, PartialEq)] pub struct Subscription { pub timeline: Timeline, - pub(crate) allowed_langs: HashSet, - pub(crate) blocks: Blocks, - pub(crate) hashtag_name: Option, + pub allowed_langs: HashSet, + /// [Blocks](./request/struct.Blocks.html) + pub blocks: Blocks, + pub hashtag_name: Option, pub access_token: Option, } +/// Blocked and muted users and domains #[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct Blocks { - pub(crate) blocked_domains: HashSet, - pub(crate) blocked_users: HashSet, - pub(crate) blocking_users: HashSet, +pub struct Blocks { + pub blocked_domains: HashSet, + pub blocked_users: HashSet, + pub blocking_users: HashSet, } impl Default for Subscription { diff --git a/src/request/timeline.rs b/src/request/timeline.rs index b8d221a..154a839 100644 --- a/src/request/timeline.rs +++ b/src/request/timeline.rs @@ -1,5 +1,5 @@ -pub(crate) use self::err::TimelineErr; pub(crate) use self::inner::{Content, Reach, Scope, Stream, UserData}; +use super::err::Timeline as Error; use super::query::Query; use lru::LruCache; @@ -8,7 +8,7 @@ use warp::reject::Rejection; mod err; mod inner; -type Result = std::result::Result; +type Result = std::result::Result; #[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] pub struct Timeline(pub(crate) Stream, pub(crate) Reach, pub(crate) Content); @@ -18,9 +18,25 @@ impl Timeline { Self(Stream::Unset, Reach::Local, Content::Notification) } + pub(crate) fn is_public(&self) -> bool { + if let Self(Stream::Public, _, _) = self { + true + } else { + false + } + } + + pub(crate) fn tag(&self) -> Option { + if let Self(Stream::Hashtag(id), _, _) = self { + Some(*id) + } else { + None + } + } + pub(crate) fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result { // TODO -- does this need to account for namespaces? - use {Content::*, Reach::*, Stream::*, TimelineErr::*}; + use {Content::*, Error::*, Reach::*, Stream::*}; Ok(match self { Timeline(Public, Federated, All) => "timeline:public".to_string(), @@ -42,7 +58,7 @@ impl Timeline { } Timeline(List(id), Federated, All) => ["timeline:list:", &id.to_string()].concat(), Timeline(Direct(id), Federated, All) => ["timeline:direct:", &id.to_string()].concat(), - Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?, + Timeline(_one, _two, _three) => Err(Error::InvalidInput)?, }) } @@ -50,7 +66,7 @@ impl Timeline { timeline: &str, cache: &mut LruCache, ) -> Result { - use {Content::*, Reach::*, Stream::*, TimelineErr::*}; + use {Content::*, Error::*, Reach::*, Stream::*}; let mut tag_id = |t: &str| cache.get(&t.to_string()).map_or(Err(BadTag), |id| Ok(*id)); Ok(match &timeline.split(':').collect::>()[..] { diff --git a/src/request/timeline/err.rs b/src/request/timeline/err.rs index 5dbf660..8b13789 100644 --- a/src/request/timeline/err.rs +++ b/src/request/timeline/err.rs @@ -1,28 +1 @@ -use std::fmt; -#[derive(Debug)] -pub enum TimelineErr { - MissingHashtag, - InvalidInput, - BadTag, -} - -impl std::error::Error for TimelineErr {} - -impl From for TimelineErr { - fn from(_error: std::num::ParseIntError) -> Self { - Self::InvalidInput - } -} - -impl fmt::Display for TimelineErr { - fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - use TimelineErr::*; - let msg = match self { - InvalidInput => "The timeline text from Redis could not be parsed into a supported timeline. TODO: add incoming timeline text", - MissingHashtag => "Attempted to send a hashtag timeline without supplying a tag name", - BadTag => "No hashtag exists with the specified hashtag ID" - }; - write!(f, "{}", msg) - } -} diff --git a/src/request/timeline/inner.rs b/src/request/timeline/inner.rs index fd3a884..f77fa4b 100644 --- a/src/request/timeline/inner.rs +++ b/src/request/timeline/inner.rs @@ -1,5 +1,5 @@ -use super::TimelineErr; -use crate::event::Id; +use super::Error; +use crate::Id; use hashbrown::HashSet; use std::convert::TryFrom; @@ -36,18 +36,18 @@ pub(crate) enum Scope { } impl TryFrom<&str> for Scope { - type Error = TimelineErr; + type Error = Error; - fn try_from(s: &str) -> Result { + fn try_from(s: &str) -> Result { match s { "read" => Ok(Scope::Read), "read:statuses" => Ok(Scope::Statuses), "read:notifications" => Ok(Scope::Notifications), "read:lists" => Ok(Scope::Lists), - "write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes + "write" | "follow" => Err(Error::InvalidInput), // ignore write scopes unexpected => { log::warn!("Ignoring unknown scope `{}`", unexpected); - Err(TimelineErr::InvalidInput) + Err(Error::InvalidInput) } } } diff --git a/src/request/ws_test.rs b/src/request/ws_test.rs index e101bed..8b547a4 100644 --- a/src/request/ws_test.rs +++ b/src/request/ws_test.rs @@ -1,14 +1,8 @@ //! Filters for the WebSocket endpoint - - - - - // #[cfg(test)] // mod test { // use super::*; -// use crate::parse_client_request::user::{Blocks, Filter, OauthScope}; // macro_rules! test_public_endpoint { // ($name:ident { diff --git a/src/response.rs b/src/response.rs index b247dfd..be209c3 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,9 +1,17 @@ //! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. -pub mod redis; -pub mod stream; +pub use event::Event; +pub use redis::Manager as RedisManager; +pub use stream::{Sse as SseStream, Ws as WsStream}; -pub(crate) use redis::ManagerErr; +pub(self) use event::err::Event as EventErr; +pub(self) use event::Payload; + +pub(crate) mod event; +mod redis; +mod stream; + +pub use redis::Error; #[cfg(feature = "bench")] pub use redis::msg::{RedisMsg, RedisParseOutput}; diff --git a/src/event.rs b/src/response/event.rs similarity index 82% rename from src/event.rs rename to src/response/event.rs index a721370..4227124 100644 --- a/src/event.rs +++ b/src/response/event.rs @@ -1,11 +1,12 @@ mod checked_event; mod dynamic_event; -mod err; +pub mod err; -pub(crate) use checked_event::{CheckedEvent, Id}; -pub(crate) use dynamic_event::{DynEvent, EventKind}; -pub(crate) use err::EventErr; +use self::checked_event::CheckedEvent; +use self::dynamic_event::{DynEvent, EventKind}; +use crate::Id; +use hashbrown::HashSet; use serde::Serialize; use std::convert::TryFrom; use std::string::String; @@ -18,6 +19,18 @@ pub enum Event { Ping, } +pub(crate) trait Payload { + fn language_unset(&self) -> bool; + + fn language(&self) -> String; + + fn involved_users(&self) -> HashSet; + + fn author(&self) -> &Id; + + fn sent_from(&self) -> &str; +} + impl Event { pub(crate) fn to_json_string(&self) -> String { if let Event::Ping = self { @@ -43,6 +56,26 @@ impl Event { } } + pub(crate) fn update_payload(&self) -> Option<&checked_event::Status> { + if let Self::TypeSafe(CheckedEvent::Update { payload, .. }) = self { + Some(&payload) + } else { + None + } + } + + pub(crate) fn dyn_update_payload(&self) -> Option<&dynamic_event::DynStatus> { + if let Self::Dynamic(DynEvent { + kind: EventKind::Update(s), + .. + }) = self + { + Some(&s) + } else { + None + } + } + fn event_name(&self) -> String { String::from(match self { Self::TypeSafe(checked) => match checked { @@ -84,14 +117,14 @@ impl Event { } impl TryFrom for Event { - type Error = EventErr; + type Error = err::Event; fn try_from(event_txt: String) -> Result { Event::try_from(event_txt.as_str()) } } impl TryFrom<&str> for Event { - type Error = EventErr; + type Error = err::Event; fn try_from(event_txt: &str) -> Result { match serde_json::from_str(event_txt) { diff --git a/src/event/checked_event.rs b/src/response/event/checked_event.rs similarity index 89% rename from src/event/checked_event.rs rename to src/response/event/checked_event.rs index 5104de9..b763c8b 100644 --- a/src/event/checked_event.rs +++ b/src/response/event/checked_event.rs @@ -11,13 +11,13 @@ mod status; mod tag; mod visibility; -use announcement::Announcement; -pub(in crate::event) use announcement_reaction::AnnouncementReaction; -use conversation::Conversation; -pub(crate) use id::Id; -use notification::Notification; -use status::Status; +pub(self) use super::Payload; +pub(super) use announcement_reaction::AnnouncementReaction; +pub(crate) use status::Status; +use announcement::Announcement; +use conversation::Conversation; +use notification::Notification; use serde::Deserialize; #[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] diff --git a/src/event/checked_event/account.rs b/src/response/event/checked_event/account.rs similarity index 94% rename from src/event/checked_event/account.rs rename to src/response/event/checked_event/account.rs index 0322cc3..0c00897 100644 --- a/src/event/checked_event/account.rs +++ b/src/response/event/checked_event/account.rs @@ -1,4 +1,5 @@ -use super::{emoji::Emoji, id::Id, visibility::Visibility}; +use super::{emoji::Emoji, visibility::Visibility}; +use crate::Id; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] diff --git a/src/event/checked_event/announcement.rs b/src/response/event/checked_event/announcement.rs similarity index 100% rename from src/event/checked_event/announcement.rs rename to src/response/event/checked_event/announcement.rs diff --git a/src/event/checked_event/announcement_reaction.rs b/src/response/event/checked_event/announcement_reaction.rs similarity index 100% rename from src/event/checked_event/announcement_reaction.rs rename to src/response/event/checked_event/announcement_reaction.rs diff --git a/src/event/checked_event/conversation.rs b/src/response/event/checked_event/conversation.rs similarity index 100% rename from src/event/checked_event/conversation.rs rename to src/response/event/checked_event/conversation.rs diff --git a/src/event/checked_event/emoji.rs b/src/response/event/checked_event/emoji.rs similarity index 100% rename from src/event/checked_event/emoji.rs rename to src/response/event/checked_event/emoji.rs diff --git a/src/event/checked_event/id.rs b/src/response/event/checked_event/id.rs similarity index 80% rename from src/event/checked_event/id.rs rename to src/response/event/checked_event/id.rs index 0226e5d..73f7a0c 100644 --- a/src/event/checked_event/id.rs +++ b/src/response/event/checked_event/id.rs @@ -1,4 +1,5 @@ -use super::super::EventErr; +use super::super::err; +use crate::Id; use serde::{ de::{self, Visitor}, @@ -7,19 +8,11 @@ use serde::{ use serde_json::Value; use std::{convert::TryFrom, fmt, num::ParseIntError, str::FromStr}; -/// A user ID. -/// -/// Internally, Mastodon IDs are i64s, but are sent to clients as string because -/// JavaScript numbers don't support i64s. This newtype serializes to/from a string, but -/// keeps the i64 as the "true" value for internal use. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub struct Id(pub i64); - impl TryFrom<&Value> for Id { - type Error = EventErr; + type Error = err::Event; fn try_from(v: &Value) -> Result { - Ok(v.as_str().ok_or(EventErr::DynParse)?.parse()?) + Ok(v.as_str().ok_or(err::Event::DynParse)?.parse()?) } } diff --git a/src/event/checked_event/mention.rs b/src/response/event/checked_event/mention.rs similarity index 92% rename from src/event/checked_event/mention.rs rename to src/response/event/checked_event/mention.rs index e5e95f8..14c47ad 100644 --- a/src/event/checked_event/mention.rs +++ b/src/response/event/checked_event/mention.rs @@ -1,4 +1,4 @@ -use super::id::Id; +use crate::Id; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] diff --git a/src/event/checked_event/notification.rs b/src/response/event/checked_event/notification.rs similarity index 100% rename from src/event/checked_event/notification.rs rename to src/response/event/checked_event/notification.rs diff --git a/src/response/event/checked_event/status.rs b/src/response/event/checked_event/status.rs new file mode 100644 index 0000000..684cf67 --- /dev/null +++ b/src/response/event/checked_event/status.rs @@ -0,0 +1,103 @@ +mod application; +mod attachment; +mod card; +mod poll; + +use super::account::Account; +use super::emoji::Emoji; +use super::mention::Mention; +use super::tag::Tag; +use super::visibility::Visibility; +use super::Payload; +use crate::Id; +use application::Application; +use attachment::Attachment; +use card::Card; +use hashbrown::HashSet; +use poll::Poll; +use serde::{Deserialize, Serialize}; +use std::boxed::Box; +use std::string::String; + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Status { + id: Id, + uri: String, + created_at: String, + account: Account, + content: String, + visibility: Visibility, + sensitive: bool, + spoiler_text: String, + media_attachments: Vec, + application: Option, // Should be non-optional? + mentions: Vec, + tags: Vec, + emojis: Vec, + reblogs_count: i64, + favourites_count: i64, + replies_count: i64, + url: Option, + in_reply_to_id: Option, + in_reply_to_account_id: Option, + reblog: Option>, + poll: Option, + card: Option, + pub(crate) language: Option, + + text: Option, + // ↓↓↓ Only for authorized users + favourited: Option, + reblogged: Option, + muted: Option, + bookmarked: Option, + pinned: Option, +} + +impl Payload for Status { + fn language_unset(&self) -> bool { + match &self.language { + None => true, + Some(empty) if empty == &String::new() => true, + Some(_language) => false, + } + } + + fn language(&self) -> String { + self.language.clone().unwrap_or_default() + } + + /// Returns all users involved in the `Status`. + /// + /// A user is involved in the Status/toot if they: + /// * Are mentioned in this toot + /// * Wrote this toot + /// * Wrote a toot that this toot is replying to (if any) + /// * Wrote the toot that this toot is boosting (if any) + fn involved_users(&self) -> HashSet { + // involved_users = mentioned_users + author + replied-to user + boosted user + let mut involved_users: HashSet = self.mentions.iter().map(|m| Id(m.id.0)).collect(); + + // author + involved_users.insert(Id(self.account.id.0)); + // replied-to user + if let Some(user_id) = self.in_reply_to_account_id { + involved_users.insert(Id(user_id.0)); + } + // boosted user + if let Some(boosted_status) = self.reblog.clone() { + involved_users.insert(Id(boosted_status.account.id.0)); + } + involved_users + } + + fn author(&self) -> &Id { + &self.account.id + } + + fn sent_from(&self) -> &str { + let sender_username = &self.account.acct; + sender_username.split('@').nth(1).unwrap_or_default() // default occurs when sent from local instance + } +} diff --git a/src/event/checked_event/status/application.rs b/src/response/event/checked_event/status/application.rs similarity index 100% rename from src/event/checked_event/status/application.rs rename to src/response/event/checked_event/status/application.rs diff --git a/src/event/checked_event/status/attachment.rs b/src/response/event/checked_event/status/attachment.rs similarity index 100% rename from src/event/checked_event/status/attachment.rs rename to src/response/event/checked_event/status/attachment.rs diff --git a/src/event/checked_event/status/card.rs b/src/response/event/checked_event/status/card.rs similarity index 100% rename from src/event/checked_event/status/card.rs rename to src/response/event/checked_event/status/card.rs diff --git a/src/event/checked_event/status/poll.rs b/src/response/event/checked_event/status/poll.rs similarity index 100% rename from src/event/checked_event/status/poll.rs rename to src/response/event/checked_event/status/poll.rs diff --git a/src/event/checked_event/tag.rs b/src/response/event/checked_event/tag.rs similarity index 100% rename from src/event/checked_event/tag.rs rename to src/response/event/checked_event/tag.rs diff --git a/src/event/checked_event/visibility.rs b/src/response/event/checked_event/visibility.rs similarity index 100% rename from src/event/checked_event/visibility.rs rename to src/response/event/checked_event/visibility.rs diff --git a/src/event/dynamic_event.rs b/src/response/event/dynamic_event.rs similarity index 52% rename from src/event/dynamic_event.rs rename to src/response/event/dynamic_event.rs index 0ac80f2..9ab9db4 100644 --- a/src/event/dynamic_event.rs +++ b/src/response/event/dynamic_event.rs @@ -1,5 +1,6 @@ -use super::{EventErr, Id}; -use crate::request::Blocks; +use super::err; +use super::Payload; +use crate::Id; use std::convert::TryFrom; @@ -38,7 +39,7 @@ pub(crate) struct DynStatus { pub(crate) boosted_user: Option, } -type Result = std::result::Result; +type Result = std::result::Result; impl DynEvent { pub(crate) fn set_update(self) -> Result { @@ -50,16 +51,13 @@ impl DynEvent { } } } - impl DynStatus { pub(crate) fn new(payload: &Value) -> Result { - use EventErr::*; - Ok(Self { id: Id::try_from(&payload["account"]["id"])?, username: payload["account"]["acct"] .as_str() - .ok_or(DynParse)? + .ok_or(err::Event::DynParse)? .to_string(), language: payload["language"].as_str().map(String::from), mentioned_users: HashSet::new(), @@ -67,68 +65,50 @@ impl DynStatus { boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(), }) } - /// Returns `true` if the status is filtered out based on its language - pub(crate) fn language_not(&self, allowed_langs: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; +} - if allowed_langs.is_empty() { - return ALLOW; // listing no allowed_langs results in allowing all languages - } - - match self.language.clone() { - Some(toot_language) if allowed_langs.contains(&toot_language) => ALLOW, // - None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == String::new() => ALLOW, - Some(_toot_language) => REJECT, +impl Payload for DynStatus { + fn language_unset(&self) -> bool { + match &self.language { + None => true, + Some(empty) if empty == &String::new() => true, + Some(_language) => false, } } - /// Returns `true` if the toot contained in this Event originated from a blocked domain, - /// is from an account that has blocked the current user, or if the User's list of - /// blocked/muted users includes a user involved in the toot. + fn language(&self) -> String { + self.language.clone().unwrap_or_default() + } + /// Returns all users involved in the `Status`. /// - /// A user is involved in the toot if they: + /// A user is involved in the Status/toot if they: /// * Are mentioned in this toot /// * Wrote this toot /// * Wrote a toot that this toot is replying to (if any) /// * Wrote the toot that this toot is boosting (if any) - pub(crate) fn involves_any(&self, blocks: &Blocks) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let Blocks { - blocked_users, - blocking_users, - blocked_domains, - } = blocks; - - if self.involves(blocked_users) || blocking_users.contains(&self.id) { - REJECT - } else { - match self.username.split('@').nth(1) { - Some(originating_domain) if blocked_domains.contains(originating_domain) => REJECT, - Some(_) | None => ALLOW, // None means the local instance, which can't be blocked - } - } - } - - fn involves(&self, blocked_users: &HashSet) -> bool { - // mentions + fn involved_users(&self) -> HashSet { + // involved_users = mentioned_users + author + replied-to user + boosted user let mut involved_users: HashSet = self.mentioned_users.clone(); // author involved_users.insert(self.id); - // replied-to user if let Some(user_id) = self.replied_to_user { involved_users.insert(user_id); } - // boosted user - if let Some(user_id) = self.boosted_user { - involved_users.insert(user_id); + if let Some(boosted_status) = self.boosted_user { + involved_users.insert(boosted_status); } + involved_users + } - !involved_users.is_disjoint(blocked_users) + fn author(&self) -> &Id { + &self.id + } + + fn sent_from(&self) -> &str { + let sender_username = &self.username; + sender_username.split('@').nth(1).unwrap_or_default() // default occurs when sent from local instance } } diff --git a/src/event/err.rs b/src/response/event/err.rs similarity index 76% rename from src/event/err.rs rename to src/response/event/err.rs index c51b6a3..0c6ce42 100644 --- a/src/event/err.rs +++ b/src/response/event/err.rs @@ -1,17 +1,17 @@ use std::{fmt, num::ParseIntError}; #[derive(Debug)] -pub enum EventErr { +pub enum Event { SerdeParse(serde_json::Error), NonNumId(ParseIntError), DynParse, } -impl std::error::Error for EventErr {} +impl std::error::Error for Event {} -impl fmt::Display for EventErr { +impl fmt::Display for Event { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - use EventErr::*; + use Event::*; match self { SerdeParse(inner) => write!(f, "{}", inner), NonNumId(inner) => write!(f, "ID could not be parsed: {}", inner), @@ -21,12 +21,12 @@ impl fmt::Display for EventErr { } } -impl From for EventErr { +impl From for Event { fn from(error: ParseIntError) -> Self { Self::NonNumId(error) } } -impl From for EventErr { +impl From for Event { fn from(error: serde_json::Error) -> Self { Self::SerdeParse(error) } diff --git a/src/response/redis.rs b/src/response/redis.rs index da37491..db24126 100644 --- a/src/response/redis.rs +++ b/src/response/redis.rs @@ -2,18 +2,21 @@ mod connection; mod manager; mod msg; -pub(crate) use connection::{RedisConn, RedisConnErr}; +pub(self) use super::{Event, EventErr}; +pub(self) use connection::RedisConn; +pub use manager::Error; pub use manager::Manager; -pub(crate) use manager::ManagerErr; -pub(crate) use msg::RedisParseErr; -pub(crate) enum RedisCmd { +use connection::RedisConnErr; +use msg::RedisParseErr; + +enum RedisCmd { Subscribe, Unsubscribe, } impl RedisCmd { - pub(crate) fn into_sendable(self, tl: &str) -> (Vec, Vec) { + fn into_sendable(self, tl: &str) -> (Vec, Vec) { match self { RedisCmd::Subscribe => ( [ diff --git a/src/response/redis/connection.rs b/src/response/redis/connection.rs index ec0795c..d24eff9 100644 --- a/src/response/redis/connection.rs +++ b/src/response/redis/connection.rs @@ -2,10 +2,11 @@ mod err; pub(crate) use err::RedisConnErr; use super::msg::{RedisParseErr, RedisParseOutput}; -use super::{ManagerErr, RedisCmd}; +use super::Error as ManagerErr; +use super::Event; +use super::RedisCmd; use crate::config::Redis; -use crate::event::Event; -use crate::request::{Stream, Timeline}; +use crate::request::Timeline; use futures::{Async, Poll}; use lru::LruCache; @@ -18,7 +19,7 @@ use std::time::Duration; type Result = std::result::Result; #[derive(Debug)] -pub(crate) struct RedisConn { +pub(super) struct RedisConn { primary: TcpStream, secondary: TcpStream, redis_namespace: Option, @@ -29,7 +30,7 @@ pub(crate) struct RedisConn { } impl RedisConn { - pub(crate) fn new(redis_cfg: &Redis) -> Result { + pub(super) fn new(redis_cfg: &Redis) -> Result { let addr = [&*redis_cfg.host, ":", &*redis_cfg.port.to_string()].concat(); let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?; @@ -50,7 +51,7 @@ impl RedisConn { Ok(redis_conn) } - pub(crate) fn poll_redis(&mut self) -> Poll, ManagerErr> { + pub(super) fn poll_redis(&mut self) -> Poll, ManagerErr> { loop { match self.primary.read(&mut self.redis_input[self.cursor..]) { Ok(n) => { @@ -111,18 +112,15 @@ impl RedisConn { res } - pub(crate) fn update_cache(&mut self, hashtag: String, id: i64) { + pub(super) fn update_cache(&mut self, hashtag: String, id: i64) { self.tag_id_cache.put(hashtag.clone(), id); self.tag_name_cache.put(id, hashtag); } pub(crate) fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> { - let hashtag = match timeline { - Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id), - _non_hashtag_timeline => None, - }; - + let hashtag = timeline.tag().and_then(|id| self.tag_name_cache.get(&id)); let tl = timeline.to_redis_raw_timeline(hashtag)?; + let (primary_cmd, secondary_cmd) = cmd.into_sendable(&tl); self.primary.write_all(&primary_cmd)?; self.secondary.write_all(&secondary_cmd)?; diff --git a/src/response/redis/connection/err.rs b/src/response/redis/connection/err.rs index 0707597..893dd30 100644 --- a/src/response/redis/connection/err.rs +++ b/src/response/redis/connection/err.rs @@ -1,4 +1,4 @@ -use crate::request::TimelineErr; +use crate::request; use std::fmt; #[derive(Debug)] @@ -9,11 +9,11 @@ pub enum RedisConnErr { IncorrectPassword(String), MissingPassword, NotRedis(String), - TimelineErr(TimelineErr), + TimelineErr(request::TimelineErr), } impl RedisConnErr { - pub(crate) fn with_addr>(address: T, inner: std::io::Error) -> Self { + pub(super) fn with_addr>(address: T, inner: std::io::Error) -> Self { Self::ConnectionErr { addr: address.as_ref().to_string(), inner, @@ -57,8 +57,8 @@ impl fmt::Display for RedisConnErr { } } -impl From for RedisConnErr { - fn from(e: TimelineErr) -> RedisConnErr { +impl From for RedisConnErr { + fn from(e: request::TimelineErr) -> RedisConnErr { RedisConnErr::TimelineErr(e) } } diff --git a/src/response/redis/manager.rs b/src/response/redis/manager.rs index afb4795..b32eef2 100644 --- a/src/response/redis/manager.rs +++ b/src/response/redis/manager.rs @@ -2,12 +2,14 @@ //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. mod err; -pub(crate) use err::ManagerErr; +pub use err::Error; +use super::Event; use super::{RedisCmd, RedisConn}; use crate::config; -use crate::event::Event; -use crate::request::{Stream, Subscription, Timeline}; +use crate::request::{Subscription, Timeline}; + +pub(self) use super::EventErr; use futures::{Async, Stream as _Stream}; use hashbrown::HashMap; @@ -15,7 +17,7 @@ use std::sync::{Arc, Mutex, MutexGuard, PoisonError}; use std::time::{Duration, Instant}; use tokio::sync::{mpsc, watch}; -type Result = std::result::Result; +type Result = std::result::Result; /// The item that streams from Redis and is polled by the `ClientAgent` #[derive(Debug)] @@ -50,7 +52,7 @@ impl Manager { pub fn subscribe(&mut self, subscription: &Subscription) { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); - if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) { + if let (Some(hashtag), Some(id)) = (tag, tl.tag()) { self.redis_connection.update_cache(hashtag, id); }; diff --git a/src/response/redis/manager/err.rs b/src/response/redis/manager/err.rs index b543520..8d6d4f2 100644 --- a/src/response/redis/manager/err.rs +++ b/src/response/redis/manager/err.rs @@ -1,10 +1,10 @@ use super::super::{RedisConnErr, RedisParseErr}; -use crate::event::{Event, EventErr}; +use super::{Event, EventErr}; use crate::request::{Timeline, TimelineErr}; use std::fmt; #[derive(Debug)] -pub enum ManagerErr { +pub enum Error { InvalidId, TimelineErr(TimelineErr), EventErr(EventErr), @@ -13,11 +13,11 @@ pub enum ManagerErr { ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>), } -impl std::error::Error for ManagerErr {} +impl std::error::Error for Error {} -impl fmt::Display for ManagerErr { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - use ManagerErr::*; + use Error::*; match self { InvalidId => write!( f, @@ -33,31 +33,31 @@ impl fmt::Display for ManagerErr { } } -impl From> for ManagerErr { +impl From> for Error { fn from(error: tokio::sync::watch::error::SendError<(Timeline, Event)>) -> Self { Self::ChannelSendErr(error) } } -impl From for ManagerErr { +impl From for Error { fn from(error: EventErr) -> Self { Self::EventErr(error) } } -impl From for ManagerErr { +impl From for Error { fn from(e: RedisConnErr) -> Self { Self::RedisConnErr(e) } } -impl From for ManagerErr { +impl From for Error { fn from(e: TimelineErr) -> Self { Self::TimelineErr(e) } } -impl From for ManagerErr { +impl From for Error { fn from(e: RedisParseErr) -> Self { Self::RedisParseErr(e) } diff --git a/src/response/stream.rs b/src/response/stream.rs index 0e79589..453fb55 100644 --- a/src/response/stream.rs +++ b/src/response/stream.rs @@ -1,5 +1,7 @@ pub use sse::Sse; pub use ws::Ws; +pub(self) use super::{Event, Payload}; + mod sse; mod ws; diff --git a/src/response/stream/sse.rs b/src/response/stream/sse.rs index 9bcb058..d427440 100644 --- a/src/response/stream/sse.rs +++ b/src/response/stream/sse.rs @@ -1,4 +1,4 @@ -use crate::event::Event; +use super::{Event, Payload}; use crate::request::{Subscription, Timeline}; use futures::stream::Stream; @@ -18,41 +18,19 @@ impl Sse { sse_rx: watch::Receiver<(Timeline, Event)>, ) -> impl Reply { let target_timeline = subscription.timeline; - let allowed_langs = subscription.allowed_langs; - let blocks = subscription.blocks; let event_stream = sse_rx .filter(move |(timeline, _)| target_timeline == *timeline) - .filter_map(move |(timeline, event)| { - use crate::event::{ - CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind, - }; - - use crate::request::Stream::Public; - match event { - TypeSafe(Update { payload, queued_at }) => match timeline { - Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, - _ if payload.involves_any(&blocks) => None, - _ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at }) - .to_warp_reply(), - }, - TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(), - Dynamic(dyn_event) => { - if let EventKind::Update(s) = dyn_event.kind { - match timeline { - Timeline(Public, _, _) if s.language_not(&allowed_langs) => None, - _ if s.involves_any(&blocks) => None, - _ => Dynamic(DynEvent { - kind: EventKind::Update(s), - ..dyn_event - }) - .to_warp_reply(), - } - } else { - None - } + .filter_map(move |(_timeline, event)| { + match (event.update_payload(), event.dyn_update_payload()) { + (Some(update), _) if Sse::update_not_filtered(subscription.clone(), update) => { + event.to_warp_reply() } - Ping => None, // pings handled automatically + (None, None) => event.to_warp_reply(), // send all non-updates + (_, Some(update)) if Sse::update_not_filtered(subscription.clone(), update) => { + event.to_warp_reply() + } + (_, _) => None, } }) .then(move |res| { @@ -69,4 +47,23 @@ impl Sse { .stream(event_stream), ) } + + fn update_not_filtered(subscription: Subscription, update: &impl Payload) -> bool { + let blocks = &subscription.blocks; + let allowed_langs = &subscription.allowed_langs; + + match subscription.timeline { + tl if tl.is_public() + && !update.language_unset() + && !allowed_langs.is_empty() + && !allowed_langs.contains(&update.language()) => + { + false + } + _ if !blocks.blocked_users.is_disjoint(&update.involved_users()) => false, + _ if blocks.blocking_users.contains(update.author()) => false, + _ if blocks.blocked_domains.contains(update.sent_from()) => false, + _ => true, + } + } } diff --git a/src/response/stream/ws.rs b/src/response/stream/ws.rs index 354fccf..ed51924 100644 --- a/src/response/stream/ws.rs +++ b/src/response/stream/ws.rs @@ -1,10 +1,12 @@ -use crate::event::Event; +use super::{Event, Payload}; use crate::request::{Subscription, Timeline}; use futures::{future::Future, stream::Stream}; use tokio::sync::{mpsc, watch}; use warp::ws::{Message, WebSocket}; +type Result = std::result::Result; + pub struct Ws { unsubscribe_tx: mpsc::UnboundedSender, subscription: Subscription, @@ -52,40 +54,38 @@ impl Ws { incoming_events.for_each(move |(tl, event)| { if matches!(event, Event::Ping) { - self.send_msg(&event) + self.send_msg(&event)? } else if target_timeline == tl { - use crate::event::{CheckedEvent::Update, Event::*, EventKind}; - use crate::request::Stream::Public; - let blocks = &self.subscription.blocks; - let allowed_langs = &self.subscription.allowed_langs; - - match event { - TypeSafe(Update { payload, queued_at }) => match tl { - Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()), - _ if payload.involves_any(&blocks) => Ok(()), - _ => self.send_msg(&TypeSafe(Update { payload, queued_at })), - }, - TypeSafe(non_update) => self.send_msg(&TypeSafe(non_update)), - Dynamic(dyn_event) => { - if let EventKind::Update(s) = dyn_event.kind.clone() { - match tl { - Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()), - _ if s.involves_any(&blocks) => Ok(()), - _ => self.send_msg(&Dynamic(dyn_event)), - } - } else { - self.send_msg(&Dynamic(dyn_event)) - } - } - Ping => unreachable!(), // handled pings above + match (event.update_payload(), event.dyn_update_payload()) { + (Some(update), _) => self.send_or_filter(tl, &event, update)?, + (None, None) => self.send_msg(&event)?, // send all non-updates + (_, Some(dyn_update)) => self.send_or_filter(tl, &event, dyn_update)?, } - } else { - Ok(()) } + Ok(()) }) } - fn send_msg(&mut self, event: &Event) -> Result<(), ()> { + fn send_or_filter(&mut self, tl: Timeline, event: &Event, update: &impl Payload) -> Result<()> { + let blocks = &self.subscription.blocks; + let allowed_langs = &self.subscription.allowed_langs; + const SKIP: Result<()> = Ok(()); + match tl { + tl if tl.is_public() + && !update.language_unset() + && !allowed_langs.is_empty() + && !allowed_langs.contains(&update.language()) => + { + SKIP + } + _ if !blocks.blocked_users.is_disjoint(&update.involved_users()) => SKIP, + _ if blocks.blocking_users.contains(update.author()) => SKIP, + _ if blocks.blocked_domains.contains(update.sent_from()) => SKIP, + _ => Ok(self.send_msg(&event)?), + } + } + + fn send_msg(&mut self, event: &Event) -> Result<()> { let txt = &event.to_json_string(); let tl = self.subscription.timeline; let mut channel = self.ws_tx.clone().ok_or(())?;