diff --git a/src/pg.rs b/src/auth.rs similarity index 57% rename from src/pg.rs rename to src/auth.rs index d7242e9..740c1e8 100644 --- a/src/pg.rs +++ b/src/auth.rs @@ -1,4 +1,22 @@ +use super::query; use postgres; +use warp::Filter; + +pub fn get_token() -> warp::filters::BoxedFilter<(String,)> { + let token_from_header = warp::header::header::("authorization") + .map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string()); + + let token_from_query = warp::query().map(|q: query::Auth| q.access_token); + token_from_query.or(token_from_header).unify().boxed() +} + +pub fn get_account_id_from_token(token: String) -> Result { + if let Ok(account_id) = get_account_id(token) { + Ok(account_id) + } else { + Err(warp::reject::custom("Error: Invalid access token")) + } +} fn conn() -> postgres::Connection { postgres::Connection::connect( diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..e56e64e --- /dev/null +++ b/src/error.rs @@ -0,0 +1,29 @@ +use serde_derive::Serialize; +#[derive(Serialize)] +struct ErrorMessage { + error: String, +} +impl ErrorMessage { + fn new(msg: impl std::fmt::Display) -> Self { + Self { + error: msg.to_string(), + } + } +} + +pub fn handle_errors( + rejection: warp::reject::Rejection, +) -> Result { + let err_txt = match rejection.cause() { + Some(text) if text.to_string() == "Missing request header 'authorization'" => { + "Error: Missing access token".to_string() + } + Some(text) => text.to_string(), + None => "Unknown server error".to_string(), + }; + let json = warp::reply::json(&ErrorMessage::new(err_txt)); + Ok(warp::reply::with_status( + json, + warp::http::StatusCode::UNAUTHORIZED, + )) +} diff --git a/src/main.rs b/src/main.rs index a29f056..5749e83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,74 +1,22 @@ -mod pg; +mod auth; +mod error; mod pubsub; mod query; use futures::stream::Stream; use log::info; use pretty_env_logger; -use serde_derive::Serialize; use warp::{path, Filter}; fn main() { pretty_env_logger::init(); let base = path!("api" / "v1" / "streaming"); - let token = warp::any() - .and(warp::header::optional::("authorization")) - .map(|auth_header: Option| { - if let Some(header_value) = auth_header { - Some( - header_value - .split(" ") - .nth(1) - .unwrap_or("invalid token") - .to_string(), - ) - } else { - None - } - }); - - fn get_account_id_from_token(token: Option) -> Result { - if token.is_none() { - Err(warp::reject::custom("Error: Missing access token")) - } else if let Ok(account_id) = pg::get_account_id(token.unwrap()) { - Ok(account_id) - } else { - Err(warp::reject::custom("Error: Invalid access token")) - } - } - - #[derive(Serialize)] - struct ErrorMessage { - error: String, - } - impl ErrorMessage { - fn new(msg: impl std::fmt::Display) -> Self { - Self { - error: msg.to_string(), - } - } - } - - fn handle_errors( - rejection: warp::reject::Rejection, - ) -> Result { - let err_txt = match rejection.cause() { - Some(text) => text.to_string(), - None => "Unknown server error".to_string(), - }; - let json = warp::reply::json(&ErrorMessage::new(err_txt)); - Ok(warp::reply::with_status( - json, - warp::http::StatusCode::UNAUTHORIZED, - )) - } - // GET /api/v1/streaming/user let user_timeline = base .and(path("user")) .and(path::end()) - .and(token) - .and_then(get_account_id_from_token) + .and(auth::get_token()) + .and_then(auth::get_account_id_from_token) .map(|account_id: i64| { info!("GET /api/v1/streaming/user"); pubsub::stream_from(account_id.to_string()) @@ -78,8 +26,8 @@ fn main() { let user_timeline_notifications = base .and(path!("user" / "notification")) .and(path::end()) - .and(token) - .and_then(get_account_id_from_token) + .and(auth::get_token()) + .and_then(auth::get_account_id_from_token) .map(|account_id: i64| { let full_stream = pubsub::stream_from(account_id.to_string()); // TODO: filter stream to just have notifications @@ -133,8 +81,8 @@ fn main() { let direct_timeline = base .and(path("direct")) .and(path::end()) - .and(token) - .and_then(get_account_id_from_token) + .and(auth::get_token()) + .and_then(auth::get_account_id_from_token) .map(|account_id: i64| { info!("GET /api/v1/streaming/direct"); pubsub::stream_from(format!("direct:{}", account_id)) @@ -201,7 +149,7 @@ fn main() { None, )) }) - .recover(handle_errors); + .recover(error::handle_errors); info!("starting streaming api server"); warp::serve(routes).run(([127, 0, 0, 1], 3030)); diff --git a/src/query.rs b/src/query.rs index 42b4ea7..0e90177 100644 --- a/src/query.rs +++ b/src/query.rs @@ -12,3 +12,7 @@ pub struct Hashtag { pub struct List { pub list: String, } +#[derive(Deserialize)] +pub struct Auth { + pub access_token: String, +}