Update concurrency primitive. (#139)

* Initial [WIP] implementation

This initial implementation works to send messages but does not yet
handle unsubscribing properly.

* Implement UnboundedSender

* Implement UnboundedChannels for concurrency
This commit is contained in:
Daniel Sockwell 2020-04-23 19:28:26 -04:00 committed by GitHub
parent 91186fb9f7
commit 2725439110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 175 additions and 198 deletions

12
Cargo.lock generated
View File

@ -416,7 +416,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "flodgatt" name = "flodgatt"
version = "0.9.0" version = "0.9.1"
dependencies = [ dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",
@ -437,6 +437,7 @@ dependencies = [
"tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)", "tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)",
"url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
"warp 0.1.20 (git+https://github.com/seanmonstar/warp.git)", "warp 0.1.20 (git+https://github.com/seanmonstar/warp.git)",
] ]
@ -2223,6 +2224,14 @@ name = "utf-8"
version = "0.7.5" version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "uuid"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]] [[package]]
name = "vcpkg" name = "vcpkg"
version = "0.2.7" version = "0.2.7"
@ -2589,6 +2598,7 @@ dependencies = [
"checksum url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "75b414f6c464c879d7f9babf951f23bc3743fb7313c081b2e6ca719067ea9d61" "checksum url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "75b414f6c464c879d7f9babf951f23bc3743fb7313c081b2e6ca719067ea9d61"
"checksum urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3df3561629a8bb4c57e5a2e4c43348d9e29c7c29d9b1c4c1f47166deca8f37ed" "checksum urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3df3561629a8bb4c57e5a2e4c43348d9e29c7c29d9b1c4c1f47166deca8f37ed"
"checksum utf-8 0.7.5 (registry+https://github.com/rust-lang/crates.io-index)" = "05e42f7c18b8f902290b009cde6d651262f956c98bc51bca4cd1d511c9cd85c7" "checksum utf-8 0.7.5 (registry+https://github.com/rust-lang/crates.io-index)" = "05e42f7c18b8f902290b009cde6d651262f956c98bc51bca4cd1d511c9cd85c7"
"checksum uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9fde2f6a4bea1d6e007c4ad38c6839fa71cbb63b6dbf5b595aa38dc9b1093c11"
"checksum vcpkg 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "33dd455d0f96e90a75803cfeb7f948768c08d70a6de9a8d2362461935698bf95" "checksum vcpkg 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "33dd455d0f96e90a75803cfeb7f948768c08d70a6de9a8d2362461935698bf95"
"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd" "checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
"checksum walkdir 2.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "9658c94fa8b940eab2250bd5a457f9c48b748420d71293b165c8cdbe2f55f71e" "checksum walkdir 2.2.9 (registry+https://github.com/rust-lang/crates.io-index)" = "9658c94fa8b940eab2250bd5a457f9c48b748420d71293b165c8cdbe2f55f71e"

View File

@ -1,7 +1,7 @@
[package] [package]
name = "flodgatt" name = "flodgatt"
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
version = "0.9.0" version = "0.9.1"
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"] authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
edition = "2018" edition = "2018"
@ -25,6 +25,7 @@ r2d2 = "0.8.8"
lru = "0.4.3" lru = "0.4.3"
urlencoding = "1.0.0" urlencoding = "1.0.0"
hashbrown = "0.7.1" hashbrown = "0.7.1"
uuid = { version = "0.8.1", features = ["v4"] }
[dev-dependencies] [dev-dependencies]
criterion = "0.3" criterion = "0.3"

View File

@ -1,6 +1,6 @@
use flodgatt::config; use flodgatt::config;
use flodgatt::request::{Handler, Subscription, Timeline}; use flodgatt::request::{Handler, Subscription};
use flodgatt::response::{Event, RedisManager, SseStream, WsStream}; use flodgatt::response::{RedisManager, SseStream, WsStream};
use flodgatt::Error; use flodgatt::Error;
use futures::{future::lazy, stream::Stream as _}; use futures::{future::lazy, stream::Stream as _};
@ -9,7 +9,7 @@ use std::net::SocketAddr;
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
use std::time::Instant; use std::time::Instant;
use tokio::net::UnixListener; use tokio::net::UnixListener;
use tokio::sync::{mpsc, watch}; use tokio::sync::mpsc;
use tokio::timer::Interval; use tokio::timer::Interval;
use warp::ws::Ws2; use warp::ws::Ws2;
use warp::Filter; use warp::Filter;
@ -20,25 +20,21 @@ fn main() -> Result<(), Error> {
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?; let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?;
let poll_freq = *redis_cfg.polling_interval; let poll_freq = *redis_cfg.polling_interval;
// Create channels to communicate between threads
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?; let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?;
let shared_manager = RedisManager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc(); let shared_manager = RedisManager::try_from(&redis_cfg)?.into_arc();
// Server Sent Events // Server Sent Events
let sse_manager = shared_manager.clone(); let sse_manager = shared_manager.clone();
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
let sse = request let sse = request
.sse_subscription() .sse_subscription()
.and(warp::sse()) .and(warp::sse())
.map(move |subscription: Subscription, sse: warp::sse::Sse| { .map(move |subscription: Subscription, sse: warp::sse::Sse| {
log::info!("Incoming SSE request for {:?}", subscription.timeline); log::info!("Incoming SSE request for {:?}", subscription.timeline);
let mut manager = sse_manager.lock().unwrap_or_else(RedisManager::recover); let mut manager = sse_manager.lock().unwrap_or_else(RedisManager::recover);
manager.subscribe(&subscription); let (event_tx, event_rx) = mpsc::unbounded_channel();
manager.subscribe(&subscription, event_tx);
SseStream::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) let sse_stream = SseStream::new(subscription);
sse_stream.send_events(sse, event_rx)
}) })
.with(warp::reply::with::header("Connection", "keep-alive")); .with(warp::reply::with::header("Connection", "keep-alive"));
@ -50,11 +46,15 @@ fn main() -> Result<(), Error> {
.map(move |subscription: Subscription, ws: Ws2| { .map(move |subscription: Subscription, ws: Ws2| {
log::info!("Incoming websocket request for {:?}", subscription.timeline); log::info!("Incoming websocket request for {:?}", subscription.timeline);
let mut manager = ws_manager.lock().unwrap_or_else(RedisManager::recover); let mut manager = ws_manager.lock().unwrap_or_else(RedisManager::recover);
manager.subscribe(&subscription); let (event_tx, event_rx) = mpsc::unbounded_channel();
manager.subscribe(&subscription, event_tx);
let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security
let ws_stream = WsStream::new(cmd_tx.clone(), event_rx.clone(), subscription); let ws_stream = WsStream::new(subscription);
(ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token) (
ws.on_upgrade(move |ws| ws_stream.send_to(ws, event_rx)),
token,
)
}) })
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));

View File

@ -12,7 +12,7 @@ use std::convert::TryFrom;
use std::string::String; use std::string::String;
use warp::sse::ServerSentEvent; use warp::sse::ServerSentEvent;
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum Event { pub enum Event {
TypeSafe(CheckedEvent), TypeSafe(CheckedEvent),
Dynamic(DynEvent), Dynamic(DynEvent),

View File

@ -22,7 +22,7 @@ use serde::Deserialize;
#[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)] #[serde(rename_all = "snake_case", tag = "event", deny_unknown_fields)]
#[rustfmt::skip] #[rustfmt::skip]
#[derive(Deserialize, Debug, Clone, PartialEq)] #[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum CheckedEvent { pub enum CheckedEvent {
Update { payload: Status, queued_at: Option<i64> }, Update { payload: Status, queued_at: Option<i64> },
Notification { payload: Notification }, Notification { payload: Notification },

View File

@ -3,7 +3,7 @@ use crate::Id;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Account { pub(super) struct Account {
pub id: Id, pub id: Id,
username: String, username: String,
@ -31,7 +31,7 @@ pub(super) struct Account {
} }
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Field { struct Field {
name: String, name: String,
value: String, value: String,
@ -39,7 +39,7 @@ struct Field {
} }
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct Source { struct Source {
note: String, note: String,
fields: Vec<Field>, fields: Vec<Field>,

View File

@ -2,7 +2,7 @@ use super::{emoji::Emoji, mention::Mention, tag::Tag, AnnouncementReaction};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Announcement { pub struct Announcement {
// Fully undocumented // Fully undocumented
id: String, id: String,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct AnnouncementReaction { pub struct AnnouncementReaction {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
announcement_id: Option<String>, announcement_id: Option<String>,

View File

@ -2,7 +2,7 @@ use super::{account::Account, status::Status};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Conversation { pub struct Conversation {
id: String, id: String,
accounts: Vec<Account>, accounts: Vec<Account>,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Emoji { pub(super) struct Emoji {
shortcode: String, shortcode: String,
url: String, url: String,

View File

@ -2,7 +2,7 @@ use crate::Id;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Mention { pub(super) struct Mention {
pub id: Id, pub id: Id,
username: String, username: String,

View File

@ -2,7 +2,7 @@ use super::{account::Account, status::Status};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Notification { pub struct Notification {
id: String, id: String,
r#type: NotificationType, r#type: NotificationType,
@ -12,7 +12,7 @@ pub struct Notification {
} }
#[serde(rename_all = "snake_case", deny_unknown_fields)] #[serde(rename_all = "snake_case", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
enum NotificationType { enum NotificationType {
Follow, Follow,
FollowRequest, // Undocumented FollowRequest, // Undocumented

View File

@ -20,7 +20,7 @@ use std::boxed::Box;
use std::string::String; use std::string::String;
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct Status { pub struct Status {
id: Id, id: Id,
uri: String, uri: String,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Application { pub(super) struct Application {
name: String, name: String,
website: Option<String>, website: Option<String>,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Attachment { pub(super) struct Attachment {
id: String, id: String,
r#type: AttachmentType, r#type: AttachmentType,
@ -15,7 +15,7 @@ pub(super) struct Attachment {
} }
#[serde(rename_all = "lowercase", deny_unknown_fields)] #[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
enum AttachmentType { enum AttachmentType {
Unknown, Unknown,
Image, Image,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Card { pub(super) struct Card {
url: String, url: String,
title: String, title: String,
@ -19,7 +19,7 @@ pub(super) struct Card {
} }
#[serde(rename_all = "lowercase", deny_unknown_fields)] #[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
enum CardType { enum CardType {
Link, Link,
Photo, Photo,

View File

@ -2,7 +2,7 @@ use super::super::emoji::Emoji;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Poll { pub(super) struct Poll {
id: String, id: String,
expires_at: String, expires_at: String,
@ -17,7 +17,7 @@ pub(super) struct Poll {
} }
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct PollOptions { struct PollOptions {
title: String, title: String,
votes_count: Option<i32>, votes_count: Option<i32>,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) struct Tag { pub(super) struct Tag {
name: String, name: String,
url: String, url: String,
@ -9,7 +9,7 @@ pub(super) struct Tag {
} }
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
struct History { struct History {
day: String, day: String,
uses: String, uses: String,

View File

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[serde(rename_all = "lowercase", deny_unknown_fields)] #[serde(rename_all = "lowercase", deny_unknown_fields)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub(super) enum Visibility { pub(super) enum Visibility {
Public, Public,
Unlisted, Unlisted,

View File

@ -8,7 +8,7 @@ use hashbrown::HashSet;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct DynEvent { pub struct DynEvent {
#[serde(skip)] #[serde(skip)]
pub(crate) kind: EventKind, pub(crate) kind: EventKind,
@ -17,7 +17,7 @@ pub struct DynEvent {
pub(crate) queued_at: Option<i64>, pub(crate) queued_at: Option<i64>,
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum EventKind { pub(crate) enum EventKind {
Update(DynStatus), Update(DynStatus),
NonUpdate, NonUpdate,
@ -29,7 +29,7 @@ impl Default for EventKind {
} }
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct DynStatus { pub(crate) struct DynStatus {
pub(crate) id: Id, pub(crate) id: Id,
pub(crate) username: String, pub(crate) username: String,

View File

@ -11,37 +11,29 @@ use crate::request::{Subscription, Timeline};
pub(self) use super::EventErr; pub(self) use super::EventErr;
use futures::{Async, Stream as _Stream}; use futures::Async;
use hashbrown::HashMap; use hashbrown::HashMap;
use std::sync::{Arc, Mutex, MutexGuard, PoisonError}; use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::{mpsc, watch}; use tokio::sync::mpsc::UnboundedSender;
use uuid::Uuid;
type Result<T> = std::result::Result<T, Error>; type Result<T> = std::result::Result<T, Error>;
/// The item that streams from Redis and is polled by the `ClientAgent` /// The item that streams from Redis and is polled by the `ClientAgent`
#[derive(Debug)]
pub struct Manager { pub struct Manager {
redis_connection: RedisConn, redis_connection: RedisConn,
clients_per_timeline: HashMap<Timeline, i32>, timelines: HashMap<Timeline, HashMap<Uuid, UnboundedSender<Event>>>,
tx: watch::Sender<(Timeline, Event)>,
rx: mpsc::UnboundedReceiver<Timeline>,
ping_time: Instant, ping_time: Instant,
} }
impl Manager { impl Manager {
/// Create a new `Manager`, with its own Redis connections (but, as yet, no /// Create a new `Manager`, with its own Redis connections (but, as yet, no
/// active subscriptions). /// active subscriptions).
pub fn try_from( pub fn try_from(redis_cfg: &config::Redis) -> Result<Self> {
redis_cfg: &config::Redis,
tx: watch::Sender<(Timeline, Event)>,
rx: mpsc::UnboundedReceiver<Timeline>,
) -> Result<Self> {
Ok(Self { Ok(Self {
redis_connection: RedisConn::new(redis_cfg)?, redis_connection: RedisConn::new(redis_cfg)?,
clients_per_timeline: HashMap::new(), timelines: HashMap::new(),
tx,
rx,
ping_time: Instant::now(), ping_time: Instant::now(),
}) })
} }
@ -50,64 +42,64 @@ impl Manager {
Arc::new(Mutex::new(self)) Arc::new(Mutex::new(self))
} }
pub fn subscribe(&mut self, subscription: &Subscription) { pub fn subscribe(&mut self, subscription: &Subscription, channel: UnboundedSender<Event>) {
let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline);
if let (Some(hashtag), Some(id)) = (tag, tl.tag()) { if let (Some(hashtag), Some(id)) = (tag, tl.tag()) {
self.redis_connection.update_cache(hashtag, id); self.redis_connection.update_cache(hashtag, id);
}; };
let number_of_subscriptions = self let channels = self.timelines.entry(tl).or_default();
.clients_per_timeline channels.insert(Uuid::new_v4(), channel);
.entry(tl)
.and_modify(|n| *n += 1)
.or_insert(1);
use RedisCmd::*; if channels.len() == 1 {
if *number_of_subscriptions == 1 {
self.redis_connection self.redis_connection
.send_cmd(Subscribe, &tl) .send_cmd(RedisCmd::Subscribe, &tl)
.unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e)); .unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e));
}; };
} }
pub(crate) fn unsubscribe(&mut self, tl: Timeline) -> Result<()> { pub(crate) fn unsubscribe(&mut self, tl: &mut Timeline, id: &Uuid) -> Result<()> {
let number_of_subscriptions = self let channels = self.timelines.get_mut(tl).ok_or(Error::InvalidId)?;
.clients_per_timeline channels.remove(id);
.entry(tl)
.and_modify(|n| *n -= 1) if channels.len() == 0 {
.or_insert_with(|| { self.redis_connection.send_cmd(RedisCmd::Unsubscribe, &tl)?;
log::error!( self.timelines.remove(&tl);
"Attempted to unsubscribe from a timeline to which you were not subscribed: {:?}",
tl
);
0
});
use RedisCmd::*;
if *number_of_subscriptions == 0 {
self.redis_connection.send_cmd(Unsubscribe, &tl)?;
self.clients_per_timeline.remove_entry(&tl);
}; };
log::info!("Ended stream for {:?}", tl); log::info!("Ended stream for {:?}", tl);
Ok(()) Ok(())
} }
pub fn poll_broadcast(&mut self) -> Result<()> { pub fn poll_broadcast(&mut self) -> Result<()> {
while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { let mut completed_timelines = Vec::new();
self.unsubscribe(tl)?
}
if self.ping_time.elapsed() > Duration::from_secs(30) { if self.ping_time.elapsed() > Duration::from_secs(30) {
self.ping_time = Instant::now(); self.ping_time = Instant::now();
self.tx.broadcast((Timeline::empty(), Event::Ping))? for (timeline, channels) in self.timelines.iter_mut() {
} else { for (uuid, channel) in channels.iter_mut() {
match self.redis_connection.poll_redis() { match channel.try_send(Event::Ping) {
Ok(Async::NotReady) | Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace Ok(_) => (),
Ok(Async::Ready(Some((timeline, event)))) => { Err(_) => completed_timelines.push((*timeline, *uuid)),
self.tx.broadcast((timeline, event))? }
} }
}
};
loop {
match self.redis_connection.poll_redis() {
Ok(Async::NotReady) => break,
Ok(Async::Ready(Some((tl, event)))) => {
for (uuid, tx) in self.timelines.get_mut(&tl).ok_or(Error::InvalidId)? {
tx.try_send(event.clone())
.unwrap_or_else(|_| completed_timelines.push((tl, *uuid)))
}
}
Ok(Async::Ready(None)) => (), // cmd or msg for other namespace
Err(err) => log::error!("{}", err), // drop msg, log err, and proceed Err(err) => log::error!("{}", err), // drop msg, log err, and proceed
} }
} }
for (tl, channel) in completed_timelines.iter_mut() {
self.unsubscribe(tl, &channel)?;
}
Ok(()) Ok(())
} }
@ -119,20 +111,20 @@ impl Manager {
pub fn count(&self) -> String { pub fn count(&self) -> String {
format!( format!(
"Current connections: {}", "Current connections: {}",
self.clients_per_timeline.values().sum::<i32>() self.timelines.values().map(|el| el.len()).sum::<usize>()
) )
} }
pub fn list(&self) -> String { pub fn list(&self) -> String {
let max_len = self let max_len = self
.clients_per_timeline .timelines
.keys() .keys()
.fold(0, |acc, el| acc.max(format!("{:?}:", el).len())); .fold(0, |acc, el| acc.max(format!("{:?}:", el).len()));
self.clients_per_timeline self.timelines
.iter() .iter()
.map(|(tl, n)| { .map(|(tl, channel_map)| {
let tl_txt = format!("{:?}:", tl); let tl_txt = format!("{:?}:", tl);
format!("{:>1$} {2}\n", tl_txt, max_len, n) format!("{:>1$} {2}\n", tl_txt, max_len, channel_map.len())
}) })
.collect() .collect()
} }

View File

@ -6,11 +6,13 @@ use std::fmt;
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
InvalidId, InvalidId,
TimelineErr(TimelineErr), TimelineErr(TimelineErr),
EventErr(EventErr), EventErr(EventErr),
RedisParseErr(RedisParseErr), RedisParseErr(RedisParseErr),
RedisConnErr(RedisConnErr), RedisConnErr(RedisConnErr),
ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>), ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>),
ChannelSendErr2(tokio::sync::mpsc::error::UnboundedTrySendError<Event>),
} }
impl std::error::Error for Error {} impl std::error::Error for Error {}
@ -21,13 +23,14 @@ impl fmt::Display for Error {
match self { match self {
InvalidId => write!( InvalidId => write!(
f, f,
"Attempted to get messages for a subscription that had not been set up." "tried to access a timeline/channel subscription that does not exist"
), ),
EventErr(inner) => write!(f, "{}", inner), EventErr(inner) => write!(f, "{}", inner),
RedisParseErr(inner) => write!(f, "{}", inner), RedisParseErr(inner) => write!(f, "{}", inner),
RedisConnErr(inner) => write!(f, "{}", inner), RedisConnErr(inner) => write!(f, "{}", inner),
TimelineErr(inner) => write!(f, "{}", inner), TimelineErr(inner) => write!(f, "{}", inner),
ChannelSendErr(inner) => write!(f, "{}", inner), ChannelSendErr(inner) => write!(f, "{}", inner),
ChannelSendErr2(inner) => write!(f, "{}", inner),
}?; }?;
Ok(()) Ok(())
} }
@ -38,6 +41,11 @@ impl From<tokio::sync::watch::error::SendError<(Timeline, Event)>> for Error {
Self::ChannelSendErr(error) Self::ChannelSendErr(error)
} }
} }
impl From<tokio::sync::mpsc::error::UnboundedTrySendError<Event>> for Error {
fn from(error: tokio::sync::mpsc::error::UnboundedTrySendError<Event>) -> Self {
Self::ChannelSendErr2(error)
}
}
impl From<EventErr> for Error { impl From<EventErr> for Error {
fn from(error: EventErr) -> Self { fn from(error: EventErr) -> Self {

View File

@ -1,44 +1,29 @@
use super::{Event, Payload}; use super::{Event, Payload};
use crate::request::{Subscription, Timeline}; use crate::request::Subscription;
use futures::stream::Stream; use futures::stream::Stream;
use log;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{mpsc, watch}; use tokio::sync::mpsc::UnboundedReceiver;
use warp::reply::Reply; use warp::reply::Reply;
use warp::sse::Sse as WarpSse; use warp::sse::Sse as WarpSse;
pub struct Sse; type EventRx = UnboundedReceiver<Event>;
pub struct Sse(Subscription);
impl Sse { impl Sse {
pub fn send_events( pub fn new(subscription: Subscription) -> Self {
sse: WarpSse, Self(subscription)
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>, }
subscription: Subscription,
sse_rx: watch::Receiver<(Timeline, Event)>,
) -> impl Reply {
let target_timeline = subscription.timeline;
let event_stream = sse_rx pub fn send_events(self, sse: WarpSse, event_rx: EventRx) -> impl Reply {
.filter(move |(timeline, _)| target_timeline == *timeline) let event_stream = event_rx.filter_map(move |event| {
.filter_map(move |(_timeline, event)| { match (event.update_payload(), event.dyn_update_payload()) {
match (event.update_payload(), event.dyn_update_payload()) { (Some(update), _) if self.update_not_filtered(update) => event.to_warp_reply(),
(Some(update), _) if Sse::update_not_filtered(subscription.clone(), update) => { (_, Some(update)) if self.update_not_filtered(update) => event.to_warp_reply(),
event.to_warp_reply() (_, _) => event.to_warp_reply(), // send all non-updates
} }
(None, None) => event.to_warp_reply(), // send all non-updates });
(_, Some(update)) if Sse::update_not_filtered(subscription.clone(), update) => {
event.to_warp_reply()
}
(_, _) => None,
}
})
.then(move |res| {
unsubscribe_tx
.try_send(target_timeline)
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
res
});
sse.reply( sse.reply(
warp::sse::keep_alive() warp::sse::keep_alive()
@ -48,11 +33,11 @@ impl Sse {
) )
} }
fn update_not_filtered(subscription: Subscription, update: &impl Payload) -> bool { fn update_not_filtered(&self, update: &impl Payload) -> bool {
let blocks = &subscription.blocks; let blocks = &self.0.blocks;
let allowed_langs = &subscription.allowed_langs; let allowed_langs = &self.0.allowed_langs;
match subscription.timeline { match self.0.timeline {
tl if tl.is_public() tl if tl.is_public()
&& !update.language_unset() && !update.language_unset()
&& !allowed_langs.is_empty() && !allowed_langs.is_empty()

View File

@ -1,41 +1,30 @@
use super::{Event, Payload}; use super::{Event, Payload};
use crate::request::{Subscription, Timeline}; use crate::request::Subscription;
use futures::{future::Future, stream::Stream}; use futures::future::Future;
use tokio::sync::{mpsc, watch}; use futures::stream::Stream;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use warp::ws::{Message, WebSocket}; use warp::ws::{Message, WebSocket};
type Result<T> = std::result::Result<T, ()>; type EventRx = UnboundedReceiver<Event>;
type MsgTx = UnboundedSender<Message>;
pub struct Ws { pub struct Ws(Subscription);
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
subscription: Subscription,
ws_rx: watch::Receiver<(Timeline, Event)>,
ws_tx: Option<mpsc::UnboundedSender<Message>>,
}
impl Ws { impl Ws {
pub fn new( pub fn new(subscription: Subscription) -> Self {
unsubscribe_tx: mpsc::UnboundedSender<Timeline>, Self(subscription)
ws_rx: watch::Receiver<(Timeline, Event)>,
subscription: Subscription,
) -> Self {
Self {
unsubscribe_tx,
subscription,
ws_rx,
ws_tx: None,
}
} }
pub fn send_to(mut self, ws: WebSocket) -> impl Future<Item = (), Error = ()> { pub fn send_to(
mut self,
ws: WebSocket,
event_rx: EventRx,
) -> impl Future<Item = (), Error = ()> {
let (transmit_to_ws, _receive_from_ws) = ws.split(); let (transmit_to_ws, _receive_from_ws) = ws.split();
// Create a pipe // Create a pipe, send one end of it to a different green thread and tell that end
let (ws_tx, ws_rx) = mpsc::unbounded_channel(); // to forward to the WebSocket client
self.ws_tx = Some(ws_tx); let (mut ws_tx, ws_rx) = mpsc::unbounded_channel();
// Send one end of it to a different green thread and tell that end to forward
// whatever it gets on to the WebSocket client
warp::spawn( warp::spawn(
ws_rx ws_rx
.map_err(|_| -> warp::Error { unreachable!() }) .map_err(|_| -> warp::Error { unreachable!() })
@ -49,61 +38,53 @@ impl Ws {
}), }),
); );
let target_timeline = self.subscription.timeline; event_rx.map_err(|_| ()).for_each(move |event| {
let incoming_events = self.ws_rx.clone().map_err(|_| ());
incoming_events.for_each(move |(tl, event)| {
//TODO log::info!("{:?}, {:?}", &tl, &event);
if matches!(event, Event::Ping) { if matches!(event, Event::Ping) {
self.send_msg(&event)? send_msg(&event, &mut ws_tx)?
} else if target_timeline == tl { } else {
match (event.update_payload(), event.dyn_update_payload()) { match (event.update_payload(), event.dyn_update_payload()) {
(Some(update), _) => self.send_or_filter(tl, &event, update)?, (Some(update), _) => self.send_or_filter(&event, update, &mut ws_tx),
(None, None) => self.send_msg(&event)?, // send all non-updates (None, None) => send_msg(&event, &mut ws_tx), // send all non-updates
(_, Some(dyn_update)) => self.send_or_filter(tl, &event, dyn_update)?, (_, Some(dyn_update)) => self.send_or_filter(&event, dyn_update, &mut ws_tx),
} }?
} }
Ok(()) Ok(())
}) })
} }
fn send_or_filter(&mut self, tl: Timeline, event: &Event, update: &impl Payload) -> Result<()> { fn send_or_filter(
let (blocks, allowed_langs) = (&self.subscription.blocks, &self.subscription.allowed_langs); &mut self,
const SKIP: Result<()> = Ok(()); event: &Event,
update: &impl Payload,
mut ws_tx: &mut MsgTx,
) -> Result<(), ()> {
let (blocks, allowed_langs) = (&self.0.blocks, &self.0.allowed_langs);
match tl { let skip = |reason, tl| Ok(log::info!("{:?} msg skipped - {}", tl, reason));
match self.0.timeline {
tl if tl.is_public() tl if tl.is_public()
&& !update.language_unset() && !update.language_unset()
&& !allowed_langs.is_empty() && !allowed_langs.is_empty()
&& !allowed_langs.contains(&update.language()) => && !allowed_langs.contains(&update.language()) =>
{ {
log::info!("{:?} msg skipped - disallowed language", tl); skip("disallowed language", tl)
SKIP
} }
tl if !blocks.blocked_users.is_disjoint(&update.involved_users()) => { tl if !blocks.blocked_users.is_disjoint(&update.involved_users()) => {
log::info!("{:?} msg skipped - involves blocked user", tl); skip("involves blocked user", tl)
SKIP
}
tl if blocks.blocking_users.contains(update.author()) => {
log::info!("{:?} msg skipped - from blocking user", tl);
SKIP
} }
tl if blocks.blocking_users.contains(update.author()) => skip("from blocking user", tl),
tl if blocks.blocked_domains.contains(update.sent_from()) => { tl if blocks.blocked_domains.contains(update.sent_from()) => {
log::info!("{:?} msg skipped - from blocked domain", tl); skip("from blocked domain", tl)
SKIP
} }
_ => Ok(self.send_msg(&event)?), _ => Ok(send_msg(event, &mut ws_tx)?),
} }
} }
}
fn send_msg(&mut self, event: &Event) -> Result<()> {
let txt = &event.to_json_string(); fn send_msg(event: &Event, ws_tx: &mut MsgTx) -> Result<(), ()> {
let tl = self.subscription.timeline; ws_tx
let mut channel = self.ws_tx.clone().ok_or(())?; .try_send(Message::text(&event.to_json_string()))
channel.try_send(Message::text(txt)).map_err(|_| { .map_err(|_| log::info!("WebSocket connection closed"))
self.unsubscribe_tx
.try_send(tl)
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
})
}
} }