diff --git a/Cargo.lock b/Cargo.lock index cd041a9..767381e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,3 +1,5 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. [[package]] name = "aho-corasick" version = "0.7.3" diff --git a/src/main.rs b/src/main.rs index 44f00dd..900ea39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,7 +40,7 @@ use receiver::Receiver; use std::env; use std::net::SocketAddr; use stream::StreamManager; -use user::{Scope, User}; +use user::{Method, Scope, User}; use warp::path; use warp::Filter as WarpFilter; @@ -96,7 +96,7 @@ fn main() { //let redis_updates_ws = StreamManager::new(Receiver::new()); let websocket = path!("api" / "v1" / "streaming") - .and(Scope::Public.get_access_token()) + .and(Scope::Public.get_access_token(Method::WS)) .and_then(|token| User::from_access_token(token, Scope::Public)) .and(warp::query()) .and(query::Media::to_filter()) @@ -140,7 +140,14 @@ fn main() { Ok(ws.on_upgrade(move |socket| ws::send_replies(socket, stream))) }, - ); + ) + .map(|reply| { + warp::reply::with_header( + reply, + "sec-websocket-protocol", + "LhbVOxKckgqyMg3nDLaEu5vgqY6Yzc9Pk1w8_yKQwS8", + ) + }); let address: SocketAddr = env::var("SERVER_ADDR") .unwrap_or("127.0.0.1:4000".to_owned()) diff --git a/src/timeline.rs b/src/timeline.rs index 541ac2b..4cdf012 100644 --- a/src/timeline.rs +++ b/src/timeline.rs @@ -1,6 +1,6 @@ //! Filters for all the endpoints accessible for Server Sent Event updates use crate::query; -use crate::user::{Scope, User}; +use crate::user::{Method, Scope, User}; use warp::filters::BoxedFilter; use warp::{path, Filter}; @@ -14,7 +14,7 @@ type TimelineUser = ((String, User),); pub fn user() -> BoxedFilter { path!("api" / "v1" / "streaming" / "user") .and(path::end()) - .and(Scope::Private.get_access_token()) + .and(Scope::Private.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Private)) .map(|user: User| (user.id.to_string(), user)) .boxed() @@ -30,7 +30,7 @@ pub fn user() -> BoxedFilter { pub fn user_notifications() -> BoxedFilter { path!("api" / "v1" / "streaming" / "user" / "notification") .and(path::end()) - .and(Scope::Private.get_access_token()) + .and(Scope::Private.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Private)) .map(|user: User| (user.id.to_string(), user.with_notification_filter())) .boxed() @@ -43,7 +43,7 @@ pub fn user_notifications() -> BoxedFilter { pub fn public() -> BoxedFilter { path!("api" / "v1" / "streaming" / "public") .and(path::end()) - .and(Scope::Public.get_access_token()) + .and(Scope::Public.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Public)) .map(|user: User| ("public".to_owned(), user.with_language_filter())) .boxed() @@ -56,7 +56,7 @@ pub fn public() -> BoxedFilter { pub fn public_media() -> BoxedFilter { path!("api" / "v1" / "streaming" / "public") .and(path::end()) - .and(Scope::Public.get_access_token()) + .and(Scope::Public.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Public)) .and(warp::query()) .map(|user: User, q: query::Media| match q.only_media.as_ref() { @@ -73,7 +73,7 @@ pub fn public_media() -> BoxedFilter { pub fn public_local() -> BoxedFilter { path!("api" / "v1" / "streaming" / "public" / "local") .and(path::end()) - .and(Scope::Public.get_access_token()) + .and(Scope::Public.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Public)) .map(|user: User| ("public:local".to_owned(), user.with_language_filter())) .boxed() @@ -85,7 +85,7 @@ pub fn public_local() -> BoxedFilter { /// **public**. Filter: `Language` pub fn public_local_media() -> BoxedFilter { path!("api" / "v1" / "streaming" / "public" / "local") - .and(Scope::Public.get_access_token()) + .and(Scope::Public.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Public)) .and(warp::query()) .and(path::end()) @@ -103,7 +103,7 @@ pub fn public_local_media() -> BoxedFilter { pub fn direct() -> BoxedFilter { path!("api" / "v1" / "streaming" / "direct") .and(path::end()) - .and(Scope::Private.get_access_token()) + .and(Scope::Private.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Private)) .map(|user: User| (format!("direct:{}", user.id), user.with_no_filter())) .boxed() @@ -139,7 +139,7 @@ pub fn hashtag_local() -> BoxedFilter { /// **private**. Filter: `None` pub fn list() -> BoxedFilter { path!("api" / "v1" / "streaming" / "list") - .and(Scope::Private.get_access_token()) + .and(Scope::Private.get_access_token(Method::HttpPush)) .and_then(|token| User::from_access_token(token, Scope::Private)) .and(warp::query()) .and_then(|user: User, q: query::List| (user.authorized_for_list(q.list), Ok(user))) diff --git a/src/user.rs b/src/user.rs index 5a5cbb2..3923aa2 100644 --- a/src/user.rs +++ b/src/user.rs @@ -27,6 +27,7 @@ pub enum Filter { #[derive(Clone, Debug, PartialEq)] pub struct User { pub id: i64, + pub access_token: String, pub langs: Option>, pub logged_in: bool, pub filter: Filter, @@ -49,6 +50,7 @@ LIMIT 1", &[&token], ) .expect("Hard-coded query will return Some([0 or more rows])"); + dbg!(&result); if !result.is_empty() { let only_row = result.get(0); let id: i64 = only_row.get(1); @@ -56,6 +58,7 @@ LIMIT 1", info!("Granting logged-in access"); Ok(User { id, + access_token: token, langs, logged_in: true, filter: Filter::None, @@ -64,6 +67,7 @@ LIMIT 1", info!("Granting public access to non-authenticated client"); Ok(User { id: -1, + access_token: token, langs: None, logged_in: false, filter: Filter::None, @@ -116,6 +120,7 @@ LIMIT 1", pub fn public() -> Self { User { id: -1, + access_token: String::new(), langs: None, logged_in: false, filter: Filter::None, @@ -128,18 +133,41 @@ pub enum Scope { Public, Private, } +pub enum Method { + WS, + HttpPush, +} impl Scope { - pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> { - let token_from_header = warp::header::header::("authorization") - .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); - let token_from_query = warp::query().map(|q: query::Auth| q.access_token); + pub fn get_access_token(self, method: Method) -> warp::filters::BoxedFilter<(String,)> { + let token_from_header_http_push = + warp::header::header::("authorization").map(|auth: String| { + dbg!(auth.split(' ').nth(1).unwrap_or("invalid").to_string()); + auth.split(' ').nth(1).unwrap_or("invalid").to_string() + }); + let token_from_header_ws = + warp::header::header::("Sec-WebSocket-Protocol").map(|auth: String| { + dbg!(&auth); + auth + }); + let token_from_query = warp::query().map(|q: query::Auth| { + dbg!(&q.access_token); + q.access_token + }); let public = warp::any().map(|| "no access token".to_string()); - match self { + match (self, method) { // if they're trying to access a private scope without an access token, reject the request - Scope::Private => any_of!(token_from_query, token_from_header).boxed(), + (Scope::Private, Method::HttpPush) => { + any_of!(token_from_query, token_from_header_http_push).boxed() + } + (Scope::Private, Method::WS) => any_of!(token_from_query, token_from_header_ws).boxed(), // if they're trying to access a public scope without an access token, proceed - Scope::Public => any_of!(token_from_query, token_from_header, public).boxed(), + (Scope::Public, Method::HttpPush) => { + any_of!(token_from_query, token_from_header_http_push, public).boxed() + } + (Scope::Public, Method::WS) => { + any_of!(token_from_query, token_from_header_ws, public).boxed() + } } } }