mirror of https://github.com/mastodon/flodgatt
Add ability for multiple clients to connect to the same pub/sub connection
This commit is contained in:
parent
425a9d0aae
commit
9e921c1c97
File diff suppressed because it is too large
Load Diff
|
@ -12,12 +12,11 @@ log = "0.4.6"
|
|||
actix = "0.7.9"
|
||||
actix-redis = "0.5.1"
|
||||
redis-async = "0.4.4"
|
||||
uuid = "0.7.2"
|
||||
envconfig = "0.5.0"
|
||||
envconfig_derive = "0.5.0"
|
||||
whoami = "0.4.1"
|
||||
futures = "0.1.26"
|
||||
tokio = "0.1.18"
|
||||
tokio = "0.1.19"
|
||||
warp = "0.1.15"
|
||||
regex = "1.1.5"
|
||||
serde_json = "1.0.39"
|
||||
|
@ -25,6 +24,7 @@ serde_derive = "1.0.90"
|
|||
serde = "1.0.90"
|
||||
pretty_env_logger = "0.3.0"
|
||||
postgres = "0.15.2"
|
||||
uuid = { version = "0.7", features = ["v4"] }
|
||||
|
||||
[features]
|
||||
default = [ "production" ]
|
||||
|
|
97
src/main.rs
97
src/main.rs
|
@ -1,13 +1,11 @@
|
|||
mod error;
|
||||
mod pubsub;
|
||||
mod query;
|
||||
mod stream;
|
||||
mod user;
|
||||
mod utils;
|
||||
use futures::stream::Stream;
|
||||
use futures::{Async, Poll};
|
||||
use pubsub::PubSub;
|
||||
use serde_json::Value;
|
||||
use std::io::Error;
|
||||
use stream::StreamManager;
|
||||
use user::{Filter, Scope, User};
|
||||
use warp::{path, Filter as WarpFilter};
|
||||
|
||||
|
@ -33,7 +31,7 @@ fn main() {
|
|||
.and(path::end())
|
||||
.and(user::get_access_token(user::Scope::Public))
|
||||
.and_then(|token| user::get_account(token, Scope::Public))
|
||||
.map(|user: User| ("public".into(), user.with_language_filter()));
|
||||
.map(|user: User| ("public".to_owned(), user.with_language_filter()));
|
||||
|
||||
// GET /api/v1/streaming/public?only_media=true [public; language filter]
|
||||
let public_timeline_media = path!("api" / "v1" / "streaming" / "public")
|
||||
|
@ -42,8 +40,8 @@ fn main() {
|
|||
.and_then(|token| user::get_account(token, Scope::Public))
|
||||
.and(warp::query())
|
||||
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
|
||||
"1" | "true" => ("public:media".into(), user.with_language_filter()),
|
||||
_ => ("public".into(), user.with_language_filter()),
|
||||
"1" | "true" => ("public:media".to_owned(), user.with_language_filter()),
|
||||
_ => ("public".to_owned(), user.with_language_filter()),
|
||||
});
|
||||
|
||||
// GET /api/v1/streaming/public/local [public; language filter]
|
||||
|
@ -51,7 +49,7 @@ fn main() {
|
|||
.and(path::end())
|
||||
.and(user::get_access_token(user::Scope::Public))
|
||||
.and_then(|token| user::get_account(token, Scope::Public))
|
||||
.map(|user: User| ("public:local".into(), user.with_language_filter()));
|
||||
.map(|user: User| ("public:local".to_owned(), user.with_language_filter()));
|
||||
|
||||
// GET /api/v1/streaming/public/local?only_media=true [public; language filter]
|
||||
let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local")
|
||||
|
@ -60,8 +58,8 @@ fn main() {
|
|||
.and(warp::query())
|
||||
.and(path::end())
|
||||
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
|
||||
"1" | "true" => ("public:local:media".into(), user.with_language_filter()),
|
||||
_ => ("public:local".into(), user.with_language_filter()),
|
||||
"1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()),
|
||||
_ => ("public:local".to_owned(), user.with_language_filter()),
|
||||
});
|
||||
|
||||
// GET /api/v1/streaming/direct [private; *no* filter]
|
||||
|
@ -75,10 +73,7 @@ fn main() {
|
|||
let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag")
|
||||
.and(warp::query())
|
||||
.and(path::end())
|
||||
.map(|q: query::Hashtag| {
|
||||
dbg!(&q);
|
||||
(format!("hashtag:{}", q.tag), User::public())
|
||||
});
|
||||
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()));
|
||||
|
||||
// GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter]
|
||||
let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local")
|
||||
|
@ -95,8 +90,8 @@ fn main() {
|
|||
.untuple_one()
|
||||
.and(path::end())
|
||||
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()));
|
||||
let event_stream = RedisStream::new();
|
||||
let event_stream = warp::any().map(move || event_stream.clone());
|
||||
|
||||
let redis_updates = StreamManager::new();
|
||||
let routes = or!(
|
||||
user_timeline,
|
||||
user_timeline_notifications,
|
||||
|
@ -111,14 +106,22 @@ fn main() {
|
|||
)
|
||||
.untuple_one()
|
||||
.and(warp::sse())
|
||||
.and(event_stream)
|
||||
.and(warp::any().map(move || redis_updates.new_copy()))
|
||||
.map(
|
||||
|timeline: String, user: User, sse: warp::sse::Sse, mut event_stream: RedisStream| {
|
||||
event_stream.add(timeline.clone(), user);
|
||||
|timeline: String, user: User, sse: warp::sse::Sse, mut event_stream: StreamManager| {
|
||||
event_stream.add(&timeline, &user);
|
||||
sse.reply(warp::sse::keep(
|
||||
event_stream.filter_map(move |item| {
|
||||
println!("ding");
|
||||
Some((warp::sse::event("event"), warp::sse::data(item.to_string())))
|
||||
let payload = item["payload"].clone();
|
||||
let event = item["event"].clone().to_string();
|
||||
let toot_lang = payload["language"].as_str().expect("redis str").to_string();
|
||||
let user_langs = user.langs.clone();
|
||||
|
||||
match (&user.filter, user_langs) {
|
||||
(Filter::Notification, _) if event != "notification" => None,
|
||||
(Filter::Language, Some(ref langs)) if !langs.contains(&toot_lang) => None,
|
||||
_ => Some((warp::sse::event(event), warp::sse::data(payload))),
|
||||
}
|
||||
}),
|
||||
None,
|
||||
))
|
||||
|
@ -129,55 +132,3 @@ fn main() {
|
|||
|
||||
warp::serve(routes).run(([127, 0, 0, 1], 3030));
|
||||
}
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
#[derive(Clone)]
|
||||
struct RedisStream {
|
||||
recv: Arc<Mutex<HashMap<String, pubsub::Receiver>>>,
|
||||
current_stream: String,
|
||||
}
|
||||
impl RedisStream {
|
||||
fn new() -> Self {
|
||||
let recv = Arc::new(Mutex::new(HashMap::new()));
|
||||
Self {
|
||||
recv,
|
||||
current_stream: "".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn add(&mut self, timeline: String, user: User) -> &Self {
|
||||
let mut hash_map_of_streams = self.recv.lock().unwrap();
|
||||
if !hash_map_of_streams.contains_key(&timeline) {
|
||||
println!(
|
||||
"First time encountering `{}`, saving it to the HashMap",
|
||||
&timeline
|
||||
);
|
||||
hash_map_of_streams.insert(timeline.clone(), PubSub::from(timeline.clone(), user));
|
||||
} else {
|
||||
println!(
|
||||
"HashMap already contains `{}`, returning unmodified HashMap",
|
||||
&timeline
|
||||
);
|
||||
}
|
||||
self.current_stream = timeline;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Stream for RedisStream {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
println!("polling Interval");
|
||||
let mut hash_map_of_streams = self.recv.lock().unwrap();
|
||||
let target_stream = self.current_stream.clone();
|
||||
let stream = hash_map_of_streams.get_mut(&target_stream).unwrap();
|
||||
match stream.poll() {
|
||||
Ok(Async::Ready(Some(value))) => Ok(Async::Ready(Some(value))),
|
||||
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
use crate::stream;
|
||||
use crate::user::User;
|
||||
use futures::{Async, Future, Poll};
|
||||
use log::{debug, info};
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use log::info;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::{thread, time};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, Error, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use warp::Stream;
|
||||
|
||||
static OPEN_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
|
||||
static MAX_CONNECTIONS: AtomicUsize = AtomicUsize::new(400);
|
||||
pub static OPEN_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
|
||||
pub static MAX_CONNECTIONS: AtomicUsize = AtomicUsize::new(400);
|
||||
|
||||
struct RedisCmd {
|
||||
pub struct RedisCmd {
|
||||
resp_cmd: String,
|
||||
}
|
||||
impl RedisCmd {
|
||||
|
@ -27,13 +25,13 @@ impl RedisCmd {
|
|||
);
|
||||
Self { resp_cmd }
|
||||
}
|
||||
fn subscribe_to_timeline(timeline: &str) -> String {
|
||||
pub fn subscribe_to_timeline(timeline: &str) -> String {
|
||||
let channel = format!("timeline:{}", timeline);
|
||||
let subscribe = RedisCmd::new("subscribe", &channel);
|
||||
info!("Subscribing to {}", &channel);
|
||||
subscribe.resp_cmd
|
||||
}
|
||||
fn unsubscribe_from_timeline(timeline: &str) -> String {
|
||||
pub fn unsubscribe_from_timeline(timeline: &str) -> String {
|
||||
let channel = format!("timeline:{}", timeline);
|
||||
let unsubscribe = RedisCmd::new("unsubscribe", &channel);
|
||||
info!("Unsubscribing from {}", &channel);
|
||||
|
@ -41,54 +39,6 @@ impl RedisCmd {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver {
|
||||
rx: ReadHalf<TcpStream>,
|
||||
tx: WriteHalf<TcpStream>,
|
||||
tl: String,
|
||||
pub user: User,
|
||||
}
|
||||
impl Receiver {
|
||||
fn new(socket: TcpStream, tl: String, user: User) -> Self {
|
||||
println!("created a new Receiver");
|
||||
let (rx, mut tx) = socket.split();
|
||||
tx.poll_write(RedisCmd::subscribe_to_timeline(&tl).as_bytes())
|
||||
.expect("Can subscribe to Redis");
|
||||
Self { rx, tx, tl, user }
|
||||
}
|
||||
}
|
||||
impl Stream for Receiver {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
|
||||
let mut buffer = vec![0u8; 3000];
|
||||
if let Async::Ready(num_bytes_read) = self.rx.poll_read(&mut buffer)? {
|
||||
// capture everything between `{` and `}` as potential JSON
|
||||
let re = Regex::new(r"(?P<json>\{.*\})").expect("Valid hard-coded regex");
|
||||
|
||||
if let Some(cap) = re.captures(&String::from_utf8_lossy(&buffer[..num_bytes_read])) {
|
||||
debug!("{}", &cap["json"]);
|
||||
let json: Value = serde_json::from_str(&cap["json"].to_string().clone())?;
|
||||
return Ok(Async::Ready(Some(json)));
|
||||
}
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
impl Drop for Receiver {
|
||||
fn drop(&mut self) {
|
||||
let channel = format!("timeline:{}", self.tl);
|
||||
self.tx
|
||||
.poll_write(RedisCmd::unsubscribe_from_timeline(&channel).as_bytes())
|
||||
.expect("Can unsubscribe from Redis");
|
||||
let open_connections = OPEN_CONNECTIONS.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
info!("Receiver dropped. {} connection(s) open", open_connections);
|
||||
}
|
||||
}
|
||||
|
||||
use futures::sink::Sink;
|
||||
use tokio::net::tcp::ConnectFuture;
|
||||
struct Socket {
|
||||
connect: ConnectFuture,
|
||||
|
@ -111,12 +61,12 @@ impl Future for Socket {
|
|||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
match self.connect.poll() {
|
||||
Ok(Async::Ready(socket)) => {
|
||||
self.tx.clone().try_send(socket);
|
||||
self.tx.clone().try_send(socket).expect("Socket created");
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => {
|
||||
println!("failed to connect: {}", e);
|
||||
info!("failed to connect: {}", e);
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
|
@ -126,12 +76,12 @@ impl Future for Socket {
|
|||
pub struct PubSub {}
|
||||
|
||||
impl PubSub {
|
||||
pub fn from(timeline: impl std::fmt::Display, user: User) -> Receiver {
|
||||
pub fn from(timeline: impl std::fmt::Display, user: &User) -> stream::Receiver {
|
||||
while OPEN_CONNECTIONS.load(Ordering::Relaxed) > MAX_CONNECTIONS.load(Ordering::Relaxed) {
|
||||
thread::sleep(time::Duration::from_millis(1000));
|
||||
}
|
||||
let new_connections = OPEN_CONNECTIONS.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
println!("{} connection(s) now open", new_connections);
|
||||
info!("{} connection(s) now open", new_connections);
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::channel(5);
|
||||
let socket = Socket::new("127.0.0.1:6379", tx);
|
||||
|
@ -146,7 +96,7 @@ impl PubSub {
|
|||
};
|
||||
|
||||
let timeline = timeline.to_string();
|
||||
let stream_of_data_from_redis = Receiver::new(socket, timeline, user);
|
||||
let stream_of_data_from_redis = stream::Receiver::new(socket, timeline, user);
|
||||
stream_of_data_from_redis
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
use crate::pubsub;
|
||||
use crate::pubsub::PubSub;
|
||||
use crate::user::User;
|
||||
use futures::stream::Stream;
|
||||
use futures::{Async, Poll};
|
||||
use log::info;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, Error, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StreamManager {
|
||||
recv: Arc<Mutex<HashMap<String, Receiver>>>,
|
||||
last_polled: Arc<Mutex<HashMap<String, Instant>>>,
|
||||
current_stream: String,
|
||||
id: uuid::Uuid,
|
||||
}
|
||||
impl StreamManager {
|
||||
pub fn new() -> Self {
|
||||
StreamManager {
|
||||
recv: Arc::new(Mutex::new(HashMap::new())),
|
||||
last_polled: Arc::new(Mutex::new(HashMap::new())),
|
||||
current_stream: String::new(),
|
||||
id: Uuid::new_v4(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_copy(&self) -> Self {
|
||||
let id = Uuid::new_v4();
|
||||
StreamManager { id, ..self.clone() }
|
||||
}
|
||||
|
||||
pub fn add(&mut self, timeline: &String, user: &User) -> &Self {
|
||||
let mut streams = self.recv.lock().expect("No other thread panic");
|
||||
streams
|
||||
.entry(timeline.clone())
|
||||
.or_insert_with(|| PubSub::from(&timeline, &user));
|
||||
let mut last_polled = self.last_polled.lock().expect("No other thread panic");
|
||||
last_polled.insert(timeline.clone(), Instant::now());
|
||||
|
||||
// Drop any streams that haven't been polled in the last 30 seconds
|
||||
last_polled
|
||||
.clone()
|
||||
.iter()
|
||||
.filter(|(_, time)| time.elapsed().as_secs() > 30)
|
||||
.for_each(|(key, _)| {
|
||||
last_polled.remove(key);
|
||||
streams.remove(key);
|
||||
});
|
||||
|
||||
self.current_stream = timeline.clone();
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Stream for StreamManager {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
let mut last_polled = self.last_polled.lock().expect("No other thread panic");
|
||||
let target_stream = self.current_stream.clone();
|
||||
last_polled.insert(target_stream.clone(), Instant::now());
|
||||
|
||||
let mut streams = self.recv.lock().expect("No other thread panic");
|
||||
let shared_conn = streams.get_mut(&target_stream).expect("known key");
|
||||
shared_conn.set_polled_by(self.id);
|
||||
|
||||
match shared_conn.poll() {
|
||||
Ok(Async::Ready(Some(value))) => Ok(Async::Ready(Some(value))),
|
||||
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver {
|
||||
rx: ReadHalf<TcpStream>,
|
||||
tx: WriteHalf<TcpStream>,
|
||||
tl: String,
|
||||
pub user: User,
|
||||
polled_by: Uuid,
|
||||
msg_queue: HashMap<Uuid, VecDeque<Value>>,
|
||||
}
|
||||
impl Receiver {
|
||||
pub fn new(socket: TcpStream, tl: String, user: &User) -> Self {
|
||||
let (rx, mut tx) = socket.split();
|
||||
tx.poll_write(pubsub::RedisCmd::subscribe_to_timeline(&tl).as_bytes())
|
||||
.expect("Can subscribe to Redis");
|
||||
Self {
|
||||
rx,
|
||||
tx,
|
||||
tl,
|
||||
user: user.clone(),
|
||||
polled_by: Uuid::new_v4(),
|
||||
msg_queue: HashMap::new(),
|
||||
}
|
||||
}
|
||||
pub fn set_polled_by(&mut self, id: Uuid) -> &Self {
|
||||
self.polled_by = id;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Stream for Receiver {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
|
||||
let mut buffer = vec![0u8; 3000];
|
||||
let polled_by = self.polled_by;
|
||||
self.msg_queue.entry(polled_by).or_insert(VecDeque::new());
|
||||
info!("Being polled by StreamManager with uuid: {}", polled_by);
|
||||
if let Async::Ready(num_bytes_read) = self.rx.poll_read(&mut buffer)? {
|
||||
// capture everything between `{` and `}` as potential JSON
|
||||
// TODO: figure out if `(?x)` is needed
|
||||
let re = Regex::new(r"(?P<json>\{.*\})").expect("Valid hard-coded regex");
|
||||
|
||||
if let Some(cap) = re.captures(&String::from_utf8_lossy(&buffer[..num_bytes_read])) {
|
||||
let json: Value = serde_json::from_str(&cap["json"].to_string().clone())?;
|
||||
for value in self.msg_queue.values_mut() {
|
||||
value.push_back(json.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(value) = self.msg_queue.entry(polled_by).or_default().pop_front() {
|
||||
Ok(Async::Ready(Some(value)))
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Drop for Receiver {
|
||||
fn drop(&mut self) {
|
||||
let channel = format!("timeline:{}", self.tl);
|
||||
self.tx
|
||||
.poll_write(pubsub::RedisCmd::unsubscribe_from_timeline(&channel).as_bytes())
|
||||
.expect("Can unsubscribe from Redis");
|
||||
let open_connections = pubsub::OPEN_CONNECTIONS.fetch_sub(1, Ordering::Relaxed) - 1;
|
||||
info!("Receiver dropped. {} connection(s) open", open_connections);
|
||||
}
|
||||
}
|
10
src/user.rs
10
src/user.rs
|
@ -33,7 +33,7 @@ pub enum Filter {
|
|||
#[derive(Clone, Debug)]
|
||||
pub struct User {
|
||||
pub id: i64,
|
||||
pub langs: Vec<String>,
|
||||
pub langs: Option<Vec<String>>,
|
||||
pub logged_in: bool,
|
||||
pub filter: Filter,
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ impl User {
|
|||
pub fn public() -> Self {
|
||||
User {
|
||||
id: -1,
|
||||
langs: Vec::new(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
filter: Filter::None,
|
||||
}
|
||||
|
@ -107,7 +107,8 @@ LIMIT 1",
|
|||
if !result.is_empty() {
|
||||
let only_row = result.get(0);
|
||||
let id: i64 = only_row.get(1);
|
||||
let langs: Vec<String> = only_row.get(2);
|
||||
let langs: Option<Vec<String>> = only_row.get(2);
|
||||
println!("Granting logged-in access");
|
||||
Ok(User {
|
||||
id,
|
||||
langs,
|
||||
|
@ -115,9 +116,10 @@ LIMIT 1",
|
|||
filter: Filter::None,
|
||||
})
|
||||
} else if let Scope::Public = scope {
|
||||
println!("Granting public access");
|
||||
Ok(User {
|
||||
id: -1,
|
||||
langs: Vec::new(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
filter: Filter::None,
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue