diff --git a/src/main.rs b/src/main.rs index 563c2eb..36eac91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,8 @@ mod user; mod utils; use futures::stream::Stream; use pretty_env_logger; -use pubsub::{stream_from, Filter}; -use user::{Scope, User}; +use pubsub::stream_from; +use user::{Filter, Scope, User}; use warp::{path, Filter as WarpFilter}; fn main() { @@ -17,21 +17,21 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(user.id, Filter::None)); + .map(|user: User| stream_from(user.id.to_string(), user)); // GET /api/v1/streaming/user/notification [private; notification filter] let user_timeline_notifications = path!("api" / "v1" / "streaming" / "user" / "notification") .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|user: User| stream_from(user.id, Filter::Notification)); + .map(|user: User| stream_from(user.id.to_string(), user.with_notification_filter())); // GET /api/v1/streaming/public [public; language filter] let public_timeline = path!("api" / "v1" / "streaming" / "public") .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| stream_from("public".into(), Filter::Language(user.langs))); + .map(|user: User| stream_from("public".into(), user.with_language_filter())); // GET /api/v1/streaming/public?only_media=true [public; language filter] let public_timeline_media = path!("api" / "v1" / "streaming" / "public") @@ -40,8 +40,8 @@ fn main() { .and_then(|token| user::get_account(token, Scope::Public)) .and(warp::query()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => stream_from("public:media".into(), Filter::Language(user.langs)), - _ => stream_from("public".into(), Filter::Language(user.langs)), + "1" | "true" => stream_from("public:media".into(), user.with_language_filter()), + _ => stream_from("public".into(), user.with_language_filter()), }); // GET /api/v1/streaming/public/local [public; language filter] @@ -49,7 +49,7 @@ fn main() { .and(path::end()) .and(user::get_access_token(user::Scope::Public)) .and_then(|token| user::get_account(token, Scope::Public)) - .map(|user: User| stream_from("public:local".into(), Filter::Language(user.langs))); + .map(|user: User| stream_from("public:local".into(), user.with_language_filter())); // GET /api/v1/streaming/public/local?only_media=true [public; language filter] let local_timeline_media = path!("api" / "v1" / "streaming" / "public" / "local") @@ -58,8 +58,8 @@ fn main() { .and(warp::query()) .and(path::end()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { - "1" | "true" => stream_from("public:local:media".into(), Filter::Language(user.langs)), - _ => stream_from("public:local".into(), Filter::None), + "1" | "true" => stream_from("public:local:media".into(), user.with_language_filter()), + _ => stream_from("public:local".into(), user.with_language_filter()), }); // GET /api/v1/streaming/direct [private; *no* filter] @@ -67,28 +67,29 @@ fn main() { .and(path::end()) .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) - .map(|account: User| stream_from(format!("direct:{}", account.id), Filter::None)); + .map(|user: User| stream_from(format!("direct:{}", user.id), user.with_no_filter())); // GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter] let hashtag_timeline = path!("api" / "v1" / "streaming" / "hashtag") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| stream_from(format!("hashtag:{}", q.tag), Filter::None)); + .map(|q: query::Hashtag| stream_from(format!("hashtag:{}", q.tag), User::public())); // GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter] let hashtag_timeline_local = path!("api" / "v1" / "streaming" / "hashtag" / "local") .and(warp::query()) .and(path::end()) - .map(|q: query::Hashtag| stream_from(format!("hashtag:{}:local", q.tag), Filter::None)); + .map(|q: query::Hashtag| stream_from(format!("hashtag:{}:local", q.tag), User::public())); // GET /api/v1/streaming/list?list=:list_id [private; no filter] let list_timeline = path!("api" / "v1" / "streaming" / "list") .and(user::get_access_token(Scope::Private)) .and_then(|token| user::get_account(token, Scope::Private)) .and(warp::query()) + .and_then(|user: User, q: query::List| user.is_authorized_for_list(q.list)) + .untuple_one() .and(path::end()) - // TODO: filter down to lists the user can access - .map(|_user: User, q: query::List| stream_from(format!("list:{}", q.list), Filter::None)); + .map(|list: i64, user: User| stream_from(format!("list:{}", list), user.with_no_filter())); let routes = or!( user_timeline, @@ -105,15 +106,15 @@ fn main() { .and_then(|event_stream| event_stream) .and(warp::sse()) .map(|event_stream: pubsub::Receiver, sse: warp::sse::Sse| { - let filter = event_stream.filter.clone(); + let user = event_stream.user.clone(); sse.reply(warp::sse::keep( event_stream.filter_map(move |item| { let payload = item["payload"].clone(); let event = item["event"].to_string().clone(); - let lang = item["language"].to_string().clone(); - match filter { + let toot_lang = item["language"].to_string().clone(); + match &user.filter { Filter::Notification if event != "notification" => None, - Filter::Language(ref vec) if !vec.contains(&lang) => None, + Filter::Language if !user.langs.contains(&toot_lang) => None, _ => Some((warp::sse::event(event), warp::sse::data(payload))), } }), diff --git a/src/pubsub.rs b/src/pubsub.rs index 92d9d8e..c38bfd4 100644 --- a/src/pubsub.rs +++ b/src/pubsub.rs @@ -1,3 +1,4 @@ +use crate::user::User; use futures::{Async, Future, Poll}; use log::{debug, info}; use regex::Regex; @@ -6,15 +7,9 @@ use tokio::io::{AsyncRead, AsyncWrite, Error, ReadHalf, WriteHalf}; use tokio::net::TcpStream; use warp::Stream; -#[derive(Clone)] -pub enum Filter { - None, - Language(Vec), - Notification, -} pub struct Receiver { rx: ReadHalf, - pub filter: Filter, + pub user: User, } impl Stream for Receiver { type Item = Value; @@ -68,16 +63,15 @@ fn send_subscribe_cmd(tx: WriteHalf, channel: String) { tokio::spawn(sender.map_err(|e| eprintln!("{}", e))); } -/// Create a stream from a string. pub fn stream_from( timeline: String, - filter: Filter, + user: User, ) -> impl Future { get_socket() .and_then(move |socket| { let (rx, tx) = socket.split(); send_subscribe_cmd(tx, format!("timeline:{}", timeline)); - let stream_of_data_from_redis = Receiver { rx, filter }; + let stream_of_data_from_redis = Receiver { rx, user }; Ok(stream_of_data_from_redis) }) .map_err(warp::reject::custom) diff --git a/src/query.rs b/src/query.rs index 0e90177..bb73830 100644 --- a/src/query.rs +++ b/src/query.rs @@ -10,7 +10,7 @@ pub struct Hashtag { } #[derive(Deserialize)] pub struct List { - pub list: String, + pub list: i64, } #[derive(Deserialize)] pub struct Auth { diff --git a/src/user.rs b/src/user.rs index 731919c..b28c518 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,6 +1,6 @@ use crate::{or, query}; use postgres; -use warp::Filter; +use warp::Filter as WarpFilter; pub fn get_access_token(scope: Scope) -> warp::filters::BoxedFilter<(String,)> { let token_from_header = warp::header::header::("authorization") @@ -23,11 +23,65 @@ fn conn() -> postgres::Connection { ) .unwrap() } +#[derive(Clone)] +pub enum Filter { + None, + Language, + Notification, +} +#[derive(Clone)] pub struct User { - pub id: String, + pub id: i64, pub langs: Vec, pub logged_in: bool, + pub filter: Filter, +} +impl User { + pub fn with_notification_filter(self) -> Self { + Self { + filter: Filter::Notification, + ..self + } + } + pub fn with_language_filter(self) -> Self { + Self { + filter: Filter::Language, + ..self + } + } + pub fn with_no_filter(self) -> Self { + Self { + filter: Filter::None, + ..self + } + } + pub fn is_authorized_for_list(self, list: i64) -> Result<(i64, User), warp::reject::Rejection> { + let conn = conn(); + // For the Postgres query, `id` = list number; `account_id` = user.id + let rows = &conn + .query( + " SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", + &[&list], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if !rows.is_empty() { + let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1); + if id_of_account_that_owns_the_list == self.id { + return Ok((list, self)); + } + }; + + Err(warp::reject::custom("Error: Invalid access token")) + } + pub fn public() -> Self { + User { + id: -1, + langs: Vec::new(), + logged_in: false, + filter: Filter::None, + } + } } pub enum Scope { @@ -55,15 +109,17 @@ LIMIT 1", let id: i64 = only_row.get(1); let langs: Vec = only_row.get(2); Ok(User { - id: id.to_string(), + id: id, langs, logged_in: true, + filter: Filter::None, }) } else if let Scope::Public = scope { Ok(User { - id: String::new(), + id: -1, langs: Vec::new(), logged_in: false, + filter: Filter::None, }) } else { Err(warp::reject::custom("Error: Invalid access token"))