From 638364883f98b3dd7acf5e29969f6f0d3134eb06 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Fri, 10 Apr 2020 17:06:13 -0400 Subject: [PATCH] Add additional error handling --- old | 447 ------------------ src/messages/event/checked_event/id.rs | 36 +- src/messages/event/checked_event/mod.rs | 1 + .../event/checked_event/status/mod.rs | 19 +- src/messages/event/dynamic_event/mod.rs | 136 ++++++ src/messages/event/err.rs | 33 ++ src/messages/event/mod.rs | 32 +- src/messages/mod.rs | 2 +- src/parse_client_request/postgres.rs | 34 +- src/parse_client_request/subscription.rs | 15 +- src/redis_to_client_stream/event_stream.rs | 43 +- src/redis_to_client_stream/receiver/err.rs | 11 +- 12 files changed, 295 insertions(+), 514 deletions(-) delete mode 100644 old create mode 100644 src/messages/event/dynamic_event/mod.rs create mode 100644 src/messages/event/err.rs diff --git a/old b/old deleted file mode 100644 index bc8a053..0000000 --- a/old +++ /dev/null @@ -1,447 +0,0 @@ -use crate::log_fatal; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::boxed::Box; -use std::{collections::HashSet, string::String}; - -pub enum Event { - TypeSafe(CheckedEvent), - Dynamic(DynamicEvent), -} - -impl Event { - pub fn to_json_string(&self) -> String { - let event = &self.event_name(); - let sendable_event = match self.payload() { - Some(payload) => SendableEvent::WithPayload { event, payload }, - None => SendableEvent::NoPayload { event }, - }; - serde_json::to_string(&sendable_event) - .unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event)) - } - - pub fn event_name(&self) -> String { - String::from(match self { - Self::TypeSafe(checked) => match checked { - CheckedEvent::Update { .. } => "update", - CheckedEvent::Notification { .. } => "notification", - CheckedEvent::Delete { .. } => "delete", - CheckedEvent::Announcement { .. } => "announcement", - CheckedEvent::AnnouncementReaction { .. } => "announcement.reaction", - CheckedEvent::AnnouncementDelete { .. } => "announcement.delete", - CheckedEvent::Conversation { .. } => "conversation", - CheckedEvent::FiltersChanged => "filters_changed", - }, - Self::Dynamic(dyn_event) => &dyn_event.event, - }) - } - - pub fn payload(&self) -> Option { - use CheckedEvent::*; - match self { - Self::TypeSafe(checked) => match checked { - Update { payload, .. } => Some(escaped(payload)), - Notification { payload, .. } => Some(escaped(payload)), - Delete { payload, .. } => Some(payload.0.clone()), - Announcement { payload, .. } => Some(escaped(payload)), - AnnouncementReaction { payload, .. } => Some(escaped(payload)), - AnnouncementDelete { payload, .. } => Some(payload.0.clone()), - Conversation { payload, .. } => Some(escaped(payload)), - FiltersChanged => None, - }, - Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()), - } - } -} - -#[derive(Deserialize, Debug, Clone, PartialEq)] -pub struct DynamicEvent { - pub event: String, - payload: Value, - queued_at: Option, -} - -#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] -#[rustfmt::skip] -#[derive(Deserialize, Debug, Clone, PartialEq)] -pub enum CheckedEvent { - Update { payload: Status, queued_at: Option }, - Notification { payload: Notification }, - Delete { payload: DeletedId }, - FiltersChanged, - Announcement { payload: Announcement }, - #[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))] - AnnouncementReaction { payload: AnnouncementReaction }, - #[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))] - AnnouncementDelete { payload: DeletedId }, - Conversation { payload: Conversation, queued_at: Option }, -} - -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum SendableEvent<'a> { - WithPayload { event: &'a str, payload: String }, - NoPayload { event: &'a str }, -} - -fn escaped(content: T) -> String { - serde_json::to_string(&content) - .unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content)) -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Conversation { - id: String, - accounts: Vec, - unread: bool, - last_status: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct DeletedId(String); - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Status { - id: String, - uri: String, - created_at: String, - account: Account, - content: String, - visibility: Visibility, - sensitive: bool, - spoiler_text: String, - media_attachments: Vec, - application: Option, // Should be non-optional? - mentions: Vec, - tags: Vec, - emojis: Vec, - reblogs_count: i64, - favourites_count: i64, - replies_count: i64, - url: Option, - in_reply_to_id: Option, - in_reply_to_account_id: Option, - reblog: Option>, - poll: Option, - card: Option, - language: Option, - text: Option, - // ↓↓↓ Only for authorized users - favourited: Option, - reblogged: Option, - muted: Option, - bookmarked: Option, - pinned: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum Visibility { - Public, - Unlisted, - Private, - Direct, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Account { - id: String, - username: String, - acct: String, - url: String, - display_name: String, - note: String, - avatar: String, - avatar_static: String, - header: String, - header_static: String, - locked: bool, - emojis: Vec, - discoverable: Option, // Shouldn't be option? - created_at: String, - statuses_count: i64, - followers_count: i64, - following_count: i64, - moved: Option>, - fields: Option>, - bot: Option, - source: Option, - group: Option, // undocumented - last_status_at: Option, // undocumented -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Attachment { - id: String, - r#type: AttachmentType, - url: String, - preview_url: String, - remote_url: Option, - text_url: Option, - meta: Option, - description: Option, - blurhash: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum AttachmentType { - Unknown, - Image, - Gifv, - Video, - Audio, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Application { - name: String, - website: Option, - vapid_key: Option, - client_id: Option, - client_secret: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Emoji { - shortcode: String, - url: String, - static_url: String, - visible_in_picker: bool, - category: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Field { - name: String, - value: String, - verified_at: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Source { - note: String, - fields: Vec, - privacy: Option, - sensitive: bool, - language: String, - follow_requests_count: i64, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Mention { - id: String, - username: String, - acct: String, - url: String, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Tag { - name: String, - url: String, - history: Option>, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Poll { - id: String, - expires_at: String, - expired: bool, - multiple: bool, - votes_count: i64, - voters_count: Option, - voted: Option, - own_votes: Option>, - options: Vec, - emojis: Vec, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct PollOptions { - title: String, - votes_count: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Card { - url: String, - title: String, - description: String, - r#type: CardType, - author_name: Option, - author_url: Option, - provider_name: Option, - provider_url: Option, - html: Option, - width: Option, - height: Option, - image: Option, - embed_url: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum CardType { - Link, - Photo, - Video, - Rich, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct History { - day: String, - uses: String, - accounts: String, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Notification { - id: String, - r#type: NotificationType, - created_at: String, - account: Account, - status: Option, -} - -#[serde(rename_all = "snake_case", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum NotificationType { - Follow, - FollowRequest, // Undocumented - Mention, - Reblog, - Favourite, - Poll, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Announcement { - // Fully undocumented - id: String, - tags: Vec, - all_day: bool, - content: String, - emojis: Vec, - starts_at: Option, - ends_at: Option, - published_at: String, - updated_at: String, - mentions: Vec, - reactions: Vec, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct AnnouncementReaction { - #[serde(skip_serializing_if = "Option::is_none")] - announcement_id: Option, - count: i64, - name: String, -} - -impl Status { - /// Returns `true` if the status is filtered out based on its language - pub fn language_not_allowed(&self, allowed_langs: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - - let reject_and_maybe_log = |toot_language| { - log::info!("Filtering out toot from `{}`", &self.account.acct); - log::info!("Toot language: `{}`", toot_language); - log::info!("Recipient's allowed languages: `{:?}`", allowed_langs); - REJECT - }; - if allowed_langs.is_empty() { - return ALLOW; // listing no allowed_langs results in allowing all languages - } - - match self.language.as_ref() { - Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, - None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == &String::new() => ALLOW, - Some(toot_language) => reject_and_maybe_log(toot_language), - } - } - - /// Returns `true` if this toot originated from a domain the User has blocked. - pub fn from_blocked_domain(&self, blocked_domains: &HashSet) -> bool { - let full_username = &self.account.acct; - - match full_username.split('@').nth(1) { - Some(originating_domain) => blocked_domains.contains(originating_domain), - None => false, // None means the user is on the local instance, which can't be blocked - } - } - - /// Returns `true` if the Status is from an account that has blocked the current user. - pub fn from_blocking_user(&self, blocking_users: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self); - - if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) { - REJECT - } else { - ALLOW - } - } - - /// Returns `true` if the User's list of blocked and muted users includes a user - /// involved in this toot. - /// - /// A user is involved if they: - /// * Are mentioned in this toot - /// * Wrote this toot - /// * Wrote a toot that this toot is replying to (if any) - /// * Wrote the toot that this toot is boosting (if any) - pub fn involves_blocked_user(&self, blocked_users: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); - - // involved_users = mentioned_users + author + replied-to user + boosted user - let mut involved_users: HashSet = self - .mentions - .iter() - .map(|mention| mention.id.parse().unwrap_or_else(err)) - .collect(); - - involved_users.insert(self.account.id.parse::().unwrap_or_else(err)); - - if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() { - involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err)); - } - - if let Some(boosted_status) = self.reblog.clone() { - involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); - } - - if involved_users.is_disjoint(blocked_users) { - ALLOW - } else { - REJECT - } - } -} - -#[cfg(test)] -mod test; diff --git a/src/messages/event/checked_event/id.rs b/src/messages/event/checked_event/id.rs index e773a7f..8eb031a 100644 --- a/src/messages/event/checked_event/id.rs +++ b/src/messages/event/checked_event/id.rs @@ -1,17 +1,42 @@ +use super::super::EventErr; + use serde::{ de::{self, Visitor}, Deserialize, Deserializer, Serialize, Serializer, }; -use std::fmt; +use serde_json::Value; +use std::{convert::TryFrom, fmt, num::ParseIntError, str::FromStr}; /// A user ID. /// /// Internally, Mastodon IDs are i64s, but are sent to clients as string because /// JavaScript numbers don't support i64s. This newtype serializes to/from a string, but /// keeps the i64 as the "true" value for internal use. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct Id(pub i64); +impl TryFrom<&Value> for Id { + type Error = EventErr; + + fn try_from(v: &Value) -> Result { + Ok(v.as_str().ok_or(EventErr::DynParse)?.parse()?) + } +} + +impl std::ops::Deref for Id { + type Target = i64; + fn deref(&self) -> &i64 { + &self.0 + } +} +impl FromStr for Id { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + Ok(Self(s.parse()?)) + } +} + impl Serialize for Id { fn serialize(&self, serializer: S) -> Result where @@ -38,6 +63,13 @@ impl<'de> Visitor<'de> for IdVisitor { formatter.write_str("a string that can be parsed into an i64") } + fn visit_str(self, value: &str) -> Result { + match value.parse() { + Ok(n) => Ok(Id(n)), + Err(e) => Err(E::custom(format!("could not parse: {}", e))), + } + } + fn visit_string(self, value: String) -> Result { match value.parse() { Ok(n) => Ok(Id(n)), diff --git a/src/messages/event/checked_event/mod.rs b/src/messages/event/checked_event/mod.rs index 4702cb1..ac78809 100644 --- a/src/messages/event/checked_event/mod.rs +++ b/src/messages/event/checked_event/mod.rs @@ -14,6 +14,7 @@ mod visibility; pub use announcement::Announcement; pub(in crate::messages::event) use announcement_reaction::AnnouncementReaction; pub use conversation::Conversation; +pub use id::Id; pub use notification::Notification; pub use status::Status; diff --git a/src/messages/event/checked_event/status/mod.rs b/src/messages/event/checked_event/status/mod.rs index 9a118e8..e517870 100644 --- a/src/messages/event/checked_event/status/mod.rs +++ b/src/messages/event/checked_event/status/mod.rs @@ -92,7 +92,7 @@ impl Status { blocking_users, blocked_domains, } = blocks; - let user_id = &self.account.id.0; + let user_id = &Id(self.account.id.0); if blocking_users.contains(user_id) || self.involves(blocked_users) { REJECT @@ -105,20 +105,23 @@ impl Status { } } - fn involves(&self, blocked_users: &HashSet) -> bool { + fn involves(&self, blocked_users: &HashSet) -> bool { // involved_users = mentioned_users + author + replied-to user + boosted user - let mut involved_users: HashSet = - self.mentions.iter().map(|mention| mention.id.0).collect(); + let mut involved_users: HashSet = self + .mentions + .iter() + .map(|mention| Id(mention.id.0)) + .collect(); // author - involved_users.insert(self.account.id.0); + involved_users.insert(Id(self.account.id.0)); // replied-to user - if let Some(user_id) = self.in_reply_to_account_id.clone() { - involved_users.insert(user_id.0); + if let Some(user_id) = self.in_reply_to_account_id { + involved_users.insert(Id(user_id.0)); } // boosted user if let Some(boosted_status) = self.reblog.clone() { - involved_users.insert(boosted_status.account.id.0); + involved_users.insert(Id(boosted_status.account.id.0)); } !involved_users.is_disjoint(blocked_users) } diff --git a/src/messages/event/dynamic_event/mod.rs b/src/messages/event/dynamic_event/mod.rs new file mode 100644 index 0000000..adcc150 --- /dev/null +++ b/src/messages/event/dynamic_event/mod.rs @@ -0,0 +1,136 @@ +use super::{EventErr, Id}; +use crate::parse_client_request::Blocks; + +use std::convert::TryFrom; + +use hashbrown::HashSet; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct DynEvent { + #[serde(skip)] + pub kind: EventKind, + pub event: String, + pub payload: Value, + pub queued_at: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum EventKind { + Update(DynStatus), + NonUpdate, +} +impl Default for EventKind { + fn default() -> Self { + Self::NonUpdate + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DynStatus { + pub id: Id, + pub username: String, + pub language: Option, + pub mentioned_users: HashSet, + pub replied_to_user: Option, + pub boosted_user: Option, + pub payload: Value, +} + +type Result = std::result::Result; // TODO cut if not used more than once + +impl DynEvent { + pub fn set_update(self) -> Result { + if self.event == "update" { + let kind = EventKind::Update(DynStatus::new(self.payload.clone())?); + Ok(Self { kind, ..self }) + } else { + Ok(self) + } + } +} + +impl DynStatus { + pub fn new(payload: Value) -> Result { + use EventErr::*; + + Ok(Self { + id: Id::try_from(&payload["account"]["id"])?, + username: payload["account"]["acct"] + .as_str() + .ok_or(DynParse)? + .to_string(), + language: payload["language"].as_str().map(|s| s.to_string()), + mentioned_users: HashSet::new(), + replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(), + boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(), + payload, + }) + } + /// Returns `true` if the status is filtered out based on its language + pub fn language_not(&self, allowed_langs: &HashSet) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + + if allowed_langs.is_empty() { + return ALLOW; // listing no allowed_langs results in allowing all languages + } + + match self.language.clone() { + Some(toot_language) if allowed_langs.contains(&toot_language) => ALLOW, // + None => ALLOW, // If toot language is unknown, toot is always allowed + Some(empty) if empty == String::new() => ALLOW, + Some(_toot_language) => REJECT, + } + } + + /// Returns `true` if the toot contained in this Event originated from a blocked domain, + /// is from an account that has blocked the current user, or if the User's list of + /// blocked/muted users includes a user involved in the toot. + /// + /// A user is involved in the toot if they: + /// * Are mentioned in this toot + /// * Wrote this toot + /// * Wrote a toot that this toot is replying to (if any) + /// * Wrote the toot that this toot is boosting (if any) + pub fn involves_any(&self, blocks: &Blocks) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + let Blocks { + blocked_users, + blocking_users, + blocked_domains, + } = blocks; + + if self.involves(blocked_users) || blocking_users.contains(&self.id) { + REJECT + } else { + match self.username.split('@').nth(1) { + Some(originating_domain) if blocked_domains.contains(originating_domain) => REJECT, + Some(_) | None => ALLOW, // None means the local instance, which can't be blocked + } + } + } + + // involved_users = mentioned_users + author + replied-to user + boosted user + fn involves(&self, blocked_users: &HashSet) -> bool { + // mentions + let mut involved_users: HashSet = self.mentioned_users.clone(); + + // author + involved_users.insert(self.id); + + // replied-to user + if let Some(user_id) = self.replied_to_user { + involved_users.insert(user_id); + } + + // boosted user + if let Some(user_id) = self.boosted_user { + involved_users.insert(user_id); + } + + !involved_users.is_disjoint(blocked_users) + } +} diff --git a/src/messages/event/err.rs b/src/messages/event/err.rs new file mode 100644 index 0000000..c51b6a3 --- /dev/null +++ b/src/messages/event/err.rs @@ -0,0 +1,33 @@ +use std::{fmt, num::ParseIntError}; + +#[derive(Debug)] +pub enum EventErr { + SerdeParse(serde_json::Error), + NonNumId(ParseIntError), + DynParse, +} + +impl std::error::Error for EventErr {} + +impl fmt::Display for EventErr { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + use EventErr::*; + match self { + SerdeParse(inner) => write!(f, "{}", inner), + NonNumId(inner) => write!(f, "ID could not be parsed: {}", inner), + DynParse => write!(f, "Could not find a required field in input JSON"), + }?; + Ok(()) + } +} + +impl From for EventErr { + fn from(error: ParseIntError) -> Self { + Self::NonNumId(error) + } +} +impl From for EventErr { + fn from(error: serde_json::Error) -> Self { + Self::SerdeParse(error) + } +} diff --git a/src/messages/event/mod.rs b/src/messages/event/mod.rs index b244f51..187fe82 100644 --- a/src/messages/event/mod.rs +++ b/src/messages/event/mod.rs @@ -1,17 +1,21 @@ mod checked_event; mod dynamic_event; +mod err; -pub use {checked_event::CheckedEvent, dynamic_event::DynamicEvent}; +pub use { + checked_event::{CheckedEvent, Id}, + dynamic_event::{DynEvent, DynStatus, EventKind}, + err::EventErr, +}; use crate::log_fatal; -use crate::redis_to_client_stream::ReceiverErr; use serde::Serialize; use std::{convert::TryFrom, string::String}; #[derive(Debug, Clone)] pub enum Event { TypeSafe(CheckedEvent), - Dynamic(DynamicEvent), + Dynamic(DynEvent), Ping, } @@ -38,7 +42,11 @@ impl Event { CheckedEvent::Conversation { .. } => "conversation", CheckedEvent::FiltersChanged => "filters_changed", }, - Self::Dynamic(dyn_event) => &dyn_event.event, + Self::Dynamic(DynEvent { + kind: EventKind::Update(_), + .. + }) => "update", + Self::Dynamic(DynEvent { event, .. }) => event, Self::Ping => panic!("event_name() called on EventNotReady"), }) } @@ -56,21 +64,23 @@ impl Event { Conversation { payload, .. } => Some(escaped(payload)), FiltersChanged => None, }, - Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()), + Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()), Self::Ping => panic!("payload() called on EventNotReady"), } } } impl TryFrom for Event { - type Error = ReceiverErr; - fn try_from(event_txt: String) -> Result { + type Error = EventErr; + + fn try_from(event_txt: String) -> Result { Event::try_from(event_txt.as_str()) } } impl TryFrom<&str> for Event { - type Error = ReceiverErr; - fn try_from(event_txt: &str) -> Result { + type Error = EventErr; + + fn try_from(event_txt: &str) -> Result { match serde_json::from_str(event_txt) { Ok(checked_event) => Ok(Event::TypeSafe(checked_event)), Err(e) => { @@ -80,8 +90,8 @@ impl TryFrom<&str> for Event { Forwarding Redis payload without type checking it.", e ); - let dyn_event: DynamicEvent = serde_json::from_str(&event_txt)?; - Ok(Event::Dynamic(dyn_event)) + + Ok(Event::Dynamic(serde_json::from_str(&event_txt)?)) } } } diff --git a/src/messages/mod.rs b/src/messages/mod.rs index cab529c..2599a30 100644 --- a/src/messages/mod.rs +++ b/src/messages/mod.rs @@ -1,3 +1,3 @@ mod event; -pub use event::{CheckedEvent, DynamicEvent, Event}; +pub use event::{CheckedEvent, DynEvent, Event, EventErr, EventKind, Id}; diff --git a/src/parse_client_request/postgres.rs b/src/parse_client_request/postgres.rs index d3120cf..a979e68 100644 --- a/src/parse_client_request/postgres.rs +++ b/src/parse_client_request/postgres.rs @@ -1,6 +1,7 @@ //! Postgres queries use crate::{ config, + messages::Id, parse_client_request::subscription::{Scope, UserData}, }; use ::postgres; @@ -28,6 +29,7 @@ impl PgPool { .expect("Can connect to local postgres"); Self(pool) } + pub fn select_user(self, token: &str) -> Result { let mut conn = self.0.get().unwrap(); let query_rows = conn @@ -45,7 +47,7 @@ LIMIT 1", ) .expect("Hard-coded query will return Some([0 or more rows])"); if let Some(result_columns) = query_rows.get(0) { - let id = result_columns.get(1); + let id = Id(result_columns.get(1)); let allowed_langs = result_columns .try_get::<_, Vec<_>>(2) .unwrap_or_else(|_| Vec::new()) @@ -96,17 +98,16 @@ LIMIT 1", ) .expect("Hard-coded query will return Some([0 or more rows])"); - match rows.get(0) { - Some(row) => Ok(row.get(0)), - None => Err(warp::reject::custom("Error: Hashtag does not exist.")), - } + rows.get(0) + .map(|row| row.get(0)) + .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) } /// Query Postgres for everyone the user has blocked or muted /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocked_users(self, user_id: i64) -> HashSet { + pub fn select_blocked_users(self, user_id: Id) -> HashSet { self.0 .get() .unwrap() @@ -118,18 +119,18 @@ SELECT target_account_id UNION SELECT target_account_id FROM mutes WHERE account_id = $1", - &[&user_id], + &[&*user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() - .map(|row| row.get(0)) + .map(|row| Id(row.get(0))) .collect() } /// Query Postgres for everyone who has blocked the user /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocking_users(self, user_id: i64) -> HashSet { + pub fn select_blocking_users(self, user_id: Id) -> HashSet { self.0 .get() .unwrap() @@ -138,11 +139,11 @@ UNION SELECT target_account_id SELECT account_id FROM blocks WHERE target_account_id = $1", - &[&user_id], + &[&*user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() - .map(|row| row.get(0)) + .map(|row| Id(row.get(0))) .collect() } @@ -150,13 +151,13 @@ SELECT account_id /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. - pub fn select_blocked_domains(self, user_id: i64) -> HashSet { + pub fn select_blocked_domains(self, user_id: Id) -> HashSet { self.0 .get() .unwrap() .query( "SELECT domain FROM account_domain_blocks WHERE account_id = $1", - &[&user_id], + &[&*user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() @@ -165,7 +166,7 @@ SELECT account_id } /// Test whether a user owns a list - pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool { + pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool { let mut conn = self.0.get().unwrap(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn @@ -181,10 +182,7 @@ LIMIT 1", match rows.get(0) { None => false, - Some(row) => { - let list_owner_id: i64 = row.get(1); - list_owner_id == user_id - } + Some(row) => Id(row.get(1)) == user_id, } } } diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs index 5f796f8..3ad632b 100644 --- a/src/parse_client_request/subscription.rs +++ b/src/parse_client_request/subscription.rs @@ -9,6 +9,7 @@ use super::postgres::PgPool; use super::query::Query; use crate::err::TimelineErr; use crate::log_fatal; +use crate::messages::Id; use hashbrown::HashSet; use lru::LruCache; use uuid::Uuid; @@ -62,8 +63,8 @@ pub struct Subscription { #[derive(Clone, Default, Debug, PartialEq)] pub struct Blocks { pub blocked_domains: HashSet, - pub blocked_users: HashSet, - pub blocking_users: HashSet, + pub blocked_users: HashSet, + pub blocking_users: HashSet, } impl Default for Subscription { @@ -254,11 +255,11 @@ impl Timeline { "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All), "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All), "user" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, All), + true => Timeline(User(*user.id), Federated, All), false => Err(custom("Error: Missing access token"))?, }, "user:notification" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, Notification), + true => Timeline(User(*user.id), Federated, Notification), false => Err(custom("Error: Missing access token"))?, }, "list" => match user.scopes.contains(&Lists) && user_owns_list() { @@ -266,7 +267,7 @@ impl Timeline { false => Err(warp::reject::custom("Error: Missing access token"))?, }, "direct" => match user.scopes.contains(&Statuses) { - true => Timeline(Direct(user.id), Federated, All), + true => Timeline(Direct(*user.id), Federated, All), false => Err(custom("Error: Missing access token"))?, }, other => { @@ -309,7 +310,7 @@ pub enum Scope { } pub struct UserData { - pub id: i64, + pub id: Id, pub allowed_langs: HashSet, pub scopes: HashSet, } @@ -317,7 +318,7 @@ pub struct UserData { impl UserData { fn public() -> Self { Self { - id: -1, + id: Id(-1), allowed_langs: HashSet::new(), scopes: HashSet::new(), } diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs index 666194e..6c44e47 100644 --- a/src/redis_to_client_stream/event_stream.rs +++ b/src/redis_to_client_stream/event_stream.rs @@ -56,7 +56,7 @@ impl WsStream { if matches!(event, Event::Ping) { self.send_ping() } else if target_timeline == tl { - use crate::messages::{CheckedEvent::Update, Event::*}; + use crate::messages::{CheckedEvent::Update, Event::*, EventKind}; use crate::parse_client_request::Stream::Public; let blocks = &self.subscription.blocks; let allowed_langs = &self.subscription.allowed_langs; @@ -68,12 +68,17 @@ impl WsStream { _ => self.send_msg(TypeSafe(Update { payload, queued_at })), }, TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)), - Dynamic(event) if event.event == "update" => match tl { - Timeline(Public, _, _) if event.language_not(allowed_langs) => Ok(()), - _ if event.involves_any(&blocks) => Ok(()), - _ => self.send_msg(Dynamic(event)), - }, - Dynamic(non_update) => self.send_msg(Dynamic(non_update)), + Dynamic(dyn_event) => { + if let EventKind::Update(s) = dyn_event.kind.clone() { + match tl { + Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()), + _ if s.involves_any(&blocks) => Ok(()), + _ => self.send_msg(Dynamic(dyn_event)), + } + } else { + self.send_msg(Dynamic(dyn_event)) + } + } Ping => unreachable!(), // handled pings above } } else { @@ -127,7 +132,10 @@ impl SseStream { let event_stream = sse_rx .filter_map(move |(timeline, event)| { if target_timeline == timeline { - use crate::messages::{CheckedEvent, CheckedEvent::Update, Event::*}; + use crate::messages::{ + CheckedEvent, CheckedEvent::Update, Event::*, EventKind, + }; + use crate::parse_client_request::Stream::Public; match event { TypeSafe(Update { payload, queued_at }) => match timeline { @@ -139,12 +147,19 @@ impl SseStream { })), }, TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), - Dynamic(event) if event.event == "update" => match timeline { - Timeline(Public, _, _) if event.language_not(&allowed_langs) => None, - _ if event.involves_any(&blocks) => None, - _ => Self::reply_with(Event::Dynamic(event)), - }, - Dynamic(non_update) => Self::reply_with(Event::Dynamic(non_update)), + Dynamic(dyn_event) => { + if let EventKind::Update(s) = dyn_event.kind.clone() { + match timeline { + Timeline(Public, _, _) if s.language_not(&allowed_langs) => { + None + } + _ if s.involves_any(&blocks) => None, + _ => Self::reply_with(Dynamic(dyn_event)), + } + } else { + None + } + } Ping => None, // pings handled automatically } } else { diff --git a/src/redis_to_client_stream/receiver/err.rs b/src/redis_to_client_stream/receiver/err.rs index c4e6efe..37e9d66 100644 --- a/src/redis_to_client_stream/receiver/err.rs +++ b/src/redis_to_client_stream/receiver/err.rs @@ -1,16 +1,14 @@ use super::super::redis::{RedisConnErr, RedisParseErr}; use crate::err::TimelineErr; -use crate::messages::Event; +use crate::messages::{Event, EventErr}; use crate::parse_client_request::Timeline; -use serde_json; use std::fmt; - #[derive(Debug)] pub enum ReceiverErr { InvalidId, TimelineErr(TimelineErr), - EventErr(serde_json::Error), + EventErr(EventErr), RedisParseErr(RedisParseErr), RedisConnErr(RedisConnErr), ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>), @@ -35,14 +33,15 @@ impl fmt::Display for ReceiverErr { Ok(()) } } + impl From> for ReceiverErr { fn from(error: tokio::sync::watch::error::SendError<(Timeline, Event)>) -> Self { Self::ChannelSendErr(error) } } -impl From for ReceiverErr { - fn from(error: serde_json::Error) -> Self { +impl From for ReceiverErr { + fn from(error: EventErr) -> Self { Self::EventErr(error) } }