Add tests for websocket routes (#38)

* Refactor organazation of SSE

This commit refactors how SSE requests are handled to bring them into
line with how WS requests are handled and increase consistency.

* Add websocket tests

* Bump version to 0.2.0

Bump version and update name from ragequit to flodgatt.

* Add test for non-existant endpoints

* Update documentation for recent changes``
This commit is contained in:
Daniel Sockwell 2019-09-09 13:06:24 -04:00 committed by GitHub
parent 90602d17ed
commit ecfdda093c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 731 additions and 435 deletions

42
Cargo.lock generated
View File

@ -300,6 +300,27 @@ name = "fixedbitset"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
version = "0.2.0"
dependencies = [
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"openssl 0.10.24 (registry+https://github.com/rust-lang/crates.io-index)",
"postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)",
"postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)",
"pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)",
"tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)",
"uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
"warp 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "fnv"
version = "1.0.6"
@ -877,27 +898,6 @@ dependencies = [
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "ragequit"
version = "0.1.0"
dependencies = [
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"openssl 0.10.24 (registry+https://github.com/rust-lang/crates.io-index)",
"postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)",
"postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)",
"pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)",
"tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)",
"uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
"warp 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand"
version = "0.5.6"

View File

@ -1,7 +1,7 @@
[package]
name = "ragequit"
name = "flodgatt"
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
version = "0.1.0"
version = "0.2.0"
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
edition = "2018"

View File

@ -1,26 +1,24 @@
//! Streaming server for Mastodon
//!
//!
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
//! connect to the server with the API described [in Mastodon's public API
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a
//! server is running this sever, Mastodon clients can use either Server Sent Events or
//! WebSockets to connect to the server with the API described [in Mastodon's public API
//! documentation](https://docs.joinmastodon.org/api/streaming/).
//!
//! # Data Flow
//! * **Parsing the client request**
//! When the client request first comes in, it is parsed based on the endpoint it targets (for
//! server sent events), its query parameters, and its headers (for WebSocket). Based on this
//! data, we authenticate the user, retrieve relevant user data from Postgres, and determine the
//! timeline targeted by the request. Successfully parsing the client request results in generating
//! a `User` and target `timeline` for the request. If any requests are invalid/not authorized, we
//! reject them in this stage.
//! * **Streaming update from Redis to the client**:
//! After the user request is parsed, we pass the `User` and `timeline` data on to the
//! `ClientAgent`. The `ClientAgent` is responsible for communicating the user's request to the
//! `Receiver`, polling the `Receiver` for any updates, and then for wording those updates on to the
//! client. The `Receiver`, in tern, is responsible for managing the Redis subscriptions,
//! periodically polling Redis, and sorting the replies from Redis into queues for when it is polled
//! by the `ClientAgent`.
//! * **Parsing the client request** When the client request first comes in, it is
//! parsed based on the endpoint it targets (for server sent events), its query parameters,
//! and its headers (for WebSocket). Based on this data, we authenticate the user, retrieve
//! relevant user data from Postgres, and determine the timeline targeted by the request.
//! Successfully parsing the client request results in generating a `User` corresponding to
//! the request. If any requests are invalid/not authorized, we reject them in this stage.
//! * **Streaming update from Redis to the client**: After the user request is parsed, we pass
//! the `User` data on to the `ClientAgent`. The `ClientAgent` is responsible for
//! communicating the user's request to the `Receiver`, polling the `Receiver` for any
//! updates, and then for wording those updates on to the client. The `Receiver`, in tern, is
//! responsible for managing the Redis subscriptions, periodically polling Redis, and sorting
//! the replies from Redis into queues for when it is polled by the `ClientAgent`.
//!
//! # Concurrency
//! The `Receiver` is created when the server is first initialized, and there is only one
@ -31,11 +29,10 @@
//! that the `Receiver`'s poll of Redis be fast, since there will only ever be one
//! `Receiver`.
//!
//! # Configuration
//! By default, the server uses config values from the `config.rs` module; these values can be
//! overwritten with environmental variables or in the `.env` file. The most important settings
//! for performance control the frequency with which the `ClientAgent` polls the `Receiver` and
//! the frequency with which the `Receiver` polls Redis.
//! # Configuration By default, the server uses config values from the `config.rs` module;
//! these values can be overwritten with environmental variables or in the `.env` file. The
//! most important settings for performance control the frequency with which the `ClientAgent`
//! polls the `Receiver` and the frequency with which the `Receiver` polls Redis.
//!
pub mod config;
pub mod parse_client_request;

View File

@ -1,10 +1,10 @@
use log::{log_enabled, Level};
use ragequit::{
use flodgatt::{
config,
parse_client_request::{sse, user, ws},
redis_to_client_stream,
redis_to_client_stream::ClientAgent,
};
use log::{log_enabled, Level};
use warp::{ws::Ws2, Filter as WarpFilter};
fn main() {
@ -17,18 +17,14 @@ fn main() {
};
// Server Sent Events
//
// For SSE, the API requires users to use different endpoints, so we first filter based on
// the endpoint. Using that endpoint determine the `timeline` the user is requesting,
// the scope for that `timeline`, and authenticate the `User` if they provided a token.
let sse_routes = sse::filter_incomming_request()
let sse_routes = sse::extract_user_or_reject()
.and(warp::sse())
.map(
move |timeline: String, user: user::User, sse_connection_to_client: warp::sse::Sse| {
move |user: user::User, sse_connection_to_client: warp::sse::Sse| {
// Create a new ClientAgent
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(&timeline, user);
// Assign ClientAgent to generate stream of updates for the user/timeline pair
client_agent.init_for_user(user);
// send the updates through the SSE connection
redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client)
},
@ -37,52 +33,17 @@ fn main() {
.recover(config::handle_errors);
// WebSocket
//
// For WS, the API specifies a single endpoint, so we extract the User/timeline pair
// directy from the query
let websocket_routes = ws::extract_user_and_query()
.and_then(move |mut user: user::User, q: ws::Query, ws: Ws2| {
let websocket_routes = ws::extract_user_or_reject()
.and(warp::ws::ws2())
.and_then(move |user: user::User, ws: Ws2| {
let token = user.access_token.clone();
let read_scope = user.scopes.clone();
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
// Private endpoints: User
"user" if user.logged_in && (read_scope.all || read_scope.statuses) => {
format!("{}", user.id)
}
"user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => {
user = user.set_filter(user::Filter::Notification);
format!("{}", user.id)
}
// List endpoint:
"list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
format!("list:{}", q.list)
}
// Direct endpoint:
"direct" if user.logged_in && (read_scope.all || read_scope.statuses) => {
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Invalid Access Token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
};
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(&timeline, user);
client_agent.init_for_user(user);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
Ok((
Ok::<_, warp::Rejection>((
ws.on_upgrade(move |socket| {
redis_to_client_stream::send_updates_to_ws(socket, client_agent)
}),

View File

@ -1,4 +1,4 @@
//! Parse the client request and return a 'timeline' and a (maybe authenticated) `User`
//! Parse the client request and return a (possibly authenticated) `User`
pub mod query;
pub mod sse;
pub mod user;

View File

@ -3,8 +3,32 @@ use serde_derive::Deserialize;
use warp::filters::BoxedFilter;
use warp::Filter as WarpFilter;
macro_rules! query {
($name:tt => $parameter:tt:$type:tt) => {
#[derive(Debug)]
pub struct Query {
pub access_token: Option<String>,
pub stream: String,
pub media: bool,
pub hashtag: String,
pub list: i64,
}
impl Query {
pub fn update_access_token(
self,
token: Option<String>,
) -> Result<Self, warp::reject::Rejection> {
match token {
Some(token) => Ok(Self {
access_token: Some(token),
..self
}),
None => Ok(self),
}
}
}
macro_rules! make_query_type {
($name:tt => $parameter:tt:$type:ty) => {
#[derive(Deserialize, Debug, Default)]
pub struct $name {
pub $parameter: $type,
@ -19,16 +43,16 @@ macro_rules! query {
}
};
}
query!(Media => only_media:String);
make_query_type!(Media => only_media:String);
impl Media {
pub fn is_truthy(&self) -> bool {
self.only_media == "true" || self.only_media == "1"
}
}
query!(Hashtag => tag: String);
query!(List => list: i64);
query!(Auth => access_token: String);
query!(Stream => stream: String);
make_query_type!(Hashtag => tag: String);
make_query_type!(List => list: i64);
make_query_type!(Auth => access_token: Option<String>);
make_query_type!(Stream => stream: String);
impl ToString for Stream {
fn to_string(&self) -> String {
format!("{:?}", self)
@ -43,3 +67,19 @@ pub fn optional_media_query() -> BoxedFilter<(Media,)> {
.unify()
.boxed()
}
pub struct OptionalAccessToken;
impl OptionalAccessToken {
pub fn from_header() -> warp::filters::BoxedFilter<(Option<String>,)> {
let from_header = warp::header::header::<String>("authorization").map(|auth: String| {
match auth.split(' ').nth(1) {
Some(s) => Some(s.to_string()),
None => None,
}
});
let no_token = warp::any().map(|| None);
from_header.or(no_token).unify().boxed()
}
}

View File

@ -1,9 +1,5 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use super::{
query,
user::{Filter::*, OptionalAccessToken, User},
};
use crate::config::CustomError;
use super::{query, query::Query, user::User};
use warp::{filters::BoxedFilter, path, Filter};
#[allow(dead_code)]
@ -12,140 +8,111 @@ type TimelineUser = ((String, User),);
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
$filter$(.or($other_filter).unify())*.boxed()
};
}
pub fn filter_incomming_request() -> BoxedFilter<(String, User)> {
macro_rules! parse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
any_of!(
path!("api" / "v1" / "streaming" / "user" / "notification")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_reject)
.map(|user: User| (user.id.to_string(), user.set_filter(Notification))),
// **NOTE**: This endpoint was present in the node.js server, but not in the
// [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local).
// Should it be publicly documented?
path!("api" / "v1" / "streaming" / "user")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_reject)
.map(|user: User| (user.id.to_string(), user)),
path!("api" / "v1" / "streaming" / "public" / "local")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_public_user)
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)),
_ => ("public:local".to_owned(), user.set_filter(Language)),
}),
path!("api" / "v1" / "streaming" / "public")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_public_user)
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:media".to_owned(), user.set_filter(Language)),
_ => ("public".to_owned(), user.set_filter(Language)),
}),
path!("api" / "v1" / "streaming" / "public" / "local")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_public_user)
.map(|user: User| ("public:local".to_owned(), user.set_filter(Language))),
path!("api" / "v1" / "streaming" / "public")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_public_user)
.map(|user: User| ("public".to_owned(), user.set_filter(Language))),
path!("api" / "v1" / "streaming" / "direct")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_reject)
.map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))),
// **Note**: Hashtags are *not* filtered on language, right?
path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_public_user)
.and(warp::query())
.map(|_, q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())),
path!("api" / "v1" / "streaming" / "hashtag")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_public_user)
.and(warp::query())
.map(|_, q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())),
path!("api" / "v1" / "streaming" / "list")
.and(OptionalAccessToken::from_header_or_query())
.and_then(User::from_access_token_or_reject)
.and(warp::query())
.and_then(|user: User, q: query::List| {
if user.owns_list(q.list) {
(Ok(q.list), Ok(user))
} else {
(Err(CustomError::unauthorized_list()), Ok(user))
}
})
.untuple_one()
.map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter)))
parse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
.untuple_one()
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_header())
.and_then(Query::update_access_token)
.and_then(User::from_query)
.boxed()
}
#[cfg(test)]
mod test {
use super::*;
struct TestUser;
impl TestUser {
fn logged_in() -> User {
User::from_access_token_or_reject(Some("TEST_USER".to_string())).expect("in test")
}
fn public() -> User {
User::from_access_token_or_public_user(None).expect("in test")
}
}
use crate::parse_client_request::user::{Filter, OauthScope};
macro_rules! test_public_endpoint {
($name:ident {
endpoint: $path:expr,
timeline: $timeline:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let (timeline, user) = warp::test::request()
let user = warp::test::request()
.path($path)
.filter(&filter_incomming_request())
.filter(&extract_user_or_reject())
.expect("in test");
assert_eq!(&timeline, $timeline);
assert_eq!(user, $user);
}
};
}
macro_rules! test_private_endpoint {
($name:ident {
endpoint: $path:expr,
$(query: $query:expr,)*
timeline: $timeline:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let path = format!("{}?access_token=TEST_USER", $path);
$(let path = format!("{}&{}", path, $query);)*
let (timeline, user) = warp::test::request()
let user = warp::test::request()
.path(&path)
.filter(&filter_incomming_request())
.filter(&extract_user_or_reject())
.expect("in test");
assert_eq!(&timeline, $timeline);
assert_eq!(user, $user);
let (timeline, user) = warp::test::request()
let user = warp::test::request()
.path(&path)
.header("Authorization", "Bearer: TEST_USER")
.filter(&filter_incomming_request())
.filter(&extract_user_or_reject())
.expect("in test");
assert_eq!(&timeline, $timeline);
assert_eq!(user, $user);
}
};
}
macro_rules! test_bad_auth_token_in_query {
($name: ident {
endpoint: $path:expr,
@ -153,19 +120,17 @@ mod test {
}) => {
#[test]
#[should_panic(expected = "Error: Invalid access token")]
fn $name() {
let path = format!("{}?access_token=INVALID", $path);
$(let path = format!("{}&{}", path, $query);)*
dbg!(&path);
warp::test::request()
.path(&path)
.filter(&filter_incomming_request())
.filter(&extract_user_or_reject())
.expect("in test");
}
};
}
macro_rules! test_bad_auth_token_in_header {
($name: ident {
endpoint: $path:expr,
@ -180,7 +145,7 @@ mod test {
warp::test::request()
.path(&path)
.header("Authorization", "Bearer: INVALID")
.filter(&filter_incomming_request())
.filter(&extract_user_or_reject())
.expect("in test");
}
};
@ -197,7 +162,7 @@ mod test {
$(let path = format!("{}?{}", path, $query);)*
warp::test::request()
.path(&path)
.filter(&filter_incomming_request())
.filter(&extract_user_or_reject())
.expect("in test");
}
};
@ -205,13 +170,193 @@ mod test {
test_public_endpoint!(public_media_true {
endpoint: "/api/v1/streaming/public?only_media=true",
timeline: "public:media",
user: TestUser::public().set_filter(Language),
user: User {
target_timeline: "public:media".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(public_media_1 {
endpoint: "/api/v1/streaming/public?only_media=1",
timeline: "public:media",
user: TestUser::public().set_filter(Language),
user: User {
target_timeline: "public:media".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(public_local {
endpoint: "/api/v1/streaming/public/local",
user: User {
target_timeline: "public:local".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(public_local_media_true {
endpoint: "/api/v1/streaming/public/local?only_media=true",
user: User {
target_timeline: "public:local:media".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(public_local_media_1 {
endpoint: "/api/v1/streaming/public/local?only_media=1",
user: User {
target_timeline: "public:local:media".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag {
endpoint: "/api/v1/streaming/hashtag?tag=a",
user: User {
target_timeline: "hashtag:a".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag_local {
endpoint: "/api/v1/streaming/hashtag/local?tag=a",
user: User {
target_timeline: "hashtag:local:a".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_private_endpoint!(user {
endpoint: "/api/v1/streaming/user",
user: User {
target_timeline: "1".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::NoFilter,
},
});
test_private_endpoint!(user_notification {
endpoint: "/api/v1/streaming/user/notification",
user: User {
target_timeline: "1".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::Notification,
},
});
test_private_endpoint!(direct {
endpoint: "/api/v1/streaming/direct",
user: User {
target_timeline: "direct".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::NoFilter,
},
});
test_private_endpoint!(list_valid_list {
endpoint: "/api/v1/streaming/list",
query: "list=1",
user: User {
target_timeline: "list:1".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::NoFilter,
},
});
test_bad_auth_token_in_query!(public_media_true_bad_auth {
endpoint: "/api/v1/streaming/public",
@ -221,29 +366,12 @@ mod test {
endpoint: "/api/v1/streaming/public",
query: "only_media=1",
});
test_public_endpoint!(public_local {
endpoint: "/api/v1/streaming/public/local",
timeline: "public:local",
user: TestUser::public().set_filter(Language),
});
test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
endpoint: "/api/v1/streaming/public/local",
});
test_bad_auth_token_in_header!(public_local_bad_auth_in_header {
endpoint: "/api/v1/streaming/public/local",
});
test_public_endpoint!(public_local_media_true {
endpoint: "/api/v1/streaming/public/local?only_media=true",
timeline: "public:local:media",
user: TestUser::public().set_filter(Language),
});
test_public_endpoint!(public_local_media_1 {
endpoint: "/api/v1/streaming/public/local?only_media=1",
timeline: "public:local:media",
user: TestUser::public().set_filter(Language),
});
test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
endpoint: "/api/v1/streaming/public/local",
query: "only_media=1",
@ -252,12 +380,6 @@ mod test {
endpoint: "/api/v1/streaming/public/local",
query: "only_media=true",
});
test_public_endpoint!(hashtag {
endpoint: "/api/v1/streaming/hashtag?tag=a",
timeline: "hashtag:a",
user: TestUser::public(),
});
test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
endpoint: "/api/v1/streaming/hashtag",
query: "tag=a",
@ -266,26 +388,6 @@ mod test {
endpoint: "/api/v1/streaming/hashtag",
query: "tag=a",
});
test_public_endpoint!(hashtag_local {
endpoint: "/api/v1/streaming/hashtag/local?tag=a",
timeline: "hashtag:a:local",
user: TestUser::public(),
});
test_bad_auth_token_in_query!(hashtag_local_bad_auth_in_query {
endpoint: "/api/v1/streaming/hashtag/local",
query: "tag=a",
});
test_bad_auth_token_in_header!(hashtag_local_bad_auth_in_header {
endpoint: "/api/v1/streaming/hashtag/local",
query: "tag=a",
});
test_private_endpoint!(user {
endpoint: "/api/v1/streaming/user",
timeline: "1",
user: TestUser::logged_in(),
});
test_bad_auth_token_in_query!(user_bad_auth_in_query {
endpoint: "/api/v1/streaming/user",
});
@ -295,12 +397,6 @@ mod test {
test_missing_auth!(user_missing_auth_token {
endpoint: "/api/v1/streaming/user",
});
test_private_endpoint!(user_notification {
endpoint: "/api/v1/streaming/user/notification",
timeline: "1",
user: TestUser::logged_in().set_filter(Notification),
});
test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
endpoint: "/api/v1/streaming/user/notification",
});
@ -310,12 +406,6 @@ mod test {
test_missing_auth!(user_notification_missing_auth_token {
endpoint: "/api/v1/streaming/user/notification",
});
test_private_endpoint!(direct {
endpoint: "/api/v1/streaming/direct",
timeline: "direct:1",
user: TestUser::logged_in(),
});
test_bad_auth_token_in_query!(direct_bad_auth_in_query {
endpoint: "/api/v1/streaming/direct",
});
@ -325,13 +415,6 @@ mod test {
test_missing_auth!(direct_missing_auth_token {
endpoint: "/api/v1/streaming/direct",
});
test_private_endpoint!(list_valid_list {
endpoint: "/api/v1/streaming/list",
query: "list=1",
timeline: "list:1",
user: TestUser::logged_in(),
});
test_bad_auth_token_in_query!(list_bad_auth_in_query {
endpoint: "/api/v1/streaming/list",
query: "list=1",
@ -345,4 +428,13 @@ mod test {
query: "list=1",
});
#[test]
#[should_panic(expected = "NotFound")]
fn nonexistant_endpoint() {
warp::test::request()
.path("/api/v1/streaming/DOES_NOT_EXIST")
.filter(&extract_user_or_reject())
.expect("in test");
}
}

View File

@ -5,16 +5,8 @@ mod mock_postgres;
use mock_postgres as postgres;
#[cfg(not(test))]
mod postgres;
use crate::parse_client_request::query;
use log::info;
use super::query::Query;
use warp::reject::Rejection;
use warp::Filter as WarpFilter;
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}
/// The filters that can be applied to toots after they come from Redis
#[derive(Clone, Debug, PartialEq)]
@ -23,10 +15,16 @@ pub enum Filter {
Language,
Notification,
}
impl Default for Filter {
fn default() -> Self {
Filter::NoFilter
}
}
/// The User (with data read from Postgres)
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, Default, PartialEq)]
pub struct User {
pub target_timeline: String,
pub id: i64,
pub access_token: String,
pub scopes: OauthScope,
@ -34,11 +32,7 @@ pub struct User {
pub logged_in: bool,
pub filter: Filter,
}
impl Default for User {
fn default() -> Self {
User::public()
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct OauthScope {
pub all: bool,
@ -62,72 +56,82 @@ impl From<Vec<String>> for OauthScope {
}
}
/// Create a user based on the supplied path and access scope for the resource
#[macro_export]
macro_rules! user_from_path {
($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +)
.and($scope.get_access_token())
.and_then(|token| User::from_access_token(token, $scope)))
}
impl User {
pub fn from_access_token_or_reject(token: Option<String>) -> Result<Self, Rejection> {
match token {
None => Err(warp::reject::custom("Error: Missing access token")),
pub fn from_query(q: Query) -> Result<Self, Rejection> {
let (id, access_token, scopes, langs, logged_in) = match q.access_token.clone() {
None => (
-1,
"no access token".to_owned(),
OauthScope::default(),
None,
false,
),
Some(token) => {
let (id, langs, scope_list) = postgres::query_for_user_data(&token);
if id == -1 {
return Err(warp::reject::custom("Error: Invalid access token"));
}
let scopes = OauthScope::from(scope_list);
Ok(User {
id,
access_token: token,
scopes,
langs,
logged_in: true,
filter: Filter::NoFilter,
})
(id, token, scopes, langs, true)
}
}
};
let mut user = User {
id,
target_timeline: "PLACEHOLDER".to_string(),
access_token,
scopes,
langs,
logged_in,
filter: Filter::Language,
};
user = user.update_timeline_and_filter(q)?;
Ok(user)
}
pub fn from_access_token_or_public_user(token: Option<String>) -> Result<Self, Rejection> {
match token {
None => Ok(User::public()),
Some(_) => User::from_access_token_or_reject(token),
}
}
/// Create a user from the access token supplied in the header or query paramaters
pub fn from_access_token(
access_token: String,
scope: Scope,
) -> Result<Self, warp::reject::Rejection> {
let (id, langs, scope_list) = postgres::query_for_user_data(&access_token);
let scopes = OauthScope::from(scope_list);
if id != -1 || scope == Scope::Public {
let (logged_in, log_msg) = match id {
-1 => (false, "Public access to non-authenticated endpoints"),
_ => (true, "Granting logged-in access"),
};
info!("{}", log_msg);
Ok(User {
id,
access_token,
scopes,
langs,
logged_in,
filter: Filter::NoFilter,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
/// Set the Notification/Language filter
pub fn set_filter(self, filter: Filter) -> Self {
Self { filter, ..self }
fn update_timeline_and_filter(mut self, q: Query) -> Result<Self, Rejection> {
let read_scope = self.scopes.clone();
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
// Private endpoints: User
"user" if self.logged_in && (read_scope.all || read_scope.statuses) => {
self.filter = Filter::NoFilter;
format!("{}", self.id)
}
"user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => {
self.filter = Filter::Notification;
format!("{}", self.id)
}
// List endpoint:
"list" if self.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
self.filter = Filter::NoFilter;
format!("list:{}", q.list)
}
// Direct endpoint:
"direct" if self.logged_in && (read_scope.all || read_scope.statuses) => {
self.filter = Filter::NoFilter;
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Missing access token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent endpoint")),
};
Ok(Self {
target_timeline: timeline,
..self
})
}
/// Determine whether the User is authorised for a specified list
pub fn owns_list(&self, list: i64) -> bool {
match postgres::query_list_owner(list) {
@ -135,75 +139,4 @@ impl User {
_ => false,
}
}
pub fn public2() -> warp::filters::BoxedFilter<(User,)> {
warp::any()
.map(|| User {
id: -1,
access_token: String::from("no access token"),
scopes: OauthScope::default(),
langs: None,
logged_in: false,
filter: Filter::NoFilter,
})
.boxed()
}
/// A public (non-authenticated) User
pub fn public() -> Self {
User {
id: -1,
access_token: String::from("no access token"),
scopes: OauthScope::default(),
langs: None,
logged_in: false,
filter: Filter::NoFilter,
}
}
}
pub struct OptionalAccessToken;
impl OptionalAccessToken {
pub fn from_header_or_query() -> warp::filters::BoxedFilter<(Option<String>,)> {
let from_header = warp::header::header::<String>("authorization").map(|auth: String| {
match auth.split(' ').nth(1) {
Some(s) => Some(s.to_string()),
None => None,
}
});
let from_query = warp::query().map(|q: query::Auth| Some(q.access_token));
let no_token = warp::any().map(|| None);
any_of!(from_header, from_query, no_token).boxed()
}
}
/// Whether the endpoint requires authentication or not
#[derive(PartialEq)]
pub enum Scope {
Public,
Private,
}
impl Scope {
pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> {
let token_from_header_http_push = warp::header::header::<String>("authorization")
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
let token_from_header_ws =
warp::header::header::<String>("Sec-WebSocket-Protocol").map(|auth: String| auth);
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
let private_scopes = any_of!(
token_from_header_http_push,
token_from_header_ws,
token_from_query
);
let public = warp::any().map(|| "no access token".to_string());
match self {
// if they're trying to access a private scope without an access token, reject the request
Scope::Private => private_scopes.boxed(),
// if they're trying to access a public scope without an access token, proceed
Scope::Public => any_of!(private_scopes, public).boxed(),
}
}
}

View File

@ -1,43 +1,316 @@
//! Filters for the WebSocket endpoint
use super::{
query,
user::{Scope, User},
};
use crate::user_from_path;
use super::{query, query::Query, user::User};
use warp::{filters::BoxedFilter, path, Filter};
/// WebSocket filters
pub fn extract_user_and_query() -> BoxedFilter<(User, Query, warp::ws::Ws2)> {
user_from_path!("streaming", Scope::Public)
fn parse_query() -> BoxedFilter<(Query,)> {
path!("api" / "v1" / "streaming")
.and(path::end())
.and(warp::query())
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.and(warp::ws2())
.map(
|user: User,
stream: query::Stream,
|stream: query::Stream,
auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List,
ws: warp::ws::Ws2| {
let query = Query {
list: query::List| {
Query {
access_token: auth.access_token,
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
};
(user, query, ws)
}
},
)
.untuple_one()
.boxed()
}
#[derive(Debug)]
pub struct Query {
pub stream: String,
pub media: bool,
pub hashtag: String,
pub list: i64,
pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
parse_query().and_then(User::from_query).boxed()
}
#[cfg(test)]
mod test {
use super::*;
use crate::parse_client_request::user::{Filter, OauthScope};
macro_rules! test_public_endpoint {
($name:ident {
endpoint: $path:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let user = warp::test::request()
.path($path)
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject())
.expect("in test");
assert_eq!(user, $user);
}
};
}
macro_rules! test_private_endpoint {
($name:ident {
endpoint: $path:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let path = format!("{}&access_token=TEST_USER", $path);
let user = warp::test::request()
.path(&path)
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject())
.expect("in test");
assert_eq!(user, $user);
}
};
}
macro_rules! test_bad_auth_token_in_query {
($name: ident {
endpoint: $path:expr,
}) => {
#[test]
#[should_panic(expected = "Error: Invalid access token")]
fn $name() {
let path = format!("{}&access_token=INVALID", $path);
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.expect("in test");
}
};
}
macro_rules! test_missing_auth {
($name: ident {
endpoint: $path:expr,
}) => {
#[test]
#[should_panic(expected = "Error: Missing access token")]
fn $name() {
let path = $path;
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.expect("in test");
}
};
}
test_public_endpoint!(public_media {
endpoint: "/api/v1/streaming?stream=public:media",
user: User {
target_timeline: "public:media".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(public_local {
endpoint: "/api/v1/streaming?stream=public:local",
user: User {
target_timeline: "public:local".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(public_local_media {
endpoint: "/api/v1/streaming?stream=public:local:media",
user: User {
target_timeline: "public:local:media".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag {
endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
user: User {
target_timeline: "hashtag:a".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag_local {
endpoint: "/api/v1/streaming?stream=hashtag:local&tag=a",
user: User {
target_timeline: "hashtag:local:a".to_string(),
id: -1,
access_token: "no access token".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
filter: Filter::Language,
},
});
test_private_endpoint!(user {
endpoint: "/api/v1/streaming?stream=user",
user: User {
target_timeline: "1".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::NoFilter,
},
});
test_private_endpoint!(user_notification {
endpoint: "/api/v1/streaming?stream=user:notification",
user: User {
target_timeline: "1".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::Notification,
},
});
test_private_endpoint!(direct {
endpoint: "/api/v1/streaming?stream=direct",
user: User {
target_timeline: "direct".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::NoFilter,
},
});
test_private_endpoint!(list_valid_list {
endpoint: "/api/v1/streaming?stream=list&list=1",
user: User {
target_timeline: "list:1".to_string(),
id: 1,
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
filter: Filter::NoFilter,
},
});
test_bad_auth_token_in_query!(public_media_true_bad_auth {
endpoint: "/api/v1/streaming?stream=public:media",
});
test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=public:local",
});
test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=public:local:media",
});
test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
});
test_bad_auth_token_in_query!(user_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=user",
});
test_missing_auth!(user_missing_auth_token {
endpoint: "/api/v1/streaming?stream=user",
});
test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=user:notification",
});
test_missing_auth!(user_notification_missing_auth_token {
endpoint: "/api/v1/streaming?stream=user:notification",
});
test_bad_auth_token_in_query!(direct_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=direct",
});
test_missing_auth!(direct_missing_auth_token {
endpoint: "/api/v1/streaming?stream=direct",
});
test_bad_auth_token_in_query!(list_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=list&list=1",
});
test_missing_auth!(list_missing_auth_token {
endpoint: "/api/v1/streaming?stream=list&list=1",
});
#[test]
#[should_panic(expected = "NotFound")]
fn nonexistant_endpoint() {
warp::test::request()
.path("/api/v1/streaming/DOES_NOT_EXIST")
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject())
.expect("in test");
}
}

View File

@ -40,7 +40,7 @@ impl ClientAgent {
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())),
id: Uuid::default(),
target_timeline: String::new(),
current_user: User::public(),
current_user: User::default(),
}
}
@ -61,12 +61,12 @@ impl ClientAgent {
/// a different user, the `Receiver` is responsible for figuring
/// that out and avoiding duplicated connections. Thus, it is safe to
/// use this method for each new client connection.
pub fn init_for_user(&mut self, target_timeline: &str, user: User) {
pub fn init_for_user(&mut self, user: User) {
self.id = Uuid::new_v4();
self.target_timeline = target_timeline.to_owned();
self.target_timeline = user.target_timeline.to_owned();
self.current_user = user;
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
receiver.manage_new_timeline(self.id, target_timeline);
receiver.manage_new_timeline(self.id, &self.target_timeline);
}
}