From 4649f89442cfc8f3623f6e634fb56435fefaa445 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Tue, 30 Apr 2019 18:41:13 -0400 Subject: [PATCH] Add unit tests, (some) integration tests, and documentation --- README.md | 13 +- src/error.rs | 5 +- src/main.rs | 153 ++++++--------- src/query.rs | 1 + src/receiver.rs | 9 +- src/stream.rs | 7 + src/timeline.rs | 496 ++++++++++++++++++++++++++++++++++++++++++++++++ src/user.rs | 126 ++++++------ src/utils.rs | 9 +- 9 files changed, 642 insertions(+), 177 deletions(-) create mode 100644 src/timeline.rs diff --git a/README.md b/README.md index 1a991db..808cc65 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,4 @@ # RageQuit -A blazingly fast drop-in replacement for the Mastodon streaming api server +A WIP blazingly fast drop-in replacement for the Mastodon streaming api server. -## Notes on data flow -The current structure of the app is as follows: - -Client Request --> Warp - Warp filters for valid requests and parses request data. Based on that data, it repeatedly polls the StreamManager - -Warp --> StreamManager - The StreamManager consults a hash table to see if there is a currently open PubSub channel. If there is, it uses that channel; if not, it (synchronously) sends a subscribe command to Redis. The StreamManager polls the Receiver, providing info about which StreamManager it is that is doing the polling. The stream manager is also responsible for monitoring the hash table to see if it should unsubscribe from any channels and, if necessary, sending the unsubscribe command. - -StreamManger --> Receiver - The Receiver receives data from Redis and stores it in a series of queues (one for each StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages relevant to that StreamManager and removes them from the queue. diff --git a/src/error.rs b/src/error.rs index fd24bb1..bede2f6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,8 @@ +//! Custom Errors and Warp::Rejections use serde_derive::Serialize; + #[derive(Serialize)] -struct ErrorMessage { +pub struct ErrorMessage { error: String, } impl ErrorMessage { @@ -11,6 +13,7 @@ impl ErrorMessage { } } +/// Recover from Errors by sending appropriate Warp::Rejections pub fn handle_errors( rejection: warp::reject::Rejection, ) -> Result { diff --git a/src/main.rs b/src/main.rs index 0dead6b..ebc80f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,109 +1,66 @@ -mod error; -mod query; -mod receiver; -mod stream; -mod user; -mod utils; +//! Streaming server for Mastodon +//! +//! +//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server +//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to +//! connect to the server with the API described [in the public API +//! documentation](https://docs.joinmastodon.org/api/streaming/) +//! +//! # Notes on data flow +//! * **Client Request → Warp**: +//! Warp filters for valid requests and parses request data. Based on that data, it repeatedly polls +//! the StreamManager +//! +//! * **Warp → StreamManager**: +//! The StreamManager consults a hash table to see if there is a currently open PubSub channel. If +//! there is, it uses that channel; if not, it (synchronously) sends a subscribe command to Redis. +//! The StreamManager polls the Receiver, providing info about which StreamManager it is that is +//! doing the polling. The stream manager is also responsible for monitoring the hash table to see +//! if it should unsubscribe from any channels and, if necessary, sending the unsubscribe command. +//! +//! * **StreamManger → Receiver**: +//! The Receiver receives data from Redis and stores it in a series of queues (one for each +//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages +//! relevant to that StreamManager and removes them from the queue. + +pub mod error; +pub mod query; +pub mod receiver; +pub mod stream; +pub mod timeline; +pub mod user; use futures::stream::Stream; use receiver::Receiver; use stream::StreamManager; -use user::{Filter, Scope, User}; -use warp::{path, Filter as WarpFilter}; +use user::{Filter, User}; +use warp::Filter as WarpFilter; fn main() { pretty_env_logger::init(); - // GET /api/v1/streaming/user [private; language filter] - let user_timeline = path!("api" / "v1" / "streaming" / "user") - .and(path::end()) - .and(user::get_access_token(Scope::Private)) - .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| (user.id.to_string(), user)); - - // GET /api/v1/streaming/user/notification [private; notification filter] - let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification") - .and(path::end()) - .and(user::get_access_token(Scope::Private)) - .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| (user.id.to_string(), user.with_notification_filter())); - - // GET /api/v1/streaming/public [public; language filter] - let public_timeline = path!("api" / "v1" / "streaming" / "public") - .and(path::end()) - .and(user::get_access_token(user::Scope::Public)) - .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| ("public".to_owned(), user.with_language_filter())); - - // GET /api/v1/streaming/public?only_media=true [public; language filter] - let public_timeline_media = path!("api" / "v1" / "streaming" / "public") - .and(path::end()) - .and(user::get_access_token(user::Scope::Public)) - .and_then(|token| user::get_account(token, Scope::Public)) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:media".to_owned(), user.with_language_filter()), - _ => ("public".to_owned(), user.with_language_filter()), - }); - - // GET /api/v1/streaming/public/local [public; language filter] - let local_timeline = path!("api" / "v1" / "streaming" / "public" / "local") - .and(path::end()) - .and(user::get_access_token(user::Scope::Public)) - .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| ("public:local".to_owned(), user.with_language_filter())); - - // GET /api/v1/streaming/public/local?only_media=true [public; language filter] - let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local") - .and(user::get_access_token(user::Scope::Public)) - .and_then(|token| user::get_account(token, Scope::Public)) - .and(warp::query()) - .and(path::end()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()), - _ => ("public:local".to_owned(), user.with_language_filter()), - }); - - // GET /api/v1/streaming/direct [private; *no* filter] - let direct_timeline = path!("api" / "v1" / "streaming" / "direct") - .and(path::end()) - .and(user::get_access_token(Scope::Private)) - .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| (format!("direct:{}", user.id), user.with_no_filter())); - - // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] - let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag") - .and(warp::query()) - .and(path::end()) - .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())); - - // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] - let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local") - .and(warp::query()) - .and(path::end()) - .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())); - - // GET /api/v1/streaming/list?list=:list_id [private; no filter] - let list_timeline = path!("api" / "v1" / "streaming" / "list") - .and(user::get_access_token(Scope::Private)) - .and_then(|token| user::get_account(token, Scope::Private)) - .and(warp::query()) - .and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user))) - .untuple_one() - .and(path::end()) - .map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter())); - let redis_updates = StreamManager::new(Receiver::new()); - let routes = or!( - user_timeline, - user_timeline_notifications, - public_timeline_media, - public_timeline, - local_timeline_media, - local_timeline, - direct_timeline, - hashtag_timeline, - hashtag_timeline_local, - list_timeline + + let routes = any_of!( + // GET /api/v1/streaming/user/notification [private; notification filter] + timeline::user_notifications(), + // GET /api/v1/streaming/user [private; language filter] + timeline::user(), + // GET /api/v1/streaming/public/local?only_media=true [public; language filter] + timeline::public_local_media(), + // GET /api/v1/streaming/public?only_media=true [public; language filter] + timeline::public_media(), + // GET /api/v1/streaming/public/local [public; language filter] + timeline::public_local(), + // GET /api/v1/streaming/public [public; language filter] + timeline::public(), + // GET /api/v1/streaming/direct [private; *no* filter] + timeline::direct(), + // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] + timeline::hashtag(), + // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] + timeline::hashtag_local(), + // GET /api/v1/streaming/list?list=:list_id [private; no filter] + timeline::list() ) .untuple_one() .and(warp::sse()) diff --git a/src/query.rs b/src/query.rs index 9c019b4..4b8aebc 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,3 +1,4 @@ +//! Validate query prarams with type checking use serde_derive::Deserialize; #[derive(Deserialize, Debug)] diff --git a/src/receiver.rs b/src/receiver.rs index f3d400f..cef6355 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -1,3 +1,4 @@ +//! Interfacing with Redis and stream the results on to the `StreamManager` use crate::user::User; use futures::stream::Stream; use futures::{Async, Poll}; @@ -12,6 +13,7 @@ use std::io::{Read, Write}; use std::net::TcpStream; use std::time::Duration; +/// The item that streams from Redis and is polled by the `StreamManger` #[derive(Debug)] pub struct Receiver { stream: TcpStream, @@ -35,22 +37,25 @@ impl Receiver { msg_queue: HashMap::new(), } } + /// Update the `StreamManager` that is currently polling the `Receiver` pub fn set_polled_by(&mut self, id: Uuid) -> &Self { self.polled_by = id; self } + /// Send a subscribe command to the Redis PubSub pub fn subscribe(&mut self, tl: &str) { let subscribe_cmd = redis_cmd_from("subscribe", &tl); info!("Subscribing to {}", &tl); self.stream - .write(&subscribe_cmd) + .write_all(&subscribe_cmd) .expect("Can subscribe to Redis"); } + /// Send an unsubscribe command to the Redis PubSub pub fn unsubscribe(&mut self, tl: &str) { let unsubscribe_cmd = redis_cmd_from("unsubscribe", &tl); info!("Subscribing to {}", &tl); self.stream - .write(&unsubscribe_cmd) + .write_all(&unsubscribe_cmd) .expect("Can unsubscribe from Redis"); } } diff --git a/src/stream.rs b/src/stream.rs index 55c295c..2b670f5 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,3 +1,4 @@ +//! Manage all existing Redis PubSub connection use crate::receiver::Receiver; use crate::user::User; use futures::stream::Stream; @@ -9,6 +10,7 @@ use std::time::Instant; use tokio::io::Error; use uuid::Uuid; +/// Struct for manageing all Redis streams #[derive(Clone)] pub struct StreamManager { receiver: Arc>, @@ -26,11 +28,16 @@ impl StreamManager { } } + /// Clone the StreamManager with a new unique id pub fn new_copy(&self) -> Self { let id = Uuid::new_v4(); StreamManager { id, ..self.clone() } } + /// Subscribe to a channel if not already subscribed + /// + /// + /// `.add()` also unsubscribes from any channels that no longer have clients pub fn add(&mut self, timeline: &str, _user: &User) -> &Self { let mut subscriptions = self.subscriptions.lock().expect("No other thread panic"); let mut receiver = self.receiver.lock().unwrap(); diff --git a/src/timeline.rs b/src/timeline.rs new file mode 100644 index 0000000..ee2530f --- /dev/null +++ b/src/timeline.rs @@ -0,0 +1,496 @@ +//! Filters for all the endpoints accessible for Server Sent Event updates +use crate::query; +use crate::user::{Scope, User}; +use warp::filters::BoxedFilter; +use warp::{path, Filter}; + +#[allow(dead_code)] +type TimelineUser = ((String, User),); + +/// GET /api/v1/streaming/user +/// +/// +/// **private**. Filter: `Language` +pub fn user() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "user") + .and(path::end()) + .and(Scope::Private.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Private)) + .map(|user: User| (user.id.to_string(), user)) + .boxed() +} + +/// GET /api/v1/streaming/user/notification +/// +/// +/// **private**. Filter: `Notification` +/// +/// +/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented? +pub fn user_notifications() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "user" / "notification") + .and(path::end()) + .and(Scope::Private.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Private)) + .map(|user: User| (user.id.to_string(), user.with_notification_filter())) + .boxed() +} + +/// GET /api/v1/streaming/public +/// +/// +/// **public**. Filter: `Language` +pub fn public() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "public") + .and(path::end()) + .and(Scope::Public.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Public)) + .map(|user: User| ("public".to_owned(), user.with_language_filter())) + .boxed() +} + +/// GET /api/v1/streaming/public?only_media=true +/// +/// +/// **public**. Filter: `Language` +pub fn public_media() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "public") + .and(path::end()) + .and(Scope::Public.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Public)) + .and(warp::query()) + .map(|user: User, q: query::Media| match q.only_media.as_ref() { + "1" | "true" => ("public:media".to_owned(), user.with_language_filter()), + _ => ("public".to_owned(), user.with_language_filter()), + }) + .boxed() +} + +/// GET /api/v1/streaming/public/local +/// +/// +/// **public**. Filter: `Language` +pub fn public_local() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "public" / "local") + .and(path::end()) + .and(Scope::Public.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Public)) + .map(|user: User| ("public:local".to_owned(), user.with_language_filter())) + .boxed() +} + +/// GET /api/v1/streaming/public/local?only_media=true +/// +/// +/// **public**. Filter: `Language` +pub fn public_local_media() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "public" / "local") + .and(Scope::Public.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Public)) + .and(warp::query()) + .and(path::end()) + .map(|user: User, q: query::Media| match q.only_media.as_ref() { + "1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()), + _ => ("public:local".to_owned(), user.with_language_filter()), + }) + .boxed() +} + +/// GET /api/v1/streaming/direct +/// +/// +/// **private**. Filter: `None` +pub fn direct() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "direct") + .and(path::end()) + .and(Scope::Private.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Private)) + .map(|user: User| (format!("direct:{}", user.id), user.with_no_filter())) + .boxed() +} + +/// GET /api/v1/streaming/hashtag?tag=:hashtag +/// +/// +/// **public**. Filter: `None` +pub fn hashtag() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "hashtag") + .and(warp::query()) + .and(path::end()) + .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())) + .boxed() +} + +/// GET /api/v1/streaming/hashtag/local?tag=:hashtag +/// +/// +/// **public**. Filter: `None` +pub fn hashtag_local() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "hashtag" / "local") + .and(warp::query()) + .and(path::end()) + .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())) + .boxed() +} + +/// GET /api/v1/streaming/list?list=:list_id +/// +/// +/// **private**. Filter: `None` +pub fn list() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "list") + .and(Scope::Private.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Private)) + .and(warp::query()) + .and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user))) + .untuple_one() + .and(path::end()) + .map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter())) + .boxed() +} + +/// Combines multiple routes with the same return type together with +/// `or()` and `unify()` +#[macro_export] +macro_rules! any_of { + ($filter:expr, $($other_filter:expr),*) => { + $filter$(.or($other_filter).unify())* + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::user; + + #[test] + fn user_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1", + )) + .filter(&user()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/user",)) + .filter(&user()); + assert!(no_access_token(value)); + } + + #[test] + #[ignore] + fn user_auth() { + let user_id: i64 = 1; + let access_token = get_access_token(user_id); + + // Query auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/user?access_token={}", + access_token + )) + .filter(&user()) + .unwrap(); + + let expected_user = + User::from_access_token(access_token.clone(), user::Scope::Private).unwrap(); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); + + // Header auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/user") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&user()) + .unwrap(); + + let expected_user = User::from_access_token(access_token, user::Scope::Private).unwrap(); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); + } + + #[test] + fn user_notifications_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN", + )) + .filter(&user_notifications()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/user/notification",)) + .filter(&user_notifications()); + assert!(no_access_token(value)); + } + + #[test] + #[ignore] + fn user_notifications_auth() { + let user_id: i64 = 1; + let access_token = get_access_token(user_id); + + // Query auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/user/notification?access_token={}", + access_token + )) + .filter(&user_notifications()) + .unwrap(); + + let expected_user = User::from_access_token(access_token.clone(), user::Scope::Private) + .unwrap() + .with_notification_filter(); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); + + // Header auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/user/notification") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&user_notifications()) + .unwrap(); + + let expected_user = User::from_access_token(access_token, user::Scope::Private) + .unwrap() + .with_notification_filter(); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); + } + #[test] + fn public_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public") + .filter(&public()) + .unwrap(); + + assert_eq!(value.0, "public".to_string()); + assert_eq!(value.1, User::public().with_language_filter()); + } + + #[test] + fn public_media_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public?only_media=true") + .filter(&public_media()) + .unwrap(); + + assert_eq!(value.0, "public:media".to_string()); + assert_eq!(value.1, User::public().with_language_filter()); + + let value = warp::test::request() + .path("/api/v1/streaming/public?only_media=1") + .filter(&public_media()) + .unwrap(); + + assert_eq!(value.0, "public:media".to_string()); + assert_eq!(value.1, User::public().with_language_filter()); + } + + #[test] + fn public_local_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public/local") + .filter(&public_local()) + .unwrap(); + + assert_eq!(value.0, "public:local".to_string()); + assert_eq!(value.1, User::public().with_language_filter()); + } + + #[test] + fn public_local_media_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public/local?only_media=true") + .filter(&public_local_media()) + .unwrap(); + + assert_eq!(value.0, "public:local:media".to_string()); + assert_eq!(value.1, User::public().with_language_filter()); + + let value = warp::test::request() + .path("/api/v1/streaming/public/local?only_media=1") + .filter(&public_local_media()) + .unwrap(); + + assert_eq!(value.0, "public:local:media".to_string()); + assert_eq!(value.1, User::public().with_language_filter()); + } + + #[test] + fn direct_timeline_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN", + )) + .filter(&direct()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/direct",)) + .filter(&direct()); + assert!(no_access_token(value)); + } + + #[test] + #[ignore] + fn direct_timeline_auth() { + let user_id: i64 = 1; + let access_token = get_access_token(user_id); + + // Query auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/direct?access_token={}", + access_token + )) + .filter(&direct()) + .unwrap(); + + let expected_user = + User::from_access_token(access_token.clone(), user::Scope::Private).unwrap(); + + assert_eq!(actual_timeline, "direct:1"); + assert_eq!(actual_user, expected_user); + + // Header auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/direct") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&direct()) + .unwrap(); + + let expected_user = User::from_access_token(access_token, user::Scope::Private).unwrap(); + + assert_eq!(actual_timeline, "direct:1"); + assert_eq!(actual_user, expected_user); + } + + #[test] + fn hashtag_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/hashtag?tag=a") + .filter(&hashtag()) + .unwrap(); + + assert_eq!(value.0, "hashtag:a".to_string()); + assert_eq!(value.1, User::public()); + } + + #[test] + fn hashtag_timeline_local() { + let value = warp::test::request() + .path("/api/v1/streaming/hashtag/local?tag=a") + .filter(&hashtag_local()) + .unwrap(); + + assert_eq!(value.0, "hashtag:a:local".to_string()); + assert_eq!(value.1, User::public()); + } + + #[test] + #[ignore] + fn list_timeline_auth() { + let list_id = 1; + let list_owner_id = get_list_owner(list_id); + let access_token = get_access_token(list_owner_id); + + // Query Auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/list?access_token={}&list={}", + access_token, list_id, + )) + .filter(&list()) + .unwrap(); + + let expected_user = + User::from_access_token(access_token.clone(), user::Scope::Private).unwrap(); + + assert_eq!(actual_timeline, "list:1"); + assert_eq!(actual_user, expected_user); + + // Header Auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/list?list=1") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&list()) + .unwrap(); + + let expected_user = User::from_access_token(access_token, user::Scope::Private).unwrap(); + + assert_eq!(actual_timeline, "list:1"); + assert_eq!(actual_user, expected_user); + } + + #[test] + fn list_timeline_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1", + )) + .filter(&list()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/list?list=1",)) + .filter(&list()); + assert!(no_access_token(value)); + } + + fn get_list_owner(list_number: i32) -> i64 { + let list_number: i64 = list_number.into(); + let conn = user::connect_to_postgres(); + let rows = &conn + .query( + "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", + &[&list_number], + ) + .unwrap(); + + assert_eq!( + rows.len(), + 1, + "Test database must contain at least one user with a list to run this test." + ); + + rows.get(0).get(1) + } + fn get_access_token(user_id: i64) -> String { + let conn = user::connect_to_postgres(); + let rows = &conn + .query( + "SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1", + &[&user_id], + ) + .expect("Can get access token from id"); + rows.get(0).get(0) + } + fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool { + match value { + Err(error) => match error.cause() { + Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => { + true + } + _ => false, + }, + _ => false, + } + } + + fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool { + match value { + Err(error) => match error.cause() { + Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true, + _ => false, + }, + _ => false, + } + } +} diff --git a/src/user.rs b/src/user.rs index d766410..f541275 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,36 +1,28 @@ -use crate::{or, query}; +//! Create a User by querying the Postgres database with the user's access_token +use crate::{any_of, query}; +use log::info; use postgres; use warp::Filter as WarpFilter; -pub fn get_access_token(scope: Scope) -> warp::filters::BoxedFilter<(String,)> { - let token_from_header = warp::header::header::("authorization") - .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); - let token_from_query = warp::query().map(|q: query::Auth| q.access_token); - let public = warp::any().map(|| "no access token".to_string()); - - match scope { - // if they're trying to access a private scope without an access token, reject the request - Scope::Private => or!(token_from_query, token_from_header).boxed(), - // if they're trying to access a public scope without an access token, proceed - Scope::Public => or!(token_from_query, token_from_header, public).boxed(), - } -} - -fn conn() -> postgres::Connection { +/// (currently hardcoded to localhost) +pub fn connect_to_postgres() -> postgres::Connection { postgres::Connection::connect( "postgres://dsock@localhost/mastodon_development", postgres::TlsMode::None, ) .unwrap() } -#[derive(Clone, Debug)] + +/// The filters that can be applied to toots after they come from Redis +#[derive(Clone, Debug, PartialEq)] pub enum Filter { None, Language, Notification, } -#[derive(Clone, Debug)] +/// The User (with data read from Postgres) +#[derive(Clone, Debug, PartialEq)] pub struct User { pub id: i64, pub langs: Option>, @@ -38,26 +30,70 @@ pub struct User { pub filter: Filter, } impl User { + /// Create a user from the access token supplied in the header or query paramaters + pub fn from_access_token(token: String, scope: Scope) -> Result { + let conn = connect_to_postgres(); + let result = &conn + .query( + " +SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages +FROM +oauth_access_tokens +INNER JOIN users ON +oauth_access_tokens.resource_owner_id = users.id +WHERE oauth_access_tokens.token = $1 +AND oauth_access_tokens.revoked_at IS NULL +LIMIT 1", + &[&token], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if !result.is_empty() { + let only_row = result.get(0); + let id: i64 = only_row.get(1); + let langs: Option> = only_row.get(2); + info!("Granting logged-in access"); + Ok(User { + id, + langs, + logged_in: true, + filter: Filter::None, + }) + } else if let Scope::Public = scope { + info!("Granting public access"); + Ok(User { + id: -1, + langs: None, + logged_in: false, + filter: Filter::None, + }) + } else { + Err(warp::reject::custom("Error: Invalid access token")) + } + } + /// Add a Notification filter pub fn with_notification_filter(self) -> Self { Self { filter: Filter::Notification, ..self } } + /// Add a Language filter pub fn with_language_filter(self) -> Self { Self { filter: Filter::Language, ..self } } + /// Remove all filters pub fn with_no_filter(self) -> Self { Self { filter: Filter::None, ..self } } + /// Determine whether the User is authorised for a specified list pub fn is_authorized_for_list(&self, list: i64) -> Result { - let conn = conn(); + let conn = connect_to_postgres(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn .query( @@ -74,6 +110,7 @@ impl User { Err(warp::reject::custom("Error: Invalid access token")) } + /// A public (non-authenticated) User pub fn public() -> Self { User { id: -1, @@ -84,46 +121,23 @@ impl User { } } +/// Whether the endpoint requires authentication or not pub enum Scope { Public, Private, } -pub fn get_account(token: String, scope: Scope) -> Result { - let conn = conn(); - let result = &conn - .query( - " -SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages -FROM -oauth_access_tokens -INNER JOIN users ON -oauth_access_tokens.resource_owner_id = users.id -WHERE oauth_access_tokens.token = $1 -AND oauth_access_tokens.revoked_at IS NULL -LIMIT 1", - &[&token], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - if !result.is_empty() { - let only_row = result.get(0); - let id: i64 = only_row.get(1); - let langs: Option> = only_row.get(2); - println!("Granting logged-in access"); - Ok(User { - id, - langs, - logged_in: true, - filter: Filter::None, - }) - } else if let Scope::Public = scope { - println!("Granting public access"); - Ok(User { - id: -1, - langs: None, - logged_in: false, - filter: Filter::None, - }) - } else { - Err(warp::reject::custom("Error: Invalid access token")) +impl Scope { + pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> { + let token_from_header = warp::header::header::("authorization") + .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); + let token_from_query = warp::query().map(|q: query::Auth| q.access_token); + let public = warp::any().map(|| "no access token".to_string()); + + match self { + // if they're trying to access a private scope without an access token, reject the request + Scope::Private => any_of!(token_from_query, token_from_header).boxed(), + // if they're trying to access a public scope without an access token, proceed + Scope::Public => any_of!(token_from_query, token_from_header, public).boxed(), + } } } diff --git a/src/utils.rs b/src/utils.rs index 8c59e43..8b13789 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1 @@ -/// Combines multiple routes with the same return type together with -/// `or()` and `unify()` -#[macro_export] -macro_rules! or { - ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())* - }; -} +