diff --git a/src/main.rs b/src/main.rs index 19d47dc..afd8b26 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,6 +22,7 @@ fn main() -> Result<(), Error> { // Create channels to communicate between threads let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); + let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?; @@ -36,7 +37,8 @@ fn main() -> Result<(), Error> { .map(move |subscription: Subscription, sse: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); let mut manager = sse_manager.lock().unwrap_or_else(RedisManager::recover); - manager.subscribe(&subscription); + let (event_tx_2, _event_rx_2) = mpsc::unbounded_channel(); + manager.subscribe(&subscription, event_tx_2); SseStream::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone()) }) @@ -50,11 +52,15 @@ fn main() -> Result<(), Error> { .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); let mut manager = ws_manager.lock().unwrap_or_else(RedisManager::recover); - manager.subscribe(&subscription); + let (event_tx_2, event_rx_2) = mpsc::unbounded_channel(); + manager.subscribe(&subscription, event_tx_2); let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security - let ws_stream = WsStream::new(cmd_tx.clone(), event_rx.clone(), subscription); + let ws_stream = WsStream::new(cmd_tx.clone(), subscription); - (ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token) + ( + ws.on_upgrade(move |ws| ws_stream.send_to(ws, event_rx_2)), + token, + ) }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); diff --git a/src/response/redis/manager.rs b/src/response/redis/manager.rs index b32eef2..4033aec 100644 --- a/src/response/redis/manager.rs +++ b/src/response/redis/manager.rs @@ -25,6 +25,7 @@ pub struct Manager { redis_connection: RedisConn, clients_per_timeline: HashMap, tx: watch::Sender<(Timeline, Event)>, + timelines: HashMap>>, rx: mpsc::UnboundedReceiver, ping_time: Instant, } @@ -40,6 +41,7 @@ impl Manager { Ok(Self { redis_connection: RedisConn::new(redis_cfg)?, clients_per_timeline: HashMap::new(), + timelines: HashMap::new(), tx, rx, ping_time: Instant::now(), @@ -50,12 +52,21 @@ impl Manager { Arc::new(Mutex::new(self)) } - pub fn subscribe(&mut self, subscription: &Subscription) { + pub fn subscribe( + &mut self, + subscription: &Subscription, + channel: mpsc::UnboundedSender, + ) { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); if let (Some(hashtag), Some(id)) = (tag, tl.tag()) { self.redis_connection.update_cache(hashtag, id); }; + self.timelines + .entry(tl) + .and_modify(|vec| vec.push(channel.clone())) + .or_insert_with(|| vec![channel]); + let number_of_subscriptions = self .clients_per_timeline .entry(tl) @@ -70,7 +81,16 @@ impl Manager { }; } - pub(crate) fn unsubscribe(&mut self, tl: Timeline) -> Result<()> { + pub(crate) fn unsubscribe( + &mut self, + tl: Timeline, + _target_channel: mpsc::UnboundedSender, + ) -> Result<()> { + let channels = self.timelines.get(&tl).expect("TODO"); + for (_i, _channel) in channels.iter().enumerate() { + // TODO - find alternate implementation + } + let number_of_subscriptions = self .clients_per_timeline .entry(tl) @@ -92,22 +112,40 @@ impl Manager { } pub fn poll_broadcast(&mut self) -> Result<()> { - while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { - self.unsubscribe(tl)? - } - + // while let Ok(Async::Ready(Some(tl))) = self.rx.poll() { + // self.unsubscribe(tl)? + // } + let mut completed_timelines = Vec::new(); if self.ping_time.elapsed() > Duration::from_secs(30) { self.ping_time = Instant::now(); - self.tx.broadcast((Timeline::empty(), Event::Ping))? - } else { - match self.redis_connection.poll_redis() { - Ok(Async::NotReady) | Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace - Ok(Async::Ready(Some((timeline, event)))) => { - self.tx.broadcast((timeline, event))? + for (timeline, channels) in self.timelines.iter_mut() { + for channel in channels.iter_mut() { + match channel.try_send(Event::Ping) { + Ok(_) => (), + Err(_) => completed_timelines.push((*timeline, channel.clone())), + } } + } + }; + loop { + match self.redis_connection.poll_redis() { + Ok(Async::NotReady) => break, + Ok(Async::Ready(Some((timeline, event)))) => { + for channel in self.timelines.get_mut(&timeline).ok_or(Error::InvalidId)? { + match channel.try_send(event.clone()) { + Ok(_) => (), + Err(_) => completed_timelines.push((timeline, channel.clone())), + } + } + } + Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace Err(err) => log::error!("{}", err), // drop msg, log err, and proceed } } + + for (tl, channel) in completed_timelines { + self.unsubscribe(tl, channel)?; + } Ok(()) } diff --git a/src/response/redis/manager/err.rs b/src/response/redis/manager/err.rs index 8d6d4f2..4ef4653 100644 --- a/src/response/redis/manager/err.rs +++ b/src/response/redis/manager/err.rs @@ -11,6 +11,7 @@ pub enum Error { RedisParseErr(RedisParseErr), RedisConnErr(RedisConnErr), ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>), + ChannelSendErr2(tokio::sync::mpsc::error::UnboundedTrySendError), } impl std::error::Error for Error {} @@ -28,6 +29,7 @@ impl fmt::Display for Error { RedisConnErr(inner) => write!(f, "{}", inner), TimelineErr(inner) => write!(f, "{}", inner), ChannelSendErr(inner) => write!(f, "{}", inner), + ChannelSendErr2(inner) => write!(f, "{}", inner), }?; Ok(()) } @@ -38,6 +40,11 @@ impl From> for Error { Self::ChannelSendErr(error) } } +impl From> for Error { + fn from(error: tokio::sync::mpsc::error::UnboundedTrySendError) -> Self { + Self::ChannelSendErr2(error) + } +} impl From for Error { fn from(error: EventErr) -> Self { diff --git a/src/response/stream/ws.rs b/src/response/stream/ws.rs index 4485409..595c7bb 100644 --- a/src/response/stream/ws.rs +++ b/src/response/stream/ws.rs @@ -2,7 +2,7 @@ use super::{Event, Payload}; use crate::request::{Subscription, Timeline}; use futures::{future::Future, stream::Stream}; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; use warp::ws::{Message, WebSocket}; type Result = std::result::Result; @@ -10,25 +10,27 @@ type Result = std::result::Result; pub struct Ws { unsubscribe_tx: mpsc::UnboundedSender, subscription: Subscription, - ws_rx: watch::Receiver<(Timeline, Event)>, ws_tx: Option>, } impl Ws { pub fn new( unsubscribe_tx: mpsc::UnboundedSender, - ws_rx: watch::Receiver<(Timeline, Event)>, subscription: Subscription, ) -> Self { Self { unsubscribe_tx, subscription, - ws_rx, + ws_tx: None, } } - pub fn send_to(mut self, ws: WebSocket) -> impl Future { + pub fn send_to( + mut self, + ws: WebSocket, + incoming_events: mpsc::UnboundedReceiver, + ) -> impl Future { let (transmit_to_ws, _receive_from_ws) = ws.split(); // Create a pipe let (ws_tx, ws_rx) = mpsc::unbounded_channel(); @@ -49,29 +51,25 @@ impl Ws { }), ); - let target_timeline = self.subscription.timeline; - let incoming_events = self.ws_rx.clone().map_err(|_| ()); - - incoming_events.for_each(move |(tl, event)| { - //TODO log::info!("{:?}, {:?}", &tl, &event); + incoming_events.map_err(|_| ()).for_each(move |event| { if matches!(event, Event::Ping) { self.send_msg(&event)? - } else if target_timeline == tl { + } else { match (event.update_payload(), event.dyn_update_payload()) { - (Some(update), _) => self.send_or_filter(tl, &event, update)?, + (Some(update), _) => self.send_or_filter(&event, update)?, (None, None) => self.send_msg(&event)?, // send all non-updates - (_, Some(dyn_update)) => self.send_or_filter(tl, &event, dyn_update)?, + (_, Some(dyn_update)) => self.send_or_filter(&event, dyn_update)?, } } Ok(()) }) } - fn send_or_filter(&mut self, tl: Timeline, event: &Event, update: &impl Payload) -> Result<()> { + fn send_or_filter(&mut self, event: &Event, update: &impl Payload) -> Result<()> { let (blocks, allowed_langs) = (&self.subscription.blocks, &self.subscription.allowed_langs); const SKIP: Result<()> = Ok(()); - match tl { + match self.subscription.timeline { tl if tl.is_public() && !update.language_unset() && !allowed_langs.is_empty()