Allow seperate SSE responses to share Redis pubsub

This commit implements a shared stream of data from Redis, which
allows all SSE connections that send the same data to the client
to share a single connection to Redis.  (Previously, each client
got their own connection, which would significantly increase the
number of open Redis connections—especially since nearly all clients
will subscribe to `/public`.)
This commit is contained in:
Daniel Sockwell 2019-04-26 20:00:11 -04:00
parent f676e51ce4
commit 425a9d0aae
4 changed files with 144 additions and 54 deletions

View File

@ -4,7 +4,10 @@ mod query;
mod user;
mod utils;
use futures::stream::Stream;
use futures::{Async, Poll};
use pubsub::PubSub;
use serde_json::Value;
use std::io::Error;
use user::{Filter, Scope, User};
use warp::{path, Filter as WarpFilter};
@ -16,21 +19,21 @@ fn main() {
.and(path::end())
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.map(|user: User| PubSub::from(user.id.to_string(), user));
.map(|user: User| (user.id.to_string(), user));
// GET /api/v1/streaming/user/notification [private; notification filter]
let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification")
.and(path::end())
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.map(|user: User| PubSub::from(user.id.to_string(), user.with_notification_filter()));
.map(|user: User| (user.id.to_string(), user.with_notification_filter()));
// GET /api/v1/streaming/public [public; language filter]
let public_timeline = path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(user::get_access_token(user::Scope::Public))
.and_then(|token| user::get_account(token, Scope::Public))
.map(|user: User| PubSub::from("public".into(), user.with_language_filter()));
.map(|user: User| ("public".into(), user.with_language_filter()));
// GET /api/v1/streaming/public?only_media=true [public; language filter]
let public_timeline_media = path!("api" / "v1" / "streaming" / "public")
@ -39,8 +42,8 @@ fn main() {
.and_then(|token| user::get_account(token, Scope::Public))
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => PubSub::from("public:media".into(), user.with_language_filter()),
_ => PubSub::from("public".into(), user.with_language_filter()),
"1" | "true" => ("public:media".into(), user.with_language_filter()),
_ => ("public".into(), user.with_language_filter()),
});
// GET /api/v1/streaming/public/local [public; language filter]
@ -48,7 +51,7 @@ fn main() {
.and(path::end())
.and(user::get_access_token(user::Scope::Public))
.and_then(|token| user::get_account(token, Scope::Public))
.map(|user: User| PubSub::from("public:local".into(), user.with_language_filter()));
.map(|user: User| ("public:local".into(), user.with_language_filter()));
// GET /api/v1/streaming/public/local?only_media=true [public; language filter]
let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local")
@ -57,8 +60,8 @@ fn main() {
.and(warp::query())
.and(path::end())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => PubSub::from("public:local:media".into(), user.with_language_filter()),
_ => PubSub::from("public:local".into(), user.with_language_filter()),
"1" | "true" => ("public:local:media".into(), user.with_language_filter()),
_ => ("public:local".into(), user.with_language_filter()),
});
// GET /api/v1/streaming/direct [private; *no* filter]
@ -66,19 +69,22 @@ fn main() {
.and(path::end())
.and(user::get_access_token(Scope::Private))
.and_then(|token| user::get_account(token, Scope::Private))
.map(|user: User| PubSub::from(format!("direct:{}", user.id), user.with_no_filter()));
.map(|user: User| (format!("direct:{}", user.id), user.with_no_filter()));
// GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter]
let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| PubSub::from(format!("hashtag:{}", q.tag), User::public()));
.map(|q: query::Hashtag| {
dbg!(&q);
(format!("hashtag:{}", q.tag), User::public())
});
// GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter]
let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| PubSub::from(format!("hashtag:{}:local", q.tag), User::public()));
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()));
// GET /api/v1/streaming/list?list=:list_id [private; no filter]
let list_timeline = path!("api" / "v1" / "streaming" / "list")
@ -88,8 +94,9 @@ fn main() {
.and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user)))
.untuple_one()
.and(path::end())
.map(|list: i64, user: User| PubSub::from(format!("list:{}", list), user.with_no_filter()));
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()));
let event_stream = RedisStream::new();
let event_stream = warp::any().map(move || event_stream.clone());
let routes = or!(
user_timeline,
user_timeline_notifications,
@ -102,29 +109,75 @@ fn main() {
hashtag_timeline_local,
list_timeline
)
.and_then(|event_stream| event_stream)
.untuple_one()
.and(warp::sse())
.map(|event_stream: pubsub::Receiver, sse: warp::sse::Sse| {
let user = event_stream.user.clone();
sse.reply(warp::sse::keep(
event_stream.filter_map(move |item| {
let payload = item["payload"].clone();
let event = item["event"].to_string().clone();
let toot_lang = item["language"].to_string().clone();
println!("ding");
match &user.filter {
Filter::Notification if event != "notification" => None,
Filter::Language if !user.langs.contains(&toot_lang) => None,
_ => Some((warp::sse::event(event), warp::sse::data(payload))),
}
}),
None,
))
})
.and(event_stream)
.map(
|timeline: String, user: User, sse: warp::sse::Sse, mut event_stream: RedisStream| {
event_stream.add(timeline.clone(), user);
sse.reply(warp::sse::keep(
event_stream.filter_map(move |item| {
println!("ding");
Some((warp::sse::event("event"), warp::sse::data(item.to_string())))
}),
None,
))
},
)
.with(warp::reply::with::header("Connection", "keep-alive"))
.recover(error::handle_errors);
warp::serve(routes).run(([127, 0, 0, 1], 3030));
}
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RedisStream {
recv: Arc<Mutex<HashMap<String, pubsub::Receiver>>>,
current_stream: String,
}
impl RedisStream {
fn new() -> Self {
let recv = Arc::new(Mutex::new(HashMap::new()));
Self {
recv,
current_stream: "".to_string(),
}
}
fn add(&mut self, timeline: String, user: User) -> &Self {
let mut hash_map_of_streams = self.recv.lock().unwrap();
if !hash_map_of_streams.contains_key(&timeline) {
println!(
"First time encountering `{}`, saving it to the HashMap",
&timeline
);
hash_map_of_streams.insert(timeline.clone(), PubSub::from(timeline.clone(), user));
} else {
println!(
"HashMap already contains `{}`, returning unmodified HashMap",
&timeline
);
}
self.current_stream = timeline;
self
}
}
impl Stream for RedisStream {
type Item = Value;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
println!("polling Interval");
let mut hash_map_of_streams = self.recv.lock().unwrap();
let target_stream = self.current_stream.clone();
let stream = hash_map_of_streams.get_mut(&target_stream).unwrap();
match stream.poll() {
Ok(Async::Ready(Some(value))) => Ok(Async::Ready(Some(value))),
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e),
}
}
}

View File

@ -41,6 +41,7 @@ impl RedisCmd {
}
}
#[derive(Debug)]
pub struct Receiver {
rx: ReadHalf<TcpStream>,
tx: WriteHalf<TcpStream>,
@ -49,6 +50,7 @@ pub struct Receiver {
}
impl Receiver {
fn new(socket: TcpStream, tl: String, user: User) -> Self {
println!("created a new Receiver");
let (rx, mut tx) = socket.split();
tx.poll_write(RedisCmd::subscribe_to_timeline(&tl).as_bytes())
.expect("Can subscribe to Redis");
@ -86,30 +88,65 @@ impl Drop for Receiver {
}
}
use futures::sink::Sink;
use tokio::net::tcp::ConnectFuture;
struct Socket {
connect: ConnectFuture,
tx: tokio::sync::mpsc::Sender<TcpStream>,
}
impl Socket {
fn new(address: impl std::fmt::Display, tx: tokio::sync::mpsc::Sender<TcpStream>) -> Self {
let address = address
.to_string()
.parse()
.expect("Unable to parse address");
let connect = TcpStream::connect(&address);
Self { connect, tx }
}
}
impl Future for Socket {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.connect.poll() {
Ok(Async::Ready(socket)) => {
self.tx.clone().try_send(socket);
Ok(Async::Ready(()))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
println!("failed to connect: {}", e);
Ok(Async::Ready(()))
}
}
}
}
pub struct PubSub {}
impl PubSub {
pub fn from(
timeline: impl std::fmt::Display,
user: User,
) -> impl Future<Item = Receiver, Error = warp::reject::Rejection> {
pub fn from(timeline: impl std::fmt::Display, user: User) -> Receiver {
while OPEN_CONNECTIONS.load(Ordering::Relaxed) > MAX_CONNECTIONS.load(Ordering::Relaxed) {
thread::sleep(time::Duration::from_millis(1000));
}
let new_connections = OPEN_CONNECTIONS.fetch_add(1, Ordering::Relaxed) + 1;
println!("{} connection(s) now open", new_connections);
let (tx, mut rx) = tokio::sync::mpsc::channel(5);
let socket = Socket::new("127.0.0.1:6379", tx);
tokio::spawn(futures::future::lazy(move || socket));
let socket = loop {
if let Ok(Async::Ready(Some(msg))) = rx.poll() {
break msg;
}
thread::sleep(time::Duration::from_millis(100));
};
let timeline = timeline.to_string();
fn get_socket() -> impl Future<Item = TcpStream, Error = Box<Error>> {
let address = "127.0.0.1:6379".parse().expect("Unable to parse address");
let connection = TcpStream::connect(&address);
connection.and_then(Ok).map_err(Box::new)
}
get_socket()
.and_then(move |socket| {
let stream_of_data_from_redis = Receiver::new(socket, timeline, user);
Ok(stream_of_data_from_redis)
})
.map_err(warp::reject::custom)
let stream_of_data_from_redis = Receiver::new(socket, timeline, user);
stream_of_data_from_redis
}
}

View File

@ -1,18 +1,18 @@
use serde_derive::Deserialize;
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct Media {
pub only_media: String,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct Hashtag {
pub tag: String,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct List {
pub list: i64,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct Auth {
pub access_token: String,
}

View File

@ -23,14 +23,14 @@ fn conn() -> postgres::Connection {
)
.unwrap()
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Filter {
None,
Language,
Notification,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct User {
pub id: i64,
pub langs: Vec<String>,