From 2725439110511e63514e4816a3b2c53619d38945 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Thu, 23 Apr 2020 19:28:26 -0400 Subject: [PATCH] Update concurrency primitive. (#139) * Initial [WIP] implementation This initial implementation works to send messages but does not yet handle unsubscribing properly. * Implement UnboundedSender * Implement UnboundedChannels for concurrency --- Cargo.lock | 12 +- Cargo.toml | 3 +- src/main.rs | 30 ++--- src/response/event.rs | 2 +- src/response/event/checked_event.rs | 2 +- src/response/event/checked_event/account.rs | 6 +- .../event/checked_event/announcement.rs | 2 +- .../checked_event/announcement_reaction.rs | 2 +- .../event/checked_event/conversation.rs | 2 +- src/response/event/checked_event/emoji.rs | 2 +- src/response/event/checked_event/mention.rs | 2 +- .../event/checked_event/notification.rs | 4 +- src/response/event/checked_event/status.rs | 2 +- .../event/checked_event/status/application.rs | 2 +- .../event/checked_event/status/attachment.rs | 4 +- .../event/checked_event/status/card.rs | 4 +- .../event/checked_event/status/poll.rs | 4 +- src/response/event/checked_event/tag.rs | 4 +- .../event/checked_event/visibility.rs | 2 +- src/response/event/dynamic_event.rs | 6 +- src/response/redis/manager.rs | 100 ++++++++-------- src/response/redis/manager/err.rs | 10 +- src/response/stream/sse.rs | 55 ++++----- src/response/stream/ws.rs | 111 ++++++++---------- 24 files changed, 175 insertions(+), 198 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7008e9f..6356baf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,7 +416,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.9.0" +version = "0.9.1" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -437,6 +437,7 @@ dependencies = [ "tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)", "url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)", "warp 0.1.20 (git+https://github.com/seanmonstar/warp.git)", ] @@ -2223,6 +2224,14 @@ name = "utf-8" version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "uuid" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "vcpkg" version = "0.2.7" @@ -2589,6 +2598,7 @@ dependencies = [ "checksum url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "75b414f6c464c879d7f9babf951f23bc3743fb7313c081b2e6ca719067ea9d61" "checksum urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3df3561629a8bb4c57e5a2e4c43348d9e29c7c29d9b1c4c1f47166deca8f37ed" "checksum utf-8 0.7.5 (registry+https://github.com/rust-lang/crates.io-index)" = "05e42f7c18b8f902290b009cde6d651262f956c98bc51bca4cd1d511c9cd85c7" +"checksum uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9fde2f6a4bea1d6e007c4ad38c6839fa71cbb63b6dbf5b595aa38dc9b1093c11" "checksum vcpkg 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "33dd455d0f96e90a75803cfeb7f948768c08d70a6de9a8d2362461935698bf95" "checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" "checksum walkdir 2.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "9658c94fa8b940eab2250bd5a457f9c48b748420d71293b165c8cdbe2f55f71e" diff --git a/Cargo.toml b/Cargo.toml index 495047f..7df2294 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.9.0" +version = "0.9.1" authors = ["Daniel Long Sockwell "] edition = "2018" @@ -25,6 +25,7 @@ r2d2 = "0.8.8" lru = "0.4.3" urlencoding = "1.0.0" hashbrown = "0.7.1" +uuid = { version = "0.8.1", features = ["v4"] } [dev-dependencies] criterion = "0.3" diff --git a/src/main.rs b/src/main.rs index 19d47dc..3d9f191 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use flodgatt::config; -use flodgatt::request::{Handler, Subscription, Timeline}; -use flodgatt::response::{Event, RedisManager, SseStream, WsStream}; +use flodgatt::request::{Handler, Subscription}; +use flodgatt::response::{RedisManager, SseStream, WsStream}; use flodgatt::Error; use futures::{future::lazy, stream::Stream as _}; @@ -9,7 +9,7 @@ use std::net::SocketAddr; use std::os::unix::fs::PermissionsExt; use std::time::Instant; use tokio::net::UnixListener; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; use tokio::timer::Interval; use warp::ws::Ws2; use warp::Filter; @@ -20,25 +20,21 @@ fn main() -> Result<(), Error> { let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?; let poll_freq = *redis_cfg.polling_interval; - // Create channels to communicate between threads - let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); - let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); - let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?; - let shared_manager = RedisManager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc(); + let shared_manager = RedisManager::try_from(&redis_cfg)?.into_arc(); // Server Sent Events let sse_manager = shared_manager.clone(); - let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone()); let sse = request .sse_subscription() .and(warp::sse()) .map(move |subscription: Subscription, sse: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); let mut manager = sse_manager.lock().unwrap_or_else(RedisManager::recover); - manager.subscribe(&subscription); - - SseStream::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) + let (event_tx, event_rx) = mpsc::unbounded_channel(); + manager.subscribe(&subscription, event_tx); + let sse_stream = SseStream::new(subscription); + sse_stream.send_events(sse, event_rx) }) .with(warp::reply::with::header("Connection", "keep-alive")); @@ -50,11 +46,15 @@ fn main() -> Result<(), Error> { .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); let mut manager = ws_manager.lock().unwrap_or_else(RedisManager::recover); - manager.subscribe(&subscription); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + manager.subscribe(&subscription, event_tx); let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security - let ws_stream = WsStream::new(cmd_tx.clone(), event_rx.clone(), subscription); + let ws_stream = WsStream::new(subscription); - (ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token) + ( + ws.on_upgrade(move |ws| ws_stream.send_to(ws, event_rx)), + token, + ) }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); diff --git a/src/response/event.rs b/src/response/event.rs index 4227124..091a250 100644 --- a/src/response/event.rs +++ b/src/response/event.rs @@ -12,7 +12,7 @@ use std::convert::TryFrom; use std::string::String; use warp::sse::ServerSentEvent; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Event { TypeSafe(CheckedEvent), Dynamic(DynEvent), diff --git a/src/response/event/checked_event.rs b/src/response/event/checked_event.rs index b763c8b..5176534 100644 --- a/src/response/event/checked_event.rs +++ b/src/response/event/checked_event.rs @@ -22,7 +22,7 @@ use serde::Deserialize; #[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] #[rustfmt::skip] -#[derive(Deserialize, Debug, Clone, PartialEq)] +#[derive(Deserialize, Debug, Clone, PartialEq, Eq)] pub enum CheckedEvent { Update { payload: Status, queued_at: Option }, Notification { payload: Notification }, diff --git a/src/response/event/checked_event/account.rs b/src/response/event/checked_event/account.rs index 0c00897..53b670e 100644 --- a/src/response/event/checked_event/account.rs +++ b/src/response/event/checked_event/account.rs @@ -3,7 +3,7 @@ use crate::Id; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Account { pub id: Id, username: String, @@ -31,7 +31,7 @@ pub(super) struct Account { } #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] struct Field { name: String, value: String, @@ -39,7 +39,7 @@ struct Field { } #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] struct Source { note: String, fields: Vec, diff --git a/src/response/event/checked_event/announcement.rs b/src/response/event/checked_event/announcement.rs index 4ac88b9..883a02a 100644 --- a/src/response/event/checked_event/announcement.rs +++ b/src/response/event/checked_event/announcement.rs @@ -2,7 +2,7 @@ use super::{emoji::Emoji, mention::Mention, tag::Tag, AnnouncementReaction}; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct Announcement { // Fully undocumented id: String, diff --git a/src/response/event/checked_event/announcement_reaction.rs b/src/response/event/checked_event/announcement_reaction.rs index b17b0be..02d6989 100644 --- a/src/response/event/checked_event/announcement_reaction.rs +++ b/src/response/event/checked_event/announcement_reaction.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct AnnouncementReaction { #[serde(skip_serializing_if = "Option::is_none")] announcement_id: Option, diff --git a/src/response/event/checked_event/conversation.rs b/src/response/event/checked_event/conversation.rs index 5b9fde9..46aa380 100644 --- a/src/response/event/checked_event/conversation.rs +++ b/src/response/event/checked_event/conversation.rs @@ -2,7 +2,7 @@ use super::{account::Account, status::Status}; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct Conversation { id: String, accounts: Vec, diff --git a/src/response/event/checked_event/emoji.rs b/src/response/event/checked_event/emoji.rs index 836f341..f274113 100644 --- a/src/response/event/checked_event/emoji.rs +++ b/src/response/event/checked_event/emoji.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Emoji { shortcode: String, url: String, diff --git a/src/response/event/checked_event/mention.rs b/src/response/event/checked_event/mention.rs index 14c47ad..c26ee2c 100644 --- a/src/response/event/checked_event/mention.rs +++ b/src/response/event/checked_event/mention.rs @@ -2,7 +2,7 @@ use crate::Id; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Mention { pub id: Id, username: String, diff --git a/src/response/event/checked_event/notification.rs b/src/response/event/checked_event/notification.rs index 9f70f90..8958bfd 100644 --- a/src/response/event/checked_event/notification.rs +++ b/src/response/event/checked_event/notification.rs @@ -2,7 +2,7 @@ use super::{account::Account, status::Status}; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct Notification { id: String, r#type: NotificationType, @@ -12,7 +12,7 @@ pub struct Notification { } #[serde(rename_all = "snake_case", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] enum NotificationType { Follow, FollowRequest, // Undocumented diff --git a/src/response/event/checked_event/status.rs b/src/response/event/checked_event/status.rs index 684cf67..ea1655c 100644 --- a/src/response/event/checked_event/status.rs +++ b/src/response/event/checked_event/status.rs @@ -20,7 +20,7 @@ use std::boxed::Box; use std::string::String; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct Status { id: Id, uri: String, diff --git a/src/response/event/checked_event/status/application.rs b/src/response/event/checked_event/status/application.rs index 1fd2f88..0688056 100644 --- a/src/response/event/checked_event/status/application.rs +++ b/src/response/event/checked_event/status/application.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Application { name: String, website: Option, diff --git a/src/response/event/checked_event/status/attachment.rs b/src/response/event/checked_event/status/attachment.rs index bd76c14..8e2eed0 100644 --- a/src/response/event/checked_event/status/attachment.rs +++ b/src/response/event/checked_event/status/attachment.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Attachment { id: String, r#type: AttachmentType, @@ -15,7 +15,7 @@ pub(super) struct Attachment { } #[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] enum AttachmentType { Unknown, Image, diff --git a/src/response/event/checked_event/status/card.rs b/src/response/event/checked_event/status/card.rs index 2a5667f..5cec657 100644 --- a/src/response/event/checked_event/status/card.rs +++ b/src/response/event/checked_event/status/card.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Card { url: String, title: String, @@ -19,7 +19,7 @@ pub(super) struct Card { } #[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] enum CardType { Link, Photo, diff --git a/src/response/event/checked_event/status/poll.rs b/src/response/event/checked_event/status/poll.rs index 908358e..3c99915 100644 --- a/src/response/event/checked_event/status/poll.rs +++ b/src/response/event/checked_event/status/poll.rs @@ -2,7 +2,7 @@ use super::super::emoji::Emoji; use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Poll { id: String, expires_at: String, @@ -17,7 +17,7 @@ pub(super) struct Poll { } #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] struct PollOptions { title: String, votes_count: Option, diff --git a/src/response/event/checked_event/tag.rs b/src/response/event/checked_event/tag.rs index 99fe927..bae5a65 100644 --- a/src/response/event/checked_event/tag.rs +++ b/src/response/event/checked_event/tag.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) struct Tag { name: String, url: String, @@ -9,7 +9,7 @@ pub(super) struct Tag { } #[serde(deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] struct History { day: String, uses: String, diff --git a/src/response/event/checked_event/visibility.rs b/src/response/event/checked_event/visibility.rs index 2c3efba..1334b7a 100644 --- a/src/response/event/checked_event/visibility.rs +++ b/src/response/event/checked_event/visibility.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; #[serde(rename_all = "lowercase", deny_unknown_fields)] -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub(super) enum Visibility { Public, Unlisted, diff --git a/src/response/event/dynamic_event.rs b/src/response/event/dynamic_event.rs index 9ab9db4..f65b5fa 100644 --- a/src/response/event/dynamic_event.rs +++ b/src/response/event/dynamic_event.rs @@ -8,7 +8,7 @@ use hashbrown::HashSet; use serde::{Deserialize, Serialize}; use serde_json::Value; -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct DynEvent { #[serde(skip)] pub(crate) kind: EventKind, @@ -17,7 +17,7 @@ pub struct DynEvent { pub(crate) queued_at: Option, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum EventKind { Update(DynStatus), NonUpdate, @@ -29,7 +29,7 @@ impl Default for EventKind { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct DynStatus { pub(crate) id: Id, pub(crate) username: String, diff --git a/src/response/redis/manager.rs b/src/response/redis/manager.rs index b32eef2..3c7f5cc 100644 --- a/src/response/redis/manager.rs +++ b/src/response/redis/manager.rs @@ -11,37 +11,29 @@ use crate::request::{Subscription, Timeline}; pub(self) use super::EventErr; -use futures::{Async, Stream as _Stream}; +use futures::Async; use hashbrown::HashMap; use std::sync::{Arc, Mutex, MutexGuard, PoisonError}; use std::time::{Duration, Instant}; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc::UnboundedSender; +use uuid::Uuid; type Result = std::result::Result; /// The item that streams from Redis and is polled by the `ClientAgent` -#[derive(Debug)] pub struct Manager { redis_connection: RedisConn, - clients_per_timeline: HashMap, - tx: watch::Sender<(Timeline, Event)>, - rx: mpsc::UnboundedReceiver, + timelines: HashMap>>, ping_time: Instant, } impl Manager { /// Create a new `Manager`, with its own Redis connections (but, as yet, no /// active subscriptions). - pub fn try_from( - redis_cfg: &config::Redis, - tx: watch::Sender<(Timeline, Event)>, - rx: mpsc::UnboundedReceiver, - ) -> Result { + pub fn try_from(redis_cfg: &config::Redis) -> Result { Ok(Self { redis_connection: RedisConn::new(redis_cfg)?, - clients_per_timeline: HashMap::new(), - tx, - rx, + timelines: HashMap::new(), ping_time: Instant::now(), }) } @@ -50,64 +42,64 @@ impl Manager { Arc::new(Mutex::new(self)) } - pub fn subscribe(&mut self, subscription: &Subscription) { + pub fn subscribe(&mut self, subscription: &Subscription, channel: UnboundedSender) { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); if let (Some(hashtag), Some(id)) = (tag, tl.tag()) { self.redis_connection.update_cache(hashtag, id); }; - let number_of_subscriptions = self - .clients_per_timeline - .entry(tl) - .and_modify(|n| *n += 1) - .or_insert(1); + let channels = self.timelines.entry(tl).or_default(); + channels.insert(Uuid::new_v4(), channel); - use RedisCmd::*; - if *number_of_subscriptions == 1 { + if channels.len() == 1 { self.redis_connection - .send_cmd(Subscribe, &tl) + .send_cmd(RedisCmd::Subscribe, &tl) .unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e)); }; } - pub(crate) fn unsubscribe(&mut self, tl: Timeline) -> Result<()> { - let number_of_subscriptions = self - .clients_per_timeline - .entry(tl) - .and_modify(|n| *n -= 1) - .or_insert_with(|| { - log::error!( - "Attempted to unsubscribe from a timeline to which you were not subscribed: {:?}", - tl - ); - 0 - }); - use RedisCmd::*; - if *number_of_subscriptions == 0 { - self.redis_connection.send_cmd(Unsubscribe, &tl)?; - self.clients_per_timeline.remove_entry(&tl); + pub(crate) fn unsubscribe(&mut self, tl: &mut Timeline, id: &Uuid) -> Result<()> { + let channels = self.timelines.get_mut(tl).ok_or(Error::InvalidId)?; + channels.remove(id); + + if channels.len() == 0 { + self.redis_connection.send_cmd(RedisCmd::Unsubscribe, &tl)?; + self.timelines.remove(&tl); }; log::info!("Ended stream for {:?}", tl); Ok(()) } pub fn poll_broadcast(&mut self) -> Result<()> { - while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { - self.unsubscribe(tl)? - } - + let mut completed_timelines = Vec::new(); if self.ping_time.elapsed() > Duration::from_secs(30) { self.ping_time = Instant::now(); - self.tx.broadcast((Timeline::empty(), Event::Ping))? - } else { - match self.redis_connection.poll_redis() { - Ok(Async::NotReady) | Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace - Ok(Async::Ready(Some((timeline, event)))) => { - self.tx.broadcast((timeline, event))? + for (timeline, channels) in self.timelines.iter_mut() { + for (uuid, channel) in channels.iter_mut() { + match channel.try_send(Event::Ping) { + Ok(_) => (), + Err(_) => completed_timelines.push((*timeline, *uuid)), + } } + } + }; + loop { + match self.redis_connection.poll_redis() { + Ok(Async::NotReady) => break, + Ok(Async::Ready(Some((tl, event)))) => { + for (uuid, tx) in self.timelines.get_mut(&tl).ok_or(Error::InvalidId)? { + tx.try_send(event.clone()) + .unwrap_or_else(|_| completed_timelines.push((tl, *uuid))) + } + } + Ok(Async::Ready(None)) => (), // cmd or msg for other namespace Err(err) => log::error!("{}", err), // drop msg, log err, and proceed } } + + for (tl, channel) in completed_timelines.iter_mut() { + self.unsubscribe(tl, &channel)?; + } Ok(()) } @@ -119,20 +111,20 @@ impl Manager { pub fn count(&self) -> String { format!( "Current connections: {}", - self.clients_per_timeline.values().sum::() + self.timelines.values().map(|el| el.len()).sum::() ) } pub fn list(&self) -> String { let max_len = self - .clients_per_timeline + .timelines .keys() .fold(0, |acc, el| acc.max(format!("{:?}:", el).len())); - self.clients_per_timeline + self.timelines .iter() - .map(|(tl, n)| { + .map(|(tl, channel_map)| { let tl_txt = format!("{:?}:", tl); - format!("{:>1$} {2}\n", tl_txt, max_len, n) + format!("{:>1$} {2}\n", tl_txt, max_len, channel_map.len()) }) .collect() } diff --git a/src/response/redis/manager/err.rs b/src/response/redis/manager/err.rs index 8d6d4f2..9f7dcad 100644 --- a/src/response/redis/manager/err.rs +++ b/src/response/redis/manager/err.rs @@ -6,11 +6,13 @@ use std::fmt; #[derive(Debug)] pub enum Error { InvalidId, + TimelineErr(TimelineErr), EventErr(EventErr), RedisParseErr(RedisParseErr), RedisConnErr(RedisConnErr), ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>), + ChannelSendErr2(tokio::sync::mpsc::error::UnboundedTrySendError), } impl std::error::Error for Error {} @@ -21,13 +23,14 @@ impl fmt::Display for Error { match self { InvalidId => write!( f, - "Attempted to get messages for a subscription that had not been set up." + "tried to access a timeline/channel subscription that does not exist" ), EventErr(inner) => write!(f, "{}", inner), RedisParseErr(inner) => write!(f, "{}", inner), RedisConnErr(inner) => write!(f, "{}", inner), TimelineErr(inner) => write!(f, "{}", inner), ChannelSendErr(inner) => write!(f, "{}", inner), + ChannelSendErr2(inner) => write!(f, "{}", inner), }?; Ok(()) } @@ -38,6 +41,11 @@ impl From> for Error { Self::ChannelSendErr(error) } } +impl From> for Error { + fn from(error: tokio::sync::mpsc::error::UnboundedTrySendError) -> Self { + Self::ChannelSendErr2(error) + } +} impl From for Error { fn from(error: EventErr) -> Self { diff --git a/src/response/stream/sse.rs b/src/response/stream/sse.rs index d427440..973c04c 100644 --- a/src/response/stream/sse.rs +++ b/src/response/stream/sse.rs @@ -1,44 +1,29 @@ use super::{Event, Payload}; -use crate::request::{Subscription, Timeline}; +use crate::request::Subscription; use futures::stream::Stream; -use log; use std::time::Duration; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc::UnboundedReceiver; use warp::reply::Reply; use warp::sse::Sse as WarpSse; -pub struct Sse; +type EventRx = UnboundedReceiver; + +pub struct Sse(Subscription); impl Sse { - pub fn send_events( - sse: WarpSse, - mut unsubscribe_tx: mpsc::UnboundedSender, - subscription: Subscription, - sse_rx: watch::Receiver<(Timeline, Event)>, - ) -> impl Reply { - let target_timeline = subscription.timeline; + pub fn new(subscription: Subscription) -> Self { + Self(subscription) + } - let event_stream = sse_rx - .filter(move |(timeline, _)| target_timeline == *timeline) - .filter_map(move |(_timeline, event)| { - match (event.update_payload(), event.dyn_update_payload()) { - (Some(update), _) if Sse::update_not_filtered(subscription.clone(), update) => { - event.to_warp_reply() - } - (None, None) => event.to_warp_reply(), // send all non-updates - (_, Some(update)) if Sse::update_not_filtered(subscription.clone(), update) => { - event.to_warp_reply() - } - (_, _) => None, - } - }) - .then(move |res| { - unsubscribe_tx - .try_send(target_timeline) - .unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e)); - res - }); + pub fn send_events(self, sse: WarpSse, event_rx: EventRx) -> impl Reply { + let event_stream = event_rx.filter_map(move |event| { + match (event.update_payload(), event.dyn_update_payload()) { + (Some(update), _) if self.update_not_filtered(update) => event.to_warp_reply(), + (_, Some(update)) if self.update_not_filtered(update) => event.to_warp_reply(), + (_, _) => event.to_warp_reply(), // send all non-updates + } + }); sse.reply( warp::sse::keep_alive() @@ -48,11 +33,11 @@ impl Sse { ) } - fn update_not_filtered(subscription: Subscription, update: &impl Payload) -> bool { - let blocks = &subscription.blocks; - let allowed_langs = &subscription.allowed_langs; + fn update_not_filtered(&self, update: &impl Payload) -> bool { + let blocks = &self.0.blocks; + let allowed_langs = &self.0.allowed_langs; - match subscription.timeline { + match self.0.timeline { tl if tl.is_public() && !update.language_unset() && !allowed_langs.is_empty() diff --git a/src/response/stream/ws.rs b/src/response/stream/ws.rs index 4485409..e4d38e1 100644 --- a/src/response/stream/ws.rs +++ b/src/response/stream/ws.rs @@ -1,41 +1,30 @@ use super::{Event, Payload}; -use crate::request::{Subscription, Timeline}; +use crate::request::Subscription; -use futures::{future::Future, stream::Stream}; -use tokio::sync::{mpsc, watch}; +use futures::future::Future; +use futures::stream::Stream; +use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use warp::ws::{Message, WebSocket}; -type Result = std::result::Result; +type EventRx = UnboundedReceiver; +type MsgTx = UnboundedSender; -pub struct Ws { - unsubscribe_tx: mpsc::UnboundedSender, - subscription: Subscription, - ws_rx: watch::Receiver<(Timeline, Event)>, - ws_tx: Option>, -} +pub struct Ws(Subscription); impl Ws { - pub fn new( - unsubscribe_tx: mpsc::UnboundedSender, - ws_rx: watch::Receiver<(Timeline, Event)>, - subscription: Subscription, - ) -> Self { - Self { - unsubscribe_tx, - subscription, - ws_rx, - ws_tx: None, - } + pub fn new(subscription: Subscription) -> Self { + Self(subscription) } - pub fn send_to(mut self, ws: WebSocket) -> impl Future { + pub fn send_to( + mut self, + ws: WebSocket, + event_rx: EventRx, + ) -> impl Future { let (transmit_to_ws, _receive_from_ws) = ws.split(); - // Create a pipe - let (ws_tx, ws_rx) = mpsc::unbounded_channel(); - self.ws_tx = Some(ws_tx); - - // Send one end of it to a different green thread and tell that end to forward - // whatever it gets on to the WebSocket client + // Create a pipe, send one end of it to a different green thread and tell that end + // to forward to the WebSocket client + let (mut ws_tx, ws_rx) = mpsc::unbounded_channel(); warp::spawn( ws_rx .map_err(|_| -> warp::Error { unreachable!() }) @@ -49,61 +38,53 @@ impl Ws { }), ); - let target_timeline = self.subscription.timeline; - let incoming_events = self.ws_rx.clone().map_err(|_| ()); - - incoming_events.for_each(move |(tl, event)| { - //TODO log::info!("{:?}, {:?}", &tl, &event); + event_rx.map_err(|_| ()).for_each(move |event| { if matches!(event, Event::Ping) { - self.send_msg(&event)? - } else if target_timeline == tl { + send_msg(&event, &mut ws_tx)? + } else { match (event.update_payload(), event.dyn_update_payload()) { - (Some(update), _) => self.send_or_filter(tl, &event, update)?, - (None, None) => self.send_msg(&event)?, // send all non-updates - (_, Some(dyn_update)) => self.send_or_filter(tl, &event, dyn_update)?, - } + (Some(update), _) => self.send_or_filter(&event, update, &mut ws_tx), + (None, None) => send_msg(&event, &mut ws_tx), // send all non-updates + (_, Some(dyn_update)) => self.send_or_filter(&event, dyn_update, &mut ws_tx), + }? } Ok(()) }) } - fn send_or_filter(&mut self, tl: Timeline, event: &Event, update: &impl Payload) -> Result<()> { - let (blocks, allowed_langs) = (&self.subscription.blocks, &self.subscription.allowed_langs); - const SKIP: Result<()> = Ok(()); + fn send_or_filter( + &mut self, + event: &Event, + update: &impl Payload, + mut ws_tx: &mut MsgTx, + ) -> Result<(), ()> { + let (blocks, allowed_langs) = (&self.0.blocks, &self.0.allowed_langs); - match tl { + let skip = |reason, tl| Ok(log::info!("{:?} msg skipped - {}", tl, reason)); + + match self.0.timeline { tl if tl.is_public() && !update.language_unset() && !allowed_langs.is_empty() && !allowed_langs.contains(&update.language()) => { - log::info!("{:?} msg skipped - disallowed language", tl); - SKIP + skip("disallowed language", tl) } + tl if !blocks.blocked_users.is_disjoint(&update.involved_users()) => { - log::info!("{:?} msg skipped - involves blocked user", tl); - SKIP - } - tl if blocks.blocking_users.contains(update.author()) => { - log::info!("{:?} msg skipped - from blocking user", tl); - SKIP + skip("involves blocked user", tl) } + tl if blocks.blocking_users.contains(update.author()) => skip("from blocking user", tl), tl if blocks.blocked_domains.contains(update.sent_from()) => { - log::info!("{:?} msg skipped - from blocked domain", tl); - SKIP + skip("from blocked domain", tl) } - _ => Ok(self.send_msg(&event)?), + _ => Ok(send_msg(event, &mut ws_tx)?), } } - - fn send_msg(&mut self, event: &Event) -> Result<()> { - let txt = &event.to_json_string(); - let tl = self.subscription.timeline; - let mut channel = self.ws_tx.clone().ok_or(())?; - channel.try_send(Message::text(txt)).map_err(|_| { - self.unsubscribe_tx - .try_send(tl) - .unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e)); - }) - } +} + +fn send_msg(event: &Event, ws_tx: &mut MsgTx) -> Result<(), ()> { + ws_tx + .try_send(Message::text(&event.to_json_string())) + .map_err(|_| log::info!("WebSocket connection closed")) }