From f3b86ddac87b59aafa1b71b20acae2ebea99e60b Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Thu, 4 Jul 2019 14:00:35 -0400 Subject: [PATCH 1/4] Add CORS support Cross-Origin requests were already implicitly allowed, but this commit allows them explicitly and prohibits request methods other than GET. --- src/main.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 2fe5d89..3dae80c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -171,5 +171,9 @@ fn main() { .unwrap_or("127.0.0.1:4000".to_owned()) .parse() .expect("static string"); - warp::serve(websocket.or(routes)).run(address); + let cors = warp::cors() + .allow_any_origin() + .allow_methods(vec!["GET", "OPTIONS"]) + .allow_headers(vec!["Authorization", "Accept", "Cache-Control"]); + warp::serve(websocket.or(routes).with(cors)).run(address); } From 17320088404419bad272364d801fddd767509b55 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Fri, 5 Jul 2019 20:08:50 -0400 Subject: [PATCH 2/4] Initial cleanup/refactor --- src/config.rs | 54 +++++ src/error.rs | 4 + src/lib.rs | 38 ++++ src/main.rs | 196 ++++++------------ src/postgres.rs | 53 +++++ src/receiver.rs | 235 +++++++++++----------- src/redis_cmd.rs | 20 ++ src/stream.rs | 93 --------- src/stream_manager.rs | 154 +++++++++++++++ src/timeline.rs | 451 +++--------------------------------------- src/user.rs | 178 +++++++---------- src/ws.rs | 74 +++++-- tests/test.rs | 341 ++++++++++++++++++++++++++++++++ 13 files changed, 1004 insertions(+), 887 deletions(-) create mode 100644 src/config.rs create mode 100644 src/lib.rs create mode 100644 src/postgres.rs delete mode 100644 src/stream.rs create mode 100644 src/stream_manager.rs create mode 100644 tests/test.rs diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..fad70a4 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,54 @@ +//! Configuration settings for servers and databases +use dotenv::dotenv; +use log::warn; +use std::{env, net, time}; + +/// Configure CORS for the API server +pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { + warp::cors() + .allow_any_origin() + .allow_methods(vec!["GET", "OPTIONS"]) + .allow_headers(vec!["Authorization", "Accept", "Cache-Control"]) +} + +/// Initialize logging and read values from `src/.env` +pub fn logging_and_env() { + pretty_env_logger::init(); + dotenv().ok(); +} + +/// Configure Postgres and return a connection +pub fn postgres() -> postgres::Connection { + let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| { + format!( + "postgres://{}@localhost/mastodon_development", + env::var("USER").unwrap_or_else(|_| { + warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); + "postgres".to_owned() + }) + ) + }); + postgres::Connection::connect(postgres_addr, postgres::TlsMode::None) + .expect("Can connect to local Postgres") +} + +pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { + let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| "127.0.0.1:6379".to_string()); + let pubsub_connection = net::TcpStream::connect(&redis_addr).expect("Can connect to Redis"); + pubsub_connection + .set_read_timeout(Some(time::Duration::from_millis(10))) + .expect("Can set read timeout for Redis connection"); + let secondary_redis_connection = + net::TcpStream::connect(&redis_addr).expect("Can connect to Redis"); + secondary_redis_connection + .set_read_timeout(Some(time::Duration::from_millis(10))) + .expect("Can set read timeout for Redis connection"); + (pubsub_connection, secondary_redis_connection) +} + +pub fn socket_address() -> net::SocketAddr { + env::var("SERVER_ADDR") + .unwrap_or_else(|_| "127.0.0.1:4000".to_owned()) + .parse() + .expect("static string") +} diff --git a/src/error.rs b/src/error.rs index bede2f6..4fa7bcd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,3 +30,7 @@ pub fn handle_errors( warp::http::StatusCode::UNAUTHORIZED, )) } + +pub fn unauthorized_list() -> warp::reject::Rejection { + warp::reject::custom("Error: Access to list not authorized") +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a5b342f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,38 @@ +//! Streaming server for Mastodon +//! +//! +//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server +//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to +//! connect to the server with the API described [in Mastodon's public API +//! documentation](https://docs.joinmastodon.org/api/streaming/). +//! +//! # Notes on data flow +//! * **Client Request → Warp**: +//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User` +//! representing the client that made the request with data from the client's request and from +//! Postgres. The `User` is authenticated, if appropriate. Warp //! repeatedly polls the +//! StreamManager for information relevant to the User. +//! +//! * **Warp → StreamManager**: +//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent +//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `StreamManager` +//! polls the `Receiver` for any updates relevant to the current client. If there are updates, the +//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp. +//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the +//! `Receiver`) when necessary. +//! +//! * **StreamManager → Receiver**: +//! The Receiver receives data from Redis and stores it in a series of queues (one for each +//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages +//! relevant to that StreamManager and removes them from the queue. + +pub mod config; +pub mod error; +pub mod postgres; +pub mod query; +pub mod receiver; +pub mod redis_cmd; +pub mod stream_manager; +pub mod timeline; +pub mod user; +pub mod ws; diff --git a/src/main.rs b/src/main.rs index 3dae80c..32db4d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,58 +1,20 @@ -//! Streaming server for Mastodon -//! -//! -//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server -//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to -//! connect to the server with the API described [in the public API -//! documentation](https://docs.joinmastodon.org/api/streaming/) -//! -//! # Notes on data flow -//! * **Client Request → Warp**: -//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User` -//! representing the client that made the request. The `User` is authenticated, if appropriate. Warp -//! repeatedly polls the StreamManager for information relevant to the User. -//! -//! * **Warp → StreamManager**: -//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent -//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `StreamManager` -//! polles the `Receiver` for any updates relvant to the current client. If there are updates, the -//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp. -//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the -//! `Receiver`) when necessary. -//! -//! * **StreamManger → Receiver**: -//! The Receiver receives data from Redis and stores it in a series of queues (one for each -//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages -//! relevant to that StreamManager and removes them from the queue. - -pub mod error; -pub mod query; -pub mod receiver; -pub mod redis_cmd; -pub mod stream; -pub mod timeline; -pub mod user; -pub mod ws; -use dotenv::dotenv; -use futures::stream::Stream; -use futures::Async; -use receiver::Receiver; -use std::env; -use std::net::SocketAddr; -use stream::StreamManager; -use user::{OauthScope::*, Scope, User}; -use warp::path; -use warp::Filter as WarpFilter; +use futures::{stream::Stream, Async}; +use ragequit::{ + any_of, config, error, + stream_manager::StreamManager, + timeline, + user::{Filter::*, User}, + ws, +}; +use warp::{ws::Ws2, Filter as WarpFilter}; fn main() { - pretty_env_logger::init(); - dotenv().ok(); + config::logging_and_env(); + let stream_manager_sse = StreamManager::new(); + let stream_manager_ws = stream_manager_sse.clone(); - let redis_updates = StreamManager::new(Receiver::new()); - let redis_updates_sse = redis_updates.blank_copy(); - let redis_updates_ws = redis_updates.blank_copy(); - - let routes = any_of!( + // Server Sent Events + let sse_routes = any_of!( // GET /api/v1/streaming/user/notification [private; notification filter] timeline::user_notifications(), // GET /api/v1/streaming/user [private; language filter] @@ -77,12 +39,12 @@ fn main() { .untuple_one() .and(warp::sse()) .map(move |timeline: String, user: User, sse: warp::sse::Sse| { - let mut redis_stream = redis_updates_sse.configure_copy(&timeline, user); + let mut stream_manager = stream_manager_sse.manage_new_timeline(&timeline, user); let event_stream = tokio::timer::Interval::new( std::time::Instant::now(), std::time::Duration::from_millis(100), ) - .filter_map(move |_| match redis_stream.poll() { + .filter_map(move |_| match stream_manager.poll() { Ok(Async::Ready(Some(json_value))) => Some(( warp::sse::event(json_value["event"].clone().to_string()), warp::sse::data(json_value["payload"].clone()), @@ -94,86 +56,54 @@ fn main() { .with(warp::reply::with::header("Connection", "keep-alive")) .recover(error::handle_errors); - //let redis_updates_ws = StreamManager::new(Receiver::new()); - let websocket = path!("api" / "v1" / "streaming") - .and(Scope::Public.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Public)) - .and(warp::query()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .and(warp::ws2()) - .and_then( - move |mut user: User, - q: query::Stream, - m: query::Media, - h: query::Hashtag, - l: query::List, - ws: warp::ws::Ws2| { - let scopes = user.scopes.clone(); - let timeline = match q.stream.as_ref() { - // Public endpoints: - tl @ "public" | tl @ "public:local" if m.is_truthy() => format!("{}:media", tl), - tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), - tl @ "public" | tl @ "public:local" => tl.to_string(), - // Hashtag endpoints: - // TODO: handle missing query - tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, h.tag), - // Private endpoints: User - "user" - if user.id > 0 - && (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) => - { - format!("{}", user.id) - } - "user:notification" - if user.id > 0 - && (scopes.contains(&Read) || scopes.contains(&ReadNotifications)) => - { - user = user.with_notification_filter(); - format!("{}", user.id) - } - // List endpoint: - // TODO: handle missing query - "list" - if user.authorized_for_list(l.list).is_ok() - && (scopes.contains(&Read) || scopes.contains(&ReadList)) => - { - format!("list:{}", l.list) - } + // WebSocket + let websocket_routes = ws::websocket_routes() + .and_then(move |mut user: User, q: ws::Query, ws: Ws2| { + let read_scope = user.scopes.clone(); + let timeline = match q.stream.as_ref() { + // Public endpoints: + tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl), + tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), + tl @ "public" | tl @ "public:local" => tl.to_string(), + // Hashtag endpoints: + // TODO: handle missing query + tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag), + // Private endpoints: User + "user" if user.logged_in && (read_scope.all || read_scope.statuses) => { + format!("{}", user.id) + } + "user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => { + user = user.set_filter(Notification); + format!("{}", user.id) + } + // List endpoint: + // TODO: handle missing query + "list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => { + format!("list:{}", q.list) + } + // Direct endpoint: + "direct" if user.logged_in && (read_scope.all || read_scope.statuses) => { + "direct".to_string() + } + // Reject unathorized access attempts for private endpoints + "user" | "user:notification" | "direct" | "list" => { + return Err(warp::reject::custom("Error: Invalid Access Token")) + } + // Other endpoints don't exist: + _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), + }; + let token = user.access_token.clone(); + let stream_manager = stream_manager_ws.manage_new_timeline(&timeline, user); - // Direct endpoint: - "direct" - if user.id > 0 - && (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) => - { - "direct".to_string() - } - // Reject unathorized access attempts for private endpoints - "user" | "user:notification" | "direct" | "list" => { - return Err(warp::reject::custom("Error: Invalid Access Token")) - } - // Other endpoints don't exist: - _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), - }; - let token = user.access_token.clone(); - let stream = redis_updates_ws.configure_copy(&timeline, user); - - Ok(( - ws.on_upgrade(move |socket| ws::send_replies(socket, stream)), - token, - )) - }, - ) + Ok(( + ws.on_upgrade(move |socket| ws::send_replies(socket, stream_manager)), + token, + )) + }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); - let address: SocketAddr = env::var("SERVER_ADDR") - .unwrap_or("127.0.0.1:4000".to_owned()) - .parse() - .expect("static string"); - let cors = warp::cors() - .allow_any_origin() - .allow_methods(vec!["GET", "OPTIONS"]) - .allow_headers(vec!["Authorization", "Accept", "Cache-Control"]); - warp::serve(websocket.or(routes).with(cors)).run(address); + let cors = config::cross_origin_resource_sharing(); + let address = config::socket_address(); + + warp::serve(websocket_routes.or(sse_routes).with(cors)).run(address); } diff --git a/src/postgres.rs b/src/postgres.rs new file mode 100644 index 0000000..50af32b --- /dev/null +++ b/src/postgres.rs @@ -0,0 +1,53 @@ +//! Postgres queries +use crate::config; + +pub fn query_for_user_data(access_token: &str) -> (i64, Option>, Vec) { + let conn = config::postgres(); + let query_result = conn + .query( + " +SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes +FROM +oauth_access_tokens +INNER JOIN users ON +oauth_access_tokens.resource_owner_id = users.id +WHERE oauth_access_tokens.token = $1 +AND oauth_access_tokens.revoked_at IS NULL +LIMIT 1", + &[&access_token.to_owned()], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if !query_result.is_empty() { + let only_row = query_result.get(0); + let id: i64 = only_row.get(1); + let scopes = only_row + .get::<_, String>(3) + .split(' ') + .map(|s| s.to_owned()) + .collect(); + let langs: Option> = only_row.get(2); + (id, langs, scopes) + } else { + (-1, None, Vec::new()) + } +} + +pub fn query_list_owner(list_id: i64) -> Option { + let conn = config::postgres(); + // For the Postgres query, `id` = list number; `account_id` = user.id + let rows = &conn + .query( + " +SELECT id, account_id +FROM lists +WHERE id = $1 +LIMIT 1", + &[&list_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if rows.is_empty() { + None + } else { + Some(rows.get(0).get(1)) + } +} diff --git a/src/receiver.rs b/src/receiver.rs index 567b30d..4bf7d04 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -1,165 +1,155 @@ -//! Interfacing with Redis and stream the results on to the `StreamManager` -use crate::redis_cmd; -use crate::user::User; -use futures::stream::Stream; +//! Interface with Redis and stream the results to the `StreamManager` +//! There is only one `Receiver`, which suggests that it's name is bad. +//! +//! **TODO**: Consider changing the name. Maybe RedisConnectionPool? +//! There are many AsyncReadableStreams, though. How do they fit in? +//! Figure this out ASAP. +//! A new one is created every time the Receiver is polled +use crate::{config, pubsub_cmd, redis_cmd}; use futures::{Async, Poll}; use log::info; use regex::Regex; use serde_json::Value; -use std::collections::{HashMap, VecDeque}; -use std::env; -use std::io::{Read, Write}; -use std::net::TcpStream; -use std::time::{Duration, Instant}; +use std::{collections, io::Read, io::Write, net, time}; use tokio::io::{AsyncRead, Error}; use uuid::Uuid; -#[derive(Debug)] -struct MsgQueue { - messages: VecDeque, - last_polled_at: Instant, - redis_channel: String, -} -impl MsgQueue { - fn new(redis_channel: impl std::fmt::Display) -> Self { - let redis_channel = redis_channel.to_string(); - MsgQueue { - messages: VecDeque::new(), - last_polled_at: Instant::now(), - redis_channel, - } - } -} - -/// The item that streams from Redis and is polled by the `StreamManger` +/// The item that streams from Redis and is polled by the `StreamManager` #[derive(Debug)] pub struct Receiver { - pubsub_connection: TcpStream, - secondary_redis_connection: TcpStream, + pubsub_connection: net::TcpStream, + secondary_redis_connection: net::TcpStream, tl: String, - pub user: User, manager_id: Uuid, - msg_queues: HashMap, - clients_per_timeline: HashMap, -} -impl Default for Receiver { - fn default() -> Self { - Self::new() - } + msg_queues: collections::HashMap, + clients_per_timeline: collections::HashMap, } + impl Receiver { + /// Create a new `Receiver`, with its own Redis connections (but, as yet, no + /// active subscriptions). pub fn new() -> Self { - let redis_addr = env::var("REDIS_ADDR").unwrap_or("127.0.0.1:6379".to_string()); - let pubsub_connection = TcpStream::connect(&redis_addr).expect("Can connect to Redis"); - pubsub_connection - .set_read_timeout(Some(Duration::from_millis(10))) - .expect("Can set read timeout for Redis connection"); - let secondary_redis_connection = - TcpStream::connect(&redis_addr).expect("Can connect to Redis"); - secondary_redis_connection - .set_read_timeout(Some(Duration::from_millis(10))) - .expect("Can set read timeout for Redis connection"); + let (pubsub_connection, secondary_redis_connection) = config::redis_addr(); Self { pubsub_connection, secondary_redis_connection, tl: String::new(), - user: User::public(), - manager_id: Uuid::new_v4(), - msg_queues: HashMap::new(), - clients_per_timeline: HashMap::new(), + manager_id: Uuid::default(), + msg_queues: collections::HashMap::new(), + clients_per_timeline: collections::HashMap::new(), } } - /// Update the `StreamManager` that is currently polling the `Receiver` - pub fn update(&mut self, id: Uuid, timeline: impl std::fmt::Display) { - self.manager_id = id; + /// Assigns the `Receiver` a new timeline to monitor and runs other + /// first-time setup. + /// + /// Importantly, this method calls `subscribe_or_unsubscribe_as_needed`, + /// so Redis PubSub subscriptions are only updated when a new timeline + /// comes under management for the first time. + pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) { + self.manager_id = manager_id; + self.tl = timeline.to_string(); + let old_value = self + .msg_queues + .insert(self.manager_id, MsgQueue::new(timeline)); + // Consider removing/refactoring + if let Some(value) = old_value { + eprintln!( + "Data was overwritten when it shouldn't have been. Old data was: {:#?}", + value + ); + } + self.subscribe_or_unsubscribe_as_needed(timeline); + } + + /// Set the `Receiver`'s manager_id and target_timeline fields to the approprate + /// value to be polled by the current `StreamManager`. + pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) { + if &manager_id != &self.manager_id { + //println!("New Manager: {}", &manager_id); + } + self.manager_id = manager_id; self.tl = timeline.to_string(); } - /// Send a subscribe command to the Redis PubSub (if needed) - pub fn maybe_subscribe(&mut self, tl: &str) { - info!("Subscribing to {}", &tl); - - let manager_id = self.manager_id; - self.msg_queues.insert(manager_id, MsgQueue::new(tl)); - let current_clients = self - .clients_per_timeline - .entry(tl.to_string()) - .and_modify(|n| *n += 1) - .or_insert(1); - - if *current_clients == 1 { - let subscribe_cmd = redis_cmd::pubsub("subscribe", tl); - self.pubsub_connection - .write_all(&subscribe_cmd) - .expect("Can subscribe to Redis"); - let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "1"); - self.secondary_redis_connection - .write_all(&set_subscribed_cmd) - .expect("Can set Redis"); - info!("Now subscribed to: {:#?}", &self.msg_queues); - } - } - - /// Drop any PubSub subscriptions that don't have active clients - pub fn unsubscribe_from_empty_channels(&mut self) { - let mut timelines_with_fewer_clients = Vec::new(); + /// Drop any PubSub subscriptions that don't have active clients and check + /// that there's a subscription to the current one. If there isn't, then + /// subscribe to it. + fn subscribe_or_unsubscribe_as_needed(&mut self, tl: &str) { + let mut timelines_to_modify = Vec::new(); + timelines_to_modify.push((tl.to_owned(), 1)); // Keep only message queues that have been polled recently self.msg_queues.retain(|_id, msg_queue| { - if msg_queue.last_polled_at.elapsed() < Duration::from_secs(30) { + if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) { true } else { - timelines_with_fewer_clients.push(msg_queue.redis_channel.clone()); + let timeline = msg_queue.redis_channel.clone(); + timelines_to_modify.push((timeline, -1)); false } }); // Record the lower number of clients subscribed to that channel - for timeline in timelines_with_fewer_clients { + for (timeline, numerical_change) in timelines_to_modify { + let mut need_to_subscribe = false; let count_of_subscribed_clients = self .clients_per_timeline - .entry(timeline.clone()) - .and_modify(|n| *n -= 1) - .or_insert(0); + .entry(timeline.to_owned()) + .and_modify(|n| *n += numerical_change) + .or_insert_with(|| { + need_to_subscribe = true; + 1 + }); // If no clients, unsubscribe from the channel if *count_of_subscribed_clients <= 0 { - self.unsubscribe(&timeline); + info!("Sent unsubscribe command"); + pubsub_cmd!("unsubscribe", self, timeline.clone()); + } + if need_to_subscribe { + info!("Sent subscribe command"); + pubsub_cmd!("subscribe", self, timeline.clone()); } } } - /// Send an unsubscribe command to the Redis PubSub - pub fn unsubscribe(&mut self, tl: &str) { - let unsubscribe_cmd = redis_cmd::pubsub("unsubscribe", tl); - info!("Unsubscribing from {}", &tl); - self.pubsub_connection - .write_all(&unsubscribe_cmd) - .expect("Can unsubscribe from Redis"); - let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "0"); - self.secondary_redis_connection - .write_all(&set_subscribed_cmd) - .expect("Can set Redis"); - info!("Now subscribed only to: {:#?}", &self.msg_queues); + fn log_number_of_msgs_in_queue(&self) { + let messages_waiting = self + .msg_queues + .get(&self.manager_id) + .expect("Guaranteed by match block") + .messages + .len(); + match messages_waiting { + number if number > 10 => { + log::error!("{} messages waiting in the queue", messages_waiting) + } + _ => log::info!("{} messages waiting in the queue", messages_waiting), + } } } -impl Stream for Receiver { +impl Default for Receiver { + fn default() -> Self { + Receiver::new() + } +} + +impl futures::stream::Stream for Receiver { type Item = Value; type Error = Error; fn poll(&mut self) -> Poll, Self::Error> { let mut buffer = vec![0u8; 3000]; - info!("Being polled by: {}", self.manager_id); let timeline = self.tl.clone(); // Record current time as last polled time self.msg_queues .entry(self.manager_id) - .and_modify(|msg_queue| msg_queue.last_polled_at = Instant::now()); + .and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now()); // Add any incomming messages to the back of the relevant `msg_queues` // NOTE: This could be more/other than the `msg_queue` currently being polled - let mut async_stream = AsyncReadableStream(&mut self.pubsub_connection); + let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection); if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer)? { let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); // capture everything between `{` and `}` as potential JSON @@ -183,11 +173,14 @@ impl Stream for Receiver { match self .msg_queues .entry(self.manager_id) - .or_insert_with(|| MsgQueue::new(timeline)) + .or_insert_with(|| MsgQueue::new(timeline.clone())) .messages .pop_front() { - Some(value) => Ok(Async::Ready(Some(value))), + Some(value) => { + self.log_number_of_msgs_in_queue(); + Ok(Async::Ready(Some(value))) + } _ => Ok(Async::NotReady), } } @@ -195,12 +188,34 @@ impl Stream for Receiver { impl Drop for Receiver { fn drop(&mut self) { - let timeline = self.tl.clone(); - self.unsubscribe(&timeline); + pubsub_cmd!("unsubscribe", self, self.tl.clone()); } } -struct AsyncReadableStream<'a>(&'a mut TcpStream); +#[derive(Debug, Clone)] +struct MsgQueue { + pub messages: collections::VecDeque, + pub last_polled_at: time::Instant, + pub redis_channel: String, +} + +impl MsgQueue { + pub fn new(redis_channel: impl std::fmt::Display) -> Self { + let redis_channel = redis_channel.to_string(); + MsgQueue { + messages: collections::VecDeque::new(), + last_polled_at: time::Instant::now(), + redis_channel, + } + } +} + +struct AsyncReadableStream<'a>(&'a mut net::TcpStream); +impl<'a> AsyncReadableStream<'a> { + pub fn new(stream: &'a mut net::TcpStream) -> Self { + AsyncReadableStream(stream) + } +} impl<'a> Read for AsyncReadableStream<'a> { fn read(&mut self, buffer: &mut [u8]) -> Result { diff --git a/src/redis_cmd.rs b/src/redis_cmd.rs index 3a9c07c..b032887 100644 --- a/src/redis_cmd.rs +++ b/src/redis_cmd.rs @@ -1,6 +1,26 @@ //! Send raw TCP commands to the Redis server use std::fmt::Display; +/// Send a subscribe or unsubscribe to the Redis PubSub channel +#[macro_export] +macro_rules! pubsub_cmd { + ($cmd:expr, $self:expr, $tl:expr) => {{ + info!("Sending {} command to {}", $cmd, $tl); + $self + .pubsub_connection + .write_all(&redis_cmd::pubsub($cmd, $tl)) + .expect("Can send command to Redis"); + let new_value = if $cmd == "subscribe" { "1" } else { "0" }; + $self + .secondary_redis_connection + .write_all(&redis_cmd::set( + format!("subscribed:timeline:{}", $tl), + new_value, + )) + .expect("Can set Redis"); + info!("Now subscribed to: {:#?}", $self.msg_queues); + }}; +} /// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec { let arg = format!("timeline:{}", timeline); diff --git a/src/stream.rs b/src/stream.rs deleted file mode 100644 index 928f97d..0000000 --- a/src/stream.rs +++ /dev/null @@ -1,93 +0,0 @@ -//! Manage all existing Redis PubSub connection -use crate::receiver::Receiver; -use crate::user::{Filter, User}; -use futures::stream::Stream; -use futures::{Async, Poll}; -use serde_json::json; -use serde_json::Value; -use std::sync::{Arc, Mutex}; -use tokio::io::Error; -use uuid::Uuid; - -/// Struct for manageing all Redis streams -#[derive(Clone, Debug)] -pub struct StreamManager { - receiver: Arc>, - id: uuid::Uuid, - target_timeline: String, - current_user: Option, -} -impl StreamManager { - pub fn new(reciever: Receiver) -> Self { - StreamManager { - receiver: Arc::new(Mutex::new(reciever)), - id: Uuid::default(), - target_timeline: String::new(), - current_user: None, - } - } - - /// Create a blank StreamManager copy - pub fn blank_copy(&self) -> Self { - StreamManager { ..self.clone() } - } - /// Create a StreamManager copy with a new unique id manage subscriptions - pub fn configure_copy(&self, timeline: &String, user: User) -> Self { - let id = Uuid::new_v4(); - let mut receiver = self.receiver.lock().expect("No panic in other threads"); - receiver.update(id, timeline); - receiver.maybe_subscribe(timeline); - StreamManager { - id, - current_user: Some(user), - target_timeline: timeline.clone(), - ..self.clone() - } - } -} - -impl Stream for StreamManager { - type Item = Value; - type Error = Error; - - fn poll(&mut self) -> Poll, Self::Error> { - let mut receiver = self - .receiver - .lock() - .expect("StreamManager: No other thread panic"); - receiver.update(self.id, &self.target_timeline.clone()); - match receiver.poll() { - Ok(Async::Ready(Some(value))) => { - let user = self - .clone() - .current_user - .expect("Previously set current user"); - - let user_langs = user.langs.clone(); - let event = value["event"].as_str().expect("Redis string"); - let payload = value["payload"].to_string(); - - match (&user.filter, user_langs) { - (Filter::Notification, _) if event != "notification" => Ok(Async::NotReady), - (Filter::Language, Some(ref user_langs)) - if !user_langs.contains( - &value["payload"]["language"] - .as_str() - .expect("Redis str") - .to_string(), - ) => - { - Ok(Async::NotReady) - } - _ => Ok(Async::Ready(Some(json!( - {"event": event, - "payload": payload,} - )))), - } - } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e), - } - } -} diff --git a/src/stream_manager.rs b/src/stream_manager.rs new file mode 100644 index 0000000..ecffd8a --- /dev/null +++ b/src/stream_manager.rs @@ -0,0 +1,154 @@ +//! The `StreamManager` is responsible to providing an interface between the `Warp` +//! filters and the underlying mechanics of talking with Redis/managing multiple +//! threads. The `StreamManager` is the only struct that any Warp code should +//! need to communicate with. +//! +//! The `StreamManager`'s interface is very simple. All you can do with it is: +//! * Create a totally new `StreamManger` with no shared data; +//! * Assign an existing `StreamManager` to manage an new timeline/user pair; or +//! * Poll an existing `StreamManager` to see if there are any new messages +//! for clients +//! +//! When you poll the `StreamManager`, it is responsible for polling internal data +//! structures, getting any updates from Redis, and then filtering out any updates +//! that should be excluded by relevant filters. +//! +//! Because `StreamManagers` are lightweight data structures that do not directly +//! communicate with Redis, it is appropriate to create a new `StreamManager` for +//! each new client connection. +use crate::{ + receiver::Receiver, + user::{Filter, User}, +}; +use futures::{Async, Poll}; +use serde_json::{json, Value}; +use std::sync; +use std::time; +use tokio::io::Error; +use uuid::Uuid; + +/// Struct for managing all Redis streams. +#[derive(Clone, Default, Debug)] +pub struct StreamManager { + receiver: sync::Arc>, + id: uuid::Uuid, + target_timeline: String, + current_user: User, +} + +impl StreamManager { + /// Create a new `StreamManager` with no shared data. + pub fn new() -> Self { + StreamManager { + receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())), + id: Uuid::default(), + target_timeline: String::new(), + current_user: User::public(), + } + } + + /// Assign the `StreamManager` to manage a new timeline/user pair. + /// + /// Note that this *may or may not* result in a new Redis connection. + /// If the server has already subscribed to the timeline on behalf of + /// a different user, the `StreamManager` is responsible for figuring + /// that out and avoiding duplicated connections. Thus, it is safe to + /// use this method for each new client connection. + pub fn manage_new_timeline(&self, target_timeline: &str, user: User) -> Self { + let manager_id = Uuid::new_v4(); + let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)"); + receiver.manage_new_timeline(manager_id, target_timeline); + StreamManager { + id: manager_id, + current_user: user, + target_timeline: target_timeline.to_owned(), + receiver: self.receiver.clone(), + } + } +} + +/// The stream that the `StreamManager` manages. `Poll` is the only method implemented. +impl futures::stream::Stream for StreamManager { + type Item = Value; + type Error = Error; + + /// Checks for any new messages that should be sent to the client. + /// + /// The `StreamManager` will poll underlying data structures and will reply + /// with an `Ok(Ready(Some(Value)))` if there is a new message to send to + /// the client. If there is no new message or if the new message should be + /// filtered out based on one of the user's filters, then the `StreamManager` + /// will reply with `Ok(NotReady)`. The `StreamManager` will buble up any + /// errors from the underlying data structures. + fn poll(&mut self) -> Poll, Self::Error> { + let start_time = time::Instant::now(); + let result = { + let mut receiver = self + .receiver + .lock() + .expect("StreamManager: No other thread panic"); + receiver.configure_for_polling(self.id, &self.target_timeline.clone()); + receiver.poll() + }; + println!("Polling took: {:?}", start_time.elapsed()); + let result = match result { + Ok(Async::Ready(Some(value))) => { + let user_langs = self.current_user.langs.clone(); + let toot = Toot::from_json(value); + toot.ignore_if_caught_by_filter(&self.current_user.filter, user_langs) + } + Ok(inner_value) => Ok(inner_value), + Err(e) => Err(e), + }; + result + } +} + +struct Toot { + category: String, + payload: String, + language: String, +} +impl Toot { + fn from_json(value: Value) -> Self { + Self { + category: value["event"].as_str().expect("Redis string").to_owned(), + payload: value["payload"].to_string(), + language: value["payload"]["language"] + .as_str() + .expect("Redis str") + .to_string(), + } + } + + fn to_optional_json(&self) -> Option { + Some(json!( + {"event": self.category, + "payload": self.payload,} + )) + } + + fn ignore_if_caught_by_filter( + &self, + filter: &Filter, + user_langs: Option>, + ) -> Result>, Error> { + let toot = self; + + let (send_msg, skip_msg) = ( + Ok(Async::Ready(toot.to_optional_json())), + Ok(Async::NotReady), + ); + + match &filter { + Filter::NoFilter => send_msg, + Filter::Notification if toot.category == "notification" => send_msg, + // If not, skip it + Filter::Notification => skip_msg, + Filter::Language if user_langs.is_none() => send_msg, + Filter::Language if user_langs.expect("").contains(&toot.language) => send_msg, + // If not, skip it + Filter::Language => skip_msg, + } + } +} diff --git a/src/timeline.rs b/src/timeline.rs index 541ac2b..71b6ab2 100644 --- a/src/timeline.rs +++ b/src/timeline.rs @@ -1,6 +1,8 @@ //! Filters for all the endpoints accessible for Server Sent Event updates +use crate::error; use crate::query; -use crate::user::{Scope, User}; +use crate::user::{Filter::*, Scope, User}; +use crate::user_from_path; use warp::filters::BoxedFilter; use warp::{path, Filter}; @@ -8,14 +10,8 @@ use warp::{path, Filter}; type TimelineUser = ((String, User),); /// GET /api/v1/streaming/user -/// -/// -/// **private**. Filter: `Language` pub fn user() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "user") - .and(path::end()) - .and(Scope::Private.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Private)) + user_from_path!("streaming" / "user", Scope::Private) .map(|user: User| (user.id.to_string(), user)) .boxed() } @@ -23,477 +19,84 @@ pub fn user() -> BoxedFilter { /// GET /api/v1/streaming/user/notification /// /// -/// **private**. Filter: `Notification` -/// -/// /// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented? pub fn user_notifications() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "user" / "notification") - .and(path::end()) - .and(Scope::Private.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Private)) - .map(|user: User| (user.id.to_string(), user.with_notification_filter())) + user_from_path!("streaming" / "user" / "notification", Scope::Private) + .map(|user: User| (user.id.to_string(), user.set_filter(Notification))) .boxed() } /// GET /api/v1/streaming/public -/// -/// -/// **public**. Filter: `Language` pub fn public() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "public") - .and(path::end()) - .and(Scope::Public.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Public)) - .map(|user: User| ("public".to_owned(), user.with_language_filter())) + user_from_path!("streaming" / "public", Scope::Public) + .map(|user: User| ("public".to_owned(), user.set_filter(Language))) .boxed() } /// GET /api/v1/streaming/public?only_media=true -/// -/// -/// **public**. Filter: `Language` pub fn public_media() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "public") - .and(path::end()) - .and(Scope::Public.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Public)) + user_from_path!("streaming" / "public", Scope::Public) .and(warp::query()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:media".to_owned(), user.with_language_filter()), - _ => ("public".to_owned(), user.with_language_filter()), + "1" | "true" => ("public:media".to_owned(), user.set_filter(Language)), + _ => ("public".to_owned(), user.set_filter(Language)), }) .boxed() } /// GET /api/v1/streaming/public/local -/// -/// -/// **public**. Filter: `Language` pub fn public_local() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "public" / "local") - .and(path::end()) - .and(Scope::Public.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Public)) - .map(|user: User| ("public:local".to_owned(), user.with_language_filter())) + user_from_path!("streaming" / "public" / "local", Scope::Public) + .map(|user: User| ("public:local".to_owned(), user.set_filter(Language))) .boxed() } /// GET /api/v1/streaming/public/local?only_media=true -/// -/// -/// **public**. Filter: `Language` pub fn public_local_media() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "public" / "local") - .and(Scope::Public.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Public)) + user_from_path!("streaming" / "public" / "local", Scope::Public) .and(warp::query()) - .and(path::end()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()), - _ => ("public:local".to_owned(), user.with_language_filter()), + "1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)), + _ => ("public:local".to_owned(), user.set_filter(Language)), }) .boxed() } /// GET /api/v1/streaming/direct -/// -/// -/// **private**. Filter: `None` pub fn direct() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "direct") - .and(path::end()) - .and(Scope::Private.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Private)) - .map(|user: User| (format!("direct:{}", user.id), user.with_no_filter())) + user_from_path!("streaming" / "direct", Scope::Private) + .map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))) .boxed() } /// GET /api/v1/streaming/hashtag?tag=:hashtag -/// -/// -/// **public**. Filter: `None` pub fn hashtag() -> BoxedFilter { path!("api" / "v1" / "streaming" / "hashtag") .and(warp::query()) - .and(path::end()) .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())) .boxed() } /// GET /api/v1/streaming/hashtag/local?tag=:hashtag -/// -/// -/// **public**. Filter: `None` pub fn hashtag_local() -> BoxedFilter { path!("api" / "v1" / "streaming" / "hashtag" / "local") .and(warp::query()) - .and(path::end()) .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())) .boxed() } /// GET /api/v1/streaming/list?list=:list_id -/// -/// -/// **private**. Filter: `None` pub fn list() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "list") - .and(Scope::Private.get_access_token()) - .and_then(|token| User::from_access_token(token, Scope::Private)) + user_from_path!("streaming" / "list", Scope::Private) .and(warp::query()) - .and_then(|user: User, q: query::List| (user.authorized_for_list(q.list), Ok(user))) + .and_then(|user: User, q: query::List| { + if user.owns_list(q.list) { + (Ok(q.list), Ok(user)) + } else { + (Err(error::unauthorized_list()), Ok(user)) + } + }) .untuple_one() - .and(path::end()) - .map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter())) + .map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter))) .boxed() } - -/// Combines multiple routes with the same return type together with -/// `or()` and `unify()` -#[macro_export] -macro_rules! any_of { - ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())* - }; -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::user; - - #[test] - fn user_unauthorized() { - let value = warp::test::request() - .path(&format!( - "/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1", - )) - .filter(&user()); - assert!(invalid_access_token(value)); - - let value = warp::test::request() - .path(&format!("/api/v1/streaming/user",)) - .filter(&user()); - assert!(no_access_token(value)); - } - - #[test] - #[ignore] - fn user_auth() { - let user_id: i64 = 1; - let access_token = get_access_token(user_id); - - // Query auth - let (actual_timeline, actual_user) = warp::test::request() - .path(&format!( - "/api/v1/streaming/user?access_token={}", - access_token - )) - .filter(&user()) - .expect("in test"); - - let expected_user = - User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test"); - - assert_eq!(actual_timeline, "1"); - assert_eq!(actual_user, expected_user); - - // Header auth - let (actual_timeline, actual_user) = warp::test::request() - .path("/api/v1/streaming/user") - .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&user()) - .expect("in test"); - - let expected_user = - User::from_access_token(access_token, user::Scope::Private).expect("in test"); - - assert_eq!(actual_timeline, "1"); - assert_eq!(actual_user, expected_user); - } - - #[test] - fn user_notifications_unauthorized() { - let value = warp::test::request() - .path(&format!( - "/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN", - )) - .filter(&user_notifications()); - assert!(invalid_access_token(value)); - - let value = warp::test::request() - .path(&format!("/api/v1/streaming/user/notification",)) - .filter(&user_notifications()); - assert!(no_access_token(value)); - } - - #[test] - #[ignore] - fn user_notifications_auth() { - let user_id: i64 = 1; - let access_token = get_access_token(user_id); - - // Query auth - let (actual_timeline, actual_user) = warp::test::request() - .path(&format!( - "/api/v1/streaming/user/notification?access_token={}", - access_token - )) - .filter(&user_notifications()) - .expect("in test"); - - let expected_user = User::from_access_token(access_token.clone(), user::Scope::Private) - .expect("in test") - .with_notification_filter(); - - assert_eq!(actual_timeline, "1"); - assert_eq!(actual_user, expected_user); - - // Header auth - let (actual_timeline, actual_user) = warp::test::request() - .path("/api/v1/streaming/user/notification") - .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&user_notifications()) - .expect("in test"); - - let expected_user = User::from_access_token(access_token, user::Scope::Private) - .expect("in test") - .with_notification_filter(); - - assert_eq!(actual_timeline, "1"); - assert_eq!(actual_user, expected_user); - } - #[test] - fn public_timeline() { - let value = warp::test::request() - .path("/api/v1/streaming/public") - .filter(&public()) - .expect("in test"); - - assert_eq!(value.0, "public".to_string()); - assert_eq!(value.1, User::public().with_language_filter()); - } - - #[test] - fn public_media_timeline() { - let value = warp::test::request() - .path("/api/v1/streaming/public?only_media=true") - .filter(&public_media()) - .expect("in test"); - - assert_eq!(value.0, "public:media".to_string()); - assert_eq!(value.1, User::public().with_language_filter()); - - let value = warp::test::request() - .path("/api/v1/streaming/public?only_media=1") - .filter(&public_media()) - .expect("in test"); - - assert_eq!(value.0, "public:media".to_string()); - assert_eq!(value.1, User::public().with_language_filter()); - } - - #[test] - fn public_local_timeline() { - let value = warp::test::request() - .path("/api/v1/streaming/public/local") - .filter(&public_local()) - .expect("in test"); - - assert_eq!(value.0, "public:local".to_string()); - assert_eq!(value.1, User::public().with_language_filter()); - } - - #[test] - fn public_local_media_timeline() { - let value = warp::test::request() - .path("/api/v1/streaming/public/local?only_media=true") - .filter(&public_local_media()) - .expect("in test"); - - assert_eq!(value.0, "public:local:media".to_string()); - assert_eq!(value.1, User::public().with_language_filter()); - - let value = warp::test::request() - .path("/api/v1/streaming/public/local?only_media=1") - .filter(&public_local_media()) - .expect("in test"); - - assert_eq!(value.0, "public:local:media".to_string()); - assert_eq!(value.1, User::public().with_language_filter()); - } - - #[test] - fn direct_timeline_unauthorized() { - let value = warp::test::request() - .path(&format!( - "/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN", - )) - .filter(&direct()); - assert!(invalid_access_token(value)); - - let value = warp::test::request() - .path(&format!("/api/v1/streaming/direct",)) - .filter(&direct()); - assert!(no_access_token(value)); - } - - #[test] - #[ignore] - fn direct_timeline_auth() { - let user_id: i64 = 1; - let access_token = get_access_token(user_id); - - // Query auth - let (actual_timeline, actual_user) = warp::test::request() - .path(&format!( - "/api/v1/streaming/direct?access_token={}", - access_token - )) - .filter(&direct()) - .expect("in test"); - - let expected_user = - User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test"); - - assert_eq!(actual_timeline, "direct:1"); - assert_eq!(actual_user, expected_user); - - // Header auth - let (actual_timeline, actual_user) = warp::test::request() - .path("/api/v1/streaming/direct") - .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&direct()) - .expect("in test"); - - let expected_user = - User::from_access_token(access_token, user::Scope::Private).expect("in test"); - - assert_eq!(actual_timeline, "direct:1"); - assert_eq!(actual_user, expected_user); - } - - #[test] - fn hashtag_timeline() { - let value = warp::test::request() - .path("/api/v1/streaming/hashtag?tag=a") - .filter(&hashtag()) - .expect("in test"); - - assert_eq!(value.0, "hashtag:a".to_string()); - assert_eq!(value.1, User::public()); - } - - #[test] - fn hashtag_timeline_local() { - let value = warp::test::request() - .path("/api/v1/streaming/hashtag/local?tag=a") - .filter(&hashtag_local()) - .expect("in test"); - - assert_eq!(value.0, "hashtag:a:local".to_string()); - assert_eq!(value.1, User::public()); - } - - #[test] - #[ignore] - fn list_timeline_auth() { - let list_id = 1; - let list_owner_id = get_list_owner(list_id); - let access_token = get_access_token(list_owner_id); - - // Query Auth - let (actual_timeline, actual_user) = warp::test::request() - .path(&format!( - "/api/v1/streaming/list?access_token={}&list={}", - access_token, list_id, - )) - .filter(&list()) - .expect("in test"); - - let expected_user = - User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test"); - - assert_eq!(actual_timeline, "list:1"); - assert_eq!(actual_user, expected_user); - - // Header Auth - let (actual_timeline, actual_user) = warp::test::request() - .path("/api/v1/streaming/list?list=1") - .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&list()) - .expect("in test"); - - let expected_user = - User::from_access_token(access_token, user::Scope::Private).expect("in test"); - - assert_eq!(actual_timeline, "list:1"); - assert_eq!(actual_user, expected_user); - } - - #[test] - fn list_timeline_unauthorized() { - let value = warp::test::request() - .path(&format!( - "/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1", - )) - .filter(&list()); - assert!(invalid_access_token(value)); - - let value = warp::test::request() - .path(&format!("/api/v1/streaming/list?list=1",)) - .filter(&list()); - assert!(no_access_token(value)); - } - - fn get_list_owner(list_number: i32) -> i64 { - let list_number: i64 = list_number.into(); - let conn = user::connect_to_postgres(); - let rows = &conn - .query( - "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", - &[&list_number], - ) - .expect("in test"); - - assert_eq!( - rows.len(), - 1, - "Test database must contain at least one user with a list to run this test." - ); - - rows.get(0).get(1) - } - fn get_access_token(user_id: i64) -> String { - let conn = user::connect_to_postgres(); - let rows = &conn - .query( - "SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1", - &[&user_id], - ) - .expect("Can get access token from id"); - rows.get(0).get(0) - } - fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool { - match value { - Err(error) => match error.cause() { - Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => { - true - } - _ => false, - }, - _ => false, - } - } - - fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool { - match value { - Err(error) => match error.cause() { - Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true, - _ => false, - }, - _ => false, - } - } -} diff --git a/src/user.rs b/src/user.rs index 11c2a66..c8db5de 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,24 +1,21 @@ -//! Create a User by querying the Postgres database with the user's access_token -use crate::{any_of, query}; +//! `User` struct and related functionality +use crate::{postgres, query}; use log::info; -use postgres; -use std::env; use warp::Filter as WarpFilter; -/// (currently hardcoded to localhost) -pub fn connect_to_postgres() -> postgres::Connection { - let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or(format!( - "postgres://{}@localhost/mastodon_development", - env::var("USER").expect("User env var should exist") - )); - postgres::Connection::connect(postgres_addr, postgres::TlsMode::None) - .expect("Can connect to local Postgres") +/// Combine multiple routes with the same return type together with +/// `or()` and `unify()` +#[macro_export] +macro_rules! any_of { + ($filter:expr, $($other_filter:expr),*) => { + $filter$(.or($other_filter).unify())* + }; } /// The filters that can be applied to toots after they come from Redis #[derive(Clone, Debug, PartialEq)] pub enum Filter { - None, + NoFilter, Language, Notification, } @@ -28,140 +25,99 @@ pub enum Filter { pub struct User { pub id: i64, pub access_token: String, - pub scopes: Vec, + pub scopes: OauthScope, pub langs: Option>, pub logged_in: bool, pub filter: Filter, } -#[derive(Clone, Debug, PartialEq)] -pub enum OauthScope { - Read, - ReadStatuses, - ReadNotifications, - ReadList, - Other, -} -impl From<&str> for OauthScope { - fn from(scope: &str) -> Self { - use OauthScope::*; - match scope { - "read" => Read, - "read:statuses" => ReadStatuses, - "read:notifications" => ReadNotifications, - "read:lists" => ReadList, - _ => Other, - } +impl Default for User { + fn default() -> Self { + User::public() } } +#[derive(Clone, Debug, Default, PartialEq)] +pub struct OauthScope { + pub all: bool, + pub statuses: bool, + pub notify: bool, + pub lists: bool, +} +impl From> for OauthScope { + fn from(scope_list: Vec) -> Self { + let mut oauth_scope = OauthScope::default(); + for scope in scope_list { + match scope.as_str() { + "read" => oauth_scope.all = true, + "read:statuses" => oauth_scope.statuses = true, + "read:notifications" => oauth_scope.notify = true, + "read:lists" => oauth_scope.lists = true, + _ => (), + } + } + oauth_scope + } +} + +/// Create a user based on the supplied path and access scope for the resource +#[macro_export] +macro_rules! user_from_path { + ($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +) + .and($scope.get_access_token()) + .and_then(|token| User::from_access_token(token, $scope))) +} + impl User { /// Create a user from the access token supplied in the header or query paramaters pub fn from_access_token( access_token: String, scope: Scope, ) -> Result { - let conn = connect_to_postgres(); - let result = &conn - .query( - " -SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes -FROM -oauth_access_tokens -INNER JOIN users ON -oauth_access_tokens.resource_owner_id = users.id -WHERE oauth_access_tokens.token = $1 -AND oauth_access_tokens.revoked_at IS NULL -LIMIT 1", - &[&access_token], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - if !result.is_empty() { - let only_row = result.get(0); - let id: i64 = only_row.get(1); - let scopes = only_row - .get::<_, String>(3) - .split(' ') - .map(|scope: &str| scope.into()) - .filter(|scope| scope != &OauthScope::Other) - .collect(); - dbg!(&scopes); - let langs: Option> = only_row.get(2); - info!("Granting logged-in access"); + let (id, langs, scope_list) = postgres::query_for_user_data(&access_token); + let scopes = OauthScope::from(scope_list); + if id != -1 || scope == Scope::Public { + let (logged_in, log_msg) = match id { + -1 => (false, "Public access to non-authenticated endpoints"), + _ => (true, "Granting logged-in access"), + }; + info!("{}", log_msg); Ok(User { id, access_token, scopes, langs, - logged_in: true, - filter: Filter::None, - }) - } else if let Scope::Public = scope { - info!("Granting public access to non-authenticated client"); - Ok(User { - id: -1, - access_token, - scopes: Vec::new(), - langs: None, - logged_in: false, - filter: Filter::None, + logged_in, + filter: Filter::NoFilter, }) } else { Err(warp::reject::custom("Error: Invalid access token")) } } - /// Add a Notification filter - pub fn with_notification_filter(self) -> Self { - Self { - filter: Filter::Notification, - ..self - } - } - /// Add a Language filter - pub fn with_language_filter(self) -> Self { - Self { - filter: Filter::Language, - ..self - } - } - /// Remove all filters - pub fn with_no_filter(self) -> Self { - Self { - filter: Filter::None, - ..self - } + /// Set the Notification/Language filter + pub fn set_filter(self, filter: Filter) -> Self { + Self { filter, ..self } } /// Determine whether the User is authorised for a specified list - pub fn authorized_for_list(&self, list: i64) -> Result { - let conn = connect_to_postgres(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", - &[&list], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - if !rows.is_empty() { - let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1); - if id_of_account_that_owns_the_list == self.id { - return Ok(list); - } - }; - - Err(warp::reject::custom("Error: Invalid access token")) + pub fn owns_list(&self, list: i64) -> bool { + match postgres::query_list_owner(list) { + Some(i) if i == self.id => true, + _ => false, + } } /// A public (non-authenticated) User pub fn public() -> Self { User { id: -1, - access_token: String::new(), - scopes: Vec::new(), + access_token: String::from("no access token"), + scopes: OauthScope::default(), langs: None, logged_in: false, - filter: Filter::None, + filter: Filter::NoFilter, } } } /// Whether the endpoint requires authentication or not +#[derive(PartialEq)] pub enum Scope { Public, Private, diff --git a/src/ws.rs b/src/ws.rs index dc82d72..13c2907 100644 --- a/src/ws.rs +++ b/src/ws.rs @@ -1,44 +1,86 @@ //! WebSocket-specific functionality -use crate::stream::StreamManager; +use crate::query; +use crate::stream_manager::StreamManager; +use crate::user::{Scope, User}; +use crate::user_from_path; use futures::future::Future; use futures::stream::Stream; use futures::Async; +use std::time; +use warp::filters::BoxedFilter; +use warp::{path, Filter}; /// Send a stream of replies to a WebSocket client pub fn send_replies( socket: warp::ws::WebSocket, mut stream: StreamManager, ) -> impl futures::future::Future { - let (tx, rx) = futures::sync::mpsc::unbounded(); let (ws_tx, mut ws_rx) = socket.split(); + + // Create a pipe + let (tx, rx) = futures::sync::mpsc::unbounded(); + + // Send one end of it to a different thread and tell that end to forward whatever it gets + // on to the websocket client warp::spawn( rx.map_err(|()| -> warp::Error { unreachable!() }) .forward(ws_tx) .map_err(|_| ()) .map(|_r| ()), ); - let event_stream = tokio::timer::Interval::new( - std::time::Instant::now(), - std::time::Duration::from_millis(100), - ) - .take_while(move |_| { - if ws_rx.poll().is_err() { - futures::future::ok(false) - } else { - futures::future::ok(true) - } - }); + // For as long as the client is still connected, yeild a new event every 100 ms + let event_stream = + tokio::timer::Interval::new(time::Instant::now(), time::Duration::from_millis(100)) + .take_while(move |_| match ws_rx.poll() { + Ok(Async::Ready(None)) => futures::future::ok(false), + _ => futures::future::ok(true), + }); + + // Every time you get an event from that stream, send it through the pipe event_stream .for_each(move |_json_value| { if let Ok(Async::Ready(Some(json_value))) = stream.poll() { let msg = warp::ws::Message::text(json_value.to_string()); - if !tx.is_closed() { - tx.unbounded_send(msg).expect("No send error"); - } + tx.unbounded_send(msg).expect("No send error"); }; Ok(()) }) .then(|msg| msg) .map_err(|e| println!("{}", e)) } + +pub fn websocket_routes() -> BoxedFilter<(User, Query, warp::ws::Ws2)> { + user_from_path!("streaming", Scope::Public) + .and(warp::query()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .and(warp::ws2()) + .map( + |user: User, + stream: query::Stream, + media: query::Media, + hashtag: query::Hashtag, + list: query::List, + ws: warp::ws::Ws2| { + let query = Query { + stream: stream.stream, + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + }; + (user, query, ws) + }, + ) + .untuple_one() + .boxed() +} + +#[derive(Debug)] +pub struct Query { + pub stream: String, + pub media: bool, + pub hashtag: String, + pub list: i64, +} diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..431d0c2 --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,341 @@ +use ragequit::{ + config, + timeline::*, + user::{Filter::*, Scope, User}, +}; + +#[test] +fn user_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1", + )) + .filter(&user()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/user",)) + .filter(&user()); + assert!(no_access_token(value)); +} + +#[test] +#[ignore] +fn user_auth() { + let user_id: i64 = 1; + let access_token = get_access_token(user_id); + + // Query auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/user?access_token={}", + access_token + )) + .filter(&user()) + .expect("in test"); + + let expected_user = + User::from_access_token(access_token.clone(), Scope::Private).expect("in test"); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); + + // Header auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/user") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&user()) + .expect("in test"); + + let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); +} + +#[test] +fn user_notifications_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN", + )) + .filter(&user_notifications()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/user/notification",)) + .filter(&user_notifications()); + assert!(no_access_token(value)); +} + +#[test] +#[ignore] +fn user_notifications_auth() { + let user_id: i64 = 1; + let access_token = get_access_token(user_id); + + // Query auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/user/notification?access_token={}", + access_token + )) + .filter(&user_notifications()) + .expect("in test"); + + let expected_user = User::from_access_token(access_token.clone(), Scope::Private) + .expect("in test") + .set_filter(Notification); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); + + // Header auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/user/notification") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&user_notifications()) + .expect("in test"); + + let expected_user = User::from_access_token(access_token, Scope::Private) + .expect("in test") + .set_filter(Notification); + + assert_eq!(actual_timeline, "1"); + assert_eq!(actual_user, expected_user); +} +#[test] +fn public_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public") + .filter(&public()) + .expect("in test"); + + assert_eq!(value.0, "public".to_string()); + assert_eq!(value.1, User::public().set_filter(Language)); +} + +#[test] +fn public_media_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public?only_media=true") + .filter(&public_media()) + .expect("in test"); + + assert_eq!(value.0, "public:media".to_string()); + assert_eq!(value.1, User::public().set_filter(Language)); + + let value = warp::test::request() + .path("/api/v1/streaming/public?only_media=1") + .filter(&public_media()) + .expect("in test"); + + assert_eq!(value.0, "public:media".to_string()); + assert_eq!(value.1, User::public().set_filter(Language)); +} + +#[test] +fn public_local_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public/local") + .filter(&public_local()) + .expect("in test"); + + assert_eq!(value.0, "public:local".to_string()); + assert_eq!(value.1, User::public().set_filter(Language)); +} + +#[test] +fn public_local_media_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/public/local?only_media=true") + .filter(&public_local_media()) + .expect("in test"); + + assert_eq!(value.0, "public:local:media".to_string()); + assert_eq!(value.1, User::public().set_filter(Language)); + + let value = warp::test::request() + .path("/api/v1/streaming/public/local?only_media=1") + .filter(&public_local_media()) + .expect("in test"); + + assert_eq!(value.0, "public:local:media".to_string()); + assert_eq!(value.1, User::public().set_filter(Language)); +} + +#[test] +fn direct_timeline_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN", + )) + .filter(&direct()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/direct",)) + .filter(&direct()); + assert!(no_access_token(value)); +} + +#[test] +#[ignore] +fn direct_timeline_auth() { + let user_id: i64 = 1; + let access_token = get_access_token(user_id); + + // Query auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/direct?access_token={}", + access_token + )) + .filter(&direct()) + .expect("in test"); + + let expected_user = + User::from_access_token(access_token.clone(), Scope::Private).expect("in test"); + + assert_eq!(actual_timeline, "direct:1"); + assert_eq!(actual_user, expected_user); + + // Header auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/direct") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&direct()) + .expect("in test"); + + let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); + + assert_eq!(actual_timeline, "direct:1"); + assert_eq!(actual_user, expected_user); +} + +#[test] +fn hashtag_timeline() { + let value = warp::test::request() + .path("/api/v1/streaming/hashtag?tag=a") + .filter(&hashtag()) + .expect("in test"); + + assert_eq!(value.0, "hashtag:a".to_string()); + assert_eq!(value.1, User::public()); +} + +#[test] +fn hashtag_timeline_local() { + let value = warp::test::request() + .path("/api/v1/streaming/hashtag/local?tag=a") + .filter(&hashtag_local()) + .expect("in test"); + + assert_eq!(value.0, "hashtag:a:local".to_string()); + assert_eq!(value.1, User::public()); +} + +#[test] +#[ignore] +fn list_timeline_auth() { + let list_id = 1; + let list_owner_id = get_list_owner(list_id); + let access_token = get_access_token(list_owner_id); + + // Query Auth + let (actual_timeline, actual_user) = warp::test::request() + .path(&format!( + "/api/v1/streaming/list?access_token={}&list={}", + access_token, list_id, + )) + .filter(&list()) + .expect("in test"); + + let expected_user = + User::from_access_token(access_token.clone(), Scope::Private).expect("in test"); + + assert_eq!(actual_timeline, "list:1"); + assert_eq!(actual_user, expected_user); + + // Header Auth + let (actual_timeline, actual_user) = warp::test::request() + .path("/api/v1/streaming/list?list=1") + .header("Authorization", format!("Bearer: {}", access_token.clone())) + .filter(&list()) + .expect("in test"); + + let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); + + assert_eq!(actual_timeline, "list:1"); + assert_eq!(actual_user, expected_user); +} + +#[test] +fn list_timeline_unauthorized() { + let value = warp::test::request() + .path(&format!( + "/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1", + )) + .filter(&list()); + assert!(invalid_access_token(value)); + + let value = warp::test::request() + .path(&format!("/api/v1/streaming/list?list=1",)) + .filter(&list()); + assert!(no_access_token(value)); +} + +// Helper functions for tests +fn get_list_owner(list_number: i32) -> i64 { + let list_number: i64 = list_number.into(); + let conn = config::postgres(); + let rows = &conn + .query( + "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", + &[&list_number], + ) + .expect("in test"); + + assert_eq!( + rows.len(), + 1, + "Test database must contain at least one user with a list to run this test." + ); + + rows.get(0).get(1) +} + +fn get_access_token(user_id: i64) -> String { + let conn = config::postgres(); + let rows = &conn + .query( + "SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1", + &[&user_id], + ) + .expect("Can get access token from id"); + rows.get(0).get(0) +} + +fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool { + match value { + Err(error) => match error.cause() { + Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => true, + _ => false, + }, + _ => false, + } +} + +fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool { + match value { + Err(error) => match error.cause() { + // The cause could validly be any of these, depending on the order they're checked + // (It would pass with just one, so the last one it doesn't have is "the" cause) + Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true, + Some(c) if format!("{:?}", c) == "InvalidQuery" => true, + Some(c) if format!("{:?}", c) == "MissingHeader(\"Sec-WebSocket-Protocol\")" => true, + _ => false, + }, + _ => false, + } +} From d6ae45b292444ac6cb6bc62c7d873e13738b7b7e Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Mon, 8 Jul 2019 07:31:42 -0400 Subject: [PATCH 3/4] Code reorganization --- Cargo.lock | 43 +++--- src/.env | 1 + src/config.rs | 76 ++++++++-- src/error.rs | 36 ----- src/lib.rs | 29 ++-- src/main.rs | 91 ++++++------ src/parse_client_request/mod.rs | 4 + src/parse_client_request/query.rs | 45 ++++++ src/parse_client_request/sse.rs | 106 +++++++++++++ .../user/mod.rs} | 3 +- .../user}/postgres.rs | 0 src/parse_client_request/ws.rs | 43 ++++++ src/query.rs | 66 --------- .../client_agent.rs} | 114 +++++++------- src/redis_to_client_stream/mod.rs | 75 ++++++++++ src/{ => redis_to_client_stream}/receiver.rs | 140 ++++++++++-------- src/{ => redis_to_client_stream}/redis_cmd.rs | 0 src/timeline.rs | 102 ------------- src/ws.rs | 86 ----------- tests/test.rs | 52 +++---- 20 files changed, 590 insertions(+), 522 deletions(-) delete mode 100644 src/error.rs create mode 100644 src/parse_client_request/mod.rs create mode 100644 src/parse_client_request/query.rs create mode 100644 src/parse_client_request/sse.rs rename src/{user.rs => parse_client_request/user/mod.rs} (98%) rename src/{ => parse_client_request/user}/postgres.rs (100%) create mode 100644 src/parse_client_request/ws.rs delete mode 100644 src/query.rs rename src/{stream_manager.rs => redis_to_client_stream/client_agent.rs} (51%) create mode 100644 src/redis_to_client_stream/mod.rs rename src/{ => redis_to_client_stream}/receiver.rs (71%) rename src/{ => redis_to_client_stream}/redis_cmd.rs (100%) delete mode 100644 src/timeline.rs delete mode 100644 src/ws.rs diff --git a/Cargo.lock b/Cargo.lock index 767381e..0c11ccf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,7 +27,7 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)", - "termion 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)", + "termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", "winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -144,11 +144,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "chrono" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)", - "num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)", + "num-integer 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -246,14 +247,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "env_logger" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", "humantime 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", + "termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -637,16 +638,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "num-integer" -version = "0.1.39" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", + "autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "num-traits" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "num_cpus" @@ -782,8 +787,8 @@ name = "pretty_env_logger" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", - "env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", + "chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)", + "env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1168,7 +1173,7 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "wincolor 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1176,7 +1181,7 @@ dependencies = [ [[package]] name = "termion" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1597,7 +1602,7 @@ dependencies = [ "checksum bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" "checksum cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)" = "a0c56216487bb80eec9c4516337b2588a4f2a2290d72a1416d930e4dcdb0c90d" "checksum cfg-if 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "11d43355396e872eefb45ce6342e4374ed7bc2b3a502d1b28e36d6e23c05d1f4" -"checksum chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "45912881121cb26fad7c38c17ba7daa18764771836b34fab7d3fbd93ed633878" +"checksum chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "77d81f58b7301084de3b958691458a53c3f7e0b1d702f77e550b6a88e3a88abe" "checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" "checksum constant_time_eq 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8ff012e225ce166d4422e0e78419d901719760f62ae2b7969ca6b564d1b54a9e" "checksum crossbeam-deque 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b18cd2e169ad86297e6bc0ad9aa679aee9daa4f19e8163860faf7c164e4f5a71" @@ -1609,7 +1614,7 @@ dependencies = [ "checksum digest 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05f47366984d3ad862010e22c7ce81a7dbcaebbdfb37241a620f8b6596ee135c" "checksum dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7bdb5b956a911106b6b479cdc6bc1364d359a32299f17b49994f5327132e18d9" "checksum dtoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "ea57b42383d091c85abcc2706240b94ab2a8fa1fc81c10ff23c4de06e2a90b5e" -"checksum env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b61fa891024a945da30a9581546e8cfaf5602c7b3f4c137a2805cf388f92075a" +"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3" "checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2" "checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1" "checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" @@ -1655,8 +1660,8 @@ dependencies = [ "checksum miow 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919" "checksum net2 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88" "checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945" -"checksum num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)" = "e83d528d2677f0518c570baf2b7abdcf0cd2d248860b68507bdcb3e91d4c0cea" -"checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1" +"checksum num-integer 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "8b8af8caa3184078cd419b430ff93684cb13937970fcb7639f728992f33ce674" +"checksum num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "d9c79c952a4a139f44a0fe205c4ee66ce239c0e6ce72cd935f5f7e2f717549dd" "checksum num_cpus 1.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1a23f0ed30a54abaa0c7e83b1d2d87ada7c3c23078d1d87815af3e3b6385fbba" "checksum numtoa 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef" "checksum opaque-debug 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "93f5bb2e8e8dec81642920ccff6b61f1eb94fa3020c5a325c9851ff604152409" @@ -1716,8 +1721,8 @@ dependencies = [ "checksum stringprep 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" "checksum syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)" = "a1393e4a97a19c01e900df2aec855a29f71cf02c402e2f443b8d2747c25c5dbe" "checksum synstructure 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "73687139bf99285483c96ac0add482c3776528beac1d97d444f6e91f203a2015" -"checksum termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4096add70612622289f2fdcdbd5086dc81c1e2675e6ae58d6c4f62a16c6d7f2f" -"checksum termion 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dde0593aeb8d47accea5392b39350015b5eccb12c0d98044d856983d89548dea" +"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e" +"checksum termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a8fb22f7cde82c8220e5aeacb3258ed7ce996142c77cba193f203515e26c330" "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" "checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f" "checksum tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)" = "cec6c34409089be085de9403ba2010b80e36938c9ca992c4f67f407bb13db0b1" diff --git a/src/.env b/src/.env index 9e3e3c8..94e0c32 100644 --- a/src/.env +++ b/src/.env @@ -3,3 +3,4 @@ #SERVER_ADDR= #REDIS_ADDR= #POSTGRES_ADDR= +CORS_ALLOWED_METHODS="GET OPTIONS" diff --git a/src/config.rs b/src/config.rs index fad70a4..19e011a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,14 +1,26 @@ -//! Configuration settings for servers and databases +//! Configuration settings and custom errors for servers and databases use dotenv::dotenv; use log::warn; +use serde_derive::Serialize; use std::{env, net, time}; +const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"]; +const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"]; +const DEFAULT_POSTGRES_ADDR: &str = "postgres://@localhost/mastodon_development"; +const DEFAULT_REDIS_ADDR: &str = "127.0.0.1:6379"; +const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000"; + +/// The frequency with which the StreamAgent will poll for updates to send via SSE +pub const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; +pub const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; +pub const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; + /// Configure CORS for the API server pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { warp::cors() .allow_any_origin() - .allow_methods(vec!["GET", "OPTIONS"]) - .allow_headers(vec!["Authorization", "Accept", "Cache-Control"]) + .allow_methods(CORS_ALLOWED_METHODS.to_vec()) + .allow_headers(CORS_ALLOWED_HEADERS.to_vec()) } /// Initialize logging and read values from `src/.env` @@ -20,20 +32,22 @@ pub fn logging_and_env() { /// Configure Postgres and return a connection pub fn postgres() -> postgres::Connection { let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| { - format!( - "postgres://{}@localhost/mastodon_development", - env::var("USER").unwrap_or_else(|_| { - warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); - "postgres".to_owned() - }) - ) + let mut postgres_addr = DEFAULT_POSTGRES_ADDR.to_string(); + postgres_addr.insert_str(11, + &env::var("USER").unwrap_or_else(|_| { + warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); + "postgres".to_string() + }).as_str() + ); + postgres_addr }); postgres::Connection::connect(postgres_addr, postgres::TlsMode::None) .expect("Can connect to local Postgres") } +/// Configure Redis pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { - let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| "127.0.0.1:6379".to_string()); + let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| DEFAULT_REDIS_ADDR.to_owned()); let pubsub_connection = net::TcpStream::connect(&redis_addr).expect("Can connect to Redis"); pubsub_connection .set_read_timeout(Some(time::Duration::from_millis(10))) @@ -48,7 +62,45 @@ pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { pub fn socket_address() -> net::SocketAddr { env::var("SERVER_ADDR") - .unwrap_or_else(|_| "127.0.0.1:4000".to_owned()) + .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) .parse() .expect("static string") } + +#[derive(Serialize)] +pub struct ErrorMessage { + error: String, +} +impl ErrorMessage { + fn new(msg: impl std::fmt::Display) -> Self { + Self { + error: msg.to_string(), + } + } +} + +/// Recover from Errors by sending appropriate Warp::Rejections +pub fn handle_errors( + rejection: warp::reject::Rejection, +) -> Result { + let err_txt = match rejection.cause() { + Some(text) if text.to_string() == "Missing request header 'authorization'" => { + "Error: Missing access token".to_string() + } + Some(text) => text.to_string(), + None => "Error: Nonexistant endpoint".to_string(), + }; + let json = warp::reply::json(&ErrorMessage::new(err_txt)); + Ok(warp::reply::with_status( + json, + warp::http::StatusCode::UNAUTHORIZED, + )) +} + +pub struct CustomError {} + +impl CustomError { + pub fn unauthorized_list() -> warp::reject::Rejection { + warp::reject::custom("Error: Access to list not authorized") + } +} diff --git a/src/error.rs b/src/error.rs deleted file mode 100644 index 4fa7bcd..0000000 --- a/src/error.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! Custom Errors and Warp::Rejections -use serde_derive::Serialize; - -#[derive(Serialize)] -pub struct ErrorMessage { - error: String, -} -impl ErrorMessage { - fn new(msg: impl std::fmt::Display) -> Self { - Self { - error: msg.to_string(), - } - } -} - -/// Recover from Errors by sending appropriate Warp::Rejections -pub fn handle_errors( - rejection: warp::reject::Rejection, -) -> Result { - let err_txt = match rejection.cause() { - Some(text) if text.to_string() == "Missing request header 'authorization'" => { - "Error: Missing access token".to_string() - } - Some(text) => text.to_string(), - None => "Error: Nonexistant endpoint".to_string(), - }; - let json = warp::reply::json(&ErrorMessage::new(err_txt)); - Ok(warp::reply::with_status( - json, - warp::http::StatusCode::UNAUTHORIZED, - )) -} - -pub fn unauthorized_list() -> warp::reject::Rejection { - warp::reject::custom("Error: Access to list not authorized") -} diff --git a/src/lib.rs b/src/lib.rs index a5b342f..4d129eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,28 +11,21 @@ //! Warp filters for valid requests and parses request data. Based on that data, it generates a `User` //! representing the client that made the request with data from the client's request and from //! Postgres. The `User` is authenticated, if appropriate. Warp //! repeatedly polls the -//! StreamManager for information relevant to the User. +//! ClientAgent for information relevant to the User. //! -//! * **Warp → StreamManager**: -//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent -//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `StreamManager` +//! * **Warp → ClientAgent**: +//! A new `ClientAgent` is created for each request. The `ClientAgent` exists to manage concurrent +//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `ClientAgent` //! polls the `Receiver` for any updates relevant to the current client. If there are updates, the -//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp. -//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the +//! `ClientAgent` filters them with the client's filters and passes any matching updates up to Warp. +//! The `ClientAgent` is also responsible for sending `subscribe` commands to Redis (via the //! `Receiver`) when necessary. //! -//! * **StreamManager → Receiver**: +//! * **ClientAgent → Receiver**: //! The Receiver receives data from Redis and stores it in a series of queues (one for each -//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages -//! relevant to that StreamManager and removes them from the queue. +//! ClientAgent). When (asynchronously) polled by the ClientAgent, it sends back the messages +//! relevant to that ClientAgent and removes them from the queue. pub mod config; -pub mod error; -pub mod postgres; -pub mod query; -pub mod receiver; -pub mod redis_cmd; -pub mod stream_manager; -pub mod timeline; -pub mod user; -pub mod ws; +pub mod parse_client_request; +pub mod redis_to_client_stream; diff --git a/src/main.rs b/src/main.rs index 32db4d8..ad8ee66 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,83 +1,84 @@ -use futures::{stream::Stream, Async}; use ragequit::{ - any_of, config, error, - stream_manager::StreamManager, - timeline, - user::{Filter::*, User}, - ws, + any_of, config, + parse_client_request::{sse, user, ws}, + redis_to_client_stream, + redis_to_client_stream::ClientAgent, }; + use warp::{ws::Ws2, Filter as WarpFilter}; fn main() { config::logging_and_env(); - let stream_manager_sse = StreamManager::new(); - let stream_manager_ws = stream_manager_sse.clone(); + let client_agent_sse = ClientAgent::blank(); + let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); // Server Sent Events + // + // For SSE, the API requires users to use different endpoints, so we first filter based on + // the endpoint. Using that endpoint determine the `timeline` the user is requesting, + // the scope for that `timeline`, and authenticate the `User` if they provided a token. let sse_routes = any_of!( // GET /api/v1/streaming/user/notification [private; notification filter] - timeline::user_notifications(), + sse::Request::user_notifications(), // GET /api/v1/streaming/user [private; language filter] - timeline::user(), + sse::Request::user(), // GET /api/v1/streaming/public/local?only_media=true [public; language filter] - timeline::public_local_media(), + sse::Request::public_local_media(), // GET /api/v1/streaming/public?only_media=true [public; language filter] - timeline::public_media(), + sse::Request::public_media(), // GET /api/v1/streaming/public/local [public; language filter] - timeline::public_local(), + sse::Request::public_local(), // GET /api/v1/streaming/public [public; language filter] - timeline::public(), + sse::Request::public(), // GET /api/v1/streaming/direct [private; *no* filter] - timeline::direct(), + sse::Request::direct(), // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] - timeline::hashtag(), + sse::Request::hashtag(), // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] - timeline::hashtag_local(), + sse::Request::hashtag_local(), // GET /api/v1/streaming/list?list=:list_id [private; no filter] - timeline::list() + sse::Request::list() ) .untuple_one() .and(warp::sse()) - .map(move |timeline: String, user: User, sse: warp::sse::Sse| { - let mut stream_manager = stream_manager_sse.manage_new_timeline(&timeline, user); - let event_stream = tokio::timer::Interval::new( - std::time::Instant::now(), - std::time::Duration::from_millis(100), - ) - .filter_map(move |_| match stream_manager.poll() { - Ok(Async::Ready(Some(json_value))) => Some(( - warp::sse::event(json_value["event"].clone().to_string()), - warp::sse::data(json_value["payload"].clone()), - )), - _ => None, - }); - sse.reply(warp::sse::keep(event_stream, None)) - }) + .map( + move |timeline: String, user: user::User, sse_connection_to_client: warp::sse::Sse| { + // Create a new ClientAgent + let mut client_agent = client_agent_sse.clone_with_shared_receiver(); + // Assign that agent to generate a stream of updates for the user/timeline pair + client_agent.init_for_user(&timeline, user); + // send the updates through the SSE connection + redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client) + }, + ) .with(warp::reply::with::header("Connection", "keep-alive")) - .recover(error::handle_errors); + .recover(config::handle_errors); // WebSocket - let websocket_routes = ws::websocket_routes() - .and_then(move |mut user: User, q: ws::Query, ws: Ws2| { + // + // For WS, the API specifies a single endpoint, so we extract the User/timeline pair + // directy from the query + let websocket_routes = ws::extract_user_and_query() + .and_then(move |mut user: user::User, q: ws::Query, ws: Ws2| { + let token = user.access_token.clone(); let read_scope = user.scopes.clone(); + let timeline = match q.stream.as_ref() { // Public endpoints: tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl), tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), tl @ "public" | tl @ "public:local" => tl.to_string(), // Hashtag endpoints: - // TODO: handle missing query tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag), // Private endpoints: User "user" if user.logged_in && (read_scope.all || read_scope.statuses) => { format!("{}", user.id) } "user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => { - user = user.set_filter(Notification); + user = user.set_filter(user::Filter::Notification); format!("{}", user.id) } // List endpoint: - // TODO: handle missing query "list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => { format!("list:{}", q.list) } @@ -92,11 +93,17 @@ fn main() { // Other endpoints don't exist: _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), }; - let token = user.access_token.clone(); - let stream_manager = stream_manager_ws.manage_new_timeline(&timeline, user); + // Create a new ClientAgent + let mut client_agent = client_agent_ws.clone_with_shared_receiver(); + // Assign that agent to generate a stream of updates for the user/timeline pair + client_agent.init_for_user(&timeline, user); + // send the updates through the WS connection (along with the User's access_token + // which is sent for security) Ok(( - ws.on_upgrade(move |socket| ws::send_replies(socket, stream_manager)), + ws.on_upgrade(move |socket| { + redis_to_client_stream::send_updates_to_ws(socket, client_agent) + }), token, )) }) diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs new file mode 100644 index 0000000..d870823 --- /dev/null +++ b/src/parse_client_request/mod.rs @@ -0,0 +1,4 @@ +pub mod query; +pub mod sse; +pub mod user; +pub mod ws; diff --git a/src/parse_client_request/query.rs b/src/parse_client_request/query.rs new file mode 100644 index 0000000..a14d559 --- /dev/null +++ b/src/parse_client_request/query.rs @@ -0,0 +1,45 @@ +//! Validate query prarams with type checking +use serde_derive::Deserialize; +use warp::filters::BoxedFilter; +use warp::Filter as WarpFilter; + +macro_rules! query { + ($name:tt => $parameter:tt:$type:tt) => { + #[derive(Deserialize, Debug, Default)] + pub struct $name { + pub $parameter: $type, + } + impl $name { + pub fn to_filter() -> BoxedFilter<(Self,)> { + warp::query() + .or(warp::any().map(Self::default)) + .unify() + .boxed() + } + } + }; +} +query!(Media => only_media:String); +impl Media { + pub fn is_truthy(&self) -> bool { + self.only_media == "true" || self.only_media == "1" + } +} +query!(Hashtag => tag: String); +query!(List => list: i64); +query!(Auth => access_token: String); +query!(Stream => stream: String); +impl ToString for Stream { + fn to_string(&self) -> String { + format!("{:?}", self) + } +} + +pub fn optional_media_query() -> BoxedFilter<(Media,)> { + warp::query() + .or(warp::any().map(|| Media { + only_media: "false".to_owned(), + })) + .unify() + .boxed() +} diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs new file mode 100644 index 0000000..fa3fa5d --- /dev/null +++ b/src/parse_client_request/sse.rs @@ -0,0 +1,106 @@ +//! Filters for all the endpoints accessible for Server Sent Event updates +use super::{ + query, + user::{Filter::*, Scope, User}, +}; +use crate::{config::CustomError, user_from_path}; +use warp::{filters::BoxedFilter, path, Filter}; + +#[allow(dead_code)] +type TimelineUser = ((String, User),); + +pub enum Request {} + +impl Request { + /// GET /api/v1/streaming/user + pub fn user() -> BoxedFilter { + user_from_path!("streaming" / "user", Scope::Private) + .map(|user: User| (user.id.to_string(), user)) + .boxed() + } + + /// GET /api/v1/streaming/user/notification + /// + /// + /// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented? + pub fn user_notifications() -> BoxedFilter { + user_from_path!("streaming" / "user" / "notification", Scope::Private) + .map(|user: User| (user.id.to_string(), user.set_filter(Notification))) + .boxed() + } + + /// GET /api/v1/streaming/public + pub fn public() -> BoxedFilter { + user_from_path!("streaming" / "public", Scope::Public) + .map(|user: User| ("public".to_owned(), user.set_filter(Language))) + .boxed() + } + + /// GET /api/v1/streaming/public?only_media=true + pub fn public_media() -> BoxedFilter { + user_from_path!("streaming" / "public", Scope::Public) + .and(warp::query()) + .map(|user: User, q: query::Media| match q.only_media.as_ref() { + "1" | "true" => ("public:media".to_owned(), user.set_filter(Language)), + _ => ("public".to_owned(), user.set_filter(Language)), + }) + .boxed() + } + + /// GET /api/v1/streaming/public/local + pub fn public_local() -> BoxedFilter { + user_from_path!("streaming" / "public" / "local", Scope::Public) + .map(|user: User| ("public:local".to_owned(), user.set_filter(Language))) + .boxed() + } + + /// GET /api/v1/streaming/public/local?only_media=true + pub fn public_local_media() -> BoxedFilter { + user_from_path!("streaming" / "public" / "local", Scope::Public) + .and(warp::query()) + .map(|user: User, q: query::Media| match q.only_media.as_ref() { + "1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)), + _ => ("public:local".to_owned(), user.set_filter(Language)), + }) + .boxed() + } + + /// GET /api/v1/streaming/direct + pub fn direct() -> BoxedFilter { + user_from_path!("streaming" / "direct", Scope::Private) + .map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))) + .boxed() + } + + /// GET /api/v1/streaming/hashtag?tag=:hashtag + pub fn hashtag() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "hashtag") + .and(warp::query()) + .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())) + .boxed() + } + + /// GET /api/v1/streaming/hashtag/local?tag=:hashtag + pub fn hashtag_local() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "hashtag" / "local") + .and(warp::query()) + .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())) + .boxed() + } + + /// GET /api/v1/streaming/list?list=:list_id + pub fn list() -> BoxedFilter { + user_from_path!("streaming" / "list", Scope::Private) + .and(warp::query()) + .and_then(|user: User, q: query::List| { + if user.owns_list(q.list) { + (Ok(q.list), Ok(user)) + } else { + (Err(CustomError::unauthorized_list()), Ok(user)) + } + }) + .untuple_one() + .map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter))) + .boxed() + } +} diff --git a/src/user.rs b/src/parse_client_request/user/mod.rs similarity index 98% rename from src/user.rs rename to src/parse_client_request/user/mod.rs index c8db5de..5684ccc 100644 --- a/src/user.rs +++ b/src/parse_client_request/user/mod.rs @@ -1,5 +1,6 @@ //! `User` struct and related functionality -use crate::{postgres, query}; +mod postgres; +use crate::parse_client_request::query; use log::info; use warp::Filter as WarpFilter; diff --git a/src/postgres.rs b/src/parse_client_request/user/postgres.rs similarity index 100% rename from src/postgres.rs rename to src/parse_client_request/user/postgres.rs diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs new file mode 100644 index 0000000..77ab1d6 --- /dev/null +++ b/src/parse_client_request/ws.rs @@ -0,0 +1,43 @@ +//! WebSocket functionality +use super::{ + query, + user::{Scope, User}, +}; +use crate::user_from_path; +use warp::{filters::BoxedFilter, path, Filter}; + +/// WebSocket filters +pub fn extract_user_and_query() -> BoxedFilter<(User, Query, warp::ws::Ws2)> { + user_from_path!("streaming", Scope::Public) + .and(warp::query()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .and(warp::ws2()) + .map( + |user: User, + stream: query::Stream, + media: query::Media, + hashtag: query::Hashtag, + list: query::List, + ws: warp::ws::Ws2| { + let query = Query { + stream: stream.stream, + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + }; + (user, query, ws) + }, + ) + .untuple_one() + .boxed() +} + +#[derive(Debug)] +pub struct Query { + pub stream: String, + pub media: bool, + pub hashtag: String, + pub list: i64, +} diff --git a/src/query.rs b/src/query.rs deleted file mode 100644 index 8b43d71..0000000 --- a/src/query.rs +++ /dev/null @@ -1,66 +0,0 @@ -//! Validate query prarams with type checking -use serde_derive::Deserialize; -use warp::filters::BoxedFilter; -use warp::Filter as WarpFilter; - -#[derive(Deserialize, Debug, Default)] -pub struct Media { - pub only_media: String, -} -impl Media { - pub fn to_filter() -> BoxedFilter<(Self,)> { - warp::query() - .or(warp::any().map(Self::default)) - .unify() - .boxed() - } - pub fn is_truthy(&self) -> bool { - self.only_media == "true" || self.only_media == "1" - } -} -#[derive(Deserialize, Debug, Default)] -pub struct Hashtag { - pub tag: String, -} -impl Hashtag { - pub fn to_filter() -> BoxedFilter<(Self,)> { - warp::query() - .or(warp::any().map(Self::default)) - .unify() - .boxed() - } -} -#[derive(Deserialize, Debug, Default)] -pub struct List { - pub list: i64, -} -impl List { - pub fn to_filter() -> BoxedFilter<(Self,)> { - warp::query() - .or(warp::any().map(Self::default)) - .unify() - .boxed() - } -} -#[derive(Deserialize, Debug)] -pub struct Auth { - pub access_token: String, -} -#[derive(Deserialize, Debug)] -pub struct Stream { - pub stream: String, -} -impl ToString for Stream { - fn to_string(&self) -> String { - format!("{:?}", self) - } -} - -pub fn optional_media_query() -> BoxedFilter<(Media,)> { - warp::query() - .or(warp::any().map(|| Media { - only_media: "false".to_owned(), - })) - .unify() - .boxed() -} diff --git a/src/stream_manager.rs b/src/redis_to_client_stream/client_agent.rs similarity index 51% rename from src/stream_manager.rs rename to src/redis_to_client_stream/client_agent.rs index ecffd8a..74ae4e8 100644 --- a/src/stream_manager.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -1,45 +1,41 @@ -//! The `StreamManager` is responsible to providing an interface between the `Warp` -//! filters and the underlying mechanics of talking with Redis/managing multiple -//! threads. The `StreamManager` is the only struct that any Warp code should -//! need to communicate with. +//! Provides an interface between the `Warp` filters and the underlying +//! mechanics of talking with Redis/managing multiple threads. //! -//! The `StreamManager`'s interface is very simple. All you can do with it is: -//! * Create a totally new `StreamManger` with no shared data; -//! * Assign an existing `StreamManager` to manage an new timeline/user pair; or -//! * Poll an existing `StreamManager` to see if there are any new messages +//! The `ClientAgent`'s interface is very simple. All you can do with it is: +//! * Create a totally new `ClientAgent` with no shared data; +//! * Clone an existing `ClientAgent`, sharing the `Receiver`; +//! * to manage an new timeline/user pair; or +//! * Poll an existing `ClientAgent` to see if there are any new messages //! for clients //! -//! When you poll the `StreamManager`, it is responsible for polling internal data +//! When you poll the `ClientAgent`, it is responsible for polling internal data //! structures, getting any updates from Redis, and then filtering out any updates //! that should be excluded by relevant filters. //! //! Because `StreamManagers` are lightweight data structures that do not directly -//! communicate with Redis, it is appropriate to create a new `StreamManager` for -//! each new client connection. -use crate::{ - receiver::Receiver, - user::{Filter, User}, -}; +//! communicate with Redis, it we create a new `ClientAgent` for +//! each new client connection (each in its own thread). +use super::receiver::Receiver; +use crate::parse_client_request::user::User; use futures::{Async, Poll}; use serde_json::{json, Value}; -use std::sync; -use std::time; +use std::{sync, time}; use tokio::io::Error; use uuid::Uuid; /// Struct for managing all Redis streams. #[derive(Clone, Default, Debug)] -pub struct StreamManager { +pub struct ClientAgent { receiver: sync::Arc>, id: uuid::Uuid, target_timeline: String, current_user: User, } -impl StreamManager { - /// Create a new `StreamManager` with no shared data. - pub fn new() -> Self { - StreamManager { +impl ClientAgent { + /// Create a new `ClientAgent` with no shared data. + pub fn blank() -> Self { + ClientAgent { receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())), id: Uuid::default(), target_timeline: String::new(), @@ -47,38 +43,44 @@ impl StreamManager { } } - /// Assign the `StreamManager` to manage a new timeline/user pair. + /// Clones the `ClientAgent`, sharing the `Receiver`. + pub fn clone_with_shared_receiver(&self) -> Self { + Self { + receiver: self.receiver.clone(), + id: self.id, + target_timeline: self.target_timeline.clone(), + current_user: self.current_user.clone(), + } + } + /// Initializes the `ClientAgent` with a unique ID, a `User`, and the target timeline. + /// Also passes values to the `Receiver` for it's initialization. /// /// Note that this *may or may not* result in a new Redis connection. /// If the server has already subscribed to the timeline on behalf of - /// a different user, the `StreamManager` is responsible for figuring + /// a different user, the `Receiver` is responsible for figuring /// that out and avoiding duplicated connections. Thus, it is safe to /// use this method for each new client connection. - pub fn manage_new_timeline(&self, target_timeline: &str, user: User) -> Self { - let manager_id = Uuid::new_v4(); + pub fn init_for_user(&mut self, target_timeline: &str, user: User) { + self.id = Uuid::new_v4(); + self.target_timeline = target_timeline.to_owned(); + self.current_user = user; let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)"); - receiver.manage_new_timeline(manager_id, target_timeline); - StreamManager { - id: manager_id, - current_user: user, - target_timeline: target_timeline.to_owned(), - receiver: self.receiver.clone(), - } + receiver.manage_new_timeline(self.id, target_timeline); } } -/// The stream that the `StreamManager` manages. `Poll` is the only method implemented. -impl futures::stream::Stream for StreamManager { +/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented. +impl futures::stream::Stream for ClientAgent { type Item = Value; type Error = Error; /// Checks for any new messages that should be sent to the client. /// - /// The `StreamManager` will poll underlying data structures and will reply - /// with an `Ok(Ready(Some(Value)))` if there is a new message to send to + /// The `ClientAgent` polls the `Receiver` and replies + /// with `Ok(Ready(Some(Value)))` if there is a new message to send to /// the client. If there is no new message or if the new message should be - /// filtered out based on one of the user's filters, then the `StreamManager` - /// will reply with `Ok(NotReady)`. The `StreamManager` will buble up any + /// filtered out based on one of the user's filters, then the `ClientAgent` + /// replies with `Ok(NotReady)`. The `ClientAgent` bubles up any /// errors from the underlying data structures. fn poll(&mut self) -> Poll, Self::Error> { let start_time = time::Instant::now(); @@ -86,30 +88,35 @@ impl futures::stream::Stream for StreamManager { let mut receiver = self .receiver .lock() - .expect("StreamManager: No other thread panic"); + .expect("ClientAgent: No other thread panic"); receiver.configure_for_polling(self.id, &self.target_timeline.clone()); receiver.poll() }; - println!("Polling took: {:?}", start_time.elapsed()); - let result = match result { + + if start_time.elapsed() > time::Duration::from_millis(20) { + println!("Polling took: {:?}", start_time.elapsed()); + } + match result { Ok(Async::Ready(Some(value))) => { - let user_langs = self.current_user.langs.clone(); + let user = &self.current_user; let toot = Toot::from_json(value); - toot.ignore_if_caught_by_filter(&self.current_user.filter, user_langs) + toot.filter(&user) } Ok(inner_value) => Ok(inner_value), Err(e) => Err(e), - }; - result + } } } +/// The message to send to the client (which might not literally be a toot in some cases). struct Toot { category: String, payload: String, language: String, } + impl Toot { + /// Construct a `Toot` from well-formed JSON. fn from_json(value: Value) -> Self { Self { category: value["event"].as_str().expect("Redis string").to_owned(), @@ -121,6 +128,7 @@ impl Toot { } } + /// Convert a `Toot` to JSON inside an Option. fn to_optional_json(&self) -> Option { Some(json!( {"event": self.category, @@ -128,11 +136,8 @@ impl Toot { )) } - fn ignore_if_caught_by_filter( - &self, - filter: &Filter, - user_langs: Option>, - ) -> Result>, Error> { + /// Filter out any `Toot`'s that fail the provided filter. + fn filter(&self, user: &User) -> Result>, Error> { let toot = self; let (send_msg, skip_msg) = ( @@ -140,13 +145,14 @@ impl Toot { Ok(Async::NotReady), ); - match &filter { + use crate::parse_client_request::user::Filter; + match &user.filter { Filter::NoFilter => send_msg, Filter::Notification if toot.category == "notification" => send_msg, // If not, skip it Filter::Notification => skip_msg, - Filter::Language if user_langs.is_none() => send_msg, - Filter::Language if user_langs.expect("").contains(&toot.language) => send_msg, + Filter::Language if user.langs.is_none() => send_msg, + Filter::Language if user.langs.clone().expect("").contains(&toot.language) => send_msg, // If not, skip it Filter::Language => skip_msg, } diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs new file mode 100644 index 0000000..b8ae146 --- /dev/null +++ b/src/redis_to_client_stream/mod.rs @@ -0,0 +1,75 @@ +pub mod client_agent; +pub mod receiver; +pub mod redis_cmd; + +use crate::config; +pub use client_agent::ClientAgent; +use futures::{future::Future, stream::Stream, Async}; +use std::{env, time}; + +pub fn send_updates_to_sse( + mut client_agent: ClientAgent, + connection: warp::sse::Sse, +) -> impl warp::reply::Reply { + let sse_update_interval = env::var("SSE_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(config::DEFAULT_SSE_UPDATE_INTERVAL); + let event_stream = tokio::timer::Interval::new( + time::Instant::now(), + time::Duration::from_millis(sse_update_interval), + ) + .filter_map(move |_| match client_agent.poll() { + Ok(Async::Ready(Some(json_value))) => Some(( + warp::sse::event(json_value["event"].clone().to_string()), + warp::sse::data(json_value["payload"].clone()), + )), + _ => None, + }); + + connection.reply(warp::sse::keep(event_stream, None)) +} + +/// Send a stream of replies to a WebSocket client +pub fn send_updates_to_ws( + socket: warp::ws::WebSocket, + mut stream: ClientAgent, +) -> impl futures::future::Future { + let (ws_tx, mut ws_rx) = socket.split(); + + // Create a pipe + let (tx, rx) = futures::sync::mpsc::unbounded(); + + // Send one end of it to a different thread and tell that end to forward whatever it gets + // on to the websocket client + warp::spawn( + rx.map_err(|()| -> warp::Error { unreachable!() }) + .forward(ws_tx) + .map_err(|_| ()) + .map(|_r| ()), + ); + + // For as long as the client is still connected, yeild a new event every 100 ms + let ws_update_interval = env::var("WS_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(config::DEFAULT_WS_UPDATE_INTERVAL); + let event_stream = tokio::timer::Interval::new( + time::Instant::now(), + time::Duration::from_millis(ws_update_interval), + ) + .take_while(move |_| match ws_rx.poll() { + Ok(Async::Ready(None)) => futures::future::ok(false), + _ => futures::future::ok(true), + }); + + // Every time you get an event from that stream, send it through the pipe + event_stream + .for_each(move |_json_value| { + if let Ok(Async::Ready(Some(json_value))) = stream.poll() { + let msg = warp::ws::Message::text(json_value.to_string()); + tx.unbounded_send(msg).expect("No send error"); + }; + Ok(()) + }) + .then(|msg| msg) + .map_err(|e| println!("{}", e)) +} diff --git a/src/receiver.rs b/src/redis_to_client_stream/receiver.rs similarity index 71% rename from src/receiver.rs rename to src/redis_to_client_stream/receiver.rs index 4bf7d04..88e57f6 100644 --- a/src/receiver.rs +++ b/src/redis_to_client_stream/receiver.rs @@ -1,16 +1,13 @@ -//! Interface with Redis and stream the results to the `StreamManager` -//! There is only one `Receiver`, which suggests that it's name is bad. -//! -//! **TODO**: Consider changing the name. Maybe RedisConnectionPool? -//! There are many AsyncReadableStreams, though. How do they fit in? -//! Figure this out ASAP. -//! A new one is created every time the Receiver is polled -use crate::{config, pubsub_cmd, redis_cmd}; +//! Receives data from Redis, sorts it by `ClientAgent`, and stores it until +//! polled by the correct `ClientAgent`. Also manages sububscriptions and +//! unsubscriptions to/from Redis. +use super::redis_cmd; +use crate::{config, pubsub_cmd}; use futures::{Async, Poll}; use log::info; use regex::Regex; use serde_json::Value; -use std::{collections, io::Read, io::Write, net, time}; +use std::{collections, env, io::Read, io::Write, net, time}; use tokio::io::{AsyncRead, Error}; use uuid::Uuid; @@ -19,7 +16,8 @@ use uuid::Uuid; pub struct Receiver { pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, - tl: String, + redis_polled_at: time::Instant, + timeline: String, manager_id: Uuid, msg_queues: collections::HashMap, clients_per_timeline: collections::HashMap, @@ -33,7 +31,8 @@ impl Receiver { Self { pubsub_connection, secondary_redis_connection, - tl: String::new(), + redis_polled_at: time::Instant::now(), + timeline: String::new(), manager_id: Uuid::default(), msg_queues: collections::HashMap::new(), clients_per_timeline: collections::HashMap::new(), @@ -43,60 +42,60 @@ impl Receiver { /// Assigns the `Receiver` a new timeline to monitor and runs other /// first-time setup. /// - /// Importantly, this method calls `subscribe_or_unsubscribe_as_needed`, + /// Note: this method calls `subscribe_or_unsubscribe_as_needed`, /// so Redis PubSub subscriptions are only updated when a new timeline /// comes under management for the first time. pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) { self.manager_id = manager_id; - self.tl = timeline.to_string(); - let old_value = self - .msg_queues + self.timeline = timeline.to_string(); + self.msg_queues .insert(self.manager_id, MsgQueue::new(timeline)); - // Consider removing/refactoring - if let Some(value) = old_value { - eprintln!( - "Data was overwritten when it shouldn't have been. Old data was: {:#?}", - value - ); - } self.subscribe_or_unsubscribe_as_needed(timeline); } /// Set the `Receiver`'s manager_id and target_timeline fields to the approprate /// value to be polled by the current `StreamManager`. pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) { - if &manager_id != &self.manager_id { - //println!("New Manager: {}", &manager_id); - } self.manager_id = manager_id; - self.tl = timeline.to_string(); + self.timeline = timeline.to_string(); } /// Drop any PubSub subscriptions that don't have active clients and check /// that there's a subscription to the current one. If there isn't, then /// subscribe to it. - fn subscribe_or_unsubscribe_as_needed(&mut self, tl: &str) { + fn subscribe_or_unsubscribe_as_needed(&mut self, timeline: &str) { let mut timelines_to_modify = Vec::new(); - timelines_to_modify.push((tl.to_owned(), 1)); + struct Change { + timeline: String, + change_in_subscriber_number: i32, + } + + timelines_to_modify.push(Change { + timeline: timeline.to_owned(), + change_in_subscriber_number: 1, + }); // Keep only message queues that have been polled recently self.msg_queues.retain(|_id, msg_queue| { if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) { true } else { - let timeline = msg_queue.redis_channel.clone(); - timelines_to_modify.push((timeline, -1)); + let timeline = &msg_queue.redis_channel; + timelines_to_modify.push(Change { + timeline: timeline.to_owned(), + change_in_subscriber_number: -1, + }); false } }); // Record the lower number of clients subscribed to that channel - for (timeline, numerical_change) in timelines_to_modify { + for change in timelines_to_modify { let mut need_to_subscribe = false; let count_of_subscribed_clients = self .clients_per_timeline - .entry(timeline.to_owned()) - .and_modify(|n| *n += numerical_change) + .entry(change.timeline.clone()) + .and_modify(|n| *n += change.change_in_subscriber_number) .or_insert_with(|| { need_to_subscribe = true; 1 @@ -104,11 +103,38 @@ impl Receiver { // If no clients, unsubscribe from the channel if *count_of_subscribed_clients <= 0 { info!("Sent unsubscribe command"); - pubsub_cmd!("unsubscribe", self, timeline.clone()); + pubsub_cmd!("unsubscribe", self, change.timeline.clone()); } if need_to_subscribe { info!("Sent subscribe command"); - pubsub_cmd!("subscribe", self, timeline.clone()); + pubsub_cmd!("subscribe", self, change.timeline.clone()); + } + } + } + + /// Polls Redis for any new messages and adds them to the `MsgQueue` for + /// the appropriate `ClientAgent`. + fn poll_redis(&mut self) { + let mut buffer = vec![0u8; 3000]; + // Add any incoming messages to the back of the relevant `msg_queues` + // NOTE: This could be more/other than the `msg_queue` currently being polled + let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection); + if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() { + let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); + // capture everything between `{` and `}` as potential JSON + let json_regex = Regex::new(r"(?P\{.*\})").expect("Hard-coded"); + // capture the timeline so we know which queues to add it to + let timeline_regex = Regex::new(r"timeline:(?P.*?)\r").expect("Hard-codded"); + if let Some(result) = json_regex.captures(raw_redis_response) { + let timeline = + timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string(); + + let msg: Value = serde_json::from_str(&result["json"].to_string().clone()).unwrap(); + for msg_queue in self.msg_queues.values_mut() { + if msg_queue.redis_channel == timeline { + msg_queue.messages.push_back(msg.clone()); + } + } } } } @@ -128,47 +154,41 @@ impl Receiver { } } } + impl Default for Receiver { fn default() -> Self { Receiver::new() } } +/// The stream that the ClientAgent polls to learn about new messages. impl futures::stream::Stream for Receiver { type Item = Value; type Error = Error; + /// Returns the oldest message in the `ClientAgent`'s queue (if any). + /// + /// Note: This method does **not** poll Redis every time, because polling + /// Redis is signifiantly more time consuming that simply returning the + /// message already in a queue. Thus, we only poll Redis if it has not + /// been polled lately. fn poll(&mut self) -> Poll, Self::Error> { - let mut buffer = vec![0u8; 3000]; - let timeline = self.tl.clone(); + let timeline = self.timeline.clone(); + + let redis_poll_interval = env::var("REDIS_POLL_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(config::DEFAULT_REDIS_POLL_INTERVAL); + + if self.redis_polled_at.elapsed() > time::Duration::from_millis(redis_poll_interval) { + self.poll_redis(); + self.redis_polled_at = time::Instant::now(); + } // Record current time as last polled time self.msg_queues .entry(self.manager_id) .and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now()); - // Add any incomming messages to the back of the relevant `msg_queues` - // NOTE: This could be more/other than the `msg_queue` currently being polled - let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection); - if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer)? { - let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); - // capture everything between `{` and `}` as potential JSON - let json_regex = Regex::new(r"(?P\{.*\})").expect("Hard-coded"); - // capture the timeline so we know which queues to add it to - let timeline_regex = Regex::new(r"timeline:(?P.*?)\r").expect("Hard-codded"); - if let Some(result) = json_regex.captures(raw_redis_response) { - let timeline = - timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string(); - - let msg: Value = serde_json::from_str(&result["json"].to_string().clone())?; - for msg_queue in self.msg_queues.values_mut() { - if msg_queue.redis_channel == timeline { - msg_queue.messages.push_back(msg.clone()); - } - } - } - } - // If the `msg_queue` being polled has any new messages, return the first (oldest) one match self .msg_queues @@ -188,7 +208,7 @@ impl futures::stream::Stream for Receiver { impl Drop for Receiver { fn drop(&mut self) { - pubsub_cmd!("unsubscribe", self, self.tl.clone()); + pubsub_cmd!("unsubscribe", self, self.timeline.clone()); } } diff --git a/src/redis_cmd.rs b/src/redis_to_client_stream/redis_cmd.rs similarity index 100% rename from src/redis_cmd.rs rename to src/redis_to_client_stream/redis_cmd.rs diff --git a/src/timeline.rs b/src/timeline.rs deleted file mode 100644 index 71b6ab2..0000000 --- a/src/timeline.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Filters for all the endpoints accessible for Server Sent Event updates -use crate::error; -use crate::query; -use crate::user::{Filter::*, Scope, User}; -use crate::user_from_path; -use warp::filters::BoxedFilter; -use warp::{path, Filter}; - -#[allow(dead_code)] -type TimelineUser = ((String, User),); - -/// GET /api/v1/streaming/user -pub fn user() -> BoxedFilter { - user_from_path!("streaming" / "user", Scope::Private) - .map(|user: User| (user.id.to_string(), user)) - .boxed() -} - -/// GET /api/v1/streaming/user/notification -/// -/// -/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented? -pub fn user_notifications() -> BoxedFilter { - user_from_path!("streaming" / "user" / "notification", Scope::Private) - .map(|user: User| (user.id.to_string(), user.set_filter(Notification))) - .boxed() -} - -/// GET /api/v1/streaming/public -pub fn public() -> BoxedFilter { - user_from_path!("streaming" / "public", Scope::Public) - .map(|user: User| ("public".to_owned(), user.set_filter(Language))) - .boxed() -} - -/// GET /api/v1/streaming/public?only_media=true -pub fn public_media() -> BoxedFilter { - user_from_path!("streaming" / "public", Scope::Public) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:media".to_owned(), user.set_filter(Language)), - _ => ("public".to_owned(), user.set_filter(Language)), - }) - .boxed() -} - -/// GET /api/v1/streaming/public/local -pub fn public_local() -> BoxedFilter { - user_from_path!("streaming" / "public" / "local", Scope::Public) - .map(|user: User| ("public:local".to_owned(), user.set_filter(Language))) - .boxed() -} - -/// GET /api/v1/streaming/public/local?only_media=true -pub fn public_local_media() -> BoxedFilter { - user_from_path!("streaming" / "public" / "local", Scope::Public) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)), - _ => ("public:local".to_owned(), user.set_filter(Language)), - }) - .boxed() -} - -/// GET /api/v1/streaming/direct -pub fn direct() -> BoxedFilter { - user_from_path!("streaming" / "direct", Scope::Private) - .map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))) - .boxed() -} - -/// GET /api/v1/streaming/hashtag?tag=:hashtag -pub fn hashtag() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "hashtag") - .and(warp::query()) - .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())) - .boxed() -} - -/// GET /api/v1/streaming/hashtag/local?tag=:hashtag -pub fn hashtag_local() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "hashtag" / "local") - .and(warp::query()) - .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())) - .boxed() -} - -/// GET /api/v1/streaming/list?list=:list_id -pub fn list() -> BoxedFilter { - user_from_path!("streaming" / "list", Scope::Private) - .and(warp::query()) - .and_then(|user: User, q: query::List| { - if user.owns_list(q.list) { - (Ok(q.list), Ok(user)) - } else { - (Err(error::unauthorized_list()), Ok(user)) - } - }) - .untuple_one() - .map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter))) - .boxed() -} diff --git a/src/ws.rs b/src/ws.rs deleted file mode 100644 index 13c2907..0000000 --- a/src/ws.rs +++ /dev/null @@ -1,86 +0,0 @@ -//! WebSocket-specific functionality -use crate::query; -use crate::stream_manager::StreamManager; -use crate::user::{Scope, User}; -use crate::user_from_path; -use futures::future::Future; -use futures::stream::Stream; -use futures::Async; -use std::time; -use warp::filters::BoxedFilter; -use warp::{path, Filter}; - -/// Send a stream of replies to a WebSocket client -pub fn send_replies( - socket: warp::ws::WebSocket, - mut stream: StreamManager, -) -> impl futures::future::Future { - let (ws_tx, mut ws_rx) = socket.split(); - - // Create a pipe - let (tx, rx) = futures::sync::mpsc::unbounded(); - - // Send one end of it to a different thread and tell that end to forward whatever it gets - // on to the websocket client - warp::spawn( - rx.map_err(|()| -> warp::Error { unreachable!() }) - .forward(ws_tx) - .map_err(|_| ()) - .map(|_r| ()), - ); - - // For as long as the client is still connected, yeild a new event every 100 ms - let event_stream = - tokio::timer::Interval::new(time::Instant::now(), time::Duration::from_millis(100)) - .take_while(move |_| match ws_rx.poll() { - Ok(Async::Ready(None)) => futures::future::ok(false), - _ => futures::future::ok(true), - }); - - // Every time you get an event from that stream, send it through the pipe - event_stream - .for_each(move |_json_value| { - if let Ok(Async::Ready(Some(json_value))) = stream.poll() { - let msg = warp::ws::Message::text(json_value.to_string()); - tx.unbounded_send(msg).expect("No send error"); - }; - Ok(()) - }) - .then(|msg| msg) - .map_err(|e| println!("{}", e)) -} - -pub fn websocket_routes() -> BoxedFilter<(User, Query, warp::ws::Ws2)> { - user_from_path!("streaming", Scope::Public) - .and(warp::query()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .and(warp::ws2()) - .map( - |user: User, - stream: query::Stream, - media: query::Media, - hashtag: query::Hashtag, - list: query::List, - ws: warp::ws::Ws2| { - let query = Query { - stream: stream.stream, - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - }; - (user, query, ws) - }, - ) - .untuple_one() - .boxed() -} - -#[derive(Debug)] -pub struct Query { - pub stream: String, - pub media: bool, - pub hashtag: String, - pub list: i64, -} diff --git a/tests/test.rs b/tests/test.rs index 431d0c2..3f1a3d4 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,7 +1,7 @@ use ragequit::{ config, - timeline::*, - user::{Filter::*, Scope, User}, + parse_client_request::sse::Request, + parse_client_request::user::{Filter::*, Scope, User}, }; #[test] @@ -10,12 +10,12 @@ fn user_unauthorized() { .path(&format!( "/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1", )) - .filter(&user()); + .filter(&Request::user()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/user",)) - .filter(&user()); + .filter(&Request::user()); assert!(no_access_token(value)); } @@ -31,7 +31,7 @@ fn user_auth() { "/api/v1/streaming/user?access_token={}", access_token )) - .filter(&user()) + .filter(&Request::user()) .expect("in test"); let expected_user = @@ -44,7 +44,7 @@ fn user_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/user") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&user()) + .filter(&Request::user()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); @@ -59,12 +59,12 @@ fn user_notifications_unauthorized() { .path(&format!( "/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN", )) - .filter(&user_notifications()); + .filter(&Request::user_notifications()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/user/notification",)) - .filter(&user_notifications()); + .filter(&Request::user_notifications()); assert!(no_access_token(value)); } @@ -80,7 +80,7 @@ fn user_notifications_auth() { "/api/v1/streaming/user/notification?access_token={}", access_token )) - .filter(&user_notifications()) + .filter(&Request::user_notifications()) .expect("in test"); let expected_user = User::from_access_token(access_token.clone(), Scope::Private) @@ -94,7 +94,7 @@ fn user_notifications_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/user/notification") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&user_notifications()) + .filter(&Request::user_notifications()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private) @@ -108,7 +108,7 @@ fn user_notifications_auth() { fn public_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public") - .filter(&public()) + .filter(&Request::public()) .expect("in test"); assert_eq!(value.0, "public".to_string()); @@ -119,7 +119,7 @@ fn public_timeline() { fn public_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public?only_media=true") - .filter(&public_media()) + .filter(&Request::public_media()) .expect("in test"); assert_eq!(value.0, "public:media".to_string()); @@ -127,7 +127,7 @@ fn public_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public?only_media=1") - .filter(&public_media()) + .filter(&Request::public_media()) .expect("in test"); assert_eq!(value.0, "public:media".to_string()); @@ -138,7 +138,7 @@ fn public_media_timeline() { fn public_local_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public/local") - .filter(&public_local()) + .filter(&Request::public_local()) .expect("in test"); assert_eq!(value.0, "public:local".to_string()); @@ -149,7 +149,7 @@ fn public_local_timeline() { fn public_local_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public/local?only_media=true") - .filter(&public_local_media()) + .filter(&Request::public_local_media()) .expect("in test"); assert_eq!(value.0, "public:local:media".to_string()); @@ -157,7 +157,7 @@ fn public_local_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public/local?only_media=1") - .filter(&public_local_media()) + .filter(&Request::public_local_media()) .expect("in test"); assert_eq!(value.0, "public:local:media".to_string()); @@ -170,12 +170,12 @@ fn direct_timeline_unauthorized() { .path(&format!( "/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN", )) - .filter(&direct()); + .filter(&Request::direct()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/direct",)) - .filter(&direct()); + .filter(&Request::direct()); assert!(no_access_token(value)); } @@ -191,7 +191,7 @@ fn direct_timeline_auth() { "/api/v1/streaming/direct?access_token={}", access_token )) - .filter(&direct()) + .filter(&Request::direct()) .expect("in test"); let expected_user = @@ -204,7 +204,7 @@ fn direct_timeline_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/direct") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&direct()) + .filter(&Request::direct()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); @@ -217,7 +217,7 @@ fn direct_timeline_auth() { fn hashtag_timeline() { let value = warp::test::request() .path("/api/v1/streaming/hashtag?tag=a") - .filter(&hashtag()) + .filter(&Request::hashtag()) .expect("in test"); assert_eq!(value.0, "hashtag:a".to_string()); @@ -228,7 +228,7 @@ fn hashtag_timeline() { fn hashtag_timeline_local() { let value = warp::test::request() .path("/api/v1/streaming/hashtag/local?tag=a") - .filter(&hashtag_local()) + .filter(&Request::hashtag_local()) .expect("in test"); assert_eq!(value.0, "hashtag:a:local".to_string()); @@ -248,7 +248,7 @@ fn list_timeline_auth() { "/api/v1/streaming/list?access_token={}&list={}", access_token, list_id, )) - .filter(&list()) + .filter(&Request::list()) .expect("in test"); let expected_user = @@ -261,7 +261,7 @@ fn list_timeline_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/list?list=1") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&list()) + .filter(&Request::list()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); @@ -276,12 +276,12 @@ fn list_timeline_unauthorized() { .path(&format!( "/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1", )) - .filter(&list()); + .filter(&Request::list()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/list?list=1",)) - .filter(&list()); + .filter(&Request::list()); assert!(no_access_token(value)); } From 866f3ee34d2de06269894901200daac7547b5f74 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Mon, 8 Jul 2019 15:21:02 -0400 Subject: [PATCH 4/4] Update documentation and restructure code --- Cargo.lock | 1 + Cargo.toml | 1 + README.md | 52 +++++++++--- src/config.rs | 74 +++++++++++------ src/lib.rs | 47 ++++++----- src/main.rs | 3 +- src/parse_client_request/mod.rs | 1 + src/parse_client_request/ws.rs | 2 +- src/redis_to_client_stream/client_agent.rs | 5 +- src/redis_to_client_stream/mod.rs | 19 ++--- src/redis_to_client_stream/receiver.rs | 92 +++++++++++----------- src/redis_to_client_stream/redis_cmd.rs | 2 + 12 files changed, 182 insertions(+), 117 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0c11ccf..6720c55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -819,6 +819,7 @@ version = "0.1.0" dependencies = [ "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "postgres 0.15.2 (registry+https://github.com/rust-lang/crates.io-index)", "pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 3bbb07f..a3514cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ pretty_env_logger = "0.3.0" postgres = "0.15.2" uuid = { version = "0.7", features = ["v4"] } dotenv = "0.14.0" +lazy_static = "1.3.0" [features] default = [ "production" ] diff --git a/README.md b/README.md index 59c33ff..fbdb9f7 100644 --- a/README.md +++ b/README.md @@ -2,31 +2,63 @@ A WIP blazingly fast drop-in replacement for the Mastodon streaming api server. ## Current status -The streaming server is very much a work in progress. It is currently missing essential features including support for SSL, CORS, and separate development/production environments. However, it has reached the point where it is usable/testable in a localhost development environment and I would greatly appreciate any testing, bug reports, or other feedback you could provide. +The streaming server is currently a work in progress. However, it is now testable and, if +configured properly, would theoretically be usable in production—though production use is +not advisable until we have completed further testing. I would greatly appreciate any testing, +bug reports, or other feedback you could provide. ## Installation -Installing the WIP version requires the Rust toolchain (the released version will be available as a pre-compiled binary). To install, clone this repository and run `cargo build` (to build the server) `cargo run` (to both build and run the server), or `cargo build --release` (to build the server with release optimizations). +Installing the WIP version requires the Rust toolchain (the released version will be available +as a pre-compiled binary). To install, clone this repository and run `cargo build` (to build +the server) `cargo run` (to both build and run the server), or `cargo build --release` (to +build the server with release optimizations). ## Connection to Mastodon -The streaming server expects to connect to a running development version of Mastodon built off of the `master` branch. Specifically, it needs to connect to both the Postgres database (to authenticate users) and to the Redis database. You should run Mastodon in whatever way you normally do and configure the streaming server to connect to the appropriate databases. +The streaming server expects to connect to a running development version of Mastodon built off of +the `master` branch. Specifically, it needs to connect to both the Postgres database (to +authenticate users) and to the Redis database. You should run Mastodon in whatever way you +normally do and configure the streaming server to connect to the appropriate databases. ## Configuring -You may edit the (currently limited) configuration variables in the `.env` file. Note that, by default, this server is configured to run on port 4000. This allows for easy testing with the development version of Mastodon (which, by default, is configured to communicate with a streaming server running on `localhost:4000`). However, it also conflicts with the current/Node.js version of Mastodon's streaming server, which runs on the same port. Thus, to test this server, you should disable the other streaming server or move it to a non-conflicting port. +You may edit the configuration variables in the `config.rs` module. You can also overwrite the +default config variables in the `.env` file. Note that, by default, this server is configured +to run on port 4000. This allows for easy testing with the development version of Mastodon +(which, by default, is configured to communicate with a streaming server running on +`localhost:4000`). However, it also conflicts with the current/Node.js version of Mastodon's +streaming server, which runs on the same port. Thus, to test this server, you should disable +the Node streaming server or move it to a non-conflicting port. ## Documentation -Build documentation with `cargo doc --open`, which will build the Markdown docs and open them in your browser. Please consult those docs for a description of the code structure/organization. +Build documentation with `cargo doc --open`, which will build the Markdown docs and open them +in your browser. Please consult those docs for a detailed description of the code +structure/organization. The documentation also contains additional notes about data flow and +options for configuration. ## Running -As noted above, you can run the server with `cargo run`. Alternatively, if you built the sever using `cargo build` or `cargo build --release`, you can run the executable produced in the `target/build/debug` folder or the `target/build/release` folder. +As noted above, you can run the server with `cargo run`. Alternatively, if you built the sever +using `cargo build` or `cargo build --release`, you can run the executable produced in the +`target/build/debug` folder or the `target/build/release` folder. ## Unit and (limited) integration tests -You can run basic unit test of the public Server Sent Event endpoints with `cargo test`. You can run integration tests of the authenticated SSE endpoints (which require a Postgres connection) with `cargo test -- --ignored`. +You can run basic unit test of the public Server Sent Event endpoints with `cargo test`. You can +run integration tests of the authenticated SSE endpoints (which require a Postgres connection) +with `cargo test -- --ignored`. ## Manual testing -Once the streaming server is running, you can also test it manually. You can test it using a browser connected to the relevant Mastodon development server. Or you can test the SSE endpoints with `curl`, PostMan, or any other HTTP client. Similarly, you can test the WebSocket endpoints with `websocat` or any other WebSocket client. +Once the streaming server is running, you can also test it manually. You can test it using a +browser connected to the relevant Mastodon development server. Or you can test the SSE endpoints +with `curl`, PostMan, or any other HTTP client. Similarly, you can test the WebSocket endpoints +with `websocat` or any other WebSocket client. ## Memory/CPU usage -Note that memory usage is higher when running the development version of the streaming server (the one generated with `cargo run` or `cargo build`). If you are interested in measuring RAM or CPU usage, you should likely run `cargo build --release` and test the release version of the executable. +Note that memory usage is higher when running the development version of the streaming server (the +one generated with `cargo run` or `cargo build`). If you are interested in measuring RAM or CPU +usage, you should likely run `cargo build --release` and test the release version of the executable. ## Load testing -I have not yet found a good way to test the streaming server under load. I have experimented with using `artillery` or other load-testing utilities. However, every utility I am familiar with or have found is built around either HTTP requests or WebSocket connections in which the client sends messages. I have not found a good solution to test receiving SSEs or WebSocket connections where the client does not transmit data after establishing the connection. If you are aware of a good way to do load testing, please let me know. +I have not yet found a good way to test the streaming server under load. I have experimented with +using `artillery` or other load-testing utilities. However, every utility I am familiar with or +have found is built around either HTTP requests or WebSocket connections in which the client sends +messages. I have not found a good solution to test receiving SSEs or WebSocket connections where +the client does not transmit data after establishing the connection. If you are aware of a good +way to do load testing, please let me know. diff --git a/src/config.rs b/src/config.rs index 19e011a..c7a1aa8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,5 +1,8 @@ -//! Configuration settings and custom errors for servers and databases +//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden +//! by an environmental variable of the same name without that prefix (either by setting +//! the variable at runtime or in the `.env` file) use dotenv::dotenv; +use lazy_static::lazy_static; use log::warn; use serde_derive::Serialize; use std::{env, net, time}; @@ -10,10 +13,46 @@ const DEFAULT_POSTGRES_ADDR: &str = "postgres://@localhost/mastodon_development" const DEFAULT_REDIS_ADDR: &str = "127.0.0.1:6379"; const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000"; -/// The frequency with which the StreamAgent will poll for updates to send via SSE -pub const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; -pub const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; -pub const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; +const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; +const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; +const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; + +lazy_static! { + static ref POSTGRES_ADDR: String = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| { + let mut postgres_addr = DEFAULT_POSTGRES_ADDR.to_string(); + postgres_addr.insert_str(11, + &env::var("USER").unwrap_or_else(|_| { + warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); + "postgres".to_string() + }).as_str() + ); + postgres_addr + }); + + static ref REDIS_ADDR: String = env::var("REDIS_ADDR").unwrap_or_else(|_| DEFAULT_REDIS_ADDR.to_owned()); + + pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR") + .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) + .parse() + .expect("static string"); + + /// Interval, in ms, at which the `ClientAgent` polls the `Receiver` for updates to send via SSE. + pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL); + /// Interval, in ms, at which the `ClientAgent` polls the `Receiver` for updates to send via WS. + pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(DEFAULT_WS_UPDATE_INTERVAL); + /// Interval, in ms, at which the `Receiver` polls Redis. + /// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` + /// (on the order of 10ms rather than 50μs). Thus, changing this setting + /// would be a good place to start for performance improvements at the cost + /// of delaying all updates. + pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(DEFAULT_REDIS_POLL_INTERVAL); +} /// Configure CORS for the API server pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { @@ -31,42 +70,25 @@ pub fn logging_and_env() { /// Configure Postgres and return a connection pub fn postgres() -> postgres::Connection { - let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| { - let mut postgres_addr = DEFAULT_POSTGRES_ADDR.to_string(); - postgres_addr.insert_str(11, - &env::var("USER").unwrap_or_else(|_| { - warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); - "postgres".to_string() - }).as_str() - ); - postgres_addr - }); - postgres::Connection::connect(postgres_addr, postgres::TlsMode::None) + postgres::Connection::connect(POSTGRES_ADDR.to_string(), postgres::TlsMode::None) .expect("Can connect to local Postgres") } /// Configure Redis pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { - let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| DEFAULT_REDIS_ADDR.to_owned()); - let pubsub_connection = net::TcpStream::connect(&redis_addr).expect("Can connect to Redis"); + let pubsub_connection = + net::TcpStream::connect(&REDIS_ADDR.to_string()).expect("Can connect to Redis"); pubsub_connection .set_read_timeout(Some(time::Duration::from_millis(10))) .expect("Can set read timeout for Redis connection"); let secondary_redis_connection = - net::TcpStream::connect(&redis_addr).expect("Can connect to Redis"); + net::TcpStream::connect(&REDIS_ADDR.to_string()).expect("Can connect to Redis"); secondary_redis_connection .set_read_timeout(Some(time::Duration::from_millis(10))) .expect("Can set read timeout for Redis connection"); (pubsub_connection, secondary_redis_connection) } -pub fn socket_address() -> net::SocketAddr { - env::var("SERVER_ADDR") - .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) - .parse() - .expect("static string") -} - #[derive(Serialize)] pub struct ErrorMessage { error: String, diff --git a/src/lib.rs b/src/lib.rs index 4d129eb..2c4ec05 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,26 +6,37 @@ //! connect to the server with the API described [in Mastodon's public API //! documentation](https://docs.joinmastodon.org/api/streaming/). //! -//! # Notes on data flow -//! * **Client Request → Warp**: -//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User` -//! representing the client that made the request with data from the client's request and from -//! Postgres. The `User` is authenticated, if appropriate. Warp //! repeatedly polls the -//! ClientAgent for information relevant to the User. +//! # Data Flow +//! * **Parsing the client request** +//! When the client request first comes in, it is parsed based on the endpoint it targets (for +//! server sent events), its query parameters, and its headers (for WebSocket). Based on this +//! data, we authenticate the user, retrieve relevant user data from Postgres, and determine the +//! timeline targeted by the request. Successfully parsing the client request results in generating +//! a `User` and target `timeline` for the request. If any requests are invalid/not authorized, we +//! reject them in this stage. +//! * **Streaming update from Redis to the client**: +//! After the user request is parsed, we pass the `User` and `timeline` data on to the +//! `ClientAgent`. The `ClientAgent` is responsible for communicating the user's request to the +//! `Receiver`, polling the `Receiver` for any updates, and then for wording those updates on to the +//! client. The `Receiver`, in tern, is responsible for managing the Redis subscriptions, +//! periodically polling Redis, and sorting the replies from Redis into queues for when it is polled +//! by the `ClientAgent`. //! -//! * **Warp → ClientAgent**: -//! A new `ClientAgent` is created for each request. The `ClientAgent` exists to manage concurrent -//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `ClientAgent` -//! polls the `Receiver` for any updates relevant to the current client. If there are updates, the -//! `ClientAgent` filters them with the client's filters and passes any matching updates up to Warp. -//! The `ClientAgent` is also responsible for sending `subscribe` commands to Redis (via the -//! `Receiver`) when necessary. +//! # Concurrency +//! The `Receiver` is created when the server is first initialized, and there is only one +//! `Receiver`. Thus, the `Receiver` is a potential bottleneck. On the other hand, each +//! client request results in a new green thread, which spawns its own `ClientAgent`. Thus, +//! their will be many `ClientAgent`s polling a single `Receiver`. Accordingly, it is very +//! important that polling the `Receiver` remain as fast as possible. It is less important +//! that the `Receiver`'s poll of Redis be fast, since there will only ever be one +//! `Receiver`. +//! +//! # Configuration +//! By default, the server uses config values from the `config.rs` module; these values can be +//! overwritten with environmental variables or in the `.env` file. The most important settings +//! for performance control the frequency with which the `ClientAgent` polls the `Receiver` and +//! the frequency with which the `Receiver` polls Redis. //! -//! * **ClientAgent → Receiver**: -//! The Receiver receives data from Redis and stores it in a series of queues (one for each -//! ClientAgent). When (asynchronously) polled by the ClientAgent, it sends back the messages -//! relevant to that ClientAgent and removes them from the queue. - pub mod config; pub mod parse_client_request; pub mod redis_to_client_stream; diff --git a/src/main.rs b/src/main.rs index ad8ee66..c989738 100644 --- a/src/main.rs +++ b/src/main.rs @@ -110,7 +110,6 @@ fn main() { .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); let cors = config::cross_origin_resource_sharing(); - let address = config::socket_address(); - warp::serve(websocket_routes.or(sse_routes).with(cors)).run(address); + warp::serve(websocket_routes.or(sse_routes).with(cors)).run(*config::SERVER_ADDR); } diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs index d870823..d6a0ef7 100644 --- a/src/parse_client_request/mod.rs +++ b/src/parse_client_request/mod.rs @@ -1,3 +1,4 @@ +//! Parse the client request and return a 'timeline' and a (maybe authenticated) `User` pub mod query; pub mod sse; pub mod user; diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index 77ab1d6..77c7c68 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -1,4 +1,4 @@ -//! WebSocket functionality +//! Filters for the WebSocket endpoint use super::{ query, user::{Scope, User}, diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs index 74ae4e8..1f5a972 100644 --- a/src/redis_to_client_stream/client_agent.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -4,7 +4,7 @@ //! The `ClientAgent`'s interface is very simple. All you can do with it is: //! * Create a totally new `ClientAgent` with no shared data; //! * Clone an existing `ClientAgent`, sharing the `Receiver`; -//! * to manage an new timeline/user pair; or +//! * Manage an new timeline/user pair; or //! * Poll an existing `ClientAgent` to see if there are any new messages //! for clients //! @@ -18,6 +18,7 @@ use super::receiver::Receiver; use crate::parse_client_request::user::User; use futures::{Async, Poll}; +use log; use serde_json::{json, Value}; use std::{sync, time}; use tokio::io::Error; @@ -94,7 +95,7 @@ impl futures::stream::Stream for ClientAgent { }; if start_time.elapsed() > time::Duration::from_millis(20) { - println!("Polling took: {:?}", start_time.elapsed()); + log::warn!("Polling took: {:?}", start_time.elapsed()); } match result { Ok(Async::Ready(Some(value))) => { diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index b8ae146..b74c372 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -1,3 +1,4 @@ +//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. pub mod client_agent; pub mod receiver; pub mod redis_cmd; @@ -5,18 +6,17 @@ pub mod redis_cmd; use crate::config; pub use client_agent::ClientAgent; use futures::{future::Future, stream::Stream, Async}; -use std::{env, time}; +use log; +use std::time; +/// Send a stream of replies to a Server Sent Events client. pub fn send_updates_to_sse( mut client_agent: ClientAgent, connection: warp::sse::Sse, ) -> impl warp::reply::Reply { - let sse_update_interval = env::var("SSE_UPDATE_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(config::DEFAULT_SSE_UPDATE_INTERVAL); let event_stream = tokio::timer::Interval::new( time::Instant::now(), - time::Duration::from_millis(sse_update_interval), + time::Duration::from_millis(*config::SSE_UPDATE_INTERVAL), ) .filter_map(move |_| match client_agent.poll() { Ok(Async::Ready(Some(json_value))) => Some(( @@ -29,7 +29,7 @@ pub fn send_updates_to_sse( connection.reply(warp::sse::keep(event_stream, None)) } -/// Send a stream of replies to a WebSocket client +/// Send a stream of replies to a WebSocket client. pub fn send_updates_to_ws( socket: warp::ws::WebSocket, mut stream: ClientAgent, @@ -49,12 +49,9 @@ pub fn send_updates_to_ws( ); // For as long as the client is still connected, yeild a new event every 100 ms - let ws_update_interval = env::var("WS_UPDATE_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(config::DEFAULT_WS_UPDATE_INTERVAL); let event_stream = tokio::timer::Interval::new( time::Instant::now(), - time::Duration::from_millis(ws_update_interval), + time::Duration::from_millis(*config::WS_UPDATE_INTERVAL), ) .take_while(move |_| match ws_rx.poll() { Ok(Async::Ready(None)) => futures::future::ok(false), @@ -71,5 +68,5 @@ pub fn send_updates_to_ws( Ok(()) }) .then(|msg| msg) - .map_err(|e| println!("{}", e)) + .map_err(|e| log::error!("{}", e)) } diff --git a/src/redis_to_client_stream/receiver.rs b/src/redis_to_client_stream/receiver.rs index 88e57f6..a7e0a12 100644 --- a/src/redis_to_client_stream/receiver.rs +++ b/src/redis_to_client_stream/receiver.rs @@ -7,11 +7,11 @@ use futures::{Async, Poll}; use log::info; use regex::Regex; use serde_json::Value; -use std::{collections, env, io::Read, io::Write, net, time}; +use std::{collections, io::Read, io::Write, net, time}; use tokio::io::{AsyncRead, Error}; use uuid::Uuid; -/// The item that streams from Redis and is polled by the `StreamManager` +/// The item that streams from Redis and is polled by the `ClientAgent` #[derive(Debug)] pub struct Receiver { pubsub_connection: net::TcpStream, @@ -53,7 +53,7 @@ impl Receiver { self.subscribe_or_unsubscribe_as_needed(timeline); } - /// Set the `Receiver`'s manager_id and target_timeline fields to the approprate + /// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate /// value to be polled by the current `StreamManager`. pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) { self.manager_id = manager_id; @@ -102,43 +102,14 @@ impl Receiver { }); // If no clients, unsubscribe from the channel if *count_of_subscribed_clients <= 0 { - info!("Sent unsubscribe command"); pubsub_cmd!("unsubscribe", self, change.timeline.clone()); } if need_to_subscribe { - info!("Sent subscribe command"); pubsub_cmd!("subscribe", self, change.timeline.clone()); } } } - /// Polls Redis for any new messages and adds them to the `MsgQueue` for - /// the appropriate `ClientAgent`. - fn poll_redis(&mut self) { - let mut buffer = vec![0u8; 3000]; - // Add any incoming messages to the back of the relevant `msg_queues` - // NOTE: This could be more/other than the `msg_queue` currently being polled - let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection); - if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() { - let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); - // capture everything between `{` and `}` as potential JSON - let json_regex = Regex::new(r"(?P\{.*\})").expect("Hard-coded"); - // capture the timeline so we know which queues to add it to - let timeline_regex = Regex::new(r"timeline:(?P.*?)\r").expect("Hard-codded"); - if let Some(result) = json_regex.captures(raw_redis_response) { - let timeline = - timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string(); - - let msg: Value = serde_json::from_str(&result["json"].to_string().clone()).unwrap(); - for msg_queue in self.msg_queues.values_mut() { - if msg_queue.redis_channel == timeline { - msg_queue.messages.push_back(msg.clone()); - } - } - } - } - } - fn log_number_of_msgs_in_queue(&self) { let messages_waiting = self .msg_queues @@ -153,6 +124,10 @@ impl Receiver { _ => log::info!("{} messages waiting in the queue", messages_waiting), } } + + fn get_target_msg_queue(&mut self) -> collections::hash_map::Entry { + self.msg_queues.entry(self.manager_id) + } } impl Default for Receiver { @@ -175,24 +150,20 @@ impl futures::stream::Stream for Receiver { fn poll(&mut self) -> Poll, Self::Error> { let timeline = self.timeline.clone(); - let redis_poll_interval = env::var("REDIS_POLL_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(config::DEFAULT_REDIS_POLL_INTERVAL); - - if self.redis_polled_at.elapsed() > time::Duration::from_millis(redis_poll_interval) { - self.poll_redis(); + if self.redis_polled_at.elapsed() + > time::Duration::from_millis(*config::REDIS_POLL_INTERVAL) + { + AsyncReadableStream::poll_redis(self); self.redis_polled_at = time::Instant::now(); } // Record current time as last polled time - self.msg_queues - .entry(self.manager_id) + self.get_target_msg_queue() .and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now()); // If the `msg_queue` being polled has any new messages, return the first (oldest) one match self - .msg_queues - .entry(self.manager_id) + .get_target_msg_queue() .or_insert_with(|| MsgQueue::new(timeline.clone())) .messages .pop_front() @@ -214,13 +185,13 @@ impl Drop for Receiver { #[derive(Debug, Clone)] struct MsgQueue { - pub messages: collections::VecDeque, - pub last_polled_at: time::Instant, - pub redis_channel: String, + messages: collections::VecDeque, + last_polled_at: time::Instant, + redis_channel: String, } impl MsgQueue { - pub fn new(redis_channel: impl std::fmt::Display) -> Self { + fn new(redis_channel: impl std::fmt::Display) -> Self { let redis_channel = redis_channel.to_string(); MsgQueue { messages: collections::VecDeque::new(), @@ -232,9 +203,36 @@ impl MsgQueue { struct AsyncReadableStream<'a>(&'a mut net::TcpStream); impl<'a> AsyncReadableStream<'a> { - pub fn new(stream: &'a mut net::TcpStream) -> Self { + fn new(stream: &'a mut net::TcpStream) -> Self { AsyncReadableStream(stream) } + /// Polls Redis for any new messages and adds them to the `MsgQueue` for + /// the appropriate `ClientAgent`. + fn poll_redis(receiver: &mut Receiver) { + let mut buffer = vec![0u8; 3000]; + // Add any incoming messages to the back of the relevant `msg_queues` + // NOTE: This could be more/other than the `msg_queue` currently being polled + + let mut async_stream = AsyncReadableStream::new(&mut receiver.pubsub_connection); + if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() { + let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); + // capture everything between `{` and `}` as potential JSON + let json_regex = Regex::new(r"(?P\{.*\})").expect("Hard-coded"); + // capture the timeline so we know which queues to add it to + let timeline_regex = Regex::new(r"timeline:(?P.*?)\r").expect("Hard-codded"); + if let Some(result) = json_regex.captures(raw_redis_response) { + let timeline = + timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string(); + + let msg: Value = serde_json::from_str(&result["json"].to_string().clone()).unwrap(); + for msg_queue in receiver.msg_queues.values_mut() { + if msg_queue.redis_channel == timeline { + msg_queue.messages.push_back(msg.clone()); + } + } + } + } + } } impl<'a> Read for AsyncReadableStream<'a> { diff --git a/src/redis_to_client_stream/redis_cmd.rs b/src/redis_to_client_stream/redis_cmd.rs index b032887..7591ac8 100644 --- a/src/redis_to_client_stream/redis_cmd.rs +++ b/src/redis_to_client_stream/redis_cmd.rs @@ -1,4 +1,5 @@ //! Send raw TCP commands to the Redis server +use log::info; use std::fmt::Display; /// Send a subscribe or unsubscribe to the Redis PubSub channel @@ -25,6 +26,7 @@ macro_rules! pubsub_cmd { pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec { let arg = format!("timeline:{}", timeline); let command = command.to_string(); + info!("Sent {} command", &command); format!( "*2\r\n${cmd_length}\r\n{cmd}\r\n${arg_length}\r\n{arg}\r\n", cmd_length = command.len(),