mirror of https://github.com/mastodon/flodgatt
Finish building out postgres auth
This commit is contained in:
parent
debf01770e
commit
6746514f9a
|
@ -1,4 +1,22 @@
|
|||
use super::query;
|
||||
use postgres;
|
||||
use warp::Filter;
|
||||
|
||||
pub fn get_token() -> warp::filters::BoxedFilter<(String,)> {
|
||||
let token_from_header = warp::header::header::<String>("authorization")
|
||||
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
|
||||
|
||||
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
|
||||
token_from_query.or(token_from_header).unify().boxed()
|
||||
}
|
||||
|
||||
pub fn get_account_id_from_token(token: String) -> Result<i64, warp::reject::Rejection> {
|
||||
if let Ok(account_id) = get_account_id(token) {
|
||||
Ok(account_id)
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
|
||||
fn conn() -> postgres::Connection {
|
||||
postgres::Connection::connect(
|
|
@ -0,0 +1,29 @@
|
|||
use serde_derive::Serialize;
|
||||
#[derive(Serialize)]
|
||||
struct ErrorMessage {
|
||||
error: String,
|
||||
}
|
||||
impl ErrorMessage {
|
||||
fn new(msg: impl std::fmt::Display) -> Self {
|
||||
Self {
|
||||
error: msg.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_errors(
|
||||
rejection: warp::reject::Rejection,
|
||||
) -> Result<impl warp::Reply, warp::reject::Rejection> {
|
||||
let err_txt = match rejection.cause() {
|
||||
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
||||
"Error: Missing access token".to_string()
|
||||
}
|
||||
Some(text) => text.to_string(),
|
||||
None => "Unknown server error".to_string(),
|
||||
};
|
||||
let json = warp::reply::json(&ErrorMessage::new(err_txt));
|
||||
Ok(warp::reply::with_status(
|
||||
json,
|
||||
warp::http::StatusCode::UNAUTHORIZED,
|
||||
))
|
||||
}
|
70
src/main.rs
70
src/main.rs
|
@ -1,74 +1,22 @@
|
|||
mod pg;
|
||||
mod auth;
|
||||
mod error;
|
||||
mod pubsub;
|
||||
mod query;
|
||||
use futures::stream::Stream;
|
||||
use log::info;
|
||||
use pretty_env_logger;
|
||||
use serde_derive::Serialize;
|
||||
use warp::{path, Filter};
|
||||
|
||||
fn main() {
|
||||
pretty_env_logger::init();
|
||||
let base = path!("api" / "v1" / "streaming");
|
||||
|
||||
let token = warp::any()
|
||||
.and(warp::header::optional::<String>("authorization"))
|
||||
.map(|auth_header: Option<String>| {
|
||||
if let Some(header_value) = auth_header {
|
||||
Some(
|
||||
header_value
|
||||
.split(" ")
|
||||
.nth(1)
|
||||
.unwrap_or("invalid token")
|
||||
.to_string(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
fn get_account_id_from_token(token: Option<String>) -> Result<i64, warp::reject::Rejection> {
|
||||
if token.is_none() {
|
||||
Err(warp::reject::custom("Error: Missing access token"))
|
||||
} else if let Ok(account_id) = pg::get_account_id(token.unwrap()) {
|
||||
Ok(account_id)
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorMessage {
|
||||
error: String,
|
||||
}
|
||||
impl ErrorMessage {
|
||||
fn new(msg: impl std::fmt::Display) -> Self {
|
||||
Self {
|
||||
error: msg.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_errors(
|
||||
rejection: warp::reject::Rejection,
|
||||
) -> Result<impl warp::Reply, warp::reject::Rejection> {
|
||||
let err_txt = match rejection.cause() {
|
||||
Some(text) => text.to_string(),
|
||||
None => "Unknown server error".to_string(),
|
||||
};
|
||||
let json = warp::reply::json(&ErrorMessage::new(err_txt));
|
||||
Ok(warp::reply::with_status(
|
||||
json,
|
||||
warp::http::StatusCode::UNAUTHORIZED,
|
||||
))
|
||||
}
|
||||
|
||||
// GET /api/v1/streaming/user
|
||||
let user_timeline = base
|
||||
.and(path("user"))
|
||||
.and(path::end())
|
||||
.and(token)
|
||||
.and_then(get_account_id_from_token)
|
||||
.and(auth::get_token())
|
||||
.and_then(auth::get_account_id_from_token)
|
||||
.map(|account_id: i64| {
|
||||
info!("GET /api/v1/streaming/user");
|
||||
pubsub::stream_from(account_id.to_string())
|
||||
|
@ -78,8 +26,8 @@ fn main() {
|
|||
let user_timeline_notifications = base
|
||||
.and(path!("user" / "notification"))
|
||||
.and(path::end())
|
||||
.and(token)
|
||||
.and_then(get_account_id_from_token)
|
||||
.and(auth::get_token())
|
||||
.and_then(auth::get_account_id_from_token)
|
||||
.map(|account_id: i64| {
|
||||
let full_stream = pubsub::stream_from(account_id.to_string());
|
||||
// TODO: filter stream to just have notifications
|
||||
|
@ -133,8 +81,8 @@ fn main() {
|
|||
let direct_timeline = base
|
||||
.and(path("direct"))
|
||||
.and(path::end())
|
||||
.and(token)
|
||||
.and_then(get_account_id_from_token)
|
||||
.and(auth::get_token())
|
||||
.and_then(auth::get_account_id_from_token)
|
||||
.map(|account_id: i64| {
|
||||
info!("GET /api/v1/streaming/direct");
|
||||
pubsub::stream_from(format!("direct:{}", account_id))
|
||||
|
@ -201,7 +149,7 @@ fn main() {
|
|||
None,
|
||||
))
|
||||
})
|
||||
.recover(handle_errors);
|
||||
.recover(error::handle_errors);
|
||||
|
||||
info!("starting streaming api server");
|
||||
warp::serve(routes).run(([127, 0, 0, 1], 3030));
|
||||
|
|
|
@ -12,3 +12,7 @@ pub struct Hashtag {
|
|||
pub struct List {
|
||||
pub list: String,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
pub struct Auth {
|
||||
pub access_token: String,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue