mirror of https://github.com/mastodon/flodgatt
Refactor main()
This commit is contained in:
parent
bdb402798c
commit
cf55eb7019
|
@ -35,7 +35,7 @@
|
||||||
//! polls the `Receiver` and the frequency with which the `Receiver` polls Redis.
|
//! polls the `Receiver` and the frequency with which the `Receiver` polls Redis.
|
||||||
//!
|
//!
|
||||||
|
|
||||||
#![warn(clippy::pedantic)]
|
//#![warn(clippy::pedantic)]
|
||||||
#![allow(clippy::try_err, clippy::match_bool)]
|
#![allow(clippy::try_err, clippy::match_bool)]
|
||||||
|
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
|
116
src/main.rs
116
src/main.rs
|
@ -5,11 +5,14 @@ use flodgatt::request::{PgPool, Subscription, Timeline};
|
||||||
use flodgatt::response::redis;
|
use flodgatt::response::redis;
|
||||||
use flodgatt::response::stream;
|
use flodgatt::response::stream;
|
||||||
|
|
||||||
|
use futures::{future::lazy, stream::Stream as _Stream};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::os::unix::fs::PermissionsExt;
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
use std::time::Instant;
|
||||||
use tokio::net::UnixListener;
|
use tokio::net::UnixListener;
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
|
use tokio::timer::Interval;
|
||||||
use warp::http::StatusCode;
|
use warp::http::StatusCode;
|
||||||
use warp::path;
|
use warp::path;
|
||||||
use warp::ws::Ws2;
|
use warp::ws::Ws2;
|
||||||
|
@ -18,28 +21,27 @@ use warp::{Filter, Rejection};
|
||||||
fn main() -> Result<(), FatalErr> {
|
fn main() -> Result<(), FatalErr> {
|
||||||
config::merge_dotenv()?;
|
config::merge_dotenv()?;
|
||||||
pretty_env_logger::try_init()?;
|
pretty_env_logger::try_init()?;
|
||||||
|
|
||||||
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect());
|
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect());
|
||||||
|
|
||||||
|
// Create channels to communicate between threads
|
||||||
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
||||||
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
let shared_pg_conn = PgPool::new(postgres_cfg, *cfg.whitelist_mode);
|
let shared_pg_conn = PgPool::new(postgres_cfg, *cfg.whitelist_mode);
|
||||||
let poll_freq = *redis_cfg.polling_interval;
|
let poll_freq = *redis_cfg.polling_interval;
|
||||||
let manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
||||||
|
|
||||||
// Server Sent Events
|
// Server Sent Events
|
||||||
let sse_manager = manager.clone();
|
let sse_manager = shared_manager.clone();
|
||||||
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
||||||
let sse_routes = Subscription::from_sse_request(shared_pg_conn.clone())
|
let sse = Subscription::from_sse_request(shared_pg_conn.clone())
|
||||||
.and(warp::sse())
|
.and(warp::sse())
|
||||||
.map(
|
.map(
|
||||||
move |subscription: Subscription, client_conn: warp::sse::Sse| {
|
move |subscription: Subscription, client_conn: warp::sse::Sse| {
|
||||||
log::info!("Incoming SSE request for {:?}", subscription.timeline);
|
log::info!("Incoming SSE request for {:?}", subscription.timeline);
|
||||||
{
|
{
|
||||||
let mut manager = sse_manager.lock().unwrap_or_else(redis::Manager::recover);
|
let mut manager = sse_manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||||
manager.subscribe(&subscription).unwrap_or_else(|e| {
|
manager.subscribe(&subscription);
|
||||||
log::error!("Could not subscribe to the Redis channel: {}", e)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stream::Sse::send_events(
|
stream::Sse::send_events(
|
||||||
|
@ -53,29 +55,19 @@ fn main() -> Result<(), FatalErr> {
|
||||||
.with(warp::reply::with::header("Connection", "keep-alive"));
|
.with(warp::reply::with::header("Connection", "keep-alive"));
|
||||||
|
|
||||||
// WebSocket
|
// WebSocket
|
||||||
let ws_manager = manager.clone();
|
let ws_manager = shared_manager.clone();
|
||||||
let ws_routes = Subscription::from_ws_request(shared_pg_conn)
|
let ws = Subscription::from_ws_request(shared_pg_conn)
|
||||||
.and(warp::ws::ws2())
|
.and(warp::ws::ws2())
|
||||||
.map(move |subscription: Subscription, ws: Ws2| {
|
.map(move |subscription: Subscription, ws: Ws2| {
|
||||||
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
||||||
{
|
{
|
||||||
let mut manager = ws_manager.lock().unwrap_or_else(redis::Manager::recover);
|
let mut manager = ws_manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||||
|
manager.subscribe(&subscription);
|
||||||
manager.subscribe(&subscription).unwrap_or_else(|e| {
|
|
||||||
log::error!("Could not subscribe to the Redis channel: {}", e)
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
let cmd_tx = cmd_tx.clone();
|
let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security
|
||||||
let ws_rx = event_rx.clone();
|
let ws_stream = stream::Ws::new(cmd_tx.clone(), event_rx.clone(), subscription);
|
||||||
let token = subscription
|
|
||||||
.clone()
|
|
||||||
.access_token
|
|
||||||
.unwrap_or_else(String::new);
|
|
||||||
|
|
||||||
let ws_response_stream = ws
|
(ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token)
|
||||||
.on_upgrade(move |ws| stream::Ws::new(ws, cmd_tx, subscription).send_events(ws_rx));
|
|
||||||
|
|
||||||
(ws_response_stream, token)
|
|
||||||
})
|
})
|
||||||
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
|
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
|
||||||
|
|
||||||
|
@ -84,9 +76,10 @@ fn main() -> Result<(), FatalErr> {
|
||||||
.allow_methods(cfg.cors.allowed_methods)
|
.allow_methods(cfg.cors.allowed_methods)
|
||||||
.allow_headers(cfg.cors.allowed_headers);
|
.allow_headers(cfg.cors.allowed_headers);
|
||||||
|
|
||||||
|
// TODO -- extract to separate file
|
||||||
#[cfg(feature = "stub_status")]
|
#[cfg(feature = "stub_status")]
|
||||||
let status_endpoints = {
|
let status = {
|
||||||
let (r1, r3) = (manager.clone(), manager.clone());
|
let (r1, r3) = (shared_manager.clone(), shared_manager.clone());
|
||||||
warp::path!("api" / "v1" / "streaming" / "health")
|
warp::path!("api" / "v1" / "streaming" / "health")
|
||||||
.map(|| "OK")
|
.map(|| "OK")
|
||||||
.or(warp::path!("api" / "v1" / "streaming" / "status")
|
.or(warp::path!("api" / "v1" / "streaming" / "status")
|
||||||
|
@ -98,54 +91,43 @@ fn main() -> Result<(), FatalErr> {
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
#[cfg(not(feature = "stub_status"))]
|
#[cfg(not(feature = "stub_status"))]
|
||||||
let status_endpoints = warp::path!("api" / "v1" / "streaming" / "health").map(|| "OK");
|
let status = warp::path!("api" / "v1" / "streaming" / "health").map(|| "OK");
|
||||||
|
|
||||||
|
let streaming_server = move || {
|
||||||
|
let manager = shared_manager.clone();
|
||||||
|
let stream = Interval::new(Instant::now(), poll_freq)
|
||||||
|
.map_err(|e| log::error!("{}", e))
|
||||||
|
.for_each(move |_| {
|
||||||
|
let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||||
|
manager.poll_broadcast().unwrap_or_else(FatalErr::exit);
|
||||||
|
Ok(())
|
||||||
|
});
|
||||||
|
warp::spawn(lazy(move || stream));
|
||||||
|
warp::serve(ws.or(sse).with(cors).or(status).recover(recover))
|
||||||
|
};
|
||||||
|
|
||||||
if let Some(socket) = &*cfg.unix_socket {
|
if let Some(socket) = &*cfg.unix_socket {
|
||||||
log::info!("Using Unix socket {}", socket);
|
log::info!("Using Unix socket {}", socket);
|
||||||
fs::remove_file(socket).unwrap_or_default();
|
fs::remove_file(socket).unwrap_or_default();
|
||||||
let incoming = UnixListener::bind(socket).unwrap().incoming();
|
let incoming = UnixListener::bind(socket).expect("TODO").incoming();
|
||||||
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap();
|
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).expect("TODO");
|
||||||
|
|
||||||
warp::serve(
|
tokio::run(lazy(|| streaming_server().serve_incoming(incoming)));
|
||||||
ws_routes
|
|
||||||
.or(sse_routes)
|
|
||||||
.with(cors)
|
|
||||||
.or(status_endpoints)
|
|
||||||
.recover(|r: Rejection| {
|
|
||||||
let json_err = match r.cause() {
|
|
||||||
Some(text)
|
|
||||||
if text.to_string() == "Missing request header 'authorization'" =>
|
|
||||||
{
|
|
||||||
warp::reply::json(&"Error: Missing access token".to_string())
|
|
||||||
}
|
|
||||||
Some(text) => warp::reply::json(&text.to_string()),
|
|
||||||
None => warp::reply::json(&"Error: Nonexistant endpoint".to_string()),
|
|
||||||
};
|
|
||||||
Ok(warp::reply::with_status(json_err, StatusCode::UNAUTHORIZED))
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.run_incoming(incoming);
|
|
||||||
} else {
|
} else {
|
||||||
use futures::{future::lazy, stream::Stream as _Stream};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
|
let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
|
||||||
|
tokio::run(lazy(move || streaming_server().bind(server_addr)));
|
||||||
tokio::run(lazy(move || {
|
}
|
||||||
let receiver = manager.clone();
|
|
||||||
|
|
||||||
warp::spawn(lazy(move || {
|
|
||||||
tokio::timer::Interval::new(Instant::now(), poll_freq)
|
|
||||||
.map_err(|e| log::error!("{}", e))
|
|
||||||
.for_each(move |_| {
|
|
||||||
let mut receiver = receiver.lock().unwrap_or_else(redis::Manager::recover);
|
|
||||||
receiver.poll_broadcast().unwrap_or_else(FatalErr::exit);
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}));
|
|
||||||
|
|
||||||
warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).bind(server_addr)
|
|
||||||
}));
|
|
||||||
};
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO -- extract to separate file
|
||||||
|
fn recover(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> {
|
||||||
|
let json_err = match r.cause() {
|
||||||
|
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
||||||
|
warp::reply::json(&"Error: Missing access token".to_string())
|
||||||
|
}
|
||||||
|
Some(text) => warp::reply::json(&text.to_string()),
|
||||||
|
None => warp::reply::json(&"Error: Nonexistant endpoint".to_string()),
|
||||||
|
};
|
||||||
|
Ok(warp::reply::with_status(json_err, StatusCode::UNAUTHORIZED))
|
||||||
|
}
|
||||||
|
|
|
@ -7,11 +7,60 @@ mod subscription;
|
||||||
pub use self::postgres::PgPool;
|
pub use self::postgres::PgPool;
|
||||||
// TODO consider whether we can remove `Stream` from public API
|
// TODO consider whether we can remove `Stream` from public API
|
||||||
pub use subscription::{Blocks, Stream, Subscription, Timeline};
|
pub use subscription::{Blocks, Stream, Subscription, Timeline};
|
||||||
|
|
||||||
//#[cfg(test)]
|
|
||||||
pub use subscription::{Content, Reach};
|
pub use subscription::{Content, Reach};
|
||||||
|
|
||||||
|
use self::query::Query;
|
||||||
|
use crate::config;
|
||||||
|
use warp::{filters::BoxedFilter, path, reject::Rejection, Filter};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod sse_test;
|
mod sse_test;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod ws_test;
|
mod ws_test;
|
||||||
|
|
||||||
|
pub struct Handler {
|
||||||
|
pg_conn: PgPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Handler {
|
||||||
|
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
pg_conn: PgPool::new(postgres_cfg, whitelist_mode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_ws_request(&self) -> BoxedFilter<(Subscription,)> {
|
||||||
|
let pg_conn = self.pg_conn.clone();
|
||||||
|
parse_ws_query()
|
||||||
|
.and(query::OptionalAccessToken::from_ws_header())
|
||||||
|
.and_then(Query::update_access_token)
|
||||||
|
.and_then(move |q| Subscription::from_query(q, pg_conn.clone()))
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
||||||
|
path!("api" / "v1" / "streaming")
|
||||||
|
.and(path::end())
|
||||||
|
.and(warp::query())
|
||||||
|
.and(query::Auth::to_filter())
|
||||||
|
.and(query::Media::to_filter())
|
||||||
|
.and(query::Hashtag::to_filter())
|
||||||
|
.and(query::List::to_filter())
|
||||||
|
.map(
|
||||||
|
|stream: query::Stream,
|
||||||
|
auth: query::Auth,
|
||||||
|
media: query::Media,
|
||||||
|
hashtag: query::Hashtag,
|
||||||
|
list: query::List| {
|
||||||
|
Query {
|
||||||
|
access_token: auth.access_token,
|
||||||
|
stream: stream.stream,
|
||||||
|
media: media.is_truthy(),
|
||||||
|
hashtag: hashtag.tag,
|
||||||
|
list: list.list,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
|
@ -117,7 +117,7 @@ impl Subscription {
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
pub(super) fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||||
let user = pool.clone().select_user(&q.access_token)?;
|
let user = pool.clone().select_user(&q.access_token)?;
|
||||||
let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
|
let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
|
||||||
let hashtag_name = match timeline {
|
let hashtag_name = match timeline {
|
||||||
|
|
|
@ -50,7 +50,7 @@ impl Manager {
|
||||||
Arc::new(Mutex::new(self))
|
Arc::new(Mutex::new(self))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn subscribe(&mut self, subscription: &Subscription) -> Result<()> {
|
pub fn subscribe(&mut self, subscription: &Subscription) {
|
||||||
let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline);
|
let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline);
|
||||||
if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) {
|
if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) {
|
||||||
self.redis_connection.update_cache(hashtag, id);
|
self.redis_connection.update_cache(hashtag, id);
|
||||||
|
@ -64,9 +64,10 @@ impl Manager {
|
||||||
|
|
||||||
use RedisCmd::*;
|
use RedisCmd::*;
|
||||||
if *number_of_subscriptions == 1 {
|
if *number_of_subscriptions == 1 {
|
||||||
self.redis_connection.send_cmd(Subscribe, &tl)?
|
self.redis_connection
|
||||||
|
.send_cmd(Subscribe, &tl)
|
||||||
|
.unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e));
|
||||||
};
|
};
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unsubscribe(&mut self, tl: Timeline) -> Result<()> {
|
pub fn unsubscribe(&mut self, tl: Timeline) -> Result<()> {
|
||||||
|
|
|
@ -12,20 +12,31 @@ use warp::{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct Ws {
|
pub struct Ws {
|
||||||
ws_tx: mpsc::UnboundedSender<Message>,
|
|
||||||
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||||
subscription: Subscription,
|
subscription: Subscription,
|
||||||
|
ws_rx: watch::Receiver<(Timeline, Event)>,
|
||||||
|
ws_tx: Option<mpsc::UnboundedSender<Message>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Ws {
|
impl Ws {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
ws: WebSocket,
|
|
||||||
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||||
|
ws_rx: watch::Receiver<(Timeline, Event)>,
|
||||||
subscription: Subscription,
|
subscription: Subscription,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
unsubscribe_tx,
|
||||||
|
subscription,
|
||||||
|
ws_rx,
|
||||||
|
ws_tx: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_to(mut self, ws: WebSocket) -> impl Future<Item = (), Error = ()> {
|
||||||
let (transmit_to_ws, _receive_from_ws) = ws.split();
|
let (transmit_to_ws, _receive_from_ws) = ws.split();
|
||||||
// Create a pipe
|
// Create a pipe
|
||||||
let (ws_tx, ws_rx) = mpsc::unbounded_channel();
|
let (ws_tx, ws_rx) = mpsc::unbounded_channel();
|
||||||
|
self.ws_tx = Some(ws_tx);
|
||||||
|
|
||||||
// Send one end of it to a different green thread and tell that end to forward
|
// Send one end of it to a different green thread and tell that end to forward
|
||||||
// whatever it gets on to the WebSocket client
|
// whatever it gets on to the WebSocket client
|
||||||
|
@ -39,20 +50,11 @@ impl Ws {
|
||||||
_ => log::warn!("WebSocket send error: {}", e),
|
_ => log::warn!("WebSocket send error: {}", e),
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
Self {
|
|
||||||
ws_tx,
|
|
||||||
unsubscribe_tx,
|
|
||||||
subscription,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_events(
|
|
||||||
mut self,
|
|
||||||
event_rx: watch::Receiver<(Timeline, Event)>,
|
|
||||||
) -> impl Future<Item = (), Error = ()> {
|
|
||||||
let target_timeline = self.subscription.timeline;
|
let target_timeline = self.subscription.timeline;
|
||||||
|
let incoming_events = self.ws_rx.clone().map_err(|_| ());
|
||||||
|
|
||||||
event_rx.map_err(|_| ()).for_each(move |(tl, event)| {
|
incoming_events.for_each(move |(tl, event)| {
|
||||||
if matches!(event, Event::Ping) {
|
if matches!(event, Event::Ping) {
|
||||||
self.send_ping()
|
self.send_ping()
|
||||||
} else if target_timeline == tl {
|
} else if target_timeline == tl {
|
||||||
|
@ -97,7 +99,7 @@ impl Ws {
|
||||||
|
|
||||||
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
|
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
|
||||||
let tl = self.subscription.timeline;
|
let tl = self.subscription.timeline;
|
||||||
match self.ws_tx.try_send(Message::text(txt)) {
|
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
|
||||||
Ok(_) => Ok(()),
|
Ok(_) => Ok(()),
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
self.unsubscribe_tx
|
self.unsubscribe_tx
|
||||||
|
|
Loading…
Reference in New Issue