diff --git a/Cargo.lock b/Cargo.lock index 0318b3a..e0d0277 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -453,7 +453,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.7.1" +version = "0.8.0" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index bd642ae..35a728e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.7.1" +version = "0.8.0" authors = ["Daniel Long Sockwell "] edition = "2018" @@ -43,8 +43,9 @@ stub_status = [] production = [] [profile.release] -lto = "fat" -panic = "abort" -codegen-units = 1 +#lto = "fat" +#panic = "abort" +#codegen-units = 1 + diff --git a/src/lib.rs b/src/lib.rs index ede8963..8af4e25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,9 @@ //! most important settings for performance control the frequency with which the `ClientAgent` //! polls the `Receiver` and the frequency with which the `Receiver` polls Redis. //! + +#![allow(clippy::try_err, clippy::match_bool)] + pub mod config; pub mod err; pub mod messages; diff --git a/src/main.rs b/src/main.rs index 74b6702..079d8ba 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,14 @@ use flodgatt::{ config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig}, - parse_client_request::{PgPool, Subscription}, - redis_to_client_stream::{ClientAgent, EventStream, Receiver}, + messages::Event, + parse_client_request::{PgPool, Subscription, Timeline}, + redis_to_client_stream::{Receiver, SseStream, WsStream}, }; use std::{env, fs, net::SocketAddr, os::unix::fs::PermissionsExt}; -use tokio::net::UnixListener; +use tokio::{ + net::UnixListener, + sync::{mpsc, watch}, +}; use warp::{http::StatusCode, path, ws::Ws2, Filter, Rejection}; fn main() { @@ -23,8 +27,10 @@ fn main() { let cfg = DeploymentConfig::from_env(env_vars); let pg_pool = PgPool::new(postgres_cfg); - - let receiver = Receiver::try_from(redis_cfg) + let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); + let poll_freq = *redis_cfg.polling_interval; + let receiver = Receiver::try_from(redis_cfg, event_tx, cmd_rx) .unwrap_or_else(|e| { log::error!("{}\nFlodgatt shutting down...", e); std::process::exit(1); @@ -34,38 +40,57 @@ fn main() { // Server Sent Events let sse_receiver = receiver.clone(); - let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); + let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone()); + let whitelist_mode = *cfg.whitelist_mode; let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode) .and(warp::sse()) .map( move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); - let mut client_agent = ClientAgent::new(sse_receiver.clone(), &subscription); - client_agent.subscribe(); - + { + let mut receiver = sse_receiver.lock().expect("TODO"); + receiver.subscribe(&subscription).unwrap_or_else(|e| { + log::error!("Could not subscribe to the Redis channel: {}", e) + }); + } + let cmd_tx = sse_cmd_tx.clone(); + let sse_rx = sse_rx.clone(); + // self.sse.reply( + // warp::sse::keep_alive() + // .interval(Duration::from_secs(30)) + // .text("thump".to_string()) + // .stream(event_stream), + // ) // send the updates through the SSE connection - EventStream::send_to_sse(client_agent, sse_connection_to_client, sse_interval) + SseStream::send_events(sse_connection_to_client, cmd_tx, subscription, sse_rx) }, ) .with(warp::reply::with::header("Connection", "keep-alive")); // WebSocket let ws_receiver = receiver.clone(); - let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode); + let whitelist_mode = *cfg.whitelist_mode; let ws_routes = Subscription::from_ws_request(pg_pool, whitelist_mode) .and(warp::ws::ws2()) .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); - let mut client_agent = ClientAgent::new(ws_receiver.clone(), &subscription); - client_agent.subscribe(); + { + let mut receiver = ws_receiver.lock().expect("TODO"); + receiver.subscribe(&subscription).unwrap_or_else(|e| { + log::error!("Could not subscribe to the Redis channel: {}", e) + }); + } + let cmd_tx = cmd_tx.clone(); + let ws_rx = event_rx.clone(); + let token = subscription + .clone() + .access_token + .unwrap_or_else(String::new); - // send the updates through the WS connection - // (along with the User's access_token which is sent for security) + // send the updates through the WS connection (along with the access_token, for security) ( - ws.on_upgrade(move |s| { - EventStream::send_to_ws(s, client_agent, ws_update_interval) - }), - subscription.access_token.unwrap_or_else(String::new), + ws.on_upgrade(move |ws| WsStream::new(ws, cmd_tx, subscription).send_events(ws_rx)), + token, ) }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); @@ -77,14 +102,12 @@ fn main() { #[cfg(feature = "stub_status")] let status_endpoints = { - let (r1, r2, r3) = (receiver.clone(), receiver.clone(), receiver.clone()); + let (r1, r3) = (receiver.clone(), receiver.clone()); warp::path!("api" / "v1" / "streaming" / "health") .map(|| "OK") .or(warp::path!("api" / "v1" / "streaming" / "status") .and(warp::path::end()) .map(move || r1.lock().expect("TODO").count_connections())) - .or(warp::path!("api" / "v1" / "streaming" / "status" / "queue") - .map(move || r2.lock().expect("TODO").queue_length())) .or( warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline") .map(move || r3.lock().expect("TODO").list_connections()), @@ -119,7 +142,24 @@ fn main() { ) .run_incoming(incoming); } else { + use futures::{future::lazy, stream::Stream as _Stream}; + use std::time::Instant; + let server_addr = SocketAddr::new(*cfg.address, *cfg.port); - warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).run(server_addr); + + tokio::run(lazy(move || { + let receiver = receiver.clone(); + warp::spawn(lazy(move || { + tokio::timer::Interval::new(Instant::now(), poll_freq) + .map_err(|e| log::error!("{}", e)) + .for_each(move |_| { + let receiver = receiver.clone(); + receiver.lock().expect("TODO").poll_broadcast(); + Ok(()) + }) + })); + + warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).bind(server_addr) + })); }; } diff --git a/src/messages/event/checked_event/status/mod.rs b/src/messages/event/checked_event/status/mod.rs index 8656231..8a60540 100644 --- a/src/messages/event/checked_event/status/mod.rs +++ b/src/messages/event/checked_event/status/mod.rs @@ -40,6 +40,7 @@ pub struct Status { poll: Option, card: Option, language: Option, + text: Option, // ↓↓↓ Only for authorized users favourited: Option, @@ -85,16 +86,14 @@ impl Status { pub fn involves_any(&self, blocks: &Blocks) -> bool { const ALLOW: bool = false; const REJECT: bool = true; - let Blocks { blocked_users, blocking_users, blocked_domains, } = blocks; + let user_id = &self.account.id.parse().expect("TODO"); - if !self.calculate_involved_users().is_disjoint(blocked_users) { - REJECT - } else if blocking_users.contains(&self.account.id.parse().expect("TODO")) { + if blocking_users.contains(user_id) || self.involves(blocked_users) { REJECT } else { let full_username = &self.account.acct; @@ -105,7 +104,7 @@ impl Status { } } - fn calculate_involved_users(&self) -> HashSet { + fn involves(&self, blocked_users: &HashSet) -> bool { // TODO replace vvvv with error handling let err = |_| log_fatal!("Could not process an `id` field in {:?}", &self); @@ -126,6 +125,6 @@ impl Status { if let Some(boosted_status) = self.reblog.clone() { involved_users.insert(boosted_status.account.id.parse().unwrap_or_else(err)); } - involved_users + !involved_users.is_disjoint(blocked_users) } } diff --git a/src/messages/event/dynamic_event.rs b/src/messages/event/dynamic_event.rs index 9bff2fd..6d2ead8 100644 --- a/src/messages/event/dynamic_event.rs +++ b/src/messages/event/dynamic_event.rs @@ -23,7 +23,7 @@ impl DynamicEvent { match self.payload["language"].as_str() { Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW, None => ALLOW, // If toot language is unknown, toot is always allowed - Some(empty) if empty == &String::new() => ALLOW, + Some(empty) if empty == String::new() => ALLOW, Some(_toot_language) => REJECT, } } @@ -45,12 +45,10 @@ impl DynamicEvent { blocked_domains, } = blocks; - let user_id = self.payload["account"]["id"].as_str().expect("TODO"); + let id = self.payload["account"]["id"].as_str().expect("TODO"); let username = self.payload["account"]["acct"].as_str().expect("TODO"); - if !self.calculate_involved_users().is_disjoint(blocked_users) { - REJECT - } else if blocking_users.contains(&user_id.parse().expect("TODO")) { + if self.involves(blocked_users) || blocking_users.contains(&id.parse().expect("TODO")) { REJECT } else { let full_username = &username; @@ -60,9 +58,11 @@ impl DynamicEvent { } } } - fn calculate_involved_users(&self) -> HashSet { + + // involved_users = mentioned_users + author + replied-to user + boosted user + fn involves(&self, blocked_users: &HashSet) -> bool { + // mentions let mentions = self.payload["mentions"].as_array().expect("TODO"); - // involved_users = mentioned_users + author + replied-to user + boosted user let mut involved_users: HashSet = mentions .iter() .map(|mention| mention["id"].as_str().expect("TODO").parse().expect("TODO")) @@ -73,16 +73,15 @@ impl DynamicEvent { involved_users.insert(author_id.parse::().expect("TODO")); // replied-to user let replied_to_user = self.payload["in_reply_to_account_id"].as_str(); - if let Some(user_id) = replied_to_user.clone() { + if let Some(user_id) = replied_to_user { involved_users.insert(user_id.parse().expect("TODO")); } // boosted user - let id_of_boosted_user = self.payload["reblog"]["account"]["id"] .as_str() .expect("TODO"); involved_users.insert(id_of_boosted_user.parse().expect("TODO")); - involved_users + !involved_users.is_disjoint(blocked_users) } } diff --git a/src/messages/event/mod.rs b/src/messages/event/mod.rs index 133ffa2..3c70b7b 100644 --- a/src/messages/event/mod.rs +++ b/src/messages/event/mod.rs @@ -11,6 +11,7 @@ use std::string::String; pub enum Event { TypeSafe(CheckedEvent), Dynamic(DynamicEvent), + Ping, } impl Event { @@ -37,6 +38,7 @@ impl Event { CheckedEvent::FiltersChanged => "filters_changed", }, Self::Dynamic(dyn_event) => &dyn_event.event, + Self::Ping => panic!("event_name() called on EventNotReady"), }) } @@ -54,6 +56,7 @@ impl Event { FiltersChanged => None, }, Self::Dynamic(dyn_event) => Some(dyn_event.payload.to_string()), + Self::Ping => panic!("payload() called on EventNotReady"), } } } diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs index 783ee3c..5f796f8 100644 --- a/src/parse_client_request/subscription.rs +++ b/src/parse_client_request/subscription.rs @@ -218,7 +218,7 @@ impl Timeline { }; use {Content::*, Reach::*, Stream::*}; - Ok(match &timeline.split(":").collect::>()[..] { + Ok(match &timeline.split(':').collect::>()[..] { ["public"] => Timeline(Public, Federated, All), ["public", "local"] => Timeline(Public, Local, All), ["public", "media"] => Timeline(Public, Federated, Media), diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs deleted file mode 100644 index a9681e0..0000000 --- a/src/redis_to_client_stream/client_agent.rs +++ /dev/null @@ -1,127 +0,0 @@ -//! Provides an interface between the `Warp` filters and the underlying -//! mechanics of talking with Redis/managing multiple threads. -//! -//! The `ClientAgent`'s interface is very simple. All you can do with it is: -//! * Create a totally new `ClientAgent` with no shared data; -//! * Clone an existing `ClientAgent`, sharing the `Receiver`; -//! * Manage an new timeline/user pair; or -//! * Poll an existing `ClientAgent` to see if there are any new messages -//! for clients -//! -//! When you poll the `ClientAgent`, it is responsible for polling internal data -//! structures, getting any updates from Redis, and then filtering out any updates -//! that should be excluded by relevant filters. -//! -//! Because `StreamManagers` are lightweight data structures that do not directly -//! communicate with Redis, it we create a new `ClientAgent` for -//! each new client connection (each in its own thread).use super::{message::Message, receiver::Receiver} -use super::receiver::{Receiver, ReceiverErr}; -use crate::{ - messages::Event, - parse_client_request::{Stream::Public, Subscription, Timeline}, -}; -use futures::{ - Async::{self, NotReady, Ready}, - Poll, -}; -use std::sync::{Arc, Mutex, MutexGuard}; - -/// Struct for managing all Redis streams. -#[derive(Clone, Debug)] -pub struct ClientAgent { - receiver: Arc>, - pub subscription: Subscription, -} - -impl ClientAgent { - pub fn new(receiver: Arc>, subscription: &Subscription) -> Self { - ClientAgent { - receiver, - subscription: subscription.clone(), - } - } - - /// Initializes the `ClientAgent` with a unique ID associated with a specific user's - /// subscription. Also passes values to the `Receiver` for it's initialization. - /// - /// Note that this *may or may not* result in a new Redis connection. - /// If the server has already subscribed to the timeline on behalf of - /// a different user, the `Receiver` is responsible for figuring - /// that out and avoiding duplicated connections. Thus, it is safe to - /// use this method for each new client connection. - pub fn subscribe(&mut self) { - let mut receiver = self.lock_receiver(); - receiver - .add_subscription(&self.subscription) - .unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e)) - } - - pub fn disconnect(&self) -> futures::future::FutureResult { - let mut receiver = self.lock_receiver(); - receiver - .remove_subscription(&self.subscription) - .unwrap_or_else(|e| log::error!("Could not unsubscribe from: {}", e)); - futures::future::ok(false) - } - - fn lock_receiver(&self) -> MutexGuard { - match self.receiver.lock() { - Ok(inner) => inner, - Err(e) => { - log::error!( - "Another thread crashed: {}\n - Attempting to continue, possibly with invalid data", - e - ); - e.into_inner() - } - } - } -} - -/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented. -impl futures::stream::Stream for ClientAgent { - type Item = Event; - type Error = ReceiverErr; - - /// Checks for any new messages that should be sent to the client. - /// - /// The `ClientAgent` polls the `Receiver` and replies - /// with `Ok(Ready(Some(Value)))` if there is a new message to send to - /// the client. If there is no new message or if the new message should be - /// filtered out based on one of the user's filters, then the `ClientAgent` - /// replies with `Ok(NotReady)`. The `ClientAgent` bubles up any - /// errors from the underlying data structures. - fn poll(&mut self) -> Poll, Self::Error> { - let result = { - let mut receiver = self.lock_receiver(); - receiver.poll_for(self.subscription.id) - }; - - let timeline = &self.subscription.timeline; - let allowed_langs = &self.subscription.allowed_langs; - let blocks = &self.subscription.blocks; - let (send, block) = (|msg| Ok(Ready(Some(msg))), Ok(NotReady)); - - use crate::messages::{CheckedEvent::Update, Event::*}; - match result { - Ok(NotReady) => Ok(NotReady), - Ok(Ready(None)) => Ok(Ready(None)), - Ok(Async::Ready(Some(event))) => match event { - TypeSafe(Update { payload, queued_at }) => match timeline { - Timeline(Public, _, _) if payload.language_not(allowed_langs) => block, - _ if payload.involves_any(blocks) => block, - _ => send(TypeSafe(Update { payload, queued_at })), - }, - TypeSafe(non_update) => send(Event::TypeSafe(non_update)), - Dynamic(event) if event.event == "update" => match timeline { - Timeline(Public, _, _) if event.language_not(allowed_langs) => block, - _ if event.involves_any(blocks) => block, - _ => send(Dynamic(event)), - }, - Dynamic(non_update) => send(Dynamic(non_update)), - }, - Err(e) => Err(e), - } - } -} diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs index 8c28ecb..35a86bc 100644 --- a/src/redis_to_client_stream/event_stream.rs +++ b/src/redis_to_client_stream/event_stream.rs @@ -1,32 +1,37 @@ -use super::ClientAgent; +use crate::messages::Event; +use crate::parse_client_request::{Subscription, Timeline}; -use futures::{future::Future, stream::Stream, Async}; +use futures::{future::Future, stream::Stream}; use log; -use std::time::{Duration, Instant}; +use std::time::Duration; +use tokio::sync::{mpsc, watch}; use warp::{ reply::Reply, - sse::Sse, + sse::{ServerSentEvent, Sse}, ws::{Message, WebSocket}, }; -pub struct EventStream; -impl EventStream { - /// Send a stream of replies to a WebSocket client. - pub fn send_to_ws( +pub struct WsStream { + ws_tx: mpsc::UnboundedSender, + unsubscribe_tx: mpsc::UnboundedSender, + subscription: Subscription, +} + +impl WsStream { + pub fn new( ws: WebSocket, - mut client_agent: ClientAgent, - interval: Duration, - ) -> impl Future { + unsubscribe_tx: mpsc::UnboundedSender, + subscription: Subscription, + ) -> Self { let (transmit_to_ws, _receive_from_ws) = ws.split(); - let timeline = client_agent.subscription.timeline; - // Create a pipe - let (tx, rx) = futures::sync::mpsc::unbounded(); + let (ws_tx, ws_rx) = mpsc::unbounded_channel(); - // Send one end of it to a different thread and tell that end to forward whatever it gets - // on to the WebSocket client + // Send one end of it to a different green thread and tell that end to forward + // whatever it gets on to the WebSocket client warp::spawn( - rx.map_err(|()| -> warp::Error { unreachable!() }) + ws_rx + .map_err(|_| -> warp::Error { unreachable!() }) .forward(transmit_to_ws) .map(|_r| ()) .map_err(|e| match e.to_string().as_ref() { @@ -34,70 +39,119 @@ impl EventStream { _ => log::warn!("WebSocket send error: {}", e), }), ); - - let mut last_ping_time = Instant::now(); - tokio::timer::Interval::new(Instant::now(), interval) - .take_while(move |_| { - // Right now, we do not need to see if we have any messages _from_ the - // WebSocket connection because the API doesn't support clients sending - // commands via the WebSocket. However, if the [stream multiplexing API - // change](github.com/tootsuite/flodgatt/issues/121) is implemented, we'll - // need to receive messages from the client. If so, we'll need a - // `receive_from_ws.poll() call here (or later)` - match client_agent.poll() { - Ok(Async::NotReady) => { - if last_ping_time.elapsed() > Duration::from_secs(30) { - last_ping_time = Instant::now(); - match tx.unbounded_send(Message::text("{}")) { - Ok(_) => futures::future::ok(true), - Err(_) => client_agent.disconnect(), - } - } else { - futures::future::ok(true) - } - } - Ok(Async::Ready(Some(msg))) => { - match tx.unbounded_send(Message::text(msg.to_json_string())) { - Ok(_) => futures::future::ok(true), - Err(_) => client_agent.disconnect(), - } - } - Err(e) => { - log::error!("{}\n Dropping WebSocket message and continuing.", e); - futures::future::ok(true) - } - Ok(Async::Ready(None)) => { - log::info!("WebSocket ClientAgent got Ready(None)"); - futures::future::ok(true) - } - } - }) - .for_each(move |_instant| Ok(())) - .then(move |result| { - log::info!("WebSocket connection for {:?} closed.", timeline); - result - }) - .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e)) + Self { + ws_tx, + unsubscribe_tx, + subscription, + } } - pub fn send_to_sse(mut client_agent: ClientAgent, sse: Sse, interval: Duration) -> impl Reply { - let event_stream = - tokio::timer::Interval::new(Instant::now(), interval).filter_map(move |_| { - match client_agent.poll() { - Ok(Async::Ready(Some(event))) => Some(( - warp::sse::event(event.event_name()), - warp::sse::data(event.payload().unwrap_or_else(String::new)), - )), - Ok(Async::Ready(None)) => { - log::info!("SSE ClientAgent got Ready(None)"); - None - } - Ok(Async::NotReady) => None, - Err(e) => { - log::error!("{}\n Dropping SSE message and continuing.", e); - None - } + pub fn send_events( + mut self, + event_rx: watch::Receiver<(Timeline, Event)>, + ) -> impl Future { + let target_timeline = self.subscription.timeline; + + event_rx.map_err(|_| ()).for_each(move |(tl, event)| { + if matches!(event, Event::Ping) { + self.send_ping() + } else if target_timeline == tl { + use crate::messages::{CheckedEvent::Update, Event::*}; + use crate::parse_client_request::Stream::Public; + let blocks = &self.subscription.blocks; + let allowed_langs = &self.subscription.allowed_langs; + + match event { + TypeSafe(Update { payload, queued_at }) => match tl { + Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()), + _ if payload.involves_any(&blocks) => Ok(()), + _ => self.send_msg(TypeSafe(Update { payload, queued_at })), + }, + TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)), + Dynamic(event) if event.event == "update" => match tl { + Timeline(Public, _, _) if event.language_not(allowed_langs) => Ok(()), + _ if event.involves_any(&blocks) => Ok(()), + _ => self.send_msg(Dynamic(event)), + }, + Dynamic(non_update) => self.send_msg(Dynamic(non_update)), + Ping => unreachable!(), // handled pings above } + } else { + Ok(()) + } + }) + } + + fn send_ping(&mut self) -> Result<(), ()> { + self.send_txt("{}") + } + + fn send_msg(&mut self, event: Event) -> Result<(), ()> { + self.send_txt(&event.to_json_string()) + } + + fn send_txt(&mut self, txt: &str) -> Result<(), ()> { + let tl = self.subscription.timeline; + match self.ws_tx.try_send(Message::text(txt)) { + Ok(_) => Ok(()), + Err(_) => { + self.unsubscribe_tx.try_send(tl).expect("TODO"); + Err(()) + } + } + } +} + +pub struct SseStream {} + +impl SseStream { + fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> { + Some(( + warp::sse::event(event.event_name()), + warp::sse::data(event.payload().unwrap_or_else(String::new)), + )) + } + + pub fn send_events( + sse: Sse, + mut unsubscribe_tx: mpsc::UnboundedSender, + subscription: Subscription, + sse_rx: watch::Receiver<(Timeline, Event)>, + ) -> impl Reply { + let target_timeline = subscription.timeline; + let allowed_langs = subscription.allowed_langs; + let blocks = subscription.blocks; + + let event_stream = sse_rx + .filter_map(move |(timeline, event)| { + if target_timeline == timeline { + use crate::messages::{CheckedEvent, CheckedEvent::Update, Event::*}; + use crate::parse_client_request::Stream::Public; + match event { + TypeSafe(Update { payload, queued_at }) => match timeline { + Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, + _ if payload.involves_any(&blocks) => None, + _ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { + payload, + queued_at, + })), + }, + TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), + Dynamic(event) if event.event == "update" => match timeline { + Timeline(Public, _, _) if event.language_not(&allowed_langs) => None, + _ if event.involves_any(&blocks) => None, + _ => Self::reply_with(Event::Dynamic(event)), + }, + Dynamic(non_update) => Self::reply_with(Event::Dynamic(non_update)), + Ping => None, // pings handled automatically + } + } else { + None + } + }) + .then(move |res| { + unsubscribe_tx.try_send(target_timeline).expect("TODO"); + res }); sse.reply( diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index 3f419d1..f8b88b9 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -1,10 +1,12 @@ //! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. -mod client_agent; mod event_stream; mod receiver; mod redis; -pub use {client_agent::ClientAgent, event_stream::EventStream, receiver::Receiver}; +pub use { + event_stream::{SseStream, WsStream}, + receiver::Receiver, +}; #[cfg(feature = "bench")] pub use redis::redis_msg::{RedisMsg, RedisParseOutput}; diff --git a/src/redis_to_client_stream/receiver/message_queues.rs b/src/redis_to_client_stream/receiver/message_queues.rs deleted file mode 100644 index ace23d8..0000000 --- a/src/redis_to_client_stream/receiver/message_queues.rs +++ /dev/null @@ -1,53 +0,0 @@ -use crate::messages::Event; -use crate::parse_client_request::Timeline; - -use hashbrown::HashMap; -use std::{collections::VecDeque, fmt}; -use uuid::Uuid; - -#[derive(Clone)] -pub struct MsgQueue { - pub timeline: Timeline, - pub messages: VecDeque, -} - -impl MsgQueue { - pub fn new(timeline: Timeline) -> Self { - MsgQueue { - messages: VecDeque::new(), - - timeline, - } - } -} - -#[derive(Debug)] -pub struct MessageQueues(pub HashMap); - -impl MessageQueues {} - -impl fmt::Debug for MsgQueue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "\ -MsgQueue {{ - timeline: {:?}, - messages: {:?}, -}}", - self.timeline, self.messages, - ) - } -} - -impl std::ops::Deref for MessageQueues { - type Target = HashMap; - fn deref(&self) -> &Self::Target { - &self.0 - } -} -impl std::ops::DerefMut for MessageQueues { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} diff --git a/src/redis_to_client_stream/receiver/mod.rs b/src/redis_to_client_stream/receiver/mod.rs index 2df1cbc..3d7bed3 100644 --- a/src/redis_to_client_stream/receiver/mod.rs +++ b/src/redis_to_client_stream/receiver/mod.rs @@ -2,10 +2,7 @@ //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. mod err; -mod message_queues; - pub use err::ReceiverErr; -pub use message_queues::{MessageQueues, MsgQueue}; use super::redis::{redis_connection::RedisCmd, RedisConn}; @@ -15,11 +12,9 @@ use crate::{ parse_client_request::{Stream, Subscription, Timeline}, }; -use { - futures::{Async, Poll}, - hashbrown::HashMap, - uuid::Uuid, -}; +use futures::{Async, Stream as _Stream}; +use hashbrown::HashMap; +use tokio::sync::{mpsc, watch}; use std::{ result, @@ -33,25 +28,28 @@ type Result = result::Result; #[derive(Debug)] pub struct Receiver { redis_connection: RedisConn, - redis_poll_interval: Duration, - redis_polled_at: Instant, - pub msg_queues: MessageQueues, clients_per_timeline: HashMap, + tx: watch::Sender<(Timeline, Event)>, + rx: mpsc::UnboundedReceiver, + ping_time: Instant, } impl Receiver { /// Create a new `Receiver`, with its own Redis connections (but, as yet, no /// active subscriptions). - pub fn try_from(redis_cfg: config::RedisConfig) -> Result { - let redis_poll_interval = *redis_cfg.polling_interval; - let redis_connection = RedisConn::new(redis_cfg)?; + pub fn try_from( + redis_cfg: config::RedisConfig, + tx: watch::Sender<(Timeline, Event)>, + rx: mpsc::UnboundedReceiver, + ) -> Result { Ok(Self { - redis_polled_at: Instant::now(), - redis_poll_interval, - redis_connection, - msg_queues: MessageQueues(HashMap::new()), + redis_connection: RedisConn::new(redis_cfg)?, + clients_per_timeline: HashMap::new(), + tx, + rx, + ping_time: Instant::now(), }) } @@ -59,15 +57,12 @@ impl Receiver { Arc::new(Mutex::new(self)) } - /// Assigns the `Receiver` a new timeline to monitor and runs other - /// first-time setup. - pub fn add_subscription(&mut self, subscription: &Subscription) -> Result<()> { + pub fn subscribe(&mut self, subscription: &Subscription) -> Result<()> { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) { self.redis_connection.update_cache(hashtag, id); }; - self.msg_queues.insert(subscription.id, MsgQueue::new(tl)); let number_of_subscriptions = self .clients_per_timeline @@ -79,13 +74,11 @@ impl Receiver { if *number_of_subscriptions == 1 { self.redis_connection.send_cmd(Subscribe, &tl)? }; - + log::info!("Started stream for {:?}", tl); Ok(()) } - pub fn remove_subscription(&mut self, subscription: &Subscription) -> Result<()> { - let tl = subscription.timeline; - self.msg_queues.remove(&subscription.id); + pub fn unsubscribe(&mut self, tl: Timeline) -> Result<()> { let number_of_subscriptions = self .clients_per_timeline .entry(tl) @@ -102,48 +95,30 @@ impl Receiver { self.redis_connection.send_cmd(Unsubscribe, &tl)?; self.clients_per_timeline.remove_entry(&tl); }; - + log::info!("Ended stream for {:?}", tl); Ok(()) } - /// Returns the oldest message in the `ClientAgent`'s queue (if any). - /// - /// Note: This method does **not** poll Redis every time, because polling - /// Redis is significantly more time consuming that simply returning the - /// message already in a queue. Thus, we only poll Redis if it has not - /// been polled lately. - pub fn poll_for(&mut self, id: Uuid) -> Poll, ReceiverErr> { - // let (t1, mut polled_redis) = (Instant::now(), false); - if self.redis_polled_at.elapsed() > self.redis_poll_interval { - loop { - match self.redis_connection.poll_redis() { - Ok(Async::NotReady) => break, - Ok(Async::Ready(Some((timeline, event)))) => { - self.msg_queues - .values_mut() - .filter(|msg_queue| msg_queue.timeline == timeline) - .for_each(|msg_queue| { - msg_queue.messages.push_back(event.clone()); - }); - } - Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace - Err(err) => Err(err)?, - } - } - // polled_redis = true; - self.redis_polled_at = Instant::now(); + pub fn poll_broadcast(&mut self) { + while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { + self.unsubscribe(tl).expect("TODO"); } - // If the `msg_queue` being polled has any new messages, return the first (oldest) one - let msg_q = self.msg_queues.get_mut(&id).ok_or(ReceiverErr::InvalidId)?; - let res = match msg_q.messages.pop_front() { - Some(event) => Ok(Async::Ready(Some(event))), - None => Ok(Async::NotReady), - }; - // if !polled_redis { - // log::info!("poll_for in {:?}", t1.elapsed()); - // } - res + if self.ping_time.elapsed() > Duration::from_secs(30) { + self.ping_time = Instant::now(); + self.tx + .broadcast((Timeline::empty(), Event::Ping)) + .expect("TODO"); + } else { + match self.redis_connection.poll_redis() { + Ok(Async::NotReady) => (), + Ok(Async::Ready(Some((timeline, event)))) => { + self.tx.broadcast((timeline, event)).expect("TODO"); + } + Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace + Err(_err) => panic!("TODO"), + } + } } pub fn count_connections(&self) -> String { @@ -166,14 +141,4 @@ impl Receiver { }) .collect() } - - pub fn queue_length(&self) -> String { - format!( - "Longest MessageQueue: {}", - self.msg_queues - .0 - .values() - .fold(0, |acc, el| acc.max(el.messages.len())) - ) - } } diff --git a/src/redis_to_client_stream/redis/redis_connection/mod.rs b/src/redis_to_client_stream/redis/redis_connection/mod.rs index f34bc8b..872fac0 100644 --- a/src/redis_to_client_stream/redis/redis_connection/mod.rs +++ b/src/redis_to_client_stream/redis/redis_connection/mod.rs @@ -80,7 +80,7 @@ impl RedisConn { self.redis_input.clear(); let (input, invalid_bytes) = str::from_utf8(&input) - .map(|input| (input, "".as_bytes())) + .map(|input| (input, &b""[..])) .unwrap_or_else(|e| { let (valid, invalid) = input.split_at(e.valid_up_to()); (str::from_utf8(valid).expect("Guaranteed by ^^^^"), invalid)