From 81b454c88cb1cc187edf629c0d41abf374fdaad1 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Tue, 31 Mar 2020 09:05:51 -0400 Subject: [PATCH] Extract tests to separate files (#113) This very minor change moves tests from their current location in submodules within the file under test into submodules in separate files. This is a slight deviation from the normal Rust convention (though only very slight, since the module structure remains the same). However, it is justified here since the tests are fairly verbose and including them in the same file was a bit unwieldy. --- Cargo.lock | 2 +- src/messages/mod.rs | 429 +++++++++++++++++ src/{messages.rs => messages/test.rs} | 432 +----------------- src/parse_client_request/mod.rs | 8 +- src/parse_client_request/postgres.rs | 20 +- .../{sse.rs => sse_test.rs} | 0 src/parse_client_request/subscription.rs | 95 +--- .../{ws.rs => ws_test.rs} | 0 .../redis/{redis_msg.rs => redis_msg/mod.rs} | 87 +--- .../redis/redis_msg/test.rs | 54 +++ 10 files changed, 500 insertions(+), 627 deletions(-) create mode 100644 src/messages/mod.rs rename src/{messages.rs => messages/test.rs} (83%) rename src/parse_client_request/{sse.rs => sse_test.rs} (100%) rename src/parse_client_request/{ws.rs => ws_test.rs} (100%) rename src/redis_to_client_stream/redis/{redis_msg.rs => redis_msg/mod.rs} (66%) create mode 100644 src/redis_to_client_stream/redis/redis_msg/test.rs diff --git a/Cargo.lock b/Cargo.lock index 7cb4f27..7c54ed8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.6.5" +version = "0.6.6" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/src/messages/mod.rs b/src/messages/mod.rs new file mode 100644 index 0000000..f162725 --- /dev/null +++ b/src/messages/mod.rs @@ -0,0 +1,429 @@ +use crate::log_fatal; +use serde::{Deserialize, Serialize}; +use serde_json; +use std::boxed::Box; +use std::{collections::HashSet, string::String}; + +#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] +#[rustfmt::skip] +#[derive(Deserialize, Debug, Clone, PartialEq)] +pub enum Event { + Update { payload: Status, queued_at: Option }, + Notification { payload: Notification }, + Delete { payload: DeletedId }, + FiltersChanged, + Announcement { payload: Announcement }, + #[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))] + AnnouncementReaction { payload: AnnouncementReaction }, + #[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))] + AnnouncementDelete { payload: DeletedId }, + Conversation { payload: Conversation, queued_at: Option }, +} + +#[derive(Serialize, Debug, Clone)] +#[serde(untagged)] +pub enum SendableEvent<'a> { + WithPayload { event: &'a str, payload: String }, + NoPayload { event: &'a str }, +} +#[rustfmt::skip] +impl Event { + pub fn event_name(&self) -> String { + use Event::*; + match self { + Update { .. } => "update", + Notification { .. } => "notification", + Delete { .. } => "delete", + Announcement { .. } => "announcement", + AnnouncementReaction { .. } => "announcement.reaction", + AnnouncementDelete { .. } => "announcement.delete", + Conversation { .. } => "conversation", + FiltersChanged => "filters_changed", + } + .to_string() + } + + + pub fn payload(&self) -> Option { + use Event::*; + match self { + Update { payload: status, .. } => Some(escaped(status)), + Notification { payload: notification, .. } => Some(escaped(notification)), + Delete { payload: id, .. } => Some(id.0.clone()), + Announcement { payload: announcement, .. } => Some(escaped(announcement)), + AnnouncementReaction { payload: reaction, .. } => Some(escaped(reaction)), + AnnouncementDelete { payload: id, .. } => Some(id.0.clone()), + Conversation { payload: conversation, ..} => Some(escaped(conversation)), + FiltersChanged => None, + } + } + pub fn to_json_string(&self) -> String { + let event = &self.event_name(); + let sendable_event = match self.payload() { + Some(payload) => SendableEvent::WithPayload { event, payload }, + None => SendableEvent::NoPayload { event }, + }; + serde_json::to_string(&sendable_event) + .unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event)) + } +} + +fn escaped(content: T) -> String { + serde_json::to_string(&content) + .unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content)) +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Conversation { + id: String, + accounts: Vec, + unread: bool, + last_status: Option, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct DeletedId(String); + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Status { + id: String, + uri: String, + created_at: String, + account: Account, + content: String, + visibility: Visibility, + sensitive: bool, + spoiler_text: String, + media_attachments: Vec, + application: Option, // Should be non-optional? + mentions: Vec, + tags: Vec, + emojis: Vec, + reblogs_count: i64, + favourites_count: i64, + replies_count: i64, + url: Option, + in_reply_to_id: Option, + in_reply_to_account_id: Option, + reblog: Option>, + poll: Option, + card: Option, + language: Option, + text: Option, + // ↓↓↓ Only for authorized users + favourited: Option, + reblogged: Option, + muted: Option, + bookmarked: Option, + pinned: Option, +} + +#[serde(rename_all = "lowercase", deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub enum Visibility { + Public, + Unlisted, + Private, + Direct, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Account { + id: String, + username: String, + acct: String, + url: String, + display_name: String, + note: String, + avatar: String, + avatar_static: String, + header: String, + header_static: String, + locked: bool, + emojis: Vec, + discoverable: Option, // Shouldn't be option? + created_at: String, + statuses_count: i64, + followers_count: i64, + following_count: i64, + moved: Option>, + fields: Option>, + bot: Option, + source: Option, + group: Option, // undocumented + last_status_at: Option, // undocumented +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Attachment { + id: String, + r#type: AttachmentType, + url: String, + preview_url: String, + remote_url: Option, + text_url: Option, + meta: Option, + description: Option, + blurhash: Option, +} + +#[serde(rename_all = "lowercase", deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +enum AttachmentType { + Unknown, + Image, + Gifv, + Video, + Audio, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Application { + name: String, + website: Option, + vapid_key: Option, + client_id: Option, + client_secret: Option, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Emoji { + shortcode: String, + url: String, + static_url: String, + visible_in_picker: bool, + category: Option, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Field { + name: String, + value: String, + verified_at: Option, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Source { + note: String, + fields: Vec, + privacy: Option, + sensitive: bool, + language: String, + follow_requests_count: i64, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Mention { + id: String, + username: String, + acct: String, + url: String, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Tag { + name: String, + url: String, + history: Option>, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Poll { + id: String, + expires_at: String, + expired: bool, + multiple: bool, + votes_count: i64, + voters_count: Option, + voted: Option, + own_votes: Option>, + options: Vec, + emojis: Vec, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct PollOptions { + title: String, + votes_count: Option, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct Card { + url: String, + title: String, + description: String, + r#type: CardType, + author_name: Option, + author_url: Option, + provider_name: Option, + provider_url: Option, + html: Option, + width: Option, + height: Option, + image: Option, + embed_url: Option, +} + +#[serde(rename_all = "lowercase", deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +enum CardType { + Link, + Photo, + Video, + Rich, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +struct History { + day: String, + uses: String, + accounts: String, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Notification { + id: String, + r#type: NotificationType, + created_at: String, + account: Account, + status: Option, +} + +#[serde(rename_all = "lowercase", deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +enum NotificationType { + Follow, + Mention, + Reblog, + Favourite, + Poll, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Announcement { + // Fully undocumented + id: String, + tags: Vec, + all_day: bool, + content: String, + emojis: Vec, + starts_at: Option, + ends_at: Option, + published_at: String, + updated_at: String, + mentions: Vec, + reactions: Vec, +} + +#[serde(deny_unknown_fields)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct AnnouncementReaction { + #[serde(skip_serializing_if = "Option::is_none")] + announcement_id: Option, + count: i64, + name: String, +} + +impl Status { + /// Returns `true` if the status is filtered out based on its language + pub fn language_not_allowed(&self, allowed_langs: &HashSet) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + + let reject_and_maybe_log = |toot_language| { + log::info!("Filtering out toot from `{}`", &self.account.acct); + log::info!("Toot language: `{}`", toot_language); + log::info!("Recipient's allowed languages: `{:?}`", allowed_langs); + REJECT + }; + if allowed_langs.is_empty() { + return ALLOW; // listing no allowed_langs results in allowing all languages + } + + match self.language.as_ref() { + Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, + None => ALLOW, // If toot language is unknown, toot is always allowed + Some(empty) if empty == &String::new() => ALLOW, + Some(toot_language) => reject_and_maybe_log(toot_language), + } + } + + /// Returns `true` if this toot originated from a domain the User has blocked. + pub fn from_blocked_domain(&self, blocked_domains: &HashSet) -> bool { + let full_username = &self.account.acct; + + match full_username.split('@').nth(1) { + Some(originating_domain) => blocked_domains.contains(originating_domain), + None => false, // None means the user is on the local instance, which can't be blocked + } + } + /// Returns `true` if the Status is from an account that has blocked the current user. + pub fn from_blocking_user(&self, blocking_users: &HashSet) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self); + + if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) { + REJECT + } else { + ALLOW + } + } + + /// Returns `true` if the User's list of blocked and muted users includes a user + /// involved in this toot. + /// + /// A user is involved if they: + /// * Are mentioned in this toot + /// * Wrote this toot + /// * Wrote a toot that this toot is replying to (if any) + /// * Wrote the toot that this toot is boosting (if any) + pub fn involves_blocked_user(&self, blocked_users: &HashSet) -> bool { + const ALLOW: bool = false; + const REJECT: bool = true; + let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); + + // involved_users = mentioned_users + author + replied-to user + boosted user + let mut involved_users: HashSet = self + .mentions + .iter() + .map(|mention| mention.id.parse().unwrap_or_else(err)) + .collect(); + + involved_users.insert(self.account.id.parse::().unwrap_or_else(err)); + + if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() { + involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err)); + } + + if let Some(boosted_status) = self.reblog.clone() { + involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); + } + + if involved_users.is_disjoint(blocked_users) { + ALLOW + } else { + REJECT + } + } +} + +#[cfg(test)] +mod test; diff --git a/src/messages.rs b/src/messages/test.rs similarity index 83% rename from src/messages.rs rename to src/messages/test.rs index 152d472..7dd033a 100644 --- a/src/messages.rs +++ b/src/messages/test.rs @@ -1,432 +1,4 @@ -use crate::log_fatal; -use serde::{Deserialize, Serialize}; -use serde_json; -use std::boxed::Box; -use std::{collections::HashSet, string::String}; - -#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] -#[rustfmt::skip] -#[derive(Deserialize, Debug, Clone, PartialEq)] -pub enum Event { - Update { payload: Status, queued_at: Option }, - Notification { payload: Notification }, - Delete { payload: DeletedId }, - FiltersChanged, - Announcement { payload: Announcement }, - #[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))] - AnnouncementReaction { payload: AnnouncementReaction }, - #[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))] - AnnouncementDelete { payload: DeletedId }, - Conversation { payload: Conversation, queued_at: Option }, -} - -#[derive(Serialize, Debug, Clone)] -#[serde(untagged)] -pub enum SendableEvent<'a> { - WithPayload { event: &'a str, payload: String }, - NoPayload { event: &'a str }, -} -#[rustfmt::skip] -impl Event { - pub fn event_name(&self) -> String { - use Event::*; - match self { - Update { .. } => "update", - Notification { .. } => "notification", - Delete { .. } => "delete", - Announcement { .. } => "announcement", - AnnouncementReaction { .. } => "announcement.reaction", - AnnouncementDelete { .. } => "announcement.delete", - Conversation { .. } => "conversation", - FiltersChanged => "filters_changed", - } - .to_string() - } - - - pub fn payload(&self) -> Option { - use Event::*; - match self { - Update { payload: status, .. } => Some(escaped(status)), - Notification { payload: notification, .. } => Some(escaped(notification)), - Delete { payload: id, .. } => Some(id.0.clone()), - Announcement { payload: announcement, .. } => Some(escaped(announcement)), - AnnouncementReaction { payload: reaction, .. } => Some(escaped(reaction)), - AnnouncementDelete { payload: id, .. } => Some(id.0.clone()), - Conversation { payload: conversation, ..} => Some(escaped(conversation)), - FiltersChanged => None, - } - } - pub fn to_json_string(&self) -> String { - let event = &self.event_name(); - let sendable_event = match self.payload() { - Some(payload) => SendableEvent::WithPayload { event, payload }, - None => SendableEvent::NoPayload { event }, - }; - serde_json::to_string(&sendable_event) - .unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event)) - } -} - -fn escaped(content: T) -> String { - serde_json::to_string(&content) - .unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content)) -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Conversation { - id: String, - accounts: Vec, - unread: bool, - last_status: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct DeletedId(String); - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Status { - id: String, - uri: String, - created_at: String, - account: Account, - content: String, - visibility: Visibility, - sensitive: bool, - spoiler_text: String, - media_attachments: Vec, - application: Option, // Should be non-optional? - mentions: Vec, - tags: Vec, - emojis: Vec, - reblogs_count: i64, - favourites_count: i64, - replies_count: i64, - url: Option, - in_reply_to_id: Option, - in_reply_to_account_id: Option, - reblog: Option>, - poll: Option, - card: Option, - language: Option, - text: Option, - // ↓↓↓ Only for authorized users - favourited: Option, - reblogged: Option, - muted: Option, - bookmarked: Option, - pinned: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub enum Visibility { - Public, - Unlisted, - Private, - Direct, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Account { - id: String, - username: String, - acct: String, - url: String, - display_name: String, - note: String, - avatar: String, - avatar_static: String, - header: String, - header_static: String, - locked: bool, - emojis: Vec, - discoverable: Option, // Shouldn't be option? - created_at: String, - statuses_count: i64, - followers_count: i64, - following_count: i64, - moved: Option>, - fields: Option>, - bot: Option, - source: Option, - group: Option, // undocumented - last_status_at: Option, // undocumented -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Attachment { - id: String, - r#type: AttachmentType, - url: String, - preview_url: String, - remote_url: Option, - text_url: Option, - meta: Option, - description: Option, - blurhash: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum AttachmentType { - Unknown, - Image, - Gifv, - Video, - Audio, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Application { - name: String, - website: Option, - vapid_key: Option, - client_id: Option, - client_secret: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Emoji { - shortcode: String, - url: String, - static_url: String, - visible_in_picker: bool, - category: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Field { - name: String, - value: String, - verified_at: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Source { - note: String, - fields: Vec, - privacy: Option, - sensitive: bool, - language: String, - follow_requests_count: i64, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Mention { - id: String, - username: String, - acct: String, - url: String, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Tag { - name: String, - url: String, - history: Option>, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Poll { - id: String, - expires_at: String, - expired: bool, - multiple: bool, - votes_count: i64, - voters_count: Option, - voted: Option, - own_votes: Option>, - options: Vec, - emojis: Vec, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct PollOptions { - title: String, - votes_count: Option, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct Card { - url: String, - title: String, - description: String, - r#type: CardType, - author_name: Option, - author_url: Option, - provider_name: Option, - provider_url: Option, - html: Option, - width: Option, - height: Option, - image: Option, - embed_url: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum CardType { - Link, - Photo, - Video, - Rich, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -struct History { - day: String, - uses: String, - accounts: String, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Notification { - id: String, - r#type: NotificationType, - created_at: String, - account: Account, - status: Option, -} - -#[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -enum NotificationType { - Follow, - Mention, - Reblog, - Favourite, - Poll, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct Announcement { - // Fully undocumented - id: String, - tags: Vec, - all_day: bool, - content: String, - emojis: Vec, - starts_at: Option, - ends_at: Option, - published_at: String, - updated_at: String, - mentions: Vec, - reactions: Vec, -} - -#[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] -pub struct AnnouncementReaction { - #[serde(skip_serializing_if = "Option::is_none")] - announcement_id: Option, - count: i64, - name: String, -} - -impl Status { - /// Returns `true` if the status is filtered out based on its language - pub fn language_not_allowed(&self, allowed_langs: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - - let reject_and_maybe_log = |toot_language| { - log::info!("Filtering out toot from `{}`", &self.account.acct); - log::info!("Toot language: `{}`", toot_language); - log::info!("Recipient's allowed languages: `{:?}`", allowed_langs); - REJECT - }; - if allowed_langs.is_empty() { - return ALLOW; // listing no allowed_langs results in allowing all languages - } - - match self.language.as_ref() { - Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, - None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == &String::new() => ALLOW, - Some(toot_language) => reject_and_maybe_log(toot_language), - } - } - - /// Returns `true` if this toot originated from a domain the User has blocked. - pub fn from_blocked_domain(&self, blocked_domains: &HashSet) -> bool { - let full_username = &self.account.acct; - - match full_username.split('@').nth(1) { - Some(originating_domain) => blocked_domains.contains(originating_domain), - None => false, // None means the user is on the local instance, which can't be blocked - } - } - /// Returns `true` if the Status is from an account that has blocked the current user. - pub fn from_blocking_user(&self, blocking_users: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self); - - if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) { - REJECT - } else { - ALLOW - } - } - - /// Returns `true` if the User's list of blocked and muted users includes a user - /// involved in this toot. - /// - /// A user is involved if they: - /// * Are mentioned in this toot - /// * Wrote this toot - /// * Wrote a toot that this toot is replying to (if any) - /// * Wrote the toot that this toot is boosting (if any) - pub fn involves_blocked_user(&self, blocked_users: &HashSet) -> bool { - const ALLOW: bool = false; - const REJECT: bool = true; - let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); - - // involved_users = mentioned_users + author + replied-to user + boosted user - let mut involved_users: HashSet = self - .mentions - .iter() - .map(|mention| mention.id.parse().unwrap_or_else(err)) - .collect(); - - involved_users.insert(self.account.id.parse::().unwrap_or_else(err)); - - if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() { - involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err)); - } - - if let Some(boosted_status) = self.reblog.clone() { - involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); - } - - if involved_users.is_disjoint(blocked_users) { - ALLOW - } else { - REJECT - } - } -} - -// #[cfg(test)] -// mod test { +// TODO: Revise these tests to cover *only* the RedisMessage -> (Timeline, Event) parsing // use super::*; // use crate::{ // err::RedisParseErr, @@ -935,5 +507,3 @@ impl Status { // assert_eq!(rest, String::new()); // Ok(()) // } -// } -// TODO: Revise these tests to cover *only* the RedisMessage -> (Timeline, Event) parsing diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs index 0ecc5e9..32a5b08 100644 --- a/src/parse_client_request/mod.rs +++ b/src/parse_client_request/mod.rs @@ -1,9 +1,8 @@ //! Parse the client request and return a Subscription mod postgres; mod query; -mod sse; + mod subscription; -mod ws; pub use self::postgres::PgPool; // TODO consider whether we can remove `Stream` from public API @@ -11,3 +10,8 @@ pub use subscription::{Stream, Subscription, Timeline}; //#[cfg(test)] pub use subscription::{Content, Reach}; + +#[cfg(test)] +mod sse_test; +#[cfg(test)] +mod ws_test; diff --git a/src/parse_client_request/postgres.rs b/src/parse_client_request/postgres.rs index 606f583..69650a5 100644 --- a/src/parse_client_request/postgres.rs +++ b/src/parse_client_request/postgres.rs @@ -107,19 +107,7 @@ LIMIT 1", /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_users(self, user_id: i64) -> HashSet { - // " - // SELECT - // 1 - // FROM blocks - // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) - // OR (account_id = $2 AND target_account_id = $1) - // UNION SELECT - // 1 - // FROM mutes - // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})` - // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`" - self - .0 + self.0 .get() .unwrap() .query( @@ -142,8 +130,7 @@ UNION SELECT target_account_id /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocking_users(self, user_id: i64) -> HashSet { - self - .0 + self.0 .get() .unwrap() .query( @@ -164,8 +151,7 @@ SELECT account_id /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_domains(self, user_id: i64) -> HashSet { - self - .0 + self.0 .get() .unwrap() .query( diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse_test.rs similarity index 100% rename from src/parse_client_request/sse.rs rename to src/parse_client_request/sse_test.rs diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs index dcf8409..db6a94a 100644 --- a/src/parse_client_request/subscription.rs +++ b/src/parse_client_request/subscription.rs @@ -109,6 +109,7 @@ impl Subscription { .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) .boxed() } + fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result { let user = match q.access_token.clone() { Some(token) => pool.clone().select_user(&token)?, @@ -220,97 +221,8 @@ impl Timeline { // Other endpoints don't exist: [..] => Err(TimelineErr::InvalidInput)?, }) - // let (stream, reach, content) = if let Some(ns) = namespace { - // match timeline_slice { - // [n, "timeline", "public"] if n == ns => (Public, Federated, All), - // [_, "timeline", "public"] - // | ["timeline", "public"] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "public", "local"] if ns == n => (Public, Local, All), - // [_, "timeline", "public", "local"] - // | ["timeline", "public", "local"] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "public", "media"] if ns == n => (Public, Federated, Media), - // [_, "timeline", "public", "media"] - // | ["timeline", "public", "media"] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "public", "local", "media"] if ns == n => (Public, Local, Media), - // [_, "timeline", "public", "local", "media"] - // | ["timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "hashtag", tag_name] if ns == n => { - // let tag_id = *cache - // .get(&tag_name.to_string()) - // .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); - // (Hashtag(tag_id), Federated, All) - // } - // [_, "timeline", "hashtag", _tag] - // | ["timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "hashtag", _tag, "local"] if ns == n => (Hashtag(0), Local, All), - // [_, "timeline", "hashtag", _tag, "local"] - // | ["timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", id] if ns == n => (User(id.parse().unwrap()), Federated, All), - // [_, "timeline", _id] - // | ["timeline", _id] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", id, "notification"] if ns == n => - // (User(id.parse()?), Federated, Notification), - - // [_, "timeline", _id, "notification"] - // | ["timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "list", id] if ns == n => (List(id.parse()?), Federated, All), - // [_, "timeline", "list", _id] - // | ["timeline", "list", _id] => Err(RedisNamespaceMismatch)?, - - // [n, "timeline", "direct", id] if ns == n => (Direct(id.parse()?), Federated, All), - // [_, "timeline", "direct", _id] - // | ["timeline", "direct", _id] => Err(RedisNamespaceMismatch)?, - - // [..] => log_fatal!("Unexpected channel from Redis: {:?}", timeline_slice), - // } - // } else { - // match timeline_slice { - // ["timeline", "public"] => (Public, Federated, All), - // [_, "timeline", "public"] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "public", "local"] => (Public, Local, All), - // [_, "timeline", "public", "local"] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "public", "media"] => (Public, Federated, Media), - - // [_, "timeline", "public", "media"] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "public", "local", "media"] => (Public, Local, Media), - // [_, "timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "hashtag", _tag] => (Hashtag(0), Federated, All), - // [_, "timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "hashtag", _tag, "local"] => (Hashtag(0), Local, All), - // [_, "timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?, - - // ["timeline", id] => (User(id.parse().unwrap()), Federated, All), - // [_, "timeline", _id] => Err(RedisNamespaceMismatch)?, - - // ["timeline", id, "notification"] => { - // (User(id.parse().unwrap()), Federated, Notification) - // } - // [_, "timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "list", id] => (List(id.parse().unwrap()), Federated, All), - // [_, "timeline", "list", _id] => Err(RedisNamespaceMismatch)?, - - // ["timeline", "direct", id] => (Direct(id.parse().unwrap()), Federated, All), - // [_, "timeline", "direct", _id] => Err(RedisNamespaceMismatch)?, - - // // Other endpoints don't exist: - // [..] => Err(TimelineErr::InvalidInput)?, - // } - // }; } + fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result { use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag); @@ -353,6 +265,7 @@ impl Timeline { }) } } + #[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] pub enum Stream { User(i64), @@ -362,11 +275,13 @@ pub enum Stream { Public, Unset, } + #[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] pub enum Reach { Local, Federated, } + #[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] pub enum Content { All, diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws_test.rs similarity index 100% rename from src/parse_client_request/ws.rs rename to src/parse_client_request/ws_test.rs diff --git a/src/redis_to_client_stream/redis/redis_msg.rs b/src/redis_to_client_stream/redis/redis_msg/mod.rs similarity index 66% rename from src/redis_to_client_stream/redis/redis_msg.rs rename to src/redis_to_client_stream/redis/redis_msg/mod.rs index 3ce0f65..8c3386d 100644 --- a/src/redis_to_client_stream/redis/redis_msg.rs +++ b/src/redis_to_client_stream/redis/redis_msg/mod.rs @@ -167,89 +167,4 @@ impl<'a> TryFrom> for RedisParseOutput<'a> { } #[cfg(test)] -mod test { - use super::*; - - #[test] - fn parse_redis_subscribe() -> Result<(), RedisParseErr> { - let input = "*3\r\n$9\r\nsubscribe\r\n$15\r\ntimeline:public\r\n:1\r\n"; - - let r_subscribe = match RedisParseOutput::try_from(input) { - Ok(NonMsg(leftover)) => leftover, - Ok(Msg(msg)) => panic!("unexpectedly got a msg: {:?}", msg), - Err(e) => panic!("Error in parsing subscribe command: {:?}", e), - }; - assert!(r_subscribe.is_empty()); - - Ok(()) - } - - #[test] - fn parse_redis_detects_non_newline() -> Result<(), RedisParseErr> { - let input = - "*3QQ$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n"; - - match RedisParseOutput::try_from(input) { - Ok(NonMsg(leftover)) => panic!( - "Parsed an invalid msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})", - &input, leftover - ), - Ok(Msg(msg)) => panic!( - "Parsed an invalid msg as a msg.\nInput `{:?}` parsed to {:?}", - &input, msg - ), - Err(_) => (), // should err - }; - - Ok(()) - } - - fn parse_redis_msg() -> Result<(), RedisParseErr> { - let input = - "*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n"; - - let r_msg = match RedisParseOutput::try_from(input) { - Ok(NonMsg(leftover)) => panic!( - "Parsed a msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})", - &input, leftover - ), - Ok(Msg(msg)) => msg, - Err(e) => panic!("Error in parsing subscribe command: {:?}", e), - }; - - assert!(r_msg.leftover_input.is_empty()); - assert_eq!(r_msg.timeline_txt, "timeline:308"); - assert_eq!(r_msg.event_txt, r#"{"event":"delete","payload":"1038647"}"#); - Ok(()) - } -} - -// #[derive(Debug, Clone, PartialEq, Copy)] -// pub struct RedisUtf8<'a> { -// pub valid_utf8: &'a str, -// pub leftover_bytes: &'a [u8], -// } - -// impl<'a> From<&'a [u8]> for RedisUtf8<'a> { -// fn from(bytes: &'a [u8]) -> Self { -// match str::from_utf8(bytes) { -// Ok(valid_utf8) => Self { -// valid_utf8, -// leftover_bytes: "".as_bytes(), -// }, -// Err(e) => { -// let (valid, after_valid) = bytes.split_at(e.valid_up_to()); -// Self { -// valid_utf8: str::from_utf8(valid).expect("Guaranteed by `.valid_up_to`"), -// leftover_bytes: after_valid, -// } -// } -// } -// } -// } - -// impl<'a> Default for RedisUtf8<'a> { -// fn default() -> Self { -// Self::from("".as_bytes()) -// } -// } +mod test; diff --git a/src/redis_to_client_stream/redis/redis_msg/test.rs b/src/redis_to_client_stream/redis/redis_msg/test.rs new file mode 100644 index 0000000..b59760f --- /dev/null +++ b/src/redis_to_client_stream/redis/redis_msg/test.rs @@ -0,0 +1,54 @@ +use super::*; + +#[test] +fn parse_redis_subscribe() -> Result<(), RedisParseErr> { + let input = "*3\r\n$9\r\nsubscribe\r\n$15\r\ntimeline:public\r\n:1\r\n"; + + let r_subscribe = match RedisParseOutput::try_from(input) { + Ok(NonMsg(leftover)) => leftover, + Ok(Msg(msg)) => panic!("unexpectedly got a msg: {:?}", msg), + Err(e) => panic!("Error in parsing subscribe command: {:?}", e), + }; + assert!(r_subscribe.is_empty()); + + Ok(()) +} + +#[test] +fn parse_redis_detects_non_newline() -> Result<(), RedisParseErr> { + let input = + "*3QQ$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n"; + + match RedisParseOutput::try_from(input) { + Ok(NonMsg(leftover)) => panic!( + "Parsed an invalid msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})", + &input, leftover + ), + Ok(Msg(msg)) => panic!( + "Parsed an invalid msg as a msg.\nInput `{:?}` parsed to {:?}", + &input, msg + ), + Err(_) => (), // should err + }; + + Ok(()) +} + +fn parse_redis_msg() -> Result<(), RedisParseErr> { + let input = + "*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n"; + + let r_msg = match RedisParseOutput::try_from(input) { + Ok(NonMsg(leftover)) => panic!( + "Parsed a msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})", + &input, leftover + ), + Ok(Msg(msg)) => msg, + Err(e) => panic!("Error in parsing subscribe command: {:?}", e), + }; + + assert!(r_msg.leftover_input.is_empty()); + assert_eq!(r_msg.timeline_txt, "timeline:308"); + assert_eq!(r_msg.event_txt, r#"{"event":"delete","payload":"1038647"}"#); + Ok(()) +}