mirror of https://github.com/mastodon/flodgatt
Add tests for websocket routes (#38)
* Refactor organazation of SSE This commit refactors how SSE requests are handled to bring them into line with how WS requests are handled and increase consistency. * Add websocket tests * Bump version to 0.2.0 Bump version and update name from ragequit to flodgatt. * Add test for non-existant endpoints * Update documentation for recent changes``
This commit is contained in:
parent
90602d17ed
commit
ecfdda093c
|
@ -300,6 +300,27 @@ name = "fixedbitset"
|
|||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "flodgatt"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openssl 0.10.24 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)",
|
||||
"postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)",
|
||||
"pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"warp 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.6"
|
||||
|
@ -877,27 +898,6 @@ dependencies = [
|
|||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ragequit"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"openssl 0.10.24 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)",
|
||||
"postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)",
|
||||
"pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"warp 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.5.6"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "ragequit"
|
||||
name = "flodgatt"
|
||||
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
|
||||
edition = "2018"
|
||||
|
||||
|
|
41
src/lib.rs
41
src/lib.rs
|
@ -1,26 +1,24 @@
|
|||
//! Streaming server for Mastodon
|
||||
//!
|
||||
//!
|
||||
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
|
||||
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
|
||||
//! connect to the server with the API described [in Mastodon's public API
|
||||
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a
|
||||
//! server is running this sever, Mastodon clients can use either Server Sent Events or
|
||||
//! WebSockets to connect to the server with the API described [in Mastodon's public API
|
||||
//! documentation](https://docs.joinmastodon.org/api/streaming/).
|
||||
//!
|
||||
//! # Data Flow
|
||||
//! * **Parsing the client request**
|
||||
//! When the client request first comes in, it is parsed based on the endpoint it targets (for
|
||||
//! server sent events), its query parameters, and its headers (for WebSocket). Based on this
|
||||
//! data, we authenticate the user, retrieve relevant user data from Postgres, and determine the
|
||||
//! timeline targeted by the request. Successfully parsing the client request results in generating
|
||||
//! a `User` and target `timeline` for the request. If any requests are invalid/not authorized, we
|
||||
//! reject them in this stage.
|
||||
//! * **Streaming update from Redis to the client**:
|
||||
//! After the user request is parsed, we pass the `User` and `timeline` data on to the
|
||||
//! `ClientAgent`. The `ClientAgent` is responsible for communicating the user's request to the
|
||||
//! `Receiver`, polling the `Receiver` for any updates, and then for wording those updates on to the
|
||||
//! client. The `Receiver`, in tern, is responsible for managing the Redis subscriptions,
|
||||
//! periodically polling Redis, and sorting the replies from Redis into queues for when it is polled
|
||||
//! by the `ClientAgent`.
|
||||
//! * **Parsing the client request** When the client request first comes in, it is
|
||||
//! parsed based on the endpoint it targets (for server sent events), its query parameters,
|
||||
//! and its headers (for WebSocket). Based on this data, we authenticate the user, retrieve
|
||||
//! relevant user data from Postgres, and determine the timeline targeted by the request.
|
||||
//! Successfully parsing the client request results in generating a `User` corresponding to
|
||||
//! the request. If any requests are invalid/not authorized, we reject them in this stage.
|
||||
//! * **Streaming update from Redis to the client**: After the user request is parsed, we pass
|
||||
//! the `User` data on to the `ClientAgent`. The `ClientAgent` is responsible for
|
||||
//! communicating the user's request to the `Receiver`, polling the `Receiver` for any
|
||||
//! updates, and then for wording those updates on to the client. The `Receiver`, in tern, is
|
||||
//! responsible for managing the Redis subscriptions, periodically polling Redis, and sorting
|
||||
//! the replies from Redis into queues for when it is polled by the `ClientAgent`.
|
||||
//!
|
||||
//! # Concurrency
|
||||
//! The `Receiver` is created when the server is first initialized, and there is only one
|
||||
|
@ -31,11 +29,10 @@
|
|||
//! that the `Receiver`'s poll of Redis be fast, since there will only ever be one
|
||||
//! `Receiver`.
|
||||
//!
|
||||
//! # Configuration
|
||||
//! By default, the server uses config values from the `config.rs` module; these values can be
|
||||
//! overwritten with environmental variables or in the `.env` file. The most important settings
|
||||
//! for performance control the frequency with which the `ClientAgent` polls the `Receiver` and
|
||||
//! the frequency with which the `Receiver` polls Redis.
|
||||
//! # Configuration By default, the server uses config values from the `config.rs` module;
|
||||
//! these values can be overwritten with environmental variables or in the `.env` file. The
|
||||
//! most important settings for performance control the frequency with which the `ClientAgent`
|
||||
//! polls the `Receiver` and the frequency with which the `Receiver` polls Redis.
|
||||
//!
|
||||
pub mod config;
|
||||
pub mod parse_client_request;
|
||||
|
|
61
src/main.rs
61
src/main.rs
|
@ -1,10 +1,10 @@
|
|||
use log::{log_enabled, Level};
|
||||
use ragequit::{
|
||||
use flodgatt::{
|
||||
config,
|
||||
parse_client_request::{sse, user, ws},
|
||||
redis_to_client_stream,
|
||||
redis_to_client_stream::ClientAgent,
|
||||
};
|
||||
use log::{log_enabled, Level};
|
||||
use warp::{ws::Ws2, Filter as WarpFilter};
|
||||
|
||||
fn main() {
|
||||
|
@ -17,18 +17,14 @@ fn main() {
|
|||
};
|
||||
|
||||
// Server Sent Events
|
||||
//
|
||||
// For SSE, the API requires users to use different endpoints, so we first filter based on
|
||||
// the endpoint. Using that endpoint determine the `timeline` the user is requesting,
|
||||
// the scope for that `timeline`, and authenticate the `User` if they provided a token.
|
||||
let sse_routes = sse::filter_incomming_request()
|
||||
let sse_routes = sse::extract_user_or_reject()
|
||||
.and(warp::sse())
|
||||
.map(
|
||||
move |timeline: String, user: user::User, sse_connection_to_client: warp::sse::Sse| {
|
||||
move |user: user::User, sse_connection_to_client: warp::sse::Sse| {
|
||||
// Create a new ClientAgent
|
||||
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
|
||||
// Assign that agent to generate a stream of updates for the user/timeline pair
|
||||
client_agent.init_for_user(&timeline, user);
|
||||
// Assign ClientAgent to generate stream of updates for the user/timeline pair
|
||||
client_agent.init_for_user(user);
|
||||
// send the updates through the SSE connection
|
||||
redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client)
|
||||
},
|
||||
|
@ -37,52 +33,17 @@ fn main() {
|
|||
.recover(config::handle_errors);
|
||||
|
||||
// WebSocket
|
||||
//
|
||||
// For WS, the API specifies a single endpoint, so we extract the User/timeline pair
|
||||
// directy from the query
|
||||
let websocket_routes = ws::extract_user_and_query()
|
||||
.and_then(move |mut user: user::User, q: ws::Query, ws: Ws2| {
|
||||
let websocket_routes = ws::extract_user_or_reject()
|
||||
.and(warp::ws::ws2())
|
||||
.and_then(move |user: user::User, ws: Ws2| {
|
||||
let token = user.access_token.clone();
|
||||
let read_scope = user.scopes.clone();
|
||||
|
||||
let timeline = match q.stream.as_ref() {
|
||||
// Public endpoints:
|
||||
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
|
||||
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
|
||||
tl @ "public" | tl @ "public:local" => tl.to_string(),
|
||||
// Hashtag endpoints:
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
|
||||
// Private endpoints: User
|
||||
"user" if user.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
format!("{}", user.id)
|
||||
}
|
||||
"user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => {
|
||||
user = user.set_filter(user::Filter::Notification);
|
||||
format!("{}", user.id)
|
||||
}
|
||||
// List endpoint:
|
||||
"list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
|
||||
format!("list:{}", q.list)
|
||||
}
|
||||
// Direct endpoint:
|
||||
"direct" if user.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
"direct".to_string()
|
||||
}
|
||||
// Reject unathorized access attempts for private endpoints
|
||||
"user" | "user:notification" | "direct" | "list" => {
|
||||
return Err(warp::reject::custom("Error: Invalid Access Token"))
|
||||
}
|
||||
// Other endpoints don't exist:
|
||||
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
|
||||
};
|
||||
|
||||
// Create a new ClientAgent
|
||||
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
|
||||
// Assign that agent to generate a stream of updates for the user/timeline pair
|
||||
client_agent.init_for_user(&timeline, user);
|
||||
client_agent.init_for_user(user);
|
||||
// send the updates through the WS connection (along with the User's access_token
|
||||
// which is sent for security)
|
||||
Ok((
|
||||
Ok::<_, warp::Rejection>((
|
||||
ws.on_upgrade(move |socket| {
|
||||
redis_to_client_stream::send_updates_to_ws(socket, client_agent)
|
||||
}),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//! Parse the client request and return a 'timeline' and a (maybe authenticated) `User`
|
||||
//! Parse the client request and return a (possibly authenticated) `User`
|
||||
pub mod query;
|
||||
pub mod sse;
|
||||
pub mod user;
|
||||
|
|
|
@ -3,8 +3,32 @@ use serde_derive::Deserialize;
|
|||
use warp::filters::BoxedFilter;
|
||||
use warp::Filter as WarpFilter;
|
||||
|
||||
macro_rules! query {
|
||||
($name:tt => $parameter:tt:$type:tt) => {
|
||||
#[derive(Debug)]
|
||||
pub struct Query {
|
||||
pub access_token: Option<String>,
|
||||
pub stream: String,
|
||||
pub media: bool,
|
||||
pub hashtag: String,
|
||||
pub list: i64,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn update_access_token(
|
||||
self,
|
||||
token: Option<String>,
|
||||
) -> Result<Self, warp::reject::Rejection> {
|
||||
match token {
|
||||
Some(token) => Ok(Self {
|
||||
access_token: Some(token),
|
||||
..self
|
||||
}),
|
||||
None => Ok(self),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! make_query_type {
|
||||
($name:tt => $parameter:tt:$type:ty) => {
|
||||
#[derive(Deserialize, Debug, Default)]
|
||||
pub struct $name {
|
||||
pub $parameter: $type,
|
||||
|
@ -19,16 +43,16 @@ macro_rules! query {
|
|||
}
|
||||
};
|
||||
}
|
||||
query!(Media => only_media:String);
|
||||
make_query_type!(Media => only_media:String);
|
||||
impl Media {
|
||||
pub fn is_truthy(&self) -> bool {
|
||||
self.only_media == "true" || self.only_media == "1"
|
||||
}
|
||||
}
|
||||
query!(Hashtag => tag: String);
|
||||
query!(List => list: i64);
|
||||
query!(Auth => access_token: String);
|
||||
query!(Stream => stream: String);
|
||||
make_query_type!(Hashtag => tag: String);
|
||||
make_query_type!(List => list: i64);
|
||||
make_query_type!(Auth => access_token: Option<String>);
|
||||
make_query_type!(Stream => stream: String);
|
||||
impl ToString for Stream {
|
||||
fn to_string(&self) -> String {
|
||||
format!("{:?}", self)
|
||||
|
@ -43,3 +67,19 @@ pub fn optional_media_query() -> BoxedFilter<(Media,)> {
|
|||
.unify()
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub struct OptionalAccessToken;
|
||||
|
||||
impl OptionalAccessToken {
|
||||
pub fn from_header() -> warp::filters::BoxedFilter<(Option<String>,)> {
|
||||
let from_header = warp::header::header::<String>("authorization").map(|auth: String| {
|
||||
match auth.split(' ').nth(1) {
|
||||
Some(s) => Some(s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
});
|
||||
let no_token = warp::any().map(|| None);
|
||||
|
||||
from_header.or(no_token).unify().boxed()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
//! Filters for all the endpoints accessible for Server Sent Event updates
|
||||
use super::{
|
||||
query,
|
||||
user::{Filter::*, OptionalAccessToken, User},
|
||||
};
|
||||
use crate::config::CustomError;
|
||||
use super::{query, query::Query, user::User};
|
||||
use warp::{filters::BoxedFilter, path, Filter};
|
||||
|
||||
#[allow(dead_code)]
|
||||
|
@ -12,140 +8,111 @@ type TimelineUser = ((String, User),);
|
|||
/// Helper macro to match on the first of any of the provided filters
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*
|
||||
$filter$(.or($other_filter).unify())*.boxed()
|
||||
};
|
||||
}
|
||||
|
||||
pub fn filter_incomming_request() -> BoxedFilter<(String, User)> {
|
||||
macro_rules! parse_query {
|
||||
(path => $start:tt $(/ $next:tt)*
|
||||
endpoint => $endpoint:expr) => {
|
||||
path!($start $(/ $next)*)
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: $endpoint.to_string(),
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
};
|
||||
}
|
||||
pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
|
||||
any_of!(
|
||||
path!("api" / "v1" / "streaming" / "user" / "notification")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_reject)
|
||||
.map(|user: User| (user.id.to_string(), user.set_filter(Notification))),
|
||||
// **NOTE**: This endpoint was present in the node.js server, but not in the
|
||||
// [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local).
|
||||
// Should it be publicly documented?
|
||||
path!("api" / "v1" / "streaming" / "user")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_reject)
|
||||
.map(|user: User| (user.id.to_string(), user)),
|
||||
path!("api" / "v1" / "streaming" / "public" / "local")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_public_user)
|
||||
.and(warp::query())
|
||||
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
|
||||
"1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)),
|
||||
_ => ("public:local".to_owned(), user.set_filter(Language)),
|
||||
}),
|
||||
path!("api" / "v1" / "streaming" / "public")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_public_user)
|
||||
.and(warp::query())
|
||||
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
|
||||
"1" | "true" => ("public:media".to_owned(), user.set_filter(Language)),
|
||||
_ => ("public".to_owned(), user.set_filter(Language)),
|
||||
}),
|
||||
path!("api" / "v1" / "streaming" / "public" / "local")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_public_user)
|
||||
.map(|user: User| ("public:local".to_owned(), user.set_filter(Language))),
|
||||
path!("api" / "v1" / "streaming" / "public")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_public_user)
|
||||
.map(|user: User| ("public".to_owned(), user.set_filter(Language))),
|
||||
path!("api" / "v1" / "streaming" / "direct")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_reject)
|
||||
.map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))),
|
||||
// **Note**: Hashtags are *not* filtered on language, right?
|
||||
path!("api" / "v1" / "streaming" / "hashtag" / "local")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_public_user)
|
||||
.and(warp::query())
|
||||
.map(|_, q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())),
|
||||
path!("api" / "v1" / "streaming" / "hashtag")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_public_user)
|
||||
.and(warp::query())
|
||||
.map(|_, q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())),
|
||||
path!("api" / "v1" / "streaming" / "list")
|
||||
.and(OptionalAccessToken::from_header_or_query())
|
||||
.and_then(User::from_access_token_or_reject)
|
||||
.and(warp::query())
|
||||
.and_then(|user: User, q: query::List| {
|
||||
if user.owns_list(q.list) {
|
||||
(Ok(q.list), Ok(user))
|
||||
} else {
|
||||
(Err(CustomError::unauthorized_list()), Ok(user))
|
||||
}
|
||||
})
|
||||
.untuple_one()
|
||||
.map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter)))
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user" / "notification"
|
||||
endpoint => "user:notification" ),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user"
|
||||
endpoint => "user"),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public" / "local"
|
||||
endpoint => "public:local"),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public"
|
||||
endpoint => "public"),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "direct"
|
||||
endpoint => "direct"),
|
||||
parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
||||
endpoint => "hashtag:local"),
|
||||
parse_query!(path => "api" / "v1" / "streaming" / "hashtag"
|
||||
endpoint => "hashtag"),
|
||||
parse_query!(path => "api" / "v1" / "streaming" / "list"
|
||||
endpoint => "list")
|
||||
)
|
||||
.untuple_one()
|
||||
// because SSE requests place their `access_token` in the header instead of in a query
|
||||
// parameter, we need to update our Query if the header has a token
|
||||
.and(query::OptionalAccessToken::from_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(User::from_query)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
struct TestUser;
|
||||
impl TestUser {
|
||||
fn logged_in() -> User {
|
||||
User::from_access_token_or_reject(Some("TEST_USER".to_string())).expect("in test")
|
||||
}
|
||||
fn public() -> User {
|
||||
User::from_access_token_or_public_user(None).expect("in test")
|
||||
}
|
||||
}
|
||||
use crate::parse_client_request::user::{Filter, OauthScope};
|
||||
|
||||
macro_rules! test_public_endpoint {
|
||||
($name:ident {
|
||||
endpoint: $path:expr,
|
||||
timeline: $timeline:expr,
|
||||
user: $user:expr,
|
||||
}) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
let (timeline, user) = warp::test::request()
|
||||
let user = warp::test::request()
|
||||
.path($path)
|
||||
.filter(&filter_incomming_request())
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
assert_eq!(&timeline, $timeline);
|
||||
assert_eq!(user, $user);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! test_private_endpoint {
|
||||
($name:ident {
|
||||
endpoint: $path:expr,
|
||||
$(query: $query:expr,)*
|
||||
timeline: $timeline:expr,
|
||||
user: $user:expr,
|
||||
}) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
let path = format!("{}?access_token=TEST_USER", $path);
|
||||
$(let path = format!("{}&{}", path, $query);)*
|
||||
let (timeline, user) = warp::test::request()
|
||||
let user = warp::test::request()
|
||||
.path(&path)
|
||||
.filter(&filter_incomming_request())
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
assert_eq!(&timeline, $timeline);
|
||||
assert_eq!(user, $user);
|
||||
let (timeline, user) = warp::test::request()
|
||||
let user = warp::test::request()
|
||||
.path(&path)
|
||||
.header("Authorization", "Bearer: TEST_USER")
|
||||
.filter(&filter_incomming_request())
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
assert_eq!(&timeline, $timeline);
|
||||
assert_eq!(user, $user);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! test_bad_auth_token_in_query {
|
||||
($name: ident {
|
||||
endpoint: $path:expr,
|
||||
|
@ -153,19 +120,17 @@ mod test {
|
|||
}) => {
|
||||
#[test]
|
||||
#[should_panic(expected = "Error: Invalid access token")]
|
||||
|
||||
fn $name() {
|
||||
let path = format!("{}?access_token=INVALID", $path);
|
||||
$(let path = format!("{}&{}", path, $query);)*
|
||||
dbg!(&path);
|
||||
warp::test::request()
|
||||
.path(&path)
|
||||
.filter(&filter_incomming_request())
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! test_bad_auth_token_in_header {
|
||||
($name: ident {
|
||||
endpoint: $path:expr,
|
||||
|
@ -180,7 +145,7 @@ mod test {
|
|||
warp::test::request()
|
||||
.path(&path)
|
||||
.header("Authorization", "Bearer: INVALID")
|
||||
.filter(&filter_incomming_request())
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
};
|
||||
|
@ -197,7 +162,7 @@ mod test {
|
|||
$(let path = format!("{}?{}", path, $query);)*
|
||||
warp::test::request()
|
||||
.path(&path)
|
||||
.filter(&filter_incomming_request())
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
};
|
||||
|
@ -205,13 +170,193 @@ mod test {
|
|||
|
||||
test_public_endpoint!(public_media_true {
|
||||
endpoint: "/api/v1/streaming/public?only_media=true",
|
||||
timeline: "public:media",
|
||||
user: TestUser::public().set_filter(Language),
|
||||
user: User {
|
||||
target_timeline: "public:media".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(public_media_1 {
|
||||
endpoint: "/api/v1/streaming/public?only_media=1",
|
||||
timeline: "public:media",
|
||||
user: TestUser::public().set_filter(Language),
|
||||
user: User {
|
||||
target_timeline: "public:media".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(public_local {
|
||||
endpoint: "/api/v1/streaming/public/local",
|
||||
user: User {
|
||||
target_timeline: "public:local".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(public_local_media_true {
|
||||
endpoint: "/api/v1/streaming/public/local?only_media=true",
|
||||
user: User {
|
||||
target_timeline: "public:local:media".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(public_local_media_1 {
|
||||
endpoint: "/api/v1/streaming/public/local?only_media=1",
|
||||
user: User {
|
||||
target_timeline: "public:local:media".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(hashtag {
|
||||
endpoint: "/api/v1/streaming/hashtag?tag=a",
|
||||
user: User {
|
||||
target_timeline: "hashtag:a".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(hashtag_local {
|
||||
endpoint: "/api/v1/streaming/hashtag/local?tag=a",
|
||||
user: User {
|
||||
target_timeline: "hashtag:local:a".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
|
||||
test_private_endpoint!(user {
|
||||
endpoint: "/api/v1/streaming/user",
|
||||
user: User {
|
||||
target_timeline: "1".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
},
|
||||
});
|
||||
test_private_endpoint!(user_notification {
|
||||
endpoint: "/api/v1/streaming/user/notification",
|
||||
user: User {
|
||||
target_timeline: "1".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::Notification,
|
||||
},
|
||||
});
|
||||
test_private_endpoint!(direct {
|
||||
endpoint: "/api/v1/streaming/direct",
|
||||
user: User {
|
||||
target_timeline: "direct".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
},
|
||||
});
|
||||
|
||||
test_private_endpoint!(list_valid_list {
|
||||
endpoint: "/api/v1/streaming/list",
|
||||
query: "list=1",
|
||||
user: User {
|
||||
target_timeline: "list:1".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
},
|
||||
});
|
||||
test_bad_auth_token_in_query!(public_media_true_bad_auth {
|
||||
endpoint: "/api/v1/streaming/public",
|
||||
|
@ -221,29 +366,12 @@ mod test {
|
|||
endpoint: "/api/v1/streaming/public",
|
||||
query: "only_media=1",
|
||||
});
|
||||
|
||||
test_public_endpoint!(public_local {
|
||||
endpoint: "/api/v1/streaming/public/local",
|
||||
timeline: "public:local",
|
||||
user: TestUser::public().set_filter(Language),
|
||||
});
|
||||
test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/public/local",
|
||||
});
|
||||
test_bad_auth_token_in_header!(public_local_bad_auth_in_header {
|
||||
endpoint: "/api/v1/streaming/public/local",
|
||||
});
|
||||
|
||||
test_public_endpoint!(public_local_media_true {
|
||||
endpoint: "/api/v1/streaming/public/local?only_media=true",
|
||||
timeline: "public:local:media",
|
||||
user: TestUser::public().set_filter(Language),
|
||||
});
|
||||
test_public_endpoint!(public_local_media_1 {
|
||||
endpoint: "/api/v1/streaming/public/local?only_media=1",
|
||||
timeline: "public:local:media",
|
||||
user: TestUser::public().set_filter(Language),
|
||||
});
|
||||
test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/public/local",
|
||||
query: "only_media=1",
|
||||
|
@ -252,12 +380,6 @@ mod test {
|
|||
endpoint: "/api/v1/streaming/public/local",
|
||||
query: "only_media=true",
|
||||
});
|
||||
|
||||
test_public_endpoint!(hashtag {
|
||||
endpoint: "/api/v1/streaming/hashtag?tag=a",
|
||||
timeline: "hashtag:a",
|
||||
user: TestUser::public(),
|
||||
});
|
||||
test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/hashtag",
|
||||
query: "tag=a",
|
||||
|
@ -266,26 +388,6 @@ mod test {
|
|||
endpoint: "/api/v1/streaming/hashtag",
|
||||
query: "tag=a",
|
||||
});
|
||||
|
||||
test_public_endpoint!(hashtag_local {
|
||||
endpoint: "/api/v1/streaming/hashtag/local?tag=a",
|
||||
timeline: "hashtag:a:local",
|
||||
user: TestUser::public(),
|
||||
});
|
||||
test_bad_auth_token_in_query!(hashtag_local_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/hashtag/local",
|
||||
query: "tag=a",
|
||||
});
|
||||
test_bad_auth_token_in_header!(hashtag_local_bad_auth_in_header {
|
||||
endpoint: "/api/v1/streaming/hashtag/local",
|
||||
query: "tag=a",
|
||||
});
|
||||
|
||||
test_private_endpoint!(user {
|
||||
endpoint: "/api/v1/streaming/user",
|
||||
timeline: "1",
|
||||
user: TestUser::logged_in(),
|
||||
});
|
||||
test_bad_auth_token_in_query!(user_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/user",
|
||||
});
|
||||
|
@ -295,12 +397,6 @@ mod test {
|
|||
test_missing_auth!(user_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming/user",
|
||||
});
|
||||
|
||||
test_private_endpoint!(user_notification {
|
||||
endpoint: "/api/v1/streaming/user/notification",
|
||||
timeline: "1",
|
||||
user: TestUser::logged_in().set_filter(Notification),
|
||||
});
|
||||
test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/user/notification",
|
||||
});
|
||||
|
@ -310,12 +406,6 @@ mod test {
|
|||
test_missing_auth!(user_notification_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming/user/notification",
|
||||
});
|
||||
|
||||
test_private_endpoint!(direct {
|
||||
endpoint: "/api/v1/streaming/direct",
|
||||
timeline: "direct:1",
|
||||
user: TestUser::logged_in(),
|
||||
});
|
||||
test_bad_auth_token_in_query!(direct_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/direct",
|
||||
});
|
||||
|
@ -325,13 +415,6 @@ mod test {
|
|||
test_missing_auth!(direct_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming/direct",
|
||||
});
|
||||
|
||||
test_private_endpoint!(list_valid_list {
|
||||
endpoint: "/api/v1/streaming/list",
|
||||
query: "list=1",
|
||||
timeline: "list:1",
|
||||
user: TestUser::logged_in(),
|
||||
});
|
||||
test_bad_auth_token_in_query!(list_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming/list",
|
||||
query: "list=1",
|
||||
|
@ -345,4 +428,13 @@ mod test {
|
|||
query: "list=1",
|
||||
});
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "NotFound")]
|
||||
fn nonexistant_endpoint() {
|
||||
warp::test::request()
|
||||
.path("/api/v1/streaming/DOES_NOT_EXIST")
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -5,16 +5,8 @@ mod mock_postgres;
|
|||
use mock_postgres as postgres;
|
||||
#[cfg(not(test))]
|
||||
mod postgres;
|
||||
use crate::parse_client_request::query;
|
||||
use log::info;
|
||||
use super::query::Query;
|
||||
use warp::reject::Rejection;
|
||||
use warp::Filter as WarpFilter;
|
||||
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*
|
||||
};
|
||||
}
|
||||
|
||||
/// The filters that can be applied to toots after they come from Redis
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
|
@ -23,10 +15,16 @@ pub enum Filter {
|
|||
Language,
|
||||
Notification,
|
||||
}
|
||||
impl Default for Filter {
|
||||
fn default() -> Self {
|
||||
Filter::NoFilter
|
||||
}
|
||||
}
|
||||
|
||||
/// The User (with data read from Postgres)
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct User {
|
||||
pub target_timeline: String,
|
||||
pub id: i64,
|
||||
pub access_token: String,
|
||||
pub scopes: OauthScope,
|
||||
|
@ -34,11 +32,7 @@ pub struct User {
|
|||
pub logged_in: bool,
|
||||
pub filter: Filter,
|
||||
}
|
||||
impl Default for User {
|
||||
fn default() -> Self {
|
||||
User::public()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct OauthScope {
|
||||
pub all: bool,
|
||||
|
@ -62,72 +56,82 @@ impl From<Vec<String>> for OauthScope {
|
|||
}
|
||||
}
|
||||
|
||||
/// Create a user based on the supplied path and access scope for the resource
|
||||
#[macro_export]
|
||||
macro_rules! user_from_path {
|
||||
($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +)
|
||||
.and($scope.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, $scope)))
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn from_access_token_or_reject(token: Option<String>) -> Result<Self, Rejection> {
|
||||
match token {
|
||||
None => Err(warp::reject::custom("Error: Missing access token")),
|
||||
pub fn from_query(q: Query) -> Result<Self, Rejection> {
|
||||
let (id, access_token, scopes, langs, logged_in) = match q.access_token.clone() {
|
||||
None => (
|
||||
-1,
|
||||
"no access token".to_owned(),
|
||||
OauthScope::default(),
|
||||
None,
|
||||
false,
|
||||
),
|
||||
Some(token) => {
|
||||
let (id, langs, scope_list) = postgres::query_for_user_data(&token);
|
||||
if id == -1 {
|
||||
return Err(warp::reject::custom("Error: Invalid access token"));
|
||||
}
|
||||
let scopes = OauthScope::from(scope_list);
|
||||
|
||||
Ok(User {
|
||||
id,
|
||||
access_token: token,
|
||||
scopes,
|
||||
langs,
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
})
|
||||
(id, token, scopes, langs, true)
|
||||
}
|
||||
}
|
||||
};
|
||||
let mut user = User {
|
||||
id,
|
||||
target_timeline: "PLACEHOLDER".to_string(),
|
||||
access_token,
|
||||
scopes,
|
||||
langs,
|
||||
logged_in,
|
||||
filter: Filter::Language,
|
||||
};
|
||||
|
||||
user = user.update_timeline_and_filter(q)?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub fn from_access_token_or_public_user(token: Option<String>) -> Result<Self, Rejection> {
|
||||
match token {
|
||||
None => Ok(User::public()),
|
||||
Some(_) => User::from_access_token_or_reject(token),
|
||||
}
|
||||
}
|
||||
/// Create a user from the access token supplied in the header or query paramaters
|
||||
pub fn from_access_token(
|
||||
access_token: String,
|
||||
scope: Scope,
|
||||
) -> Result<Self, warp::reject::Rejection> {
|
||||
let (id, langs, scope_list) = postgres::query_for_user_data(&access_token);
|
||||
let scopes = OauthScope::from(scope_list);
|
||||
if id != -1 || scope == Scope::Public {
|
||||
let (logged_in, log_msg) = match id {
|
||||
-1 => (false, "Public access to non-authenticated endpoints"),
|
||||
_ => (true, "Granting logged-in access"),
|
||||
};
|
||||
info!("{}", log_msg);
|
||||
Ok(User {
|
||||
id,
|
||||
access_token,
|
||||
scopes,
|
||||
langs,
|
||||
logged_in,
|
||||
filter: Filter::NoFilter,
|
||||
})
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
/// Set the Notification/Language filter
|
||||
pub fn set_filter(self, filter: Filter) -> Self {
|
||||
Self { filter, ..self }
|
||||
fn update_timeline_and_filter(mut self, q: Query) -> Result<Self, Rejection> {
|
||||
let read_scope = self.scopes.clone();
|
||||
|
||||
let timeline = match q.stream.as_ref() {
|
||||
// Public endpoints:
|
||||
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
|
||||
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
|
||||
tl @ "public" | tl @ "public:local" => tl.to_string(),
|
||||
// Hashtag endpoints:
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
|
||||
// Private endpoints: User
|
||||
"user" if self.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
self.filter = Filter::NoFilter;
|
||||
format!("{}", self.id)
|
||||
}
|
||||
"user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => {
|
||||
self.filter = Filter::Notification;
|
||||
format!("{}", self.id)
|
||||
}
|
||||
// List endpoint:
|
||||
"list" if self.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
|
||||
self.filter = Filter::NoFilter;
|
||||
format!("list:{}", q.list)
|
||||
}
|
||||
// Direct endpoint:
|
||||
"direct" if self.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
self.filter = Filter::NoFilter;
|
||||
"direct".to_string()
|
||||
}
|
||||
// Reject unathorized access attempts for private endpoints
|
||||
"user" | "user:notification" | "direct" | "list" => {
|
||||
return Err(warp::reject::custom("Error: Missing access token"))
|
||||
}
|
||||
// Other endpoints don't exist:
|
||||
_ => return Err(warp::reject::custom("Error: Nonexistent endpoint")),
|
||||
};
|
||||
Ok(Self {
|
||||
target_timeline: timeline,
|
||||
..self
|
||||
})
|
||||
}
|
||||
|
||||
/// Determine whether the User is authorised for a specified list
|
||||
pub fn owns_list(&self, list: i64) -> bool {
|
||||
match postgres::query_list_owner(list) {
|
||||
|
@ -135,75 +139,4 @@ impl User {
|
|||
_ => false,
|
||||
}
|
||||
}
|
||||
pub fn public2() -> warp::filters::BoxedFilter<(User,)> {
|
||||
warp::any()
|
||||
.map(|| User {
|
||||
id: -1,
|
||||
access_token: String::from("no access token"),
|
||||
scopes: OauthScope::default(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
filter: Filter::NoFilter,
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
/// A public (non-authenticated) User
|
||||
pub fn public() -> Self {
|
||||
User {
|
||||
id: -1,
|
||||
access_token: String::from("no access token"),
|
||||
scopes: OauthScope::default(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
filter: Filter::NoFilter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OptionalAccessToken;
|
||||
|
||||
impl OptionalAccessToken {
|
||||
pub fn from_header_or_query() -> warp::filters::BoxedFilter<(Option<String>,)> {
|
||||
let from_header = warp::header::header::<String>("authorization").map(|auth: String| {
|
||||
match auth.split(' ').nth(1) {
|
||||
Some(s) => Some(s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
});
|
||||
let from_query = warp::query().map(|q: query::Auth| Some(q.access_token));
|
||||
let no_token = warp::any().map(|| None);
|
||||
|
||||
any_of!(from_header, from_query, no_token).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the endpoint requires authentication or not
|
||||
#[derive(PartialEq)]
|
||||
pub enum Scope {
|
||||
Public,
|
||||
Private,
|
||||
}
|
||||
impl Scope {
|
||||
pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> {
|
||||
let token_from_header_http_push = warp::header::header::<String>("authorization")
|
||||
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
|
||||
let token_from_header_ws =
|
||||
warp::header::header::<String>("Sec-WebSocket-Protocol").map(|auth: String| auth);
|
||||
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
|
||||
|
||||
let private_scopes = any_of!(
|
||||
token_from_header_http_push,
|
||||
token_from_header_ws,
|
||||
token_from_query
|
||||
);
|
||||
|
||||
let public = warp::any().map(|| "no access token".to_string());
|
||||
|
||||
match self {
|
||||
// if they're trying to access a private scope without an access token, reject the request
|
||||
Scope::Private => private_scopes.boxed(),
|
||||
// if they're trying to access a public scope without an access token, proceed
|
||||
Scope::Public => any_of!(private_scopes, public).boxed(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,43 +1,316 @@
|
|||
//! Filters for the WebSocket endpoint
|
||||
use super::{
|
||||
query,
|
||||
user::{Scope, User},
|
||||
};
|
||||
use crate::user_from_path;
|
||||
use super::{query, query::Query, user::User};
|
||||
use warp::{filters::BoxedFilter, path, Filter};
|
||||
|
||||
/// WebSocket filters
|
||||
pub fn extract_user_and_query() -> BoxedFilter<(User, Query, warp::ws::Ws2)> {
|
||||
user_from_path!("streaming", Scope::Public)
|
||||
fn parse_query() -> BoxedFilter<(Query,)> {
|
||||
path!("api" / "v1" / "streaming")
|
||||
.and(path::end())
|
||||
.and(warp::query())
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.and(warp::ws2())
|
||||
.map(
|
||||
|user: User,
|
||||
stream: query::Stream,
|
||||
|stream: query::Stream,
|
||||
auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List,
|
||||
ws: warp::ws::Ws2| {
|
||||
let query = Query {
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: stream.stream,
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
};
|
||||
(user, query, ws)
|
||||
}
|
||||
},
|
||||
)
|
||||
.untuple_one()
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Query {
|
||||
pub stream: String,
|
||||
pub media: bool,
|
||||
pub hashtag: String,
|
||||
pub list: i64,
|
||||
pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
|
||||
parse_query().and_then(User::from_query).boxed()
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::parse_client_request::user::{Filter, OauthScope};
|
||||
|
||||
macro_rules! test_public_endpoint {
|
||||
($name:ident {
|
||||
endpoint: $path:expr,
|
||||
user: $user:expr,
|
||||
}) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
let user = warp::test::request()
|
||||
.path($path)
|
||||
.header("connection", "upgrade")
|
||||
.header("upgrade", "websocket")
|
||||
.header("sec-websocket-version", "13")
|
||||
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
assert_eq!(user, $user);
|
||||
}
|
||||
};
|
||||
}
|
||||
macro_rules! test_private_endpoint {
|
||||
($name:ident {
|
||||
endpoint: $path:expr,
|
||||
user: $user:expr,
|
||||
}) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
let path = format!("{}&access_token=TEST_USER", $path);
|
||||
let user = warp::test::request()
|
||||
.path(&path)
|
||||
.header("connection", "upgrade")
|
||||
.header("upgrade", "websocket")
|
||||
.header("sec-websocket-version", "13")
|
||||
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
assert_eq!(user, $user);
|
||||
}
|
||||
};
|
||||
}
|
||||
macro_rules! test_bad_auth_token_in_query {
|
||||
($name: ident {
|
||||
endpoint: $path:expr,
|
||||
|
||||
}) => {
|
||||
#[test]
|
||||
#[should_panic(expected = "Error: Invalid access token")]
|
||||
|
||||
fn $name() {
|
||||
let path = format!("{}&access_token=INVALID", $path);
|
||||
warp::test::request()
|
||||
.path(&path)
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
};
|
||||
}
|
||||
macro_rules! test_missing_auth {
|
||||
($name: ident {
|
||||
endpoint: $path:expr,
|
||||
}) => {
|
||||
#[test]
|
||||
#[should_panic(expected = "Error: Missing access token")]
|
||||
fn $name() {
|
||||
let path = $path;
|
||||
warp::test::request()
|
||||
.path(&path)
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
test_public_endpoint!(public_media {
|
||||
endpoint: "/api/v1/streaming?stream=public:media",
|
||||
user: User {
|
||||
target_timeline: "public:media".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(public_local {
|
||||
endpoint: "/api/v1/streaming?stream=public:local",
|
||||
user: User {
|
||||
target_timeline: "public:local".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(public_local_media {
|
||||
endpoint: "/api/v1/streaming?stream=public:local:media",
|
||||
user: User {
|
||||
target_timeline: "public:local:media".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(hashtag {
|
||||
endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
|
||||
user: User {
|
||||
target_timeline: "hashtag:a".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
test_public_endpoint!(hashtag_local {
|
||||
endpoint: "/api/v1/streaming?stream=hashtag:local&tag=a",
|
||||
user: User {
|
||||
target_timeline: "hashtag:local:a".to_string(),
|
||||
id: -1,
|
||||
access_token: "no access token".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: false,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: false,
|
||||
filter: Filter::Language,
|
||||
},
|
||||
});
|
||||
|
||||
test_private_endpoint!(user {
|
||||
endpoint: "/api/v1/streaming?stream=user",
|
||||
user: User {
|
||||
target_timeline: "1".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
},
|
||||
});
|
||||
test_private_endpoint!(user_notification {
|
||||
endpoint: "/api/v1/streaming?stream=user:notification",
|
||||
user: User {
|
||||
target_timeline: "1".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::Notification,
|
||||
},
|
||||
});
|
||||
test_private_endpoint!(direct {
|
||||
endpoint: "/api/v1/streaming?stream=direct",
|
||||
user: User {
|
||||
target_timeline: "direct".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
},
|
||||
});
|
||||
test_private_endpoint!(list_valid_list {
|
||||
endpoint: "/api/v1/streaming?stream=list&list=1",
|
||||
user: User {
|
||||
target_timeline: "list:1".to_string(),
|
||||
id: 1,
|
||||
access_token: "TEST_USER".to_string(),
|
||||
langs: None,
|
||||
scopes: OauthScope {
|
||||
all: true,
|
||||
statuses: false,
|
||||
notify: false,
|
||||
lists: false,
|
||||
},
|
||||
logged_in: true,
|
||||
filter: Filter::NoFilter,
|
||||
},
|
||||
});
|
||||
|
||||
test_bad_auth_token_in_query!(public_media_true_bad_auth {
|
||||
endpoint: "/api/v1/streaming?stream=public:media",
|
||||
});
|
||||
test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=public:local",
|
||||
});
|
||||
test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=public:local:media",
|
||||
});
|
||||
test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
|
||||
});
|
||||
test_bad_auth_token_in_query!(user_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=user",
|
||||
});
|
||||
test_missing_auth!(user_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming?stream=user",
|
||||
});
|
||||
test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=user:notification",
|
||||
});
|
||||
test_missing_auth!(user_notification_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming?stream=user:notification",
|
||||
});
|
||||
test_bad_auth_token_in_query!(direct_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=direct",
|
||||
});
|
||||
test_missing_auth!(direct_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming?stream=direct",
|
||||
});
|
||||
test_bad_auth_token_in_query!(list_bad_auth_in_query {
|
||||
endpoint: "/api/v1/streaming?stream=list&list=1",
|
||||
});
|
||||
test_missing_auth!(list_missing_auth_token {
|
||||
endpoint: "/api/v1/streaming?stream=list&list=1",
|
||||
});
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "NotFound")]
|
||||
fn nonexistant_endpoint() {
|
||||
warp::test::request()
|
||||
.path("/api/v1/streaming/DOES_NOT_EXIST")
|
||||
.header("connection", "upgrade")
|
||||
.header("upgrade", "websocket")
|
||||
.header("sec-websocket-version", "13")
|
||||
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
.filter(&extract_user_or_reject())
|
||||
.expect("in test");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ impl ClientAgent {
|
|||
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())),
|
||||
id: Uuid::default(),
|
||||
target_timeline: String::new(),
|
||||
current_user: User::public(),
|
||||
current_user: User::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,12 +61,12 @@ impl ClientAgent {
|
|||
/// a different user, the `Receiver` is responsible for figuring
|
||||
/// that out and avoiding duplicated connections. Thus, it is safe to
|
||||
/// use this method for each new client connection.
|
||||
pub fn init_for_user(&mut self, target_timeline: &str, user: User) {
|
||||
pub fn init_for_user(&mut self, user: User) {
|
||||
self.id = Uuid::new_v4();
|
||||
self.target_timeline = target_timeline.to_owned();
|
||||
self.target_timeline = user.target_timeline.to_owned();
|
||||
self.current_user = user;
|
||||
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
|
||||
receiver.manage_new_timeline(self.id, target_timeline);
|
||||
receiver.manage_new_timeline(self.id, &self.target_timeline);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue