From d6ae45b292444ac6cb6bc62c7d873e13738b7b7e Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Mon, 8 Jul 2019 07:31:42 -0400 Subject: [PATCH] Code reorganization --- Cargo.lock | 43 +++--- src/.env | 1 + src/config.rs | 76 ++++++++-- src/error.rs | 36 ----- src/lib.rs | 29 ++-- src/main.rs | 91 ++++++------ src/parse_client_request/mod.rs | 4 + src/parse_client_request/query.rs | 45 ++++++ src/parse_client_request/sse.rs | 106 +++++++++++++ .../user/mod.rs} | 3 +- .../user}/postgres.rs | 0 src/parse_client_request/ws.rs | 43 ++++++ src/query.rs | 66 --------- .../client_agent.rs} | 114 +++++++------- src/redis_to_client_stream/mod.rs | 75 ++++++++++ src/{ => redis_to_client_stream}/receiver.rs | 140 ++++++++++-------- src/{ => redis_to_client_stream}/redis_cmd.rs | 0 src/timeline.rs | 102 ------------- src/ws.rs | 86 ----------- tests/test.rs | 52 +++---- 20 files changed, 590 insertions(+), 522 deletions(-) delete mode 100644 src/error.rs create mode 100644 src/parse_client_request/mod.rs create mode 100644 src/parse_client_request/query.rs create mode 100644 src/parse_client_request/sse.rs rename src/{user.rs => parse_client_request/user/mod.rs} (98%) rename src/{ => parse_client_request/user}/postgres.rs (100%) create mode 100644 src/parse_client_request/ws.rs delete mode 100644 src/query.rs rename src/{stream_manager.rs => redis_to_client_stream/client_agent.rs} (51%) create mode 100644 src/redis_to_client_stream/mod.rs rename src/{ => redis_to_client_stream}/receiver.rs (71%) rename src/{ => redis_to_client_stream}/redis_cmd.rs (100%) delete mode 100644 src/timeline.rs delete mode 100644 src/ws.rs diff --git a/Cargo.lock b/Cargo.lock index 767381e..0c11ccf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,7 +27,7 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)", - "termion 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)", + "termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)", "winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -144,11 +144,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "chrono" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)", - "num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)", + "num-integer 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)", "time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -246,14 +247,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "env_logger" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", "humantime 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)", - "termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", + "termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -637,16 +638,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "num-integer" -version = "0.1.39" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", + "autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] name = "num-traits" -version = "0.2.6" +version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)", +] [[package]] name = "num_cpus" @@ -782,8 +787,8 @@ name = "pretty_env_logger" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ - "chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", - "env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", + "chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)", + "env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -1168,7 +1173,7 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "wincolor 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1176,7 +1181,7 @@ dependencies = [ [[package]] name = "termion" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1597,7 +1602,7 @@ dependencies = [ "checksum bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c" "checksum cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)" = "a0c56216487bb80eec9c4516337b2588a4f2a2290d72a1416d930e4dcdb0c90d" "checksum cfg-if 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "11d43355396e872eefb45ce6342e4374ed7bc2b3a502d1b28e36d6e23c05d1f4" -"checksum chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "45912881121cb26fad7c38c17ba7daa18764771836b34fab7d3fbd93ed633878" +"checksum chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "77d81f58b7301084de3b958691458a53c3f7e0b1d702f77e550b6a88e3a88abe" "checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" "checksum constant_time_eq 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8ff012e225ce166d4422e0e78419d901719760f62ae2b7969ca6b564d1b54a9e" "checksum crossbeam-deque 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b18cd2e169ad86297e6bc0ad9aa679aee9daa4f19e8163860faf7c164e4f5a71" @@ -1609,7 +1614,7 @@ dependencies = [ "checksum digest 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05f47366984d3ad862010e22c7ce81a7dbcaebbdfb37241a620f8b6596ee135c" "checksum dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7bdb5b956a911106b6b479cdc6bc1364d359a32299f17b49994f5327132e18d9" "checksum dtoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "ea57b42383d091c85abcc2706240b94ab2a8fa1fc81c10ff23c4de06e2a90b5e" -"checksum env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b61fa891024a945da30a9581546e8cfaf5602c7b3f4c137a2805cf388f92075a" +"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3" "checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2" "checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1" "checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" @@ -1655,8 +1660,8 @@ dependencies = [ "checksum miow 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919" "checksum net2 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88" "checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945" -"checksum num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)" = "e83d528d2677f0518c570baf2b7abdcf0cd2d248860b68507bdcb3e91d4c0cea" -"checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1" +"checksum num-integer 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "8b8af8caa3184078cd419b430ff93684cb13937970fcb7639f728992f33ce674" +"checksum num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "d9c79c952a4a139f44a0fe205c4ee66ce239c0e6ce72cd935f5f7e2f717549dd" "checksum num_cpus 1.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1a23f0ed30a54abaa0c7e83b1d2d87ada7c3c23078d1d87815af3e3b6385fbba" "checksum numtoa 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef" "checksum opaque-debug 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "93f5bb2e8e8dec81642920ccff6b61f1eb94fa3020c5a325c9851ff604152409" @@ -1716,8 +1721,8 @@ dependencies = [ "checksum stringprep 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" "checksum syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)" = "a1393e4a97a19c01e900df2aec855a29f71cf02c402e2f443b8d2747c25c5dbe" "checksum synstructure 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "73687139bf99285483c96ac0add482c3776528beac1d97d444f6e91f203a2015" -"checksum termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4096add70612622289f2fdcdbd5086dc81c1e2675e6ae58d6c4f62a16c6d7f2f" -"checksum termion 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dde0593aeb8d47accea5392b39350015b5eccb12c0d98044d856983d89548dea" +"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e" +"checksum termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a8fb22f7cde82c8220e5aeacb3258ed7ce996142c77cba193f203515e26c330" "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" "checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f" "checksum tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)" = "cec6c34409089be085de9403ba2010b80e36938c9ca992c4f67f407bb13db0b1" diff --git a/src/.env b/src/.env index 9e3e3c8..94e0c32 100644 --- a/src/.env +++ b/src/.env @@ -3,3 +3,4 @@ #SERVER_ADDR= #REDIS_ADDR= #POSTGRES_ADDR= +CORS_ALLOWED_METHODS="GET OPTIONS" diff --git a/src/config.rs b/src/config.rs index fad70a4..19e011a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,14 +1,26 @@ -//! Configuration settings for servers and databases +//! Configuration settings and custom errors for servers and databases use dotenv::dotenv; use log::warn; +use serde_derive::Serialize; use std::{env, net, time}; +const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"]; +const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"]; +const DEFAULT_POSTGRES_ADDR: &str = "postgres://@localhost/mastodon_development"; +const DEFAULT_REDIS_ADDR: &str = "127.0.0.1:6379"; +const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000"; + +/// The frequency with which the StreamAgent will poll for updates to send via SSE +pub const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; +pub const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; +pub const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; + /// Configure CORS for the API server pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { warp::cors() .allow_any_origin() - .allow_methods(vec!["GET", "OPTIONS"]) - .allow_headers(vec!["Authorization", "Accept", "Cache-Control"]) + .allow_methods(CORS_ALLOWED_METHODS.to_vec()) + .allow_headers(CORS_ALLOWED_HEADERS.to_vec()) } /// Initialize logging and read values from `src/.env` @@ -20,20 +32,22 @@ pub fn logging_and_env() { /// Configure Postgres and return a connection pub fn postgres() -> postgres::Connection { let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| { - format!( - "postgres://{}@localhost/mastodon_development", - env::var("USER").unwrap_or_else(|_| { - warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); - "postgres".to_owned() - }) - ) + let mut postgres_addr = DEFAULT_POSTGRES_ADDR.to_string(); + postgres_addr.insert_str(11, + &env::var("USER").unwrap_or_else(|_| { + warn!("No USER env variable set. Connecting to Postgress with default `postgres` user"); + "postgres".to_string() + }).as_str() + ); + postgres_addr }); postgres::Connection::connect(postgres_addr, postgres::TlsMode::None) .expect("Can connect to local Postgres") } +/// Configure Redis pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { - let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| "127.0.0.1:6379".to_string()); + let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| DEFAULT_REDIS_ADDR.to_owned()); let pubsub_connection = net::TcpStream::connect(&redis_addr).expect("Can connect to Redis"); pubsub_connection .set_read_timeout(Some(time::Duration::from_millis(10))) @@ -48,7 +62,45 @@ pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { pub fn socket_address() -> net::SocketAddr { env::var("SERVER_ADDR") - .unwrap_or_else(|_| "127.0.0.1:4000".to_owned()) + .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) .parse() .expect("static string") } + +#[derive(Serialize)] +pub struct ErrorMessage { + error: String, +} +impl ErrorMessage { + fn new(msg: impl std::fmt::Display) -> Self { + Self { + error: msg.to_string(), + } + } +} + +/// Recover from Errors by sending appropriate Warp::Rejections +pub fn handle_errors( + rejection: warp::reject::Rejection, +) -> Result { + let err_txt = match rejection.cause() { + Some(text) if text.to_string() == "Missing request header 'authorization'" => { + "Error: Missing access token".to_string() + } + Some(text) => text.to_string(), + None => "Error: Nonexistant endpoint".to_string(), + }; + let json = warp::reply::json(&ErrorMessage::new(err_txt)); + Ok(warp::reply::with_status( + json, + warp::http::StatusCode::UNAUTHORIZED, + )) +} + +pub struct CustomError {} + +impl CustomError { + pub fn unauthorized_list() -> warp::reject::Rejection { + warp::reject::custom("Error: Access to list not authorized") + } +} diff --git a/src/error.rs b/src/error.rs deleted file mode 100644 index 4fa7bcd..0000000 --- a/src/error.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! Custom Errors and Warp::Rejections -use serde_derive::Serialize; - -#[derive(Serialize)] -pub struct ErrorMessage { - error: String, -} -impl ErrorMessage { - fn new(msg: impl std::fmt::Display) -> Self { - Self { - error: msg.to_string(), - } - } -} - -/// Recover from Errors by sending appropriate Warp::Rejections -pub fn handle_errors( - rejection: warp::reject::Rejection, -) -> Result { - let err_txt = match rejection.cause() { - Some(text) if text.to_string() == "Missing request header 'authorization'" => { - "Error: Missing access token".to_string() - } - Some(text) => text.to_string(), - None => "Error: Nonexistant endpoint".to_string(), - }; - let json = warp::reply::json(&ErrorMessage::new(err_txt)); - Ok(warp::reply::with_status( - json, - warp::http::StatusCode::UNAUTHORIZED, - )) -} - -pub fn unauthorized_list() -> warp::reject::Rejection { - warp::reject::custom("Error: Access to list not authorized") -} diff --git a/src/lib.rs b/src/lib.rs index a5b342f..4d129eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,28 +11,21 @@ //! Warp filters for valid requests and parses request data. Based on that data, it generates a `User` //! representing the client that made the request with data from the client's request and from //! Postgres. The `User` is authenticated, if appropriate. Warp //! repeatedly polls the -//! StreamManager for information relevant to the User. +//! ClientAgent for information relevant to the User. //! -//! * **Warp → StreamManager**: -//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent -//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `StreamManager` +//! * **Warp → ClientAgent**: +//! A new `ClientAgent` is created for each request. The `ClientAgent` exists to manage concurrent +//! access to the (single) `Receiver`, which it can access behind an `Arc`. The `ClientAgent` //! polls the `Receiver` for any updates relevant to the current client. If there are updates, the -//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp. -//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the +//! `ClientAgent` filters them with the client's filters and passes any matching updates up to Warp. +//! The `ClientAgent` is also responsible for sending `subscribe` commands to Redis (via the //! `Receiver`) when necessary. //! -//! * **StreamManager → Receiver**: +//! * **ClientAgent → Receiver**: //! The Receiver receives data from Redis and stores it in a series of queues (one for each -//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages -//! relevant to that StreamManager and removes them from the queue. +//! ClientAgent). When (asynchronously) polled by the ClientAgent, it sends back the messages +//! relevant to that ClientAgent and removes them from the queue. pub mod config; -pub mod error; -pub mod postgres; -pub mod query; -pub mod receiver; -pub mod redis_cmd; -pub mod stream_manager; -pub mod timeline; -pub mod user; -pub mod ws; +pub mod parse_client_request; +pub mod redis_to_client_stream; diff --git a/src/main.rs b/src/main.rs index 32db4d8..ad8ee66 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,83 +1,84 @@ -use futures::{stream::Stream, Async}; use ragequit::{ - any_of, config, error, - stream_manager::StreamManager, - timeline, - user::{Filter::*, User}, - ws, + any_of, config, + parse_client_request::{sse, user, ws}, + redis_to_client_stream, + redis_to_client_stream::ClientAgent, }; + use warp::{ws::Ws2, Filter as WarpFilter}; fn main() { config::logging_and_env(); - let stream_manager_sse = StreamManager::new(); - let stream_manager_ws = stream_manager_sse.clone(); + let client_agent_sse = ClientAgent::blank(); + let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); // Server Sent Events + // + // For SSE, the API requires users to use different endpoints, so we first filter based on + // the endpoint. Using that endpoint determine the `timeline` the user is requesting, + // the scope for that `timeline`, and authenticate the `User` if they provided a token. let sse_routes = any_of!( // GET /api/v1/streaming/user/notification [private; notification filter] - timeline::user_notifications(), + sse::Request::user_notifications(), // GET /api/v1/streaming/user [private; language filter] - timeline::user(), + sse::Request::user(), // GET /api/v1/streaming/public/local?only_media=true [public; language filter] - timeline::public_local_media(), + sse::Request::public_local_media(), // GET /api/v1/streaming/public?only_media=true [public; language filter] - timeline::public_media(), + sse::Request::public_media(), // GET /api/v1/streaming/public/local [public; language filter] - timeline::public_local(), + sse::Request::public_local(), // GET /api/v1/streaming/public [public; language filter] - timeline::public(), + sse::Request::public(), // GET /api/v1/streaming/direct [private; *no* filter] - timeline::direct(), + sse::Request::direct(), // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] - timeline::hashtag(), + sse::Request::hashtag(), // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] - timeline::hashtag_local(), + sse::Request::hashtag_local(), // GET /api/v1/streaming/list?list=:list_id [private; no filter] - timeline::list() + sse::Request::list() ) .untuple_one() .and(warp::sse()) - .map(move |timeline: String, user: User, sse: warp::sse::Sse| { - let mut stream_manager = stream_manager_sse.manage_new_timeline(&timeline, user); - let event_stream = tokio::timer::Interval::new( - std::time::Instant::now(), - std::time::Duration::from_millis(100), - ) - .filter_map(move |_| match stream_manager.poll() { - Ok(Async::Ready(Some(json_value))) => Some(( - warp::sse::event(json_value["event"].clone().to_string()), - warp::sse::data(json_value["payload"].clone()), - )), - _ => None, - }); - sse.reply(warp::sse::keep(event_stream, None)) - }) + .map( + move |timeline: String, user: user::User, sse_connection_to_client: warp::sse::Sse| { + // Create a new ClientAgent + let mut client_agent = client_agent_sse.clone_with_shared_receiver(); + // Assign that agent to generate a stream of updates for the user/timeline pair + client_agent.init_for_user(&timeline, user); + // send the updates through the SSE connection + redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client) + }, + ) .with(warp::reply::with::header("Connection", "keep-alive")) - .recover(error::handle_errors); + .recover(config::handle_errors); // WebSocket - let websocket_routes = ws::websocket_routes() - .and_then(move |mut user: User, q: ws::Query, ws: Ws2| { + // + // For WS, the API specifies a single endpoint, so we extract the User/timeline pair + // directy from the query + let websocket_routes = ws::extract_user_and_query() + .and_then(move |mut user: user::User, q: ws::Query, ws: Ws2| { + let token = user.access_token.clone(); let read_scope = user.scopes.clone(); + let timeline = match q.stream.as_ref() { // Public endpoints: tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl), tl @ "public:media" | tl @ "public:local:media" => tl.to_string(), tl @ "public" | tl @ "public:local" => tl.to_string(), // Hashtag endpoints: - // TODO: handle missing query tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag), // Private endpoints: User "user" if user.logged_in && (read_scope.all || read_scope.statuses) => { format!("{}", user.id) } "user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => { - user = user.set_filter(Notification); + user = user.set_filter(user::Filter::Notification); format!("{}", user.id) } // List endpoint: - // TODO: handle missing query "list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => { format!("list:{}", q.list) } @@ -92,11 +93,17 @@ fn main() { // Other endpoints don't exist: _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), }; - let token = user.access_token.clone(); - let stream_manager = stream_manager_ws.manage_new_timeline(&timeline, user); + // Create a new ClientAgent + let mut client_agent = client_agent_ws.clone_with_shared_receiver(); + // Assign that agent to generate a stream of updates for the user/timeline pair + client_agent.init_for_user(&timeline, user); + // send the updates through the WS connection (along with the User's access_token + // which is sent for security) Ok(( - ws.on_upgrade(move |socket| ws::send_replies(socket, stream_manager)), + ws.on_upgrade(move |socket| { + redis_to_client_stream::send_updates_to_ws(socket, client_agent) + }), token, )) }) diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs new file mode 100644 index 0000000..d870823 --- /dev/null +++ b/src/parse_client_request/mod.rs @@ -0,0 +1,4 @@ +pub mod query; +pub mod sse; +pub mod user; +pub mod ws; diff --git a/src/parse_client_request/query.rs b/src/parse_client_request/query.rs new file mode 100644 index 0000000..a14d559 --- /dev/null +++ b/src/parse_client_request/query.rs @@ -0,0 +1,45 @@ +//! Validate query prarams with type checking +use serde_derive::Deserialize; +use warp::filters::BoxedFilter; +use warp::Filter as WarpFilter; + +macro_rules! query { + ($name:tt => $parameter:tt:$type:tt) => { + #[derive(Deserialize, Debug, Default)] + pub struct $name { + pub $parameter: $type, + } + impl $name { + pub fn to_filter() -> BoxedFilter<(Self,)> { + warp::query() + .or(warp::any().map(Self::default)) + .unify() + .boxed() + } + } + }; +} +query!(Media => only_media:String); +impl Media { + pub fn is_truthy(&self) -> bool { + self.only_media == "true" || self.only_media == "1" + } +} +query!(Hashtag => tag: String); +query!(List => list: i64); +query!(Auth => access_token: String); +query!(Stream => stream: String); +impl ToString for Stream { + fn to_string(&self) -> String { + format!("{:?}", self) + } +} + +pub fn optional_media_query() -> BoxedFilter<(Media,)> { + warp::query() + .or(warp::any().map(|| Media { + only_media: "false".to_owned(), + })) + .unify() + .boxed() +} diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs new file mode 100644 index 0000000..fa3fa5d --- /dev/null +++ b/src/parse_client_request/sse.rs @@ -0,0 +1,106 @@ +//! Filters for all the endpoints accessible for Server Sent Event updates +use super::{ + query, + user::{Filter::*, Scope, User}, +}; +use crate::{config::CustomError, user_from_path}; +use warp::{filters::BoxedFilter, path, Filter}; + +#[allow(dead_code)] +type TimelineUser = ((String, User),); + +pub enum Request {} + +impl Request { + /// GET /api/v1/streaming/user + pub fn user() -> BoxedFilter { + user_from_path!("streaming" / "user", Scope::Private) + .map(|user: User| (user.id.to_string(), user)) + .boxed() + } + + /// GET /api/v1/streaming/user/notification + /// + /// + /// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented? + pub fn user_notifications() -> BoxedFilter { + user_from_path!("streaming" / "user" / "notification", Scope::Private) + .map(|user: User| (user.id.to_string(), user.set_filter(Notification))) + .boxed() + } + + /// GET /api/v1/streaming/public + pub fn public() -> BoxedFilter { + user_from_path!("streaming" / "public", Scope::Public) + .map(|user: User| ("public".to_owned(), user.set_filter(Language))) + .boxed() + } + + /// GET /api/v1/streaming/public?only_media=true + pub fn public_media() -> BoxedFilter { + user_from_path!("streaming" / "public", Scope::Public) + .and(warp::query()) + .map(|user: User, q: query::Media| match q.only_media.as_ref() { + "1" | "true" => ("public:media".to_owned(), user.set_filter(Language)), + _ => ("public".to_owned(), user.set_filter(Language)), + }) + .boxed() + } + + /// GET /api/v1/streaming/public/local + pub fn public_local() -> BoxedFilter { + user_from_path!("streaming" / "public" / "local", Scope::Public) + .map(|user: User| ("public:local".to_owned(), user.set_filter(Language))) + .boxed() + } + + /// GET /api/v1/streaming/public/local?only_media=true + pub fn public_local_media() -> BoxedFilter { + user_from_path!("streaming" / "public" / "local", Scope::Public) + .and(warp::query()) + .map(|user: User, q: query::Media| match q.only_media.as_ref() { + "1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)), + _ => ("public:local".to_owned(), user.set_filter(Language)), + }) + .boxed() + } + + /// GET /api/v1/streaming/direct + pub fn direct() -> BoxedFilter { + user_from_path!("streaming" / "direct", Scope::Private) + .map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))) + .boxed() + } + + /// GET /api/v1/streaming/hashtag?tag=:hashtag + pub fn hashtag() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "hashtag") + .and(warp::query()) + .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())) + .boxed() + } + + /// GET /api/v1/streaming/hashtag/local?tag=:hashtag + pub fn hashtag_local() -> BoxedFilter { + path!("api" / "v1" / "streaming" / "hashtag" / "local") + .and(warp::query()) + .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())) + .boxed() + } + + /// GET /api/v1/streaming/list?list=:list_id + pub fn list() -> BoxedFilter { + user_from_path!("streaming" / "list", Scope::Private) + .and(warp::query()) + .and_then(|user: User, q: query::List| { + if user.owns_list(q.list) { + (Ok(q.list), Ok(user)) + } else { + (Err(CustomError::unauthorized_list()), Ok(user)) + } + }) + .untuple_one() + .map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter))) + .boxed() + } +} diff --git a/src/user.rs b/src/parse_client_request/user/mod.rs similarity index 98% rename from src/user.rs rename to src/parse_client_request/user/mod.rs index c8db5de..5684ccc 100644 --- a/src/user.rs +++ b/src/parse_client_request/user/mod.rs @@ -1,5 +1,6 @@ //! `User` struct and related functionality -use crate::{postgres, query}; +mod postgres; +use crate::parse_client_request::query; use log::info; use warp::Filter as WarpFilter; diff --git a/src/postgres.rs b/src/parse_client_request/user/postgres.rs similarity index 100% rename from src/postgres.rs rename to src/parse_client_request/user/postgres.rs diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs new file mode 100644 index 0000000..77ab1d6 --- /dev/null +++ b/src/parse_client_request/ws.rs @@ -0,0 +1,43 @@ +//! WebSocket functionality +use super::{ + query, + user::{Scope, User}, +}; +use crate::user_from_path; +use warp::{filters::BoxedFilter, path, Filter}; + +/// WebSocket filters +pub fn extract_user_and_query() -> BoxedFilter<(User, Query, warp::ws::Ws2)> { + user_from_path!("streaming", Scope::Public) + .and(warp::query()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .and(warp::ws2()) + .map( + |user: User, + stream: query::Stream, + media: query::Media, + hashtag: query::Hashtag, + list: query::List, + ws: warp::ws::Ws2| { + let query = Query { + stream: stream.stream, + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + }; + (user, query, ws) + }, + ) + .untuple_one() + .boxed() +} + +#[derive(Debug)] +pub struct Query { + pub stream: String, + pub media: bool, + pub hashtag: String, + pub list: i64, +} diff --git a/src/query.rs b/src/query.rs deleted file mode 100644 index 8b43d71..0000000 --- a/src/query.rs +++ /dev/null @@ -1,66 +0,0 @@ -//! Validate query prarams with type checking -use serde_derive::Deserialize; -use warp::filters::BoxedFilter; -use warp::Filter as WarpFilter; - -#[derive(Deserialize, Debug, Default)] -pub struct Media { - pub only_media: String, -} -impl Media { - pub fn to_filter() -> BoxedFilter<(Self,)> { - warp::query() - .or(warp::any().map(Self::default)) - .unify() - .boxed() - } - pub fn is_truthy(&self) -> bool { - self.only_media == "true" || self.only_media == "1" - } -} -#[derive(Deserialize, Debug, Default)] -pub struct Hashtag { - pub tag: String, -} -impl Hashtag { - pub fn to_filter() -> BoxedFilter<(Self,)> { - warp::query() - .or(warp::any().map(Self::default)) - .unify() - .boxed() - } -} -#[derive(Deserialize, Debug, Default)] -pub struct List { - pub list: i64, -} -impl List { - pub fn to_filter() -> BoxedFilter<(Self,)> { - warp::query() - .or(warp::any().map(Self::default)) - .unify() - .boxed() - } -} -#[derive(Deserialize, Debug)] -pub struct Auth { - pub access_token: String, -} -#[derive(Deserialize, Debug)] -pub struct Stream { - pub stream: String, -} -impl ToString for Stream { - fn to_string(&self) -> String { - format!("{:?}", self) - } -} - -pub fn optional_media_query() -> BoxedFilter<(Media,)> { - warp::query() - .or(warp::any().map(|| Media { - only_media: "false".to_owned(), - })) - .unify() - .boxed() -} diff --git a/src/stream_manager.rs b/src/redis_to_client_stream/client_agent.rs similarity index 51% rename from src/stream_manager.rs rename to src/redis_to_client_stream/client_agent.rs index ecffd8a..74ae4e8 100644 --- a/src/stream_manager.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -1,45 +1,41 @@ -//! The `StreamManager` is responsible to providing an interface between the `Warp` -//! filters and the underlying mechanics of talking with Redis/managing multiple -//! threads. The `StreamManager` is the only struct that any Warp code should -//! need to communicate with. +//! Provides an interface between the `Warp` filters and the underlying +//! mechanics of talking with Redis/managing multiple threads. //! -//! The `StreamManager`'s interface is very simple. All you can do with it is: -//! * Create a totally new `StreamManger` with no shared data; -//! * Assign an existing `StreamManager` to manage an new timeline/user pair; or -//! * Poll an existing `StreamManager` to see if there are any new messages +//! The `ClientAgent`'s interface is very simple. All you can do with it is: +//! * Create a totally new `ClientAgent` with no shared data; +//! * Clone an existing `ClientAgent`, sharing the `Receiver`; +//! * to manage an new timeline/user pair; or +//! * Poll an existing `ClientAgent` to see if there are any new messages //! for clients //! -//! When you poll the `StreamManager`, it is responsible for polling internal data +//! When you poll the `ClientAgent`, it is responsible for polling internal data //! structures, getting any updates from Redis, and then filtering out any updates //! that should be excluded by relevant filters. //! //! Because `StreamManagers` are lightweight data structures that do not directly -//! communicate with Redis, it is appropriate to create a new `StreamManager` for -//! each new client connection. -use crate::{ - receiver::Receiver, - user::{Filter, User}, -}; +//! communicate with Redis, it we create a new `ClientAgent` for +//! each new client connection (each in its own thread). +use super::receiver::Receiver; +use crate::parse_client_request::user::User; use futures::{Async, Poll}; use serde_json::{json, Value}; -use std::sync; -use std::time; +use std::{sync, time}; use tokio::io::Error; use uuid::Uuid; /// Struct for managing all Redis streams. #[derive(Clone, Default, Debug)] -pub struct StreamManager { +pub struct ClientAgent { receiver: sync::Arc>, id: uuid::Uuid, target_timeline: String, current_user: User, } -impl StreamManager { - /// Create a new `StreamManager` with no shared data. - pub fn new() -> Self { - StreamManager { +impl ClientAgent { + /// Create a new `ClientAgent` with no shared data. + pub fn blank() -> Self { + ClientAgent { receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())), id: Uuid::default(), target_timeline: String::new(), @@ -47,38 +43,44 @@ impl StreamManager { } } - /// Assign the `StreamManager` to manage a new timeline/user pair. + /// Clones the `ClientAgent`, sharing the `Receiver`. + pub fn clone_with_shared_receiver(&self) -> Self { + Self { + receiver: self.receiver.clone(), + id: self.id, + target_timeline: self.target_timeline.clone(), + current_user: self.current_user.clone(), + } + } + /// Initializes the `ClientAgent` with a unique ID, a `User`, and the target timeline. + /// Also passes values to the `Receiver` for it's initialization. /// /// Note that this *may or may not* result in a new Redis connection. /// If the server has already subscribed to the timeline on behalf of - /// a different user, the `StreamManager` is responsible for figuring + /// a different user, the `Receiver` is responsible for figuring /// that out and avoiding duplicated connections. Thus, it is safe to /// use this method for each new client connection. - pub fn manage_new_timeline(&self, target_timeline: &str, user: User) -> Self { - let manager_id = Uuid::new_v4(); + pub fn init_for_user(&mut self, target_timeline: &str, user: User) { + self.id = Uuid::new_v4(); + self.target_timeline = target_timeline.to_owned(); + self.current_user = user; let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)"); - receiver.manage_new_timeline(manager_id, target_timeline); - StreamManager { - id: manager_id, - current_user: user, - target_timeline: target_timeline.to_owned(), - receiver: self.receiver.clone(), - } + receiver.manage_new_timeline(self.id, target_timeline); } } -/// The stream that the `StreamManager` manages. `Poll` is the only method implemented. -impl futures::stream::Stream for StreamManager { +/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented. +impl futures::stream::Stream for ClientAgent { type Item = Value; type Error = Error; /// Checks for any new messages that should be sent to the client. /// - /// The `StreamManager` will poll underlying data structures and will reply - /// with an `Ok(Ready(Some(Value)))` if there is a new message to send to + /// The `ClientAgent` polls the `Receiver` and replies + /// with `Ok(Ready(Some(Value)))` if there is a new message to send to /// the client. If there is no new message or if the new message should be - /// filtered out based on one of the user's filters, then the `StreamManager` - /// will reply with `Ok(NotReady)`. The `StreamManager` will buble up any + /// filtered out based on one of the user's filters, then the `ClientAgent` + /// replies with `Ok(NotReady)`. The `ClientAgent` bubles up any /// errors from the underlying data structures. fn poll(&mut self) -> Poll, Self::Error> { let start_time = time::Instant::now(); @@ -86,30 +88,35 @@ impl futures::stream::Stream for StreamManager { let mut receiver = self .receiver .lock() - .expect("StreamManager: No other thread panic"); + .expect("ClientAgent: No other thread panic"); receiver.configure_for_polling(self.id, &self.target_timeline.clone()); receiver.poll() }; - println!("Polling took: {:?}", start_time.elapsed()); - let result = match result { + + if start_time.elapsed() > time::Duration::from_millis(20) { + println!("Polling took: {:?}", start_time.elapsed()); + } + match result { Ok(Async::Ready(Some(value))) => { - let user_langs = self.current_user.langs.clone(); + let user = &self.current_user; let toot = Toot::from_json(value); - toot.ignore_if_caught_by_filter(&self.current_user.filter, user_langs) + toot.filter(&user) } Ok(inner_value) => Ok(inner_value), Err(e) => Err(e), - }; - result + } } } +/// The message to send to the client (which might not literally be a toot in some cases). struct Toot { category: String, payload: String, language: String, } + impl Toot { + /// Construct a `Toot` from well-formed JSON. fn from_json(value: Value) -> Self { Self { category: value["event"].as_str().expect("Redis string").to_owned(), @@ -121,6 +128,7 @@ impl Toot { } } + /// Convert a `Toot` to JSON inside an Option. fn to_optional_json(&self) -> Option { Some(json!( {"event": self.category, @@ -128,11 +136,8 @@ impl Toot { )) } - fn ignore_if_caught_by_filter( - &self, - filter: &Filter, - user_langs: Option>, - ) -> Result>, Error> { + /// Filter out any `Toot`'s that fail the provided filter. + fn filter(&self, user: &User) -> Result>, Error> { let toot = self; let (send_msg, skip_msg) = ( @@ -140,13 +145,14 @@ impl Toot { Ok(Async::NotReady), ); - match &filter { + use crate::parse_client_request::user::Filter; + match &user.filter { Filter::NoFilter => send_msg, Filter::Notification if toot.category == "notification" => send_msg, // If not, skip it Filter::Notification => skip_msg, - Filter::Language if user_langs.is_none() => send_msg, - Filter::Language if user_langs.expect("").contains(&toot.language) => send_msg, + Filter::Language if user.langs.is_none() => send_msg, + Filter::Language if user.langs.clone().expect("").contains(&toot.language) => send_msg, // If not, skip it Filter::Language => skip_msg, } diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs new file mode 100644 index 0000000..b8ae146 --- /dev/null +++ b/src/redis_to_client_stream/mod.rs @@ -0,0 +1,75 @@ +pub mod client_agent; +pub mod receiver; +pub mod redis_cmd; + +use crate::config; +pub use client_agent::ClientAgent; +use futures::{future::Future, stream::Stream, Async}; +use std::{env, time}; + +pub fn send_updates_to_sse( + mut client_agent: ClientAgent, + connection: warp::sse::Sse, +) -> impl warp::reply::Reply { + let sse_update_interval = env::var("SSE_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(config::DEFAULT_SSE_UPDATE_INTERVAL); + let event_stream = tokio::timer::Interval::new( + time::Instant::now(), + time::Duration::from_millis(sse_update_interval), + ) + .filter_map(move |_| match client_agent.poll() { + Ok(Async::Ready(Some(json_value))) => Some(( + warp::sse::event(json_value["event"].clone().to_string()), + warp::sse::data(json_value["payload"].clone()), + )), + _ => None, + }); + + connection.reply(warp::sse::keep(event_stream, None)) +} + +/// Send a stream of replies to a WebSocket client +pub fn send_updates_to_ws( + socket: warp::ws::WebSocket, + mut stream: ClientAgent, +) -> impl futures::future::Future { + let (ws_tx, mut ws_rx) = socket.split(); + + // Create a pipe + let (tx, rx) = futures::sync::mpsc::unbounded(); + + // Send one end of it to a different thread and tell that end to forward whatever it gets + // on to the websocket client + warp::spawn( + rx.map_err(|()| -> warp::Error { unreachable!() }) + .forward(ws_tx) + .map_err(|_| ()) + .map(|_r| ()), + ); + + // For as long as the client is still connected, yeild a new event every 100 ms + let ws_update_interval = env::var("WS_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(config::DEFAULT_WS_UPDATE_INTERVAL); + let event_stream = tokio::timer::Interval::new( + time::Instant::now(), + time::Duration::from_millis(ws_update_interval), + ) + .take_while(move |_| match ws_rx.poll() { + Ok(Async::Ready(None)) => futures::future::ok(false), + _ => futures::future::ok(true), + }); + + // Every time you get an event from that stream, send it through the pipe + event_stream + .for_each(move |_json_value| { + if let Ok(Async::Ready(Some(json_value))) = stream.poll() { + let msg = warp::ws::Message::text(json_value.to_string()); + tx.unbounded_send(msg).expect("No send error"); + }; + Ok(()) + }) + .then(|msg| msg) + .map_err(|e| println!("{}", e)) +} diff --git a/src/receiver.rs b/src/redis_to_client_stream/receiver.rs similarity index 71% rename from src/receiver.rs rename to src/redis_to_client_stream/receiver.rs index 4bf7d04..88e57f6 100644 --- a/src/receiver.rs +++ b/src/redis_to_client_stream/receiver.rs @@ -1,16 +1,13 @@ -//! Interface with Redis and stream the results to the `StreamManager` -//! There is only one `Receiver`, which suggests that it's name is bad. -//! -//! **TODO**: Consider changing the name. Maybe RedisConnectionPool? -//! There are many AsyncReadableStreams, though. How do they fit in? -//! Figure this out ASAP. -//! A new one is created every time the Receiver is polled -use crate::{config, pubsub_cmd, redis_cmd}; +//! Receives data from Redis, sorts it by `ClientAgent`, and stores it until +//! polled by the correct `ClientAgent`. Also manages sububscriptions and +//! unsubscriptions to/from Redis. +use super::redis_cmd; +use crate::{config, pubsub_cmd}; use futures::{Async, Poll}; use log::info; use regex::Regex; use serde_json::Value; -use std::{collections, io::Read, io::Write, net, time}; +use std::{collections, env, io::Read, io::Write, net, time}; use tokio::io::{AsyncRead, Error}; use uuid::Uuid; @@ -19,7 +16,8 @@ use uuid::Uuid; pub struct Receiver { pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, - tl: String, + redis_polled_at: time::Instant, + timeline: String, manager_id: Uuid, msg_queues: collections::HashMap, clients_per_timeline: collections::HashMap, @@ -33,7 +31,8 @@ impl Receiver { Self { pubsub_connection, secondary_redis_connection, - tl: String::new(), + redis_polled_at: time::Instant::now(), + timeline: String::new(), manager_id: Uuid::default(), msg_queues: collections::HashMap::new(), clients_per_timeline: collections::HashMap::new(), @@ -43,60 +42,60 @@ impl Receiver { /// Assigns the `Receiver` a new timeline to monitor and runs other /// first-time setup. /// - /// Importantly, this method calls `subscribe_or_unsubscribe_as_needed`, + /// Note: this method calls `subscribe_or_unsubscribe_as_needed`, /// so Redis PubSub subscriptions are only updated when a new timeline /// comes under management for the first time. pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) { self.manager_id = manager_id; - self.tl = timeline.to_string(); - let old_value = self - .msg_queues + self.timeline = timeline.to_string(); + self.msg_queues .insert(self.manager_id, MsgQueue::new(timeline)); - // Consider removing/refactoring - if let Some(value) = old_value { - eprintln!( - "Data was overwritten when it shouldn't have been. Old data was: {:#?}", - value - ); - } self.subscribe_or_unsubscribe_as_needed(timeline); } /// Set the `Receiver`'s manager_id and target_timeline fields to the approprate /// value to be polled by the current `StreamManager`. pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) { - if &manager_id != &self.manager_id { - //println!("New Manager: {}", &manager_id); - } self.manager_id = manager_id; - self.tl = timeline.to_string(); + self.timeline = timeline.to_string(); } /// Drop any PubSub subscriptions that don't have active clients and check /// that there's a subscription to the current one. If there isn't, then /// subscribe to it. - fn subscribe_or_unsubscribe_as_needed(&mut self, tl: &str) { + fn subscribe_or_unsubscribe_as_needed(&mut self, timeline: &str) { let mut timelines_to_modify = Vec::new(); - timelines_to_modify.push((tl.to_owned(), 1)); + struct Change { + timeline: String, + change_in_subscriber_number: i32, + } + + timelines_to_modify.push(Change { + timeline: timeline.to_owned(), + change_in_subscriber_number: 1, + }); // Keep only message queues that have been polled recently self.msg_queues.retain(|_id, msg_queue| { if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) { true } else { - let timeline = msg_queue.redis_channel.clone(); - timelines_to_modify.push((timeline, -1)); + let timeline = &msg_queue.redis_channel; + timelines_to_modify.push(Change { + timeline: timeline.to_owned(), + change_in_subscriber_number: -1, + }); false } }); // Record the lower number of clients subscribed to that channel - for (timeline, numerical_change) in timelines_to_modify { + for change in timelines_to_modify { let mut need_to_subscribe = false; let count_of_subscribed_clients = self .clients_per_timeline - .entry(timeline.to_owned()) - .and_modify(|n| *n += numerical_change) + .entry(change.timeline.clone()) + .and_modify(|n| *n += change.change_in_subscriber_number) .or_insert_with(|| { need_to_subscribe = true; 1 @@ -104,11 +103,38 @@ impl Receiver { // If no clients, unsubscribe from the channel if *count_of_subscribed_clients <= 0 { info!("Sent unsubscribe command"); - pubsub_cmd!("unsubscribe", self, timeline.clone()); + pubsub_cmd!("unsubscribe", self, change.timeline.clone()); } if need_to_subscribe { info!("Sent subscribe command"); - pubsub_cmd!("subscribe", self, timeline.clone()); + pubsub_cmd!("subscribe", self, change.timeline.clone()); + } + } + } + + /// Polls Redis for any new messages and adds them to the `MsgQueue` for + /// the appropriate `ClientAgent`. + fn poll_redis(&mut self) { + let mut buffer = vec![0u8; 3000]; + // Add any incoming messages to the back of the relevant `msg_queues` + // NOTE: This could be more/other than the `msg_queue` currently being polled + let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection); + if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() { + let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); + // capture everything between `{` and `}` as potential JSON + let json_regex = Regex::new(r"(?P\{.*\})").expect("Hard-coded"); + // capture the timeline so we know which queues to add it to + let timeline_regex = Regex::new(r"timeline:(?P.*?)\r").expect("Hard-codded"); + if let Some(result) = json_regex.captures(raw_redis_response) { + let timeline = + timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string(); + + let msg: Value = serde_json::from_str(&result["json"].to_string().clone()).unwrap(); + for msg_queue in self.msg_queues.values_mut() { + if msg_queue.redis_channel == timeline { + msg_queue.messages.push_back(msg.clone()); + } + } } } } @@ -128,47 +154,41 @@ impl Receiver { } } } + impl Default for Receiver { fn default() -> Self { Receiver::new() } } +/// The stream that the ClientAgent polls to learn about new messages. impl futures::stream::Stream for Receiver { type Item = Value; type Error = Error; + /// Returns the oldest message in the `ClientAgent`'s queue (if any). + /// + /// Note: This method does **not** poll Redis every time, because polling + /// Redis is signifiantly more time consuming that simply returning the + /// message already in a queue. Thus, we only poll Redis if it has not + /// been polled lately. fn poll(&mut self) -> Poll, Self::Error> { - let mut buffer = vec![0u8; 3000]; - let timeline = self.tl.clone(); + let timeline = self.timeline.clone(); + + let redis_poll_interval = env::var("REDIS_POLL_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(config::DEFAULT_REDIS_POLL_INTERVAL); + + if self.redis_polled_at.elapsed() > time::Duration::from_millis(redis_poll_interval) { + self.poll_redis(); + self.redis_polled_at = time::Instant::now(); + } // Record current time as last polled time self.msg_queues .entry(self.manager_id) .and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now()); - // Add any incomming messages to the back of the relevant `msg_queues` - // NOTE: This could be more/other than the `msg_queue` currently being polled - let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection); - if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer)? { - let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]); - // capture everything between `{` and `}` as potential JSON - let json_regex = Regex::new(r"(?P\{.*\})").expect("Hard-coded"); - // capture the timeline so we know which queues to add it to - let timeline_regex = Regex::new(r"timeline:(?P.*?)\r").expect("Hard-codded"); - if let Some(result) = json_regex.captures(raw_redis_response) { - let timeline = - timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string(); - - let msg: Value = serde_json::from_str(&result["json"].to_string().clone())?; - for msg_queue in self.msg_queues.values_mut() { - if msg_queue.redis_channel == timeline { - msg_queue.messages.push_back(msg.clone()); - } - } - } - } - // If the `msg_queue` being polled has any new messages, return the first (oldest) one match self .msg_queues @@ -188,7 +208,7 @@ impl futures::stream::Stream for Receiver { impl Drop for Receiver { fn drop(&mut self) { - pubsub_cmd!("unsubscribe", self, self.tl.clone()); + pubsub_cmd!("unsubscribe", self, self.timeline.clone()); } } diff --git a/src/redis_cmd.rs b/src/redis_to_client_stream/redis_cmd.rs similarity index 100% rename from src/redis_cmd.rs rename to src/redis_to_client_stream/redis_cmd.rs diff --git a/src/timeline.rs b/src/timeline.rs deleted file mode 100644 index 71b6ab2..0000000 --- a/src/timeline.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Filters for all the endpoints accessible for Server Sent Event updates -use crate::error; -use crate::query; -use crate::user::{Filter::*, Scope, User}; -use crate::user_from_path; -use warp::filters::BoxedFilter; -use warp::{path, Filter}; - -#[allow(dead_code)] -type TimelineUser = ((String, User),); - -/// GET /api/v1/streaming/user -pub fn user() -> BoxedFilter { - user_from_path!("streaming" / "user", Scope::Private) - .map(|user: User| (user.id.to_string(), user)) - .boxed() -} - -/// GET /api/v1/streaming/user/notification -/// -/// -/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented? -pub fn user_notifications() -> BoxedFilter { - user_from_path!("streaming" / "user" / "notification", Scope::Private) - .map(|user: User| (user.id.to_string(), user.set_filter(Notification))) - .boxed() -} - -/// GET /api/v1/streaming/public -pub fn public() -> BoxedFilter { - user_from_path!("streaming" / "public", Scope::Public) - .map(|user: User| ("public".to_owned(), user.set_filter(Language))) - .boxed() -} - -/// GET /api/v1/streaming/public?only_media=true -pub fn public_media() -> BoxedFilter { - user_from_path!("streaming" / "public", Scope::Public) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:media".to_owned(), user.set_filter(Language)), - _ => ("public".to_owned(), user.set_filter(Language)), - }) - .boxed() -} - -/// GET /api/v1/streaming/public/local -pub fn public_local() -> BoxedFilter { - user_from_path!("streaming" / "public" / "local", Scope::Public) - .map(|user: User| ("public:local".to_owned(), user.set_filter(Language))) - .boxed() -} - -/// GET /api/v1/streaming/public/local?only_media=true -pub fn public_local_media() -> BoxedFilter { - user_from_path!("streaming" / "public" / "local", Scope::Public) - .and(warp::query()) - .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)), - _ => ("public:local".to_owned(), user.set_filter(Language)), - }) - .boxed() -} - -/// GET /api/v1/streaming/direct -pub fn direct() -> BoxedFilter { - user_from_path!("streaming" / "direct", Scope::Private) - .map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter))) - .boxed() -} - -/// GET /api/v1/streaming/hashtag?tag=:hashtag -pub fn hashtag() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "hashtag") - .and(warp::query()) - .map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public())) - .boxed() -} - -/// GET /api/v1/streaming/hashtag/local?tag=:hashtag -pub fn hashtag_local() -> BoxedFilter { - path!("api" / "v1" / "streaming" / "hashtag" / "local") - .and(warp::query()) - .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())) - .boxed() -} - -/// GET /api/v1/streaming/list?list=:list_id -pub fn list() -> BoxedFilter { - user_from_path!("streaming" / "list", Scope::Private) - .and(warp::query()) - .and_then(|user: User, q: query::List| { - if user.owns_list(q.list) { - (Ok(q.list), Ok(user)) - } else { - (Err(error::unauthorized_list()), Ok(user)) - } - }) - .untuple_one() - .map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter))) - .boxed() -} diff --git a/src/ws.rs b/src/ws.rs deleted file mode 100644 index 13c2907..0000000 --- a/src/ws.rs +++ /dev/null @@ -1,86 +0,0 @@ -//! WebSocket-specific functionality -use crate::query; -use crate::stream_manager::StreamManager; -use crate::user::{Scope, User}; -use crate::user_from_path; -use futures::future::Future; -use futures::stream::Stream; -use futures::Async; -use std::time; -use warp::filters::BoxedFilter; -use warp::{path, Filter}; - -/// Send a stream of replies to a WebSocket client -pub fn send_replies( - socket: warp::ws::WebSocket, - mut stream: StreamManager, -) -> impl futures::future::Future { - let (ws_tx, mut ws_rx) = socket.split(); - - // Create a pipe - let (tx, rx) = futures::sync::mpsc::unbounded(); - - // Send one end of it to a different thread and tell that end to forward whatever it gets - // on to the websocket client - warp::spawn( - rx.map_err(|()| -> warp::Error { unreachable!() }) - .forward(ws_tx) - .map_err(|_| ()) - .map(|_r| ()), - ); - - // For as long as the client is still connected, yeild a new event every 100 ms - let event_stream = - tokio::timer::Interval::new(time::Instant::now(), time::Duration::from_millis(100)) - .take_while(move |_| match ws_rx.poll() { - Ok(Async::Ready(None)) => futures::future::ok(false), - _ => futures::future::ok(true), - }); - - // Every time you get an event from that stream, send it through the pipe - event_stream - .for_each(move |_json_value| { - if let Ok(Async::Ready(Some(json_value))) = stream.poll() { - let msg = warp::ws::Message::text(json_value.to_string()); - tx.unbounded_send(msg).expect("No send error"); - }; - Ok(()) - }) - .then(|msg| msg) - .map_err(|e| println!("{}", e)) -} - -pub fn websocket_routes() -> BoxedFilter<(User, Query, warp::ws::Ws2)> { - user_from_path!("streaming", Scope::Public) - .and(warp::query()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .and(warp::ws2()) - .map( - |user: User, - stream: query::Stream, - media: query::Media, - hashtag: query::Hashtag, - list: query::List, - ws: warp::ws::Ws2| { - let query = Query { - stream: stream.stream, - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - }; - (user, query, ws) - }, - ) - .untuple_one() - .boxed() -} - -#[derive(Debug)] -pub struct Query { - pub stream: String, - pub media: bool, - pub hashtag: String, - pub list: i64, -} diff --git a/tests/test.rs b/tests/test.rs index 431d0c2..3f1a3d4 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,7 +1,7 @@ use ragequit::{ config, - timeline::*, - user::{Filter::*, Scope, User}, + parse_client_request::sse::Request, + parse_client_request::user::{Filter::*, Scope, User}, }; #[test] @@ -10,12 +10,12 @@ fn user_unauthorized() { .path(&format!( "/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1", )) - .filter(&user()); + .filter(&Request::user()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/user",)) - .filter(&user()); + .filter(&Request::user()); assert!(no_access_token(value)); } @@ -31,7 +31,7 @@ fn user_auth() { "/api/v1/streaming/user?access_token={}", access_token )) - .filter(&user()) + .filter(&Request::user()) .expect("in test"); let expected_user = @@ -44,7 +44,7 @@ fn user_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/user") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&user()) + .filter(&Request::user()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); @@ -59,12 +59,12 @@ fn user_notifications_unauthorized() { .path(&format!( "/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN", )) - .filter(&user_notifications()); + .filter(&Request::user_notifications()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/user/notification",)) - .filter(&user_notifications()); + .filter(&Request::user_notifications()); assert!(no_access_token(value)); } @@ -80,7 +80,7 @@ fn user_notifications_auth() { "/api/v1/streaming/user/notification?access_token={}", access_token )) - .filter(&user_notifications()) + .filter(&Request::user_notifications()) .expect("in test"); let expected_user = User::from_access_token(access_token.clone(), Scope::Private) @@ -94,7 +94,7 @@ fn user_notifications_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/user/notification") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&user_notifications()) + .filter(&Request::user_notifications()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private) @@ -108,7 +108,7 @@ fn user_notifications_auth() { fn public_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public") - .filter(&public()) + .filter(&Request::public()) .expect("in test"); assert_eq!(value.0, "public".to_string()); @@ -119,7 +119,7 @@ fn public_timeline() { fn public_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public?only_media=true") - .filter(&public_media()) + .filter(&Request::public_media()) .expect("in test"); assert_eq!(value.0, "public:media".to_string()); @@ -127,7 +127,7 @@ fn public_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public?only_media=1") - .filter(&public_media()) + .filter(&Request::public_media()) .expect("in test"); assert_eq!(value.0, "public:media".to_string()); @@ -138,7 +138,7 @@ fn public_media_timeline() { fn public_local_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public/local") - .filter(&public_local()) + .filter(&Request::public_local()) .expect("in test"); assert_eq!(value.0, "public:local".to_string()); @@ -149,7 +149,7 @@ fn public_local_timeline() { fn public_local_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public/local?only_media=true") - .filter(&public_local_media()) + .filter(&Request::public_local_media()) .expect("in test"); assert_eq!(value.0, "public:local:media".to_string()); @@ -157,7 +157,7 @@ fn public_local_media_timeline() { let value = warp::test::request() .path("/api/v1/streaming/public/local?only_media=1") - .filter(&public_local_media()) + .filter(&Request::public_local_media()) .expect("in test"); assert_eq!(value.0, "public:local:media".to_string()); @@ -170,12 +170,12 @@ fn direct_timeline_unauthorized() { .path(&format!( "/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN", )) - .filter(&direct()); + .filter(&Request::direct()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/direct",)) - .filter(&direct()); + .filter(&Request::direct()); assert!(no_access_token(value)); } @@ -191,7 +191,7 @@ fn direct_timeline_auth() { "/api/v1/streaming/direct?access_token={}", access_token )) - .filter(&direct()) + .filter(&Request::direct()) .expect("in test"); let expected_user = @@ -204,7 +204,7 @@ fn direct_timeline_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/direct") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&direct()) + .filter(&Request::direct()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); @@ -217,7 +217,7 @@ fn direct_timeline_auth() { fn hashtag_timeline() { let value = warp::test::request() .path("/api/v1/streaming/hashtag?tag=a") - .filter(&hashtag()) + .filter(&Request::hashtag()) .expect("in test"); assert_eq!(value.0, "hashtag:a".to_string()); @@ -228,7 +228,7 @@ fn hashtag_timeline() { fn hashtag_timeline_local() { let value = warp::test::request() .path("/api/v1/streaming/hashtag/local?tag=a") - .filter(&hashtag_local()) + .filter(&Request::hashtag_local()) .expect("in test"); assert_eq!(value.0, "hashtag:a:local".to_string()); @@ -248,7 +248,7 @@ fn list_timeline_auth() { "/api/v1/streaming/list?access_token={}&list={}", access_token, list_id, )) - .filter(&list()) + .filter(&Request::list()) .expect("in test"); let expected_user = @@ -261,7 +261,7 @@ fn list_timeline_auth() { let (actual_timeline, actual_user) = warp::test::request() .path("/api/v1/streaming/list?list=1") .header("Authorization", format!("Bearer: {}", access_token.clone())) - .filter(&list()) + .filter(&Request::list()) .expect("in test"); let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test"); @@ -276,12 +276,12 @@ fn list_timeline_unauthorized() { .path(&format!( "/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1", )) - .filter(&list()); + .filter(&Request::list()); assert!(invalid_access_token(value)); let value = warp::test::request() .path(&format!("/api/v1/streaming/list?list=1",)) - .filter(&list()); + .filter(&Request::list()); assert!(no_access_token(value)); }