From 1657113c58e01d5b6466840c5565810a1409a7a4 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Thu, 9 Apr 2020 13:32:36 -0400 Subject: [PATCH] Stream events via a watch channel (#128) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This squashed commit makes a fairly significant structural change to significantly reduce Flodgatt's CPU usage. Flodgatt connects to Redis in a single (green) thread, and then creates a new thread to handle each WebSocket/SSE connection. Previously, each thread was responsible for polling the Redis thread to determine whether it had a message relevant to the connected client. I initially selected this structure both because it was simple and because it minimized memory overhead – no messages are sent to a particular thread unless they are relevant to the client connected to the thread. However, I recently ran some load tests that show this approach to have unacceptable CPU costs when 300+ clients are simultaneously connected. Accordingly, Flodgatt now uses a different structure: the main Redis thread now announces each incoming message via a watch channel connected to every client thread, and each client thread filters out irrelevant messages. In theory, this could lead to slightly higher memory use, but tests I have run so far have not found a measurable increase. On the other hand, Flodgatt's CPU use is now an order of magnitude lower in tests I've run. This approach does run a (very slight) risk of dropping messages under extremely heavy load: because a watch channel only stores the most recent message transmitted, if Flodgatt adds a second message before the thread can read the first message, the first message will be overwritten and never transmitted. This seems unlikely to happen in practice, and we can avoid the issue entirely by changing to a broadcast channel when we upgrade to the most recent Tokio version (see #75). --- Cargo.lock | 2 +- Cargo.toml | 9 +- src/lib.rs | 3 + src/main.rs | 86 +++++-- .../event/checked_event/status/mod.rs | 11 +- src/messages/event/dynamic_event.rs | 19 +- src/messages/event/mod.rs | 3 + src/parse_client_request/subscription.rs | 2 +- src/redis_to_client_stream/client_agent.rs | 127 ----------- src/redis_to_client_stream/event_stream.rs | 210 +++++++++++------- src/redis_to_client_stream/mod.rs | 6 +- .../receiver/message_queues.rs | 53 ----- src/redis_to_client_stream/receiver/mod.rs | 111 ++++----- .../redis/redis_connection/mod.rs | 2 +- 14 files changed, 265 insertions(+), 379 deletions(-) delete mode 100644 src/redis_to_client_stream/client_agent.rs delete mode 100644 src/redis_to_client_stream/receiver/message_queues.rs diff --git a/Cargo.lock b/Cargo.lock index 0318b3a..e0d0277 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -453,7 +453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.7.1" +version = "0.8.0" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index bd642ae..35a728e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.7.1" +version = "0.8.0" authors = ["Daniel Long Sockwell "] edition = "2018" @@ -43,8 +43,9 @@ stub_status = [] production = [] [profile.release] -lto = "fat" -panic = "abort" -codegen-units = 1 +#lto = "fat" +#panic = "abort" +#codegen-units = 1 + diff --git a/src/lib.rs b/src/lib.rs index ede8963..8af4e25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,9 @@ //! most important settings for performance control the frequency with which the `ClientAgent` //! polls the `Receiver` and the frequency with which the `Receiver` polls Redis. //! + +#![allow(clippy::try_err, clippy::match_bool)] + pub mod config; pub mod err; pub mod messages; diff --git a/src/main.rs b/src/main.rs index 74b6702..079d8ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,14 @@ use flodgatt::{ config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig}, - parse_client_request::{PgPool, Subscription}, - redis_to_client_stream::{ClientAgent, EventStream, Receiver}, + messages::Event, + parse_client_request::{PgPool, Subscription, Timeline}, + redis_to_client_stream::{Receiver, SseStream, WsStream}, }; use std::{env, fs, net::SocketAddr, os::unix::fs::PermissionsExt}; -use tokio::net::UnixListener; +use tokio::{ + net::UnixListener, + sync::{mpsc, watch}, +}; use warp::{http::StatusCode, path, ws::Ws2, Filter, Rejection}; fn main() { @@ -23,8 +27,10 @@ fn main() { let cfg = DeploymentConfig::from_env(env_vars); let pg_pool = PgPool::new(postgres_cfg); - - let receiver = Receiver::try_from(redis_cfg) + let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); + let poll_freq = *redis_cfg.polling_interval; + let receiver = Receiver::try_from(redis_cfg, event_tx, cmd_rx) .unwrap_or_else(|e| { log::error!("{}\nFlodgatt shutting down...", e); std::process::exit(1); @@ -34,38 +40,57 @@ fn main() { // Server Sent Events let sse_receiver = receiver.clone(); - let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); + let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone()); + let whitelist_mode = *cfg.whitelist_mode; let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode) .and(warp::sse()) .map( move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); - let mut client_agent = ClientAgent::new(sse_receiver.clone(), &subscription); - client_agent.subscribe(); - + { + let mut receiver = sse_receiver.lock().expect("TODO"); + receiver.subscribe(&subscription).unwrap_or_else(|e| { + log::error!("Could not subscribe to the Redis channel: {}", e) + }); + } + let cmd_tx = sse_cmd_tx.clone(); + let sse_rx = sse_rx.clone(); + // self.sse.reply( + // warp::sse::keep_alive() + // .interval(Duration::from_secs(30)) + // .text("thump".to_string()) + // .stream(event_stream), + // ) // send the updates through the SSE connection - EventStream::send_to_sse(client_agent, sse_connection_to_client, sse_interval) + SseStream::send_events(sse_connection_to_client, cmd_tx, subscription, sse_rx) }, ) .with(warp::reply::with::header("Connection", "keep-alive")); // WebSocket let ws_receiver = receiver.clone(); - let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode); + let whitelist_mode = *cfg.whitelist_mode; let ws_routes = Subscription::from_ws_request(pg_pool, whitelist_mode) .and(warp::ws::ws2()) .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); - let mut client_agent = ClientAgent::new(ws_receiver.clone(), &subscription); - client_agent.subscribe(); + { + let mut receiver = ws_receiver.lock().expect("TODO"); + receiver.subscribe(&subscription).unwrap_or_else(|e| { + log::error!("Could not subscribe to the Redis channel: {}", e) + }); + } + let cmd_tx = cmd_tx.clone(); + let ws_rx = event_rx.clone(); + let token = subscription + .clone() + .access_token + .unwrap_or_else(String::new); - // send the updates through the WS connection - // (along with the User's access_token which is sent for security) + // send the updates through the WS connection (along with the access_token, for security) ( - ws.on_upgrade(move |s| { - EventStream::send_to_ws(s, client_agent, ws_update_interval) - }), - subscription.access_token.unwrap_or_else(String::new), + ws.on_upgrade(move |ws| WsStream::new(ws, cmd_tx, subscription).send_events(ws_rx)), + token, ) }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); @@ -77,14 +102,12 @@ fn main() { #[cfg(feature = "stub_status")] let status_endpoints = { - let (r1, r2, r3) = (receiver.clone(), receiver.clone(), receiver.clone()); + let (r1, r3) = (receiver.clone(), receiver.clone()); warp::path!("api" / "v1" / "streaming" / "health") .map(|| "OK") .or(warp::path!("api" / "v1" / "streaming" / "status") .and(warp::path::end()) .map(move || r1.lock().expect("TODO").count_connections())) - .or(warp::path!("api" / "v1" / "streaming" / "status" / "queue") - .map(move || r2.lock().expect("TODO").queue_length())) .or( warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline") .map(move || r3.lock().expect("TODO").list_connections()), @@ -119,7 +142,24 @@ fn main() { ) .run_incoming(incoming); } else { + use futures::{future::lazy, stream::Stream as _Stream}; + use std::time::Instant; + let server_addr = SocketAddr::new(*cfg.address, *cfg.port); - warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).run(server_addr); + + tokio::run(lazy(move || { + let receiver = receiver.clone(); + warp::spawn(lazy(move || { + tokio::timer::Interval::new(Instant::now(), poll_freq) + .map_err(|e| log::error!("{}", e)) + .for_each(move |_| { + let receiver = receiver.clone(); + receiver.lock().expect("TODO").poll_broadcast(); + Ok(()) + }) + })); + + warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).bind(server_addr) + })); }; } diff --git a/src/messages/event/checked_event/status/mod.rs b/src/messages/event/checked_event/status/mod.rs index 8656231..8a60540 100644 --- a/src/messages/event/checked_event/status/mod.rs +++ b/src/messages/event/checked_event/status/mod.rs @@ -40,6 +40,7 @@ pub struct Status { poll: Option, card: Option, language: Option, + text: Option, // ↓↓↓ Only for authorized users favourited: Option, @@ -85,16 +86,14 @@ impl Status { pub fn involves_any(&self, blocks: &Blocks) -> bool { const ALLOW: bool = false; const REJECT: bool = true; - let Blocks { blocked_users, blocking_users, blocked_domains, } = blocks; + let user_id = &self.account.id.parse().expect("TODO"); - if !self.calculate_involved_users().is_disjoint(blocked_users) { - REJECT - } else if blocking_users.contains(&self.account.id.parse().expect("TODO")) { + if blocking_users.contains(user_id) || self.involves(blocked_users) { REJECT } else { let full_username = &self.account.acct; @@ -105,7 +104,7 @@ impl Status { } } - fn calculate_involved_users(&self) -> HashSet { + fn involves(&self, blocked_users: &HashSet) -> bool { // TODO replace vvvv with error handling let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); @@ -126,6 +125,6 @@ impl Status { if let Some(boosted_status) = self.reblog.clone() { involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); } - involved_users + !involved_users.is_disjoint(blocked_users) } } diff --git a/src/messages/event/dynamic_event.rs b/src/messages/event/dynamic_event.rs index 9bff2fd..6d2ead8 100644 --- a/src/messages/event/dynamic_event.rs +++ b/src/messages/event/dynamic_event.rs @@ -23,7 +23,7 @@ impl DynamicEvent { match self.payload["language"].as_str() { Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == &String::new() => ALLOW, + Some(empty) if empty == String::new() => ALLOW, Some(_toot_language) => REJECT, } } @@ -45,12 +45,10 @@ impl DynamicEvent { blocked_domains, } = blocks; - let user_id = self.payload["account"]["id"].as_str().expect("TODO"); + let id = self.payload["account"]["id"].as_str().expect("TODO"); let username = self.payload["account"]["acct"].as_str().expect("TODO"); - if !self.calculate_involved_users().is_disjoint(blocked_users) { - REJECT - } else if blocking_users.contains(&user_id.parse().expect("TODO")) { + if self.involves(blocked_users) || blocking_users.contains(&id.parse().expect("TODO")) { REJECT } else { let full_username = &username; @@ -60,9 +58,11 @@ impl DynamicEvent { } } } - fn calculate_involved_users(&self) -> HashSet { + + // involved_users = mentioned_users + author + replied-to user + boosted user + fn involves(&self, blocked_users: &HashSet) -> bool { + // mentions let mentions = self.payload["mentions"].as_array().expect("TODO"); - // involved_users = mentioned_users + author + replied-to user + boosted user let mut involved_users: HashSet = mentions .iter() .map(|mention| mention["id"].as_str().expect("TODO").parse().expect("TODO")) @@ -73,16 +73,15 @@ impl DynamicEvent { involved_users.insert(author_id.parse::().expect("TODO")); // replied-to user let replied_to_user = self.payload["in_reply_to_account_id"].as_str(); - if let Some(user_id) = replied_to_user.clone() { + if let Some(user_id) = replied_to_user { involved_users.insert(user_id.parse().expect("TODO")); } // boosted user - let id_of_boosted_user = self.payload["reblog"]["account"]["id"] .as_str() .expect("TODO"); involved_users.insert(id_of_boosted_user.parse().expect("TODO")); - involved_users + !involved_users.is_disjoint(blocked_users) } } diff --git a/src/messages/event/mod.rs b/src/messages/event/mod.rs index 133ffa2..3c70b7b 100644 --- a/src/messages/event/mod.rs +++ b/src/messages/event/mod.rs @@ -11,6 +11,7 @@ use std::string::String; pub enum Event { TypeSafe(CheckedEvent), Dynamic(DynamicEvent), + Ping, } impl Event { @@ -37,6 +38,7 @@ impl Event { CheckedEvent::FiltersChanged => "filters_changed", }, Self::Dynamic(dyn_event) => &dyn_event.event, + Self::Ping => panic!("event_name() called on EventNotReady"), }) } @@ -54,6 +56,7 @@ impl Event { FiltersChanged => None, }, Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()), + Self::Ping => panic!("payload() called on EventNotReady"), } } } diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs index 783ee3c..5f796f8 100644 --- a/src/parse_client_request/subscription.rs +++ b/src/parse_client_request/subscription.rs @@ -218,7 +218,7 @@ impl Timeline { }; use {Content::*, Reach::*, Stream::*}; - Ok(match &timeline.split(":").collect::>()[..] { + Ok(match &timeline.split(':').collect::>()[..] { ["public"] => Timeline(Public, Federated, All), ["public", "local"] => Timeline(Public, Local, All), ["public", "media"] => Timeline(Public, Federated, Media), diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs deleted file mode 100644 index a9681e0..0000000 --- a/src/redis_to_client_stream/client_agent.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! Provides an interface between the `Warp` filters and the underlying -//! mechanics of talking with Redis/managing multiple threads. -//! -//! The `ClientAgent`'s interface is very simple. All you can do with it is: -//! * Create a totally new `ClientAgent` with no shared data; -//! * Clone an existing `ClientAgent`, sharing the `Receiver`; -//! * Manage an new timeline/user pair; or -//! * Poll an existing `ClientAgent` to see if there are any new messages -//! for clients -//! -//! When you poll the `ClientAgent`, it is responsible for polling internal data -//! structures, getting any updates from Redis, and then filtering out any updates -//! that should be excluded by relevant filters. -//! -//! Because `StreamManagers` are lightweight data structures that do not directly -//! communicate with Redis, it we create a new `ClientAgent` for -//! each new client connection (each in its own thread).use super::{message::Message, receiver::Receiver} -use super::receiver::{Receiver, ReceiverErr}; -use crate::{ - messages::Event, - parse_client_request::{Stream::Public, Subscription, Timeline}, -}; -use futures::{ - Async::{self, NotReady, Ready}, - Poll, -}; -use std::sync::{Arc, Mutex, MutexGuard}; - -/// Struct for managing all Redis streams. -#[derive(Clone, Debug)] -pub struct ClientAgent { - receiver: Arc>, - pub subscription: Subscription, -} - -impl ClientAgent { - pub fn new(receiver: Arc>, subscription: &Subscription) -> Self { - ClientAgent { - receiver, - subscription: subscription.clone(), - } - } - - /// Initializes the `ClientAgent` with a unique ID associated with a specific user's - /// subscription. Also passes values to the `Receiver` for it's initialization. - /// - /// Note that this *may or may not* result in a new Redis connection. - /// If the server has already subscribed to the timeline on behalf of - /// a different user, the `Receiver` is responsible for figuring - /// that out and avoiding duplicated connections. Thus, it is safe to - /// use this method for each new client connection. - pub fn subscribe(&mut self) { - let mut receiver = self.lock_receiver(); - receiver - .add_subscription(&self.subscription) - .unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e)) - } - - pub fn disconnect(&self) -> futures::future::FutureResult { - let mut receiver = self.lock_receiver(); - receiver - .remove_subscription(&self.subscription) - .unwrap_or_else(|e| log::error!("Could not unsubscribe from: {}", e)); - futures::future::ok(false) - } - - fn lock_receiver(&self) -> MutexGuard { - match self.receiver.lock() { - Ok(inner) => inner, - Err(e) => { - log::error!( - "Another thread crashed: {}\n - Attempting to continue, possibly with invalid data", - e - ); - e.into_inner() - } - } - } -} - -/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented. -impl futures::stream::Stream for ClientAgent { - type Item = Event; - type Error = ReceiverErr; - - /// Checks for any new messages that should be sent to the client. - /// - /// The `ClientAgent` polls the `Receiver` and replies - /// with `Ok(Ready(Some(Value)))` if there is a new message to send to - /// the client. If there is no new message or if the new message should be - /// filtered out based on one of the user's filters, then the `ClientAgent` - /// replies with `Ok(NotReady)`. The `ClientAgent` bubles up any - /// errors from the underlying data structures. - fn poll(&mut self) -> Poll, Self::Error> { - let result = { - let mut receiver = self.lock_receiver(); - receiver.poll_for(self.subscription.id) - }; - - let timeline = &self.subscription.timeline; - let allowed_langs = &self.subscription.allowed_langs; - let blocks = &self.subscription.blocks; - let (send, block) = (|msg| Ok(Ready(Some(msg))), Ok(NotReady)); - - use crate::messages::{CheckedEvent::Update, Event::*}; - match result { - Ok(NotReady) => Ok(NotReady), - Ok(Ready(None)) => Ok(Ready(None)), - Ok(Async::Ready(Some(event))) => match event { - TypeSafe(Update { payload, queued_at }) => match timeline { - Timeline(Public, _, _) if payload.language_not(allowed_langs) => block, - _ if payload.involves_any(blocks) => block, - _ => send(TypeSafe(Update { payload, queued_at })), - }, - TypeSafe(non_update) => send(Event::TypeSafe(non_update)), - Dynamic(event) if event.event == "update" => match timeline { - Timeline(Public, _, _) if event.language_not(allowed_langs) => block, - _ if event.involves_any(blocks) => block, - _ => send(Dynamic(event)), - }, - Dynamic(non_update) => send(Dynamic(non_update)), - }, - Err(e) => Err(e), - } - } -} diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs index 8c28ecb..35a86bc 100644 --- a/src/redis_to_client_stream/event_stream.rs +++ b/src/redis_to_client_stream/event_stream.rs @@ -1,32 +1,37 @@ -use super::ClientAgent; +use crate::messages::Event; +use crate::parse_client_request::{Subscription, Timeline}; -use futures::{future::Future, stream::Stream, Async}; +use futures::{future::Future, stream::Stream}; use log; -use std::time::{Duration, Instant}; +use std::time::Duration; +use tokio::sync::{mpsc, watch}; use warp::{ reply::Reply, - sse::Sse, + sse::{ServerSentEvent, Sse}, ws::{Message, WebSocket}, }; -pub struct EventStream; -impl EventStream { - /// Send a stream of replies to a WebSocket client. - pub fn send_to_ws( +pub struct WsStream { + ws_tx: mpsc::UnboundedSender, + unsubscribe_tx: mpsc::UnboundedSender, + subscription: Subscription, +} + +impl WsStream { + pub fn new( ws: WebSocket, - mut client_agent: ClientAgent, - interval: Duration, - ) -> impl Future { + unsubscribe_tx: mpsc::UnboundedSender, + subscription: Subscription, + ) -> Self { let (transmit_to_ws, _receive_from_ws) = ws.split(); - let timeline = client_agent.subscription.timeline; - // Create a pipe - let (tx, rx) = futures::sync::mpsc::unbounded(); + let (ws_tx, ws_rx) = mpsc::unbounded_channel(); - // Send one end of it to a different thread and tell that end to forward whatever it gets - // on to the WebSocket client + // Send one end of it to a different green thread and tell that end to forward + // whatever it gets on to the WebSocket client warp::spawn( - rx.map_err(|()| -> warp::Error { unreachable!() }) + ws_rx + .map_err(|_| -> warp::Error { unreachable!() }) .forward(transmit_to_ws) .map(|_r| ()) .map_err(|e| match e.to_string().as_ref() { @@ -34,70 +39,119 @@ impl EventStream { _ => log::warn!("WebSocket send error: {}", e), }), ); - - let mut last_ping_time = Instant::now(); - tokio::timer::Interval::new(Instant::now(), interval) - .take_while(move |_| { - // Right now, we do not need to see if we have any messages _from_ the - // WebSocket connection because the API doesn't support clients sending - // commands via the WebSocket. However, if the [stream multiplexing API - // change](github.com/tootsuite/flodgatt/issues/121) is implemented, we'll - // need to receive messages from the client. If so, we'll need a - // `receive_from_ws.poll() call here (or later)` - match client_agent.poll() { - Ok(Async::NotReady) => { - if last_ping_time.elapsed() > Duration::from_secs(30) { - last_ping_time = Instant::now(); - match tx.unbounded_send(Message::text("{}")) { - Ok(_) => futures::future::ok(true), - Err(_) => client_agent.disconnect(), - } - } else { - futures::future::ok(true) - } - } - Ok(Async::Ready(Some(msg))) => { - match tx.unbounded_send(Message::text(msg.to_json_string())) { - Ok(_) => futures::future::ok(true), - Err(_) => client_agent.disconnect(), - } - } - Err(e) => { - log::error!("{}\n Dropping WebSocket message and continuing.", e); - futures::future::ok(true) - } - Ok(Async::Ready(None)) => { - log::info!("WebSocket ClientAgent got Ready(None)"); - futures::future::ok(true) - } - } - }) - .for_each(move |_instant| Ok(())) - .then(move |result| { - log::info!("WebSocket connection for {:?} closed.", timeline); - result - }) - .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e)) + Self { + ws_tx, + unsubscribe_tx, + subscription, + } } - pub fn send_to_sse(mut client_agent: ClientAgent, sse: Sse, interval: Duration) -> impl Reply { - let event_stream = - tokio::timer::Interval::new(Instant::now(), interval).filter_map(move |_| { - match client_agent.poll() { - Ok(Async::Ready(Some(event))) => Some(( - warp::sse::event(event.event_name()), - warp::sse::data(event.payload().unwrap_or_else(String::new)), - )), - Ok(Async::Ready(None)) => { - log::info!("SSE ClientAgent got Ready(None)"); - None - } - Ok(Async::NotReady) => None, - Err(e) => { - log::error!("{}\n Dropping SSE message and continuing.", e); - None - } + pub fn send_events( + mut self, + event_rx: watch::Receiver<(Timeline, Event)>, + ) -> impl Future { + let target_timeline = self.subscription.timeline; + + event_rx.map_err(|_| ()).for_each(move |(tl, event)| { + if matches!(event, Event::Ping) { + self.send_ping() + } else if target_timeline == tl { + use crate::messages::{CheckedEvent::Update, Event::*}; + use crate::parse_client_request::Stream::Public; + let blocks = &self.subscription.blocks; + let allowed_langs = &self.subscription.allowed_langs; + + match event { + TypeSafe(Update { payload, queued_at }) => match tl { + Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()), + _ if payload.involves_any(&blocks) => Ok(()), + _ => self.send_msg(TypeSafe(Update { payload, queued_at })), + }, + TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)), + Dynamic(event) if event.event == "update" => match tl { + Timeline(Public, _, _) if event.language_not(allowed_langs) => Ok(()), + _ if event.involves_any(&blocks) => Ok(()), + _ => self.send_msg(Dynamic(event)), + }, + Dynamic(non_update) => self.send_msg(Dynamic(non_update)), + Ping => unreachable!(), // handled pings above } + } else { + Ok(()) + } + }) + } + + fn send_ping(&mut self) -> Result<(), ()> { + self.send_txt("{}") + } + + fn send_msg(&mut self, event: Event) -> Result<(), ()> { + self.send_txt(&event.to_json_string()) + } + + fn send_txt(&mut self, txt: &str) -> Result<(), ()> { + let tl = self.subscription.timeline; + match self.ws_tx.try_send(Message::text(txt)) { + Ok(_) => Ok(()), + Err(_) => { + self.unsubscribe_tx.try_send(tl).expect("TODO"); + Err(()) + } + } + } +} + +pub struct SseStream {} + +impl SseStream { + fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> { + Some(( + warp::sse::event(event.event_name()), + warp::sse::data(event.payload().unwrap_or_else(String::new)), + )) + } + + pub fn send_events( + sse: Sse, + mut unsubscribe_tx: mpsc::UnboundedSender, + subscription: Subscription, + sse_rx: watch::Receiver<(Timeline, Event)>, + ) -> impl Reply { + let target_timeline = subscription.timeline; + let allowed_langs = subscription.allowed_langs; + let blocks = subscription.blocks; + + let event_stream = sse_rx + .filter_map(move |(timeline, event)| { + if target_timeline == timeline { + use crate::messages::{CheckedEvent, CheckedEvent::Update, Event::*}; + use crate::parse_client_request::Stream::Public; + match event { + TypeSafe(Update { payload, queued_at }) => match timeline { + Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, + _ if payload.involves_any(&blocks) => None, + _ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { + payload, + queued_at, + })), + }, + TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), + Dynamic(event) if event.event == "update" => match timeline { + Timeline(Public, _, _) if event.language_not(&allowed_langs) => None, + _ if event.involves_any(&blocks) => None, + _ => Self::reply_with(Event::Dynamic(event)), + }, + Dynamic(non_update) => Self::reply_with(Event::Dynamic(non_update)), + Ping => None, // pings handled automatically + } + } else { + None + } + }) + .then(move |res| { + unsubscribe_tx.try_send(target_timeline).expect("TODO"); + res }); sse.reply( diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index 3f419d1..f8b88b9 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -1,10 +1,12 @@ //! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. -mod client_agent; mod event_stream; mod receiver; mod redis; -pub use {client_agent::ClientAgent, event_stream::EventStream, receiver::Receiver}; +pub use { + event_stream::{SseStream, WsStream}, + receiver::Receiver, +}; #[cfg(feature = "bench")] pub use redis::redis_msg::{RedisMsg, RedisParseOutput}; diff --git a/src/redis_to_client_stream/receiver/message_queues.rs b/src/redis_to_client_stream/receiver/message_queues.rs deleted file mode 100644 index ace23d8..0000000 --- a/src/redis_to_client_stream/receiver/message_queues.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::messages::Event; -use crate::parse_client_request::Timeline; - -use hashbrown::HashMap; -use std::{collections::VecDeque, fmt}; -use uuid::Uuid; - -#[derive(Clone)] -pub struct MsgQueue { - pub timeline: Timeline, - pub messages: VecDeque, -} - -impl MsgQueue { - pub fn new(timeline: Timeline) -> Self { - MsgQueue { - messages: VecDeque::new(), - - timeline, - } - } -} - -#[derive(Debug)] -pub struct MessageQueues(pub HashMap); - -impl MessageQueues {} - -impl fmt::Debug for MsgQueue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "\ -MsgQueue {{ - timeline: {:?}, - messages: {:?}, -}}", - self.timeline, self.messages, - ) - } -} - -impl std::ops::Deref for MessageQueues { - type Target = HashMap; - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl std::ops::DerefMut for MessageQueues { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} diff --git a/src/redis_to_client_stream/receiver/mod.rs b/src/redis_to_client_stream/receiver/mod.rs index 2df1cbc..3d7bed3 100644 --- a/src/redis_to_client_stream/receiver/mod.rs +++ b/src/redis_to_client_stream/receiver/mod.rs @@ -2,10 +2,7 @@ //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. mod err; -mod message_queues; - pub use err::ReceiverErr; -pub use message_queues::{MessageQueues, MsgQueue}; use super::redis::{redis_connection::RedisCmd, RedisConn}; @@ -15,11 +12,9 @@ use crate::{ parse_client_request::{Stream, Subscription, Timeline}, }; -use { - futures::{Async, Poll}, - hashbrown::HashMap, - uuid::Uuid, -}; +use futures::{Async, Stream as _Stream}; +use hashbrown::HashMap; +use tokio::sync::{mpsc, watch}; use std::{ result, @@ -33,25 +28,28 @@ type Result = result::Result; #[derive(Debug)] pub struct Receiver { redis_connection: RedisConn, - redis_poll_interval: Duration, - redis_polled_at: Instant, - pub msg_queues: MessageQueues, clients_per_timeline: HashMap, + tx: watch::Sender<(Timeline, Event)>, + rx: mpsc::UnboundedReceiver, + ping_time: Instant, } impl Receiver { /// Create a new `Receiver`, with its own Redis connections (but, as yet, no /// active subscriptions). - pub fn try_from(redis_cfg: config::RedisConfig) -> Result { - let redis_poll_interval = *redis_cfg.polling_interval; - let redis_connection = RedisConn::new(redis_cfg)?; + pub fn try_from( + redis_cfg: config::RedisConfig, + tx: watch::Sender<(Timeline, Event)>, + rx: mpsc::UnboundedReceiver, + ) -> Result { Ok(Self { - redis_polled_at: Instant::now(), - redis_poll_interval, - redis_connection, - msg_queues: MessageQueues(HashMap::new()), + redis_connection: RedisConn::new(redis_cfg)?, + clients_per_timeline: HashMap::new(), + tx, + rx, + ping_time: Instant::now(), }) } @@ -59,15 +57,12 @@ impl Receiver { Arc::new(Mutex::new(self)) } - /// Assigns the `Receiver` a new timeline to monitor and runs other - /// first-time setup. - pub fn add_subscription(&mut self, subscription: &Subscription) -> Result<()> { + pub fn subscribe(&mut self, subscription: &Subscription) -> Result<()> { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) { self.redis_connection.update_cache(hashtag, id); }; - self.msg_queues.insert(subscription.id, MsgQueue::new(tl)); let number_of_subscriptions = self .clients_per_timeline @@ -79,13 +74,11 @@ impl Receiver { if *number_of_subscriptions == 1 { self.redis_connection.send_cmd(Subscribe, &tl)? }; - + log::info!("Started stream for {:?}", tl); Ok(()) } - pub fn remove_subscription(&mut self, subscription: &Subscription) -> Result<()> { - let tl = subscription.timeline; - self.msg_queues.remove(&subscription.id); + pub fn unsubscribe(&mut self, tl: Timeline) -> Result<()> { let number_of_subscriptions = self .clients_per_timeline .entry(tl) @@ -102,48 +95,30 @@ impl Receiver { self.redis_connection.send_cmd(Unsubscribe, &tl)?; self.clients_per_timeline.remove_entry(&tl); }; - + log::info!("Ended stream for {:?}", tl); Ok(()) } - /// Returns the oldest message in the `ClientAgent`'s queue (if any). - /// - /// Note: This method does **not** poll Redis every time, because polling - /// Redis is significantly more time consuming that simply returning the - /// message already in a queue. Thus, we only poll Redis if it has not - /// been polled lately. - pub fn poll_for(&mut self, id: Uuid) -> Poll, ReceiverErr> { - // let (t1, mut polled_redis) = (Instant::now(), false); - if self.redis_polled_at.elapsed() > self.redis_poll_interval { - loop { - match self.redis_connection.poll_redis() { - Ok(Async::NotReady) => break, - Ok(Async::Ready(Some((timeline, event)))) => { - self.msg_queues - .values_mut() - .filter(|msg_queue| msg_queue.timeline == timeline) - .for_each(|msg_queue| { - msg_queue.messages.push_back(event.clone()); - }); - } - Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace - Err(err) => Err(err)?, - } - } - // polled_redis = true; - self.redis_polled_at = Instant::now(); + pub fn poll_broadcast(&mut self) { + while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { + self.unsubscribe(tl).expect("TODO"); } - // If the `msg_queue` being polled has any new messages, return the first (oldest) one - let msg_q = self.msg_queues.get_mut(&id).ok_or(ReceiverErr::InvalidId)?; - let res = match msg_q.messages.pop_front() { - Some(event) => Ok(Async::Ready(Some(event))), - None => Ok(Async::NotReady), - }; - // if !polled_redis { - // log::info!("poll_for in {:?}", t1.elapsed()); - // } - res + if self.ping_time.elapsed() > Duration::from_secs(30) { + self.ping_time = Instant::now(); + self.tx + .broadcast((Timeline::empty(), Event::Ping)) + .expect("TODO"); + } else { + match self.redis_connection.poll_redis() { + Ok(Async::NotReady) => (), + Ok(Async::Ready(Some((timeline, event)))) => { + self.tx.broadcast((timeline, event)).expect("TODO"); + } + Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace + Err(_err) => panic!("TODO"), + } + } } pub fn count_connections(&self) -> String { @@ -166,14 +141,4 @@ impl Receiver { }) .collect() } - - pub fn queue_length(&self) -> String { - format!( - "Longest MessageQueue: {}", - self.msg_queues - .0 - .values() - .fold(0, |acc, el| acc.max(el.messages.len())) - ) - } } diff --git a/src/redis_to_client_stream/redis/redis_connection/mod.rs b/src/redis_to_client_stream/redis/redis_connection/mod.rs index f34bc8b..872fac0 100644 --- a/src/redis_to_client_stream/redis/redis_connection/mod.rs +++ b/src/redis_to_client_stream/redis/redis_connection/mod.rs @@ -80,7 +80,7 @@ impl RedisConn { self.redis_input.clear(); let (input, invalid_bytes) = str::from_utf8(&input) - .map(|input| (input, "".as_bytes())) + .map(|input| (input, &b""[..])) .unwrap_or_else(|e| { let (valid, invalid) = input.split_at(e.valid_up_to()); (str::from_utf8(valid).expect("Guaranteed by ^^^^"), invalid)