diff --git a/src/lib.rs b/src/lib.rs index 4316d80..e29bb98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,7 +35,7 @@ //! polls the `Receiver` and the frequency with which the `Receiver` polls Redis. //! -#![warn(clippy::pedantic)] +//#![warn(clippy::pedantic)] #![allow(clippy::try_err, clippy::match_bool)] pub mod config; diff --git a/src/main.rs b/src/main.rs index 6d25642..0a1e3c3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,14 @@ use flodgatt::request::{PgPool, Subscription, Timeline}; use flodgatt::response::redis; use flodgatt::response::stream; +use futures::{future::lazy, stream::Stream as _Stream}; use std::fs; use std::net::SocketAddr; use std::os::unix::fs::PermissionsExt; +use std::time::Instant; use tokio::net::UnixListener; use tokio::sync::{mpsc, watch}; +use tokio::timer::Interval; use warp::http::StatusCode; use warp::path; use warp::ws::Ws2; @@ -18,28 +21,27 @@ use warp::{Filter, Rejection}; fn main() -> Result<(), FatalErr> { config::merge_dotenv()?; pretty_env_logger::try_init()?; - let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect()); + + // Create channels to communicate between threads let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let shared_pg_conn = PgPool::new(postgres_cfg, *cfg.whitelist_mode); let poll_freq = *redis_cfg.polling_interval; - let manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc(); + let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc(); // Server Sent Events - let sse_manager = manager.clone(); + let sse_manager = shared_manager.clone(); let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone()); - let sse_routes = Subscription::from_sse_request(shared_pg_conn.clone()) + let sse = Subscription::from_sse_request(shared_pg_conn.clone()) .and(warp::sse()) .map( move |subscription: Subscription, client_conn: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); { let mut manager = sse_manager.lock().unwrap_or_else(redis::Manager::recover); - manager.subscribe(&subscription).unwrap_or_else(|e| { - log::error!("Could not subscribe to the Redis channel: {}", e) - }); + manager.subscribe(&subscription); } stream::Sse::send_events( @@ -53,29 +55,19 @@ fn main() -> Result<(), FatalErr> { .with(warp::reply::with::header("Connection", "keep-alive")); // WebSocket - let ws_manager = manager.clone(); - let ws_routes = Subscription::from_ws_request(shared_pg_conn) + let ws_manager = shared_manager.clone(); + let ws = Subscription::from_ws_request(shared_pg_conn) .and(warp::ws::ws2()) .map(move |subscription: Subscription, ws: Ws2| { log::info!("Incoming websocket request for {:?}", subscription.timeline); { let mut manager = ws_manager.lock().unwrap_or_else(redis::Manager::recover); - - manager.subscribe(&subscription).unwrap_or_else(|e| { - log::error!("Could not subscribe to the Redis channel: {}", e) - }); + manager.subscribe(&subscription); } - let cmd_tx = cmd_tx.clone(); - let ws_rx = event_rx.clone(); - let token = subscription - .clone() - .access_token - .unwrap_or_else(String::new); + let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security + let ws_stream = stream::Ws::new(cmd_tx.clone(), event_rx.clone(), subscription); - let ws_response_stream = ws - .on_upgrade(move |ws| stream::Ws::new(ws, cmd_tx, subscription).send_events(ws_rx)); - - (ws_response_stream, token) + (ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token) }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); @@ -84,9 +76,10 @@ fn main() -> Result<(), FatalErr> { .allow_methods(cfg.cors.allowed_methods) .allow_headers(cfg.cors.allowed_headers); + // TODO -- extract to separate file #[cfg(feature = "stub_status")] - let status_endpoints = { - let (r1, r3) = (manager.clone(), manager.clone()); + let status = { + let (r1, r3) = (shared_manager.clone(), shared_manager.clone()); warp::path!("api" / "v1" / "streaming" / "health") .map(|| "OK") .or(warp::path!("api" / "v1" / "streaming" / "status") @@ -98,54 +91,43 @@ fn main() -> Result<(), FatalErr> { ) }; #[cfg(not(feature = "stub_status"))] - let status_endpoints = warp::path!("api" / "v1" / "streaming" / "health").map(|| "OK"); + let status = warp::path!("api" / "v1" / "streaming" / "health").map(|| "OK"); + + let streaming_server = move || { + let manager = shared_manager.clone(); + let stream = Interval::new(Instant::now(), poll_freq) + .map_err(|e| log::error!("{}", e)) + .for_each(move |_| { + let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover); + manager.poll_broadcast().unwrap_or_else(FatalErr::exit); + Ok(()) + }); + warp::spawn(lazy(move || stream)); + warp::serve(ws.or(sse).with(cors).or(status).recover(recover)) + }; if let Some(socket) = &*cfg.unix_socket { log::info!("Using Unix socket {}", socket); fs::remove_file(socket).unwrap_or_default(); - let incoming = UnixListener::bind(socket).unwrap().incoming(); - fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap(); + let incoming = UnixListener::bind(socket).expect("TODO").incoming(); + fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).expect("TODO"); - warp::serve( - ws_routes - .or(sse_routes) - .with(cors) - .or(status_endpoints) - .recover(|r: Rejection| { - let json_err = match r.cause() { - Some(text) - if text.to_string() == "Missing request header 'authorization'" => - { - warp::reply::json(&"Error: Missing access token".to_string()) - } - Some(text) => warp::reply::json(&text.to_string()), - None => warp::reply::json(&"Error: Nonexistant endpoint".to_string()), - }; - Ok(warp::reply::with_status(json_err, StatusCode::UNAUTHORIZED)) - }), - ) - .run_incoming(incoming); + tokio::run(lazy(|| streaming_server().serve_incoming(incoming))); } else { - use futures::{future::lazy, stream::Stream as _Stream}; - use std::time::Instant; - let server_addr = SocketAddr::new(*cfg.address, *cfg.port); - - tokio::run(lazy(move || { - let receiver = manager.clone(); - - warp::spawn(lazy(move || { - tokio::timer::Interval::new(Instant::now(), poll_freq) - .map_err(|e| log::error!("{}", e)) - .for_each(move |_| { - let mut receiver = receiver.lock().unwrap_or_else(redis::Manager::recover); - receiver.poll_broadcast().unwrap_or_else(FatalErr::exit); - Ok(()) - }) - })); - - warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).bind(server_addr) - })); - }; + tokio::run(lazy(move || streaming_server().bind(server_addr))); + } Ok(()) } + +// TODO -- extract to separate file +fn recover(r: Rejection) -> Result { + let json_err = match r.cause() { + Some(text) if text.to_string() == "Missing request header 'authorization'" => { + warp::reply::json(&"Error: Missing access token".to_string()) + } + Some(text) => warp::reply::json(&text.to_string()), + None => warp::reply::json(&"Error: Nonexistant endpoint".to_string()), + }; + Ok(warp::reply::with_status(json_err, StatusCode::UNAUTHORIZED)) +} diff --git a/src/request.rs b/src/request.rs index 83b48ed..9b5a28a 100644 --- a/src/request.rs +++ b/src/request.rs @@ -7,11 +7,60 @@ mod subscription; pub use self::postgres::PgPool; // TODO consider whether we can remove `Stream` from public API pub use subscription::{Blocks, Stream, Subscription, Timeline}; - -//#[cfg(test)] pub use subscription::{Content, Reach}; +use self::query::Query; +use crate::config; +use warp::{filters::BoxedFilter, path, reject::Rejection, Filter}; + #[cfg(test)] mod sse_test; #[cfg(test)] mod ws_test; + +pub struct Handler { + pg_conn: PgPool, +} + +impl Handler { + pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self { + Self { + pg_conn: PgPool::new(postgres_cfg, whitelist_mode), + } + } + + pub fn from_ws_request(&self) -> BoxedFilter<(Subscription,)> { + let pg_conn = self.pg_conn.clone(); + parse_ws_query() + .and(query::OptionalAccessToken::from_ws_header()) + .and_then(Query::update_access_token) + .and_then(move |q| Subscription::from_query(q, pg_conn.clone())) + .boxed() + } +} + +fn parse_ws_query() -> BoxedFilter<(Query,)> { + path!("api" / "v1" / "streaming") + .and(path::end()) + .and(warp::query()) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map( + |stream: query::Stream, + auth: query::Auth, + media: query::Media, + hashtag: query::Hashtag, + list: query::List| { + Query { + access_token: auth.access_token, + stream: stream.stream, + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() +} diff --git a/src/request/subscription.rs b/src/request/subscription.rs index 2914eca..7fe5810 100644 --- a/src/request/subscription.rs +++ b/src/request/subscription.rs @@ -117,7 +117,7 @@ impl Subscription { .boxed() } - fn from_query(q: Query, pool: PgPool) -> Result { + pub(super) fn from_query(q: Query, pool: PgPool) -> Result { let user = pool.clone().select_user(&q.access_token)?; let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?; let hashtag_name = match timeline { diff --git a/src/response/redis/manager.rs b/src/response/redis/manager.rs index ba64a23..7ce3054 100644 --- a/src/response/redis/manager.rs +++ b/src/response/redis/manager.rs @@ -50,7 +50,7 @@ impl Manager { Arc::new(Mutex::new(self)) } - pub fn subscribe(&mut self, subscription: &Subscription) -> Result<()> { + pub fn subscribe(&mut self, subscription: &Subscription) { let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline); if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) { self.redis_connection.update_cache(hashtag, id); @@ -64,9 +64,10 @@ impl Manager { use RedisCmd::*; if *number_of_subscriptions == 1 { - self.redis_connection.send_cmd(Subscribe, &tl)? + self.redis_connection + .send_cmd(Subscribe, &tl) + .unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e)); }; - Ok(()) } pub fn unsubscribe(&mut self, tl: Timeline) -> Result<()> { diff --git a/src/response/stream.rs b/src/response/stream.rs index 5baba9e..60a97d4 100644 --- a/src/response/stream.rs +++ b/src/response/stream.rs @@ -12,20 +12,31 @@ use warp::{ }; pub struct Ws { - ws_tx: mpsc::UnboundedSender, unsubscribe_tx: mpsc::UnboundedSender, subscription: Subscription, + ws_rx: watch::Receiver<(Timeline, Event)>, + ws_tx: Option>, } impl Ws { pub fn new( - ws: WebSocket, unsubscribe_tx: mpsc::UnboundedSender, + ws_rx: watch::Receiver<(Timeline, Event)>, subscription: Subscription, ) -> Self { + Self { + unsubscribe_tx, + subscription, + ws_rx, + ws_tx: None, + } + } + + pub fn send_to(mut self, ws: WebSocket) -> impl Future { let (transmit_to_ws, _receive_from_ws) = ws.split(); // Create a pipe let (ws_tx, ws_rx) = mpsc::unbounded_channel(); + self.ws_tx = Some(ws_tx); // Send one end of it to a different green thread and tell that end to forward // whatever it gets on to the WebSocket client @@ -39,20 +50,11 @@ impl Ws { _ => log::warn!("WebSocket send error: {}", e), }), ); - Self { - ws_tx, - unsubscribe_tx, - subscription, - } - } - pub fn send_events( - mut self, - event_rx: watch::Receiver<(Timeline, Event)>, - ) -> impl Future { let target_timeline = self.subscription.timeline; + let incoming_events = self.ws_rx.clone().map_err(|_| ()); - event_rx.map_err(|_| ()).for_each(move |(tl, event)| { + incoming_events.for_each(move |(tl, event)| { if matches!(event, Event::Ping) { self.send_ping() } else if target_timeline == tl { @@ -97,7 +99,7 @@ impl Ws { fn send_txt(&mut self, txt: &str) -> Result<(), ()> { let tl = self.subscription.timeline; - match self.ws_tx.try_send(Message::text(txt)) { + match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) { Ok(_) => Ok(()), Err(_) => { self.unsubscribe_tx