Extract tests to separate files (#113)

This very minor change moves tests from their current location in
submodules within the file under test into submodules in separate
files.  This is a slight deviation from the normal Rust convention
(though only very slight, since the module structure remains the
same).  However, it is justified here since the tests are fairly
verbose and including them in the same file was a bit unwieldy.
This commit is contained in:
Daniel Sockwell 2020-03-31 09:05:51 -04:00 committed by GitHub
parent 5965a514fd
commit 81b454c88c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 500 additions and 627 deletions

2
Cargo.lock generated
View File

@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
version = "0.6.5"
version = "0.6.6"
dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",

429
src/messages/mod.rs Normal file
View File

@ -0,0 +1,429 @@
use crate::log_fatal;
use serde::{Deserialize, Serialize};
use serde_json;
use std::boxed::Box;
use std::{collections::HashSet, string::String};
#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)]
#[rustfmt::skip]
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub enum Event {
Update { payload: Status, queued_at: Option<i64> },
Notification { payload: Notification },
Delete { payload: DeletedId },
FiltersChanged,
Announcement { payload: Announcement },
#[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))]
AnnouncementReaction { payload: AnnouncementReaction },
#[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))]
AnnouncementDelete { payload: DeletedId },
Conversation { payload: Conversation, queued_at: Option<i64> },
}
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum SendableEvent<'a> {
WithPayload { event: &'a str, payload: String },
NoPayload { event: &'a str },
}
#[rustfmt::skip]
impl Event {
pub fn event_name(&self) -> String {
use Event::*;
match self {
Update { .. } => "update",
Notification { .. } => "notification",
Delete { .. } => "delete",
Announcement { .. } => "announcement",
AnnouncementReaction { .. } => "announcement.reaction",
AnnouncementDelete { .. } => "announcement.delete",
Conversation { .. } => "conversation",
FiltersChanged => "filters_changed",
}
.to_string()
}
pub fn payload(&self) -> Option<String> {
use Event::*;
match self {
Update { payload: status, .. } => Some(escaped(status)),
Notification { payload: notification, .. } => Some(escaped(notification)),
Delete { payload: id, .. } => Some(id.0.clone()),
Announcement { payload: announcement, .. } => Some(escaped(announcement)),
AnnouncementReaction { payload: reaction, .. } => Some(escaped(reaction)),
AnnouncementDelete { payload: id, .. } => Some(id.0.clone()),
Conversation { payload: conversation, ..} => Some(escaped(conversation)),
FiltersChanged => None,
}
}
pub fn to_json_string(&self) -> String {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event)
.unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event))
}
}
fn escaped<T: Serialize + std::fmt::Debug>(content: T) -> String {
serde_json::to_string(&content)
.unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content))
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Conversation {
id: String,
accounts: Vec<Account>,
unread: bool,
last_status: Option<Status>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct DeletedId(String);
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Status {
id: String,
uri: String,
created_at: String,
account: Account,
content: String,
visibility: Visibility,
sensitive: bool,
spoiler_text: String,
media_attachments: Vec<Attachment>,
application: Option<Application>, // Should be non-optional?
mentions: Vec<Mention>,
tags: Vec<Tag>,
emojis: Vec<Emoji>,
reblogs_count: i64,
favourites_count: i64,
replies_count: i64,
url: Option<String>,
in_reply_to_id: Option<String>,
in_reply_to_account_id: Option<String>,
reblog: Option<Box<Status>>,
poll: Option<Poll>,
card: Option<Card>,
language: Option<String>,
text: Option<String>,
// ↓↓↓ Only for authorized users
favourited: Option<bool>,
reblogged: Option<bool>,
muted: Option<bool>,
bookmarked: Option<bool>,
pinned: Option<bool>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum Visibility {
Public,
Unlisted,
Private,
Direct,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Account {
id: String,
username: String,
acct: String,
url: String,
display_name: String,
note: String,
avatar: String,
avatar_static: String,
header: String,
header_static: String,
locked: bool,
emojis: Vec<Emoji>,
discoverable: Option<bool>, // Shouldn't be option?
created_at: String,
statuses_count: i64,
followers_count: i64,
following_count: i64,
moved: Option<Box<String>>,
fields: Option<Vec<Field>>,
bot: Option<bool>,
source: Option<Source>,
group: Option<bool>, // undocumented
last_status_at: Option<String>, // undocumented
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Attachment {
id: String,
r#type: AttachmentType,
url: String,
preview_url: String,
remote_url: Option<String>,
text_url: Option<String>,
meta: Option<serde_json::Value>,
description: Option<String>,
blurhash: Option<String>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum AttachmentType {
Unknown,
Image,
Gifv,
Video,
Audio,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Application {
name: String,
website: Option<String>,
vapid_key: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Emoji {
shortcode: String,
url: String,
static_url: String,
visible_in_picker: bool,
category: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Field {
name: String,
value: String,
verified_at: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Source {
note: String,
fields: Vec<Field>,
privacy: Option<Visibility>,
sensitive: bool,
language: String,
follow_requests_count: i64,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Mention {
id: String,
username: String,
acct: String,
url: String,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Tag {
name: String,
url: String,
history: Option<Vec<History>>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Poll {
id: String,
expires_at: String,
expired: bool,
multiple: bool,
votes_count: i64,
voters_count: Option<i64>,
voted: Option<bool>,
own_votes: Option<Vec<i64>>,
options: Vec<PollOptions>,
emojis: Vec<Emoji>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct PollOptions {
title: String,
votes_count: Option<i32>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Card {
url: String,
title: String,
description: String,
r#type: CardType,
author_name: Option<String>,
author_url: Option<String>,
provider_name: Option<String>,
provider_url: Option<String>,
html: Option<String>,
width: Option<i64>,
height: Option<i64>,
image: Option<String>,
embed_url: Option<String>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum CardType {
Link,
Photo,
Video,
Rich,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct History {
day: String,
uses: String,
accounts: String,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Notification {
id: String,
r#type: NotificationType,
created_at: String,
account: Account,
status: Option<Status>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum NotificationType {
Follow,
Mention,
Reblog,
Favourite,
Poll,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Announcement {
// Fully undocumented
id: String,
tags: Vec<Tag>,
all_day: bool,
content: String,
emojis: Vec<Emoji>,
starts_at: Option<String>,
ends_at: Option<String>,
published_at: String,
updated_at: String,
mentions: Vec<Mention>,
reactions: Vec<AnnouncementReaction>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AnnouncementReaction {
#[serde(skip_serializing_if = "Option::is_none")]
announcement_id: Option<String>,
count: i64,
name: String,
}
impl Status {
/// Returns `true` if the status is filtered out based on its language
pub fn language_not_allowed(&self, allowed_langs: &HashSet<String>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let reject_and_maybe_log = |toot_language| {
log::info!("Filtering out toot from `{}`", &self.account.acct);
log::info!("Toot language: `{}`", toot_language);
log::info!("Recipient's allowed languages: `{:?}`", allowed_langs);
REJECT
};
if allowed_langs.is_empty() {
return ALLOW; // listing no allowed_langs results in allowing all languages
}
match self.language.as_ref() {
Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW,
None => ALLOW, // If toot language is unknown, toot is always allowed
Some(empty) if empty == &String::new() => ALLOW,
Some(toot_language) => reject_and_maybe_log(toot_language),
}
}
/// Returns `true` if this toot originated from a domain the User has blocked.
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
let full_username = &self.account.acct;
match full_username.split('@').nth(1) {
Some(originating_domain) => blocked_domains.contains(originating_domain),
None => false, // None means the user is on the local instance, which can't be blocked
}
}
/// Returns `true` if the Status is from an account that has blocked the current user.
pub fn from_blocking_user(&self, blocking_users: &HashSet<i64>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self);
if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) {
REJECT
} else {
ALLOW
}
}
/// Returns `true` if the User's list of blocked and muted users includes a user
/// involved in this toot.
///
/// A user is involved if they:
/// * Are mentioned in this toot
/// * Wrote this toot
/// * Wrote a toot that this toot is replying to (if any)
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self);
// involved_users = mentioned_users + author + replied-to user + boosted user
let mut involved_users: HashSet<i64> = self
.mentions
.iter()
.map(|mention| mention.id.parse().unwrap_or_else(err))
.collect();
involved_users.insert(self.account.id.parse::<i64>().unwrap_or_else(err));
if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() {
involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err));
}
if let Some(boosted_status) = self.reblog.clone() {
involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err));
}
if involved_users.is_disjoint(blocked_users) {
ALLOW
} else {
REJECT
}
}
}
#[cfg(test)]
mod test;

View File

@ -1,432 +1,4 @@
use crate::log_fatal;
use serde::{Deserialize, Serialize};
use serde_json;
use std::boxed::Box;
use std::{collections::HashSet, string::String};
#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)]
#[rustfmt::skip]
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub enum Event {
Update { payload: Status, queued_at: Option<i64> },
Notification { payload: Notification },
Delete { payload: DeletedId },
FiltersChanged,
Announcement { payload: Announcement },
#[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))]
AnnouncementReaction { payload: AnnouncementReaction },
#[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))]
AnnouncementDelete { payload: DeletedId },
Conversation { payload: Conversation, queued_at: Option<i64> },
}
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum SendableEvent<'a> {
WithPayload { event: &'a str, payload: String },
NoPayload { event: &'a str },
}
#[rustfmt::skip]
impl Event {
pub fn event_name(&self) -> String {
use Event::*;
match self {
Update { .. } => "update",
Notification { .. } => "notification",
Delete { .. } => "delete",
Announcement { .. } => "announcement",
AnnouncementReaction { .. } => "announcement.reaction",
AnnouncementDelete { .. } => "announcement.delete",
Conversation { .. } => "conversation",
FiltersChanged => "filters_changed",
}
.to_string()
}
pub fn payload(&self) -> Option<String> {
use Event::*;
match self {
Update { payload: status, .. } => Some(escaped(status)),
Notification { payload: notification, .. } => Some(escaped(notification)),
Delete { payload: id, .. } => Some(id.0.clone()),
Announcement { payload: announcement, .. } => Some(escaped(announcement)),
AnnouncementReaction { payload: reaction, .. } => Some(escaped(reaction)),
AnnouncementDelete { payload: id, .. } => Some(id.0.clone()),
Conversation { payload: conversation, ..} => Some(escaped(conversation)),
FiltersChanged => None,
}
}
pub fn to_json_string(&self) -> String {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event)
.unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event))
}
}
fn escaped<T: Serialize + std::fmt::Debug>(content: T) -> String {
serde_json::to_string(&content)
.unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content))
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Conversation {
id: String,
accounts: Vec<Account>,
unread: bool,
last_status: Option<Status>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct DeletedId(String);
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Status {
id: String,
uri: String,
created_at: String,
account: Account,
content: String,
visibility: Visibility,
sensitive: bool,
spoiler_text: String,
media_attachments: Vec<Attachment>,
application: Option<Application>, // Should be non-optional?
mentions: Vec<Mention>,
tags: Vec<Tag>,
emojis: Vec<Emoji>,
reblogs_count: i64,
favourites_count: i64,
replies_count: i64,
url: Option<String>,
in_reply_to_id: Option<String>,
in_reply_to_account_id: Option<String>,
reblog: Option<Box<Status>>,
poll: Option<Poll>,
card: Option<Card>,
language: Option<String>,
text: Option<String>,
// ↓↓↓ Only for authorized users
favourited: Option<bool>,
reblogged: Option<bool>,
muted: Option<bool>,
bookmarked: Option<bool>,
pinned: Option<bool>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum Visibility {
Public,
Unlisted,
Private,
Direct,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Account {
id: String,
username: String,
acct: String,
url: String,
display_name: String,
note: String,
avatar: String,
avatar_static: String,
header: String,
header_static: String,
locked: bool,
emojis: Vec<Emoji>,
discoverable: Option<bool>, // Shouldn't be option?
created_at: String,
statuses_count: i64,
followers_count: i64,
following_count: i64,
moved: Option<Box<String>>,
fields: Option<Vec<Field>>,
bot: Option<bool>,
source: Option<Source>,
group: Option<bool>, // undocumented
last_status_at: Option<String>, // undocumented
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Attachment {
id: String,
r#type: AttachmentType,
url: String,
preview_url: String,
remote_url: Option<String>,
text_url: Option<String>,
meta: Option<serde_json::Value>,
description: Option<String>,
blurhash: Option<String>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum AttachmentType {
Unknown,
Image,
Gifv,
Video,
Audio,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Application {
name: String,
website: Option<String>,
vapid_key: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Emoji {
shortcode: String,
url: String,
static_url: String,
visible_in_picker: bool,
category: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Field {
name: String,
value: String,
verified_at: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Source {
note: String,
fields: Vec<Field>,
privacy: Option<Visibility>,
sensitive: bool,
language: String,
follow_requests_count: i64,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Mention {
id: String,
username: String,
acct: String,
url: String,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Tag {
name: String,
url: String,
history: Option<Vec<History>>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Poll {
id: String,
expires_at: String,
expired: bool,
multiple: bool,
votes_count: i64,
voters_count: Option<i64>,
voted: Option<bool>,
own_votes: Option<Vec<i64>>,
options: Vec<PollOptions>,
emojis: Vec<Emoji>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct PollOptions {
title: String,
votes_count: Option<i32>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Card {
url: String,
title: String,
description: String,
r#type: CardType,
author_name: Option<String>,
author_url: Option<String>,
provider_name: Option<String>,
provider_url: Option<String>,
html: Option<String>,
width: Option<i64>,
height: Option<i64>,
image: Option<String>,
embed_url: Option<String>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum CardType {
Link,
Photo,
Video,
Rich,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct History {
day: String,
uses: String,
accounts: String,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Notification {
id: String,
r#type: NotificationType,
created_at: String,
account: Account,
status: Option<Status>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum NotificationType {
Follow,
Mention,
Reblog,
Favourite,
Poll,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Announcement {
// Fully undocumented
id: String,
tags: Vec<Tag>,
all_day: bool,
content: String,
emojis: Vec<Emoji>,
starts_at: Option<String>,
ends_at: Option<String>,
published_at: String,
updated_at: String,
mentions: Vec<Mention>,
reactions: Vec<AnnouncementReaction>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AnnouncementReaction {
#[serde(skip_serializing_if = "Option::is_none")]
announcement_id: Option<String>,
count: i64,
name: String,
}
impl Status {
/// Returns `true` if the status is filtered out based on its language
pub fn language_not_allowed(&self, allowed_langs: &HashSet<String>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let reject_and_maybe_log = |toot_language| {
log::info!("Filtering out toot from `{}`", &self.account.acct);
log::info!("Toot language: `{}`", toot_language);
log::info!("Recipient's allowed languages: `{:?}`", allowed_langs);
REJECT
};
if allowed_langs.is_empty() {
return ALLOW; // listing no allowed_langs results in allowing all languages
}
match self.language.as_ref() {
Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW,
None => ALLOW, // If toot language is unknown, toot is always allowed
Some(empty) if empty == &String::new() => ALLOW,
Some(toot_language) => reject_and_maybe_log(toot_language),
}
}
/// Returns `true` if this toot originated from a domain the User has blocked.
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
let full_username = &self.account.acct;
match full_username.split('@').nth(1) {
Some(originating_domain) => blocked_domains.contains(originating_domain),
None => false, // None means the user is on the local instance, which can't be blocked
}
}
/// Returns `true` if the Status is from an account that has blocked the current user.
pub fn from_blocking_user(&self, blocking_users: &HashSet<i64>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self);
if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) {
REJECT
} else {
ALLOW
}
}
/// Returns `true` if the User's list of blocked and muted users includes a user
/// involved in this toot.
///
/// A user is involved if they:
/// * Are mentioned in this toot
/// * Wrote this toot
/// * Wrote a toot that this toot is replying to (if any)
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self);
// involved_users = mentioned_users + author + replied-to user + boosted user
let mut involved_users: HashSet<i64> = self
.mentions
.iter()
.map(|mention| mention.id.parse().unwrap_or_else(err))
.collect();
involved_users.insert(self.account.id.parse::<i64>().unwrap_or_else(err));
if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() {
involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err));
}
if let Some(boosted_status) = self.reblog.clone() {
involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err));
}
if involved_users.is_disjoint(blocked_users) {
ALLOW
} else {
REJECT
}
}
}
// #[cfg(test)]
// mod test {
// TODO: Revise these tests to cover *only* the RedisMessage -> (Timeline, Event) parsing
// use super::*;
// use crate::{
// err::RedisParseErr,
@ -935,5 +507,3 @@ impl Status {
// assert_eq!(rest, String::new());
// Ok(())
// }
// }
// TODO: Revise these tests to cover *only* the RedisMessage -> (Timeline, Event) parsing

View File

@ -1,9 +1,8 @@
//! Parse the client request and return a Subscription
mod postgres;
mod query;
mod sse;
mod subscription;
mod ws;
pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API
@ -11,3 +10,8 @@ pub use subscription::{Stream, Subscription, Timeline};
//#[cfg(test)]
pub use subscription::{Content, Reach};
#[cfg(test)]
mod sse_test;
#[cfg(test)]
mod ws_test;

View File

@ -107,19 +107,7 @@ LIMIT 1",
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: i64) -> HashSet<i64> {
// "
// SELECT
// 1
// FROM blocks
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
// OR (account_id = $2 AND target_account_id = $1)
// UNION SELECT
// 1
// FROM mutes
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
self
.0
self.0
.get()
.unwrap()
.query(
@ -142,8 +130,7 @@ UNION SELECT target_account_id
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: i64) -> HashSet<i64> {
self
.0
self.0
.get()
.unwrap()
.query(
@ -164,8 +151,7 @@ SELECT account_id
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: i64) -> HashSet<String> {
self
.0
self.0
.get()
.unwrap()
.query(

View File

@ -109,6 +109,7 @@ impl Subscription {
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
let user = match q.access_token.clone() {
Some(token) => pool.clone().select_user(&token)?,
@ -220,97 +221,8 @@ impl Timeline {
// Other endpoints don't exist:
[..] => Err(TimelineErr::InvalidInput)?,
})
// let (stream, reach, content) = if let Some(ns) = namespace {
// match timeline_slice {
// [n, "timeline", "public"] if n == ns => (Public, Federated, All),
// [_, "timeline", "public"]
// | ["timeline", "public"] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "public", "local"] if ns == n => (Public, Local, All),
// [_, "timeline", "public", "local"]
// | ["timeline", "public", "local"] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "public", "media"] if ns == n => (Public, Federated, Media),
// [_, "timeline", "public", "media"]
// | ["timeline", "public", "media"] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "public", "local", "media"] if ns == n => (Public, Local, Media),
// [_, "timeline", "public", "local", "media"]
// | ["timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "hashtag", tag_name] if ns == n => {
// let tag_id = *cache
// .get(&tag_name.to_string())
// .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
// (Hashtag(tag_id), Federated, All)
// }
// [_, "timeline", "hashtag", _tag]
// | ["timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "hashtag", _tag, "local"] if ns == n => (Hashtag(0), Local, All),
// [_, "timeline", "hashtag", _tag, "local"]
// | ["timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", id] if ns == n => (User(id.parse().unwrap()), Federated, All),
// [_, "timeline", _id]
// | ["timeline", _id] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", id, "notification"] if ns == n =>
// (User(id.parse()?), Federated, Notification),
// [_, "timeline", _id, "notification"]
// | ["timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "list", id] if ns == n => (List(id.parse()?), Federated, All),
// [_, "timeline", "list", _id]
// | ["timeline", "list", _id] => Err(RedisNamespaceMismatch)?,
// [n, "timeline", "direct", id] if ns == n => (Direct(id.parse()?), Federated, All),
// [_, "timeline", "direct", _id]
// | ["timeline", "direct", _id] => Err(RedisNamespaceMismatch)?,
// [..] => log_fatal!("Unexpected channel from Redis: {:?}", timeline_slice),
// }
// } else {
// match timeline_slice {
// ["timeline", "public"] => (Public, Federated, All),
// [_, "timeline", "public"] => Err(RedisNamespaceMismatch)?,
// ["timeline", "public", "local"] => (Public, Local, All),
// [_, "timeline", "public", "local"] => Err(RedisNamespaceMismatch)?,
// ["timeline", "public", "media"] => (Public, Federated, Media),
// [_, "timeline", "public", "media"] => Err(RedisNamespaceMismatch)?,
// ["timeline", "public", "local", "media"] => (Public, Local, Media),
// [_, "timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?,
// ["timeline", "hashtag", _tag] => (Hashtag(0), Federated, All),
// [_, "timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?,
// ["timeline", "hashtag", _tag, "local"] => (Hashtag(0), Local, All),
// [_, "timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?,
// ["timeline", id] => (User(id.parse().unwrap()), Federated, All),
// [_, "timeline", _id] => Err(RedisNamespaceMismatch)?,
// ["timeline", id, "notification"] => {
// (User(id.parse().unwrap()), Federated, Notification)
// }
// [_, "timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?,
// ["timeline", "list", id] => (List(id.parse().unwrap()), Federated, All),
// [_, "timeline", "list", _id] => Err(RedisNamespaceMismatch)?,
// ["timeline", "direct", id] => (Direct(id.parse().unwrap()), Federated, All),
// [_, "timeline", "direct", _id] => Err(RedisNamespaceMismatch)?,
// // Other endpoints don't exist:
// [..] => Err(TimelineErr::InvalidInput)?,
// }
// };
}
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
@ -353,6 +265,7 @@ impl Timeline {
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Stream {
User(i64),
@ -362,11 +275,13 @@ pub enum Stream {
Public,
Unset,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Reach {
Local,
Federated,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Content {
All,

View File

@ -167,89 +167,4 @@ impl<'a> TryFrom<RedisStructuredText<'a>> for RedisParseOutput<'a> {
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_redis_subscribe() -> Result<(), RedisParseErr> {
let input = "*3\r\n$9\r\nsubscribe\r\n$15\r\ntimeline:public\r\n:1\r\n";
let r_subscribe = match RedisParseOutput::try_from(input) {
Ok(NonMsg(leftover)) => leftover,
Ok(Msg(msg)) => panic!("unexpectedly got a msg: {:?}", msg),
Err(e) => panic!("Error in parsing subscribe command: {:?}", e),
};
assert!(r_subscribe.is_empty());
Ok(())
}
#[test]
fn parse_redis_detects_non_newline() -> Result<(), RedisParseErr> {
let input =
"*3QQ$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n";
match RedisParseOutput::try_from(input) {
Ok(NonMsg(leftover)) => panic!(
"Parsed an invalid msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})",
&input, leftover
),
Ok(Msg(msg)) => panic!(
"Parsed an invalid msg as a msg.\nInput `{:?}` parsed to {:?}",
&input, msg
),
Err(_) => (), // should err
};
Ok(())
}
fn parse_redis_msg() -> Result<(), RedisParseErr> {
let input =
"*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n";
let r_msg = match RedisParseOutput::try_from(input) {
Ok(NonMsg(leftover)) => panic!(
"Parsed a msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})",
&input, leftover
),
Ok(Msg(msg)) => msg,
Err(e) => panic!("Error in parsing subscribe command: {:?}", e),
};
assert!(r_msg.leftover_input.is_empty());
assert_eq!(r_msg.timeline_txt, "timeline:308");
assert_eq!(r_msg.event_txt, r#"{"event":"delete","payload":"1038647"}"#);
Ok(())
}
}
// #[derive(Debug, Clone, PartialEq, Copy)]
// pub struct RedisUtf8<'a> {
// pub valid_utf8: &'a str,
// pub leftover_bytes: &'a [u8],
// }
// impl<'a> From<&'a [u8]> for RedisUtf8<'a> {
// fn from(bytes: &'a [u8]) -> Self {
// match str::from_utf8(bytes) {
// Ok(valid_utf8) => Self {
// valid_utf8,
// leftover_bytes: "".as_bytes(),
// },
// Err(e) => {
// let (valid, after_valid) = bytes.split_at(e.valid_up_to());
// Self {
// valid_utf8: str::from_utf8(valid).expect("Guaranteed by `.valid_up_to`"),
// leftover_bytes: after_valid,
// }
// }
// }
// }
// }
// impl<'a> Default for RedisUtf8<'a> {
// fn default() -> Self {
// Self::from("".as_bytes())
// }
// }
mod test;

View File

@ -0,0 +1,54 @@
use super::*;
#[test]
fn parse_redis_subscribe() -> Result<(), RedisParseErr> {
let input = "*3\r\n$9\r\nsubscribe\r\n$15\r\ntimeline:public\r\n:1\r\n";
let r_subscribe = match RedisParseOutput::try_from(input) {
Ok(NonMsg(leftover)) => leftover,
Ok(Msg(msg)) => panic!("unexpectedly got a msg: {:?}", msg),
Err(e) => panic!("Error in parsing subscribe command: {:?}", e),
};
assert!(r_subscribe.is_empty());
Ok(())
}
#[test]
fn parse_redis_detects_non_newline() -> Result<(), RedisParseErr> {
let input =
"*3QQ$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n";
match RedisParseOutput::try_from(input) {
Ok(NonMsg(leftover)) => panic!(
"Parsed an invalid msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})",
&input, leftover
),
Ok(Msg(msg)) => panic!(
"Parsed an invalid msg as a msg.\nInput `{:?}` parsed to {:?}",
&input, msg
),
Err(_) => (), // should err
};
Ok(())
}
fn parse_redis_msg() -> Result<(), RedisParseErr> {
let input =
"*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$38\r\n{\"event\":\"delete\",\"payload\":\"1038647\"}\r\n";
let r_msg = match RedisParseOutput::try_from(input) {
Ok(NonMsg(leftover)) => panic!(
"Parsed a msg as a non-msg.\nInput `{}` parsed to NonMsg({:?})",
&input, leftover
),
Ok(Msg(msg)) => msg,
Err(e) => panic!("Error in parsing subscribe command: {:?}", e),
};
assert!(r_msg.leftover_input.is_empty());
assert_eq!(r_msg.timeline_txt, "timeline:308");
assert_eq!(r_msg.event_txt, r#"{"event":"delete","payload":"1038647"}"#);
Ok(())
}