diff --git a/src/config.rs b/src/config.rs index 96e32bd..891a9b2 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,6 +14,7 @@ mod redis_cfg; mod redis_cfg_types; pub fn merge_dotenv() -> Result<(), err::FatalErr> { + // TODO -- should this allow the user to run in a dir without a `.env` file? dotenv::from_filename(match env::var("ENV").ok().as_deref() { Some("production") => ".env.production", Some("development") | None => ".env", diff --git a/src/config/deployment_cfg_types.rs b/src/config/deployment_cfg_types.rs index 97d018b..aaa7996 100644 --- a/src/config/deployment_cfg_types.rs +++ b/src/config/deployment_cfg_types.rs @@ -92,7 +92,7 @@ impl fmt::Debug for Cors<'_> { } } -#[derive(EnumString, EnumVariantNames, Debug)] +#[derive(EnumString, EnumVariantNames, Debug, Clone)] #[strum(serialize_all = "snake_case")] pub enum LogLevelInner { Trace, @@ -102,7 +102,7 @@ pub enum LogLevelInner { Error, } -#[derive(EnumString, EnumVariantNames, Debug)] +#[derive(EnumString, EnumVariantNames, Debug, Clone)] #[strum(serialize_all = "snake_case")] pub enum EnvInner { Production, diff --git a/src/config/environmental_variables.rs b/src/config/environmental_variables.rs index 33028b7..6df89f0 100644 --- a/src/config/environmental_variables.rs +++ b/src/config/environmental_variables.rs @@ -1,6 +1,7 @@ use hashbrown::HashMap; use std::fmt; +#[derive(Debug)] pub struct EnvVar(pub HashMap); impl std::ops::Deref for EnvVar { type Target = HashMap; @@ -94,6 +95,7 @@ macro_rules! from_env_var { let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr); let from_str = |$arg:ident| $body:expr; ) => { + #[derive(Clone)] pub struct $name(pub $type); impl std::fmt::Debug for $name { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/src/config/postgres_cfg.rs b/src/config/postgres_cfg.rs index 69779ce..ef26f65 100644 --- a/src/config/postgres_cfg.rs +++ b/src/config/postgres_cfg.rs @@ -2,7 +2,7 @@ use super::{postgres_cfg_types::*, EnvVar}; use url::Url; use urlencoding; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Postgres { pub user: PgUser, pub host: PgHost, diff --git a/src/config/postgres_cfg_types.rs b/src/config/postgres_cfg_types.rs index 7551d1b..e6195dc 100644 --- a/src/config/postgres_cfg_types.rs +++ b/src/config/postgres_cfg_types.rs @@ -49,7 +49,7 @@ from_env_var!( let from_str = |s| PgSslInner::from_str(s).ok(); ); -#[derive(EnumString, EnumVariantNames, Debug)] +#[derive(EnumString, EnumVariantNames, Debug, Clone)] #[strum(serialize_all = "snake_case")] pub enum PgSslInner { Prefer, diff --git a/src/main.rs b/src/main.rs index 0a1e3c3..d0ebf0f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use flodgatt::config; use flodgatt::err::FatalErr; use flodgatt::messages::Event; -use flodgatt::request::{PgPool, Subscription, Timeline}; +use flodgatt::request::{self, Subscription, Timeline}; use flodgatt::response::redis; use flodgatt::response::stream; @@ -27,14 +27,16 @@ fn main() -> Result<(), FatalErr> { let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); - let shared_pg_conn = PgPool::new(postgres_cfg, *cfg.whitelist_mode); + let request_handler = request::Handler::new(postgres_cfg, *cfg.whitelist_mode); let poll_freq = *redis_cfg.polling_interval; let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc(); // Server Sent Events let sse_manager = shared_manager.clone(); let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone()); - let sse = Subscription::from_sse_request(shared_pg_conn.clone()) + + let sse = request_handler + .parse_sse_request() .and(warp::sse()) .map( move |subscription: Subscription, client_conn: warp::sse::Sse| { @@ -56,7 +58,8 @@ fn main() -> Result<(), FatalErr> { // WebSocket let ws_manager = shared_manager.clone(); - let ws = Subscription::from_ws_request(shared_pg_conn) + let ws = request_handler + .parse_ws_request() .and(warp::ws::ws2()) .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); diff --git a/src/request.rs b/src/request.rs index 9b5a28a..7e0e963 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,23 +1,53 @@ //! Parse the client request and return a Subscription mod postgres; mod query; +pub mod timeline; mod subscription; pub use self::postgres::PgPool; // TODO consider whether we can remove `Stream` from public API -pub use subscription::{Blocks, Stream, Subscription, Timeline}; -pub use subscription::{Content, Reach}; +pub use subscription::{Blocks, Subscription}; +pub use timeline::{Content, Reach, Stream, Timeline}; use self::query::Query; use crate::config; -use warp::{filters::BoxedFilter, path, reject::Rejection, Filter}; +use warp::{filters::BoxedFilter, path, Filter}; #[cfg(test)] mod sse_test; #[cfg(test)] mod ws_test; +/// Helper macro to match on the first of any of the provided filters +macro_rules! any_of { + ($filter:expr, $($other_filter:expr),*) => { + $filter$(.or($other_filter).unify())*.boxed() + }; +} +macro_rules! parse_sse_query { + (path => $start:tt $(/ $next:tt)* + endpoint => $endpoint:expr) => { + path!($start $(/ $next)*) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map(|auth: query::Auth, media: query::Media, hashtag: query::Hashtag, list: query::List| { + Query { + access_token: auth.access_token, + stream: $endpoint.to_string(), + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() + }; +} + +#[derive(Debug, Clone)] pub struct Handler { pg_conn: PgPool, } @@ -29,14 +59,47 @@ impl Handler { } } - pub fn from_ws_request(&self) -> BoxedFilter<(Subscription,)> { + pub fn parse_ws_request(&self) -> BoxedFilter<(Subscription,)> { let pg_conn = self.pg_conn.clone(); parse_ws_query() .and(query::OptionalAccessToken::from_ws_header()) .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_conn.clone())) + .and_then(move |q| Subscription::query_postgres(q, pg_conn.clone())) .boxed() } + + pub fn parse_sse_request(&self) -> BoxedFilter<(Subscription,)> { + let pg_conn = self.pg_conn.clone(); + any_of!( + parse_sse_query!( + path => "api" / "v1" / "streaming" / "user" / "notification" + endpoint => "user:notification" ), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "user" + endpoint => "user"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "public" / "local" + endpoint => "public:local"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "public" + endpoint => "public"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "direct" + endpoint => "direct"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" + endpoint => "hashtag:local"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" + endpoint => "hashtag"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "list" + endpoint => "list") + ) + // because SSE requests place their `access_token` in the header instead of in a query + // parameter, we need to update our Query if the header has a token + .and(query::OptionalAccessToken::from_sse_header()) + .and_then(Query::update_access_token) + .and_then(move |q| Subscription::query_postgres(q, pg_conn.clone())) + .boxed() + } } fn parse_ws_query() -> BoxedFilter<(Query,)> { diff --git a/src/request/postgres.rs b/src/request/postgres.rs index c94b26b..1dd2ce5 100644 --- a/src/request/postgres.rs +++ b/src/request/postgres.rs @@ -1,12 +1,12 @@ //! Postgres queries -use crate::{ - config, - messages::Id, - request::subscription::{Scope, UserData}, -}; +use crate::config; +use crate::messages::Id; +use crate::request::timeline::{Scope, UserData}; + use ::postgres; use hashbrown::HashSet; use r2d2_postgres::PostgresConnectionManager; +use std::convert::TryFrom; use warp::reject::Rejection; #[derive(Clone, Debug)] @@ -14,6 +14,7 @@ pub struct PgPool { pub conn: r2d2::Pool>, whitelist_mode: bool, } + impl PgPool { pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self { let mut cfg = postgres::Config::new(); @@ -40,15 +41,11 @@ impl PgPool { let mut conn = self.conn.get().unwrap(); if let Some(token) = token { let query_rows = conn - .query( - " + .query(" SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes -FROM -oauth_access_tokens -INNER JOIN users ON -oauth_access_tokens.resource_owner_id = users.id -WHERE oauth_access_tokens.token = $1 -AND oauth_access_tokens.revoked_at IS NULL + FROM oauth_access_tokens +INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id + WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1", &[&token.to_owned()], ) @@ -57,29 +54,20 @@ LIMIT 1", let id = Id(result_columns.get(1)); let allowed_langs = result_columns .try_get::<_, Vec<_>>(2) - .unwrap_or_else(|_| Vec::new()) + .unwrap_or_default() .into_iter() .collect(); + let mut scopes: HashSet = result_columns .get::<_, String>(3) .split(' ') - .filter_map(|scope| match scope { - "read" => Some(Scope::Read), - "read:statuses" => Some(Scope::Statuses), - "read:notifications" => Some(Scope::Notifications), - "read:lists" => Some(Scope::Lists), - "write" | "follow" => None, // ignore write scopes - unexpected => { - log::warn!("Ignoring unknown scope `{}`", unexpected); - None - } - }) + .filter_map(|scope| Scope::try_from(scope).ok()) .collect(); // We don't need to separately track read auth - it's just all three others - if scopes.remove(&Scope::Read) { - scopes.insert(Scope::Statuses); - scopes.insert(Scope::Notifications); - scopes.insert(Scope::Lists); + if scopes.contains(&Scope::Read) { + scopes = vec![Scope::Statuses, Scope::Notifications, Scope::Lists] + .into_iter() + .collect() } Ok(UserData { @@ -98,19 +86,10 @@ LIMIT 1", } pub fn select_hashtag_id(self, tag_name: &str) -> Result { - let mut conn = self.conn.get().unwrap(); - let rows = &conn - .query( - " -SELECT id -FROM tags -WHERE name = $1 -LIMIT 1", - &[&tag_name], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - rows.get(0) + let mut conn = self.conn.get().expect("TODO"); + conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name]) + .expect("Hard-coded query will return Some([0 or more rows])") + .get(0) .map(|row| row.get(0)) .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) } @@ -120,43 +99,31 @@ LIMIT 1", /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_users(self, user_id: Id) -> HashSet { - self.conn - .get() - .unwrap() - .query( - " -SELECT target_account_id - FROM blocks - WHERE account_id = $1 -UNION SELECT target_account_id - FROM mutes - WHERE account_id = $1", - &[&*user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| Id(row.get(0))) - .collect() + let mut conn = self.conn.get().expect("TODO"); + conn.query( + "SELECT target_account_id FROM blocks WHERE account_id = $1 + UNION SELECT target_account_id FROM mutes WHERE account_id = $1", + &[&*user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| Id(row.get(0))) + .collect() } /// Query Postgres for everyone who has blocked the user /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocking_users(self, user_id: Id) -> HashSet { - self.conn - .get() - .unwrap() - .query( - " -SELECT account_id - FROM blocks - WHERE target_account_id = $1", - &[&*user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| Id(row.get(0))) - .collect() + let mut conn = self.conn.get().expect("TODO"); + conn.query( + "SELECT account_id FROM blocks WHERE target_account_id = $1", + &[&*user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| Id(row.get(0))) + .collect() } /// Query Postgres for all current domain blocks @@ -164,37 +131,27 @@ SELECT account_id /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_domains(self, user_id: Id) -> HashSet { - self.conn - .get() - .unwrap() - .query( - "SELECT domain FROM account_domain_blocks WHERE account_id = $1", - &[&*user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() + let mut conn = self.conn.get().expect("TODO"); + conn.query( + "SELECT domain FROM account_domain_blocks WHERE account_id = $1", + &[&*user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() } /// Test whether a user owns a list pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool { - let mut conn = self.conn.get().unwrap(); + let mut conn = self.conn.get().expect("TODO"); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn .query( - " -SELECT id, account_id -FROM lists -WHERE id = $1 -LIMIT 1", + "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", &[&list_id], ) .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - None => false, - Some(row) => Id(row.get(1)) == user_id, - } + rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id) } } diff --git a/src/request/subscription.rs b/src/request/subscription.rs index 7fe5810..447faeb 100644 --- a/src/request/subscription.rs +++ b/src/request/subscription.rs @@ -6,47 +6,13 @@ // #[cfg(not(test))] use super::postgres::PgPool; -use super::query; use super::query::Query; -use crate::err::TimelineErr; - +use super::{Content, Reach, Stream, Timeline}; use crate::messages::Id; use hashbrown::HashSet; -use lru::LruCache; -use warp::{filters::BoxedFilter, path, reject::Rejection, Filter}; -/// Helper macro to match on the first of any of the provided filters -macro_rules! any_of { - ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())*.boxed() - }; -} -macro_rules! parse_sse_query { - (path => $start:tt $(/ $next:tt)* - endpoint => $endpoint:expr) => { - path!($start $(/ $next)*) - .and(query::Auth::to_filter()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .map( - |auth: query::Auth, - media: query::Media, - hashtag: query::Hashtag, - list: query::List| { - Query { - access_token: auth.access_token, - stream: $endpoint.to_string(), - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - } - }, - ) - .boxed() - }; -} +use warp::reject::Rejection; #[derive(Clone, Debug, PartialEq)] pub struct Subscription { @@ -77,49 +43,24 @@ impl Default for Subscription { } impl Subscription { - pub fn from_ws_request(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> { - parse_ws_query() - .and(query::OptionalAccessToken::from_ws_header()) - .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone())) - .boxed() - } - - pub fn from_sse_request(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> { - any_of!( - parse_sse_query!( - path => "api" / "v1" / "streaming" / "user" / "notification" - endpoint => "user:notification" ), - parse_sse_query!( - path => "api" / "v1" / "streaming" / "user" - endpoint => "user"), - parse_sse_query!( - path => "api" / "v1" / "streaming" / "public" / "local" - endpoint => "public:local"), - parse_sse_query!( - path => "api" / "v1" / "streaming" / "public" - endpoint => "public"), - parse_sse_query!( - path => "api" / "v1" / "streaming" / "direct" - endpoint => "direct"), - parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" - endpoint => "hashtag:local"), - parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" - endpoint => "hashtag"), - parse_sse_query!(path => "api" / "v1" / "streaming" / "list" - endpoint => "list") - ) - // because SSE requests place their `access_token` in the header instead of in a query - // parameter, we need to update our Query if the header has a token - .and(query::OptionalAccessToken::from_sse_header()) - .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone())) - .boxed() - } - - pub(super) fn from_query(q: Query, pool: PgPool) -> Result { + pub(super) fn query_postgres(q: Query, pool: PgPool) -> Result { let user = pool.clone().select_user(&q.access_token)?; - let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?; + let timeline = { + let tl = Timeline::from_query_and_user(&q, &user)?; + let pool = pool.clone(); + use Stream::*; + match tl { + Timeline(Hashtag(_), reach, stream) => { + let tag = pool.select_hashtag_id(&q.hashtag)?; + Timeline(Hashtag(tag), reach, stream) + } + Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => { + Err(warp::reject::custom("Error: Missing access token"))? + } + other_tl => other_tl, + } + }; + let hashtag_name = match timeline { Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag), _non_hashtag_timeline => None, @@ -138,179 +79,3 @@ impl Subscription { }) } } - -fn parse_ws_query() -> BoxedFilter<(Query,)> { - path!("api" / "v1" / "streaming") - .and(path::end()) - .and(warp::query()) - .and(query::Auth::to_filter()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .map( - |stream: query::Stream, - auth: query::Auth, - media: query::Media, - hashtag: query::Hashtag, - list: query::List| { - Query { - access_token: auth.access_token, - stream: stream.stream, - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - } - }, - ) - .boxed() -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub struct Timeline(pub Stream, pub Reach, pub Content); - -impl Timeline { - pub fn empty() -> Self { - use {Content::*, Reach::*, Stream::*}; - Self(Unset, Local, Notification) - } - - pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result { - use {Content::*, Reach::*, Stream::*}; - Ok(match self { - Timeline(Public, Federated, All) => "timeline:public".into(), - Timeline(Public, Local, All) => "timeline:public:local".into(), - Timeline(Public, Federated, Media) => "timeline:public:media".into(), - Timeline(Public, Local, Media) => "timeline:public:local:media".into(), - // TODO -- would `.push_str` be faster here? - Timeline(Hashtag(_id), Federated, All) => format!( - "timeline:hashtag:{}", - hashtag.ok_or_else(|| TimelineErr::MissingHashtag)? - ), - Timeline(Hashtag(_id), Local, All) => format!( - "timeline:hashtag:{}:local", - hashtag.ok_or_else(|| TimelineErr::MissingHashtag)? - ), - Timeline(User(id), Federated, All) => format!("timeline:{}", id), - Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), - Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), - Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), - Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?, - }) - } - - pub fn from_redis_text( - timeline: &str, - cache: &mut LruCache, - ) -> Result { - let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) { - Some(id) => Ok(*id), - None => Err(TimelineErr::InvalidInput), // TODO more specific - }; - - use {Content::*, Reach::*, Stream::*}; - Ok(match &timeline.split(':').collect::>()[..] { - ["public"] => Timeline(Public, Federated, All), - ["public", "local"] => Timeline(Public, Local, All), - ["public", "media"] => Timeline(Public, Federated, Media), - ["public", "local", "media"] => Timeline(Public, Local, Media), - ["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All), - ["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All), - [id] => Timeline(User(id.parse()?), Federated, All), - [id, "notification"] => Timeline(User(id.parse()?), Federated, Notification), - ["list", id] => Timeline(List(id.parse()?), Federated, All), - ["direct", id] => Timeline(Direct(id.parse()?), Federated, All), - // Other endpoints don't exist: - [..] => Err(TimelineErr::InvalidInput)?, - }) - } - - fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result { - use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; - let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag); - let user_owns_list = || pool.clone().user_owns_list(user.id, q.list); - - Ok(match q.stream.as_ref() { - "public" => match q.media { - true => Timeline(Public, Federated, Media), - false => Timeline(Public, Federated, All), - }, - "public:local" => match q.media { - true => Timeline(Public, Local, Media), - false => Timeline(Public, Local, All), - }, - "public:media" => Timeline(Public, Federated, Media), - "public:local:media" => Timeline(Public, Local, Media), - - "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All), - "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All), - "user" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, All), - false => Err(custom("Error: Missing access token"))?, - }, - "user:notification" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, Notification), - false => Err(custom("Error: Missing access token"))?, - }, - "list" => match user.scopes.contains(&Lists) && user_owns_list() { - true => Timeline(List(q.list), Federated, All), - false => Err(warp::reject::custom("Error: Missing access token"))?, - }, - "direct" => match user.scopes.contains(&Statuses) { - true => Timeline(Direct(*user.id), Federated, All), - false => Err(custom("Error: Missing access token"))?, - }, - other => { - log::warn!("Request for nonexistent endpoint: `{}`", other); - Err(custom("Error: Nonexistent endpoint"))? - } - }) - } -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Stream { - User(Id), - // TODO consider whether List, Direct, and Hashtag should all be `id::Id`s - List(i64), - Direct(i64), - Hashtag(i64), - Public, - Unset, -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Reach { - Local, - Federated, -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Content { - All, - Media, - Notification, -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Scope { - Read, - Statuses, - Notifications, - Lists, -} - -pub struct UserData { - pub id: Id, - pub allowed_langs: HashSet, - pub scopes: HashSet, -} - -impl UserData { - pub fn public() -> Self { - Self { - id: Id(-1), - allowed_langs: HashSet::new(), - scopes: HashSet::new(), - } - } -} diff --git a/src/request/timeline.rs b/src/request/timeline.rs new file mode 100644 index 0000000..883a27b --- /dev/null +++ b/src/request/timeline.rs @@ -0,0 +1,174 @@ +use super::query::Query; +use crate::err::TimelineErr; +use crate::messages::Id; + +use hashbrown::HashSet; +use lru::LruCache; +use std::convert::TryFrom; +use warp::reject::Rejection; + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub struct Timeline(pub Stream, pub Reach, pub Content); + +impl Timeline { + pub fn empty() -> Self { + use {Content::*, Reach::*, Stream::*}; + Self(Unset, Local, Notification) + } + + pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result { + use {Content::*, Reach::*, Stream::*}; + Ok(match self { + Timeline(Public, Federated, All) => "timeline:public".into(), + Timeline(Public, Local, All) => "timeline:public:local".into(), + Timeline(Public, Federated, Media) => "timeline:public:media".into(), + Timeline(Public, Local, Media) => "timeline:public:local:media".into(), + // TODO -- would `.push_str` be faster here? + Timeline(Hashtag(_id), Federated, All) => format!( + "timeline:hashtag:{}", + hashtag.ok_or_else(|| TimelineErr::MissingHashtag)? + ), + Timeline(Hashtag(_id), Local, All) => format!( + "timeline:hashtag:{}:local", + hashtag.ok_or_else(|| TimelineErr::MissingHashtag)? + ), + Timeline(User(id), Federated, All) => format!("timeline:{}", id), + Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), + Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), + Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), + Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?, + }) + } + + pub fn from_redis_text( + timeline: &str, + cache: &mut LruCache, + ) -> Result { + let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) { + Some(id) => Ok(*id), + None => Err(TimelineErr::InvalidInput), // TODO more specific + }; + + use {Content::*, Reach::*, Stream::*}; + Ok(match &timeline.split(':').collect::>()[..] { + ["public"] => Timeline(Public, Federated, All), + ["public", "local"] => Timeline(Public, Local, All), + ["public", "media"] => Timeline(Public, Federated, Media), + ["public", "local", "media"] => Timeline(Public, Local, Media), + ["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All), + ["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All), + [id] => Timeline(User(id.parse()?), Federated, All), + [id, "notification"] => Timeline(User(id.parse()?), Federated, Notification), + ["list", id] => Timeline(List(id.parse()?), Federated, All), + ["direct", id] => Timeline(Direct(id.parse()?), Federated, All), + // Other endpoints don't exist: + [..] => Err(TimelineErr::InvalidInput)?, + }) + } + + pub fn from_query_and_user(q: &Query, user: &UserData) -> Result { + use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; + + Ok(match q.stream.as_ref() { + "public" => match q.media { + true => Timeline(Public, Federated, Media), + false => Timeline(Public, Federated, All), + }, + "public:local" => match q.media { + true => Timeline(Public, Local, Media), + false => Timeline(Public, Local, All), + }, + "public:media" => Timeline(Public, Federated, Media), + "public:local:media" => Timeline(Public, Local, Media), + + "hashtag" => Timeline(Hashtag(0), Federated, All), + "hashtag:local" => Timeline(Hashtag(0), Local, All), + "user" => match user.scopes.contains(&Statuses) { + true => Timeline(User(user.id), Federated, All), + false => Err(custom("Error: Missing access token"))?, + }, + "user:notification" => match user.scopes.contains(&Statuses) { + true => Timeline(User(user.id), Federated, Notification), + false => Err(custom("Error: Missing access token"))?, + }, + "list" => match user.scopes.contains(&Lists) { + true => Timeline(List(q.list), Federated, All), + false => Err(warp::reject::custom("Error: Missing access token"))?, + }, + "direct" => match user.scopes.contains(&Statuses) { + true => Timeline(Direct(*user.id), Federated, All), + false => Err(custom("Error: Missing access token"))?, + }, + other => { + log::warn!("Request for nonexistent endpoint: `{}`", other); + Err(custom("Error: Nonexistent endpoint"))? + } + }) + } +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Stream { + User(Id), + // TODO consider whether List, Direct, and Hashtag should all be `id::Id`s + List(i64), + Direct(i64), + Hashtag(i64), + Public, + Unset, +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Reach { + Local, + Federated, +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Content { + All, + Media, + Notification, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Scope { + Read, + Statuses, + Notifications, + Lists, +} + +impl TryFrom<&str> for Scope { + type Error = TimelineErr; + + fn try_from(s: &str) -> Result { + match s { + "read" => Ok(Scope::Read), + "read:statuses" => Ok(Scope::Statuses), + "read:notifications" => Ok(Scope::Notifications), + "read:lists" => Ok(Scope::Lists), + "write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes + unexpected => { + log::warn!("Ignoring unknown scope `{}`", unexpected); + Err(TimelineErr::InvalidInput) + } + } + } +} + +pub struct UserData { + pub id: Id, + pub allowed_langs: HashSet, + pub scopes: HashSet, +} + +impl UserData { + pub fn public() -> Self { + Self { + id: Id(-1), + allowed_langs: HashSet::new(), + scopes: HashSet::new(), + } + } +} diff --git a/src/response/redis/connection.rs b/src/response/redis/connection.rs index c463d25..f4900c2 100644 --- a/src/response/redis/connection.rs +++ b/src/response/redis/connection.rs @@ -3,22 +3,16 @@ pub use err::RedisConnErr; use super::msg::{RedisParseErr, RedisParseOutput}; use super::ManagerErr; -use crate::{ - config::Redis, - messages::Event, - request::{Stream, Timeline}, -}; - -use std::{ - convert::{TryFrom, TryInto}, - io::{Read, Write}, - net::TcpStream, - str, - time::Duration, -}; - +use crate::config::Redis; +use crate::messages::Event; +use crate::request::{Stream, Timeline}; use futures::{Async, Poll}; use lru::LruCache; +use std::convert::{TryFrom, TryInto}; +use std::io::{Read, Write}; +use std::net::TcpStream; +use std::str; +use std::time::Duration; type Result = std::result::Result; @@ -46,7 +40,7 @@ impl RedisConn { // TODO: eventually, it might make sense to have Mastodon publish to timelines with // the tag number instead of the tag name. This would save us from dealing // with a cache here and would be consistent with how lists/users are handled. - redis_namespace: redis_cfg.namespace.clone(), + redis_namespace: redis_cfg.namespace.clone().0, redis_input: Vec::new(), }; Ok(redis_conn) @@ -61,14 +55,12 @@ impl RedisConn { self.redis_input.extend_from_slice(&buffer[..n]); break; } - Ok(n) => { - self.redis_input.extend_from_slice(&buffer[..n]); - } + Ok(n) => self.redis_input.extend_from_slice(&buffer[..n]), Err(_) => break, }; if first_read { size = 2000; - buffer = vec![0u8; size]; + buffer = vec![0_u8; size]; first_read = false; } } @@ -117,50 +109,6 @@ impl RedisConn { self.tag_name_cache.put(id, hashtag); } - fn new_connection(addr: &str, pass: Option<&String>) -> Result { - match TcpStream::connect(&addr) { - Ok(mut conn) => { - if let Some(password) = pass { - Self::auth_connection(&mut conn, &addr, password)?; - } - - Self::validate_connection(&mut conn, &addr)?; - conn.set_read_timeout(Some(Duration::from_millis(10))) - .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - Ok(conn) - } - Err(e) => Err(RedisConnErr::with_addr(&addr, e)), - } - } - fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> { - conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes()) - .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - let mut buffer = vec![0u8; 5]; - conn.read_exact(&mut buffer) - .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - let reply = String::from_utf8_lossy(&buffer); - match &*reply { - "+OK\r\n" => (), - _ => Err(RedisConnErr::IncorrectPassword(pass.to_string()))?, - }; - Ok(()) - } - - fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> { - conn.write_all(b"PING\r\n") - .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - let mut buffer = vec![0u8; 7]; - conn.read_exact(&mut buffer) - .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - let reply = String::from_utf8_lossy(&buffer); - match &*reply { - "+PONG\r\n" => Ok(()), - "-NOAUTH" => Err(RedisConnErr::MissingPassword), - "HTTP/1." => Err(RedisConnErr::NotRedis(addr.to_string())), - _ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())), - } - } - pub fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> { let hashtag = match timeline { Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id), @@ -182,6 +130,44 @@ impl RedisConn { self.secondary.write_all(&secondary_cmd.as_bytes())?; Ok(()) } + + fn new_connection(addr: &str, pass: Option<&String>) -> Result { + let mut conn = TcpStream::connect(&addr)?; + if let Some(password) = pass { + Self::auth_connection(&mut conn, &addr, password)?; + } + + Self::validate_connection(&mut conn, &addr)?; + conn.set_read_timeout(Some(Duration::from_millis(10))) + .map_err(|e| RedisConnErr::with_addr(&addr, e))?; + Ok(conn) + } + fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> { + conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes()) + .map_err(|e| RedisConnErr::with_addr(&addr, e))?; + let mut buffer = vec![0u8; 5]; + conn.read_exact(&mut buffer) + .map_err(|e| RedisConnErr::with_addr(&addr, e))?; + if String::from_utf8_lossy(&buffer) != "+OK\r\n" { + Err(RedisConnErr::IncorrectPassword(pass.to_string()))? + } + Ok(()) + } + + fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> { + conn.write_all(b"PING\r\n") + .map_err(|e| RedisConnErr::with_addr(&addr, e))?; + let mut buffer = vec![0u8; 7]; + conn.read_exact(&mut buffer) + .map_err(|e| RedisConnErr::with_addr(&addr, e))?; + let reply = String::from_utf8_lossy(&buffer); + match &*reply { + "+PONG\r\n" => Ok(()), + "-NOAUTH" => Err(RedisConnErr::MissingPassword), + "HTTP/1." => Err(RedisConnErr::NotRedis(addr.to_string())), + _ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())), + } + } } pub enum RedisCmd { diff --git a/src/response/stream.rs b/src/response/stream.rs index 60a97d4..b270696 100644 --- a/src/response/stream.rs +++ b/src/response/stream.rs @@ -132,43 +132,38 @@ impl Sse { let blocks = subscription.blocks; let event_stream = sse_rx + .filter(move |(timeline, _)| target_timeline == *timeline) .filter_map(move |(timeline, event)| { - if target_timeline == timeline { - use crate::messages::{ - CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind, - }; + use crate::messages::{ + CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind, + }; - use crate::request::Stream::Public; - match event { - TypeSafe(Update { payload, queued_at }) => match timeline { - Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, - _ if payload.involves_any(&blocks) => None, - _ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { - payload, - queued_at, - })), - }, - TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), - Dynamic(dyn_event) => { - if let EventKind::Update(s) = dyn_event.kind { - match timeline { - Timeline(Public, _, _) if s.language_not(&allowed_langs) => { - None - } - _ if s.involves_any(&blocks) => None, - _ => Self::reply_with(Dynamic(DynEvent { - kind: EventKind::Update(s), - ..dyn_event - })), - } - } else { - None + use crate::request::Stream::Public; + match event { + TypeSafe(Update { payload, queued_at }) => match timeline { + Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, + _ if payload.involves_any(&blocks) => None, + _ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { + payload, + queued_at, + })), + }, + TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), + Dynamic(dyn_event) => { + if let EventKind::Update(s) = dyn_event.kind { + match timeline { + Timeline(Public, _, _) if s.language_not(&allowed_langs) => None, + _ if s.involves_any(&blocks) => None, + _ => Self::reply_with(Dynamic(DynEvent { + kind: EventKind::Update(s), + ..dyn_event + })), } + } else { + None } - Ping => None, // pings handled automatically } - } else { - None + Ping => None, // pings handled automatically } }) .then(move |res| { @@ -186,3 +181,4 @@ impl Sse { ) } } +// TODO -- split WS and SSE into separate files and add misc stuff from main.rs here