diff --git a/src/main.rs b/src/main.rs index 44192fa..fa218a8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,7 @@ mod query; mod user; mod utils; use futures::stream::Stream; -use pretty_env_logger; -use pubsub::stream_from; +use pubsub::PubSub; use user::{Filter, Scope, User}; use warp::{path, Filter as WarpFilter}; @@ -17,21 +16,21 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(user.id.to_string(), user)); + .map(|user: User| PubSub::from(user.id.to_string(), user)); // GET /api/v1/streaming/user/notification [private; notification filter] let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification") .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(user.id.to_string(), user.with_notification_filter())); + .map(|user: User| PubSub::from(user.id.to_string(), user.with_notification_filter())); // GET /api/v1/streaming/public [public; language filter] let public_timeline = path!("api" / "v1" / "streaming" / "public") .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| stream_from("public".into(), user.with_language_filter())); + .map(|user: User| PubSub::from("public".into(), user.with_language_filter())); // GET /api/v1/streaming/public?only_media=true [public; language filter] let public_timeline_media = path!("api" / "v1" / "streaming" / "public") @@ -40,8 +39,8 @@ fn main() { .and_then(|token| user::get_account(token, Scope::Public)) .and(warp::query()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => stream_from("public:media".into(), user.with_language_filter()), - _ => stream_from("public".into(), user.with_language_filter()), + "1" | "true" => PubSub::from("public:media".into(), user.with_language_filter()), + _ => PubSub::from("public".into(), user.with_language_filter()), }); // GET /api/v1/streaming/public/local [public; language filter] @@ -49,7 +48,7 @@ fn main() { .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| stream_from("public:local".into(), user.with_language_filter())); + .map(|user: User| PubSub::from("public:local".into(), user.with_language_filter())); // GET /api/v1/streaming/public/local?only_media=true [public; language filter] let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local") @@ -58,8 +57,8 @@ fn main() { .and(warp::query()) .and(path::end()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => stream_from("public:local:media".into(), user.with_language_filter()), - _ => stream_from("public:local".into(), user.with_language_filter()), + "1" | "true" => PubSub::from("public:local:media".into(), user.with_language_filter()), + _ => PubSub::from("public:local".into(), user.with_language_filter()), }); // GET /api/v1/streaming/direct [private; *no* filter] @@ -67,29 +66,29 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(format!("direct:{}", user.id), user.with_no_filter())); + .map(|user: User| PubSub::from(format!("direct:{}", user.id), user.with_no_filter())); // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| stream_from(format!("hashtag:{}", q.tag), User::public())); + .map(|q: query::Hashtag| PubSub::from(format!("hashtag:{}", q.tag), User::public())); // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| stream_from(format!("hashtag:{}:local", q.tag), User::public())); + .map(|q: query::Hashtag| PubSub::from(format!("hashtag:{}:local", q.tag), User::public())); // GET /api/v1/streaming/list?list=:list_id [private; no filter] let list_timeline = path!("api" / "v1" / "streaming" / "list") .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) .and(warp::query()) - .and_then(|user: User, q: query::List| user.is_authorized_for_list(q.list)) + .and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user))) .untuple_one() .and(path::end()) - .map(|list: i64, user: User| stream_from(format!("list:{}", list), user.with_no_filter())); + .map(|list: i64, user: User| PubSub::from(format!("list:{}", list), user.with_no_filter())); let routes = or!( user_timeline, diff --git a/src/pubsub.rs b/src/pubsub.rs index c2eb787..562d992 100644 --- a/src/pubsub.rs +++ b/src/pubsub.rs @@ -3,34 +3,56 @@ use futures::{Async, Future, Poll}; use log::{debug, info}; use regex::Regex; use serde_json::Value; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{thread, time}; use tokio::io::{AsyncRead, AsyncWrite, Error, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use warp::Stream; +static OPEN_CONNECTIONS: AtomicUsize = AtomicUsize::new(0); +static MAX_CONNECTIONS: AtomicUsize = AtomicUsize::new(400); + +struct RedisCmd { + resp_cmd: String, +} +impl RedisCmd { + fn new(cmd: impl std::fmt::Display, arg: impl std::fmt::Display) -> Self { + let (cmd, arg) = (cmd.to_string(), arg.to_string()); + let resp_cmd = format!( + "*2\r\n${cmd_length}\r\n{cmd}\r\n${arg_length}\r\n{arg}\r\n", + cmd_length = cmd.len(), + cmd = cmd, + arg_length = arg.len(), + arg = arg + ); + Self { resp_cmd } + } + fn subscribe_to_timeline(timeline: &str) -> String { + let channel = format!("timeline:{}", timeline); + let subscribe = RedisCmd::new("subscribe", &channel); + info!("Subscribing to {}", &channel); + subscribe.resp_cmd + } + fn unsubscribe_from_timeline(timeline: &str) -> String { + let channel = format!("timeline:{}", timeline); + let unsubscribe = RedisCmd::new("unsubscribe", &channel); + info!("Unsubscribing from {}", &channel); + unsubscribe.resp_cmd + } +} + pub struct Receiver { rx: ReadHalf, tx: WriteHalf, - timeline: String, + tl: String, pub user: User, } impl Receiver { - fn new(socket: TcpStream, timeline: String, user: User) -> Self { + fn new(socket: TcpStream, tl: String, user: User) -> Self { let (rx, mut tx) = socket.split(); - let channel = format!("timeline:{}", timeline); - info!("Subscribing to {}", &channel); - let subscribe_cmd = format!( - "*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", - channel.len(), - channel - ); - let buffer = subscribe_cmd.as_bytes(); - tx.poll_write(&buffer).unwrap(); - Self { - rx, - tx, - timeline, - user, - } + tx.poll_write(RedisCmd::subscribe_to_timeline(&tl).as_bytes()) + .expect("Can subscribe to Redis"); + Self { rx, tx, tl, user } } } impl Stream for Receiver { @@ -40,12 +62,12 @@ impl Stream for Receiver { fn poll(&mut self) -> Poll, Self::Error> { let mut buffer = vec![0u8; 3000]; if let Async::Ready(num_bytes_read) = self.rx.poll_read(&mut buffer)? { - let re = Regex::new(r"(?x)(?P\{.*\})").unwrap(); + // capture everything between `{` and `}` as potential JSON + let re = Regex::new(r"(?P\{.*\})").expect("Valid hard-coded regex"); if let Some(cap) = re.captures(&String::from_utf8_lossy(&buffer[..num_bytes_read])) { debug!("{}", &cap["json"]); - let json_string = cap["json"].to_string(); - let json: Value = serde_json::from_str(&json_string.clone())?; + let json: Value = serde_json::from_str(&cap["json"].to_string().clone())?; return Ok(Async::Ready(Some(json))); } return Ok(Async::NotReady); @@ -55,31 +77,39 @@ impl Stream for Receiver { } impl Drop for Receiver { fn drop(&mut self) { - let channel = format!("timeline:{}", self.timeline); - let unsubscribe_cmd = format!( - "*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", - channel.len(), - channel - ); - self.tx.poll_write(unsubscribe_cmd.as_bytes()).unwrap(); - println!("Receiver got dropped!"); + let channel = format!("timeline:{}", self.tl); + self.tx + .poll_write(RedisCmd::unsubscribe_from_timeline(&channel).as_bytes()) + .expect("Can unsubscribe from Redis"); + let open_connections = OPEN_CONNECTIONS.fetch_sub(1, Ordering::Relaxed) - 1; + info!("Receiver dropped. {} connection(s) open", open_connections); } } -fn get_socket() -> impl Future> { - let address = "127.0.0.1:6379".parse().expect("Unable to parse address"); - let connection = TcpStream::connect(&address); - connection.and_then(Ok).map_err(Box::new) -} +pub struct PubSub {} -pub fn stream_from( - timeline: String, - user: User, -) -> impl Future { - get_socket() - .and_then(move |socket| { - let stream_of_data_from_redis = Receiver::new(socket, timeline, user); - Ok(stream_of_data_from_redis) - }) - .map_err(warp::reject::custom) +impl PubSub { + pub fn from( + timeline: impl std::fmt::Display, + user: User, + ) -> impl Future { + while OPEN_CONNECTIONS.load(Ordering::Relaxed) > MAX_CONNECTIONS.load(Ordering::Relaxed) { + thread::sleep(time::Duration::from_millis(1000)); + } + let new_connections = OPEN_CONNECTIONS.fetch_add(1, Ordering::Relaxed) + 1; + println!("{} connection(s) now open", new_connections); + + let timeline = timeline.to_string(); + fn get_socket() -> impl Future> { + let address = "127.0.0.1:6379".parse().expect("Unable to parse address"); + let connection = TcpStream::connect(&address); + connection.and_then(Ok).map_err(Box::new) + } + get_socket() + .and_then(move |socket| { + let stream_of_data_from_redis = Receiver::new(socket, timeline, user); + Ok(stream_of_data_from_redis) + }) + .map_err(warp::reject::custom) + } } diff --git a/src/user.rs b/src/user.rs index b28c518..5a08997 100644 --- a/src/user.rs +++ b/src/user.rs @@ -56,7 +56,7 @@ impl User { ..self } } - pub fn is_authorized_for_list(self, list: i64) -> Result<(i64, User), warp::reject::Rejection> { + pub fn is_authorized_for_list(&self, list: i64) -> Result { let conn = conn(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn @@ -68,7 +68,7 @@ impl User { if !rows.is_empty() { let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1); if id_of_account_that_owns_the_list == self.id { - return Ok((list, self)); + return Ok(list); } }; @@ -109,7 +109,7 @@ LIMIT 1", let id: i64 = only_row.get(1); let langs: Vec = only_row.get(2); Ok(User { - id: id, + id, langs, logged_in: true, filter: Filter::None,