Add unit tests, (some) integration tests, and documentation

This commit is contained in:
Daniel Sockwell 2019-04-30 18:41:13 -04:00
parent ae08218c0f
commit 4649f89442
9 changed files with 642 additions and 177 deletions

View File

@ -1,15 +1,4 @@
# RageQuit
A blazingly fast drop-in replacement for the Mastodon streaming api server
A WIP blazingly fast drop-in replacement for the Mastodon streaming api server.
## Notes on data flow
The current structure of the app is as follows:
Client Request --> Warp
Warp filters for valid requests and parses request data. Based on that data, it repeatedly polls the StreamManager
Warp --> StreamManager
The StreamManager consults a hash table to see if there is a currently open PubSub channel. If there is, it uses that channel; if not, it (synchronously) sends a subscribe command to Redis. The StreamManager polls the Receiver, providing info about which StreamManager it is that is doing the polling. The stream manager is also responsible for monitoring the hash table to see if it should unsubscribe from any channels and, if necessary, sending the unsubscribe command.
StreamManger --> Receiver
The Receiver receives data from Redis and stores it in a series of queues (one for each StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages relevant to that StreamManager and removes them from the queue.

View File

@ -1,6 +1,8 @@
//! Custom Errors and Warp::Rejections
use serde_derive::Serialize;
#[derive(Serialize)]
struct ErrorMessage {
pub struct ErrorMessage {
error: String,
}
impl ErrorMessage {
@ -11,6 +13,7 @@ impl ErrorMessage {
}
}
/// Recover from Errors by sending appropriate Warp::Rejections
pub fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {

View File

@ -1,109 +1,66 @@
mod error;
mod query;
mod receiver;
mod stream;
mod user;
mod utils;
//! Streaming server for Mastodon
//!
//!
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
//! connect to the server with the API described [in the public API
//! documentation](https://docs.joinmastodon.org/api/streaming/)
//!
//! # Notes on data flow
//! * **Client Request → Warp**:
//! Warp filters for valid requests and parses request data. Based on that data, it repeatedly polls
//! the StreamManager
//!
//! * **Warp → StreamManager**:
//! The StreamManager consults a hash table to see if there is a currently open PubSub channel. If
//! there is, it uses that channel; if not, it (synchronously) sends a subscribe command to Redis.
//! The StreamManager polls the Receiver, providing info about which StreamManager it is that is
//! doing the polling. The stream manager is also responsible for monitoring the hash table to see
//! if it should unsubscribe from any channels and, if necessary, sending the unsubscribe command.
//!
//! * **StreamManger → Receiver**:
//! The Receiver receives data from Redis and stores it in a series of queues (one for each
//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages
//! relevant to that StreamManager and removes them from the queue.
pub mod error;
pub mod query;
pub mod receiver;
pub mod stream;
pub mod timeline;
pub mod user;
use futures::stream::Stream;
use receiver::Receiver;
use stream::StreamManager;
use user::{Filter, Scope, User};
use warp::{path, Filter as WarpFilter};
use user::{Filter, User};
use warp::Filter as WarpFilter;
fn main() {
pretty_env_logger::init();
// GET /api/v1/streaming/user [private; language filter]
let user_timeline = path!("api" / "v1" / "streaming" / "user")
.and(path::end())
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user));
// GET /api/v1/streaming/user/notification [private; notification filter]
let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification")
.and(path::end())
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user.with_notification_filter()));
// GET /api/v1/streaming/public [public; language filter]
let public_timeline = path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(user::get_access_token(user::Scope::Public))
.and_then(|token| user::get_account(token, Scope::Public))
.map(|user: User| ("public".to_owned(), user.with_language_filter()));
// GET /api/v1/streaming/public?only_media=true [public; language filter]
let public_timeline_media = path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(user::get_access_token(user::Scope::Public))
.and_then(|token| user::get_account(token, Scope::Public))
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:media".to_owned(), user.with_language_filter()),
_ => ("public".to_owned(), user.with_language_filter()),
});
// GET /api/v1/streaming/public/local [public; language filter]
let local_timeline = path!("api" / "v1" / "streaming" / "public" / "local")
.and(path::end())
.and(user::get_access_token(user::Scope::Public))
.and_then(|token| user::get_account(token, Scope::Public))
.map(|user: User| ("public:local".to_owned(), user.with_language_filter()));
// GET /api/v1/streaming/public/local?only_media=true [public; language filter]
let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local")
.and(user::get_access_token(user::Scope::Public))
.and_then(|token| user::get_account(token, Scope::Public))
.and(warp::query())
.and(path::end())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()),
_ => ("public:local".to_owned(), user.with_language_filter()),
});
// GET /api/v1/streaming/direct [private; *no* filter]
let direct_timeline = path!("api" / "v1" / "streaming" / "direct")
.and(path::end())
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.map(|user: User| (format!("direct:{}", user.id), user.with_no_filter()));
// GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter]
let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()));
// GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter]
let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()));
// GET /api/v1/streaming/list?list=:list_id [private; no filter]
let list_timeline = path!("api" / "v1" / "streaming" / "list")
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.and(warp::query())
.and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user)))
.untuple_one()
.and(path::end())
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()));
let redis_updates = StreamManager::new(Receiver::new());
let routes = or!(
user_timeline,
user_timeline_notifications,
public_timeline_media,
public_timeline,
local_timeline_media,
local_timeline,
direct_timeline,
hashtag_timeline,
hashtag_timeline_local,
list_timeline
let routes = any_of!(
// GET /api/v1/streaming/user/notification [private; notification filter]
timeline::user_notifications(),
// GET /api/v1/streaming/user [private; language filter]
timeline::user(),
// GET /api/v1/streaming/public/local?only_media=true [public; language filter]
timeline::public_local_media(),
// GET /api/v1/streaming/public?only_media=true [public; language filter]
timeline::public_media(),
// GET /api/v1/streaming/public/local [public; language filter]
timeline::public_local(),
// GET /api/v1/streaming/public [public; language filter]
timeline::public(),
// GET /api/v1/streaming/direct [private; *no* filter]
timeline::direct(),
// GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter]
timeline::hashtag(),
// GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter]
timeline::hashtag_local(),
// GET /api/v1/streaming/list?list=:list_id [private; no filter]
timeline::list()
)
.untuple_one()
.and(warp::sse())

View File

@ -1,3 +1,4 @@
//! Validate query prarams with type checking
use serde_derive::Deserialize;
#[derive(Deserialize, Debug)]

View File

@ -1,3 +1,4 @@
//! Interfacing with Redis and stream the results on to the `StreamManager`
use crate::user::User;
use futures::stream::Stream;
use futures::{Async, Poll};
@ -12,6 +13,7 @@ use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::Duration;
/// The item that streams from Redis and is polled by the `StreamManger`
#[derive(Debug)]
pub struct Receiver {
stream: TcpStream,
@ -35,22 +37,25 @@ impl Receiver {
msg_queue: HashMap::new(),
}
}
/// Update the `StreamManager` that is currently polling the `Receiver`
pub fn set_polled_by(&mut self, id: Uuid) -> &Self {
self.polled_by = id;
self
}
/// Send a subscribe command to the Redis PubSub
pub fn subscribe(&mut self, tl: &str) {
let subscribe_cmd = redis_cmd_from("subscribe", &tl);
info!("Subscribing to {}", &tl);
self.stream
.write(&subscribe_cmd)
.write_all(&subscribe_cmd)
.expect("Can subscribe to Redis");
}
/// Send an unsubscribe command to the Redis PubSub
pub fn unsubscribe(&mut self, tl: &str) {
let unsubscribe_cmd = redis_cmd_from("unsubscribe", &tl);
info!("Subscribing to {}", &tl);
self.stream
.write(&unsubscribe_cmd)
.write_all(&unsubscribe_cmd)
.expect("Can unsubscribe from Redis");
}
}

View File

@ -1,3 +1,4 @@
//! Manage all existing Redis PubSub connection
use crate::receiver::Receiver;
use crate::user::User;
use futures::stream::Stream;
@ -9,6 +10,7 @@ use std::time::Instant;
use tokio::io::Error;
use uuid::Uuid;
/// Struct for manageing all Redis streams
#[derive(Clone)]
pub struct StreamManager {
receiver: Arc<Mutex<Receiver>>,
@ -26,11 +28,16 @@ impl StreamManager {
}
}
/// Clone the StreamManager with a new unique id
pub fn new_copy(&self) -> Self {
let id = Uuid::new_v4();
StreamManager { id, ..self.clone() }
}
/// Subscribe to a channel if not already subscribed
///
///
/// `.add()` also unsubscribes from any channels that no longer have clients
pub fn add(&mut self, timeline: &str, _user: &User) -> &Self {
let mut subscriptions = self.subscriptions.lock().expect("No other thread panic");
let mut receiver = self.receiver.lock().unwrap();

496
src/timeline.rs Normal file
View File

@ -0,0 +1,496 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use crate::query;
use crate::user::{Scope, User};
use warp::filters::BoxedFilter;
use warp::{path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, User),);
/// GET /api/v1/streaming/user
///
///
/// **private**. Filter: `Language`
pub fn user() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "user")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user))
.boxed()
}
/// GET /api/v1/streaming/user/notification
///
///
/// **private**. Filter: `Notification`
///
///
/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented?
pub fn user_notifications() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "user" / "notification")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user.with_notification_filter()))
.boxed()
}
/// GET /api/v1/streaming/public
///
///
/// **public**. Filter: `Language`
pub fn public() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.map(|user: User| ("public".to_owned(), user.with_language_filter()))
.boxed()
}
/// GET /api/v1/streaming/public?only_media=true
///
///
/// **public**. Filter: `Language`
pub fn public_media() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:media".to_owned(), user.with_language_filter()),
_ => ("public".to_owned(), user.with_language_filter()),
})
.boxed()
}
/// GET /api/v1/streaming/public/local
///
///
/// **public**. Filter: `Language`
pub fn public_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public" / "local")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.map(|user: User| ("public:local".to_owned(), user.with_language_filter()))
.boxed()
}
/// GET /api/v1/streaming/public/local?only_media=true
///
///
/// **public**. Filter: `Language`
pub fn public_local_media() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public" / "local")
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.and(warp::query())
.and(path::end())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()),
_ => ("public:local".to_owned(), user.with_language_filter()),
})
.boxed()
}
/// GET /api/v1/streaming/direct
///
///
/// **private**. Filter: `None`
pub fn direct() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "direct")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (format!("direct:{}", user.id), user.with_no_filter()))
.boxed()
}
/// GET /api/v1/streaming/hashtag?tag=:hashtag
///
///
/// **public**. Filter: `None`
pub fn hashtag() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/hashtag/local?tag=:hashtag
///
///
/// **public**. Filter: `None`
pub fn hashtag_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/list?list=:list_id
///
///
/// **private**. Filter: `None`
pub fn list() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "list")
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.and(warp::query())
.and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user)))
.untuple_one()
.and(path::end())
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()))
.boxed()
}
/// Combines multiple routes with the same return type together with
/// `or()` and `unify()`
#[macro_export]
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::user;
#[test]
fn user_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&user());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user",))
.filter(&user());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token={}",
access_token
))
.filter(&user())
.unwrap();
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).unwrap();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user())
.unwrap();
let expected_user = User::from_access_token(access_token, user::Scope::Private).unwrap();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn user_notifications_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
))
.filter(&user_notifications());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user/notification",))
.filter(&user_notifications());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_notifications_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token={}",
access_token
))
.filter(&user_notifications())
.unwrap();
let expected_user = User::from_access_token(access_token.clone(), user::Scope::Private)
.unwrap()
.with_notification_filter();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user/notification")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user_notifications())
.unwrap();
let expected_user = User::from_access_token(access_token, user::Scope::Private)
.unwrap()
.with_notification_filter();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn public_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public")
.filter(&public())
.unwrap();
assert_eq!(value.0, "public".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=true")
.filter(&public_media())
.unwrap();
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=1")
.filter(&public_media())
.unwrap();
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_local_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local")
.filter(&public_local())
.unwrap();
assert_eq!(value.0, "public:local".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_local_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=true")
.filter(&public_local_media())
.unwrap();
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=1")
.filter(&public_local_media())
.unwrap();
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn direct_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
))
.filter(&direct());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/direct",))
.filter(&direct());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn direct_timeline_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token={}",
access_token
))
.filter(&direct())
.unwrap();
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).unwrap();
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/direct")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&direct())
.unwrap();
let expected_user = User::from_access_token(access_token, user::Scope::Private).unwrap();
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn hashtag_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag?tag=a")
.filter(&hashtag())
.unwrap();
assert_eq!(value.0, "hashtag:a".to_string());
assert_eq!(value.1, User::public());
}
#[test]
fn hashtag_timeline_local() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag/local?tag=a")
.filter(&hashtag_local())
.unwrap();
assert_eq!(value.0, "hashtag:a:local".to_string());
assert_eq!(value.1, User::public());
}
#[test]
#[ignore]
fn list_timeline_auth() {
let list_id = 1;
let list_owner_id = get_list_owner(list_id);
let access_token = get_access_token(list_owner_id);
// Query Auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token={}&list={}",
access_token, list_id,
))
.filter(&list())
.unwrap();
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).unwrap();
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
// Header Auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/list?list=1")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&list())
.unwrap();
let expected_user = User::from_access_token(access_token, user::Scope::Private).unwrap();
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn list_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&list());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/list?list=1",))
.filter(&list());
assert!(no_access_token(value));
}
fn get_list_owner(list_number: i32) -> i64 {
let list_number: i64 = list_number.into();
let conn = user::connect_to_postgres();
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_number],
)
.unwrap();
assert_eq!(
rows.len(),
1,
"Test database must contain at least one user with a list to run this test."
);
rows.get(0).get(1)
}
fn get_access_token(user_id: i64) -> String {
let conn = user::connect_to_postgres();
let rows = &conn
.query(
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
&[&user_id],
)
.expect("Can get access token from id");
rows.get(0).get(0)
}
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => {
true
}
_ => false,
},
_ => false,
}
}
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
_ => false,
},
_ => false,
}
}
}

View File

@ -1,36 +1,28 @@
use crate::{or, query};
//! Create a User by querying the Postgres database with the user's access_token
use crate::{any_of, query};
use log::info;
use postgres;
use warp::Filter as WarpFilter;
pub fn get_access_token(scope: Scope) -> warp::filters::BoxedFilter<(String,)> {
let token_from_header = warp::header::header::<String>("authorization")
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
let public = warp::any().map(|| "no access token".to_string());
match scope {
// if they're trying to access a private scope without an access token, reject the request
Scope::Private => or!(token_from_query, token_from_header).boxed(),
// if they're trying to access a public scope without an access token, proceed
Scope::Public => or!(token_from_query, token_from_header, public).boxed(),
}
}
fn conn() -> postgres::Connection {
/// (currently hardcoded to localhost)
pub fn connect_to_postgres() -> postgres::Connection {
postgres::Connection::connect(
"postgres://dsock@localhost/mastodon_development",
postgres::TlsMode::None,
)
.unwrap()
}
#[derive(Clone, Debug)]
/// The filters that can be applied to toots after they come from Redis
#[derive(Clone, Debug, PartialEq)]
pub enum Filter {
None,
Language,
Notification,
}
#[derive(Clone, Debug)]
/// The User (with data read from Postgres)
#[derive(Clone, Debug, PartialEq)]
pub struct User {
pub id: i64,
pub langs: Option<Vec<String>>,
@ -38,26 +30,70 @@ pub struct User {
pub filter: Filter,
}
impl User {
/// Create a user from the access token supplied in the header or query paramaters
pub fn from_access_token(token: String, scope: Scope) -> Result<Self, warp::reject::Rejection> {
let conn = connect_to_postgres();
let result = &conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !result.is_empty() {
let only_row = result.get(0);
let id: i64 = only_row.get(1);
let langs: Option<Vec<String>> = only_row.get(2);
info!("Granting logged-in access");
Ok(User {
id,
langs,
logged_in: true,
filter: Filter::None,
})
} else if let Scope::Public = scope {
info!("Granting public access");
Ok(User {
id: -1,
langs: None,
logged_in: false,
filter: Filter::None,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
/// Add a Notification filter
pub fn with_notification_filter(self) -> Self {
Self {
filter: Filter::Notification,
..self
}
}
/// Add a Language filter
pub fn with_language_filter(self) -> Self {
Self {
filter: Filter::Language,
..self
}
}
/// Remove all filters
pub fn with_no_filter(self) -> Self {
Self {
filter: Filter::None,
..self
}
}
/// Determine whether the User is authorised for a specified list
pub fn is_authorized_for_list(&self, list: i64) -> Result<i64, warp::reject::Rejection> {
let conn = conn();
let conn = connect_to_postgres();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
@ -74,6 +110,7 @@ impl User {
Err(warp::reject::custom("Error: Invalid access token"))
}
/// A public (non-authenticated) User
pub fn public() -> Self {
User {
id: -1,
@ -84,46 +121,23 @@ impl User {
}
}
/// Whether the endpoint requires authentication or not
pub enum Scope {
Public,
Private,
}
pub fn get_account(token: String, scope: Scope) -> Result<User, warp::reject::Rejection> {
let conn = conn();
let result = &conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !result.is_empty() {
let only_row = result.get(0);
let id: i64 = only_row.get(1);
let langs: Option<Vec<String>> = only_row.get(2);
println!("Granting logged-in access");
Ok(User {
id,
langs,
logged_in: true,
filter: Filter::None,
})
} else if let Scope::Public = scope {
println!("Granting public access");
Ok(User {
id: -1,
langs: None,
logged_in: false,
filter: Filter::None,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
impl Scope {
pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> {
let token_from_header = warp::header::header::<String>("authorization")
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
let public = warp::any().map(|| "no access token".to_string());
match self {
// if they're trying to access a private scope without an access token, reject the request
Scope::Private => any_of!(token_from_query, token_from_header).boxed(),
// if they're trying to access a public scope without an access token, proceed
Scope::Public => any_of!(token_from_query, token_from_header, public).boxed(),
}
}
}

View File

@ -1,8 +1 @@
/// Combines multiple routes with the same return type together with
/// `or()` and `unify()`
#[macro_export]
macro_rules! or {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}