diff --git a/Cargo.toml b/Cargo.toml index cbdacc3..90a99c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.4.8" +version = "0.5.0" authors = ["Daniel Long Sockwell "] edition = "2018" diff --git a/src/main.rs b/src/main.rs index 843bed3..44177a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,7 +29,7 @@ fn main() { let client_agent_sse = ClientAgent::blank(redis_cfg); let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); - let pg_pool = user::PostgresPool::new(postgres_cfg); + let pg_pool = user::PgPool::new(postgres_cfg); log::warn!("Streaming server initialized and ready to accept connections"); diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs index 981a187..9dc5379 100644 --- a/src/parse_client_request/sse.rs +++ b/src/parse_client_request/sse.rs @@ -1,7 +1,7 @@ //! Filters for all the endpoints accessible for Server Sent Event updates use super::{ query::{self, Query}, - user::{PostgresPool, User}, + user::{PgPool, User}, }; use warp::{filters::BoxedFilter, path, Filter}; #[allow(dead_code)] @@ -39,7 +39,7 @@ macro_rules! parse_query { .boxed() }; } -pub fn extract_user_or_reject(pg_pool: PostgresPool) -> BoxedFilter<(User,)> { +pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(User,)> { any_of!( parse_query!( path => "api" / "v1" / "streaming" / "user" / "notification" @@ -74,7 +74,7 @@ pub fn extract_user_or_reject(pg_pool: PostgresPool) -> BoxedFilter<(User,)> { #[cfg(test)] mod test { use super::*; - use crate::parse_client_request::user::{Filter, OauthScope, PostgresPool}; + use crate::parse_client_request::user::{Blocks, Filter, OauthScope, PgPool}; macro_rules! test_public_endpoint { ($name:ident { @@ -83,7 +83,7 @@ mod test { }) => { #[test] fn $name() { - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); let user = warp::test::request() .path($path) .filter(&extract_user_or_reject(mock_pg_pool)) @@ -101,7 +101,7 @@ mod test { #[test] fn $name() { let path = format!("{}?access_token=TEST_USER", $path); - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); $(let path = format!("{}&{}", path, $query);)* let user = warp::test::request() .path(&path) @@ -127,7 +127,7 @@ mod test { fn $name() { let path = format!("{}?access_token=INVALID", $path); $(let path = format!("{}&{}", path, $query);)* - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path(&path) .filter(&extract_user_or_reject(mock_pg_pool)) @@ -146,7 +146,7 @@ mod test { let path = $path; $(let path = format!("{}?{}", path, $query);)* - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path(&path) .header("Authorization", "Bearer: INVALID") @@ -165,7 +165,7 @@ mod test { fn $name() { let path = $path; $(let path = format!("{}?{}", path, $query);)* - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path(&path) .filter(&extract_user_or_reject(mock_pg_pool)) @@ -180,7 +180,7 @@ mod test { target_timeline: "public:media".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -189,6 +189,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -198,7 +199,7 @@ mod test { target_timeline: "public:media".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -207,6 +208,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -216,7 +218,7 @@ mod test { target_timeline: "public:local".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -225,6 +227,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -234,7 +237,7 @@ mod test { target_timeline: "public:local:media".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -243,6 +246,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -252,7 +256,7 @@ mod test { target_timeline: "public:local:media".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -261,6 +265,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -270,7 +275,7 @@ mod test { target_timeline: "hashtag:a".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -279,6 +284,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -288,7 +294,7 @@ mod test { target_timeline: "hashtag:local:a".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -297,6 +303,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -316,6 +323,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::NoFilter, }, }); @@ -334,6 +342,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::Notification, }, }); @@ -352,6 +361,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::NoFilter, }, }); @@ -372,6 +382,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::NoFilter, }, }); @@ -448,7 +459,7 @@ mod test { #[test] #[should_panic(expected = "NotFound")] fn nonexistant_endpoint() { - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path("/api/v1/streaming/DOES_NOT_EXIST") .filter(&extract_user_or_reject(mock_pg_pool)) diff --git a/src/parse_client_request/user/mock_postgres.rs b/src/parse_client_request/user/mock_postgres.rs index 6f4739d..d84d678 100644 --- a/src/parse_client_request/user/mock_postgres.rs +++ b/src/parse_client_request/user/mock_postgres.rs @@ -1,37 +1,40 @@ //! Mock Postgres connection (for use in unit testing) +use super::{OauthScope, User}; +use std::collections::HashSet; #[derive(Clone)] -pub struct PostgresPool; -impl PostgresPool { +pub struct PgPool; +impl PgPool { pub fn new() -> Self { Self } } -pub fn query_for_user_data( - access_token: &str, - _pg_pool: PostgresPool, -) -> (i64, String, Option>, Vec) { - let (user_id, email, lang, scopes) = if access_token == "TEST_USER" { - ( - 1, - "user@example.com".to_string(), - None, - vec![ - "read".to_string(), - "write".to_string(), - "follow".to_string(), - ], - ) - } else { - (-1, "".to_string(), None, Vec::new()) - }; - (user_id, email, lang, scopes) +pub fn select_user(access_token: &str, _pg_pool: PgPool) -> Result { + let mut user = User::default(); + if access_token == "TEST_USER" { + user.id = 1; + user.logged_in = true; + user.access_token = "TEST_USER".to_string(); + user.email = "user@example.com".to_string(); + user.scopes = OauthScope::from(vec![ + "read".to_string(), + "write".to_string(), + "follow".to_string(), + ]); + } else if access_token == "INVALID" { + return Err(warp::reject::custom("Error: Invalid access token")); + } + Ok(user) } -pub fn query_list_owner(list_id: i64, _pg_pool: PostgresPool) -> Option { - match list_id { - 1 => Some(1), - _ => None, - } +pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet { + HashSet::new() +} +pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet { + HashSet::new() +} + +pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool { + user_id == list_id } diff --git a/src/parse_client_request/user/mod.rs b/src/parse_client_request/user/mod.rs index 6d2dedc..8cd2cd4 100644 --- a/src/parse_client_request/user/mod.rs +++ b/src/parse_client_request/user/mod.rs @@ -5,8 +5,9 @@ mod mock_postgres; use mock_postgres as postgres; #[cfg(not(test))] mod postgres; -pub use self::postgres::PostgresPool; +pub use self::postgres::PgPool; use super::query::Query; +use std::collections::HashSet; use warp::reject::Rejection; /// The filters that can be applied to toots after they come from Redis @@ -18,23 +19,10 @@ pub enum Filter { } impl Default for Filter { fn default() -> Self { - Filter::NoFilter + Filter::Language } } -/// The User (with data read from Postgres) -#[derive(Clone, Debug, Default, PartialEq)] -pub struct User { - pub target_timeline: String, - pub email: String, // We only use email for logging; we could cut it for performance - pub id: i64, - pub access_token: String, - pub scopes: OauthScope, - pub langs: Option>, - pub logged_in: bool, - pub filter: Filter, -} - #[derive(Clone, Debug, Default, PartialEq)] pub struct OauthScope { pub all: bool, @@ -58,51 +46,59 @@ impl From> for OauthScope { } } +#[derive(Clone, Default, Debug, PartialEq)] +pub struct Blocks { + pub domain_blocks: HashSet, + pub user_blocks: HashSet, +} + +/// The User (with data read from Postgres) +#[derive(Clone, Debug, PartialEq)] +pub struct User { + pub target_timeline: String, + pub email: String, // We only use email for logging; we could cut it for performance + pub access_token: String, // We only need this once (to send back with the WS reply). Cut? + pub id: i64, + pub scopes: OauthScope, + pub langs: Option>, + pub logged_in: bool, + pub filter: Filter, + pub blocks: Blocks, +} + +impl Default for User { + fn default() -> Self { + Self { + id: -1, + email: "".to_string(), + access_token: "".to_string(), + scopes: OauthScope::default(), + langs: None, + logged_in: false, + target_timeline: String::new(), + filter: Filter::default(), + blocks: Blocks::default(), + } + } +} + impl User { - pub fn from_query(q: Query, pg_pool: PostgresPool) -> Result { - let (id, access_token, email, scopes, langs, logged_in) = match q.access_token.clone() { - None => ( - -1, - "no access token".to_owned(), - "".to_string(), - OauthScope::default(), - None, - false, - ), - Some(token) => { - let (id, email, langs, scope_list) = - postgres::query_for_user_data(&token, pg_pool.clone()); - - if id == -1 { - return Err(warp::reject::custom("Error: Invalid access token")); - } - let scopes = OauthScope::from(scope_list); - (id, token, email, scopes, langs, true) - } - }; - let mut user = User { - id, - email, - target_timeline: "PLACEHOLDER".to_string(), - access_token, - scopes, - langs, - logged_in, - filter: Filter::Language, + pub fn from_query(q: Query, pool: PgPool) -> Result { + println!("Creating user..."); + let mut user: User = match q.access_token.clone() { + None => User::default(), + Some(token) => postgres::select_user(&token, pool.clone())?, }; - user = user.update_timeline_and_filter(q, pg_pool.clone())?; - + user = user.set_timeline_and_filter(q, pool.clone())?; + user.blocks.user_blocks = postgres::select_user_blocks(user.id, pool.clone()); + user.blocks.domain_blocks = postgres::select_domain_blocks(pool.clone()); + dbg!(&user); Ok(user) } - fn update_timeline_and_filter( - mut self, - q: Query, - pg_pool: PostgresPool, - ) -> Result { + fn set_timeline_and_filter(mut self, q: Query, pool: PgPool) -> Result { let read_scope = self.scopes.clone(); - let timeline = match q.stream.as_ref() { // Public endpoints: tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl), @@ -110,7 +106,7 @@ impl User { tl @ "public" | tl @ "public:local" => tl.to_string(), // Hashtag endpoints: tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag), - // Private endpoints: User + // Private endpoints: User: "user" if self.logged_in && (read_scope.all || read_scope.statuses) => { self.filter = Filter::NoFilter; format!("{}", self.id) @@ -120,7 +116,7 @@ impl User { format!("{}", self.id) } // List endpoint: - "list" if self.owns_list(q.list, pg_pool) && (read_scope.all || read_scope.lists) => { + "list" if self.owns_list(q.list, pool) && (read_scope.all || read_scope.lists) => { self.filter = Filter::NoFilter; format!("list:{}", q.list) } @@ -142,11 +138,7 @@ impl User { }) } - /// Determine whether the User is authorised for a specified list - pub fn owns_list(&self, list: i64, pg_pool: PostgresPool) -> bool { - match postgres::query_list_owner(list, pg_pool) { - Some(i) if i == self.id => true, - _ => false, - } + fn owns_list(&self, list: i64, pool: PgPool) -> bool { + postgres::user_owns_list(self.id, list, pool) } } diff --git a/src/parse_client_request/user/postgres.rs b/src/parse_client_request/user/postgres.rs index 6c95ed7..4c87d9f 100644 --- a/src/parse_client_request/user/postgres.rs +++ b/src/parse_client_request/user/postgres.rs @@ -1,11 +1,16 @@ //! Postgres queries -use crate::config; +use crate::{ + config, + parse_client_request::user::{OauthScope, User}, +}; use ::postgres; use r2d2_postgres::PostgresConnectionManager; +use std::collections::HashSet; +use warp::reject::Rejection; #[derive(Clone)] -pub struct PostgresPool(pub r2d2::Pool>); -impl PostgresPool { +pub struct PgPool(pub r2d2::Pool>); +impl PgPool { pub fn new(pg_cfg: config::PostgresConfig) -> Self { let mut cfg = postgres::Config::new(); cfg.user(&pg_cfg.user) @@ -25,12 +30,12 @@ impl PostgresPool { } } -pub fn query_for_user_data( - access_token: &str, - pg_pool: PostgresPool, -) -> (i64, String, Option>, Vec) { +/// Build a user based on the result of querying Postgres with the access token +/// +/// This does _not_ set the timeline, filter, or blocks fields. Use the various `User` +/// methods to do so. In general, this function shouldn't be needed outside `User`. +pub fn select_user(access_token: &str, pg_pool: PgPool) -> Result { let mut conn = pg_pool.0.get().unwrap(); - let query_result = conn .query( " @@ -45,41 +50,70 @@ LIMIT 1", &[&access_token.to_owned()], ) .expect("Hard-coded query will return Some([0 or more rows])"); - if !query_result.is_empty() { + if query_result.is_empty() { + Err(warp::reject::custom("Error: Invalid access token")) + } else { let only_row: &postgres::Row = query_result.get(0).unwrap(); - let id: i64 = only_row.get(1); - let email: String = only_row.get(2); - let scopes = only_row + let scope_vec: Vec = only_row .get::<_, String>(4) .split(' ') .map(|s| s.to_owned()) .collect(); - let langs: Option> = only_row.get(3); - (id, email, langs, scopes) - } else { - (-1, "".to_string(), None, Vec::new()) + Ok(User { + id: only_row.get(1), + access_token: access_token.to_string(), + email: only_row.get(2), + logged_in: true, + scopes: OauthScope::from(scope_vec), + langs: only_row.get(3), + ..User::default() + }) } } -#[cfg(test)] -pub fn query_for_user_data(access_token: &str) -> (i64, Option>, Vec) { - let (user_id, lang, scopes) = if access_token == "TEST_USER" { - ( - 1, - None, - vec![ - "read".to_string(), - "write".to_string(), - "follow".to_string(), - ], +/// Query Postgres for everyone the user has blocked or muted +/// +/// **NOTE**: because we check this when the user connects, it will not include any blocks +/// the user adds until they refresh/reconnect. +pub fn select_user_blocks(user_id: i64, pg_pool: PgPool) -> HashSet { + pg_pool + .0 + .get() + .unwrap() + .query( + " +SELECT target_account_id + FROM blocks + WHERE account_id = $1 +UNION SELECT target_account_id + FROM mutes + WHERE account_id = $1", + &[&user_id], ) - } else { - (-1, None, Vec::new()) - }; - (user_id, lang, scopes) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() } -pub fn query_list_owner(list_id: i64, pg_pool: PostgresPool) -> Option { +/// Query Postgres for all current domain blocks +/// +/// **NOTE**: because we check this when the user connects, it will not include any blocks +/// the user adds until they refresh/reconnect. +pub fn select_domain_blocks(pg_pool: PgPool) -> HashSet { + pg_pool + .0 + .get() + .unwrap() + .query("SELECT domain FROM account_domain_blocks", &[]) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() +} + +/// Test whether a user owns a list +pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool { let mut conn = pg_pool.0.get().unwrap(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn @@ -92,9 +126,12 @@ LIMIT 1", &[&list_id], ) .expect("Hard-coded query will return Some([0 or more rows])"); - if rows.is_empty() { - None - } else { - Some(rows.get(0).unwrap().get(1)) + + match rows.get(0) { + None => false, + Some(row) => { + let list_owner_id: i64 = row.get(1); + list_owner_id == user_id + } } } diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index 777478c..ef71189 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -1,7 +1,7 @@ //! Filters for the WebSocket endpoint use super::{ query::{self, Query}, - user::{PostgresPool, User}, + user::{PgPool, User}, }; use warp::{filters::BoxedFilter, path, Filter}; @@ -32,7 +32,7 @@ fn parse_query() -> BoxedFilter<(Query,)> { .boxed() } -pub fn extract_user_or_reject(pg_pool: PostgresPool) -> BoxedFilter<(User,)> { +pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(User,)> { parse_query() .and(query::OptionalAccessToken::from_ws_header()) .and_then(Query::update_access_token) @@ -43,7 +43,7 @@ pub fn extract_user_or_reject(pg_pool: PostgresPool) -> BoxedFilter<(User,)> { #[cfg(test)] mod test { use super::*; - use crate::parse_client_request::user::{Filter, OauthScope}; + use crate::parse_client_request::user::{Blocks, Filter, OauthScope}; macro_rules! test_public_endpoint { ($name:ident { @@ -52,7 +52,7 @@ mod test { }) => { #[test] fn $name() { - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); let user = warp::test::request() .path($path) .header("connection", "upgrade") @@ -72,7 +72,7 @@ mod test { }) => { #[test] fn $name() { - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); let path = format!("{}&access_token=TEST_USER", $path); let user = warp::test::request() .path(&path) @@ -96,7 +96,7 @@ mod test { fn $name() { let path = format!("{}&access_token=INVALID", $path); - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path(&path) .filter(&extract_user_or_reject(mock_pg_pool)) @@ -112,7 +112,7 @@ mod test { #[should_panic(expected = "Error: Missing access token")] fn $name() { let path = $path; - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path(&path) .filter(&extract_user_or_reject(mock_pg_pool)) @@ -127,7 +127,7 @@ mod test { target_timeline: "public:media".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -136,6 +136,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -145,7 +146,7 @@ mod test { target_timeline: "public:local".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -154,6 +155,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -163,7 +165,7 @@ mod test { target_timeline: "public:local:media".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -172,6 +174,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -181,7 +184,7 @@ mod test { target_timeline: "hashtag:a".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -190,6 +193,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -199,7 +203,7 @@ mod test { target_timeline: "hashtag:local:a".to_string(), id: -1, email: "".to_string(), - access_token: "no access token".to_string(), + access_token: "".to_string(), langs: None, scopes: OauthScope { all: false, @@ -208,6 +212,7 @@ mod test { lists: false, }, logged_in: false, + blocks: Blocks::default(), filter: Filter::Language, }, }); @@ -227,6 +232,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::NoFilter, }, }); @@ -245,6 +251,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::Notification, }, }); @@ -263,6 +270,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::NoFilter, }, }); @@ -281,6 +289,7 @@ mod test { lists: false, }, logged_in: true, + blocks: Blocks::default(), filter: Filter::NoFilter, }, }); @@ -325,7 +334,7 @@ mod test { #[test] #[should_panic(expected = "NotFound")] fn nonexistant_endpoint() { - let mock_pg_pool = PostgresPool::new(); + let mock_pg_pool = PgPool::new(); warp::test::request() .path("/api/v1/streaming/DOES_NOT_EXIST") .header("connection", "upgrade") diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs index 00aab56..dae8b6a 100644 --- a/src/redis_to_client_stream/client_agent.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -19,7 +19,7 @@ use super::receiver::Receiver; use crate::{config, parse_client_request::user::User}; use futures::{Async, Poll}; use serde_json::Value; -use std::sync; +use std::{collections::HashSet, sync}; use tokio::io::Error; use uuid::Uuid; @@ -110,6 +110,7 @@ impl futures::stream::Stream for ClientAgent { } /// The message to send to the client (which might not literally be a toot in some cases). +#[derive(Debug, Clone)] pub struct Toot { pub category: String, pub payload: Value, @@ -118,7 +119,7 @@ pub struct Toot { impl Toot { /// Construct a `Toot` from well-formed JSON. - fn from_json(value: Value) -> Self { + pub fn from_json(value: Value) -> Self { let category = value["event"].as_str().expect("Redis string").to_owned(); let language = if category == "update" { Some(value["payload"]["language"].to_string()) @@ -133,8 +134,44 @@ impl Toot { } } + pub fn get_originating_domain(&self) -> HashSet { + let api = "originating Invariant Violation: JSON value does not conform to Mastdon API"; + let mut originating_domain = HashSet::new(); + originating_domain.insert( + self.payload["account"]["acct"] + .as_str() + .expect(&api) + .split("@") + .nth(1) + .expect(&api) + .to_string(), + ); + originating_domain + } + + pub fn get_involved_users(&self) -> HashSet { + let mut involved_users: HashSet = HashSet::new(); + let msg = self.payload.clone(); + + let api = "Invariant Violation: JSON value does not conform to Mastdon API"; + involved_users.insert(msg["account"]["id"].str_to_i64().expect(&api)); + if let Some(mentions) = msg["mentions"].as_array() { + for mention in mentions { + involved_users.insert(mention["id"].str_to_i64().expect(&api)); + } + } + if let Some(replied_to_account) = msg["in_reply_to_account_id"].as_str() { + involved_users.insert(replied_to_account.parse().expect(&api)); + } + + if let Some(reblog) = msg["reblog"].as_object() { + involved_users.insert(reblog["account"]["id"].str_to_i64().expect(&api)); + } + involved_users + } + /// Filter out any `Toot`'s that fail the provided filter. - fn filter(self, user: &User) -> Result>, Error> { + pub fn filter(self, user: &User) -> Result>, Error> { let toot = self; let category = toot.category.clone(); @@ -161,3 +198,17 @@ impl Toot { } } } + +trait ConvertValue { + fn str_to_i64(&self) -> Result>; +} + +impl ConvertValue for Value { + fn str_to_i64(&self) -> Result> { + Ok(self + .as_str() + .ok_or(format!("{} is not a string", &self))? + .parse() + .map_err(|_| "Could not parse str")?) + } +} diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index 95b0a64..ee5b0c6 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -82,29 +82,36 @@ pub fn send_updates_to_ws( let mut time = time::Instant::now(); - let (tl, email, id) = ( + let (tl, email, id, blocked_users, blocked_domains) = ( client_agent.current_user.target_timeline.clone(), client_agent.current_user.email.clone(), client_agent.current_user.id, + client_agent.current_user.blocks.user_blocks.clone(), + client_agent.current_user.blocks.domain_blocks.clone(), ); // Every time you get an event from that stream, send it through the pipe event_stream .for_each(move |_instant| { if let Ok(Async::Ready(Some(toot))) = client_agent.poll() { - let txt = &toot.payload["content"]; - log::warn!("toot: {}\n in TL: {}\nuser: {}({})", txt, tl, email, id); + if blocked_domains.is_disjoint(&toot.get_originating_domain()) + && blocked_users.is_disjoint(&toot.get_involved_users()) + { + let txt = &toot.payload["content"]; + log::warn!("toot: {}\nTL: {}\nUser: {}({})", txt, tl, email, id); - let msg = warp::ws::Message::text( - json!({"event": toot.category, - "payload": toot.payload.to_string()}) - .to_string(), - ); - - tx.unbounded_send(msg).expect("No send error"); + tx.unbounded_send(warp::ws::Message::text( + json!({ "event": toot.category, + "payload": &toot.payload.to_string() }) + .to_string(), + )) + .expect("No send error"); + } else { + log::info!("Blocked a message to {}", email); + } }; if time.elapsed() > time::Duration::from_secs(30) { - let msg = warp::ws::Message::text("{}"); - tx.unbounded_send(msg).expect("Can ping"); + tx.unbounded_send(warp::ws::Message::text("{}")) + .expect("Can ping"); time = time::Instant::now(); } Ok(())