From 6d037dd5afe6bef5565c5c48ba034c1be75ea8ba Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Wed, 8 May 2019 23:02:01 -0400 Subject: [PATCH] Working WS implemetation, but not cleaned up --- src/error.rs | 1 + src/main.rs | 258 +++++++++++++++++++++++++++++++++++++++--------- src/query.rs | 53 +++++++++- src/receiver.rs | 87 +++++++++++++--- src/stream.rs | 111 ++++++++++++++------- src/timeline.rs | 2 +- src/user.rs | 21 +++- 7 files changed, 426 insertions(+), 107 deletions(-) diff --git a/src/error.rs b/src/error.rs index bede2f6..62c2d68 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,6 +25,7 @@ pub fn handle_errors( None => "Error: Nonexistant endpoint".to_string(), }; let json = warp::reply::json(&ErrorMessage::new(err_txt)); + println!("REJECTED!"); Ok(warp::reply::with_status( json, warp::http::StatusCode::UNAUTHORIZED, diff --git a/src/main.rs b/src/main.rs index ebc80f6..58673f7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,55 +38,219 @@ use warp::Filter as WarpFilter; fn main() { pretty_env_logger::init(); - let redis_updates = StreamManager::new(Receiver::new()); + // let redis_updates = StreamManager::new(Receiver::new()); - let routes = any_of!( - // GET /api/v1/streaming/user/notification [private; notification filter] - timeline::user_notifications(), - // GET /api/v1/streaming/user [private; language filter] - timeline::user(), - // GET /api/v1/streaming/public/local?only_media=true [public; language filter] - timeline::public_local_media(), - // GET /api/v1/streaming/public?only_media=true [public; language filter] - timeline::public_media(), - // GET /api/v1/streaming/public/local [public; language filter] - timeline::public_local(), - // GET /api/v1/streaming/public [public; language filter] - timeline::public(), - // GET /api/v1/streaming/direct [private; *no* filter] - timeline::direct(), - // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] - timeline::hashtag(), - // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] - timeline::hashtag_local(), - // GET /api/v1/streaming/list?list=:list_id [private; no filter] - timeline::list() - ) - .untuple_one() - .and(warp::sse()) - .and(warp::any().map(move || redis_updates.new_copy())) - .map( - |timeline: String, user: User, sse: warp::sse::Sse, mut event_stream: StreamManager| { - event_stream.add(&timeline, &user); - sse.reply(warp::sse::keep( - event_stream.filter_map(move |item| { - let payload = item["payload"].clone(); - let event = item["event"].clone().to_string(); - let toot_lang = payload["language"].as_str().expect("redis str").to_string(); - let user_langs = user.langs.clone(); + // let routes = any_of!( + // // GET /api/v1/streaming/user/notification [private; notification filter] + // timeline::user_notifications(), + // // GET /api/v1/streaming/user [private; language filter] + // timeline::user(), + // // GET /api/v1/streaming/public/local?only_media=true [public; language filter] + // timeline::public_local_media(), + // // GET /api/v1/streaming/public?only_media=true [public; language filter] + // timeline::public_media(), + // // GET /api/v1/streaming/public/local [public; language filter] + // timeline::public_local(), + // // GET /api/v1/streaming/public [public; language filter] + // timeline::public(), + // // GET /api/v1/streaming/direct [private; *no* filter] + // timeline::direct(), + // // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] + // timeline::hashtag(), + // // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] + // timeline::hashtag_local(), + // // GET /api/v1/streaming/list?list=:list_id [private; no filter] + // timeline::list() + // ) + // .untuple_one() + // .and(warp::sse()) + // .and(warp::any().map(move || redis_updates.new_copy())) + // .map( + // |timeline: String, user: User, sse: warp::sse::Sse, mut event_stream: StreamManager| { + // dbg!(&event_stream); + // event_stream.add(&timeline, &user); + // sse.reply(warp::sse::keep( + // event_stream.filter_map(move |item| { + // let payload = item["payload"].clone(); + // let event = item["event"].clone().to_string(); + // let toot_lang = payload["language"].as_str().expect("redis str").to_string(); + // let user_langs = user.langs.clone(); - match (&user.filter, user_langs) { - (Filter::Notification, _) if event != "notification" => None, - (Filter::Language, Some(ref langs)) if !langs.contains(&toot_lang) => None, - _ => Some((warp::sse::event(event), warp::sse::data(payload))), + // match (&user.filter, user_langs) { + // (Filter::Notification, _) if event != "notification" => None, + // (Filter::Language, Some(ref langs)) if !langs.contains(&toot_lang) => None, + // _ => Some((warp::sse::event(event), warp::sse::data(payload))), + // } + // }), + // None, + // )) + // }, + // ) + // .with(warp::reply::with::header("Connection", "keep-alive")) + // .recover(error::handle_errors); + + use futures::future::Future; + use futures::sink::Sink; + use futures::Async; + use user::Scope; + use warp::path; + let redis_updates_ws = StreamManager::new(Receiver::new()); + let websocket = path!("api" / "v1" / "streaming") + .and(Scope::Public.get_access_token()) + .and_then(|token| User::from_access_token(token, Scope::Public)) + .and(warp::query()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .and(warp::ws2()) + .and(warp::any().map(move || { + println!("Getting StreamManager.new_copy()"); + redis_updates_ws.new_copy() + })) + .and_then( + |mut user: User, + q: query::Stream, + m: query::Media, + h: query::Hashtag, + l: query::List, + ws: warp::ws::Ws2, + mut stream: StreamManager| { + println!("DING"); + let unauthorized = Err(warp::reject::custom("Error: Invalid Access Token")); + let timeline = match q.stream.as_ref() { + // Public endpoints: + tl @ "public" | tl @ "public:local" if m.is_truthy() => format!("{}:media", tl), + tl @ "public:media" | tl @ "public:local:media" => format!("{}", tl), + tl @ "public" | tl @ "public:local" => format!("{}", tl), + // User + "user" if user.id == -1 => return unauthorized, + "user" => format!("{}", user.id), + "user:notification" => { + user = user.with_notification_filter(); + format!("{}", user.id) } - }), - None, - )) - }, - ) - .with(warp::reply::with::header("Connection", "keep-alive")) - .recover(error::handle_errors); + // Hashtag endpoints: + // TODO: handle missing query + tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, h.tag), + // List endpoint: + // TODO: handle missing query + "list" if user.authorized_for_list(l.list).is_err() => return unauthorized, + "list" => format!("list:{}", l.list), + // Direct endpoint: + "direct" if user.id == -1 => return unauthorized, + "direct" => format!("direct"), + // Other endpoints don't exist: + _ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")), + }; - warp::serve(routes).run(([127, 0, 0, 1], 3030)); + stream.add(&timeline, &user); + stream.set_user(user); + dbg!(&stream); + Ok(ws.on_upgrade(move |socket| handle_ws(socket, stream))) + }, + ); + + fn handle_ws( + socket: warp::ws::WebSocket, + mut stream: StreamManager, + ) -> impl futures::future::Future { + let (mut tx, rx) = futures::sync::mpsc::unbounded(); + let (ws_tx, mut ws_rx) = socket.split(); + // let event_stream = stream + // .map(move |value| warp::ws::Message::text(value.to_string())) + // .map_err(|_| unreachable!()); + warp::spawn( + rx.map_err(|()| -> warp::Error { unreachable!() }) + .forward(ws_tx) + .map_err(|_| ()) + .map(|_r| ()), + ); + let event_stream = tokio::timer::Interval::new( + std::time::Instant::now(), + std::time::Duration::from_secs(10), + ) + .take_while(move |_| { + if ws_rx.poll().is_err() { + println!("Need to close WS"); + futures::future::ok(false) + } else { + // println!("We can still send to WS"); + futures::future::ok(true) + } + }); + + event_stream + .for_each(move |_json_value| { + // println!("For each triggered"); + if let Ok(Async::Ready(Some(json_value))) = stream.poll() { + let msg = warp::ws::Message::text(json_value.to_string()); + tx.unbounded_send(msg).unwrap(); + }; + Ok(()) + }) + .then(|msg| { + println!("Done with stream"); + msg + }) + .map_err(|e| { + println!("{}", e); + }) + } + + let log = warp::any().map(|| { + println!("----got request----"); + warp::reply() + }); + warp::serve(websocket.or(log)).run(([127, 0, 0, 1], 3030)); } + +// loop { +// //println!("Awake"); +// match stream.poll() { +// Err(_) | Ok(Async::Ready(None)) => { +// eprintln!("Breaking out of poll loop due to an error"); +// break; +// } +// Ok(Async::NotReady) => (), +// Ok(Async::Ready(Some(item))) => { +// let user_langs = user.langs.clone(); +// let copy = item.clone(); +// let event = copy["event"].as_str().unwrap(); +// let copy = item.clone(); +// let payload = copy["payload"].to_string(); +// let copy = item.clone(); +// let toot_lang = copy["payload"]["language"] +// .as_str() +// .expect("redis str") +// .to_string(); + +// println!("sending: {:?}", &payload); +// match (&user.filter, user_langs) { +// (Filter::Notification, _) if event != "notification" => continue, +// (Filter::Language, Some(ref langs)) if !langs.contains(&toot_lang) => { +// continue; +// } +// _ => match tx.unbounded_send(warp::ws::Message::text( +// json!( +// {"event": event, +// "payload": payload,} +// ) +// .to_string(), +// )) { +// Ok(()) => println!("Sent OK"), +// Err(e) => { +// println!("Couldn't send: {}", e); +// } +// }, +// } +// } +// }; +// if ws_rx.poll().is_err() { +// println!("Need to close WS"); +// break; +// } else { +// println!("We can still send to WS"); +// } +// std::thread::sleep(std::time::Duration::from_millis(2000)); +// //println!("Asleep"); +// } diff --git a/src/query.rs b/src/query.rs index 4b8aebc..4ec5d15 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,19 +1,66 @@ //! Validate query prarams with type checking use serde_derive::Deserialize; +use warp::filters::BoxedFilter; +use warp::Filter as WarpFilter; -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Debug, Default)] pub struct Media { pub only_media: String, } -#[derive(Deserialize, Debug)] +impl Media { + pub fn to_filter() -> BoxedFilter<(Self,)> { + warp::query() + .or(warp::any().map(|| Self::default())) + .unify() + .boxed() + } + pub fn is_truthy(&self) -> bool { + self.only_media == "true" || self.only_media == "1" + } +} +#[derive(Deserialize, Debug, Default)] pub struct Hashtag { pub tag: String, } -#[derive(Deserialize, Debug)] +impl Hashtag { + pub fn to_filter() -> BoxedFilter<(Self,)> { + warp::query() + .or(warp::any().map(|| Self::default())) + .unify() + .boxed() + } +} +#[derive(Deserialize, Debug, Default)] pub struct List { pub list: i64, } +impl List { + pub fn to_filter() -> BoxedFilter<(Self,)> { + warp::query() + .or(warp::any().map(|| Self::default())) + .unify() + .boxed() + } +} #[derive(Deserialize, Debug)] pub struct Auth { pub access_token: String, } +#[derive(Deserialize, Debug)] +pub struct Stream { + pub stream: String, +} +impl ToString for Stream { + fn to_string(&self) -> String { + format!("{:?}", self) + } +} + +pub fn optional_media_query() -> BoxedFilter<(Media,)> { + warp::query() + .or(warp::any().map(|| Media { + only_media: "false".to_owned(), + })) + .unify() + .boxed() +} diff --git a/src/receiver.rs b/src/receiver.rs index cef6355..456f7f3 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -12,6 +12,24 @@ use uuid::Uuid; use std::io::{Read, Write}; use std::net::TcpStream; use std::time::Duration; +use std::time::Instant; + +#[derive(Debug)] +struct MsgQueue { + messages: VecDeque, + last_polled_at: Instant, + redis_channel: String, +} +impl MsgQueue { + fn new(redis_channel: impl std::fmt::Display) -> Self { + let redis_channel = redis_channel.to_string(); + MsgQueue { + messages: VecDeque::new(), + last_polled_at: Instant::now(), + redis_channel, + } + } +} /// The item that streams from Redis and is polled by the `StreamManger` #[derive(Debug)] @@ -19,8 +37,9 @@ pub struct Receiver { stream: TcpStream, tl: String, pub user: User, - polled_by: Uuid, - msg_queue: HashMap>, + manager_id: Uuid, + msg_queue: HashMap, + subscribed_timelines: HashMap, } impl Receiver { pub fn new() -> Self { @@ -33,30 +52,63 @@ impl Receiver { stream, tl: String::new(), user: User::public(), - polled_by: Uuid::new_v4(), + manager_id: Uuid::new_v4(), msg_queue: HashMap::new(), + subscribed_timelines: HashMap::new(), } } /// Update the `StreamManager` that is currently polling the `Receiver` - pub fn set_polled_by(&mut self, id: Uuid) -> &Self { - self.polled_by = id; - self + pub fn set_manager_id(&mut self, id: Uuid) { + self.manager_id = id; } - /// Send a subscribe command to the Redis PubSub + /// Send a subscribe command to the Redis PubSub and check if any subscriptions should be dropped pub fn subscribe(&mut self, tl: &str) { - let subscribe_cmd = redis_cmd_from("subscribe", &tl); info!("Subscribing to {}", &tl); + + let manager_id = self.manager_id; + self.msg_queue.insert(manager_id, MsgQueue::new(tl)); + self.subscribed_timelines + .entry(tl.to_string()) + .and_modify(|n| *n += 1) + .or_insert(1); + + let mut timelines_with_dropped_clients = Vec::new(); + self.msg_queue.retain(|id, msg_queue| { + if msg_queue.last_polled_at.elapsed() > Duration::from_secs(30) { + timelines_with_dropped_clients.push(msg_queue.redis_channel.clone()); + println!("Dropping: {}", id); + false + } else { + println!("Retaining: {}", id); + true + } + }); + + for timeline in timelines_with_dropped_clients { + let count_of_subscribed_clients = self + .subscribed_timelines + .entry(timeline.clone()) + .and_modify(|n| *n -= 1) + .or_insert(0); + if *count_of_subscribed_clients <= 0 { + self.unsubscribe(&timeline); + } + } + + let subscribe_cmd = redis_cmd_from("subscribe", &tl); self.stream .write_all(&subscribe_cmd) .expect("Can subscribe to Redis"); + println!("Done subscribing"); } /// Send an unsubscribe command to the Redis PubSub pub fn unsubscribe(&mut self, tl: &str) { let unsubscribe_cmd = redis_cmd_from("unsubscribe", &tl); - info!("Subscribing to {}", &tl); + info!("Unsubscribing from {}", &tl); self.stream .write_all(&unsubscribe_cmd) .expect("Can unsubscribe from Redis"); + println!("Done unsubscribing"); } } impl Stream for Receiver { @@ -65,10 +117,10 @@ impl Stream for Receiver { fn poll(&mut self) -> Poll, Self::Error> { let mut buffer = vec![0u8; 3000]; - let polled_by = self.polled_by; + let polled_by = self.manager_id; self.msg_queue .entry(polled_by) - .or_insert_with(VecDeque::new); + .and_modify(|msg_queue| msg_queue.last_polled_at = Instant::now()); info!("Being polled by StreamManager with uuid: {}", polled_by); let mut async_stream = AsyncReadableStream(&mut self.stream); @@ -80,12 +132,19 @@ impl Stream for Receiver { if let Some(cap) = re.captures(&String::from_utf8_lossy(&buffer[..num_bytes_read])) { let json: Value = serde_json::from_str(&cap["json"].to_string().clone())?; - for value in self.msg_queue.values_mut() { - value.push_back(json.clone()); + for msg_queue in self.msg_queue.values_mut() { + msg_queue.messages.push_back(json.clone()); } } } - if let Some(value) = self.msg_queue.entry(polled_by).or_default().pop_front() { + dbg!(&self); + if let Some(value) = self + .msg_queue + .entry(polled_by) + .or_insert(MsgQueue::new(self.tl.clone())) + .messages + .pop_front() + { Ok(Async::Ready(Some(value))) } else { Ok(Async::NotReady) diff --git a/src/stream.rs b/src/stream.rs index 2b670f5..d4c30f0 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -11,20 +11,20 @@ use tokio::io::Error; use uuid::Uuid; /// Struct for manageing all Redis streams -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct StreamManager { receiver: Arc>, - subscriptions: Arc>>, - current_stream: String, + //subscriptions: Arc>>, id: uuid::Uuid, + current_user: Option, } impl StreamManager { pub fn new(reciever: Receiver) -> Self { StreamManager { receiver: Arc::new(Mutex::new(reciever)), - subscriptions: Arc::new(Mutex::new(HashMap::new())), - current_stream: String::new(), + // subscriptions: Arc::new(Mutex::new(HashMap::new())), id: Uuid::new_v4(), + current_user: None, } } @@ -38,52 +38,89 @@ impl StreamManager { /// /// /// `.add()` also unsubscribes from any channels that no longer have clients - pub fn add(&mut self, timeline: &str, _user: &User) -> &Self { - let mut subscriptions = self.subscriptions.lock().expect("No other thread panic"); + pub fn add(&mut self, timeline: &str, _user: &User) { + println!("ADD lock"); let mut receiver = self.receiver.lock().unwrap(); - subscriptions - .entry(timeline.to_string()) - .or_insert_with(|| { - receiver.subscribe(timeline); - Instant::now() - }); + receiver.set_manager_id(self.id); + receiver.subscribe(timeline); + dbg!(&receiver); - // Unsubscribe from that haven't been polled in the last 30 seconds - let channels = subscriptions.clone(); - let channels_to_unsubscribe = channels - .iter() - .filter(|(_, time)| time.elapsed().as_secs() > 30); - for (channel, _) in channels_to_unsubscribe { - receiver.unsubscribe(&channel); - } - // Update our map of streams - *subscriptions = channels - .clone() - .into_iter() - .filter(|(_, time)| time.elapsed().as_secs() > 30) - .collect(); + println!("ADD unlock"); + } - self.current_stream = timeline.to_string(); - self + pub fn set_user(&mut self, user: User) { + self.current_user = Some(user); } } +use crate::user::Filter; +use serde_json::json; + impl Stream for StreamManager { type Item = Value; type Error = Error; fn poll(&mut self) -> Poll, Self::Error> { - let mut subscriptions = self.subscriptions.lock().expect("No other thread panic"); - let target_stream = self.current_stream.clone(); - subscriptions.insert(target_stream.clone(), Instant::now()); - let mut receiver = self.receiver.lock().expect("No other thread panic"); - receiver.set_polled_by(self.id); + receiver.set_manager_id(self.id); + let result = match receiver.poll() { + Ok(Async::Ready(Some(value))) => { + let user = self.clone().current_user.unwrap(); - match receiver.poll() { - Ok(Async::Ready(Some(value))) => Ok(Async::Ready(Some(value))), + let user_langs = user.langs.clone(); + let copy = value.clone(); + let event = copy["event"].as_str().unwrap(); + let copy = value.clone(); + let payload = copy["payload"].to_string(); + let copy = value.clone(); + let toot_lang = copy["payload"]["language"] + .as_str() + .expect("redis str") + .to_string(); + + println!("sending: {:?}", &payload); + match (&user.filter, user_langs) { + (Filter::Notification, _) if event != "notification" => Ok(Async::NotReady), + (Filter::Language, Some(ref langs)) if !langs.contains(&toot_lang) => { + Ok(Async::NotReady) + } + + _ => Ok(Async::Ready(Some(json!( + {"event": event, + "payload": payload,} + )))), + } + } Ok(Async::Ready(None)) => Ok(Async::Ready(None)), Ok(Async::NotReady) => Ok(Async::NotReady), Err(e) => Err(e), - } + }; + // dbg!(&result); + result } } + +// CUT FROM .add +// let mut subscriptions = self.subscriptions.lo ck().expect("No other thread panic"); +// subscriptions +// .entry(timeline.to_string()) +// .or_insert_with(|| { +// println!("Inserting TL: {}", &timeline); +//***** // +// Instant::now() +// }); + +// self.current_stream = timeline.to_string(); +// // Unsubscribe from that haven't been polled in the last 30 seconds +// let channels = subscriptions.clone(); +// let channels_to_unsubscribe = channels +// .iter() +// .filter(|(_, time)| time.elapsed().as_secs() > 30); +// for (channel, _) in channels_to_unsubscribe { +//***** // receiver.unsubscribe(&channel); +// } +// // Update our map of streams +// *subscriptions = channels +// .clone() +// .into_iter() +// .filter(|(_, time)| time.elapsed().as_secs() < 30) +// .collect(); diff --git a/src/timeline.rs b/src/timeline.rs index ee2530f..38b4e6e 100644 --- a/src/timeline.rs +++ b/src/timeline.rs @@ -142,7 +142,7 @@ pub fn list() -> BoxedFilter { .and(Scope::Private.get_access_token()) .and_then(|token| User::from_access_token(token, Scope::Private)) .and(warp::query()) - .and_then(|user: User, q: query::List| (user.is_authorized_for_list(q.list), Ok(user))) + .and_then(|user: User, q: query::List| (user.authorized_for_list(q.list), Ok(user))) .untuple_one() .and(path::end()) .map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter())) diff --git a/src/user.rs b/src/user.rs index f541275..f752f8e 100644 --- a/src/user.rs +++ b/src/user.rs @@ -32,6 +32,7 @@ pub struct User { impl User { /// Create a user from the access token supplied in the header or query paramaters pub fn from_access_token(token: String, scope: Scope) -> Result { + println!("Getting user"); let conn = connect_to_postgres(); let result = &conn .query( @@ -59,7 +60,7 @@ LIMIT 1", filter: Filter::None, }) } else if let Scope::Public = scope { - info!("Granting public access"); + println!("Granting public access"); Ok(User { id: -1, langs: None, @@ -92,7 +93,7 @@ LIMIT 1", } } /// Determine whether the User is authorised for a specified list - pub fn is_authorized_for_list(&self, list: i64) -> Result { + pub fn authorized_for_list(&self, list: i64) -> Result { let conn = connect_to_postgres(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn @@ -128,9 +129,19 @@ pub enum Scope { } impl Scope { pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> { - let token_from_header = warp::header::header::("authorization") - .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); - let token_from_query = warp::query().map(|q: query::Auth| q.access_token); + println!("Getting access token"); + let token_from_header = + warp::header::header::("authorization").map(|auth: String| { + println!( + "Got token_from_header: {}", + auth.split(' ').nth(1).unwrap_or("invalid").to_string() + ); + auth.split(' ').nth(1).unwrap_or("invalid").to_string() + }); + let token_from_query = warp::query().map(|q: query::Auth| { + println!("Got token_from_query: {}", &q.access_token); + q.access_token + }); let public = warp::any().map(|| "no access token".to_string()); match self {