From 425a9d0aae900de9aef0cd9b3b96680ba0ff71dc Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Fri, 26 Apr 2019 20:00:11 -0400 Subject: [PATCH] Allow seperate SSE responses to share Redis pubsub MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements a shared stream of data from Redis, which allows all SSE connections that send the same data to the client to share a single connection to Redis. (Previously, each client got their own connection, which would significantly increase the number of open Redis connections—especially since nearly all clients will subscribe to `/public`.) --- src/main.rs | 119 ++++++++++++++++++++++++++++++++++++-------------- src/pubsub.rs | 67 +++++++++++++++++++++------- src/query.rs | 8 ++-- src/user.rs | 4 +- 4 files changed, 144 insertions(+), 54 deletions(-) diff --git a/src/main.rs b/src/main.rs index fa218a8..499abf5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,7 +4,10 @@ mod query; mod user; mod utils; use futures::stream::Stream; +use futures::{Async, Poll}; use pubsub::PubSub; +use serde_json::Value; +use std::io::Error; use user::{Filter, Scope, User}; use warp::{path, Filter as WarpFilter}; @@ -16,21 +19,21 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| PubSub::from(user.id.to_string(), user)); + .map(|user: User| (user.id.to_string(), user)); // GET /api/v1/streaming/user/notification [private; notification filter] let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification") .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| PubSub::from(user.id.to_string(), user.with_notification_filter())); + .map(|user: User| (user.id.to_string(), user.with_notification_filter())); // GET /api/v1/streaming/public [public; language filter] let public_timeline = path!("api" / "v1" / "streaming" / "public") .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| PubSub::from("public".into(), user.with_language_filter())); + .map(|user: User| ("public".into(), user.with_language_filter())); // GET /api/v1/streaming/public?only_media=true [public; language filter] let public_timeline_media = path!("api" / "v1" / "streaming" / "public") @@ -39,8 +42,8 @@ fn main() { .and_then(|token| user::get_account(token, Scope::Public)) .and(warp::query()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => PubSub::from("public:media".into(), user.with_language_filter()), - _ => PubSub::from("public".into(), user.with_language_filter()), + "1" | "true" => ("public:media".into(), user.with_language_filter()), + _ => ("public".into(), user.with_language_filter()), }); // GET /api/v1/streaming/public/local [public; language filter] @@ -48,7 +51,7 @@ fn main() { .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| PubSub::from("public:local".into(), user.with_language_filter())); + .map(|user: User| ("public:local".into(), user.with_language_filter())); // GET /api/v1/streaming/public/local?only_media=true [public; language filter] let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local") @@ -57,8 +60,8 @@ fn main() { .and(warp::query()) .and(path::end()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => PubSub::from("public:local:media".into(), user.with_language_filter()), - _ => PubSub::from("public:local".into(), user.with_language_filter()), + "1" | "true" => ("public:local:media".into(), user.with_language_filter()), + _ => ("public:local".into(), user.with_language_filter()), }); // GET /api/v1/streaming/direct [private; *no* filter] @@ -66,19 +69,22 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| PubSub::from(format!("direct:{}", user.id), user.with_no_filter())); + .map(|user: User| (format!("direct:{}", user.id), user.with_no_filter())); // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| PubSub::from(format!("hashtag:{}", q.tag), User::public())); + .map(|q: query::Hashtag| { + dbg!(&q); + (format!("hashtag:{}", q.tag), User::public()) + }); // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| PubSub::from(format!("hashtag:{}:local", q.tag), User::public())); + .map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public())); // GET /api/v1/streaming/list?list=:list_id [private; no filter] let list_timeline = path!("api" / "v1" / "streaming" / "list") @@ -88,8 +94,9 @@ fn main() { .and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user))) .untuple_one() .and(path::end()) - .map(|list: i64, user: User| PubSub::from(format!("list:{}", list), user.with_no_filter())); - + .map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter())); + let event_stream = RedisStream::new(); + let event_stream = warp::any().map(move || event_stream.clone()); let routes = or!( user_timeline, user_timeline_notifications, @@ -102,29 +109,75 @@ fn main() { hashtag_timeline_local, list_timeline ) - .and_then(|event_stream| event_stream) + .untuple_one() .and(warp::sse()) - .map(|event_stream: pubsub::Receiver, sse: warp::sse::Sse| { - let user = event_stream.user.clone(); - sse.reply(warp::sse::keep( - event_stream.filter_map(move |item| { - let payload = item["payload"].clone(); - let event = item["event"].to_string().clone(); - let toot_lang = item["language"].to_string().clone(); - - println!("ding"); - - match &user.filter { - Filter::Notification if event != "notification" => None, - Filter::Language if !user.langs.contains(&toot_lang) => None, - _ => Some((warp::sse::event(event), warp::sse::data(payload))), - } - }), - None, - )) - }) + .and(event_stream) + .map( + |timeline: String, user: User, sse: warp::sse::Sse, mut event_stream: RedisStream| { + event_stream.add(timeline.clone(), user); + sse.reply(warp::sse::keep( + event_stream.filter_map(move |item| { + println!("ding"); + Some((warp::sse::event("event"), warp::sse::data(item.to_string()))) + }), + None, + )) + }, + ) .with(warp::reply::with::header("Connection", "keep-alive")) .recover(error::handle_errors); warp::serve(routes).run(([127, 0, 0, 1], 3030)); } + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +#[derive(Clone)] +struct RedisStream { + recv: Arc>>, + current_stream: String, +} +impl RedisStream { + fn new() -> Self { + let recv = Arc::new(Mutex::new(HashMap::new())); + Self { + recv, + current_stream: "".to_string(), + } + } + + fn add(&mut self, timeline: String, user: User) -> &Self { + let mut hash_map_of_streams = self.recv.lock().unwrap(); + if !hash_map_of_streams.contains_key(&timeline) { + println!( + "First time encountering `{}`, saving it to the HashMap", + &timeline + ); + hash_map_of_streams.insert(timeline.clone(), PubSub::from(timeline.clone(), user)); + } else { + println!( + "HashMap already contains `{}`, returning unmodified HashMap", + &timeline + ); + } + self.current_stream = timeline; + self + } +} +impl Stream for RedisStream { + type Item = Value; + type Error = Error; + + fn poll(&mut self) -> Poll, Self::Error> { + println!("polling Interval"); + let mut hash_map_of_streams = self.recv.lock().unwrap(); + let target_stream = self.current_stream.clone(); + let stream = hash_map_of_streams.get_mut(&target_stream).unwrap(); + match stream.poll() { + Ok(Async::Ready(Some(value))) => Ok(Async::Ready(Some(value))), + Ok(Async::Ready(None)) => Ok(Async::Ready(None)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => Err(e), + } + } +} diff --git a/src/pubsub.rs b/src/pubsub.rs index 562d992..c02207c 100644 --- a/src/pubsub.rs +++ b/src/pubsub.rs @@ -41,6 +41,7 @@ impl RedisCmd { } } +#[derive(Debug)] pub struct Receiver { rx: ReadHalf, tx: WriteHalf, @@ -49,6 +50,7 @@ pub struct Receiver { } impl Receiver { fn new(socket: TcpStream, tl: String, user: User) -> Self { + println!("created a new Receiver"); let (rx, mut tx) = socket.split(); tx.poll_write(RedisCmd::subscribe_to_timeline(&tl).as_bytes()) .expect("Can subscribe to Redis"); @@ -86,30 +88,65 @@ impl Drop for Receiver { } } +use futures::sink::Sink; +use tokio::net::tcp::ConnectFuture; +struct Socket { + connect: ConnectFuture, + tx: tokio::sync::mpsc::Sender, +} +impl Socket { + fn new(address: impl std::fmt::Display, tx: tokio::sync::mpsc::Sender) -> Self { + let address = address + .to_string() + .parse() + .expect("Unable to parse address"); + let connect = TcpStream::connect(&address); + Self { connect, tx } + } +} +impl Future for Socket { + type Item = (); + type Error = (); + + fn poll(&mut self) -> Poll { + match self.connect.poll() { + Ok(Async::Ready(socket)) => { + self.tx.clone().try_send(socket); + Ok(Async::Ready(())) + } + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => { + println!("failed to connect: {}", e); + Ok(Async::Ready(())) + } + } + } +} + pub struct PubSub {} impl PubSub { - pub fn from( - timeline: impl std::fmt::Display, - user: User, - ) -> impl Future { + pub fn from(timeline: impl std::fmt::Display, user: User) -> Receiver { while OPEN_CONNECTIONS.load(Ordering::Relaxed) > MAX_CONNECTIONS.load(Ordering::Relaxed) { thread::sleep(time::Duration::from_millis(1000)); } let new_connections = OPEN_CONNECTIONS.fetch_add(1, Ordering::Relaxed) + 1; println!("{} connection(s) now open", new_connections); + let (tx, mut rx) = tokio::sync::mpsc::channel(5); + let socket = Socket::new("127.0.0.1:6379", tx); + + tokio::spawn(futures::future::lazy(move || socket)); + + let socket = loop { + if let Ok(Async::Ready(Some(msg))) = rx.poll() { + break msg; + } + thread::sleep(time::Duration::from_millis(100)); + }; + let timeline = timeline.to_string(); - fn get_socket() -> impl Future> { - let address = "127.0.0.1:6379".parse().expect("Unable to parse address"); - let connection = TcpStream::connect(&address); - connection.and_then(Ok).map_err(Box::new) - } - get_socket() - .and_then(move |socket| { - let stream_of_data_from_redis = Receiver::new(socket, timeline, user); - Ok(stream_of_data_from_redis) - }) - .map_err(warp::reject::custom) + let stream_of_data_from_redis = Receiver::new(socket, timeline, user); + stream_of_data_from_redis } } diff --git a/src/query.rs b/src/query.rs index bb73830..9c019b4 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,18 +1,18 @@ use serde_derive::Deserialize; -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct Media { pub only_media: String, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct Hashtag { pub tag: String, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct List { pub list: i64, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug)] pub struct Auth { pub access_token: String, } diff --git a/src/user.rs b/src/user.rs index 5a08997..1d2968d 100644 --- a/src/user.rs +++ b/src/user.rs @@ -23,14 +23,14 @@ fn conn() -> postgres::Connection { ) .unwrap() } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum Filter { None, Language, Notification, } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct User { pub id: i64, pub langs: Vec,