Finish building out postgres auth

This commit is contained in:
Daniel Sockwell 2019-04-18 19:02:29 -04:00
parent debf01770e
commit 6746514f9a
4 changed files with 60 additions and 61 deletions

View File

@ -1,4 +1,22 @@
use super::query;
use postgres;
use warp::Filter;
pub fn get_token() -> warp::filters::BoxedFilter<(String,)> {
let token_from_header = warp::header::header::<String>("authorization")
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
token_from_query.or(token_from_header).unify().boxed()
}
pub fn get_account_id_from_token(token: String) -> Result<i64, warp::reject::Rejection> {
if let Ok(account_id) = get_account_id(token) {
Ok(account_id)
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
fn conn() -> postgres::Connection {
postgres::Connection::connect(

29
src/error.rs Normal file
View File

@ -0,0 +1,29 @@
use serde_derive::Serialize;
#[derive(Serialize)]
struct ErrorMessage {
error: String,
}
impl ErrorMessage {
fn new(msg: impl std::fmt::Display) -> Self {
Self {
error: msg.to_string(),
}
}
}
pub fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {
let err_txt = match rejection.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Unknown server error".to_string(),
};
let json = warp::reply::json(&ErrorMessage::new(err_txt));
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
}

View File

@ -1,74 +1,22 @@
mod pg;
mod auth;
mod error;
mod pubsub;
mod query;
use futures::stream::Stream;
use log::info;
use pretty_env_logger;
use serde_derive::Serialize;
use warp::{path, Filter};
fn main() {
pretty_env_logger::init();
let base = path!("api" / "v1" / "streaming");
let token = warp::any()
.and(warp::header::optional::<String>("authorization"))
.map(|auth_header: Option<String>| {
if let Some(header_value) = auth_header {
Some(
header_value
.split(" ")
.nth(1)
.unwrap_or("invalid token")
.to_string(),
)
} else {
None
}
});
fn get_account_id_from_token(token: Option<String>) -> Result<i64, warp::reject::Rejection> {
if token.is_none() {
Err(warp::reject::custom("Error: Missing access token"))
} else if let Ok(account_id) = pg::get_account_id(token.unwrap()) {
Ok(account_id)
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
#[derive(Serialize)]
struct ErrorMessage {
error: String,
}
impl ErrorMessage {
fn new(msg: impl std::fmt::Display) -> Self {
Self {
error: msg.to_string(),
}
}
}
fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {
let err_txt = match rejection.cause() {
Some(text) => text.to_string(),
None => "Unknown server error".to_string(),
};
let json = warp::reply::json(&ErrorMessage::new(err_txt));
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
}
// GET /api/v1/streaming/user
let user_timeline = base
.and(path("user"))
.and(path::end())
.and(token)
.and_then(get_account_id_from_token)
.and(auth::get_token())
.and_then(auth::get_account_id_from_token)
.map(|account_id: i64| {
info!("GET /api/v1/streaming/user");
pubsub::stream_from(account_id.to_string())
@ -78,8 +26,8 @@ fn main() {
let user_timeline_notifications = base
.and(path!("user" / "notification"))
.and(path::end())
.and(token)
.and_then(get_account_id_from_token)
.and(auth::get_token())
.and_then(auth::get_account_id_from_token)
.map(|account_id: i64| {
let full_stream = pubsub::stream_from(account_id.to_string());
// TODO: filter stream to just have notifications
@ -133,8 +81,8 @@ fn main() {
let direct_timeline = base
.and(path("direct"))
.and(path::end())
.and(token)
.and_then(get_account_id_from_token)
.and(auth::get_token())
.and_then(auth::get_account_id_from_token)
.map(|account_id: i64| {
info!("GET /api/v1/streaming/direct");
pubsub::stream_from(format!("direct:{}", account_id))
@ -201,7 +149,7 @@ fn main() {
None,
))
})
.recover(handle_errors);
.recover(error::handle_errors);
info!("starting streaming api server");
warp::serve(routes).run(([127, 0, 0, 1], 3030));

View File

@ -12,3 +12,7 @@ pub struct Hashtag {
pub struct List {
pub list: String,
}
#[derive(Deserialize)]
pub struct Auth {
pub access_token: String,
}