diff --git a/Cargo.lock b/Cargo.lock index e0d0277..9a0bdc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -453,7 +453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.8.0" +version = "0.8.1" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 35a728e..881c450 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.8.0" +version = "0.8.1" authors = ["Daniel Long Sockwell "] edition = "2018" diff --git a/old b/old deleted file mode 100644 index bc8a053..0000000 --- a/old +++ /dev/null @@ -1,447 +0,0 @@ -use crate::log_fatal; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::boxed::Box; -use std::{collections::HashSet, string::String}; - -pub enum Event { - TypeSafe(CheckedEvent), - Dynamic(DynamicEvent), -} - -impl Event { - pub fn to_json_string(&self) -> String { - let event = &self.event_name(); - let sendable_event = match self.payload() { - Some(payload) => SendableEvent::WithPayload { event, payload }, - None => SendableEvent::NoPayload { event }, - }; - serde_json::to_string(&sendable_event) - .unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event)) - } - - pub fn event_name(&self) -> String { - String::from(match self { - Self::TypeSafe(checked) => match checked { - CheckedEvent::Update { .. } => "update", - CheckedEvent::Notification { .. } => "notification", - CheckedEvent::Delete { .. } => "delete", - CheckedEvent::Announcement { .. } => "announcement", - CheckedEvent::AnnouncementReaction { .. } => "announcement.reaction", - CheckedEvent::AnnouncementDelete { .. } => "announcement.delete", - CheckedEvent::Conversation { .. } => "conversation", - CheckedEvent::FiltersChanged => "filters_changed", - }, - Self::Dynamic(dyn_event) => &dyn_event.event, - }) - } - - pub fn payload(&self) -> Option { - use CheckedEvent::*; - match self { - Self::TypeSafe(checked) => match checked { - Update { payload, .. } => Some(escaped(payload)), - Notification { payload, .. } => Some(escaped(payload)), - Delete { payload, .. } => Some(payload.0.clone()), - Announcement { payload, .. } => Some(escaped(payload)), - AnnouncementReaction { payload, .. } => Some(escaped(payload)), - AnnouncementDelete { payload, .. } => Some(payload.0.clone()), - Conversation { payload, .. } => Some(escaped(payload)), - FiltersChanged => None, - }, - Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()), - } - } -} - -#[derive(Deserialize, Debug, Clone, PartialEq)] -pub struct DynamicEvent { - pub event: String, - payload: Value, - queued_at: Option, -} - -#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] -#[rustfmt::skip] -#[derive(Deserialize, Debug, Clone, PartialEq)] -pub enum CheckedEvent { - Update { payload: Status, queued_at: Option }, - Notification { payload: Notification }, - Delete { payload: DeletedId }, - FiltersChanged, - Announcement { payload: Announcement }, - #[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))] - AnnouncementReaction { payload: AnnouncementReaction }, - #[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))] - AnnouncementDelete { payload: DeletedId }, - Conversation { payload: Conversation, queued_at: Option }, -} - -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum SendableEvent<'a> { - WithPayload { event: &'a str, payload: String }, - NoPayload { event: &'a str }, -} - -fn escaped(content: T) -> String { - serde_json::to_string(&content) - .unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content)) -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Conversation { - id: String, - accounts: Vec, - unread: bool, - last_status: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct DeletedId(String); - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Status { - id: String, - uri: String, - created_at: String, - account: Account, - content: String, - visibility: Visibility, - sensitive: bool, - spoiler_text: String, - media_attachments: Vec, - application: Option, // Should be non-optional? - mentions: Vec, - tags: Vec, - emojis: Vec, - reblogs_count: i64, - favourites_count: i64, - replies_count: i64, - url: Option, - in_reply_to_id: Option, - in_reply_to_account_id: Option, - reblog: Option>, - poll: Option, - card: Option, - language: Option, - text: Option, - // ↓↓↓ Only for authorized users - favourited: Option, - reblogged: Option, - muted: Option, - bookmarked: Option, - pinned: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum Visibility { - Public, - Unlisted, - Private, - Direct, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Account { - id: String, - username: String, - acct: String, - url: String, - display_name: String, - note: String, - avatar: String, - avatar_static: String, - header: String, - header_static: String, - locked: bool, - emojis: Vec, - discoverable: Option, // Shouldn't be option? - created_at: String, - statuses_count: i64, - followers_count: i64, - following_count: i64, - moved: Option>, - fields: Option>, - bot: Option, - source: Option, - group: Option, // undocumented - last_status_at: Option, // undocumented -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Attachment { - id: String, - r#type: AttachmentType, - url: String, - preview_url: String, - remote_url: Option, - text_url: Option, - meta: Option, - description: Option, - blurhash: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum AttachmentType { - Unknown, - Image, - Gifv, - Video, - Audio, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Application { - name: String, - website: Option, - vapid_key: Option, - client_id: Option, - client_secret: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Emoji { - shortcode: String, - url: String, - static_url: String, - visible_in_picker: bool, - category: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Field { - name: String, - value: String, - verified_at: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Source { - note: String, - fields: Vec, - privacy: Option, - sensitive: bool, - language: String, - follow_requests_count: i64, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Mention { - id: String, - username: String, - acct: String, - url: String, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Tag { - name: String, - url: String, - history: Option>, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Poll { - id: String, - expires_at: String, - expired: bool, - multiple: bool, - votes_count: i64, - voters_count: Option, - voted: Option, - own_votes: Option>, - options: Vec, - emojis: Vec, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct PollOptions { - title: String, - votes_count: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Card { - url: String, - title: String, - description: String, - r#type: CardType, - author_name: Option, - author_url: Option, - provider_name: Option, - provider_url: Option, - html: Option, - width: Option, - height: Option, - image: Option, - embed_url: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum CardType { - Link, - Photo, - Video, - Rich, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct History { - day: String, - uses: String, - accounts: String, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Notification { - id: String, - r#type: NotificationType, - created_at: String, - account: Account, - status: Option, -} - -#[serde(rename_all = "snake_case", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum NotificationType { - Follow, - FollowRequest, // Undocumented - Mention, - Reblog, - Favourite, - Poll, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Announcement { - // Fully undocumented - id: String, - tags: Vec, - all_day: bool, - content: String, - emojis: Vec, - starts_at: Option, - ends_at: Option, - published_at: String, - updated_at: String, - mentions: Vec, - reactions: Vec, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct AnnouncementReaction { - #[serde(skip_serializing_if = "Option::is_none")] - announcement_id: Option, - count: i64, - name: String, -} - -impl Status { - /// Returns `true` if the status is filtered out based on its language - pub fn language_not_allowed(&self, allowed_langs: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - - let reject_and_maybe_log = |toot_language| { - log::info!("Filtering out toot from `{}`", &self.account.acct); - log::info!("Toot language: `{}`", toot_language); - log::info!("Recipient's allowed languages: `{:?}`", allowed_langs); - REJECT - }; - if allowed_langs.is_empty() { - return ALLOW; // listing no allowed_langs results in allowing all languages - } - - match self.language.as_ref() { - Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, - None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == &String::new() => ALLOW, - Some(toot_language) => reject_and_maybe_log(toot_language), - } - } - - /// Returns `true` if this toot originated from a domain the User has blocked. - pub fn from_blocked_domain(&self, blocked_domains: &HashSet) -> bool { - let full_username = &self.account.acct; - - match full_username.split('@').nth(1) { - Some(originating_domain) => blocked_domains.contains(originating_domain), - None => false, // None means the user is on the local instance, which can't be blocked - } - } - - /// Returns `true` if the Status is from an account that has blocked the current user. - pub fn from_blocking_user(&self, blocking_users: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self); - - if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) { - REJECT - } else { - ALLOW - } - } - - /// Returns `true` if the User's list of blocked and muted users includes a user - /// involved in this toot. - /// - /// A user is involved if they: - /// * Are mentioned in this toot - /// * Wrote this toot - /// * Wrote a toot that this toot is replying to (if any) - /// * Wrote the toot that this toot is boosting (if any) - pub fn involves_blocked_user(&self, blocked_users: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); - - // involved_users = mentioned_users + author + replied-to user + boosted user - let mut involved_users: HashSet = self - .mentions - .iter() - .map(|mention| mention.id.parse().unwrap_or_else(err)) - .collect(); - - involved_users.insert(self.account.id.parse::().unwrap_or_else(err)); - - if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() { - involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err)); - } - - if let Some(boosted_status) = self.reblog.clone() { - involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); - } - - if involved_users.is_disjoint(blocked_users) { - ALLOW - } else { - REJECT - } - } -} - -#[cfg(test)] -mod test; diff --git a/src/err/mod.rs b/src/err/mod.rs index dc96bbb..55a8927 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -2,17 +2,42 @@ mod timeline; pub use timeline::TimelineErr; +use crate::redis_to_client_stream::ReceiverErr; use std::fmt; +pub enum FatalErr { + Err, + ReceiverErr(ReceiverErr), +} + +impl FatalErr { + pub fn exit(msg: impl fmt::Display) { + eprintln!("{}", msg); + std::process::exit(1); + } +} + +impl std::error::Error for FatalErr {} +impl fmt::Debug for FatalErr { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "{}", self) + } +} + +impl fmt::Display for FatalErr { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "Error message") + } +} + +impl From for FatalErr { + fn from(e: ReceiverErr) -> Self { + Self::ReceiverErr(e) + } +} + +// TODO delete vvvv when postgres_cfg.rs has better error handling pub fn die_with_msg(msg: impl fmt::Display) -> ! { eprintln!("FATAL ERROR: {}", msg); std::process::exit(1); } - -#[macro_export] -macro_rules! log_fatal { - ($str:expr, $var:expr) => {{ - log::error!($str, $var); - panic!(); - };}; -} diff --git a/src/err/timeline.rs b/src/err/timeline.rs index 4ba9f34..6f05f89 100644 --- a/src/err/timeline.rs +++ b/src/err/timeline.rs @@ -2,10 +2,12 @@ use std::fmt; #[derive(Debug)] pub enum TimelineErr { - RedisNamespaceMismatch, + MissingHashtag, InvalidInput, } +impl std::error::Error for TimelineErr {} + impl From for TimelineErr { fn from(_error: std::num::ParseIntError) -> Self { Self::InvalidInput @@ -16,8 +18,8 @@ impl fmt::Display for TimelineErr { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { use TimelineErr::*; let msg = match self { - RedisNamespaceMismatch => "TODO: Cut this error", - InvalidInput => "The timeline text from Redis could not be parsed into a supported timeline. TODO: add incoming timeline text" + InvalidInput => "The timeline text from Redis could not be parsed into a supported timeline. TODO: add incoming timeline text", + MissingHashtag => "Attempted to send a hashtag timeline without supplying a tag name", }; write!(f, "{}", msg) } diff --git a/src/main.rs b/src/main.rs index 079d8ba..fed290e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use flodgatt::{ config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig}, + err::FatalErr, messages::Event, parse_client_request::{PgPool, Subscription, Timeline}, redis_to_client_stream::{Receiver, SseStream, WsStream}, @@ -11,7 +12,7 @@ use tokio::{ }; use warp::{http::StatusCode, path, ws::Ws2, Filter, Rejection}; -fn main() { +fn main() -> Result<(), FatalErr> { dotenv::from_filename(match env::var("ENV").ok().as_deref() { Some("production") => ".env.production", Some("development") | None => ".env", @@ -30,12 +31,7 @@ fn main() { let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let poll_freq = *redis_cfg.polling_interval; - let receiver = Receiver::try_from(redis_cfg, event_tx, cmd_rx) - .unwrap_or_else(|e| { - log::error!("{}\nFlodgatt shutting down...", e); - std::process::exit(1); - }) - .into_arc(); + let receiver = Receiver::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc(); log::info!("Streaming server initialized and ready to accept connections"); // Server Sent Events @@ -48,19 +44,13 @@ fn main() { move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); { - let mut receiver = sse_receiver.lock().expect("TODO"); + let mut receiver = sse_receiver.lock().unwrap_or_else(Receiver::recover); receiver.subscribe(&subscription).unwrap_or_else(|e| { log::error!("Could not subscribe to the Redis channel: {}", e) }); } let cmd_tx = sse_cmd_tx.clone(); let sse_rx = sse_rx.clone(); - // self.sse.reply( - // warp::sse::keep_alive() - // .interval(Duration::from_secs(30)) - // .text("thump".to_string()) - // .stream(event_stream), - // ) // send the updates through the SSE connection SseStream::send_events(sse_connection_to_client, cmd_tx, subscription, sse_rx) }, @@ -75,7 +65,8 @@ fn main() { .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); { - let mut receiver = ws_receiver.lock().expect("TODO"); + let mut receiver = ws_receiver.lock().unwrap_or_else(Receiver::recover); + receiver.subscribe(&subscription).unwrap_or_else(|e| { log::error!("Could not subscribe to the Redis channel: {}", e) }); @@ -107,10 +98,10 @@ fn main() { .map(|| "OK") .or(warp::path!("api" / "v1" / "streaming" / "status") .and(warp::path::end()) - .map(move || r1.lock().expect("TODO").count_connections())) + .map(move || r1.lock().unwrap_or_else(Receiver::recover).count())) .or( warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline") - .map(move || r3.lock().expect("TODO").list_connections()), + .map(move || r3.lock().unwrap_or_else(Receiver::recover).list()), ) }; #[cfg(not(feature = "stub_status"))] @@ -149,12 +140,13 @@ fn main() { tokio::run(lazy(move || { let receiver = receiver.clone(); + warp::spawn(lazy(move || { tokio::timer::Interval::new(Instant::now(), poll_freq) .map_err(|e| log::error!("{}", e)) .for_each(move |_| { - let receiver = receiver.clone(); - receiver.lock().expect("TODO").poll_broadcast(); + let mut receiver = receiver.lock().unwrap_or_else(Receiver::recover); + receiver.poll_broadcast().unwrap_or_else(FatalErr::exit); Ok(()) }) })); @@ -162,4 +154,5 @@ fn main() { warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).bind(server_addr) })); }; + Ok(()) } diff --git a/src/messages/event/checked_event/account.rs b/src/messages/event/checked_event/account.rs index f86ef6c..0322cc3 100644 --- a/src/messages/event/checked_event/account.rs +++ b/src/messages/event/checked_event/account.rs @@ -1,10 +1,10 @@ -use super::{emoji::Emoji, visibility::Visibility}; +use super::{emoji::Emoji, id::Id, visibility::Visibility}; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub(super) struct Account { - pub id: String, + pub id: Id, username: String, pub acct: String, url: String, @@ -21,7 +21,7 @@ pub(super) struct Account { statuses_count: i64, followers_count: i64, following_count: i64, - moved: Option>, + moved: Option, fields: Option>, bot: Option, source: Option, diff --git a/src/messages/event/checked_event/id.rs b/src/messages/event/checked_event/id.rs new file mode 100644 index 0000000..0226e5d --- /dev/null +++ b/src/messages/event/checked_event/id.rs @@ -0,0 +1,85 @@ +use super::super::EventErr; + +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use serde_json::Value; +use std::{convert::TryFrom, fmt, num::ParseIntError, str::FromStr}; + +/// A user ID. +/// +/// Internally, Mastodon IDs are i64s, but are sent to clients as string because +/// JavaScript numbers don't support i64s. This newtype serializes to/from a string, but +/// keeps the i64 as the "true" value for internal use. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct Id(pub i64); + +impl TryFrom<&Value> for Id { + type Error = EventErr; + + fn try_from(v: &Value) -> Result { + Ok(v.as_str().ok_or(EventErr::DynParse)?.parse()?) + } +} + +impl std::ops::Deref for Id { + type Target = i64; + fn deref(&self) -> &i64 { + &self.0 + } +} +impl FromStr for Id { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + Ok(Self(s.parse()?)) + } +} + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "{}", self.0) + } +} + +impl Serialize for Id { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&format!("{}", self.0)) + } +} + +impl<'de> Deserialize<'de> for Id { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_string(IdVisitor) + } +} + +struct IdVisitor; +impl<'de> Visitor<'de> for IdVisitor { + type Value = Id; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string that can be parsed into an i64") + } + + fn visit_str(self, value: &str) -> Result { + match value.parse() { + Ok(n) => Ok(Id(n)), + Err(e) => Err(E::custom(format!("could not parse: {}", e))), + } + } + + fn visit_string(self, value: String) -> Result { + match value.parse() { + Ok(n) => Ok(Id(n)), + Err(e) => Err(E::custom(format!("could not parse: {}", e))), + } + } +} diff --git a/src/messages/event/checked_event/mention.rs b/src/messages/event/checked_event/mention.rs index 5f9a876..e5e95f8 100644 --- a/src/messages/event/checked_event/mention.rs +++ b/src/messages/event/checked_event/mention.rs @@ -1,9 +1,10 @@ +use super::id::Id; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub(super) struct Mention { - pub id: String, + pub id: Id, username: String, acct: String, url: String, diff --git a/src/messages/event/checked_event/mod.rs b/src/messages/event/checked_event/mod.rs index 60c7513..ac78809 100644 --- a/src/messages/event/checked_event/mod.rs +++ b/src/messages/event/checked_event/mod.rs @@ -4,6 +4,7 @@ mod announcement; mod announcement_reaction; mod conversation; mod emoji; +mod id; mod mention; mod notification; mod status; @@ -13,6 +14,7 @@ mod visibility; pub use announcement::Announcement; pub(in crate::messages::event) use announcement_reaction::AnnouncementReaction; pub use conversation::Conversation; +pub use id::Id; pub use notification::Notification; pub use status::Status; diff --git a/src/messages/event/checked_event/status/mod.rs b/src/messages/event/checked_event/status/mod.rs index 8a60540..e517870 100644 --- a/src/messages/event/checked_event/status/mod.rs +++ b/src/messages/event/checked_event/status/mod.rs @@ -3,10 +3,11 @@ mod attachment; mod card; mod poll; -use super::{account::Account, emoji::Emoji, mention::Mention, tag::Tag, visibility::Visibility}; +use super::{ + account::Account, emoji::Emoji, id::Id, mention::Mention, tag::Tag, visibility::Visibility, +}; use {application::Application, attachment::Attachment, card::Card, poll::Poll}; -use crate::log_fatal; use crate::parse_client_request::Blocks; use hashbrown::HashSet; @@ -17,7 +18,7 @@ use std::string::String; #[serde(deny_unknown_fields)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Status { - id: String, + id: Id, uri: String, created_at: String, account: Account, @@ -34,8 +35,8 @@ pub struct Status { favourites_count: i64, replies_count: i64, url: Option, - in_reply_to_id: Option, - in_reply_to_account_id: Option, + in_reply_to_id: Option, + in_reply_to_account_id: Option, reblog: Option>, poll: Option, card: Option, @@ -91,7 +92,7 @@ impl Status { blocking_users, blocked_domains, } = blocks; - let user_id = &self.account.id.parse().expect("TODO"); + let user_id = &Id(self.account.id.0); if blocking_users.contains(user_id) || self.involves(blocked_users) { REJECT @@ -104,26 +105,23 @@ impl Status { } } - fn involves(&self, blocked_users: &HashSet) -> bool { - // TODO replace vvvv with error handling - let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); - + fn involves(&self, blocked_users: &HashSet) -> bool { // involved_users = mentioned_users + author + replied-to user + boosted user - let mut involved_users: HashSet = self + let mut involved_users: HashSet = self .mentions .iter() - .map(|mention| mention.id.parse().unwrap_or_else(err)) + .map(|mention| Id(mention.id.0)) .collect(); // author - involved_users.insert(self.account.id.parse::().unwrap_or_else(err)); + involved_users.insert(Id(self.account.id.0)); // replied-to user - if let Some(user_id) = self.in_reply_to_account_id.clone() { - involved_users.insert(user_id.parse().unwrap_or_else(err)); + if let Some(user_id) = self.in_reply_to_account_id { + involved_users.insert(Id(user_id.0)); } // boosted user if let Some(boosted_status) = self.reblog.clone() { - involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); + involved_users.insert(Id(boosted_status.account.id.0)); } !involved_users.is_disjoint(blocked_users) } diff --git a/src/messages/event/dynamic_event.rs b/src/messages/event/dynamic_event.rs deleted file mode 100644 index 6d2ead8..0000000 --- a/src/messages/event/dynamic_event.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::parse_client_request::Blocks; -use hashbrown::HashSet; -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] -pub struct DynamicEvent { - pub event: String, - pub payload: Value, - queued_at: Option, -} - -impl DynamicEvent { - /// Returns `true` if the status is filtered out based on its language - pub fn language_not(&self, allowed_langs: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - - if allowed_langs.is_empty() { - return ALLOW; // listing no allowed_langs results in allowing all languages - } - - match self.payload["language"].as_str() { - Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, - None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == String::new() => ALLOW, - Some(_toot_language) => REJECT, - } - } - /// Returns `true` if the toot contained in this Event originated from a blocked domain, - /// is from an account that has blocked the current user, or if the User's list of - /// blocked/muted users includes a user involved in the toot. - /// - /// A user is involved in the toot if they: - /// * Are mentioned in this toot - /// * Wrote this toot - /// * Wrote a toot that this toot is replying to (if any) - /// * Wrote the toot that this toot is boosting (if any) - pub fn involves_any(&self, blocks: &Blocks) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let Blocks { - blocked_users, - blocking_users, - blocked_domains, - } = blocks; - - let id = self.payload["account"]["id"].as_str().expect("TODO"); - let username = self.payload["account"]["acct"].as_str().expect("TODO"); - - if self.involves(blocked_users) || blocking_users.contains(&id.parse().expect("TODO")) { - REJECT - } else { - let full_username = &username; - match full_username.split('@').nth(1) { - Some(originating_domain) if blocked_domains.contains(originating_domain) => REJECT, - Some(_) | None => ALLOW, // None means the local instance, which can't be blocked - } - } - } - - // involved_users = mentioned_users + author + replied-to user + boosted user - fn involves(&self, blocked_users: &HashSet) -> bool { - // mentions - let mentions = self.payload["mentions"].as_array().expect("TODO"); - let mut involved_users: HashSet = mentions - .iter() - .map(|mention| mention["id"].as_str().expect("TODO").parse().expect("TODO")) - .collect(); - - // author - let author_id = self.payload["account"]["id"].as_str().expect("TODO"); - involved_users.insert(author_id.parse::().expect("TODO")); - // replied-to user - let replied_to_user = self.payload["in_reply_to_account_id"].as_str(); - if let Some(user_id) = replied_to_user { - involved_users.insert(user_id.parse().expect("TODO")); - } - // boosted user - let id_of_boosted_user = self.payload["reblog"]["account"]["id"] - .as_str() - .expect("TODO"); - involved_users.insert(id_of_boosted_user.parse().expect("TODO")); - - !involved_users.is_disjoint(blocked_users) - } -} diff --git a/src/messages/event/dynamic_event/mod.rs b/src/messages/event/dynamic_event/mod.rs new file mode 100644 index 0000000..fbb8462 --- /dev/null +++ b/src/messages/event/dynamic_event/mod.rs @@ -0,0 +1,135 @@ +use super::{EventErr, Id}; +use crate::parse_client_request::Blocks; + +use std::convert::TryFrom; + +use hashbrown::HashSet; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct DynEvent { + #[serde(skip)] + pub kind: EventKind, + pub event: String, + pub payload: Value, + pub queued_at: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum EventKind { + Update(DynStatus), + NonUpdate, +} + +impl Default for EventKind { + fn default() -> Self { + Self::NonUpdate + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DynStatus { + pub id: Id, + pub username: String, + pub language: Option, + pub mentioned_users: HashSet, + pub replied_to_user: Option, + pub boosted_user: Option, +} + +type Result = std::result::Result; + +impl DynEvent { + pub fn set_update(self) -> Result { + if self.event == "update" { + let kind = EventKind::Update(DynStatus::new(self.payload.clone())?); + Ok(Self { kind, ..self }) + } else { + Ok(self) + } + } +} + +impl DynStatus { + pub fn new(payload: Value) -> Result { + use EventErr::*; + + Ok(Self { + id: Id::try_from(&payload["account"]["id"])?, + username: payload["account"]["acct"] + .as_str() + .ok_or(DynParse)? + .to_string(), + language: payload["language"].as_str().map(|s| s.to_string()), + mentioned_users: HashSet::new(), + replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(), + boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(), + }) + } + /// Returns `true` if the status is filtered out based on its language + pub fn language_not(&self, allowed_langs: &HashSet) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + + if allowed_langs.is_empty() { + return ALLOW; // listing no allowed_langs results in allowing all languages + } + + match self.language.clone() { + Some(toot_language) if allowed_langs.contains(&toot_language) => ALLOW, // + None => ALLOW, // If toot language is unknown, toot is always allowed + Some(empty) if empty == String::new() => ALLOW, + Some(_toot_language) => REJECT, + } + } + + /// Returns `true` if the toot contained in this Event originated from a blocked domain, + /// is from an account that has blocked the current user, or if the User's list of + /// blocked/muted users includes a user involved in the toot. + /// + /// A user is involved in the toot if they: + /// * Are mentioned in this toot + /// * Wrote this toot + /// * Wrote a toot that this toot is replying to (if any) + /// * Wrote the toot that this toot is boosting (if any) + pub fn involves_any(&self, blocks: &Blocks) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + let Blocks { + blocked_users, + blocking_users, + blocked_domains, + } = blocks; + + if self.involves(blocked_users) || blocking_users.contains(&self.id) { + REJECT + } else { + match self.username.split('@').nth(1) { + Some(originating_domain) if blocked_domains.contains(originating_domain) => REJECT, + Some(_) | None => ALLOW, // None means the local instance, which can't be blocked + } + } + } + + // involved_users = mentioned_users + author + replied-to user + boosted user + fn involves(&self, blocked_users: &HashSet) -> bool { + // mentions + let mut involved_users: HashSet = self.mentioned_users.clone(); + + // author + involved_users.insert(self.id); + + // replied-to user + if let Some(user_id) = self.replied_to_user { + involved_users.insert(user_id); + } + + // boosted user + if let Some(user_id) = self.boosted_user { + involved_users.insert(user_id); + } + + !involved_users.is_disjoint(blocked_users) + } +} diff --git a/src/messages/event/err.rs b/src/messages/event/err.rs new file mode 100644 index 0000000..c51b6a3 --- /dev/null +++ b/src/messages/event/err.rs @@ -0,0 +1,33 @@ +use std::{fmt, num::ParseIntError}; + +#[derive(Debug)] +pub enum EventErr { + SerdeParse(serde_json::Error), + NonNumId(ParseIntError), + DynParse, +} + +impl std::error::Error for EventErr {} + +impl fmt::Display for EventErr { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use EventErr::*; + match self { + SerdeParse(inner) => write!(f, "{}", inner), + NonNumId(inner) => write!(f, "ID could not be parsed: {}", inner), + DynParse => write!(f, "Could not find a required field in input JSON"), + }?; + Ok(()) + } +} + +impl From for EventErr { + fn from(error: ParseIntError) -> Self { + Self::NonNumId(error) + } +} +impl From for EventErr { + fn from(error: serde_json::Error) -> Self { + Self::SerdeParse(error) + } +} diff --git a/src/messages/event/mod.rs b/src/messages/event/mod.rs index 3c70b7b..8583b98 100644 --- a/src/messages/event/mod.rs +++ b/src/messages/event/mod.rs @@ -1,16 +1,20 @@ mod checked_event; mod dynamic_event; +mod err; -pub use {checked_event::CheckedEvent, dynamic_event::DynamicEvent}; +pub use { + checked_event::{CheckedEvent, Id}, + dynamic_event::{DynEvent, DynStatus, EventKind}, + err::EventErr, +}; -use crate::log_fatal; use serde::Serialize; -use std::string::String; +use std::{convert::TryFrom, string::String}; #[derive(Debug, Clone)] pub enum Event { TypeSafe(CheckedEvent), - Dynamic(DynamicEvent), + Dynamic(DynEvent), Ping, } @@ -21,8 +25,7 @@ impl Event { Some(payload) => SendableEvent::WithPayload { event, payload }, None => SendableEvent::NoPayload { event }, }; - serde_json::to_string(&sendable_event) - .unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event)) + serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize") } pub fn event_name(&self) -> String { @@ -37,8 +40,12 @@ impl Event { CheckedEvent::Conversation { .. } => "conversation", CheckedEvent::FiltersChanged => "filters_changed", }, - Self::Dynamic(dyn_event) => &dyn_event.event, - Self::Ping => panic!("event_name() called on EventNotReady"), + Self::Dynamic(DynEvent { + kind: EventKind::Update(_), + .. + }) => "update", + Self::Dynamic(DynEvent { event, .. }) => event, + Self::Ping => panic!("event_name() called on Ping"), }) } @@ -55,30 +62,34 @@ impl Event { Conversation { payload, .. } => Some(escaped(payload)), FiltersChanged => None, }, - Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()), - Self::Ping => panic!("payload() called on EventNotReady"), + Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()), + Self::Ping => panic!("payload() called on Ping"), } } } -impl From for Event { - fn from(event_txt: String) -> Event { - Event::from(event_txt.as_str()) +impl TryFrom for Event { + type Error = EventErr; + + fn try_from(event_txt: String) -> Result { + Event::try_from(event_txt.as_str()) } } -impl From<&str> for Event { - fn from(event_txt: &str) -> Event { +impl TryFrom<&str> for Event { + type Error = EventErr; + + fn try_from(event_txt: &str) -> Result { match serde_json::from_str(event_txt) { - Ok(checked_event) => Event::TypeSafe(checked_event), + Ok(checked_event) => Ok(Event::TypeSafe(checked_event)), Err(e) => { log::error!( "Error safely parsing Redis input. Mastodon and Flodgatt do not \ - strictly conform to the same version of Mastodon's API.\n{}\ + strictly conform to the same version of Mastodon's API.\n{}\n\ Forwarding Redis payload without type checking it.", e ); - let dyn_event: DynamicEvent = serde_json::from_str(&event_txt).expect("TODO"); - Event::Dynamic(dyn_event) + + Ok(Event::Dynamic(serde_json::from_str(&event_txt)?)) } } } @@ -92,6 +103,5 @@ enum SendableEvent<'a> { } fn escaped(content: T) -> String { - serde_json::to_string(&content) - .unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content)) + serde_json::to_string(&content).expect("Guaranteed by Serialize trait bound") } diff --git a/src/messages/mod.rs b/src/messages/mod.rs index cab529c..2599a30 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -1,3 +1,3 @@ mod event; -pub use event::{CheckedEvent, DynamicEvent, Event}; +pub use event::{CheckedEvent, DynEvent, Event, EventErr, EventKind, Id}; diff --git a/src/parse_client_request/postgres.rs b/src/parse_client_request/postgres.rs index d3120cf..a979e68 100644 --- a/src/parse_client_request/postgres.rs +++ b/src/parse_client_request/postgres.rs @@ -1,6 +1,7 @@ //! Postgres queries use crate::{ config, + messages::Id, parse_client_request::subscription::{Scope, UserData}, }; use ::postgres; @@ -28,6 +29,7 @@ impl PgPool { .expect("Can connect to local postgres"); Self(pool) } + pub fn select_user(self, token: &str) -> Result { let mut conn = self.0.get().unwrap(); let query_rows = conn @@ -45,7 +47,7 @@ LIMIT 1", ) .expect("Hard-coded query will return Some([0 or more rows])"); if let Some(result_columns) = query_rows.get(0) { - let id = result_columns.get(1); + let id = Id(result_columns.get(1)); let allowed_langs = result_columns .try_get::<_, Vec<_>>(2) .unwrap_or_else(|_| Vec::new()) @@ -96,17 +98,16 @@ LIMIT 1", ) .expect("Hard-coded query will return Some([0 or more rows])"); - match rows.get(0) { - Some(row) => Ok(row.get(0)), - None => Err(warp::reject::custom("Error: Hashtag does not exist.")), - } + rows.get(0) + .map(|row| row.get(0)) + .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) } /// Query Postgres for everyone the user has blocked or muted /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocked_users(self, user_id: i64) -> HashSet { + pub fn select_blocked_users(self, user_id: Id) -> HashSet { self.0 .get() .unwrap() @@ -118,18 +119,18 @@ SELECT target_account_id UNION SELECT target_account_id FROM mutes WHERE account_id = $1", - &[&user_id], + &[&*user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() - .map(|row| row.get(0)) + .map(|row| Id(row.get(0))) .collect() } /// Query Postgres for everyone who has blocked the user /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocking_users(self, user_id: i64) -> HashSet { + pub fn select_blocking_users(self, user_id: Id) -> HashSet { self.0 .get() .unwrap() @@ -138,11 +139,11 @@ UNION SELECT target_account_id SELECT account_id FROM blocks WHERE target_account_id = $1", - &[&user_id], + &[&*user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() - .map(|row| row.get(0)) + .map(|row| Id(row.get(0))) .collect() } @@ -150,13 +151,13 @@ SELECT account_id /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocked_domains(self, user_id: i64) -> HashSet { + pub fn select_blocked_domains(self, user_id: Id) -> HashSet { self.0 .get() .unwrap() .query( "SELECT domain FROM account_domain_blocks WHERE account_id = $1", - &[&user_id], + &[&*user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() @@ -165,7 +166,7 @@ SELECT account_id } /// Test whether a user owns a list - pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool { + pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool { let mut conn = self.0.get().unwrap(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn @@ -181,10 +182,7 @@ LIMIT 1", match rows.get(0) { None => false, - Some(row) => { - let list_owner_id: i64 = row.get(1); - list_owner_id == user_id - } + Some(row) => Id(row.get(1)) == user_id, } } } diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs index 5f796f8..64b2a21 100644 --- a/src/parse_client_request/subscription.rs +++ b/src/parse_client_request/subscription.rs @@ -6,16 +6,15 @@ // #[cfg(not(test))] use super::postgres::PgPool; +use super::query; use super::query::Query; use crate::err::TimelineErr; -use crate::log_fatal; + +use crate::messages::Id; + use hashbrown::HashSet; use lru::LruCache; -use uuid::Uuid; -use warp::reject::Rejection; - -use super::query; -use warp::{filters::BoxedFilter, path, Filter}; +use warp::{filters::BoxedFilter, path, reject::Rejection, Filter}; /// Helper macro to match on the first of any of the provided filters macro_rules! any_of { @@ -51,7 +50,6 @@ macro_rules! parse_sse_query { #[derive(Clone, Debug, PartialEq)] pub struct Subscription { - pub id: Uuid, pub timeline: Timeline, pub allowed_langs: HashSet, pub blocks: Blocks, @@ -62,14 +60,13 @@ pub struct Subscription { #[derive(Clone, Default, Debug, PartialEq)] pub struct Blocks { pub blocked_domains: HashSet, - pub blocked_users: HashSet, - pub blocking_users: HashSet, + pub blocked_users: HashSet, + pub blocking_users: HashSet, } impl Default for Subscription { fn default() -> Self { Self { - id: Uuid::new_v4(), timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification), allowed_langs: HashSet::new(), blocks: Blocks::default(), @@ -133,7 +130,6 @@ impl Subscription { }; Ok(Subscription { - id: Uuid::new_v4(), timeline, allowed_langs: user.allowed_langs, blocks: Blocks { @@ -182,30 +178,28 @@ impl Timeline { Self(Unset, Local, Notification) } - pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String { + pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result { use {Content::*, Reach::*, Stream::*}; - match self { + Ok(match self { Timeline(Public, Federated, All) => "timeline:public".into(), Timeline(Public, Local, All) => "timeline:public:local".into(), Timeline(Public, Federated, Media) => "timeline:public:media".into(), Timeline(Public, Local, Media) => "timeline:public:local:media".into(), - Timeline(Hashtag(id), Federated, All) => format!( + Timeline(Hashtag(_id), Federated, All) => format!( "timeline:hashtag:{}", - hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) + hashtag.ok_or_else(|| TimelineErr::MissingHashtag)? ), - Timeline(Hashtag(id), Local, All) => format!( + Timeline(Hashtag(_id), Local, All) => format!( "timeline:hashtag:{}:local", - hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) + hashtag.ok_or_else(|| TimelineErr::MissingHashtag)? ), Timeline(User(id), Federated, All) => format!("timeline:{}", id), Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), - Timeline(one, _two, _three) => { - log_fatal!("Supposedly impossible timeline reached: {:?}", one) - } - } + Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?, + }) } pub fn from_redis_text( @@ -225,10 +219,10 @@ impl Timeline { ["public", "local", "media"] => Timeline(Public, Local, Media), ["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All), ["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All), - [id] => Timeline(User(id.parse().unwrap()), Federated, All), - [id, "notification"] => Timeline(User(id.parse().unwrap()), Federated, Notification), - ["list", id] => Timeline(List(id.parse().unwrap()), Federated, All), - ["direct", id] => Timeline(Direct(id.parse().unwrap()), Federated, All), + [id] => Timeline(User(id.parse()?), Federated, All), + [id, "notification"] => Timeline(User(id.parse()?), Federated, Notification), + ["list", id] => Timeline(List(id.parse()?), Federated, All), + ["direct", id] => Timeline(Direct(id.parse()?), Federated, All), // Other endpoints don't exist: [..] => Err(TimelineErr::InvalidInput)?, }) @@ -266,7 +260,7 @@ impl Timeline { false => Err(warp::reject::custom("Error: Missing access token"))?, }, "direct" => match user.scopes.contains(&Statuses) { - true => Timeline(Direct(user.id), Federated, All), + true => Timeline(Direct(*user.id), Federated, All), false => Err(custom("Error: Missing access token"))?, }, other => { @@ -279,7 +273,8 @@ impl Timeline { #[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] pub enum Stream { - User(i64), + User(Id), + // TODO consider whether List, Direct, and Hashtag should all be `id::Id`s List(i64), Direct(i64), Hashtag(i64), @@ -309,7 +304,7 @@ pub enum Scope { } pub struct UserData { - pub id: i64, + pub id: Id, pub allowed_langs: HashSet, pub scopes: HashSet, } @@ -317,7 +312,7 @@ pub struct UserData { impl UserData { fn public() -> Self { Self { - id: -1, + id: Id(-1), allowed_langs: HashSet::new(), scopes: HashSet::new(), } diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs index 35a86bc..abd13e0 100644 --- a/src/redis_to_client_stream/event_stream.rs +++ b/src/redis_to_client_stream/event_stream.rs @@ -56,7 +56,7 @@ impl WsStream { if matches!(event, Event::Ping) { self.send_ping() } else if target_timeline == tl { - use crate::messages::{CheckedEvent::Update, Event::*}; + use crate::messages::{CheckedEvent::Update, Event::*, EventKind}; use crate::parse_client_request::Stream::Public; let blocks = &self.subscription.blocks; let allowed_langs = &self.subscription.allowed_langs; @@ -68,12 +68,17 @@ impl WsStream { _ => self.send_msg(TypeSafe(Update { payload, queued_at })), }, TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)), - Dynamic(event) if event.event == "update" => match tl { - Timeline(Public, _, _) if event.language_not(allowed_langs) => Ok(()), - _ if event.involves_any(&blocks) => Ok(()), - _ => self.send_msg(Dynamic(event)), - }, - Dynamic(non_update) => self.send_msg(Dynamic(non_update)), + Dynamic(dyn_event) => { + if let EventKind::Update(s) = dyn_event.kind.clone() { + match tl { + Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()), + _ if s.involves_any(&blocks) => Ok(()), + _ => self.send_msg(Dynamic(dyn_event)), + } + } else { + self.send_msg(Dynamic(dyn_event)) + } + } Ping => unreachable!(), // handled pings above } } else { @@ -95,7 +100,9 @@ impl WsStream { match self.ws_tx.try_send(Message::text(txt)) { Ok(_) => Ok(()), Err(_) => { - self.unsubscribe_tx.try_send(tl).expect("TODO"); + self.unsubscribe_tx + .try_send(tl) + .unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e)); Err(()) } } @@ -125,7 +132,10 @@ impl SseStream { let event_stream = sse_rx .filter_map(move |(timeline, event)| { if target_timeline == timeline { - use crate::messages::{CheckedEvent, CheckedEvent::Update, Event::*}; + use crate::messages::{ + CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind, + }; + use crate::parse_client_request::Stream::Public; match event { TypeSafe(Update { payload, queued_at }) => match timeline { @@ -137,12 +147,22 @@ impl SseStream { })), }, TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), - Dynamic(event) if event.event == "update" => match timeline { - Timeline(Public, _, _) if event.language_not(&allowed_langs) => None, - _ if event.involves_any(&blocks) => None, - _ => Self::reply_with(Event::Dynamic(event)), - }, - Dynamic(non_update) => Self::reply_with(Event::Dynamic(non_update)), + Dynamic(dyn_event) => { + if let EventKind::Update(s) = dyn_event.kind { + match timeline { + Timeline(Public, _, _) if s.language_not(&allowed_langs) => { + None + } + _ if s.involves_any(&blocks) => None, + _ => Self::reply_with(Dynamic(DynEvent { + kind: EventKind::Update(s), + ..dyn_event + })), + } + } else { + None + } + } Ping => None, // pings handled automatically } } else { @@ -150,7 +170,9 @@ impl SseStream { } }) .then(move |res| { - unsubscribe_tx.try_send(target_timeline).expect("TODO"); + unsubscribe_tx + .try_send(target_timeline) + .unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e)); res }); diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index f8b88b9..19ed558 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -5,7 +5,7 @@ mod redis; pub use { event_stream::{SseStream, WsStream}, - receiver::Receiver, + receiver::{Receiver, ReceiverErr}, }; #[cfg(feature = "bench")] diff --git a/src/redis_to_client_stream/receiver/err.rs b/src/redis_to_client_stream/receiver/err.rs index 4921145..37e9d66 100644 --- a/src/redis_to_client_stream/receiver/err.rs +++ b/src/redis_to_client_stream/receiver/err.rs @@ -1,18 +1,21 @@ use super::super::redis::{RedisConnErr, RedisParseErr}; use crate::err::TimelineErr; +use crate::messages::{Event, EventErr}; +use crate::parse_client_request::Timeline; -use serde_json; use std::fmt; - #[derive(Debug)] pub enum ReceiverErr { InvalidId, TimelineErr(TimelineErr), - EventErr(serde_json::Error), + EventErr(EventErr), RedisParseErr(RedisParseErr), RedisConnErr(RedisConnErr), + ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>), } +impl std::error::Error for ReceiverErr {} + impl fmt::Display for ReceiverErr { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { use ReceiverErr::*; @@ -25,13 +28,20 @@ impl fmt::Display for ReceiverErr { RedisParseErr(inner) => write!(f, "{}", inner), RedisConnErr(inner) => write!(f, "{}", inner), TimelineErr(inner) => write!(f, "{}", inner), + ChannelSendErr(inner) => write!(f, "{}", inner), }?; Ok(()) } } -impl From for ReceiverErr { - fn from(error: serde_json::Error) -> Self { +impl From> for ReceiverErr { + fn from(error: tokio::sync::watch::error::SendError<(Timeline, Event)>) -> Self { + Self::ChannelSendErr(error) + } +} + +impl From for ReceiverErr { + fn from(error: EventErr) -> Self { Self::EventErr(error) } } diff --git a/src/redis_to_client_stream/receiver/mod.rs b/src/redis_to_client_stream/receiver/mod.rs index 3d7bed3..cb70387 100644 --- a/src/redis_to_client_stream/receiver/mod.rs +++ b/src/redis_to_client_stream/receiver/mod.rs @@ -18,7 +18,7 @@ use tokio::sync::{mpsc, watch}; use std::{ result, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard, PoisonError}, time::{Duration, Instant}, }; @@ -59,7 +59,6 @@ impl Receiver { pub fn subscribe(&mut self, subscription: &Subscription) -> Result<()> { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); - if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) { self.redis_connection.update_cache(hashtag, id); }; @@ -74,7 +73,6 @@ impl Receiver { if *number_of_subscriptions == 1 { self.redis_connection.send_cmd(Subscribe, &tl)? }; - log::info!("Started stream for {:?}", tl); Ok(()) } @@ -99,36 +97,40 @@ impl Receiver { Ok(()) } - pub fn poll_broadcast(&mut self) { + pub fn poll_broadcast(&mut self) -> Result<()> { while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { - self.unsubscribe(tl).expect("TODO"); + self.unsubscribe(tl)? } if self.ping_time.elapsed() > Duration::from_secs(30) { self.ping_time = Instant::now(); - self.tx - .broadcast((Timeline::empty(), Event::Ping)) - .expect("TODO"); + self.tx.broadcast((Timeline::empty(), Event::Ping))? } else { match self.redis_connection.poll_redis() { Ok(Async::NotReady) => (), Ok(Async::Ready(Some((timeline, event)))) => { - self.tx.broadcast((timeline, event)).expect("TODO"); + self.tx.broadcast((timeline, event))? } Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace - Err(_err) => panic!("TODO"), + Err(err) => log::error!("{}", err), // drop msg, log err, and proceed } } + Ok(()) } - pub fn count_connections(&self) -> String { + pub fn recover(poisoned: PoisonError>) -> MutexGuard { + log::error!("{}", &poisoned); + poisoned.into_inner() + } + + pub fn count(&self) -> String { format!( "Current connections: {}", self.clients_per_timeline.values().sum::() ) } - pub fn list_connections(&self) -> String { + pub fn list(&self) -> String { let max_len = self .clients_per_timeline .keys() diff --git a/src/redis_to_client_stream/redis/redis_connection/err.rs b/src/redis_to_client_stream/redis/redis_connection/err.rs index ddaf3b4..bb702c2 100644 --- a/src/redis_to_client_stream/redis/redis_connection/err.rs +++ b/src/redis_to_client_stream/redis/redis_connection/err.rs @@ -1,3 +1,4 @@ +use crate::err::TimelineErr; use std::fmt; #[derive(Debug)] @@ -8,6 +9,7 @@ pub enum RedisConnErr { IncorrectPassword(String), MissingPassword, NotRedis(String), + TimelineErr(TimelineErr), } impl RedisConnErr { @@ -49,11 +51,18 @@ impl fmt::Display for RedisConnErr { REDIS_PORT environmental variables and try again.", addr ), + TimelineErr(inner) => format!("{}", inner), }; write!(f, "{}", msg) } } +impl From for RedisConnErr { + fn from(e: TimelineErr) -> RedisConnErr { + RedisConnErr::TimelineErr(e) + } +} + impl From for RedisConnErr { fn from(e: std::io::Error) -> RedisConnErr { RedisConnErr::UnknownRedisErr(e) diff --git a/src/redis_to_client_stream/redis/redis_connection/mod.rs b/src/redis_to_client_stream/redis/redis_connection/mod.rs index 872fac0..a63c03c 100644 --- a/src/redis_to_client_stream/redis/redis_connection/mod.rs +++ b/src/redis_to_client_stream/redis/redis_connection/mod.rs @@ -10,7 +10,7 @@ use crate::{ }; use std::{ - convert::TryFrom, + convert::{TryFrom, TryInto}, io::{Read, Write}, net::TcpStream, str, @@ -92,13 +92,13 @@ impl RedisConn { Some(ns) if msg.timeline_txt.starts_with(&format!("{}:timeline:", ns)) => { let trimmed_tl_txt = &msg.timeline_txt[ns.len() + ":timeline:".len()..]; let tl = Timeline::from_redis_text(trimmed_tl_txt, &mut self.tag_id_cache)?; - let event = msg.event_txt.into(); + let event = msg.event_txt.try_into()?; (Ok(Ready(Some((tl, event)))), msg.leftover_input) } None => { let trimmed_tl_txt = &msg.timeline_txt["timeline:".len()..]; let tl = Timeline::from_redis_text(trimmed_tl_txt, &mut self.tag_id_cache)?; - let event = msg.event_txt.into(); + let event = msg.event_txt.try_into()?; (Ok(Ready(Some((tl, event)))), msg.leftover_input) } Some(_non_matching_namespace) => (Ok(Ready(None)), msg.leftover_input), @@ -166,8 +166,8 @@ impl RedisConn { Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id), _non_hashtag_timeline => None, }; - let tl = timeline.to_redis_raw_timeline(hashtag); + let tl = timeline.to_redis_raw_timeline(hashtag)?; let (primary_cmd, secondary_cmd) = match cmd { RedisCmd::Subscribe => ( format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl),