Initial cleanup/refactor

This commit is contained in:
Daniel Sockwell 2019-07-05 20:08:50 -04:00
parent f3b86ddac8
commit 1732008840
13 changed files with 1004 additions and 887 deletions

54
src/config.rs Normal file
View File

@ -0,0 +1,54 @@
//! Configuration settings for servers and databases
use dotenv::dotenv;
use log::warn;
use std::{env, net, time};
/// Configure CORS for the API server
pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors {
warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET", "OPTIONS"])
.allow_headers(vec!["Authorization", "Accept", "Cache-Control"])
}
/// Initialize logging and read values from `src/.env`
pub fn logging_and_env() {
pretty_env_logger::init();
dotenv().ok();
}
/// Configure Postgres and return a connection
pub fn postgres() -> postgres::Connection {
let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| {
format!(
"postgres://{}@localhost/mastodon_development",
env::var("USER").unwrap_or_else(|_| {
warn!("No USER env variable set. Connecting to Postgress with default `postgres` user");
"postgres".to_owned()
})
)
});
postgres::Connection::connect(postgres_addr, postgres::TlsMode::None)
.expect("Can connect to local Postgres")
}
pub fn redis_addr() -> (net::TcpStream, net::TcpStream) {
let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| "127.0.0.1:6379".to_string());
let pubsub_connection = net::TcpStream::connect(&redis_addr).expect("Can connect to Redis");
pubsub_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
let secondary_redis_connection =
net::TcpStream::connect(&redis_addr).expect("Can connect to Redis");
secondary_redis_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
(pubsub_connection, secondary_redis_connection)
}
pub fn socket_address() -> net::SocketAddr {
env::var("SERVER_ADDR")
.unwrap_or_else(|_| "127.0.0.1:4000".to_owned())
.parse()
.expect("static string")
}

View File

@ -30,3 +30,7 @@ pub fn handle_errors(
warp::http::StatusCode::UNAUTHORIZED,
))
}
pub fn unauthorized_list() -> warp::reject::Rejection {
warp::reject::custom("Error: Access to list not authorized")
}

38
src/lib.rs Normal file
View File

@ -0,0 +1,38 @@
//! Streaming server for Mastodon
//!
//!
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
//! connect to the server with the API described [in Mastodon's public API
//! documentation](https://docs.joinmastodon.org/api/streaming/).
//!
//! # Notes on data flow
//! * **Client Request → Warp**:
//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User`
//! representing the client that made the request with data from the client's request and from
//! Postgres. The `User` is authenticated, if appropriate. Warp //! repeatedly polls the
//! StreamManager for information relevant to the User.
//!
//! * **Warp → StreamManager**:
//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent
//! access to the (single) `Receiver`, which it can access behind an `Arc<Mutex>`. The `StreamManager`
//! polls the `Receiver` for any updates relevant to the current client. If there are updates, the
//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp.
//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the
//! `Receiver`) when necessary.
//!
//! * **StreamManager → Receiver**:
//! The Receiver receives data from Redis and stores it in a series of queues (one for each
//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages
//! relevant to that StreamManager and removes them from the queue.
pub mod config;
pub mod error;
pub mod postgres;
pub mod query;
pub mod receiver;
pub mod redis_cmd;
pub mod stream_manager;
pub mod timeline;
pub mod user;
pub mod ws;

View File

@ -1,58 +1,20 @@
//! Streaming server for Mastodon
//!
//!
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
//! connect to the server with the API described [in the public API
//! documentation](https://docs.joinmastodon.org/api/streaming/)
//!
//! # Notes on data flow
//! * **Client Request → Warp**:
//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User`
//! representing the client that made the request. The `User` is authenticated, if appropriate. Warp
//! repeatedly polls the StreamManager for information relevant to the User.
//!
//! * **Warp → StreamManager**:
//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent
//! access to the (single) `Receiver`, which it can access behind an `Arc<Mutex>`. The `StreamManager`
//! polles the `Receiver` for any updates relvant to the current client. If there are updates, the
//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp.
//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the
//! `Receiver`) when necessary.
//!
//! * **StreamManger → Receiver**:
//! The Receiver receives data from Redis and stores it in a series of queues (one for each
//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages
//! relevant to that StreamManager and removes them from the queue.
pub mod error;
pub mod query;
pub mod receiver;
pub mod redis_cmd;
pub mod stream;
pub mod timeline;
pub mod user;
pub mod ws;
use dotenv::dotenv;
use futures::stream::Stream;
use futures::Async;
use receiver::Receiver;
use std::env;
use std::net::SocketAddr;
use stream::StreamManager;
use user::{OauthScope::*, Scope, User};
use warp::path;
use warp::Filter as WarpFilter;
use futures::{stream::Stream, Async};
use ragequit::{
any_of, config, error,
stream_manager::StreamManager,
timeline,
user::{Filter::*, User},
ws,
};
use warp::{ws::Ws2, Filter as WarpFilter};
fn main() {
pretty_env_logger::init();
dotenv().ok();
config::logging_and_env();
let stream_manager_sse = StreamManager::new();
let stream_manager_ws = stream_manager_sse.clone();
let redis_updates = StreamManager::new(Receiver::new());
let redis_updates_sse = redis_updates.blank_copy();
let redis_updates_ws = redis_updates.blank_copy();
let routes = any_of!(
// Server Sent Events
let sse_routes = any_of!(
// GET /api/v1/streaming/user/notification [private; notification filter]
timeline::user_notifications(),
// GET /api/v1/streaming/user [private; language filter]
@ -77,12 +39,12 @@ fn main() {
.untuple_one()
.and(warp::sse())
.map(move |timeline: String, user: User, sse: warp::sse::Sse| {
let mut redis_stream = redis_updates_sse.configure_copy(&timeline, user);
let mut stream_manager = stream_manager_sse.manage_new_timeline(&timeline, user);
let event_stream = tokio::timer::Interval::new(
std::time::Instant::now(),
std::time::Duration::from_millis(100),
)
.filter_map(move |_| match redis_stream.poll() {
.filter_map(move |_| match stream_manager.poll() {
Ok(Async::Ready(Some(json_value))) => Some((
warp::sse::event(json_value["event"].clone().to_string()),
warp::sse::data(json_value["payload"].clone()),
@ -94,86 +56,54 @@ fn main() {
.with(warp::reply::with::header("Connection", "keep-alive"))
.recover(error::handle_errors);
//let redis_updates_ws = StreamManager::new(Receiver::new());
let websocket = path!("api" / "v1" / "streaming")
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.and(warp::query())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.and(warp::ws2())
.and_then(
move |mut user: User,
q: query::Stream,
m: query::Media,
h: query::Hashtag,
l: query::List,
ws: warp::ws::Ws2| {
let scopes = user.scopes.clone();
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if m.is_truthy() => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
// TODO: handle missing query
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, h.tag),
// Private endpoints: User
"user"
if user.id > 0
&& (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) =>
{
format!("{}", user.id)
}
"user:notification"
if user.id > 0
&& (scopes.contains(&Read) || scopes.contains(&ReadNotifications)) =>
{
user = user.with_notification_filter();
format!("{}", user.id)
}
// List endpoint:
// TODO: handle missing query
"list"
if user.authorized_for_list(l.list).is_ok()
&& (scopes.contains(&Read) || scopes.contains(&ReadList)) =>
{
format!("list:{}", l.list)
}
// WebSocket
let websocket_routes = ws::websocket_routes()
.and_then(move |mut user: User, q: ws::Query, ws: Ws2| {
let read_scope = user.scopes.clone();
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
// TODO: handle missing query
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
// Private endpoints: User
"user" if user.logged_in && (read_scope.all || read_scope.statuses) => {
format!("{}", user.id)
}
"user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => {
user = user.set_filter(Notification);
format!("{}", user.id)
}
// List endpoint:
// TODO: handle missing query
"list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
format!("list:{}", q.list)
}
// Direct endpoint:
"direct" if user.logged_in && (read_scope.all || read_scope.statuses) => {
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Invalid Access Token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
};
let token = user.access_token.clone();
let stream_manager = stream_manager_ws.manage_new_timeline(&timeline, user);
// Direct endpoint:
"direct"
if user.id > 0
&& (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) =>
{
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Invalid Access Token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
};
let token = user.access_token.clone();
let stream = redis_updates_ws.configure_copy(&timeline, user);
Ok((
ws.on_upgrade(move |socket| ws::send_replies(socket, stream)),
token,
))
},
)
Ok((
ws.on_upgrade(move |socket| ws::send_replies(socket, stream_manager)),
token,
))
})
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
let address: SocketAddr = env::var("SERVER_ADDR")
.unwrap_or("127.0.0.1:4000".to_owned())
.parse()
.expect("static string");
let cors = warp::cors()
.allow_any_origin()
.allow_methods(vec!["GET", "OPTIONS"])
.allow_headers(vec!["Authorization", "Accept", "Cache-Control"]);
warp::serve(websocket.or(routes).with(cors)).run(address);
let cors = config::cross_origin_resource_sharing();
let address = config::socket_address();
warp::serve(websocket_routes.or(sse_routes).with(cors)).run(address);
}

53
src/postgres.rs Normal file
View File

@ -0,0 +1,53 @@
//! Postgres queries
use crate::config;
pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec<String>) {
let conn = config::postgres();
let query_result = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&access_token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !query_result.is_empty() {
let only_row = query_result.get(0);
let id: i64 = only_row.get(1);
let scopes = only_row
.get::<_, String>(3)
.split(' ')
.map(|s| s.to_owned())
.collect();
let langs: Option<Vec<String>> = only_row.get(2);
(id, langs, scopes)
} else {
(-1, None, Vec::new())
}
}
pub fn query_list_owner(list_id: i64) -> Option<i64> {
let conn = config::postgres();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if rows.is_empty() {
None
} else {
Some(rows.get(0).get(1))
}
}

View File

@ -1,165 +1,155 @@
//! Interfacing with Redis and stream the results on to the `StreamManager`
use crate::redis_cmd;
use crate::user::User;
use futures::stream::Stream;
//! Interface with Redis and stream the results to the `StreamManager`
//! There is only one `Receiver`, which suggests that it's name is bad.
//!
//! **TODO**: Consider changing the name. Maybe RedisConnectionPool?
//! There are many AsyncReadableStreams, though. How do they fit in?
//! Figure this out ASAP.
//! A new one is created every time the Receiver is polled
use crate::{config, pubsub_cmd, redis_cmd};
use futures::{Async, Poll};
use log::info;
use regex::Regex;
use serde_json::Value;
use std::collections::{HashMap, VecDeque};
use std::env;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::{Duration, Instant};
use std::{collections, io::Read, io::Write, net, time};
use tokio::io::{AsyncRead, Error};
use uuid::Uuid;
#[derive(Debug)]
struct MsgQueue {
messages: VecDeque<Value>,
last_polled_at: Instant,
redis_channel: String,
}
impl MsgQueue {
fn new(redis_channel: impl std::fmt::Display) -> Self {
let redis_channel = redis_channel.to_string();
MsgQueue {
messages: VecDeque::new(),
last_polled_at: Instant::now(),
redis_channel,
}
}
}
/// The item that streams from Redis and is polled by the `StreamManger`
/// The item that streams from Redis and is polled by the `StreamManager`
#[derive(Debug)]
pub struct Receiver {
pubsub_connection: TcpStream,
secondary_redis_connection: TcpStream,
pubsub_connection: net::TcpStream,
secondary_redis_connection: net::TcpStream,
tl: String,
pub user: User,
manager_id: Uuid,
msg_queues: HashMap<Uuid, MsgQueue>,
clients_per_timeline: HashMap<String, i32>,
}
impl Default for Receiver {
fn default() -> Self {
Self::new()
}
msg_queues: collections::HashMap<Uuid, MsgQueue>,
clients_per_timeline: collections::HashMap<String, i32>,
}
impl Receiver {
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
/// active subscriptions).
pub fn new() -> Self {
let redis_addr = env::var("REDIS_ADDR").unwrap_or("127.0.0.1:6379".to_string());
let pubsub_connection = TcpStream::connect(&redis_addr).expect("Can connect to Redis");
pubsub_connection
.set_read_timeout(Some(Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
let secondary_redis_connection =
TcpStream::connect(&redis_addr).expect("Can connect to Redis");
secondary_redis_connection
.set_read_timeout(Some(Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
let (pubsub_connection, secondary_redis_connection) = config::redis_addr();
Self {
pubsub_connection,
secondary_redis_connection,
tl: String::new(),
user: User::public(),
manager_id: Uuid::new_v4(),
msg_queues: HashMap::new(),
clients_per_timeline: HashMap::new(),
manager_id: Uuid::default(),
msg_queues: collections::HashMap::new(),
clients_per_timeline: collections::HashMap::new(),
}
}
/// Update the `StreamManager` that is currently polling the `Receiver`
pub fn update(&mut self, id: Uuid, timeline: impl std::fmt::Display) {
self.manager_id = id;
/// Assigns the `Receiver` a new timeline to monitor and runs other
/// first-time setup.
///
/// Importantly, this method calls `subscribe_or_unsubscribe_as_needed`,
/// so Redis PubSub subscriptions are only updated when a new timeline
/// comes under management for the first time.
pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) {
self.manager_id = manager_id;
self.tl = timeline.to_string();
let old_value = self
.msg_queues
.insert(self.manager_id, MsgQueue::new(timeline));
// Consider removing/refactoring
if let Some(value) = old_value {
eprintln!(
"Data was overwritten when it shouldn't have been. Old data was: {:#?}",
value
);
}
self.subscribe_or_unsubscribe_as_needed(timeline);
}
/// Set the `Receiver`'s manager_id and target_timeline fields to the approprate
/// value to be polled by the current `StreamManager`.
pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) {
if &manager_id != &self.manager_id {
//println!("New Manager: {}", &manager_id);
}
self.manager_id = manager_id;
self.tl = timeline.to_string();
}
/// Send a subscribe command to the Redis PubSub (if needed)
pub fn maybe_subscribe(&mut self, tl: &str) {
info!("Subscribing to {}", &tl);
let manager_id = self.manager_id;
self.msg_queues.insert(manager_id, MsgQueue::new(tl));
let current_clients = self
.clients_per_timeline
.entry(tl.to_string())
.and_modify(|n| *n += 1)
.or_insert(1);
if *current_clients == 1 {
let subscribe_cmd = redis_cmd::pubsub("subscribe", tl);
self.pubsub_connection
.write_all(&subscribe_cmd)
.expect("Can subscribe to Redis");
let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "1");
self.secondary_redis_connection
.write_all(&set_subscribed_cmd)
.expect("Can set Redis");
info!("Now subscribed to: {:#?}", &self.msg_queues);
}
}
/// Drop any PubSub subscriptions that don't have active clients
pub fn unsubscribe_from_empty_channels(&mut self) {
let mut timelines_with_fewer_clients = Vec::new();
/// Drop any PubSub subscriptions that don't have active clients and check
/// that there's a subscription to the current one. If there isn't, then
/// subscribe to it.
fn subscribe_or_unsubscribe_as_needed(&mut self, tl: &str) {
let mut timelines_to_modify = Vec::new();
timelines_to_modify.push((tl.to_owned(), 1));
// Keep only message queues that have been polled recently
self.msg_queues.retain(|_id, msg_queue| {
if msg_queue.last_polled_at.elapsed() < Duration::from_secs(30) {
if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) {
true
} else {
timelines_with_fewer_clients.push(msg_queue.redis_channel.clone());
let timeline = msg_queue.redis_channel.clone();
timelines_to_modify.push((timeline, -1));
false
}
});
// Record the lower number of clients subscribed to that channel
for timeline in timelines_with_fewer_clients {
for (timeline, numerical_change) in timelines_to_modify {
let mut need_to_subscribe = false;
let count_of_subscribed_clients = self
.clients_per_timeline
.entry(timeline.clone())
.and_modify(|n| *n -= 1)
.or_insert(0);
.entry(timeline.to_owned())
.and_modify(|n| *n += numerical_change)
.or_insert_with(|| {
need_to_subscribe = true;
1
});
// If no clients, unsubscribe from the channel
if *count_of_subscribed_clients <= 0 {
self.unsubscribe(&timeline);
info!("Sent unsubscribe command");
pubsub_cmd!("unsubscribe", self, timeline.clone());
}
if need_to_subscribe {
info!("Sent subscribe command");
pubsub_cmd!("subscribe", self, timeline.clone());
}
}
}
/// Send an unsubscribe command to the Redis PubSub
pub fn unsubscribe(&mut self, tl: &str) {
let unsubscribe_cmd = redis_cmd::pubsub("unsubscribe", tl);
info!("Unsubscribing from {}", &tl);
self.pubsub_connection
.write_all(&unsubscribe_cmd)
.expect("Can unsubscribe from Redis");
let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "0");
self.secondary_redis_connection
.write_all(&set_subscribed_cmd)
.expect("Can set Redis");
info!("Now subscribed only to: {:#?}", &self.msg_queues);
fn log_number_of_msgs_in_queue(&self) {
let messages_waiting = self
.msg_queues
.get(&self.manager_id)
.expect("Guaranteed by match block")
.messages
.len();
match messages_waiting {
number if number > 10 => {
log::error!("{} messages waiting in the queue", messages_waiting)
}
_ => log::info!("{} messages waiting in the queue", messages_waiting),
}
}
}
impl Stream for Receiver {
impl Default for Receiver {
fn default() -> Self {
Receiver::new()
}
}
impl futures::stream::Stream for Receiver {
type Item = Value;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
let mut buffer = vec![0u8; 3000];
info!("Being polled by: {}", self.manager_id);
let timeline = self.tl.clone();
// Record current time as last polled time
self.msg_queues
.entry(self.manager_id)
.and_modify(|msg_queue| msg_queue.last_polled_at = Instant::now());
.and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now());
// Add any incomming messages to the back of the relevant `msg_queues`
// NOTE: This could be more/other than the `msg_queue` currently being polled
let mut async_stream = AsyncReadableStream(&mut self.pubsub_connection);
let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection);
if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer)? {
let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]);
// capture everything between `{` and `}` as potential JSON
@ -183,11 +173,14 @@ impl Stream for Receiver {
match self
.msg_queues
.entry(self.manager_id)
.or_insert_with(|| MsgQueue::new(timeline))
.or_insert_with(|| MsgQueue::new(timeline.clone()))
.messages
.pop_front()
{
Some(value) => Ok(Async::Ready(Some(value))),
Some(value) => {
self.log_number_of_msgs_in_queue();
Ok(Async::Ready(Some(value)))
}
_ => Ok(Async::NotReady),
}
}
@ -195,12 +188,34 @@ impl Stream for Receiver {
impl Drop for Receiver {
fn drop(&mut self) {
let timeline = self.tl.clone();
self.unsubscribe(&timeline);
pubsub_cmd!("unsubscribe", self, self.tl.clone());
}
}
struct AsyncReadableStream<'a>(&'a mut TcpStream);
#[derive(Debug, Clone)]
struct MsgQueue {
pub messages: collections::VecDeque<Value>,
pub last_polled_at: time::Instant,
pub redis_channel: String,
}
impl MsgQueue {
pub fn new(redis_channel: impl std::fmt::Display) -> Self {
let redis_channel = redis_channel.to_string();
MsgQueue {
messages: collections::VecDeque::new(),
last_polled_at: time::Instant::now(),
redis_channel,
}
}
}
struct AsyncReadableStream<'a>(&'a mut net::TcpStream);
impl<'a> AsyncReadableStream<'a> {
pub fn new(stream: &'a mut net::TcpStream) -> Self {
AsyncReadableStream(stream)
}
}
impl<'a> Read for AsyncReadableStream<'a> {
fn read(&mut self, buffer: &mut [u8]) -> Result<usize, std::io::Error> {

View File

@ -1,6 +1,26 @@
//! Send raw TCP commands to the Redis server
use std::fmt::Display;
/// Send a subscribe or unsubscribe to the Redis PubSub channel
#[macro_export]
macro_rules! pubsub_cmd {
($cmd:expr, $self:expr, $tl:expr) => {{
info!("Sending {} command to {}", $cmd, $tl);
$self
.pubsub_connection
.write_all(&redis_cmd::pubsub($cmd, $tl))
.expect("Can send command to Redis");
let new_value = if $cmd == "subscribe" { "1" } else { "0" };
$self
.secondary_redis_connection
.write_all(&redis_cmd::set(
format!("subscribed:timeline:{}", $tl),
new_value,
))
.expect("Can set Redis");
info!("Now subscribed to: {:#?}", $self.msg_queues);
}};
}
/// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline
pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec<u8> {
let arg = format!("timeline:{}", timeline);

View File

@ -1,93 +0,0 @@
//! Manage all existing Redis PubSub connection
use crate::receiver::Receiver;
use crate::user::{Filter, User};
use futures::stream::Stream;
use futures::{Async, Poll};
use serde_json::json;
use serde_json::Value;
use std::sync::{Arc, Mutex};
use tokio::io::Error;
use uuid::Uuid;
/// Struct for manageing all Redis streams
#[derive(Clone, Debug)]
pub struct StreamManager {
receiver: Arc<Mutex<Receiver>>,
id: uuid::Uuid,
target_timeline: String,
current_user: Option<User>,
}
impl StreamManager {
pub fn new(reciever: Receiver) -> Self {
StreamManager {
receiver: Arc::new(Mutex::new(reciever)),
id: Uuid::default(),
target_timeline: String::new(),
current_user: None,
}
}
/// Create a blank StreamManager copy
pub fn blank_copy(&self) -> Self {
StreamManager { ..self.clone() }
}
/// Create a StreamManager copy with a new unique id manage subscriptions
pub fn configure_copy(&self, timeline: &String, user: User) -> Self {
let id = Uuid::new_v4();
let mut receiver = self.receiver.lock().expect("No panic in other threads");
receiver.update(id, timeline);
receiver.maybe_subscribe(timeline);
StreamManager {
id,
current_user: Some(user),
target_timeline: timeline.clone(),
..self.clone()
}
}
}
impl Stream for StreamManager {
type Item = Value;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let mut receiver = self
.receiver
.lock()
.expect("StreamManager: No other thread panic");
receiver.update(self.id, &self.target_timeline.clone());
match receiver.poll() {
Ok(Async::Ready(Some(value))) => {
let user = self
.clone()
.current_user
.expect("Previously set current user");
let user_langs = user.langs.clone();
let event = value["event"].as_str().expect("Redis string");
let payload = value["payload"].to_string();
match (&user.filter, user_langs) {
(Filter::Notification, _) if event != "notification" => Ok(Async::NotReady),
(Filter::Language, Some(ref user_langs))
if !user_langs.contains(
&value["payload"]["language"]
.as_str()
.expect("Redis str")
.to_string(),
) =>
{
Ok(Async::NotReady)
}
_ => Ok(Async::Ready(Some(json!(
{"event": event,
"payload": payload,}
)))),
}
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e),
}
}
}

154
src/stream_manager.rs Normal file
View File

@ -0,0 +1,154 @@
//! The `StreamManager` is responsible to providing an interface between the `Warp`
//! filters and the underlying mechanics of talking with Redis/managing multiple
//! threads. The `StreamManager` is the only struct that any Warp code should
//! need to communicate with.
//!
//! The `StreamManager`'s interface is very simple. All you can do with it is:
//! * Create a totally new `StreamManger` with no shared data;
//! * Assign an existing `StreamManager` to manage an new timeline/user pair; or
//! * Poll an existing `StreamManager` to see if there are any new messages
//! for clients
//!
//! When you poll the `StreamManager`, it is responsible for polling internal data
//! structures, getting any updates from Redis, and then filtering out any updates
//! that should be excluded by relevant filters.
//!
//! Because `StreamManagers` are lightweight data structures that do not directly
//! communicate with Redis, it is appropriate to create a new `StreamManager` for
//! each new client connection.
use crate::{
receiver::Receiver,
user::{Filter, User},
};
use futures::{Async, Poll};
use serde_json::{json, Value};
use std::sync;
use std::time;
use tokio::io::Error;
use uuid::Uuid;
/// Struct for managing all Redis streams.
#[derive(Clone, Default, Debug)]
pub struct StreamManager {
receiver: sync::Arc<sync::Mutex<Receiver>>,
id: uuid::Uuid,
target_timeline: String,
current_user: User,
}
impl StreamManager {
/// Create a new `StreamManager` with no shared data.
pub fn new() -> Self {
StreamManager {
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())),
id: Uuid::default(),
target_timeline: String::new(),
current_user: User::public(),
}
}
/// Assign the `StreamManager` to manage a new timeline/user pair.
///
/// Note that this *may or may not* result in a new Redis connection.
/// If the server has already subscribed to the timeline on behalf of
/// a different user, the `StreamManager` is responsible for figuring
/// that out and avoiding duplicated connections. Thus, it is safe to
/// use this method for each new client connection.
pub fn manage_new_timeline(&self, target_timeline: &str, user: User) -> Self {
let manager_id = Uuid::new_v4();
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
receiver.manage_new_timeline(manager_id, target_timeline);
StreamManager {
id: manager_id,
current_user: user,
target_timeline: target_timeline.to_owned(),
receiver: self.receiver.clone(),
}
}
}
/// The stream that the `StreamManager` manages. `Poll` is the only method implemented.
impl futures::stream::Stream for StreamManager {
type Item = Value;
type Error = Error;
/// Checks for any new messages that should be sent to the client.
///
/// The `StreamManager` will poll underlying data structures and will reply
/// with an `Ok(Ready(Some(Value)))` if there is a new message to send to
/// the client. If there is no new message or if the new message should be
/// filtered out based on one of the user's filters, then the `StreamManager`
/// will reply with `Ok(NotReady)`. The `StreamManager` will buble up any
/// errors from the underlying data structures.
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let start_time = time::Instant::now();
let result = {
let mut receiver = self
.receiver
.lock()
.expect("StreamManager: No other thread panic");
receiver.configure_for_polling(self.id, &self.target_timeline.clone());
receiver.poll()
};
println!("Polling took: {:?}", start_time.elapsed());
let result = match result {
Ok(Async::Ready(Some(value))) => {
let user_langs = self.current_user.langs.clone();
let toot = Toot::from_json(value);
toot.ignore_if_caught_by_filter(&self.current_user.filter, user_langs)
}
Ok(inner_value) => Ok(inner_value),
Err(e) => Err(e),
};
result
}
}
struct Toot {
category: String,
payload: String,
language: String,
}
impl Toot {
fn from_json(value: Value) -> Self {
Self {
category: value["event"].as_str().expect("Redis string").to_owned(),
payload: value["payload"].to_string(),
language: value["payload"]["language"]
.as_str()
.expect("Redis str")
.to_string(),
}
}
fn to_optional_json(&self) -> Option<Value> {
Some(json!(
{"event": self.category,
"payload": self.payload,}
))
}
fn ignore_if_caught_by_filter(
&self,
filter: &Filter,
user_langs: Option<Vec<String>>,
) -> Result<Async<Option<Value>>, Error> {
let toot = self;
let (send_msg, skip_msg) = (
Ok(Async::Ready(toot.to_optional_json())),
Ok(Async::NotReady),
);
match &filter {
Filter::NoFilter => send_msg,
Filter::Notification if toot.category == "notification" => send_msg,
// If not, skip it
Filter::Notification => skip_msg,
Filter::Language if user_langs.is_none() => send_msg,
Filter::Language if user_langs.expect("").contains(&toot.language) => send_msg,
// If not, skip it
Filter::Language => skip_msg,
}
}
}

View File

@ -1,6 +1,8 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use crate::error;
use crate::query;
use crate::user::{Scope, User};
use crate::user::{Filter::*, Scope, User};
use crate::user_from_path;
use warp::filters::BoxedFilter;
use warp::{path, Filter};
@ -8,14 +10,8 @@ use warp::{path, Filter};
type TimelineUser = ((String, User),);
/// GET /api/v1/streaming/user
///
///
/// **private**. Filter: `Language`
pub fn user() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "user")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
user_from_path!("streaming" / "user", Scope::Private)
.map(|user: User| (user.id.to_string(), user))
.boxed()
}
@ -23,477 +19,84 @@ pub fn user() -> BoxedFilter<TimelineUser> {
/// GET /api/v1/streaming/user/notification
///
///
/// **private**. Filter: `Notification`
///
///
/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented?
pub fn user_notifications() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "user" / "notification")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user.with_notification_filter()))
user_from_path!("streaming" / "user" / "notification", Scope::Private)
.map(|user: User| (user.id.to_string(), user.set_filter(Notification)))
.boxed()
}
/// GET /api/v1/streaming/public
///
///
/// **public**. Filter: `Language`
pub fn public() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.map(|user: User| ("public".to_owned(), user.with_language_filter()))
user_from_path!("streaming" / "public", Scope::Public)
.map(|user: User| ("public".to_owned(), user.set_filter(Language)))
.boxed()
}
/// GET /api/v1/streaming/public?only_media=true
///
///
/// **public**. Filter: `Language`
pub fn public_media() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
user_from_path!("streaming" / "public", Scope::Public)
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:media".to_owned(), user.with_language_filter()),
_ => ("public".to_owned(), user.with_language_filter()),
"1" | "true" => ("public:media".to_owned(), user.set_filter(Language)),
_ => ("public".to_owned(), user.set_filter(Language)),
})
.boxed()
}
/// GET /api/v1/streaming/public/local
///
///
/// **public**. Filter: `Language`
pub fn public_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public" / "local")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.map(|user: User| ("public:local".to_owned(), user.with_language_filter()))
user_from_path!("streaming" / "public" / "local", Scope::Public)
.map(|user: User| ("public:local".to_owned(), user.set_filter(Language)))
.boxed()
}
/// GET /api/v1/streaming/public/local?only_media=true
///
///
/// **public**. Filter: `Language`
pub fn public_local_media() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public" / "local")
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
user_from_path!("streaming" / "public" / "local", Scope::Public)
.and(warp::query())
.and(path::end())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()),
_ => ("public:local".to_owned(), user.with_language_filter()),
"1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)),
_ => ("public:local".to_owned(), user.set_filter(Language)),
})
.boxed()
}
/// GET /api/v1/streaming/direct
///
///
/// **private**. Filter: `None`
pub fn direct() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "direct")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (format!("direct:{}", user.id), user.with_no_filter()))
user_from_path!("streaming" / "direct", Scope::Private)
.map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter)))
.boxed()
}
/// GET /api/v1/streaming/hashtag?tag=:hashtag
///
///
/// **public**. Filter: `None`
pub fn hashtag() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/hashtag/local?tag=:hashtag
///
///
/// **public**. Filter: `None`
pub fn hashtag_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/list?list=:list_id
///
///
/// **private**. Filter: `None`
pub fn list() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "list")
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
user_from_path!("streaming" / "list", Scope::Private)
.and(warp::query())
.and_then(|user: User, q: query::List| (user.authorized_for_list(q.list), Ok(user)))
.and_then(|user: User, q: query::List| {
if user.owns_list(q.list) {
(Ok(q.list), Ok(user))
} else {
(Err(error::unauthorized_list()), Ok(user))
}
})
.untuple_one()
.and(path::end())
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()))
.map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter)))
.boxed()
}
/// Combines multiple routes with the same return type together with
/// `or()` and `unify()`
#[macro_export]
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::user;
#[test]
fn user_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&user());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user",))
.filter(&user());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token={}",
access_token
))
.filter(&user())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user())
.expect("in test");
let expected_user =
User::from_access_token(access_token, user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn user_notifications_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
))
.filter(&user_notifications());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user/notification",))
.filter(&user_notifications());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_notifications_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token={}",
access_token
))
.filter(&user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token.clone(), user::Scope::Private)
.expect("in test")
.with_notification_filter();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user/notification")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token, user::Scope::Private)
.expect("in test")
.with_notification_filter();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn public_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public")
.filter(&public())
.expect("in test");
assert_eq!(value.0, "public".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=true")
.filter(&public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=1")
.filter(&public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_local_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local")
.filter(&public_local())
.expect("in test");
assert_eq!(value.0, "public:local".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_local_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=true")
.filter(&public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=1")
.filter(&public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn direct_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
))
.filter(&direct());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/direct",))
.filter(&direct());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn direct_timeline_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token={}",
access_token
))
.filter(&direct())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/direct")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&direct())
.expect("in test");
let expected_user =
User::from_access_token(access_token, user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn hashtag_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag?tag=a")
.filter(&hashtag())
.expect("in test");
assert_eq!(value.0, "hashtag:a".to_string());
assert_eq!(value.1, User::public());
}
#[test]
fn hashtag_timeline_local() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag/local?tag=a")
.filter(&hashtag_local())
.expect("in test");
assert_eq!(value.0, "hashtag:a:local".to_string());
assert_eq!(value.1, User::public());
}
#[test]
#[ignore]
fn list_timeline_auth() {
let list_id = 1;
let list_owner_id = get_list_owner(list_id);
let access_token = get_access_token(list_owner_id);
// Query Auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token={}&list={}",
access_token, list_id,
))
.filter(&list())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
// Header Auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/list?list=1")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&list())
.expect("in test");
let expected_user =
User::from_access_token(access_token, user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn list_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&list());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/list?list=1",))
.filter(&list());
assert!(no_access_token(value));
}
fn get_list_owner(list_number: i32) -> i64 {
let list_number: i64 = list_number.into();
let conn = user::connect_to_postgres();
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_number],
)
.expect("in test");
assert_eq!(
rows.len(),
1,
"Test database must contain at least one user with a list to run this test."
);
rows.get(0).get(1)
}
fn get_access_token(user_id: i64) -> String {
let conn = user::connect_to_postgres();
let rows = &conn
.query(
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
&[&user_id],
)
.expect("Can get access token from id");
rows.get(0).get(0)
}
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => {
true
}
_ => false,
},
_ => false,
}
}
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
_ => false,
},
_ => false,
}
}
}

View File

@ -1,24 +1,21 @@
//! Create a User by querying the Postgres database with the user's access_token
use crate::{any_of, query};
//! `User` struct and related functionality
use crate::{postgres, query};
use log::info;
use postgres;
use std::env;
use warp::Filter as WarpFilter;
/// (currently hardcoded to localhost)
pub fn connect_to_postgres() -> postgres::Connection {
let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or(format!(
"postgres://{}@localhost/mastodon_development",
env::var("USER").expect("User env var should exist")
));
postgres::Connection::connect(postgres_addr, postgres::TlsMode::None)
.expect("Can connect to local Postgres")
/// Combine multiple routes with the same return type together with
/// `or()` and `unify()`
#[macro_export]
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}
/// The filters that can be applied to toots after they come from Redis
#[derive(Clone, Debug, PartialEq)]
pub enum Filter {
None,
NoFilter,
Language,
Notification,
}
@ -28,140 +25,99 @@ pub enum Filter {
pub struct User {
pub id: i64,
pub access_token: String,
pub scopes: Vec<OauthScope>,
pub scopes: OauthScope,
pub langs: Option<Vec<String>>,
pub logged_in: bool,
pub filter: Filter,
}
#[derive(Clone, Debug, PartialEq)]
pub enum OauthScope {
Read,
ReadStatuses,
ReadNotifications,
ReadList,
Other,
}
impl From<&str> for OauthScope {
fn from(scope: &str) -> Self {
use OauthScope::*;
match scope {
"read" => Read,
"read:statuses" => ReadStatuses,
"read:notifications" => ReadNotifications,
"read:lists" => ReadList,
_ => Other,
}
impl Default for User {
fn default() -> Self {
User::public()
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct OauthScope {
pub all: bool,
pub statuses: bool,
pub notify: bool,
pub lists: bool,
}
impl From<Vec<String>> for OauthScope {
fn from(scope_list: Vec<String>) -> Self {
let mut oauth_scope = OauthScope::default();
for scope in scope_list {
match scope.as_str() {
"read" => oauth_scope.all = true,
"read:statuses" => oauth_scope.statuses = true,
"read:notifications" => oauth_scope.notify = true,
"read:lists" => oauth_scope.lists = true,
_ => (),
}
}
oauth_scope
}
}
/// Create a user based on the supplied path and access scope for the resource
#[macro_export]
macro_rules! user_from_path {
($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +)
.and($scope.get_access_token())
.and_then(|token| User::from_access_token(token, $scope)))
}
impl User {
/// Create a user from the access token supplied in the header or query paramaters
pub fn from_access_token(
access_token: String,
scope: Scope,
) -> Result<Self, warp::reject::Rejection> {
let conn = connect_to_postgres();
let result = &conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&access_token],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !result.is_empty() {
let only_row = result.get(0);
let id: i64 = only_row.get(1);
let scopes = only_row
.get::<_, String>(3)
.split(' ')
.map(|scope: &str| scope.into())
.filter(|scope| scope != &OauthScope::Other)
.collect();
dbg!(&scopes);
let langs: Option<Vec<String>> = only_row.get(2);
info!("Granting logged-in access");
let (id, langs, scope_list) = postgres::query_for_user_data(&access_token);
let scopes = OauthScope::from(scope_list);
if id != -1 || scope == Scope::Public {
let (logged_in, log_msg) = match id {
-1 => (false, "Public access to non-authenticated endpoints"),
_ => (true, "Granting logged-in access"),
};
info!("{}", log_msg);
Ok(User {
id,
access_token,
scopes,
langs,
logged_in: true,
filter: Filter::None,
})
} else if let Scope::Public = scope {
info!("Granting public access to non-authenticated client");
Ok(User {
id: -1,
access_token,
scopes: Vec::new(),
langs: None,
logged_in: false,
filter: Filter::None,
logged_in,
filter: Filter::NoFilter,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
/// Add a Notification filter
pub fn with_notification_filter(self) -> Self {
Self {
filter: Filter::Notification,
..self
}
}
/// Add a Language filter
pub fn with_language_filter(self) -> Self {
Self {
filter: Filter::Language,
..self
}
}
/// Remove all filters
pub fn with_no_filter(self) -> Self {
Self {
filter: Filter::None,
..self
}
/// Set the Notification/Language filter
pub fn set_filter(self, filter: Filter) -> Self {
Self { filter, ..self }
}
/// Determine whether the User is authorised for a specified list
pub fn authorized_for_list(&self, list: i64) -> Result<i64, warp::reject::Rejection> {
let conn = connect_to_postgres();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
" SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !rows.is_empty() {
let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1);
if id_of_account_that_owns_the_list == self.id {
return Ok(list);
}
};
Err(warp::reject::custom("Error: Invalid access token"))
pub fn owns_list(&self, list: i64) -> bool {
match postgres::query_list_owner(list) {
Some(i) if i == self.id => true,
_ => false,
}
}
/// A public (non-authenticated) User
pub fn public() -> Self {
User {
id: -1,
access_token: String::new(),
scopes: Vec::new(),
access_token: String::from("no access token"),
scopes: OauthScope::default(),
langs: None,
logged_in: false,
filter: Filter::None,
filter: Filter::NoFilter,
}
}
}
/// Whether the endpoint requires authentication or not
#[derive(PartialEq)]
pub enum Scope {
Public,
Private,

View File

@ -1,44 +1,86 @@
//! WebSocket-specific functionality
use crate::stream::StreamManager;
use crate::query;
use crate::stream_manager::StreamManager;
use crate::user::{Scope, User};
use crate::user_from_path;
use futures::future::Future;
use futures::stream::Stream;
use futures::Async;
use std::time;
use warp::filters::BoxedFilter;
use warp::{path, Filter};
/// Send a stream of replies to a WebSocket client
pub fn send_replies(
socket: warp::ws::WebSocket,
mut stream: StreamManager,
) -> impl futures::future::Future<Item = (), Error = ()> {
let (tx, rx) = futures::sync::mpsc::unbounded();
let (ws_tx, mut ws_rx) = socket.split();
// Create a pipe
let (tx, rx) = futures::sync::mpsc::unbounded();
// Send one end of it to a different thread and tell that end to forward whatever it gets
// on to the websocket client
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map_err(|_| ())
.map(|_r| ()),
);
let event_stream = tokio::timer::Interval::new(
std::time::Instant::now(),
std::time::Duration::from_millis(100),
)
.take_while(move |_| {
if ws_rx.poll().is_err() {
futures::future::ok(false)
} else {
futures::future::ok(true)
}
});
// For as long as the client is still connected, yeild a new event every 100 ms
let event_stream =
tokio::timer::Interval::new(time::Instant::now(), time::Duration::from_millis(100))
.take_while(move |_| match ws_rx.poll() {
Ok(Async::Ready(None)) => futures::future::ok(false),
_ => futures::future::ok(true),
});
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_json_value| {
if let Ok(Async::Ready(Some(json_value))) = stream.poll() {
let msg = warp::ws::Message::text(json_value.to_string());
if !tx.is_closed() {
tx.unbounded_send(msg).expect("No send error");
}
tx.unbounded_send(msg).expect("No send error");
};
Ok(())
})
.then(|msg| msg)
.map_err(|e| println!("{}", e))
}
pub fn websocket_routes() -> BoxedFilter<(User, Query, warp::ws::Ws2)> {
user_from_path!("streaming", Scope::Public)
.and(warp::query())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.and(warp::ws2())
.map(
|user: User,
stream: query::Stream,
media: query::Media,
hashtag: query::Hashtag,
list: query::List,
ws: warp::ws::Ws2| {
let query = Query {
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
};
(user, query, ws)
},
)
.untuple_one()
.boxed()
}
#[derive(Debug)]
pub struct Query {
pub stream: String,
pub media: bool,
pub hashtag: String,
pub list: i64,
}

341
tests/test.rs Normal file
View File

@ -0,0 +1,341 @@
use ragequit::{
config,
timeline::*,
user::{Filter::*, Scope, User},
};
#[test]
fn user_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&user());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user",))
.filter(&user());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token={}",
access_token
))
.filter(&user())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn user_notifications_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
))
.filter(&user_notifications());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user/notification",))
.filter(&user_notifications());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_notifications_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token={}",
access_token
))
.filter(&user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token.clone(), Scope::Private)
.expect("in test")
.set_filter(Notification);
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user/notification")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private)
.expect("in test")
.set_filter(Notification);
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn public_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public")
.filter(&public())
.expect("in test");
assert_eq!(value.0, "public".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn public_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=true")
.filter(&public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=1")
.filter(&public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn public_local_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local")
.filter(&public_local())
.expect("in test");
assert_eq!(value.0, "public:local".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn public_local_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=true")
.filter(&public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=1")
.filter(&public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn direct_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
))
.filter(&direct());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/direct",))
.filter(&direct());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn direct_timeline_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token={}",
access_token
))
.filter(&direct())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/direct")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&direct())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn hashtag_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag?tag=a")
.filter(&hashtag())
.expect("in test");
assert_eq!(value.0, "hashtag:a".to_string());
assert_eq!(value.1, User::public());
}
#[test]
fn hashtag_timeline_local() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag/local?tag=a")
.filter(&hashtag_local())
.expect("in test");
assert_eq!(value.0, "hashtag:a:local".to_string());
assert_eq!(value.1, User::public());
}
#[test]
#[ignore]
fn list_timeline_auth() {
let list_id = 1;
let list_owner_id = get_list_owner(list_id);
let access_token = get_access_token(list_owner_id);
// Query Auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token={}&list={}",
access_token, list_id,
))
.filter(&list())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
// Header Auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/list?list=1")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&list())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn list_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&list());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/list?list=1",))
.filter(&list());
assert!(no_access_token(value));
}
// Helper functions for tests
fn get_list_owner(list_number: i32) -> i64 {
let list_number: i64 = list_number.into();
let conn = config::postgres();
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_number],
)
.expect("in test");
assert_eq!(
rows.len(),
1,
"Test database must contain at least one user with a list to run this test."
);
rows.get(0).get(1)
}
fn get_access_token(user_id: i64) -> String {
let conn = config::postgres();
let rows = &conn
.query(
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
&[&user_id],
)
.expect("Can get access token from id");
rows.get(0).get(0)
}
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => true,
_ => false,
},
_ => false,
}
}
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
// The cause could validly be any of these, depending on the order they're checked
// (It would pass with just one, so the last one it doesn't have is "the" cause)
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
Some(c) if format!("{:?}", c) == "InvalidQuery" => true,
Some(c) if format!("{:?}", c) == "MissingHeader(\"Sec-WebSocket-Protocol\")" => true,
_ => false,
},
_ => false,
}
}