Add additional error handling

This commit is contained in:
Daniel Sockwell 2020-04-10 17:06:13 -04:00
parent 62df3a56b1
commit 638364883f
12 changed files with 295 additions and 514 deletions

447
old
View File

@ -1,447 +0,0 @@
use crate::log_fatal;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::boxed::Box;
use std::{collections::HashSet, string::String};
pub enum Event {
TypeSafe(CheckedEvent),
Dynamic(DynamicEvent),
}
impl Event {
pub fn to_json_string(&self) -> String {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event)
.unwrap_or_else(|_| log_fatal!("Could not serialize `{:?}`", &sendable_event))
}
pub fn event_name(&self) -> String {
String::from(match self {
Self::TypeSafe(checked) => match checked {
CheckedEvent::Update { .. } => "update",
CheckedEvent::Notification { .. } => "notification",
CheckedEvent::Delete { .. } => "delete",
CheckedEvent::Announcement { .. } => "announcement",
CheckedEvent::AnnouncementReaction { .. } => "announcement.reaction",
CheckedEvent::AnnouncementDelete { .. } => "announcement.delete",
CheckedEvent::Conversation { .. } => "conversation",
CheckedEvent::FiltersChanged => "filters_changed",
},
Self::Dynamic(dyn_event) => &dyn_event.event,
})
}
pub fn payload(&self) -> Option<String> {
use CheckedEvent::*;
match self {
Self::TypeSafe(checked) => match checked {
Update { payload, .. } => Some(escaped(payload)),
Notification { payload, .. } => Some(escaped(payload)),
Delete { payload, .. } => Some(payload.0.clone()),
Announcement { payload, .. } => Some(escaped(payload)),
AnnouncementReaction { payload, .. } => Some(escaped(payload)),
AnnouncementDelete { payload, .. } => Some(payload.0.clone()),
Conversation { payload, .. } => Some(escaped(payload)),
FiltersChanged => None,
},
Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()),
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct DynamicEvent {
pub event: String,
payload: Value,
queued_at: Option<i64>,
}
#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)]
#[rustfmt::skip]
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub enum CheckedEvent {
Update { payload: Status, queued_at: Option<i64> },
Notification { payload: Notification },
Delete { payload: DeletedId },
FiltersChanged,
Announcement { payload: Announcement },
#[serde(rename(serialize = "announcement.reaction", deserialize = "announcement.reaction"))]
AnnouncementReaction { payload: AnnouncementReaction },
#[serde(rename(serialize = "announcement.delete", deserialize = "announcement.delete"))]
AnnouncementDelete { payload: DeletedId },
Conversation { payload: Conversation, queued_at: Option<i64> },
}
#[derive(Serialize, Debug, Clone)]
#[serde(untagged)]
pub enum SendableEvent<'a> {
WithPayload { event: &'a str, payload: String },
NoPayload { event: &'a str },
}
fn escaped<T: Serialize + std::fmt::Debug>(content: T) -> String {
serde_json::to_string(&content)
.unwrap_or_else(|_| log_fatal!("Could not parse Event with: `{:?}`", &content))
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Conversation {
id: String,
accounts: Vec<Account>,
unread: bool,
last_status: Option<Status>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct DeletedId(String);
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Status {
id: String,
uri: String,
created_at: String,
account: Account,
content: String,
visibility: Visibility,
sensitive: bool,
spoiler_text: String,
media_attachments: Vec<Attachment>,
application: Option<Application>, // Should be non-optional?
mentions: Vec<Mention>,
tags: Vec<Tag>,
emojis: Vec<Emoji>,
reblogs_count: i64,
favourites_count: i64,
replies_count: i64,
url: Option<String>,
in_reply_to_id: Option<String>,
in_reply_to_account_id: Option<String>,
reblog: Option<Box<Status>>,
poll: Option<Poll>,
card: Option<Card>,
language: Option<String>,
text: Option<String>,
// ↓↓↓ Only for authorized users
favourited: Option<bool>,
reblogged: Option<bool>,
muted: Option<bool>,
bookmarked: Option<bool>,
pinned: Option<bool>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub enum Visibility {
Public,
Unlisted,
Private,
Direct,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Account {
id: String,
username: String,
acct: String,
url: String,
display_name: String,
note: String,
avatar: String,
avatar_static: String,
header: String,
header_static: String,
locked: bool,
emojis: Vec<Emoji>,
discoverable: Option<bool>, // Shouldn't be option?
created_at: String,
statuses_count: i64,
followers_count: i64,
following_count: i64,
moved: Option<Box<String>>,
fields: Option<Vec<Field>>,
bot: Option<bool>,
source: Option<Source>,
group: Option<bool>, // undocumented
last_status_at: Option<String>, // undocumented
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Attachment {
id: String,
r#type: AttachmentType,
url: String,
preview_url: String,
remote_url: Option<String>,
text_url: Option<String>,
meta: Option<serde_json::Value>,
description: Option<String>,
blurhash: Option<String>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum AttachmentType {
Unknown,
Image,
Gifv,
Video,
Audio,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Application {
name: String,
website: Option<String>,
vapid_key: Option<String>,
client_id: Option<String>,
client_secret: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Emoji {
shortcode: String,
url: String,
static_url: String,
visible_in_picker: bool,
category: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Field {
name: String,
value: String,
verified_at: Option<String>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Source {
note: String,
fields: Vec<Field>,
privacy: Option<Visibility>,
sensitive: bool,
language: String,
follow_requests_count: i64,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Mention {
id: String,
username: String,
acct: String,
url: String,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Tag {
name: String,
url: String,
history: Option<Vec<History>>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Poll {
id: String,
expires_at: String,
expired: bool,
multiple: bool,
votes_count: i64,
voters_count: Option<i64>,
voted: Option<bool>,
own_votes: Option<Vec<i64>>,
options: Vec<PollOptions>,
emojis: Vec<Emoji>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct PollOptions {
title: String,
votes_count: Option<i32>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct Card {
url: String,
title: String,
description: String,
r#type: CardType,
author_name: Option<String>,
author_url: Option<String>,
provider_name: Option<String>,
provider_url: Option<String>,
html: Option<String>,
width: Option<i64>,
height: Option<i64>,
image: Option<String>,
embed_url: Option<String>,
}
#[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum CardType {
Link,
Photo,
Video,
Rich,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
struct History {
day: String,
uses: String,
accounts: String,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Notification {
id: String,
r#type: NotificationType,
created_at: String,
account: Account,
status: Option<Status>,
}
#[serde(rename_all = "snake_case", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
enum NotificationType {
Follow,
FollowRequest, // Undocumented
Mention,
Reblog,
Favourite,
Poll,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Announcement {
// Fully undocumented
id: String,
tags: Vec<Tag>,
all_day: bool,
content: String,
emojis: Vec<Emoji>,
starts_at: Option<String>,
ends_at: Option<String>,
published_at: String,
updated_at: String,
mentions: Vec<Mention>,
reactions: Vec<AnnouncementReaction>,
}
#[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AnnouncementReaction {
#[serde(skip_serializing_if = "Option::is_none")]
announcement_id: Option<String>,
count: i64,
name: String,
}
impl Status {
/// Returns `true` if the status is filtered out based on its language
pub fn language_not_allowed(&self, allowed_langs: &HashSet<String>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let reject_and_maybe_log = |toot_language| {
log::info!("Filtering out toot from `{}`", &self.account.acct);
log::info!("Toot language: `{}`", toot_language);
log::info!("Recipient's allowed languages: `{:?}`", allowed_langs);
REJECT
};
if allowed_langs.is_empty() {
return ALLOW; // listing no allowed_langs results in allowing all languages
}
match self.language.as_ref() {
Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW,
None => ALLOW, // If toot language is unknown, toot is always allowed
Some(empty) if empty == &String::new() => ALLOW,
Some(toot_language) => reject_and_maybe_log(toot_language),
}
}
/// Returns `true` if this toot originated from a domain the User has blocked.
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
let full_username = &self.account.acct;
match full_username.split('@').nth(1) {
Some(originating_domain) => blocked_domains.contains(originating_domain),
None => false, // None means the user is on the local instance, which can't be blocked
}
}
/// Returns `true` if the Status is from an account that has blocked the current user.
pub fn from_blocking_user(&self, blocking_users: &HashSet<i64>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let err = |_| log_fatal!("Could not process `account.id` in {:?}", &self);
if blocking_users.contains(&self.account.id.parse().unwrap_or_else(err)) {
REJECT
} else {
ALLOW
}
}
/// Returns `true` if the User's list of blocked and muted users includes a user
/// involved in this toot.
///
/// A user is involved if they:
/// * Are mentioned in this toot
/// * Wrote this toot
/// * Wrote a toot that this toot is replying to (if any)
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self);
// involved_users = mentioned_users + author + replied-to user + boosted user
let mut involved_users: HashSet<i64> = self
.mentions
.iter()
.map(|mention| mention.id.parse().unwrap_or_else(err))
.collect();
involved_users.insert(self.account.id.parse::<i64>().unwrap_or_else(err));
if let Some(replied_to_account_id) = self.in_reply_to_account_id.clone() {
involved_users.insert(replied_to_account_id.parse().unwrap_or_else(err));
}
if let Some(boosted_status) = self.reblog.clone() {
involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err));
}
if involved_users.is_disjoint(blocked_users) {
ALLOW
} else {
REJECT
}
}
}
#[cfg(test)]
mod test;

View File

@ -1,17 +1,42 @@
use super::super::EventErr;
use serde::{
de::{self, Visitor},
Deserialize, Deserializer, Serialize, Serializer,
};
use std::fmt;
use serde_json::Value;
use std::{convert::TryFrom, fmt, num::ParseIntError, str::FromStr};
/// A user ID.
///
/// Internally, Mastodon IDs are i64s, but are sent to clients as string because
/// JavaScript numbers don't support i64s. This newtype serializes to/from a string, but
/// keeps the i64 as the "true" value for internal use.
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct Id(pub i64);
impl TryFrom<&Value> for Id {
type Error = EventErr;
fn try_from(v: &Value) -> Result<Self, Self::Error> {
Ok(v.as_str().ok_or(EventErr::DynParse)?.parse()?)
}
}
impl std::ops::Deref for Id {
type Target = i64;
fn deref(&self) -> &i64 {
&self.0
}
}
impl FromStr for Id {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.parse()?))
}
}
impl Serialize for Id {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@ -38,6 +63,13 @@ impl<'de> Visitor<'de> for IdVisitor {
formatter.write_str("a string that can be parsed into an i64")
}
fn visit_str<E: de::Error>(self, value: &str) -> Result<Self::Value, E> {
match value.parse() {
Ok(n) => Ok(Id(n)),
Err(e) => Err(E::custom(format!("could not parse: {}", e))),
}
}
fn visit_string<E: de::Error>(self, value: String) -> Result<Self::Value, E> {
match value.parse() {
Ok(n) => Ok(Id(n)),

View File

@ -14,6 +14,7 @@ mod visibility;
pub use announcement::Announcement;
pub(in crate::messages::event) use announcement_reaction::AnnouncementReaction;
pub use conversation::Conversation;
pub use id::Id;
pub use notification::Notification;
pub use status::Status;

View File

@ -92,7 +92,7 @@ impl Status {
blocking_users,
blocked_domains,
} = blocks;
let user_id = &self.account.id.0;
let user_id = &Id(self.account.id.0);
if blocking_users.contains(user_id) || self.involves(blocked_users) {
REJECT
@ -105,20 +105,23 @@ impl Status {
}
}
fn involves(&self, blocked_users: &HashSet<i64>) -> bool {
fn involves(&self, blocked_users: &HashSet<Id>) -> bool {
// involved_users = mentioned_users + author + replied-to user + boosted user
let mut involved_users: HashSet<i64> =
self.mentions.iter().map(|mention| mention.id.0).collect();
let mut involved_users: HashSet<Id> = self
.mentions
.iter()
.map(|mention| Id(mention.id.0))
.collect();
// author
involved_users.insert(self.account.id.0);
involved_users.insert(Id(self.account.id.0));
// replied-to user
if let Some(user_id) = self.in_reply_to_account_id.clone() {
involved_users.insert(user_id.0);
if let Some(user_id) = self.in_reply_to_account_id {
involved_users.insert(Id(user_id.0));
}
// boosted user
if let Some(boosted_status) = self.reblog.clone() {
involved_users.insert(boosted_status.account.id.0);
involved_users.insert(Id(boosted_status.account.id.0));
}
!involved_users.is_disjoint(blocked_users)
}

View File

@ -0,0 +1,136 @@
use super::{EventErr, Id};
use crate::parse_client_request::Blocks;
use std::convert::TryFrom;
use hashbrown::HashSet;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct DynEvent {
#[serde(skip)]
pub kind: EventKind,
pub event: String,
pub payload: Value,
pub queued_at: Option<i64>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum EventKind {
Update(DynStatus),
NonUpdate,
}
impl Default for EventKind {
fn default() -> Self {
Self::NonUpdate
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DynStatus {
pub id: Id,
pub username: String,
pub language: Option<String>,
pub mentioned_users: HashSet<Id>,
pub replied_to_user: Option<Id>,
pub boosted_user: Option<Id>,
pub payload: Value,
}
type Result<T> = std::result::Result<T, EventErr>; // TODO cut if not used more than once
impl DynEvent {
pub fn set_update(self) -> Result<Self> {
if self.event == "update" {
let kind = EventKind::Update(DynStatus::new(self.payload.clone())?);
Ok(Self { kind, ..self })
} else {
Ok(self)
}
}
}
impl DynStatus {
pub fn new(payload: Value) -> Result<Self> {
use EventErr::*;
Ok(Self {
id: Id::try_from(&payload["account"]["id"])?,
username: payload["account"]["acct"]
.as_str()
.ok_or(DynParse)?
.to_string(),
language: payload["language"].as_str().map(|s| s.to_string()),
mentioned_users: HashSet::new(),
replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(),
boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(),
payload,
})
}
/// Returns `true` if the status is filtered out based on its language
pub fn language_not(&self, allowed_langs: &HashSet<String>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
if allowed_langs.is_empty() {
return ALLOW; // listing no allowed_langs results in allowing all languages
}
match self.language.clone() {
Some(toot_language) if allowed_langs.contains(&toot_language) => ALLOW, //
None => ALLOW, // If toot language is unknown, toot is always allowed
Some(empty) if empty == String::new() => ALLOW,
Some(_toot_language) => REJECT,
}
}
/// Returns `true` if the toot contained in this Event originated from a blocked domain,
/// is from an account that has blocked the current user, or if the User's list of
/// blocked/muted users includes a user involved in the toot.
///
/// A user is involved in the toot if they:
/// * Are mentioned in this toot
/// * Wrote this toot
/// * Wrote a toot that this toot is replying to (if any)
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_any(&self, blocks: &Blocks) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let Blocks {
blocked_users,
blocking_users,
blocked_domains,
} = blocks;
if self.involves(blocked_users) || blocking_users.contains(&self.id) {
REJECT
} else {
match self.username.split('@').nth(1) {
Some(originating_domain) if blocked_domains.contains(originating_domain) => REJECT,
Some(_) | None => ALLOW, // None means the local instance, which can't be blocked
}
}
}
// involved_users = mentioned_users + author + replied-to user + boosted user
fn involves(&self, blocked_users: &HashSet<Id>) -> bool {
// mentions
let mut involved_users: HashSet<Id> = self.mentioned_users.clone();
// author
involved_users.insert(self.id);
// replied-to user
if let Some(user_id) = self.replied_to_user {
involved_users.insert(user_id);
}
// boosted user
if let Some(user_id) = self.boosted_user {
involved_users.insert(user_id);
}
!involved_users.is_disjoint(blocked_users)
}
}

33
src/messages/event/err.rs Normal file
View File

@ -0,0 +1,33 @@
use std::{fmt, num::ParseIntError};
#[derive(Debug)]
pub enum EventErr {
SerdeParse(serde_json::Error),
NonNumId(ParseIntError),
DynParse,
}
impl std::error::Error for EventErr {}
impl fmt::Display for EventErr {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use EventErr::*;
match self {
SerdeParse(inner) => write!(f, "{}", inner),
NonNumId(inner) => write!(f, "ID could not be parsed: {}", inner),
DynParse => write!(f, "Could not find a required field in input JSON"),
}?;
Ok(())
}
}
impl From<ParseIntError> for EventErr {
fn from(error: ParseIntError) -> Self {
Self::NonNumId(error)
}
}
impl From<serde_json::Error> for EventErr {
fn from(error: serde_json::Error) -> Self {
Self::SerdeParse(error)
}
}

View File

@ -1,17 +1,21 @@
mod checked_event;
mod dynamic_event;
mod err;
pub use {checked_event::CheckedEvent, dynamic_event::DynamicEvent};
pub use {
checked_event::{CheckedEvent, Id},
dynamic_event::{DynEvent, DynStatus, EventKind},
err::EventErr,
};
use crate::log_fatal;
use crate::redis_to_client_stream::ReceiverErr;
use serde::Serialize;
use std::{convert::TryFrom, string::String};
#[derive(Debug, Clone)]
pub enum Event {
TypeSafe(CheckedEvent),
Dynamic(DynamicEvent),
Dynamic(DynEvent),
Ping,
}
@ -38,7 +42,11 @@ impl Event {
CheckedEvent::Conversation { .. } => "conversation",
CheckedEvent::FiltersChanged => "filters_changed",
},
Self::Dynamic(dyn_event) => &dyn_event.event,
Self::Dynamic(DynEvent {
kind: EventKind::Update(_),
..
}) => "update",
Self::Dynamic(DynEvent { event, .. }) => event,
Self::Ping => panic!("event_name() called on EventNotReady"),
})
}
@ -56,21 +64,23 @@ impl Event {
Conversation { payload, .. } => Some(escaped(payload)),
FiltersChanged => None,
},
Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()),
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
Self::Ping => panic!("payload() called on EventNotReady"),
}
}
}
impl TryFrom<String> for Event {
type Error = ReceiverErr;
fn try_from(event_txt: String) -> Result<Event, ReceiverErr> {
type Error = EventErr;
fn try_from(event_txt: String) -> Result<Event, Self::Error> {
Event::try_from(event_txt.as_str())
}
}
impl TryFrom<&str> for Event {
type Error = ReceiverErr;
fn try_from(event_txt: &str) -> Result<Event, ReceiverErr> {
type Error = EventErr;
fn try_from(event_txt: &str) -> Result<Event, Self::Error> {
match serde_json::from_str(event_txt) {
Ok(checked_event) => Ok(Event::TypeSafe(checked_event)),
Err(e) => {
@ -80,8 +90,8 @@ impl TryFrom<&str> for Event {
Forwarding Redis payload without type checking it.",
e
);
let dyn_event: DynamicEvent = serde_json::from_str(&event_txt)?;
Ok(Event::Dynamic(dyn_event))
Ok(Event::Dynamic(serde_json::from_str(&event_txt)?))
}
}
}

View File

@ -1,3 +1,3 @@
mod event;
pub use event::{CheckedEvent, DynamicEvent, Event};
pub use event::{CheckedEvent, DynEvent, Event, EventErr, EventKind, Id};

View File

@ -1,6 +1,7 @@
//! Postgres queries
use crate::{
config,
messages::Id,
parse_client_request::subscription::{Scope, UserData},
};
use ::postgres;
@ -28,6 +29,7 @@ impl PgPool {
.expect("Can connect to local postgres");
Self(pool)
}
pub fn select_user(self, token: &str) -> Result<UserData, Rejection> {
let mut conn = self.0.get().unwrap();
let query_rows = conn
@ -45,7 +47,7 @@ LIMIT 1",
)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let id = result_columns.get(1);
let id = Id(result_columns.get(1));
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
@ -96,17 +98,16 @@ LIMIT 1",
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
rows.get(0)
.map(|row| row.get(0))
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
}
/// Query Postgres for everyone the user has blocked or muted
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: i64) -> HashSet<i64> {
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
self.0
.get()
.unwrap()
@ -118,18 +119,18 @@ SELECT target_account_id
UNION SELECT target_account_id
FROM mutes
WHERE account_id = $1",
&[&user_id],
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.map(|row| Id(row.get(0)))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: i64) -> HashSet<i64> {
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
self.0
.get()
.unwrap()
@ -138,11 +139,11 @@ UNION SELECT target_account_id
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&user_id],
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.map(|row| Id(row.get(0)))
.collect()
}
@ -150,13 +151,13 @@ SELECT account_id
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: i64) -> HashSet<String> {
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
self.0
.get()
.unwrap()
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
@ -165,7 +166,7 @@ SELECT account_id
}
/// Test whether a user owns a list
pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool {
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
let mut conn = self.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
@ -181,10 +182,7 @@ LIMIT 1",
match rows.get(0) {
None => false,
Some(row) => {
let list_owner_id: i64 = row.get(1);
list_owner_id == user_id
}
Some(row) => Id(row.get(1)) == user_id,
}
}
}

View File

@ -9,6 +9,7 @@ use super::postgres::PgPool;
use super::query::Query;
use crate::err::TimelineErr;
use crate::log_fatal;
use crate::messages::Id;
use hashbrown::HashSet;
use lru::LruCache;
use uuid::Uuid;
@ -62,8 +63,8 @@ pub struct Subscription {
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Blocks {
pub blocked_domains: HashSet<String>,
pub blocked_users: HashSet<i64>,
pub blocking_users: HashSet<i64>,
pub blocked_users: HashSet<Id>,
pub blocking_users: HashSet<Id>,
}
impl Default for Subscription {
@ -254,11 +255,11 @@ impl Timeline {
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
"user" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, All),
true => Timeline(User(*user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
"user:notification" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, Notification),
true => Timeline(User(*user.id), Federated, Notification),
false => Err(custom("Error: Missing access token"))?,
},
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
@ -266,7 +267,7 @@ impl Timeline {
false => Err(warp::reject::custom("Error: Missing access token"))?,
},
"direct" => match user.scopes.contains(&Statuses) {
true => Timeline(Direct(user.id), Federated, All),
true => Timeline(Direct(*user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
other => {
@ -309,7 +310,7 @@ pub enum Scope {
}
pub struct UserData {
pub id: i64,
pub id: Id,
pub allowed_langs: HashSet<String>,
pub scopes: HashSet<Scope>,
}
@ -317,7 +318,7 @@ pub struct UserData {
impl UserData {
fn public() -> Self {
Self {
id: -1,
id: Id(-1),
allowed_langs: HashSet::new(),
scopes: HashSet::new(),
}

View File

@ -56,7 +56,7 @@ impl WsStream {
if matches!(event, Event::Ping) {
self.send_ping()
} else if target_timeline == tl {
use crate::messages::{CheckedEvent::Update, Event::*};
use crate::messages::{CheckedEvent::Update, Event::*, EventKind};
use crate::parse_client_request::Stream::Public;
let blocks = &self.subscription.blocks;
let allowed_langs = &self.subscription.allowed_langs;
@ -68,12 +68,17 @@ impl WsStream {
_ => self.send_msg(TypeSafe(Update { payload, queued_at })),
},
TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)),
Dynamic(event) if event.event == "update" => match tl {
Timeline(Public, _, _) if event.language_not(allowed_langs) => Ok(()),
_ if event.involves_any(&blocks) => Ok(()),
_ => self.send_msg(Dynamic(event)),
},
Dynamic(non_update) => self.send_msg(Dynamic(non_update)),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind.clone() {
match tl {
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
_ if s.involves_any(&blocks) => Ok(()),
_ => self.send_msg(Dynamic(dyn_event)),
}
} else {
self.send_msg(Dynamic(dyn_event))
}
}
Ping => unreachable!(), // handled pings above
}
} else {
@ -127,7 +132,10 @@ impl SseStream {
let event_stream = sse_rx
.filter_map(move |(timeline, event)| {
if target_timeline == timeline {
use crate::messages::{CheckedEvent, CheckedEvent::Update, Event::*};
use crate::messages::{
CheckedEvent, CheckedEvent::Update, Event::*, EventKind,
};
use crate::parse_client_request::Stream::Public;
match event {
TypeSafe(Update { payload, queued_at }) => match timeline {
@ -139,12 +147,19 @@ impl SseStream {
})),
},
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
Dynamic(event) if event.event == "update" => match timeline {
Timeline(Public, _, _) if event.language_not(&allowed_langs) => None,
_ if event.involves_any(&blocks) => None,
_ => Self::reply_with(Event::Dynamic(event)),
},
Dynamic(non_update) => Self::reply_with(Event::Dynamic(non_update)),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind.clone() {
match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => {
None
}
_ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(dyn_event)),
}
} else {
None
}
}
Ping => None, // pings handled automatically
}
} else {

View File

@ -1,16 +1,14 @@
use super::super::redis::{RedisConnErr, RedisParseErr};
use crate::err::TimelineErr;
use crate::messages::Event;
use crate::messages::{Event, EventErr};
use crate::parse_client_request::Timeline;
use serde_json;
use std::fmt;
#[derive(Debug)]
pub enum ReceiverErr {
InvalidId,
TimelineErr(TimelineErr),
EventErr(serde_json::Error),
EventErr(EventErr),
RedisParseErr(RedisParseErr),
RedisConnErr(RedisConnErr),
ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>),
@ -35,14 +33,15 @@ impl fmt::Display for ReceiverErr {
Ok(())
}
}
impl From<tokio::sync::watch::error::SendError<(Timeline, Event)>> for ReceiverErr {
fn from(error: tokio::sync::watch::error::SendError<(Timeline, Event)>) -> Self {
Self::ChannelSendErr(error)
}
}
impl From<serde_json::Error> for ReceiverErr {
fn from(error: serde_json::Error) -> Self {
impl From<EventErr> for ReceiverErr {
fn from(error: EventErr) -> Self {
Self::EventErr(error)
}
}