From e1257146cda97ba07c875b06643b88a9fa31440c Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Sun, 21 Apr 2019 09:21:44 -0400 Subject: [PATCH] Close Redis connections when SSE stream ends This commit tracks the existence of the SSE stream and closes the connection to the redis pub/sub channel when the stream is closed. This prevents the number of redis connections from growing over time. Note, however, that the current code still subscribes to one redis channel per SSE connection rather than reusing existing subscriptions. This will need to be fixed in a later PR. --- src/main.rs | 39 ++++++++++++++++--------------- src/pubsub.rs | 14 ++++------- src/query.rs | 2 +- src/user.rs | 64 +++++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 85 insertions(+), 34 deletions(-) diff --git a/src/main.rs b/src/main.rs index 563c2eb..36eac91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,8 @@ mod user; mod utils; use futures::stream::Stream; use pretty_env_logger; -use pubsub::{stream_from, Filter}; -use user::{Scope, User}; +use pubsub::stream_from; +use user::{Filter, Scope, User}; use warp::{path, Filter as WarpFilter}; fn main() { @@ -17,21 +17,21 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(user.id, Filter::None)); + .map(|user: User| stream_from(user.id.to_string(), user)); // GET /api/v1/streaming/user/notification [private; notification filter] let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification") .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(user.id, Filter::Notification)); + .map(|user: User| stream_from(user.id.to_string(), user.with_notification_filter())); // GET /api/v1/streaming/public [public; language filter] let public_timeline = path!("api" / "v1" / "streaming" / "public") .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| stream_from("public".into(), Filter::Language(user.langs))); + .map(|user: User| stream_from("public".into(), user.with_language_filter())); // GET /api/v1/streaming/public?only_media=true [public; language filter] let public_timeline_media = path!("api" / "v1" / "streaming" / "public") @@ -40,8 +40,8 @@ fn main() { .and_then(|token| user::get_account(token, Scope::Public)) .and(warp::query()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => stream_from("public:media".into(), Filter::Language(user.langs)), - _ => stream_from("public".into(), Filter::Language(user.langs)), + "1" | "true" => stream_from("public:media".into(), user.with_language_filter()), + _ => stream_from("public".into(), user.with_language_filter()), }); // GET /api/v1/streaming/public/local [public; language filter] @@ -49,7 +49,7 @@ fn main() { .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| stream_from("public:local".into(), Filter::Language(user.langs))); + .map(|user: User| stream_from("public:local".into(), user.with_language_filter())); // GET /api/v1/streaming/public/local?only_media=true [public; language filter] let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local") @@ -58,8 +58,8 @@ fn main() { .and(warp::query()) .and(path::end()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => stream_from("public:local:media".into(), Filter::Language(user.langs)), - _ => stream_from("public:local".into(), Filter::None), + "1" | "true" => stream_from("public:local:media".into(), user.with_language_filter()), + _ => stream_from("public:local".into(), user.with_language_filter()), }); // GET /api/v1/streaming/direct [private; *no* filter] @@ -67,28 +67,29 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|account: User| stream_from(format!("direct:{}", account.id), Filter::None)); + .map(|user: User| stream_from(format!("direct:{}", user.id), user.with_no_filter())); // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| stream_from(format!("hashtag:{}", q.tag), Filter::None)); + .map(|q: query::Hashtag| stream_from(format!("hashtag:{}", q.tag), User::public())); // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| stream_from(format!("hashtag:{}:local", q.tag), Filter::None)); + .map(|q: query::Hashtag| stream_from(format!("hashtag:{}:local", q.tag), User::public())); // GET /api/v1/streaming/list?list=:list_id [private; no filter] let list_timeline = path!("api" / "v1" / "streaming" / "list") .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) .and(warp::query()) + .and_then(|user: User, q: query::List| user.is_authorized_for_list(q.list)) + .untuple_one() .and(path::end()) - // TODO: filter down to lists the user can access - .map(|_user: User, q: query::List| stream_from(format!("list:{}", q.list), Filter::None)); + .map(|list: i64, user: User| stream_from(format!("list:{}", list), user.with_no_filter())); let routes = or!( user_timeline, @@ -105,15 +106,15 @@ fn main() { .and_then(|event_stream| event_stream) .and(warp::sse()) .map(|event_stream: pubsub::Receiver, sse: warp::sse::Sse| { - let filter = event_stream.filter.clone(); + let user = event_stream.user.clone(); sse.reply(warp::sse::keep( event_stream.filter_map(move |item| { let payload = item["payload"].clone(); let event = item["event"].to_string().clone(); - let lang = item["language"].to_string().clone(); - match filter { + let toot_lang = item["language"].to_string().clone(); + match &user.filter { Filter::Notification if event != "notification" => None, - Filter::Language(ref vec) if !vec.contains(&lang) => None, + Filter::Language if !user.langs.contains(&toot_lang) => None, _ => Some((warp::sse::event(event), warp::sse::data(payload))), } }), diff --git a/src/pubsub.rs b/src/pubsub.rs index 92d9d8e..c38bfd4 100644 --- a/src/pubsub.rs +++ b/src/pubsub.rs @@ -1,3 +1,4 @@ +use crate::user::User; use futures::{Async, Future, Poll}; use log::{debug, info}; use regex::Regex; @@ -6,15 +7,9 @@ use tokio::io::{AsyncRead, AsyncWrite, Error, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use warp::Stream; -#[derive(Clone)] -pub enum Filter { - None, - Language(Vec), - Notification, -} pub struct Receiver { rx: ReadHalf, - pub filter: Filter, + pub user: User, } impl Stream for Receiver { type Item = Value; @@ -68,16 +63,15 @@ fn send_subscribe_cmd(tx: WriteHalf, channel: String) { tokio::spawn(sender.map_err(|e| eprintln!("{}", e))); } -/// Create a stream from a string. pub fn stream_from( timeline: String, - filter: Filter, + user: User, ) -> impl Future { get_socket() .and_then(move |socket| { let (rx, tx) = socket.split(); send_subscribe_cmd(tx, format!("timeline:{}", timeline)); - let stream_of_data_from_redis = Receiver { rx, filter }; + let stream_of_data_from_redis = Receiver { rx, user }; Ok(stream_of_data_from_redis) }) .map_err(warp::reject::custom) diff --git a/src/query.rs b/src/query.rs index 0e90177..bb73830 100644 --- a/src/query.rs +++ b/src/query.rs @@ -10,7 +10,7 @@ pub struct Hashtag { } #[derive(Deserialize)] pub struct List { - pub list: String, + pub list: i64, } #[derive(Deserialize)] pub struct Auth { diff --git a/src/user.rs b/src/user.rs index 731919c..b28c518 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,6 +1,6 @@ use crate::{or, query}; use postgres; -use warp::Filter; +use warp::Filter as WarpFilter; pub fn get_access_token(scope: Scope) -> warp::filters::BoxedFilter<(String,)> { let token_from_header = warp::header::header::("authorization") @@ -23,11 +23,65 @@ fn conn() -> postgres::Connection { ) .unwrap() } +#[derive(Clone)] +pub enum Filter { + None, + Language, + Notification, +} +#[derive(Clone)] pub struct User { - pub id: String, + pub id: i64, pub langs: Vec, pub logged_in: bool, + pub filter: Filter, +} +impl User { + pub fn with_notification_filter(self) -> Self { + Self { + filter: Filter::Notification, + ..self + } + } + pub fn with_language_filter(self) -> Self { + Self { + filter: Filter::Language, + ..self + } + } + pub fn with_no_filter(self) -> Self { + Self { + filter: Filter::None, + ..self + } + } + pub fn is_authorized_for_list(self, list: i64) -> Result<(i64, User), warp::reject::Rejection> { + let conn = conn(); + // For the Postgres query, `id` = list number; `account_id` = user.id + let rows = &conn + .query( + " SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", + &[&list], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if !rows.is_empty() { + let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1); + if id_of_account_that_owns_the_list == self.id { + return Ok((list, self)); + } + }; + + Err(warp::reject::custom("Error: Invalid access token")) + } + pub fn public() -> Self { + User { + id: -1, + langs: Vec::new(), + logged_in: false, + filter: Filter::None, + } + } } pub enum Scope { @@ -55,15 +109,17 @@ LIMIT 1", let id: i64 = only_row.get(1); let langs: Vec = only_row.get(2); Ok(User { - id: id.to_string(), + id: id, langs, logged_in: true, + filter: Filter::None, }) } else if let Scope::Public = scope { Ok(User { - id: String::new(), + id: -1, langs: Vec::new(), logged_in: false, + filter: Filter::None, }) } else { Err(warp::reject::custom("Error: Invalid access token"))