mirror of https://github.com/mastodon/flodgatt
WIP implementation of Message refactor
This commit is contained in:
parent
c0355827fb
commit
4df364d1ac
|
@ -10,19 +10,6 @@ use super::query::Query;
|
|||
use std::collections::HashSet;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
/// The filters that can be applied to toots after they come from Redis
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Filter {
|
||||
NoFilter,
|
||||
Language,
|
||||
Notification,
|
||||
}
|
||||
impl Default for Filter {
|
||||
fn default() -> Self {
|
||||
Filter::Language
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct OauthScope {
|
||||
pub all: bool,
|
||||
|
@ -60,9 +47,8 @@ pub struct User {
|
|||
pub access_token: String, // We only need this once (to send back with the WS reply). Cut?
|
||||
pub id: i64,
|
||||
pub scopes: OauthScope,
|
||||
pub langs: Option<Vec<String>>,
|
||||
pub logged_in: bool,
|
||||
pub filter: Filter,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub blocks: Blocks,
|
||||
}
|
||||
|
||||
|
@ -73,10 +59,9 @@ impl Default for User {
|
|||
email: "".to_string(),
|
||||
access_token: "".to_string(),
|
||||
scopes: OauthScope::default(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
target_timeline: String::new(),
|
||||
filter: Filter::default(),
|
||||
allowed_langs: HashSet::new(),
|
||||
blocks: Blocks::default(),
|
||||
}
|
||||
}
|
||||
|
@ -97,33 +82,30 @@ impl User {
|
|||
Ok(user)
|
||||
}
|
||||
|
||||
fn set_timeline_and_filter(mut self, q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||
let read_scope = self.scopes.clone();
|
||||
let timeline = match q.stream.as_ref() {
|
||||
fn set_timeline_and_filter(self, q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||
let (read_scope, f) = (self.scopes.clone(), self.allowed_langs.clone());
|
||||
let (filter, target_timeline) = match q.stream.as_ref() {
|
||||
// Public endpoints:
|
||||
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
|
||||
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
|
||||
tl @ "public" | tl @ "public:local" => tl.to_string(),
|
||||
tl @ "public" | tl @ "public:local" if q.media => (f, format!("{}:media", tl)),
|
||||
tl @ "public:media" | tl @ "public:local:media" => (f, tl.to_string()),
|
||||
tl @ "public" | tl @ "public:local" => (f, tl.to_string()),
|
||||
|
||||
// Hashtag endpoints:
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => (f, format!("{}:{}", tl, q.hashtag)),
|
||||
// Private endpoints: User:
|
||||
"user" if self.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
self.filter = Filter::NoFilter;
|
||||
format!("{}", self.id)
|
||||
(HashSet::new(), format!("{}", self.id))
|
||||
}
|
||||
"user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => {
|
||||
self.filter = Filter::Notification;
|
||||
format!("{}", self.id)
|
||||
(HashSet::new(), format!("{}", self.id))
|
||||
}
|
||||
// List endpoint:
|
||||
"list" if self.owns_list(q.list, pool) && (read_scope.all || read_scope.lists) => {
|
||||
self.filter = Filter::NoFilter;
|
||||
format!("list:{}", q.list)
|
||||
(HashSet::new(), format!("list:{}", q.list))
|
||||
}
|
||||
// Direct endpoint:
|
||||
"direct" if self.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
self.filter = Filter::NoFilter;
|
||||
"direct".to_string()
|
||||
(HashSet::new(), "direct".to_string())
|
||||
}
|
||||
// Reject unathorized access attempts for private endpoints
|
||||
"user" | "user:notification" | "direct" | "list" => {
|
||||
|
@ -133,7 +115,8 @@ impl User {
|
|||
_ => return Err(warp::reject::custom("Error: Nonexistent endpoint")),
|
||||
};
|
||||
Ok(Self {
|
||||
target_timeline: timeline,
|
||||
target_timeline,
|
||||
allowed_langs: filter,
|
||||
..self
|
||||
})
|
||||
}
|
||||
|
|
|
@ -53,19 +53,27 @@ LIMIT 1",
|
|||
if query_result.is_empty() {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
} else {
|
||||
// TODO: better name than `only_row`
|
||||
let only_row: &postgres::Row = query_result.get(0).unwrap();
|
||||
let scope_vec: Vec<String> = only_row
|
||||
.get::<_, String>(4)
|
||||
.split(' ')
|
||||
.map(|s| s.to_owned())
|
||||
.collect();
|
||||
let mut allowed_langs = HashSet::new();
|
||||
if let Ok(langs_vec) = only_row.try_get::<_, Vec<String>>(3) {
|
||||
for lang in langs_vec {
|
||||
allowed_langs.insert(lang);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(User {
|
||||
id: only_row.get(1),
|
||||
access_token: access_token.to_string(),
|
||||
email: only_row.get(2),
|
||||
logged_in: true,
|
||||
access_token: access_token.to_string(),
|
||||
id: only_row.get(1),
|
||||
scopes: OauthScope::from(scope_vec),
|
||||
langs: only_row.get(3),
|
||||
logged_in: true,
|
||||
allowed_langs,
|
||||
..User::default()
|
||||
})
|
||||
}
|
||||
|
|
|
@ -15,11 +15,10 @@
|
|||
//! Because `StreamManagers` are lightweight data structures that do not directly
|
||||
//! communicate with Redis, it we create a new `ClientAgent` for
|
||||
//! each new client connection (each in its own thread).
|
||||
use super::receiver::Receiver;
|
||||
use super::{message::Message, receiver::Receiver};
|
||||
use crate::{config, parse_client_request::user::User};
|
||||
use futures::{Async, Poll};
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashSet, fmt::Display, sync};
|
||||
use std::sync;
|
||||
use tokio::io::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -71,7 +70,7 @@ impl ClientAgent {
|
|||
|
||||
/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented.
|
||||
impl futures::stream::Stream for ClientAgent {
|
||||
type Item = Toot;
|
||||
type Item = Message;
|
||||
type Error = Error;
|
||||
|
||||
/// Checks for any new messages that should be sent to the client.
|
||||
|
@ -96,135 +95,16 @@ impl futures::stream::Stream for ClientAgent {
|
|||
log::warn!("Polling the Receiver took: {:?}", start_time.elapsed());
|
||||
};
|
||||
|
||||
let (filter, blocks) = (&self.current_user.allowed_langs, &self.current_user.blocks);
|
||||
match result {
|
||||
Ok(Async::Ready(Some(value))) => {
|
||||
let user = &self.current_user;
|
||||
let toot = Toot::from_json(value);
|
||||
toot.filter(&user)
|
||||
}
|
||||
Ok(Async::Ready(Some(json))) => match Message::from_json(json) {
|
||||
Message::Update(status) if status.is_filtered_out(filter) => Ok(Async::NotReady),
|
||||
Message::Update(status) if status.is_blocked(blocks) => Ok(Async::NotReady),
|
||||
no_filtering_needed => Ok(Async::Ready(Some(no_filtering_needed))),
|
||||
},
|
||||
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The message to send to the client (which might not literally be a toot in some cases).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Toot {
|
||||
pub event_type: Event,
|
||||
pub language: Option<String>,
|
||||
pub payload: Value,
|
||||
}
|
||||
|
||||
use std::fmt;
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Event {
|
||||
Update,
|
||||
Delete,
|
||||
}
|
||||
|
||||
impl Display for Event {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Update => write!(f, "update"),
|
||||
Self::Delete => write!(f, "delete"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Toot {
|
||||
/// Construct a `Toot` from well-formed JSON.
|
||||
pub fn from_json(value: Value) -> Self {
|
||||
let payload = value["payload"].clone();
|
||||
match value["event"].as_str().expect("Redis") {
|
||||
"update" => Self {
|
||||
event_type: Event::Update,
|
||||
language: Some(payload["language"].as_str().expect("Redis").into()),
|
||||
payload,
|
||||
},
|
||||
"delete" => Self {
|
||||
event_type: Event::Delete,
|
||||
language: None,
|
||||
payload,
|
||||
},
|
||||
other => panic!("Unknown event type `{}` received.", other),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_originating_domain(&self) -> HashSet<String> {
|
||||
let api = "originating Invariant Violation: JSON value does not conform to Mastdon API";
|
||||
let mut originating_domain = HashSet::new();
|
||||
// TODO: make this log an error instead of panicking.
|
||||
originating_domain.insert(
|
||||
self.payload["account"]["acct"]
|
||||
.as_str()
|
||||
.expect(&api)
|
||||
.split('@')
|
||||
.nth(1)
|
||||
.expect(&api)
|
||||
.to_string(),
|
||||
);
|
||||
originating_domain
|
||||
}
|
||||
|
||||
pub fn get_involved_users(&self) -> HashSet<i64> {
|
||||
let mut involved_users: HashSet<i64> = HashSet::new();
|
||||
let msg = self.payload.clone();
|
||||
|
||||
let api = "Invariant Violation: JSON value does not conform to Mastdon API";
|
||||
involved_users.insert(msg["account"]["id"].str_to_i64().expect(&api));
|
||||
if let Some(mentions) = msg["mentions"].as_array() {
|
||||
for mention in mentions {
|
||||
involved_users.insert(mention["id"].str_to_i64().expect(&api));
|
||||
}
|
||||
}
|
||||
if let Some(replied_to_account) = msg["in_reply_to_account_id"].as_str() {
|
||||
involved_users.insert(replied_to_account.parse().expect(&api));
|
||||
}
|
||||
|
||||
if let Some(reblog) = msg["reblog"].as_object() {
|
||||
involved_users.insert(reblog["account"]["id"].str_to_i64().expect(&api));
|
||||
}
|
||||
involved_users
|
||||
}
|
||||
|
||||
/// Filter out any `Toot`'s that fail the provided filter.
|
||||
pub fn filter(self, user: &User) -> Result<Async<Option<Self>>, Error> {
|
||||
let toot_language = &self.language.clone().expect("Valid lanugage");
|
||||
let event_type = &self.event_type.clone();
|
||||
let (send_msg, skip_msg) = (Ok(Async::Ready(Some(self))), Ok(Async::NotReady));
|
||||
|
||||
match event_type {
|
||||
Event::Update => {
|
||||
use crate::parse_client_request::user::Filter;
|
||||
|
||||
match &user.filter {
|
||||
Filter::NoFilter => send_msg,
|
||||
Filter::Language if user.langs.is_none() => send_msg,
|
||||
Filter::Language if user.langs.clone().expect("").contains(toot_language) => {
|
||||
send_msg
|
||||
}
|
||||
// If not, skip it
|
||||
Filter::Notification => skip_msg,
|
||||
Filter::Language => skip_msg,
|
||||
}
|
||||
}
|
||||
Event::Delete => send_msg,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait ConvertValue {
|
||||
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>>;
|
||||
}
|
||||
|
||||
impl ConvertValue for Value {
|
||||
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>> {
|
||||
Ok(self
|
||||
.as_str()
|
||||
.ok_or(format!("{} is not a string", &self))?
|
||||
.parse()
|
||||
.map_err(|_| "Could not parse str")?)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
use crate::parse_client_request::user::Blocks;
|
||||
use serde_json::Value;
|
||||
use std::{collections::HashSet, string::String};
|
||||
use strum_macros::Display;
|
||||
|
||||
#[derive(Debug, Display, Clone)]
|
||||
pub enum Message {
|
||||
Update(Status),
|
||||
Conversation(Value),
|
||||
Notification(Value),
|
||||
Delete(String),
|
||||
FiltersChanged,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Status(pub Value);
|
||||
|
||||
impl Message {
|
||||
pub fn from_json(json: Value) -> Self {
|
||||
match json["event"].as_str().unwrap() {
|
||||
"update" => Self::Update(Status(json["payload"].clone())),
|
||||
"conversation" => Self::Conversation(json["payload"].clone()),
|
||||
"notification" => Self::Notification(json["payload"].clone()),
|
||||
"delete" => Self::Delete(json["payload"].to_string()),
|
||||
"filters_changed" => Self::FiltersChanged,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
pub fn event(&self) -> String {
|
||||
format!("{}", self)
|
||||
}
|
||||
pub fn payload(&self) -> String {
|
||||
match self {
|
||||
Self::Delete(id) => id.clone(),
|
||||
Self::Update(status) => status.0.to_string(),
|
||||
Self::Conversation(value) | Self::Notification(value) => value.to_string(),
|
||||
Self::FiltersChanged => "".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Status {
|
||||
pub fn get_originating_domain(&self) -> HashSet<String> {
|
||||
let api = "originating Invariant Violation: JSON value does not conform to Mastodon API";
|
||||
let mut originating_domain = HashSet::new();
|
||||
// TODO: make this log an error instead of panicking.
|
||||
originating_domain.insert(
|
||||
self.0["account"]["acct"]
|
||||
.as_str()
|
||||
.expect(&api)
|
||||
.split('@')
|
||||
.nth(1)
|
||||
.expect(&api)
|
||||
.to_string(),
|
||||
);
|
||||
originating_domain
|
||||
}
|
||||
|
||||
pub fn get_involved_users(&self) -> HashSet<i64> {
|
||||
let mut involved_users: HashSet<i64> = HashSet::new();
|
||||
let msg = self.0.clone();
|
||||
|
||||
let api = "Invariant Violation: JSON value does not conform to Mastodon API";
|
||||
involved_users.insert(msg["account"]["id"].str_to_i64().expect(&api));
|
||||
if let Some(mentions) = msg["mentions"].as_array() {
|
||||
for mention in mentions {
|
||||
involved_users.insert(mention["id"].str_to_i64().expect(&api));
|
||||
}
|
||||
}
|
||||
if let Some(replied_to_account) = msg["in_reply_to_account_id"].as_str() {
|
||||
involved_users.insert(replied_to_account.parse().expect(&api));
|
||||
}
|
||||
|
||||
if let Some(reblog) = msg["reblog"].as_object() {
|
||||
involved_users.insert(reblog["account"]["id"].str_to_i64().expect(&api));
|
||||
}
|
||||
involved_users
|
||||
}
|
||||
|
||||
pub fn is_filtered_out(&self, permitted_langs: &HashSet<String>) -> bool {
|
||||
// TODO add logging
|
||||
let toot_language = self.0["language"]
|
||||
.as_str()
|
||||
.expect("Valid language")
|
||||
.to_string();
|
||||
!{ permitted_langs.is_empty() || permitted_langs.contains(&toot_language) }
|
||||
}
|
||||
|
||||
/// Returns `true` if the status is blocked by _either_ domain blocks or _user_ blocks
|
||||
pub fn is_blocked(&self, b: &Blocks) -> bool {
|
||||
// TODO add logging
|
||||
!{
|
||||
b.domain_blocks.is_disjoint(&self.get_originating_domain())
|
||||
&& b.user_blocks.is_disjoint(&self.get_involved_users())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait ConvertValue {
|
||||
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>>;
|
||||
}
|
||||
|
||||
impl ConvertValue for Value {
|
||||
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>> {
|
||||
Ok(self
|
||||
.as_str()
|
||||
.ok_or(format!("{} is not a string", &self))?
|
||||
.parse()
|
||||
.map_err(|_| "Could not parse str")?)
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
|
||||
pub mod client_agent;
|
||||
pub mod message;
|
||||
pub mod receiver;
|
||||
pub mod redis;
|
||||
|
||||
|
@ -17,9 +18,9 @@ pub fn send_updates_to_sse(
|
|||
) -> impl warp::reply::Reply {
|
||||
let event_stream = tokio::timer::Interval::new(time::Instant::now(), update_interval)
|
||||
.filter_map(move |_| match client_agent.poll() {
|
||||
Ok(Async::Ready(Some(toot))) => Some((
|
||||
warp::sse::event(toot.event_type),
|
||||
warp::sse::data(toot.payload),
|
||||
Ok(Async::Ready(Some(msg))) => Some((
|
||||
warp::sse::event(msg.event()),
|
||||
warp::sse::data(msg.payload()),
|
||||
)),
|
||||
_ => None,
|
||||
});
|
||||
|
@ -82,32 +83,16 @@ pub fn send_updates_to_ws(
|
|||
|
||||
let mut time = time::Instant::now();
|
||||
|
||||
let (tl, email, id, blocked_users, blocked_domains) = (
|
||||
client_agent.current_user.target_timeline.clone(),
|
||||
client_agent.current_user.email.clone(),
|
||||
client_agent.current_user.id,
|
||||
client_agent.current_user.blocks.user_blocks.clone(),
|
||||
client_agent.current_user.blocks.domain_blocks.clone(),
|
||||
);
|
||||
// Every time you get an event from that stream, send it through the pipe
|
||||
event_stream
|
||||
.for_each(move |_instant| {
|
||||
if let Ok(Async::Ready(Some(toot))) = client_agent.poll() {
|
||||
if blocked_domains.is_disjoint(&toot.get_originating_domain())
|
||||
&& blocked_users.is_disjoint(&toot.get_involved_users())
|
||||
{
|
||||
let txt = &toot.payload["content"];
|
||||
log::warn!("toot: {}\nTL: {}\nUser: {}({})", txt, tl, email, id);
|
||||
|
||||
tx.unbounded_send(warp::ws::Message::text(
|
||||
json!({ "event": toot.event_type.to_string(),
|
||||
"payload": &toot.payload.to_string() })
|
||||
.to_string(),
|
||||
))
|
||||
.expect("No send error");
|
||||
} else {
|
||||
log::info!("Blocked a message to {}", email);
|
||||
}
|
||||
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
|
||||
tx.unbounded_send(warp::ws::Message::text(
|
||||
json!({ "event": msg.event(),
|
||||
"payload": msg.payload() })
|
||||
.to_string(),
|
||||
))
|
||||
.expect("No send error");
|
||||
};
|
||||
if time.elapsed() > time::Duration::from_secs(30) {
|
||||
tx.unbounded_send(warp::ws::Message::text("{}"))
|
||||
|
|
Loading…
Reference in New Issue