diff --git a/Cargo.lock b/Cargo.lock index 17df6b6..39c09c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,27 @@ name = "fixedbitset" version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "flodgatt" +version = "0.2.0" +dependencies = [ + "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", + "futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", + "openssl 0.10.24 (registry+https://github.com/rust-lang/crates.io-index)", + "postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)", + "postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)", + "pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)", + "tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)", + "uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)", + "warp 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "fnv" version = "1.0.6" @@ -877,27 +898,6 @@ dependencies = [ "proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)", ] -[[package]] -name = "ragequit" -version = "0.1.0" -dependencies = [ - "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", - "futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", - "openssl 0.10.24 (registry+https://github.com/rust-lang/crates.io-index)", - "postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)", - "postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)", - "pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", - "serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)", - "tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)", - "uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)", - "warp 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)", -] - [[package]] name = "rand" version = "0.5.6" diff --git a/Cargo.toml b/Cargo.toml index 3cf0d23..dc0b381 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] -name = "ragequit" +name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.1.0" +version = "0.2.0" authors = ["Daniel Long Sockwell "] edition = "2018" diff --git a/src/lib.rs b/src/lib.rs index 2c4ec05..de167ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,26 +1,24 @@ //! Streaming server for Mastodon //! //! -//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server -//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to -//! connect to the server with the API described [in Mastodon's public API +//! This server provides live, streaming updates for Mastodon clients. Specifically, when a +//! server is running this sever, Mastodon clients can use either Server Sent Events or +//! WebSockets to connect to the server with the API described [in Mastodon's public API //! documentation](https://docs.joinmastodon.org/api/streaming/). //! //! # Data Flow -//! * **Parsing the client request** -//! When the client request first comes in, it is parsed based on the endpoint it targets (for -//! server sent events), its query parameters, and its headers (for WebSocket). Based on this -//! data, we authenticate the user, retrieve relevant user data from Postgres, and determine the -//! timeline targeted by the request. Successfully parsing the client request results in generating -//! a `User` and target `timeline` for the request. If any requests are invalid/not authorized, we -//! reject them in this stage. -//! * **Streaming update from Redis to the client**: -//! After the user request is parsed, we pass the `User` and `timeline` data on to the -//! `ClientAgent`. The `ClientAgent` is responsible for communicating the user's request to the -//! `Receiver`, polling the `Receiver` for any updates, and then for wording those updates on to the -//! client. The `Receiver`, in tern, is responsible for managing the Redis subscriptions, -//! periodically polling Redis, and sorting the replies from Redis into queues for when it is polled -//! by the `ClientAgent`. +//! * **Parsing the client request** When the client request first comes in, it is +//! parsed based on the endpoint it targets (for server sent events), its query parameters, +//! and its headers (for WebSocket). Based on this data, we authenticate the user, retrieve +//! relevant user data from Postgres, and determine the timeline targeted by the request. +//! Successfully parsing the client request results in generating a `User` corresponding to +//! the request. If any requests are invalid/not authorized, we reject them in this stage. +//! * **Streaming update from Redis to the client**: After the user request is parsed, we pass +//! the `User` data on to the `ClientAgent`. The `ClientAgent` is responsible for +//! communicating the user's request to the `Receiver`, polling the `Receiver` for any +//! updates, and then for wording those updates on to the client. The `Receiver`, in tern, is +//! responsible for managing the Redis subscriptions, periodically polling Redis, and sorting +//! the replies from Redis into queues for when it is polled by the `ClientAgent`. //! //! # Concurrency //! The `Receiver` is created when the server is first initialized, and there is only one @@ -31,11 +29,10 @@ //! that the `Receiver`'s poll of Redis be fast, since there will only ever be one //! `Receiver`. //! -//! # Configuration -//! By default, the server uses config values from the `config.rs` module; these values can be -//! overwritten with environmental variables or in the `.env` file. The most important settings -//! for performance control the frequency with which the `ClientAgent` polls the `Receiver` and -//! the frequency with which the `Receiver` polls Redis. +//! # Configuration By default, the server uses config values from the `config.rs` module; +//! these values can be overwritten with environmental variables or in the `.env` file. The +//! most important settings for performance control the frequency with which the `ClientAgent` +//! polls the `Receiver` and the frequency with which the `Receiver` polls Redis. //! pub mod config; pub mod parse_client_request; diff --git a/src/main.rs b/src/main.rs index fe26068..05e3e15 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,10 @@ -use log::{log_enabled, Level}; -use ragequit::{ +use flodgatt::{ config, parse_client_request::{sse, user, ws}, redis_to_client_stream, redis_to_client_stream::ClientAgent, }; +use log::{log_enabled, Level}; use warp::{ws::Ws2, Filter as WarpFilter}; fn main() { @@ -17,18 +17,14 @@ fn main() { }; // Server Sent Events - // - // For SSE, the API requires users to use different endpoints, so we first filter based on - // the endpoint. Using that endpoint determine the `timeline` the user is requesting, - // the scope for that `timeline`, and authenticate the `User` if they provided a token. - let sse_routes = sse::filter_incomming_request() + let sse_routes = sse::extract_user_or_reject() .and(warp::sse()) .map( - move |timeline: String, user: user::User, sse_connection_to_client: warp::sse::Sse| { + move |user: user::User, sse_connection_to_client: warp::sse::Sse| { // Create a new ClientAgent let mut client_agent = client_agent_sse.clone_with_shared_receiver(); - // Assign that agent to generate a stream of updates for the user/timeline pair - client_agent.init_for_user(&timeline, user); + // Assign ClientAgent to generate stream of updates for the user/timeline pair + client_agent.init_for_user(user); // send the updates through the SSE connection redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client) }, @@ -37,52 +33,17 @@ fn main() { .recover(config::handle_errors); // WebSocket - // - // For WS, the API specifies a single endpoint, so we extract the User/timeline pair - // directy from the query - let websocket_routes = ws::extract_user_and_query() - .and_then(move |mut user: user::User, q: ws::Query, ws: Ws2| { + let websocket_routes = ws::extract_user_or_reject() + .and(warp::ws::ws2()) + .and_then(move |user: user::User, ws: Ws2| { let token = user.access_token.clone(); - let read_scope = user.scopes.clone(); - - let timeline = match q.stream.as_ref() { - // Public endpoints: - tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl), - tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), - tl @ "public" | tl @ "public:local" => tl.to_string(), - // Hashtag endpoints: - tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag), - // Private endpoints: User - "user" if user.logged_in && (read_scope.all || read_scope.statuses) => { - format!("{}", user.id) - } - "user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => { - user = user.set_filter(user::Filter::Notification); - format!("{}", user.id) - } - // List endpoint: - "list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => { - format!("list:{}", q.list) - } - // Direct endpoint: - "direct" if user.logged_in && (read_scope.all || read_scope.statuses) => { - "direct".to_string() - } - // Reject unathorized access attempts for private endpoints - "user" | "user:notification" | "direct" | "list" => { - return Err(warp::reject::custom("Error: Invalid Access Token")) - } - // Other endpoints don't exist: - _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), - }; - // Create a new ClientAgent let mut client_agent = client_agent_ws.clone_with_shared_receiver(); // Assign that agent to generate a stream of updates for the user/timeline pair - client_agent.init_for_user(&timeline, user); + client_agent.init_for_user(user); // send the updates through the WS connection (along with the User's access_token // which is sent for security) - Ok(( + Ok::<_, warp::Rejection>(( ws.on_upgrade(move |socket| { redis_to_client_stream::send_updates_to_ws(socket, client_agent) }), diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs index d6a0ef7..b55fd8d 100644 --- a/src/parse_client_request/mod.rs +++ b/src/parse_client_request/mod.rs @@ -1,4 +1,4 @@ -//! Parse the client request and return a 'timeline' and a (maybe authenticated) `User` +//! Parse the client request and return a (possibly authenticated) `User` pub mod query; pub mod sse; pub mod user; diff --git a/src/parse_client_request/query.rs b/src/parse_client_request/query.rs index a14d559..211b9a7 100644 --- a/src/parse_client_request/query.rs +++ b/src/parse_client_request/query.rs @@ -3,8 +3,32 @@ use serde_derive::Deserialize; use warp::filters::BoxedFilter; use warp::Filter as WarpFilter; -macro_rules! query { - ($name:tt => $parameter:tt:$type:tt) => { +#[derive(Debug)] +pub struct Query { + pub access_token: Option, + pub stream: String, + pub media: bool, + pub hashtag: String, + pub list: i64, +} + +impl Query { + pub fn update_access_token( + self, + token: Option, + ) -> Result { + match token { + Some(token) => Ok(Self { + access_token: Some(token), + ..self + }), + None => Ok(self), + } + } +} + +macro_rules! make_query_type { + ($name:tt => $parameter:tt:$type:ty) => { #[derive(Deserialize, Debug, Default)] pub struct $name { pub $parameter: $type, @@ -19,16 +43,16 @@ macro_rules! query { } }; } -query!(Media => only_media:String); +make_query_type!(Media => only_media:String); impl Media { pub fn is_truthy(&self) -> bool { self.only_media == "true" || self.only_media == "1" } } -query!(Hashtag => tag: String); -query!(List => list: i64); -query!(Auth => access_token: String); -query!(Stream => stream: String); +make_query_type!(Hashtag => tag: String); +make_query_type!(List => list: i64); +make_query_type!(Auth => access_token: Option); +make_query_type!(Stream => stream: String); impl ToString for Stream { fn to_string(&self) -> String { format!("{:?}", self) @@ -43,3 +67,19 @@ pub fn optional_media_query() -> BoxedFilter<(Media,)> { .unify() .boxed() } + +pub struct OptionalAccessToken; + +impl OptionalAccessToken { + pub fn from_header() -> warp::filters::BoxedFilter<(Option,)> { + let from_header = warp::header::header::("authorization").map(|auth: String| { + match auth.split(' ').nth(1) { + Some(s) => Some(s.to_string()), + None => None, + } + }); + let no_token = warp::any().map(|| None); + + from_header.or(no_token).unify().boxed() + } +} diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs index ff3786b..cf3df0f 100644 --- a/src/parse_client_request/sse.rs +++ b/src/parse_client_request/sse.rs @@ -1,9 +1,5 @@ //! Filters for all the endpoints accessible for Server Sent Event updates -use super::{ - query, - user::{Filter::*, OptionalAccessToken, User}, -}; -use crate::config::CustomError; +use super::{query, query::Query, user::User}; use warp::{filters::BoxedFilter, path, Filter}; #[allow(dead_code)] @@ -12,140 +8,111 @@ type TimelineUser = ((String, User),); /// Helper macro to match on the first of any of the provided filters macro_rules! any_of { ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())* + $filter$(.or($other_filter).unify())*.boxed() }; } -pub fn filter_incomming_request() -> BoxedFilter<(String, User)> { +macro_rules! parse_query { + (path => $start:tt $(/ $next:tt)* + endpoint => $endpoint:expr) => { + path!($start $(/ $next)*) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map( + |auth: query::Auth, + media: query::Media, + hashtag: query::Hashtag, + list: query::List| { + Query { + access_token: auth.access_token, + stream: $endpoint.to_string(), + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() + }; +} +pub fn extract_user_or_reject() -> BoxedFilter<(User,)> { any_of!( - path!("api" / "v1" / "streaming" / "user" / "notification") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_reject) - .map(|user: User| (user.id.to_string(), user.set_filter(Notification))), - // **NOTE**: This endpoint was present in the node.js server, but not in the - // [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). - // Should it be publicly documented? - path!("api" / "v1" / "streaming" / "user") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_reject) - .map(|user: User| (user.id.to_string(), user)), - path!("api" / "v1" / "streaming" / "public" / "local") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_public_user) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)), - _ => ("public:local".to_owned(), user.set_filter(Language)), - }), - path!("api" / "v1" / "streaming" / "public") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_public_user) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:media".to_owned(), user.set_filter(Language)), - _ => ("public".to_owned(), user.set_filter(Language)), - }), - path!("api" / "v1" / "streaming" / "public" / "local") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_public_user) - .map(|user: User| ("public:local".to_owned(), user.set_filter(Language))), - path!("api" / "v1" / "streaming" / "public") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_public_user) - .map(|user: User| ("public".to_owned(), user.set_filter(Language))), - path!("api" / "v1" / "streaming" / "direct") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_reject) - .map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))), - // **Note**: Hashtags are *not* filtered on language, right? - path!("api" / "v1" / "streaming" / "hashtag" / "local") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_public_user) - .and(warp::query()) - .map(|_, q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())), - path!("api" / "v1" / "streaming" / "hashtag") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_public_user) - .and(warp::query()) - .map(|_, q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())), - path!("api" / "v1" / "streaming" / "list") - .and(OptionalAccessToken::from_header_or_query()) - .and_then(User::from_access_token_or_reject) - .and(warp::query()) - .and_then(|user: User, q: query::List| { - if user.owns_list(q.list) { - (Ok(q.list), Ok(user)) - } else { - (Err(CustomError::unauthorized_list()), Ok(user)) - } - }) - .untuple_one() - .map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter))) + parse_query!( + path => "api" / "v1" / "streaming" / "user" / "notification" + endpoint => "user:notification" ), + parse_query!( + path => "api" / "v1" / "streaming" / "user" + endpoint => "user"), + parse_query!( + path => "api" / "v1" / "streaming" / "public" / "local" + endpoint => "public:local"), + parse_query!( + path => "api" / "v1" / "streaming" / "public" + endpoint => "public"), + parse_query!( + path => "api" / "v1" / "streaming" / "direct" + endpoint => "direct"), + parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" + endpoint => "hashtag:local"), + parse_query!(path => "api" / "v1" / "streaming" / "hashtag" + endpoint => "hashtag"), + parse_query!(path => "api" / "v1" / "streaming" / "list" + endpoint => "list") ) - .untuple_one() + // because SSE requests place their `access_token` in the header instead of in a query + // parameter, we need to update our Query if the header has a token + .and(query::OptionalAccessToken::from_header()) + .and_then(Query::update_access_token) + .and_then(User::from_query) .boxed() } #[cfg(test)] mod test { use super::*; - - struct TestUser; - impl TestUser { - fn logged_in() -> User { - User::from_access_token_or_reject(Some("TEST_USER".to_string())).expect("in test") - } - fn public() -> User { - User::from_access_token_or_public_user(None).expect("in test") - } - } + use crate::parse_client_request::user::{Filter, OauthScope}; macro_rules! test_public_endpoint { ($name:ident { endpoint: $path:expr, - timeline: $timeline:expr, user: $user:expr, }) => { #[test] fn $name() { - let (timeline, user) = warp::test::request() + let user = warp::test::request() .path($path) - .filter(&filter_incomming_request()) + .filter(&extract_user_or_reject()) .expect("in test"); - assert_eq!(&timeline, $timeline); assert_eq!(user, $user); } }; } - macro_rules! test_private_endpoint { ($name:ident { endpoint: $path:expr, $(query: $query:expr,)* - timeline: $timeline:expr, user: $user:expr, }) => { #[test] fn $name() { let path = format!("{}?access_token=TEST_USER", $path); $(let path = format!("{}&{}", path, $query);)* - let (timeline, user) = warp::test::request() + let user = warp::test::request() .path(&path) - .filter(&filter_incomming_request()) + .filter(&extract_user_or_reject()) .expect("in test"); - assert_eq!(&timeline, $timeline); assert_eq!(user, $user); - let (timeline, user) = warp::test::request() + let user = warp::test::request() .path(&path) .header("Authorization", "Bearer: TEST_USER") - .filter(&filter_incomming_request()) + .filter(&extract_user_or_reject()) .expect("in test"); - assert_eq!(&timeline, $timeline); assert_eq!(user, $user); } }; } - macro_rules! test_bad_auth_token_in_query { ($name: ident { endpoint: $path:expr, @@ -153,19 +120,17 @@ mod test { }) => { #[test] #[should_panic(expected = "Error: Invalid access token")] - fn $name() { let path = format!("{}?access_token=INVALID", $path); $(let path = format!("{}&{}", path, $query);)* dbg!(&path); warp::test::request() .path(&path) - .filter(&filter_incomming_request()) + .filter(&extract_user_or_reject()) .expect("in test"); } }; } - macro_rules! test_bad_auth_token_in_header { ($name: ident { endpoint: $path:expr, @@ -180,7 +145,7 @@ mod test { warp::test::request() .path(&path) .header("Authorization", "Bearer: INVALID") - .filter(&filter_incomming_request()) + .filter(&extract_user_or_reject()) .expect("in test"); } }; @@ -197,7 +162,7 @@ mod test { $(let path = format!("{}?{}", path, $query);)* warp::test::request() .path(&path) - .filter(&filter_incomming_request()) + .filter(&extract_user_or_reject()) .expect("in test"); } }; @@ -205,13 +170,193 @@ mod test { test_public_endpoint!(public_media_true { endpoint: "/api/v1/streaming/public?only_media=true", - timeline: "public:media", - user: TestUser::public().set_filter(Language), + user: User { + target_timeline: "public:media".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, }); test_public_endpoint!(public_media_1 { endpoint: "/api/v1/streaming/public?only_media=1", - timeline: "public:media", - user: TestUser::public().set_filter(Language), + user: User { + target_timeline: "public:media".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(public_local { + endpoint: "/api/v1/streaming/public/local", + user: User { + target_timeline: "public:local".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(public_local_media_true { + endpoint: "/api/v1/streaming/public/local?only_media=true", + user: User { + target_timeline: "public:local:media".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(public_local_media_1 { + endpoint: "/api/v1/streaming/public/local?only_media=1", + user: User { + target_timeline: "public:local:media".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(hashtag { + endpoint: "/api/v1/streaming/hashtag?tag=a", + user: User { + target_timeline: "hashtag:a".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(hashtag_local { + endpoint: "/api/v1/streaming/hashtag/local?tag=a", + user: User { + target_timeline: "hashtag:local:a".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + + test_private_endpoint!(user { + endpoint: "/api/v1/streaming/user", + user: User { + target_timeline: "1".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::NoFilter, + }, + }); + test_private_endpoint!(user_notification { + endpoint: "/api/v1/streaming/user/notification", + user: User { + target_timeline: "1".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::Notification, + }, + }); + test_private_endpoint!(direct { + endpoint: "/api/v1/streaming/direct", + user: User { + target_timeline: "direct".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::NoFilter, + }, + }); + + test_private_endpoint!(list_valid_list { + endpoint: "/api/v1/streaming/list", + query: "list=1", + user: User { + target_timeline: "list:1".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::NoFilter, + }, }); test_bad_auth_token_in_query!(public_media_true_bad_auth { endpoint: "/api/v1/streaming/public", @@ -221,29 +366,12 @@ mod test { endpoint: "/api/v1/streaming/public", query: "only_media=1", }); - - test_public_endpoint!(public_local { - endpoint: "/api/v1/streaming/public/local", - timeline: "public:local", - user: TestUser::public().set_filter(Language), - }); test_bad_auth_token_in_query!(public_local_bad_auth_in_query { endpoint: "/api/v1/streaming/public/local", }); test_bad_auth_token_in_header!(public_local_bad_auth_in_header { endpoint: "/api/v1/streaming/public/local", }); - - test_public_endpoint!(public_local_media_true { - endpoint: "/api/v1/streaming/public/local?only_media=true", - timeline: "public:local:media", - user: TestUser::public().set_filter(Language), - }); - test_public_endpoint!(public_local_media_1 { - endpoint: "/api/v1/streaming/public/local?only_media=1", - timeline: "public:local:media", - user: TestUser::public().set_filter(Language), - }); test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query { endpoint: "/api/v1/streaming/public/local", query: "only_media=1", @@ -252,12 +380,6 @@ mod test { endpoint: "/api/v1/streaming/public/local", query: "only_media=true", }); - - test_public_endpoint!(hashtag { - endpoint: "/api/v1/streaming/hashtag?tag=a", - timeline: "hashtag:a", - user: TestUser::public(), - }); test_bad_auth_token_in_query!(hashtag_bad_auth_in_query { endpoint: "/api/v1/streaming/hashtag", query: "tag=a", @@ -266,26 +388,6 @@ mod test { endpoint: "/api/v1/streaming/hashtag", query: "tag=a", }); - - test_public_endpoint!(hashtag_local { - endpoint: "/api/v1/streaming/hashtag/local?tag=a", - timeline: "hashtag:a:local", - user: TestUser::public(), - }); - test_bad_auth_token_in_query!(hashtag_local_bad_auth_in_query { - endpoint: "/api/v1/streaming/hashtag/local", - query: "tag=a", - }); - test_bad_auth_token_in_header!(hashtag_local_bad_auth_in_header { - endpoint: "/api/v1/streaming/hashtag/local", - query: "tag=a", - }); - - test_private_endpoint!(user { - endpoint: "/api/v1/streaming/user", - timeline: "1", - user: TestUser::logged_in(), - }); test_bad_auth_token_in_query!(user_bad_auth_in_query { endpoint: "/api/v1/streaming/user", }); @@ -295,12 +397,6 @@ mod test { test_missing_auth!(user_missing_auth_token { endpoint: "/api/v1/streaming/user", }); - - test_private_endpoint!(user_notification { - endpoint: "/api/v1/streaming/user/notification", - timeline: "1", - user: TestUser::logged_in().set_filter(Notification), - }); test_bad_auth_token_in_query!(user_notification_bad_auth_in_query { endpoint: "/api/v1/streaming/user/notification", }); @@ -310,12 +406,6 @@ mod test { test_missing_auth!(user_notification_missing_auth_token { endpoint: "/api/v1/streaming/user/notification", }); - - test_private_endpoint!(direct { - endpoint: "/api/v1/streaming/direct", - timeline: "direct:1", - user: TestUser::logged_in(), - }); test_bad_auth_token_in_query!(direct_bad_auth_in_query { endpoint: "/api/v1/streaming/direct", }); @@ -325,13 +415,6 @@ mod test { test_missing_auth!(direct_missing_auth_token { endpoint: "/api/v1/streaming/direct", }); - - test_private_endpoint!(list_valid_list { - endpoint: "/api/v1/streaming/list", - query: "list=1", - timeline: "list:1", - user: TestUser::logged_in(), - }); test_bad_auth_token_in_query!(list_bad_auth_in_query { endpoint: "/api/v1/streaming/list", query: "list=1", @@ -345,4 +428,13 @@ mod test { query: "list=1", }); + #[test] + #[should_panic(expected = "NotFound")] + fn nonexistant_endpoint() { + warp::test::request() + .path("/api/v1/streaming/DOES_NOT_EXIST") + .filter(&extract_user_or_reject()) + .expect("in test"); + } + } diff --git a/src/parse_client_request/user/mod.rs b/src/parse_client_request/user/mod.rs index f2f9ee8..53d3c80 100644 --- a/src/parse_client_request/user/mod.rs +++ b/src/parse_client_request/user/mod.rs @@ -5,16 +5,8 @@ mod mock_postgres; use mock_postgres as postgres; #[cfg(not(test))] mod postgres; -use crate::parse_client_request::query; -use log::info; +use super::query::Query; use warp::reject::Rejection; -use warp::Filter as WarpFilter; - -macro_rules! any_of { - ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())* - }; -} /// The filters that can be applied to toots after they come from Redis #[derive(Clone, Debug, PartialEq)] @@ -23,10 +15,16 @@ pub enum Filter { Language, Notification, } +impl Default for Filter { + fn default() -> Self { + Filter::NoFilter + } +} /// The User (with data read from Postgres) -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq)] pub struct User { + pub target_timeline: String, pub id: i64, pub access_token: String, pub scopes: OauthScope, @@ -34,11 +32,7 @@ pub struct User { pub logged_in: bool, pub filter: Filter, } -impl Default for User { - fn default() -> Self { - User::public() - } -} + #[derive(Clone, Debug, Default, PartialEq)] pub struct OauthScope { pub all: bool, @@ -62,72 +56,82 @@ impl From> for OauthScope { } } -/// Create a user based on the supplied path and access scope for the resource -#[macro_export] -macro_rules! user_from_path { - ($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +) - .and($scope.get_access_token()) - .and_then(|token| User::from_access_token(token, $scope))) -} - impl User { - pub fn from_access_token_or_reject(token: Option) -> Result { - match token { - None => Err(warp::reject::custom("Error: Missing access token")), + pub fn from_query(q: Query) -> Result { + let (id, access_token, scopes, langs, logged_in) = match q.access_token.clone() { + None => ( + -1, + "no access token".to_owned(), + OauthScope::default(), + None, + false, + ), Some(token) => { let (id, langs, scope_list) = postgres::query_for_user_data(&token); if id == -1 { return Err(warp::reject::custom("Error: Invalid access token")); } let scopes = OauthScope::from(scope_list); - - Ok(User { - id, - access_token: token, - scopes, - langs, - logged_in: true, - filter: Filter::NoFilter, - }) + (id, token, scopes, langs, true) } - } + }; + let mut user = User { + id, + target_timeline: "PLACEHOLDER".to_string(), + access_token, + scopes, + langs, + logged_in, + filter: Filter::Language, + }; + + user = user.update_timeline_and_filter(q)?; + + Ok(user) } - pub fn from_access_token_or_public_user(token: Option) -> Result { - match token { - None => Ok(User::public()), - Some(_) => User::from_access_token_or_reject(token), - } - } - /// Create a user from the access token supplied in the header or query paramaters - pub fn from_access_token( - access_token: String, - scope: Scope, - ) -> Result { - let (id, langs, scope_list) = postgres::query_for_user_data(&access_token); - let scopes = OauthScope::from(scope_list); - if id != -1 || scope == Scope::Public { - let (logged_in, log_msg) = match id { - -1 => (false, "Public access to non-authenticated endpoints"), - _ => (true, "Granting logged-in access"), - }; - info!("{}", log_msg); - Ok(User { - id, - access_token, - scopes, - langs, - logged_in, - filter: Filter::NoFilter, - }) - } else { - Err(warp::reject::custom("Error: Invalid access token")) - } - } - /// Set the Notification/Language filter - pub fn set_filter(self, filter: Filter) -> Self { - Self { filter, ..self } + fn update_timeline_and_filter(mut self, q: Query) -> Result { + let read_scope = self.scopes.clone(); + + let timeline = match q.stream.as_ref() { + // Public endpoints: + tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl), + tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), + tl @ "public" | tl @ "public:local" => tl.to_string(), + // Hashtag endpoints: + tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag), + // Private endpoints: User + "user" if self.logged_in && (read_scope.all || read_scope.statuses) => { + self.filter = Filter::NoFilter; + format!("{}", self.id) + } + "user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => { + self.filter = Filter::Notification; + format!("{}", self.id) + } + // List endpoint: + "list" if self.owns_list(q.list) && (read_scope.all || read_scope.lists) => { + self.filter = Filter::NoFilter; + format!("list:{}", q.list) + } + // Direct endpoint: + "direct" if self.logged_in && (read_scope.all || read_scope.statuses) => { + self.filter = Filter::NoFilter; + "direct".to_string() + } + // Reject unathorized access attempts for private endpoints + "user" | "user:notification" | "direct" | "list" => { + return Err(warp::reject::custom("Error: Missing access token")) + } + // Other endpoints don't exist: + _ => return Err(warp::reject::custom("Error: Nonexistent endpoint")), + }; + Ok(Self { + target_timeline: timeline, + ..self + }) } + /// Determine whether the User is authorised for a specified list pub fn owns_list(&self, list: i64) -> bool { match postgres::query_list_owner(list) { @@ -135,75 +139,4 @@ impl User { _ => false, } } - pub fn public2() -> warp::filters::BoxedFilter<(User,)> { - warp::any() - .map(|| User { - id: -1, - access_token: String::from("no access token"), - scopes: OauthScope::default(), - langs: None, - logged_in: false, - filter: Filter::NoFilter, - }) - .boxed() - } - /// A public (non-authenticated) User - pub fn public() -> Self { - User { - id: -1, - access_token: String::from("no access token"), - scopes: OauthScope::default(), - langs: None, - logged_in: false, - filter: Filter::NoFilter, - } - } -} - -pub struct OptionalAccessToken; - -impl OptionalAccessToken { - pub fn from_header_or_query() -> warp::filters::BoxedFilter<(Option,)> { - let from_header = warp::header::header::("authorization").map(|auth: String| { - match auth.split(' ').nth(1) { - Some(s) => Some(s.to_string()), - None => None, - } - }); - let from_query = warp::query().map(|q: query::Auth| Some(q.access_token)); - let no_token = warp::any().map(|| None); - - any_of!(from_header, from_query, no_token).boxed() - } -} - -/// Whether the endpoint requires authentication or not -#[derive(PartialEq)] -pub enum Scope { - Public, - Private, -} -impl Scope { - pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> { - let token_from_header_http_push = warp::header::header::("authorization") - .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); - let token_from_header_ws = - warp::header::header::("Sec-WebSocket-Protocol").map(|auth: String| auth); - let token_from_query = warp::query().map(|q: query::Auth| q.access_token); - - let private_scopes = any_of!( - token_from_header_http_push, - token_from_header_ws, - token_from_query - ); - - let public = warp::any().map(|| "no access token".to_string()); - - match self { - // if they're trying to access a private scope without an access token, reject the request - Scope::Private => private_scopes.boxed(), - // if they're trying to access a public scope without an access token, proceed - Scope::Public => any_of!(private_scopes, public).boxed(), - } - } } diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index 77c7c68..1843816 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -1,43 +1,316 @@ //! Filters for the WebSocket endpoint -use super::{ - query, - user::{Scope, User}, -}; -use crate::user_from_path; +use super::{query, query::Query, user::User}; use warp::{filters::BoxedFilter, path, Filter}; /// WebSocket filters -pub fn extract_user_and_query() -> BoxedFilter<(User, Query, warp::ws::Ws2)> { - user_from_path!("streaming", Scope::Public) +fn parse_query() -> BoxedFilter<(Query,)> { + path!("api" / "v1" / "streaming") + .and(path::end()) .and(warp::query()) + .and(query::Auth::to_filter()) .and(query::Media::to_filter()) .and(query::Hashtag::to_filter()) .and(query::List::to_filter()) - .and(warp::ws2()) .map( - |user: User, - stream: query::Stream, + |stream: query::Stream, + auth: query::Auth, media: query::Media, hashtag: query::Hashtag, - list: query::List, - ws: warp::ws::Ws2| { - let query = Query { + list: query::List| { + Query { + access_token: auth.access_token, stream: stream.stream, media: media.is_truthy(), hashtag: hashtag.tag, list: list.list, - }; - (user, query, ws) + } }, ) - .untuple_one() .boxed() } -#[derive(Debug)] -pub struct Query { - pub stream: String, - pub media: bool, - pub hashtag: String, - pub list: i64, +pub fn extract_user_or_reject() -> BoxedFilter<(User,)> { + parse_query().and_then(User::from_query).boxed() +} +#[cfg(test)] +mod test { + use super::*; + use crate::parse_client_request::user::{Filter, OauthScope}; + + macro_rules! test_public_endpoint { + ($name:ident { + endpoint: $path:expr, + user: $user:expr, + }) => { + #[test] + fn $name() { + let user = warp::test::request() + .path($path) + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13") + .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") + .filter(&extract_user_or_reject()) + .expect("in test"); + assert_eq!(user, $user); + } + }; + } + macro_rules! test_private_endpoint { + ($name:ident { + endpoint: $path:expr, + user: $user:expr, + }) => { + #[test] + fn $name() { + let path = format!("{}&access_token=TEST_USER", $path); + let user = warp::test::request() + .path(&path) + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13") + .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") + .filter(&extract_user_or_reject()) + .expect("in test"); + assert_eq!(user, $user); + } + }; + } + macro_rules! test_bad_auth_token_in_query { + ($name: ident { + endpoint: $path:expr, + + }) => { + #[test] + #[should_panic(expected = "Error: Invalid access token")] + + fn $name() { + let path = format!("{}&access_token=INVALID", $path); + warp::test::request() + .path(&path) + .filter(&extract_user_or_reject()) + .expect("in test"); + } + }; + } + macro_rules! test_missing_auth { + ($name: ident { + endpoint: $path:expr, + }) => { + #[test] + #[should_panic(expected = "Error: Missing access token")] + fn $name() { + let path = $path; + warp::test::request() + .path(&path) + .filter(&extract_user_or_reject()) + .expect("in test"); + } + }; + } + + test_public_endpoint!(public_media { + endpoint: "/api/v1/streaming?stream=public:media", + user: User { + target_timeline: "public:media".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(public_local { + endpoint: "/api/v1/streaming?stream=public:local", + user: User { + target_timeline: "public:local".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(public_local_media { + endpoint: "/api/v1/streaming?stream=public:local:media", + user: User { + target_timeline: "public:local:media".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(hashtag { + endpoint: "/api/v1/streaming?stream=hashtag&tag=a", + user: User { + target_timeline: "hashtag:a".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + test_public_endpoint!(hashtag_local { + endpoint: "/api/v1/streaming?stream=hashtag:local&tag=a", + user: User { + target_timeline: "hashtag:local:a".to_string(), + id: -1, + access_token: "no access token".to_string(), + langs: None, + scopes: OauthScope { + all: false, + statuses: false, + notify: false, + lists: false, + }, + logged_in: false, + filter: Filter::Language, + }, + }); + + test_private_endpoint!(user { + endpoint: "/api/v1/streaming?stream=user", + user: User { + target_timeline: "1".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::NoFilter, + }, + }); + test_private_endpoint!(user_notification { + endpoint: "/api/v1/streaming?stream=user:notification", + user: User { + target_timeline: "1".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::Notification, + }, + }); + test_private_endpoint!(direct { + endpoint: "/api/v1/streaming?stream=direct", + user: User { + target_timeline: "direct".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::NoFilter, + }, + }); + test_private_endpoint!(list_valid_list { + endpoint: "/api/v1/streaming?stream=list&list=1", + user: User { + target_timeline: "list:1".to_string(), + id: 1, + access_token: "TEST_USER".to_string(), + langs: None, + scopes: OauthScope { + all: true, + statuses: false, + notify: false, + lists: false, + }, + logged_in: true, + filter: Filter::NoFilter, + }, + }); + + test_bad_auth_token_in_query!(public_media_true_bad_auth { + endpoint: "/api/v1/streaming?stream=public:media", + }); + test_bad_auth_token_in_query!(public_local_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=public:local", + }); + test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=public:local:media", + }); + test_bad_auth_token_in_query!(hashtag_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=hashtag&tag=a", + }); + test_bad_auth_token_in_query!(user_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=user", + }); + test_missing_auth!(user_missing_auth_token { + endpoint: "/api/v1/streaming?stream=user", + }); + test_bad_auth_token_in_query!(user_notification_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=user:notification", + }); + test_missing_auth!(user_notification_missing_auth_token { + endpoint: "/api/v1/streaming?stream=user:notification", + }); + test_bad_auth_token_in_query!(direct_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=direct", + }); + test_missing_auth!(direct_missing_auth_token { + endpoint: "/api/v1/streaming?stream=direct", + }); + test_bad_auth_token_in_query!(list_bad_auth_in_query { + endpoint: "/api/v1/streaming?stream=list&list=1", + }); + test_missing_auth!(list_missing_auth_token { + endpoint: "/api/v1/streaming?stream=list&list=1", + }); + + #[test] + #[should_panic(expected = "NotFound")] + fn nonexistant_endpoint() { + warp::test::request() + .path("/api/v1/streaming/DOES_NOT_EXIST") + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13") + .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") + .filter(&extract_user_or_reject()) + .expect("in test"); + } } diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs index 76e77b2..d43a331 100644 --- a/src/redis_to_client_stream/client_agent.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -40,7 +40,7 @@ impl ClientAgent { receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())), id: Uuid::default(), target_timeline: String::new(), - current_user: User::public(), + current_user: User::default(), } } @@ -61,12 +61,12 @@ impl ClientAgent { /// a different user, the `Receiver` is responsible for figuring /// that out and avoiding duplicated connections. Thus, it is safe to /// use this method for each new client connection. - pub fn init_for_user(&mut self, target_timeline: &str, user: User) { + pub fn init_for_user(&mut self, user: User) { self.id = Uuid::new_v4(); - self.target_timeline = target_timeline.to_owned(); + self.target_timeline = user.target_timeline.to_owned(); self.current_user = user; let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)"); - receiver.manage_new_timeline(self.id, target_timeline); + receiver.manage_new_timeline(self.id, &self.target_timeline); } }