From a6b4d968cbe8a7831c114f40350fb9f89b0e4305 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Fri, 20 Mar 2020 14:42:01 -0400 Subject: [PATCH] Add support for WHITELIST_MODE (#99) When the `WHITELIST_MODE` environmental variable is set, Flodgatt requires users to authenticate with a valid access token before subscribing to any timelines (even those that are typically public). --- Cargo.lock | 2 +- src/config/deployment_cfg.rs | 2 ++ src/config/deployment_cfg_types.rs | 10 ++++++++++ src/main.rs | 8 ++++---- src/parse_client_request/sse.rs | 7 +++++-- src/parse_client_request/subscription/mod.rs | 3 ++- src/parse_client_request/ws.rs | 3 ++- 7 files changed, 26 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1433f7e..60d5ef3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.6.2" +version = "0.6.3" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/src/config/deployment_cfg.rs b/src/config/deployment_cfg.rs index d94a79e..164bca5 100644 --- a/src/config/deployment_cfg.rs +++ b/src/config/deployment_cfg.rs @@ -10,6 +10,7 @@ pub struct DeploymentConfig<'a> { pub cors: Cors<'a>, pub sse_interval: SseInterval, pub ws_interval: WsInterval, + pub whitelist_mode: WhitelistMode, } impl DeploymentConfig<'_> { @@ -22,6 +23,7 @@ impl DeploymentConfig<'_> { unix_socket: Socket::default().maybe_update(env.get("SOCKET")), sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ")), ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ")), + whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE")), cors: Cors::default(), }; cfg.env = cfg.env.maybe_update(env.get("RUST_ENV")); diff --git a/src/config/deployment_cfg_types.rs b/src/config/deployment_cfg_types.rs index 86622f9..b09de80 100644 --- a/src/config/deployment_cfg_types.rs +++ b/src/config/deployment_cfg_types.rs @@ -59,6 +59,16 @@ from_env_var!( let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string()); let from_str = |s| s.parse().ok(); ); +from_env_var!( + /// Enables [WHITELIST_MODE](https://docs.joinmastodon.org/admin/config/#whitelist_mode) + /// + /// This mode prevents non-logged in users from subscribing to any timelines + /// (including otherwise public timelines). + let name = WhitelistMode; + let default: bool = false; + let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string()); + let from_str = |s| s.parse().ok(); +); /// Permissions for Cross Origin Resource Sharing (CORS) pub struct Cors<'a> { pub allowed_headers: Vec<&'a str>, diff --git a/src/main.rs b/src/main.rs index 1e4fcc0..b964926 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,8 +34,8 @@ fn main() { log::info!("Streaming server initialized and ready to accept connections"); // Server Sent Events - let sse_update_interval = *cfg.ws_interval; - let sse_routes = sse::extract_user_or_reject(pg_pool.clone()) + let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); + let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode) .and(warp::sse()) .map( move |subscription: subscription::Subscription, @@ -57,8 +57,8 @@ fn main() { .recover(err::handle_errors); // WebSocket - let ws_update_interval = *cfg.ws_interval; - let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone()) + let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode); + let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode) .and(warp::ws::ws2()) .map( move |subscription: subscription::Subscription, token: Option, ws: Ws2| { diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs index 279ea7a..b62461c 100644 --- a/src/parse_client_request/sse.rs +++ b/src/parse_client_request/sse.rs @@ -39,7 +39,10 @@ macro_rules! parse_query { .boxed() }; } -pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> { +pub fn extract_user_or_reject( + pg_pool: PgPool, + whitelist_mode: bool, +) -> BoxedFilter<(Subscription,)> { any_of!( parse_query!( path => "api" / "v1" / "streaming" / "user" / "notification" @@ -67,7 +70,7 @@ pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> { // parameter, we need to update our Query if the header has a token .and(query::OptionalAccessToken::from_sse_header()) .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone())) + .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) .boxed() } diff --git a/src/parse_client_request/subscription/mod.rs b/src/parse_client_request/subscription/mod.rs index 9017c81..97661bb 100644 --- a/src/parse_client_request/subscription/mod.rs +++ b/src/parse_client_request/subscription/mod.rs @@ -30,9 +30,10 @@ impl Default for Subscription { } impl Subscription { - pub fn from_query(q: Query, pool: PgPool) -> Result { + pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result { let user = match q.access_token.clone() { Some(token) => postgres::select_user(&token, pool.clone())?, + None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?, None => UserData::public(), }; Ok(Subscription { diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index 1aac83f..aa74c60 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -34,11 +34,12 @@ fn parse_query() -> BoxedFilter<(Query,)> { pub fn extract_user_and_token_or_reject( pg_pool: PgPool, + whitelist_mode: bool, ) -> BoxedFilter<(Subscription, Option)> { parse_query() .and(query::OptionalAccessToken::from_ws_header()) .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone())) + .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) .and(query::OptionalAccessToken::from_ws_header()) .boxed() }