Error handling, pt3 (#131)

* Improve handling of Postgres errors

* Finish error handling improvements

* Remove `format!` calls from hot path
This commit is contained in:
Daniel Sockwell 2020-04-14 20:37:49 -04:00 committed by GitHub
parent 45f9d4b9fb
commit 37b652ad79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 372 additions and 257 deletions

View File

2
Cargo.lock generated
View File

@ -406,7 +406,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
version = "0.8.3"
version = "0.8.4"
dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@ -1,7 +1,7 @@
[package]
name = "flodgatt"
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
version = "0.8.3"
version = "0.8.4"
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
edition = "2018"

View File

@ -1,7 +1,7 @@
pub use {deployment_cfg::Deployment, postgres_cfg::Postgres, redis_cfg::Redis};
use self::environmental_variables::EnvVar;
use super::err;
use super::err::FatalErr;
use hashbrown::HashMap;
use std::env;
@ -13,22 +13,45 @@ mod postgres_cfg_types;
mod redis_cfg;
mod redis_cfg_types;
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
// TODO -- should this allow the user to run in a dir without a `.env` file?
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
type Result<T> = std::result::Result<T, FatalErr>;
pub fn merge_dotenv() -> Result<()> {
let env_file = match env::var("ENV").ok().as_deref() {
Some("production") => ".env.production",
Some("development") | None => ".env",
Some(_unsupported) => Err(err::FatalErr::Unknown)?, // TODO make more specific
})?;
Some(v) => Err(FatalErr::config("ENV", v, "`production` or `development`"))?,
};
let res = dotenv::from_filename(env_file);
if let Ok(log_level) = env::var("RUST_LOG") {
if res.is_err() && ["warn", "info", "trace", "debug"].contains(&log_level.as_str()) {
eprintln!(
" WARN: could not load environmental variables from {:?}\n\
{:8}Are you in the right directory? Proceeding with variables from the environment.",
env::current_dir().unwrap_or_else(|_|"./".into()).join(env_file), ""
);
}
}
Ok(())
}
pub fn from_env<'a>(env_vars: HashMap<String, String>) -> (Postgres, Redis, Deployment<'a>) {
#[allow(clippy::implicit_hasher)]
pub fn from_env<'a>(
env_vars: HashMap<String, String>,
) -> Result<(Postgres, Redis, Deployment<'a>)> {
let env_vars = EnvVar::new(env_vars);
log::info!("Environmental variables Flodgatt received: {}", &env_vars);
(
Postgres::from_env(env_vars.clone()),
Redis::from_env(env_vars.clone()),
Deployment::from_env(env_vars.clone()),
)
log::info!(
"Flodgatt received the following environmental variables:{}",
&env_vars
);
let pg_cfg = Postgres::from_env(env_vars.clone())?;
log::info!("Configuration for {:#?}", &pg_cfg);
let redis_cfg = Redis::from_env(env_vars.clone())?;
log::info!("Configuration for {:#?},", &redis_cfg);
let deployment_cfg = Deployment::from_env(&env_vars)?;
log::info!("Configuration for {:#?}", &deployment_cfg);
Ok((pg_cfg, redis_cfg, deployment_cfg))
}

View File

@ -1,4 +1,5 @@
use super::{deployment_cfg_types::*, EnvVar};
use crate::err::FatalErr;
#[derive(Debug, Default)]
pub struct Deployment<'a> {
@ -14,20 +15,19 @@ pub struct Deployment<'a> {
}
impl Deployment<'_> {
pub fn from_env(env: EnvVar) -> Self {
pub fn from_env(env: &EnvVar) -> Result<Self, FatalErr> {
let mut cfg = Self {
env: Env::default().maybe_update(env.get("NODE_ENV")),
log_level: LogLevel::default().maybe_update(env.get("RUST_LOG")),
address: FlodgattAddr::default().maybe_update(env.get("BIND")),
port: Port::default().maybe_update(env.get("PORT")),
unix_socket: Socket::default().maybe_update(env.get("SOCKET")),
sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ")),
ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ")),
whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE")),
env: Env::default().maybe_update(env.get("NODE_ENV"))?,
log_level: LogLevel::default().maybe_update(env.get("RUST_LOG"))?,
address: FlodgattAddr::default().maybe_update(env.get("BIND"))?,
port: Port::default().maybe_update(env.get("PORT"))?,
unix_socket: Socket::default().maybe_update(env.get("SOCKET"))?,
sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ"))?,
ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ"))?,
whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE"))?,
cors: Cors::default(),
};
cfg.env = cfg.env.maybe_update(env.get("RUST_ENV"));
log::info!("Using deployment configuration:\n {:#?}", &cfg);
cfg
cfg.env = cfg.env.maybe_update(env.get("RUST_ENV"))?;
Ok(cfg)
}
}

View File

@ -17,7 +17,7 @@ from_env_var!(
from_env_var!(
/// The address to run Flodgatt on
let name = FlodgattAddr;
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
let default: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
let from_str = |s| match s {
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),

View File

@ -25,17 +25,6 @@ impl EnvVar {
self.0.insert(key.to_string(), value.to_string());
}
}
pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
log::error!(
r"{var} is set to `{value}`, which is invalid.
{var} must be {allowed_vals}.",
var = env_var,
value = supplied_value,
allowed_vals = allowed_values
);
std::process::exit(1);
}
}
impl fmt::Display for EnvVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@ -98,7 +87,7 @@ macro_rules! from_env_var {
#[derive(Clone)]
pub struct $name(pub $type);
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
@ -117,14 +106,14 @@ macro_rules! from_env_var {
fn inner_from_str($arg: &str) -> Option<$type> {
$body
}
pub fn maybe_update(self, var: Option<&String>) -> Self {
match var {
pub fn maybe_update(self, var: Option<&String>) -> Result<Self, crate::err::FatalErr> {
Ok(match var {
Some(empty_string) if empty_string.is_empty() => Self::default(),
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
crate::config::EnvVar::err($env_var, value, $allowed_values)
})),
Some(value) => Self(Self::inner_from_str(value).ok_or_else(|| {
crate::err::FatalErr::config($env_var, value, $allowed_values)
})?),
None => self,
}
})
}
}
};

View File

@ -1,7 +1,11 @@
use super::{postgres_cfg_types::*, EnvVar};
use crate::err::FatalErr;
use url::Url;
use urlencoding;
type Result<T> = std::result::Result<T, FatalErr>;
#[derive(Debug, Clone)]
pub struct Postgres {
pub user: PgUser,
@ -13,8 +17,8 @@ pub struct Postgres {
}
impl EnvVar {
fn update_with_postgres_url(mut self, url_str: &str) -> Self {
let url = Url::parse(url_str).unwrap();
fn update_with_postgres_url(mut self, url_str: &str) -> Result<Self> {
let url = Url::parse(url_str)?;
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
for (k, v) in url.query_pairs().into_owned() {
@ -23,11 +27,11 @@ impl EnvVar {
"password" => self.maybe_add_env_var("DB_PASS", Some(v.to_string())),
"host" => self.maybe_add_env_var("DB_HOST", Some(v.to_string())),
"sslmode" => self.maybe_add_env_var("DB_SSLMODE", Some(v.to_string())),
_ => crate::err::die_with_msg(format!(
r"Unsupported parameter {} in POSTGRES_URL
Flodgatt supports only `password`, `user`, `host`, and `sslmode` parameters",
k
)),
_ => Err(FatalErr::config(
"POSTGRES_URL",
&k,
"a URL with parameters `password`, `user`, `host`, and `sslmode` only",
))?,
}
}
@ -35,42 +39,39 @@ impl EnvVar {
self.maybe_add_env_var("DB_PASS", url.password());
self.maybe_add_env_var(
"DB_HOST",
url.host().map(|h| {
urlencoding::decode(&h.to_string()).expect("Non-Unicode text in hostname")
}),
url.host()
.map(|host| urlencoding::decode(&host.to_string()))
.transpose()?,
);
self.maybe_add_env_var("DB_USER", none_if_empty(url.username().to_string()));
self.maybe_add_env_var("DB_NAME", none_if_empty(url.path()[1..].to_string()));
self
Ok(self)
}
}
impl Postgres {
/// Configure Postgres and return a connection
pub fn from_env(env: EnvVar) -> Self {
pub fn from_env(env: EnvVar) -> Result<Self> {
let env = match env.get("DATABASE_URL").cloned() {
Some(url_str) => env.update_with_postgres_url(&url_str),
Some(url_str) => env.update_with_postgres_url(&url_str)?,
None => env,
};
let cfg = Self {
user: PgUser::default().maybe_update(env.get("DB_USER")),
host: PgHost::default().maybe_update(env.get("DB_HOST")),
password: PgPass::default().maybe_update(env.get("DB_PASS")),
database: PgDatabase::default().maybe_update(env.get("DB_NAME")),
port: PgPort::default().maybe_update(env.get("DB_PORT")),
ssl_mode: PgSslMode::default().maybe_update(env.get("DB_SSLMODE")),
user: PgUser::default().maybe_update(env.get("DB_USER"))?,
host: PgHost::default().maybe_update(env.get("DB_HOST"))?,
password: PgPass::default().maybe_update(env.get("DB_PASS"))?,
database: PgDatabase::default().maybe_update(env.get("DB_NAME"))?,
port: PgPort::default().maybe_update(env.get("DB_PORT"))?,
ssl_mode: PgSslMode::default().maybe_update(env.get("DB_SSLMODE"))?,
};
log::info!("Postgres configuration:\n{:#?}", &cfg);
cfg
Ok(cfg)
}
// // use openssl::ssl::{SslConnector, SslMethod};
// // use postgres_openssl::MakeTlsConnector;
// // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
// // builder.set_ca_file("/etc/ssl/cert.pem").unwrap();
// // let mut builder = SslConnector::builder(SslMethod::tls())?;
// // builder.set_ca_file("/etc/ssl/cert.pem")?;
// // let connector = MakeTlsConnector::new(builder.build());
// // TODO: add TLS support, remove `NoTls`
}

View File

@ -1,7 +1,11 @@
use super::redis_cfg_types::*;
use crate::config::EnvVar;
use super::EnvVar;
use crate::err::FatalErr;
use url::Url;
type Result<T> = std::result::Result<T, FatalErr>;
#[derive(Debug, Default)]
pub struct Redis {
pub user: RedisUser,
@ -17,8 +21,8 @@ pub struct Redis {
}
impl EnvVar {
fn update_with_redis_url(mut self, url_str: &str) -> Self {
let url = Url::parse(url_str).unwrap();
fn update_with_redis_url(mut self, url_str: &str) -> Result<Self> {
let url = Url::parse(url_str)?;
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
self.maybe_add_env_var("REDIS_PORT", url.port());
@ -29,14 +33,14 @@ impl EnvVar {
match k.to_string().as_str() {
"password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())),
"db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())),
_ => crate::err::die_with_msg(format!(
r"Unsupported parameter {} in REDIS_URL.
Flodgatt supports only `password` and `db` parameters.",
k
)),
_ => Err(FatalErr::config(
"REDIS_URL",
&k,
"a URL with parameters `password`, `db`, only",
))?,
}
}
self
Ok(self)
}
}
@ -46,20 +50,20 @@ impl Redis {
const DB_SET_WARNING: &'static str = r"Redis database specified, but PubSub connections do not use databases.
For similar functionality, you may wish to set a REDIS_NAMESPACE";
pub fn from_env(env: EnvVar) -> Self {
pub fn from_env(env: EnvVar) -> Result<Self> {
let env = match env.get("REDIS_URL").cloned() {
Some(url_str) => env.update_with_redis_url(&url_str),
Some(url_str) => env.update_with_redis_url(&url_str)?,
None => env,
};
let cfg = Redis {
user: RedisUser::default().maybe_update(env.get("REDIS_USER")),
password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD")),
port: RedisPort::default().maybe_update(env.get("REDIS_PORT")),
host: RedisHost::default().maybe_update(env.get("REDIS_HOST")),
db: RedisDb::default().maybe_update(env.get("REDIS_DB")),
namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE")),
polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_FREQ")),
user: RedisUser::default().maybe_update(env.get("REDIS_USER"))?,
password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD"))?,
port: RedisPort::default().maybe_update(env.get("REDIS_PORT"))?,
host: RedisHost::default().maybe_update(env.get("REDIS_HOST"))?,
db: RedisDb::default().maybe_update(env.get("REDIS_DB"))?,
namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE"))?,
polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_FREQ"))?,
};
if cfg.db.is_some() {
@ -68,7 +72,6 @@ For similar functionality, you may wish to set a REDIS_NAMESPACE";
if cfg.user.is_some() {
log::warn!("{}", Self::USER_SET_WARNING);
}
log::info!("Redis configuration:\n{:#?},", &cfg);
cfg
Ok(cfg)
}
}

View File

@ -1,17 +1,29 @@
use crate::request::RequestErr;
use crate::response::ManagerErr;
use std::fmt;
pub enum FatalErr {
Unknown,
ReceiverErr(ManagerErr),
DotEnv(dotenv::Error),
Logger(log::SetLoggerError),
Postgres(RequestErr),
Unrecoverable,
StdIo(std::io::Error),
// config errs
UrlParse(url::ParseError),
UrlEncoding(urlencoding::FromUrlEncodingError),
ConfigErr(String),
}
impl FatalErr {
pub fn exit(msg: impl fmt::Display) {
pub fn log(msg: impl fmt::Display) {
eprintln!("{}", msg);
std::process::exit(1);
}
pub fn config<T: fmt::Display>(var: T, value: T, allowed_vals: T) -> Self {
Self::ConfigErr(format!(
"{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.",
var, value, allowed_vals, ""
))
}
}
@ -29,18 +41,22 @@ impl fmt::Display for FatalErr {
f,
"{}",
match self {
Unknown => "Flodgatt encountered an unknown, unrecoverable error".into(),
ReceiverErr(e) => format!("{}", e),
Logger(e) => format!("{}", e),
DotEnv(e) => format!("Could not load specified environmental file: {}", e),
StdIo(e) => format!("{}", e),
Postgres(e) => format!("could not connect to Postgres.\n{:7}{}", "", e),
ConfigErr(e) => e.to_string(),
UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e),
UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e),
Unrecoverable => "Flodgatt will now shut down.".into(),
}
)
}
}
impl From<dotenv::Error> for FatalErr {
fn from(e: dotenv::Error) -> Self {
Self::DotEnv(e)
impl From<RequestErr> for FatalErr {
fn from(e: RequestErr) -> Self {
Self::Postgres(e)
}
}
@ -49,15 +65,23 @@ impl From<ManagerErr> for FatalErr {
Self::ReceiverErr(e)
}
}
impl From<urlencoding::FromUrlEncodingError> for FatalErr {
fn from(e: urlencoding::FromUrlEncodingError) -> Self {
Self::UrlEncoding(e)
}
}
impl From<url::ParseError> for FatalErr {
fn from(e: url::ParseError) -> Self {
Self::UrlParse(e)
}
}
impl From<std::io::Error> for FatalErr {
fn from(e: std::io::Error) -> Self {
Self::StdIo(e)
}
}
impl From<log::SetLoggerError> for FatalErr {
fn from(e: log::SetLoggerError) -> Self {
Self::Logger(e)
}
}
// TODO delete vvvv when postgres_cfg.rs has better error handling
pub fn die_with_msg(msg: impl fmt::Display) -> ! {
eprintln!("FATAL ERROR: {}", msg);
std::process::exit(1);
}

View File

@ -2,14 +2,14 @@ mod checked_event;
mod dynamic_event;
mod err;
pub use {
checked_event::{CheckedEvent, Id},
dynamic_event::{DynEvent, DynStatus, EventKind},
err::EventErr,
};
pub use checked_event::{CheckedEvent, Id};
pub use dynamic_event::{DynEvent, DynStatus, EventKind};
pub use err::EventErr;
use serde::Serialize;
use std::{convert::TryFrom, string::String};
use std::convert::TryFrom;
use std::string::String;
use warp::sse::ServerSentEvent;
#[derive(Debug, Clone)]
pub enum Event {
@ -20,15 +20,30 @@ pub enum Event {
impl Event {
pub fn to_json_string(&self) -> String {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
if let Event::Ping = self {
"{}".to_string()
} else {
let event = &self.event_name();
let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload },
None => SendableEvent::NoPayload { event },
};
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
}
}
pub fn event_name(&self) -> String {
pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
if let Event::Ping = self {
None
} else {
Some((
warp::sse::event(self.event_name()),
warp::sse::data(self.payload().unwrap_or_else(String::new)),
))
}
}
fn event_name(&self) -> String {
String::from(match self {
Self::TypeSafe(checked) => match checked {
CheckedEvent::Update { .. } => "update",
@ -45,11 +60,11 @@ impl Event {
..
}) => "update",
Self::Dynamic(DynEvent { event, .. }) => event,
Self::Ping => panic!("event_name() called on Ping"),
Self::Ping => unreachable!(), // private method only called above
})
}
pub fn payload(&self) -> Option<String> {
fn payload(&self) -> Option<String> {
use CheckedEvent::*;
match self {
Self::TypeSafe(checked) => match checked {
@ -63,7 +78,7 @@ impl Event {
FiltersChanged => None,
},
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
Self::Ping => panic!("payload() called on Ping"),
Self::Ping => unreachable!(), // private method only called above
}
}
}

View File

@ -43,7 +43,7 @@ type Result<T> = std::result::Result<T, EventErr>;
impl DynEvent {
pub fn set_update(self) -> Result<Self> {
if self.event == "update" {
let kind = EventKind::Update(DynStatus::new(self.payload.clone())?);
let kind = EventKind::Update(DynStatus::new(&self.payload.clone())?);
Ok(Self { kind, ..self })
} else {
Ok(self)
@ -52,7 +52,7 @@ impl DynEvent {
}
impl DynStatus {
pub fn new(payload: Value) -> Result<Self> {
pub fn new(payload: &Value) -> Result<Self> {
use EventErr::*;
Ok(Self {
@ -61,7 +61,7 @@ impl DynStatus {
.as_str()
.ok_or(DynParse)?
.to_string(),
language: payload["language"].as_str().map(|s| s.to_string()),
language: payload["language"].as_str().map(String::from),
mentioned_users: HashSet::new(),
replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(),
boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(),

View File

@ -37,6 +37,7 @@
//#![warn(clippy::pedantic)]
#![allow(clippy::try_err, clippy::match_bool)]
//#![allow(clippy::large_enum_variant)]
pub mod config;
pub mod err;

View File

@ -19,20 +19,19 @@ use warp::Filter;
fn main() -> Result<(), FatalErr> {
config::merge_dotenv()?;
pretty_env_logger::try_init()?;
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect());
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?;
let poll_freq = *redis_cfg.polling_interval;
// Create channels to communicate between threads
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode);
let poll_freq = *redis_cfg.polling_interval;
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?;
let shared_manager = redis::Manager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc();
// Server Sent Events
let sse_manager = shared_manager.clone();
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
let sse = request
.sse_subscription()
.and(warp::sse())
@ -85,8 +84,7 @@ fn main() -> Result<(), FatalErr> {
.map_err(|e| log::error!("{}", e))
.for_each(move |_| {
let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover);
manager.poll_broadcast().unwrap_or_else(FatalErr::exit);
Ok(())
manager.poll_broadcast().map_err(FatalErr::log)
});
warp::spawn(lazy(move || stream));
warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err))
@ -95,13 +93,13 @@ fn main() -> Result<(), FatalErr> {
if let Some(socket) = &*cfg.unix_socket {
log::info!("Using Unix socket {}", socket);
fs::remove_file(socket).unwrap_or_default();
let incoming = UnixListener::bind(socket).expect("TODO").incoming();
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).expect("TODO");
let incoming = UnixListener::bind(socket)?.incoming();
fs::set_permissions(socket, PermissionsExt::from_mode(0o666))?;
tokio::run(lazy(|| streaming_server().serve_incoming(incoming)));
} else {
let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
tokio::run(lazy(move || streaming_server().bind(server_addr)));
}
Ok(())
Err(FatalErr::Unrecoverable) // on get here if there's an unrecoverable error in poll_broadcast.
}

View File

@ -3,8 +3,10 @@ mod postgres;
mod query;
pub mod timeline;
mod err;
mod subscription;
pub use self::err::RequestErr;
pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API
pub use subscription::{Blocks, Subscription};
@ -22,6 +24,8 @@ mod sse_test;
#[cfg(test)]
mod ws_test;
type Result<T> = std::result::Result<T, err::RequestErr>;
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
@ -56,10 +60,10 @@ pub struct Handler {
}
impl Handler {
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self {
Self {
pg_conn: PgPool::new(postgres_cfg, whitelist_mode),
}
pub fn new(postgres_cfg: &config::Postgres, whitelist_mode: bool) -> Result<Self> {
Ok(Self {
pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?,
})
}
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
@ -113,7 +117,7 @@ impl Handler {
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
}
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> {
pub fn err(r: Rejection) -> std::result::Result<impl warp::Reply, warp::Rejection> {
let json_err = match r.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
warp::reply::json(&"Error: Missing access token".to_string())

32
src/request/err.rs Normal file
View File

@ -0,0 +1,32 @@
use std::fmt;
#[derive(Debug)]
pub enum RequestErr {
Unknown,
PgPool(r2d2::Error),
Pg(postgres::Error),
}
impl std::error::Error for RequestErr {}
impl fmt::Display for RequestErr {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use RequestErr::*;
let msg = match self {
Unknown => "Encountered an unrecoverable error related to handling a request".into(),
PgPool(e) => format!("{}", e),
Pg(e) => format!("{}", e),
};
write!(f, "{}", msg)
}
}
impl From<r2d2::Error> for RequestErr {
fn from(e: r2d2::Error) -> Self {
Self::PgPool(e)
}
}
impl From<postgres::Error> for RequestErr {
fn from(e: postgres::Error) -> Self {
Self::Pg(e)
}
}

View File

@ -1,13 +1,13 @@
//! Postgres queries
use super::err;
use super::timeline::{Scope, UserData};
use crate::config;
use crate::event::Id;
use crate::request::timeline::{Scope, UserData};
use ::postgres;
use hashbrown::HashSet;
use r2d2_postgres::PostgresConnectionManager;
use std::convert::TryFrom;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
pub struct PgPool {
@ -15,8 +15,11 @@ pub struct PgPool {
whitelist_mode: bool,
}
type Result<T> = std::result::Result<T, err::RequestErr>;
type Rejectable<T> = std::result::Result<T, warp::Rejection>;
impl PgPool {
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
pub fn new(pg_cfg: &config::Postgres, whitelist_mode: bool) -> Result<Self> {
let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string())
@ -26,19 +29,20 @@ impl PgPool {
cfg.password(password);
};
cfg.connect(postgres::NoTls)?; // Test connection, letting us immediately exit with an error
// when Postgres isn't running instead of timing out below
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder()
.max_size(10)
.build(manager)
.expect("Can connect to local postgres");
Self {
let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
Ok(Self {
conn: pool,
whitelist_mode,
}
})
}
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> {
let mut conn = self.conn.get().unwrap();
pub fn select_user(self, token: &Option<String>) -> Rejectable<UserData> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
if let Some(token) = token {
let query_rows = conn
.query("
@ -47,9 +51,9 @@ SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_lan
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
&[&token.to_owned()],
).map_err(warp::reject::custom)?;
if let Some(result_columns) = query_rows.get(0) {
let id = Id(result_columns.get(1));
let allowed_langs = result_columns
@ -85,10 +89,10 @@ LIMIT 1",
}
}
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable<i64> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.get(0)
.map(|row| row.get(0))
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
@ -98,31 +102,31 @@ LIMIT 1",
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_blocked_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query(
"SELECT target_account_id FROM blocks WHERE account_id = $1
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.iter()
.map(|row| Id(row.get(0)))
.map(|row| Ok(Id(row.get(0))))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_blocking_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query(
"SELECT account_id FROM blocks WHERE target_account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.iter()
.map(|row| Id(row.get(0)))
.map(|row| Ok(Id(row.get(0))))
.collect()
}
@ -130,28 +134,28 @@ LIMIT 1",
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
let mut conn = self.conn.get().expect("TODO");
pub fn select_blocked_domains(self, user_id: Id) -> Rejectable<HashSet<String>> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.map_err(warp::reject::custom)?
.iter()
.map(|row| row.get(0))
.map(|row| Ok(row.get(0)))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
let mut conn = self.conn.get().expect("TODO");
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable<bool> {
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
.map_err(warp::reject::custom)?;
Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id))
}
}

View File

@ -54,7 +54,7 @@ impl Subscription {
let tag = pool.select_hashtag_id(&q.hashtag)?;
Timeline(Hashtag(tag), reach, stream)
}
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => {
Err(warp::reject::custom("Error: Missing access token"))?
}
other_tl => other_tl,
@ -70,9 +70,9 @@ impl Subscription {
timeline,
allowed_langs: user.allowed_langs,
blocks: Blocks {
blocking_users: pool.clone().select_blocking_users(user.id),
blocked_users: pool.clone().select_blocked_users(user.id),
blocked_domains: pool.select_blocked_domains(user.id),
blocking_users: pool.clone().select_blocking_users(user.id)?,
blocked_users: pool.clone().select_blocked_users(user.id)?,
blocked_domains: pool.select_blocked_domains(user.id)?,
},
hashtag_name,
access_token: q.access_token,

View File

@ -19,25 +19,29 @@ impl Timeline {
}
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String> {
use {Content::*, Reach::*, Stream::*};
// TODO -- does this need to account for namespaces?
use {Content::*, Reach::*, Stream::*, TimelineErr::*};
Ok(match self {
Timeline(Public, Federated, All) => "timeline:public".into(),
Timeline(Public, Local, All) => "timeline:public:local".into(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
// TODO -- would `.push_str` be faster here?
Timeline(Hashtag(_id), Federated, All) => format!(
"timeline:hashtag:{}",
hashtag.ok_or(TimelineErr::MissingHashtag)?
),
Timeline(Hashtag(_id), Local, All) => format!(
"timeline:hashtag:{}:local",
hashtag.ok_or(TimelineErr::MissingHashtag)?
),
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
Timeline(Public, Federated, All) => "timeline:public".to_string(),
Timeline(Public, Local, All) => "timeline:public:local".to_string(),
Timeline(Public, Federated, Media) => "timeline:public:media".to_string(),
Timeline(Public, Local, Media) => "timeline:public:local:media".to_string(),
Timeline(Hashtag(_id), Federated, All) => {
["timeline:hashtag:", hashtag.ok_or(MissingHashtag)?].concat()
}
Timeline(Hashtag(_id), Local, All) => [
"timeline:hashtag:",
hashtag.ok_or(MissingHashtag)?,
":local",
]
.concat(),
Timeline(User(id), Federated, All) => ["timeline:", &id.to_string()].concat(),
Timeline(User(id), Federated, Notification) => {
["timeline:", &id.to_string(), ":notification"].concat()
}
Timeline(List(id), Federated, All) => ["timeline:list:", &id.to_string()].concat(),
Timeline(Direct(id), Federated, All) => ["timeline:direct:", &id.to_string()].concat(),
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
})
}
@ -57,8 +61,7 @@ impl Timeline {
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
["list", id] => Timeline(List(id.parse()?), Federated, All),
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
// Other endpoints don't exist:
[..] => Err(InvalidInput)?,
[..] => Err(InvalidInput)?, // Other endpoints don't exist
})
}

View File

@ -12,15 +12,43 @@ pub enum RedisCmd {
}
impl RedisCmd {
pub fn into_sendable(&self, tl: &String) -> (Vec<u8>, Vec<u8>) {
pub fn into_sendable(self, tl: &str) -> (Vec<u8>, Vec<u8>) {
match self {
RedisCmd::Subscribe => (
format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(),
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl).into_bytes(),
[
b"*2\r\n$9\r\nsubscribe\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n",
]
.concat(),
[
b"*3\r\n$3\r\nSET\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n$1\r\n1\r\n",
]
.concat(),
),
RedisCmd::Unsubscribe => (
format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(),
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl).into_bytes(),
[
b"*2\r\n$11\r\nunsubscribe\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n",
]
.concat(),
[
b"*3\r\n$3\r\nSET\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n$1\r\n0\r\n",
]
.concat(),
),
}
}

View File

@ -28,8 +28,9 @@ pub struct RedisConn {
}
impl RedisConn {
pub fn new(redis_cfg: Redis) -> Result<Self> {
let addr = format!("{}:{}", *redis_cfg.host, *redis_cfg.port);
pub fn new(redis_cfg: &Redis) -> Result<Self> {
let addr = [&*redis_cfg.host, ":", &*redis_cfg.port.to_string()].concat();
let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?;
conn.set_nonblocking(true)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
@ -49,7 +50,7 @@ impl RedisConn {
pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ManagerErr> {
let mut size = 100; // large enough to handle subscribe/unsubscribe notice
let (mut buffer, mut first_read) = (vec![0u8; size], true);
let (mut buffer, mut first_read) = (vec![0_u8; size], true);
loop {
match self.primary.read(&mut buffer) {
Ok(n) if n != size => break self.redis_input.extend_from_slice(&buffer[..n]),
@ -81,7 +82,7 @@ impl RedisConn {
use {Async::*, RedisParseOutput::*};
let (res, leftover) = match RedisParseOutput::try_from(input) {
Ok(Msg(msg)) => match &self.redis_namespace {
Some(ns) if msg.timeline_txt.starts_with(&format!("{}:timeline:", ns)) => {
Some(ns) if msg.timeline_txt.starts_with(&[ns, ":timeline:"].concat()) => {
let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..];
let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?;
let event = msg.event_txt.try_into()?;
@ -135,8 +136,17 @@ impl RedisConn {
}
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
conn.write_all(
&[
b"*2\r\n$4\r\nauth\r\n$",
pass.len().to_string().as_bytes(),
b"\r\n",
pass.as_bytes(),
b"\r\n",
]
.concat(),
)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let mut buffer = vec![0_u8; 5];
conn.read_exact(&mut buffer)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;

View File

@ -31,7 +31,7 @@ impl Manager {
/// Create a new `Manager`, with its own Redis connections (but, as yet, no
/// active subscriptions).
pub fn try_from(
redis_cfg: config::Redis,
redis_cfg: &config::Redis,
tx: watch::Sender<(Timeline, Event)>,
rx: mpsc::UnboundedReceiver<Timeline>,
) -> Result<Self> {
@ -99,11 +99,10 @@ impl Manager {
self.tx.broadcast((Timeline::empty(), Event::Ping))?
} else {
match self.redis_connection.poll_redis() {
Ok(Async::NotReady) => (),
Ok(Async::NotReady) | Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace
Ok(Async::Ready(Some((timeline, event)))) => {
self.tx.broadcast((timeline, event))?
}
Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace
Err(err) => log::error!("{}", err), // drop msg, log err, and proceed
}
}

View File

@ -82,10 +82,7 @@ fn utf8_to_redis_data<'a>(s: &'a str) -> Result<(RedisData, &'a str), RedisParse
":" => parse_redis_int(s),
"$" => parse_redis_bulk_string(s),
"*" => parse_redis_array(s),
e => Err(InvalidLineStart(format!(
"Encountered invalid initial character `{}` in line `{}`",
e, s
))),
e => Err(InvalidLineStart(e.to_string())),
}
}

View File

@ -6,18 +6,11 @@ use log;
use std::time::Duration;
use tokio::sync::{mpsc, watch};
use warp::reply::Reply;
use warp::sse::{ServerSentEvent, Sse as WarpSse};
use warp::sse::Sse as WarpSse;
pub struct Sse;
impl Sse {
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
))
}
pub fn send_events(
sse: WarpSse,
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
@ -40,21 +33,20 @@ impl Sse {
TypeSafe(Update { payload, queued_at }) => match timeline {
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
_ if payload.involves_any(&blocks) => None,
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
payload,
queued_at,
})),
_ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at })
.to_warp_reply(),
},
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind {
match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
_ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(DynEvent {
_ => Dynamic(DynEvent {
kind: EventKind::Update(s),
..dyn_event
})),
})
.to_warp_reply(),
}
} else {
None

View File

@ -39,9 +39,11 @@ impl Ws {
.map_err(|_| -> warp::Error { unreachable!() })
.forward(transmit_to_ws)
.map(|_r| ())
.map_err(|e| match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("WebSocket send error: {}", e),
.map_err(|e| {
match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("WebSocket send error: {}", e),
}
}),
);
@ -50,7 +52,7 @@ impl Ws {
incoming_events.for_each(move |(tl, event)| {
if matches!(event, Event::Ping) {
self.send_ping()
self.send_msg(&event)
} else if target_timeline == tl {
use crate::event::{CheckedEvent::Update, Event::*, EventKind};
use crate::request::Stream::Public;
@ -61,18 +63,18 @@ impl Ws {
TypeSafe(Update { payload, queued_at }) => match tl {
Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()),
_ if payload.involves_any(&blocks) => Ok(()),
_ => self.send_msg(TypeSafe(Update { payload, queued_at })),
_ => self.send_msg(&TypeSafe(Update { payload, queued_at })),
},
TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)),
TypeSafe(non_update) => self.send_msg(&TypeSafe(non_update)),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind.clone() {
match tl {
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
_ if s.involves_any(&blocks) => Ok(()),
_ => self.send_msg(Dynamic(dyn_event)),
_ => self.send_msg(&Dynamic(dyn_event)),
}
} else {
self.send_msg(Dynamic(dyn_event))
self.send_msg(&Dynamic(dyn_event))
}
}
Ping => unreachable!(), // handled pings above
@ -83,24 +85,14 @@ impl Ws {
})
}
fn send_ping(&mut self) -> Result<(), ()> {
self.send_txt("{}")
}
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
self.send_txt(&event.to_json_string())
}
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
fn send_msg(&mut self, event: &Event) -> Result<(), ()> {
let txt = &event.to_json_string();
let tl = self.subscription.timeline;
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
Ok(_) => Ok(()),
Err(_) => {
self.unsubscribe_tx
.try_send(tl)
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
Err(())
}
}
let mut channel = self.ws_tx.clone().ok_or(())?;
channel.try_send(Message::text(txt)).map_err(|_| {
self.unsubscribe_tx
.try_send(tl)
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
})
}
}