Error handling, pt3 (#131)

* Improve handling of Postgres errors

* Finish error handling improvements

* Remove `format!` calls from hot path
This commit is contained in:
Daniel Sockwell 2020-04-14 20:37:49 -04:00 committed by GitHub
parent 45f9d4b9fb
commit 37b652ad79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 372 additions and 257 deletions

View File

2
Cargo.lock generated
View File

@ -406,7 +406,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "flodgatt" name = "flodgatt"
version = "0.8.3" version = "0.8.4"
dependencies = [ dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@ -1,7 +1,7 @@
[package] [package]
name = "flodgatt" name = "flodgatt"
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
version = "0.8.3" version = "0.8.4"
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"] authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
edition = "2018" edition = "2018"

View File

@ -1,7 +1,7 @@
pub use {deployment_cfg::Deployment, postgres_cfg::Postgres, redis_cfg::Redis}; pub use {deployment_cfg::Deployment, postgres_cfg::Postgres, redis_cfg::Redis};
use self::environmental_variables::EnvVar; use self::environmental_variables::EnvVar;
use super::err; use super::err::FatalErr;
use hashbrown::HashMap; use hashbrown::HashMap;
use std::env; use std::env;
@ -13,22 +13,45 @@ mod postgres_cfg_types;
mod redis_cfg; mod redis_cfg;
mod redis_cfg_types; mod redis_cfg_types;
pub fn merge_dotenv() -> Result<(), err::FatalErr> { type Result<T> = std::result::Result<T, FatalErr>;
// TODO -- should this allow the user to run in a dir without a `.env` file?
dotenv::from_filename(match env::var("ENV").ok().as_deref() { pub fn merge_dotenv() -> Result<()> {
let env_file = match env::var("ENV").ok().as_deref() {
Some("production") => ".env.production", Some("production") => ".env.production",
Some("development") | None => ".env", Some("development") | None => ".env",
Some(_unsupported) => Err(err::FatalErr::Unknown)?, // TODO make more specific Some(v) => Err(FatalErr::config("ENV", v, "`production` or `development`"))?,
})?; };
let res = dotenv::from_filename(env_file);
if let Ok(log_level) = env::var("RUST_LOG") {
if res.is_err() && ["warn", "info", "trace", "debug"].contains(&log_level.as_str()) {
eprintln!(
" WARN: could not load environmental variables from {:?}\n\
{:8}Are you in the right directory? Proceeding with variables from the environment.",
env::current_dir().unwrap_or_else(|_|"./".into()).join(env_file), ""
);
}
}
Ok(()) Ok(())
} }
pub fn from_env<'a>(env_vars: HashMap<String, String>) -> (Postgres, Redis, Deployment<'a>) { #[allow(clippy::implicit_hasher)]
pub fn from_env<'a>(
env_vars: HashMap<String, String>,
) -> Result<(Postgres, Redis, Deployment<'a>)> {
let env_vars = EnvVar::new(env_vars); let env_vars = EnvVar::new(env_vars);
log::info!("Environmental variables Flodgatt received: {}", &env_vars); log::info!(
( "Flodgatt received the following environmental variables:{}",
Postgres::from_env(env_vars.clone()), &env_vars
Redis::from_env(env_vars.clone()), );
Deployment::from_env(env_vars.clone()),
) let pg_cfg = Postgres::from_env(env_vars.clone())?;
log::info!("Configuration for {:#?}", &pg_cfg);
let redis_cfg = Redis::from_env(env_vars.clone())?;
log::info!("Configuration for {:#?},", &redis_cfg);
let deployment_cfg = Deployment::from_env(&env_vars)?;
log::info!("Configuration for {:#?}", &deployment_cfg);
Ok((pg_cfg, redis_cfg, deployment_cfg))
} }

View File

@ -1,4 +1,5 @@
use super::{deployment_cfg_types::*, EnvVar}; use super::{deployment_cfg_types::*, EnvVar};
use crate::err::FatalErr;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Deployment<'a> { pub struct Deployment<'a> {
@ -14,20 +15,19 @@ pub struct Deployment<'a> {
} }
impl Deployment<'_> { impl Deployment<'_> {
pub fn from_env(env: EnvVar) -> Self { pub fn from_env(env: &EnvVar) -> Result<Self, FatalErr> {
let mut cfg = Self { let mut cfg = Self {
env: Env::default().maybe_update(env.get("NODE_ENV")), env: Env::default().maybe_update(env.get("NODE_ENV"))?,
log_level: LogLevel::default().maybe_update(env.get("RUST_LOG")), log_level: LogLevel::default().maybe_update(env.get("RUST_LOG"))?,
address: FlodgattAddr::default().maybe_update(env.get("BIND")), address: FlodgattAddr::default().maybe_update(env.get("BIND"))?,
port: Port::default().maybe_update(env.get("PORT")), port: Port::default().maybe_update(env.get("PORT"))?,
unix_socket: Socket::default().maybe_update(env.get("SOCKET")), unix_socket: Socket::default().maybe_update(env.get("SOCKET"))?,
sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ")), sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ"))?,
ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ")), ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ"))?,
whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE")), whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE"))?,
cors: Cors::default(), cors: Cors::default(),
}; };
cfg.env = cfg.env.maybe_update(env.get("RUST_ENV")); cfg.env = cfg.env.maybe_update(env.get("RUST_ENV"))?;
log::info!("Using deployment configuration:\n {:#?}", &cfg); Ok(cfg)
cfg
} }
} }

View File

@ -17,7 +17,7 @@ from_env_var!(
from_env_var!( from_env_var!(
/// The address to run Flodgatt on /// The address to run Flodgatt on
let name = FlodgattAddr; let name = FlodgattAddr;
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded")); let default: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)"); let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
let from_str = |s| match s { let from_str = |s| match s {
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), "localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),

View File

@ -25,17 +25,6 @@ impl EnvVar {
self.0.insert(key.to_string(), value.to_string()); self.0.insert(key.to_string(), value.to_string());
} }
} }
pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
log::error!(
r"{var} is set to `{value}`, which is invalid.
{var} must be {allowed_vals}.",
var = env_var,
value = supplied_value,
allowed_vals = allowed_values
);
std::process::exit(1);
}
} }
impl fmt::Display for EnvVar { impl fmt::Display for EnvVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@ -98,7 +87,7 @@ macro_rules! from_env_var {
#[derive(Clone)] #[derive(Clone)]
pub struct $name(pub $type); pub struct $name(pub $type);
impl std::fmt::Debug for $name { impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self.0) write!(f, "{:?}", self.0)
} }
} }
@ -117,14 +106,14 @@ macro_rules! from_env_var {
fn inner_from_str($arg: &str) -> Option<$type> { fn inner_from_str($arg: &str) -> Option<$type> {
$body $body
} }
pub fn maybe_update(self, var: Option<&String>) -> Self { pub fn maybe_update(self, var: Option<&String>) -> Result<Self, crate::err::FatalErr> {
match var { Ok(match var {
Some(empty_string) if empty_string.is_empty() => Self::default(), Some(empty_string) if empty_string.is_empty() => Self::default(),
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| { Some(value) => Self(Self::inner_from_str(value).ok_or_else(|| {
crate::config::EnvVar::err($env_var, value, $allowed_values) crate::err::FatalErr::config($env_var, value, $allowed_values)
})), })?),
None => self, None => self,
} })
} }
} }
}; };

View File

@ -1,7 +1,11 @@
use super::{postgres_cfg_types::*, EnvVar}; use super::{postgres_cfg_types::*, EnvVar};
use crate::err::FatalErr;
use url::Url; use url::Url;
use urlencoding; use urlencoding;
type Result<T> = std::result::Result<T, FatalErr>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Postgres { pub struct Postgres {
pub user: PgUser, pub user: PgUser,
@ -13,8 +17,8 @@ pub struct Postgres {
} }
impl EnvVar { impl EnvVar {
fn update_with_postgres_url(mut self, url_str: &str) -> Self { fn update_with_postgres_url(mut self, url_str: &str) -> Result<Self> {
let url = Url::parse(url_str).unwrap(); let url = Url::parse(url_str)?;
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) }; let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
for (k, v) in url.query_pairs().into_owned() { for (k, v) in url.query_pairs().into_owned() {
@ -23,11 +27,11 @@ impl EnvVar {
"password" => self.maybe_add_env_var("DB_PASS", Some(v.to_string())), "password" => self.maybe_add_env_var("DB_PASS", Some(v.to_string())),
"host" => self.maybe_add_env_var("DB_HOST", Some(v.to_string())), "host" => self.maybe_add_env_var("DB_HOST", Some(v.to_string())),
"sslmode" => self.maybe_add_env_var("DB_SSLMODE", Some(v.to_string())), "sslmode" => self.maybe_add_env_var("DB_SSLMODE", Some(v.to_string())),
_ => crate::err::die_with_msg(format!( _ => Err(FatalErr::config(
r"Unsupported parameter {} in POSTGRES_URL "POSTGRES_URL",
Flodgatt supports only `password`, `user`, `host`, and `sslmode` parameters", &k,
k "a URL with parameters `password`, `user`, `host`, and `sslmode` only",
)), ))?,
} }
} }
@ -35,42 +39,39 @@ impl EnvVar {
self.maybe_add_env_var("DB_PASS", url.password()); self.maybe_add_env_var("DB_PASS", url.password());
self.maybe_add_env_var( self.maybe_add_env_var(
"DB_HOST", "DB_HOST",
url.host().map(|h| { url.host()
urlencoding::decode(&h.to_string()).expect("Non-Unicode text in hostname") .map(|host| urlencoding::decode(&host.to_string()))
}), .transpose()?,
); );
self.maybe_add_env_var("DB_USER", none_if_empty(url.username().to_string())); self.maybe_add_env_var("DB_USER", none_if_empty(url.username().to_string()));
self.maybe_add_env_var("DB_NAME", none_if_empty(url.path()[1..].to_string())); self.maybe_add_env_var("DB_NAME", none_if_empty(url.path()[1..].to_string()));
Ok(self)
self
} }
} }
impl Postgres { impl Postgres {
/// Configure Postgres and return a connection /// Configure Postgres and return a connection
pub fn from_env(env: EnvVar) -> Result<Self> {
pub fn from_env(env: EnvVar) -> Self {
let env = match env.get("DATABASE_URL").cloned() { let env = match env.get("DATABASE_URL").cloned() {
Some(url_str) => env.update_with_postgres_url(&url_str), Some(url_str) => env.update_with_postgres_url(&url_str)?,
None => env, None => env,
}; };
let cfg = Self { let cfg = Self {
user: PgUser::default().maybe_update(env.get("DB_USER")), user: PgUser::default().maybe_update(env.get("DB_USER"))?,
host: PgHost::default().maybe_update(env.get("DB_HOST")), host: PgHost::default().maybe_update(env.get("DB_HOST"))?,
password: PgPass::default().maybe_update(env.get("DB_PASS")), password: PgPass::default().maybe_update(env.get("DB_PASS"))?,
database: PgDatabase::default().maybe_update(env.get("DB_NAME")), database: PgDatabase::default().maybe_update(env.get("DB_NAME"))?,
port: PgPort::default().maybe_update(env.get("DB_PORT")), port: PgPort::default().maybe_update(env.get("DB_PORT"))?,
ssl_mode: PgSslMode::default().maybe_update(env.get("DB_SSLMODE")), ssl_mode: PgSslMode::default().maybe_update(env.get("DB_SSLMODE"))?,
}; };
log::info!("Postgres configuration:\n{:#?}", &cfg); Ok(cfg)
cfg
} }
// // use openssl::ssl::{SslConnector, SslMethod}; // // use openssl::ssl::{SslConnector, SslMethod};
// // use postgres_openssl::MakeTlsConnector; // // use postgres_openssl::MakeTlsConnector;
// // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); // // let mut builder = SslConnector::builder(SslMethod::tls())?;
// // builder.set_ca_file("/etc/ssl/cert.pem").unwrap(); // // builder.set_ca_file("/etc/ssl/cert.pem")?;
// // let connector = MakeTlsConnector::new(builder.build()); // // let connector = MakeTlsConnector::new(builder.build());
// // TODO: add TLS support, remove `NoTls` // // TODO: add TLS support, remove `NoTls`
} }

View File

@ -1,7 +1,11 @@
use super::redis_cfg_types::*; use super::redis_cfg_types::*;
use crate::config::EnvVar; use super::EnvVar;
use crate::err::FatalErr;
use url::Url; use url::Url;
type Result<T> = std::result::Result<T, FatalErr>;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Redis { pub struct Redis {
pub user: RedisUser, pub user: RedisUser,
@ -17,8 +21,8 @@ pub struct Redis {
} }
impl EnvVar { impl EnvVar {
fn update_with_redis_url(mut self, url_str: &str) -> Self { fn update_with_redis_url(mut self, url_str: &str) -> Result<Self> {
let url = Url::parse(url_str).unwrap(); let url = Url::parse(url_str)?;
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) }; let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
self.maybe_add_env_var("REDIS_PORT", url.port()); self.maybe_add_env_var("REDIS_PORT", url.port());
@ -29,14 +33,14 @@ impl EnvVar {
match k.to_string().as_str() { match k.to_string().as_str() {
"password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())), "password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())),
"db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())), "db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())),
_ => crate::err::die_with_msg(format!( _ => Err(FatalErr::config(
r"Unsupported parameter {} in REDIS_URL. "REDIS_URL",
Flodgatt supports only `password` and `db` parameters.", &k,
k "a URL with parameters `password`, `db`, only",
)), ))?,
} }
} }
self Ok(self)
} }
} }
@ -46,20 +50,20 @@ impl Redis {
const DB_SET_WARNING: &'static str = r"Redis database specified, but PubSub connections do not use databases. const DB_SET_WARNING: &'static str = r"Redis database specified, but PubSub connections do not use databases.
For similar functionality, you may wish to set a REDIS_NAMESPACE"; For similar functionality, you may wish to set a REDIS_NAMESPACE";
pub fn from_env(env: EnvVar) -> Self { pub fn from_env(env: EnvVar) -> Result<Self> {
let env = match env.get("REDIS_URL").cloned() { let env = match env.get("REDIS_URL").cloned() {
Some(url_str) => env.update_with_redis_url(&url_str), Some(url_str) => env.update_with_redis_url(&url_str)?,
None => env, None => env,
}; };
let cfg = Redis { let cfg = Redis {
user: RedisUser::default().maybe_update(env.get("REDIS_USER")), user: RedisUser::default().maybe_update(env.get("REDIS_USER"))?,
password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD")), password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD"))?,
port: RedisPort::default().maybe_update(env.get("REDIS_PORT")), port: RedisPort::default().maybe_update(env.get("REDIS_PORT"))?,
host: RedisHost::default().maybe_update(env.get("REDIS_HOST")), host: RedisHost::default().maybe_update(env.get("REDIS_HOST"))?,
db: RedisDb::default().maybe_update(env.get("REDIS_DB")), db: RedisDb::default().maybe_update(env.get("REDIS_DB"))?,
namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE")), namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE"))?,
polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_FREQ")), polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_FREQ"))?,
}; };
if cfg.db.is_some() { if cfg.db.is_some() {
@ -68,7 +72,6 @@ For similar functionality, you may wish to set a REDIS_NAMESPACE";
if cfg.user.is_some() { if cfg.user.is_some() {
log::warn!("{}", Self::USER_SET_WARNING); log::warn!("{}", Self::USER_SET_WARNING);
} }
log::info!("Redis configuration:\n{:#?},", &cfg); Ok(cfg)
cfg
} }
} }

View File

@ -1,17 +1,29 @@
use crate::request::RequestErr;
use crate::response::ManagerErr; use crate::response::ManagerErr;
use std::fmt; use std::fmt;
pub enum FatalErr { pub enum FatalErr {
Unknown,
ReceiverErr(ManagerErr), ReceiverErr(ManagerErr),
DotEnv(dotenv::Error),
Logger(log::SetLoggerError), Logger(log::SetLoggerError),
Postgres(RequestErr),
Unrecoverable,
StdIo(std::io::Error),
// config errs
UrlParse(url::ParseError),
UrlEncoding(urlencoding::FromUrlEncodingError),
ConfigErr(String),
} }
impl FatalErr { impl FatalErr {
pub fn exit(msg: impl fmt::Display) { pub fn log(msg: impl fmt::Display) {
eprintln!("{}", msg); eprintln!("{}", msg);
std::process::exit(1); }
pub fn config<T: fmt::Display>(var: T, value: T, allowed_vals: T) -> Self {
Self::ConfigErr(format!(
"{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.",
var, value, allowed_vals, ""
))
} }
} }
@ -29,18 +41,22 @@ impl fmt::Display for FatalErr {
f, f,
"{}", "{}",
match self { match self {
Unknown => "Flodgatt encountered an unknown, unrecoverable error".into(),
ReceiverErr(e) => format!("{}", e), ReceiverErr(e) => format!("{}", e),
Logger(e) => format!("{}", e), Logger(e) => format!("{}", e),
DotEnv(e) => format!("Could not load specified environmental file: {}", e), StdIo(e) => format!("{}", e),
Postgres(e) => format!("could not connect to Postgres.\n{:7}{}", "", e),
ConfigErr(e) => e.to_string(),
UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e),
UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e),
Unrecoverable => "Flodgatt will now shut down.".into(),
} }
) )
} }
} }
impl From<dotenv::Error> for FatalErr { impl From<RequestErr> for FatalErr {
fn from(e: dotenv::Error) -> Self { fn from(e: RequestErr) -> Self {
Self::DotEnv(e) Self::Postgres(e)
} }
} }
@ -49,15 +65,23 @@ impl From<ManagerErr> for FatalErr {
Self::ReceiverErr(e) Self::ReceiverErr(e)
} }
} }
impl From<urlencoding::FromUrlEncodingError> for FatalErr {
fn from(e: urlencoding::FromUrlEncodingError) -> Self {
Self::UrlEncoding(e)
}
}
impl From<url::ParseError> for FatalErr {
fn from(e: url::ParseError) -> Self {
Self::UrlParse(e)
}
}
impl From<std::io::Error> for FatalErr {
fn from(e: std::io::Error) -> Self {
Self::StdIo(e)
}
}
impl From<log::SetLoggerError> for FatalErr { impl From<log::SetLoggerError> for FatalErr {
fn from(e: log::SetLoggerError) -> Self { fn from(e: log::SetLoggerError) -> Self {
Self::Logger(e) Self::Logger(e)
} }
} }
// TODO delete vvvv when postgres_cfg.rs has better error handling
pub fn die_with_msg(msg: impl fmt::Display) -> ! {
eprintln!("FATAL ERROR: {}", msg);
std::process::exit(1);
}

View File

@ -2,14 +2,14 @@ mod checked_event;
mod dynamic_event; mod dynamic_event;
mod err; mod err;
pub use { pub use checked_event::{CheckedEvent, Id};
checked_event::{CheckedEvent, Id}, pub use dynamic_event::{DynEvent, DynStatus, EventKind};
dynamic_event::{DynEvent, DynStatus, EventKind}, pub use err::EventErr;
err::EventErr,
};
use serde::Serialize; use serde::Serialize;
use std::{convert::TryFrom, string::String}; use std::convert::TryFrom;
use std::string::String;
use warp::sse::ServerSentEvent;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Event { pub enum Event {
@ -20,6 +20,9 @@ pub enum Event {
impl Event { impl Event {
pub fn to_json_string(&self) -> String { pub fn to_json_string(&self) -> String {
if let Event::Ping = self {
"{}".to_string()
} else {
let event = &self.event_name(); let event = &self.event_name();
let sendable_event = match self.payload() { let sendable_event = match self.payload() {
Some(payload) => SendableEvent::WithPayload { event, payload }, Some(payload) => SendableEvent::WithPayload { event, payload },
@ -27,8 +30,20 @@ impl Event {
}; };
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize") serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
} }
}
pub fn event_name(&self) -> String { pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
if let Event::Ping = self {
None
} else {
Some((
warp::sse::event(self.event_name()),
warp::sse::data(self.payload().unwrap_or_else(String::new)),
))
}
}
fn event_name(&self) -> String {
String::from(match self { String::from(match self {
Self::TypeSafe(checked) => match checked { Self::TypeSafe(checked) => match checked {
CheckedEvent::Update { .. } => "update", CheckedEvent::Update { .. } => "update",
@ -45,11 +60,11 @@ impl Event {
.. ..
}) => "update", }) => "update",
Self::Dynamic(DynEvent { event, .. }) => event, Self::Dynamic(DynEvent { event, .. }) => event,
Self::Ping => panic!("event_name() called on Ping"), Self::Ping => unreachable!(), // private method only called above
}) })
} }
pub fn payload(&self) -> Option<String> { fn payload(&self) -> Option<String> {
use CheckedEvent::*; use CheckedEvent::*;
match self { match self {
Self::TypeSafe(checked) => match checked { Self::TypeSafe(checked) => match checked {
@ -63,7 +78,7 @@ impl Event {
FiltersChanged => None, FiltersChanged => None,
}, },
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()), Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
Self::Ping => panic!("payload() called on Ping"), Self::Ping => unreachable!(), // private method only called above
} }
} }
} }

View File

@ -43,7 +43,7 @@ type Result<T> = std::result::Result<T, EventErr>;
impl DynEvent { impl DynEvent {
pub fn set_update(self) -> Result<Self> { pub fn set_update(self) -> Result<Self> {
if self.event == "update" { if self.event == "update" {
let kind = EventKind::Update(DynStatus::new(self.payload.clone())?); let kind = EventKind::Update(DynStatus::new(&self.payload.clone())?);
Ok(Self { kind, ..self }) Ok(Self { kind, ..self })
} else { } else {
Ok(self) Ok(self)
@ -52,7 +52,7 @@ impl DynEvent {
} }
impl DynStatus { impl DynStatus {
pub fn new(payload: Value) -> Result<Self> { pub fn new(payload: &Value) -> Result<Self> {
use EventErr::*; use EventErr::*;
Ok(Self { Ok(Self {
@ -61,7 +61,7 @@ impl DynStatus {
.as_str() .as_str()
.ok_or(DynParse)? .ok_or(DynParse)?
.to_string(), .to_string(),
language: payload["language"].as_str().map(|s| s.to_string()), language: payload["language"].as_str().map(String::from),
mentioned_users: HashSet::new(), mentioned_users: HashSet::new(),
replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(), replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(),
boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(), boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(),

View File

@ -37,6 +37,7 @@
//#![warn(clippy::pedantic)] //#![warn(clippy::pedantic)]
#![allow(clippy::try_err, clippy::match_bool)] #![allow(clippy::try_err, clippy::match_bool)]
//#![allow(clippy::large_enum_variant)]
pub mod config; pub mod config;
pub mod err; pub mod err;

View File

@ -19,20 +19,19 @@ use warp::Filter;
fn main() -> Result<(), FatalErr> { fn main() -> Result<(), FatalErr> {
config::merge_dotenv()?; config::merge_dotenv()?;
pretty_env_logger::try_init()?; pretty_env_logger::try_init()?;
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect()); let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?;
let poll_freq = *redis_cfg.polling_interval;
// Create channels to communicate between threads // Create channels to communicate between threads
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping)); let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode); let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?;
let poll_freq = *redis_cfg.polling_interval; let shared_manager = redis::Manager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc();
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
// Server Sent Events // Server Sent Events
let sse_manager = shared_manager.clone(); let sse_manager = shared_manager.clone();
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone()); let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
let sse = request let sse = request
.sse_subscription() .sse_subscription()
.and(warp::sse()) .and(warp::sse())
@ -85,8 +84,7 @@ fn main() -> Result<(), FatalErr> {
.map_err(|e| log::error!("{}", e)) .map_err(|e| log::error!("{}", e))
.for_each(move |_| { .for_each(move |_| {
let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover); let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover);
manager.poll_broadcast().unwrap_or_else(FatalErr::exit); manager.poll_broadcast().map_err(FatalErr::log)
Ok(())
}); });
warp::spawn(lazy(move || stream)); warp::spawn(lazy(move || stream));
warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err)) warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err))
@ -95,13 +93,13 @@ fn main() -> Result<(), FatalErr> {
if let Some(socket) = &*cfg.unix_socket { if let Some(socket) = &*cfg.unix_socket {
log::info!("Using Unix socket {}", socket); log::info!("Using Unix socket {}", socket);
fs::remove_file(socket).unwrap_or_default(); fs::remove_file(socket).unwrap_or_default();
let incoming = UnixListener::bind(socket).expect("TODO").incoming(); let incoming = UnixListener::bind(socket)?.incoming();
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).expect("TODO"); fs::set_permissions(socket, PermissionsExt::from_mode(0o666))?;
tokio::run(lazy(|| streaming_server().serve_incoming(incoming))); tokio::run(lazy(|| streaming_server().serve_incoming(incoming)));
} else { } else {
let server_addr = SocketAddr::new(*cfg.address, *cfg.port); let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
tokio::run(lazy(move || streaming_server().bind(server_addr))); tokio::run(lazy(move || streaming_server().bind(server_addr)));
} }
Ok(()) Err(FatalErr::Unrecoverable) // on get here if there's an unrecoverable error in poll_broadcast.
} }

View File

@ -3,8 +3,10 @@ mod postgres;
mod query; mod query;
pub mod timeline; pub mod timeline;
mod err;
mod subscription; mod subscription;
pub use self::err::RequestErr;
pub use self::postgres::PgPool; pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API // TODO consider whether we can remove `Stream` from public API
pub use subscription::{Blocks, Subscription}; pub use subscription::{Blocks, Subscription};
@ -22,6 +24,8 @@ mod sse_test;
#[cfg(test)] #[cfg(test)]
mod ws_test; mod ws_test;
type Result<T> = std::result::Result<T, err::RequestErr>;
/// Helper macro to match on the first of any of the provided filters /// Helper macro to match on the first of any of the provided filters
macro_rules! any_of { macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => { ($filter:expr, $($other_filter:expr),*) => {
@ -56,10 +60,10 @@ pub struct Handler {
} }
impl Handler { impl Handler {
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self { pub fn new(postgres_cfg: &config::Postgres, whitelist_mode: bool) -> Result<Self> {
Self { Ok(Self {
pg_conn: PgPool::new(postgres_cfg, whitelist_mode), pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?,
} })
} }
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> { pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
@ -113,7 +117,7 @@ impl Handler {
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed() warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
} }
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> { pub fn err(r: Rejection) -> std::result::Result<impl warp::Reply, warp::Rejection> {
let json_err = match r.cause() { let json_err = match r.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => { Some(text) if text.to_string() == "Missing request header 'authorization'" => {
warp::reply::json(&"Error: Missing access token".to_string()) warp::reply::json(&"Error: Missing access token".to_string())

32
src/request/err.rs Normal file
View File

@ -0,0 +1,32 @@
use std::fmt;
#[derive(Debug)]
pub enum RequestErr {
Unknown,
PgPool(r2d2::Error),
Pg(postgres::Error),
}
impl std::error::Error for RequestErr {}
impl fmt::Display for RequestErr {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
use RequestErr::*;
let msg = match self {
Unknown => "Encountered an unrecoverable error related to handling a request".into(),
PgPool(e) => format!("{}", e),
Pg(e) => format!("{}", e),
};
write!(f, "{}", msg)
}
}
impl From<r2d2::Error> for RequestErr {
fn from(e: r2d2::Error) -> Self {
Self::PgPool(e)
}
}
impl From<postgres::Error> for RequestErr {
fn from(e: postgres::Error) -> Self {
Self::Pg(e)
}
}

View File

@ -1,13 +1,13 @@
//! Postgres queries //! Postgres queries
use super::err;
use super::timeline::{Scope, UserData};
use crate::config; use crate::config;
use crate::event::Id; use crate::event::Id;
use crate::request::timeline::{Scope, UserData};
use ::postgres; use ::postgres;
use hashbrown::HashSet; use hashbrown::HashSet;
use r2d2_postgres::PostgresConnectionManager; use r2d2_postgres::PostgresConnectionManager;
use std::convert::TryFrom; use std::convert::TryFrom;
use warp::reject::Rejection;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PgPool { pub struct PgPool {
@ -15,8 +15,11 @@ pub struct PgPool {
whitelist_mode: bool, whitelist_mode: bool,
} }
type Result<T> = std::result::Result<T, err::RequestErr>;
type Rejectable<T> = std::result::Result<T, warp::Rejection>;
impl PgPool { impl PgPool {
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self { pub fn new(pg_cfg: &config::Postgres, whitelist_mode: bool) -> Result<Self> {
let mut cfg = postgres::Config::new(); let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user) cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string()) .host(&*pg_cfg.host.to_string())
@ -26,19 +29,20 @@ impl PgPool {
cfg.password(password); cfg.password(password);
}; };
cfg.connect(postgres::NoTls)?; // Test connection, letting us immediately exit with an error
// when Postgres isn't running instead of timing out below
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder() let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
.max_size(10)
.build(manager) Ok(Self {
.expect("Can connect to local postgres");
Self {
conn: pool, conn: pool,
whitelist_mode, whitelist_mode,
} })
} }
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> { pub fn select_user(self, token: &Option<String>) -> Rejectable<UserData> {
let mut conn = self.conn.get().unwrap(); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
if let Some(token) = token { if let Some(token) = token {
let query_rows = conn let query_rows = conn
.query(" .query("
@ -48,8 +52,8 @@ INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1", LIMIT 1",
&[&token.to_owned()], &[&token.to_owned()],
) ).map_err(warp::reject::custom)?;
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) { if let Some(result_columns) = query_rows.get(0) {
let id = Id(result_columns.get(1)); let id = Id(result_columns.get(1));
let allowed_langs = result_columns let allowed_langs = result_columns
@ -85,10 +89,10 @@ LIMIT 1",
} }
} }
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> { pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable<i64> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name]) conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.get(0) .get(0)
.map(|row| row.get(0)) .map(|row| row.get(0))
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
@ -98,31 +102,31 @@ LIMIT 1",
/// ///
/// **NOTE**: because we check this when the user connects, it will not include any blocks /// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect. /// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> { pub fn select_blocked_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query( conn.query(
"SELECT target_account_id FROM blocks WHERE account_id = $1 "SELECT target_account_id FROM blocks WHERE account_id = $1
UNION SELECT target_account_id FROM mutes WHERE account_id = $1", UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
&[&*user_id], &[&*user_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.iter() .iter()
.map(|row| Id(row.get(0))) .map(|row| Ok(Id(row.get(0))))
.collect() .collect()
} }
/// Query Postgres for everyone who has blocked the user /// Query Postgres for everyone who has blocked the user
/// ///
/// **NOTE**: because we check this when the user connects, it will not include any blocks /// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect. /// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> { pub fn select_blocking_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query( conn.query(
"SELECT account_id FROM blocks WHERE target_account_id = $1", "SELECT account_id FROM blocks WHERE target_account_id = $1",
&[&*user_id], &[&*user_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.iter() .iter()
.map(|row| Id(row.get(0))) .map(|row| Ok(Id(row.get(0))))
.collect() .collect()
} }
@ -130,28 +134,28 @@ LIMIT 1",
/// ///
/// **NOTE**: because we check this when the user connects, it will not include any blocks /// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect. /// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> { pub fn select_blocked_domains(self, user_id: Id) -> Rejectable<HashSet<String>> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
conn.query( conn.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1", "SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&*user_id], &[&*user_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])") .map_err(warp::reject::custom)?
.iter() .iter()
.map(|row| row.get(0)) .map(|row| Ok(row.get(0)))
.collect() .collect()
} }
/// Test whether a user owns a list /// Test whether a user owns a list
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool { pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable<bool> {
let mut conn = self.conn.get().expect("TODO"); let mut conn = self.conn.get().map_err(warp::reject::custom)?;
// For the Postgres query, `id` = list number; `account_id` = user.id // For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn let rows = &conn
.query( .query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_id], &[&list_id],
) )
.expect("Hard-coded query will return Some([0 or more rows])"); .map_err(warp::reject::custom)?;
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id) Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id))
} }
} }

View File

@ -54,7 +54,7 @@ impl Subscription {
let tag = pool.select_hashtag_id(&q.hashtag)?; let tag = pool.select_hashtag_id(&q.hashtag)?;
Timeline(Hashtag(tag), reach, stream) Timeline(Hashtag(tag), reach, stream)
} }
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => { Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => {
Err(warp::reject::custom("Error: Missing access token"))? Err(warp::reject::custom("Error: Missing access token"))?
} }
other_tl => other_tl, other_tl => other_tl,
@ -70,9 +70,9 @@ impl Subscription {
timeline, timeline,
allowed_langs: user.allowed_langs, allowed_langs: user.allowed_langs,
blocks: Blocks { blocks: Blocks {
blocking_users: pool.clone().select_blocking_users(user.id), blocking_users: pool.clone().select_blocking_users(user.id)?,
blocked_users: pool.clone().select_blocked_users(user.id), blocked_users: pool.clone().select_blocked_users(user.id)?,
blocked_domains: pool.select_blocked_domains(user.id), blocked_domains: pool.select_blocked_domains(user.id)?,
}, },
hashtag_name, hashtag_name,
access_token: q.access_token, access_token: q.access_token,

View File

@ -19,25 +19,29 @@ impl Timeline {
} }
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String> { pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String> {
use {Content::*, Reach::*, Stream::*}; // TODO -- does this need to account for namespaces?
use {Content::*, Reach::*, Stream::*, TimelineErr::*};
Ok(match self { Ok(match self {
Timeline(Public, Federated, All) => "timeline:public".into(), Timeline(Public, Federated, All) => "timeline:public".to_string(),
Timeline(Public, Local, All) => "timeline:public:local".into(), Timeline(Public, Local, All) => "timeline:public:local".to_string(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(), Timeline(Public, Federated, Media) => "timeline:public:media".to_string(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(), Timeline(Public, Local, Media) => "timeline:public:local:media".to_string(),
// TODO -- would `.push_str` be faster here? Timeline(Hashtag(_id), Federated, All) => {
Timeline(Hashtag(_id), Federated, All) => format!( ["timeline:hashtag:", hashtag.ok_or(MissingHashtag)?].concat()
"timeline:hashtag:{}", }
hashtag.ok_or(TimelineErr::MissingHashtag)? Timeline(Hashtag(_id), Local, All) => [
), "timeline:hashtag:",
Timeline(Hashtag(_id), Local, All) => format!( hashtag.ok_or(MissingHashtag)?,
"timeline:hashtag:{}:local", ":local",
hashtag.ok_or(TimelineErr::MissingHashtag)? ]
), .concat(),
Timeline(User(id), Federated, All) => format!("timeline:{}", id), Timeline(User(id), Federated, All) => ["timeline:", &id.to_string()].concat(),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), Timeline(User(id), Federated, Notification) => {
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), ["timeline:", &id.to_string(), ":notification"].concat()
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), }
Timeline(List(id), Federated, All) => ["timeline:list:", &id.to_string()].concat(),
Timeline(Direct(id), Federated, All) => ["timeline:direct:", &id.to_string()].concat(),
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?, Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
}) })
} }
@ -57,8 +61,7 @@ impl Timeline {
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification), [id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
["list", id] => Timeline(List(id.parse()?), Federated, All), ["list", id] => Timeline(List(id.parse()?), Federated, All),
["direct", id] => Timeline(Direct(id.parse()?), Federated, All), ["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
// Other endpoints don't exist: [..] => Err(InvalidInput)?, // Other endpoints don't exist
[..] => Err(InvalidInput)?,
}) })
} }

View File

@ -12,15 +12,43 @@ pub enum RedisCmd {
} }
impl RedisCmd { impl RedisCmd {
pub fn into_sendable(&self, tl: &String) -> (Vec<u8>, Vec<u8>) { pub fn into_sendable(self, tl: &str) -> (Vec<u8>, Vec<u8>) {
match self { match self {
RedisCmd::Subscribe => ( RedisCmd::Subscribe => (
format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(), [
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl).into_bytes(), b"*2\r\n$9\r\nsubscribe\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n",
]
.concat(),
[
b"*3\r\n$3\r\nSET\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n$1\r\n1\r\n",
]
.concat(),
), ),
RedisCmd::Unsubscribe => ( RedisCmd::Unsubscribe => (
format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(), [
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl).into_bytes(), b"*2\r\n$11\r\nunsubscribe\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n",
]
.concat(),
[
b"*3\r\n$3\r\nSET\r\n$",
tl.len().to_string().as_bytes(),
b"\r\n",
tl.as_bytes(),
b"\r\n$1\r\n0\r\n",
]
.concat(),
), ),
} }
} }

View File

@ -28,8 +28,9 @@ pub struct RedisConn {
} }
impl RedisConn { impl RedisConn {
pub fn new(redis_cfg: Redis) -> Result<Self> { pub fn new(redis_cfg: &Redis) -> Result<Self> {
let addr = format!("{}:{}", *redis_cfg.host, *redis_cfg.port); let addr = [&*redis_cfg.host, ":", &*redis_cfg.port.to_string()].concat();
let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?; let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?;
conn.set_nonblocking(true) conn.set_nonblocking(true)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?; .map_err(|e| RedisConnErr::with_addr(&addr, e))?;
@ -49,7 +50,7 @@ impl RedisConn {
pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ManagerErr> { pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ManagerErr> {
let mut size = 100; // large enough to handle subscribe/unsubscribe notice let mut size = 100; // large enough to handle subscribe/unsubscribe notice
let (mut buffer, mut first_read) = (vec![0u8; size], true); let (mut buffer, mut first_read) = (vec![0_u8; size], true);
loop { loop {
match self.primary.read(&mut buffer) { match self.primary.read(&mut buffer) {
Ok(n) if n != size => break self.redis_input.extend_from_slice(&buffer[..n]), Ok(n) if n != size => break self.redis_input.extend_from_slice(&buffer[..n]),
@ -81,7 +82,7 @@ impl RedisConn {
use {Async::*, RedisParseOutput::*}; use {Async::*, RedisParseOutput::*};
let (res, leftover) = match RedisParseOutput::try_from(input) { let (res, leftover) = match RedisParseOutput::try_from(input) {
Ok(Msg(msg)) => match &self.redis_namespace { Ok(Msg(msg)) => match &self.redis_namespace {
Some(ns) if msg.timeline_txt.starts_with(&format!("{}:timeline:", ns)) => { Some(ns) if msg.timeline_txt.starts_with(&[ns, ":timeline:"].concat()) => {
let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..]; let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..];
let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?; let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?;
let event = msg.event_txt.try_into()?; let event = msg.event_txt.try_into()?;
@ -135,7 +136,16 @@ impl RedisConn {
} }
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> { fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes()) conn.write_all(
&[
b"*2\r\n$4\r\nauth\r\n$",
pass.len().to_string().as_bytes(),
b"\r\n",
pass.as_bytes(),
b"\r\n",
]
.concat(),
)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?; .map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let mut buffer = vec![0_u8; 5]; let mut buffer = vec![0_u8; 5];
conn.read_exact(&mut buffer) conn.read_exact(&mut buffer)

View File

@ -31,7 +31,7 @@ impl Manager {
/// Create a new `Manager`, with its own Redis connections (but, as yet, no /// Create a new `Manager`, with its own Redis connections (but, as yet, no
/// active subscriptions). /// active subscriptions).
pub fn try_from( pub fn try_from(
redis_cfg: config::Redis, redis_cfg: &config::Redis,
tx: watch::Sender<(Timeline, Event)>, tx: watch::Sender<(Timeline, Event)>,
rx: mpsc::UnboundedReceiver<Timeline>, rx: mpsc::UnboundedReceiver<Timeline>,
) -> Result<Self> { ) -> Result<Self> {
@ -99,11 +99,10 @@ impl Manager {
self.tx.broadcast((Timeline::empty(), Event::Ping))? self.tx.broadcast((Timeline::empty(), Event::Ping))?
} else { } else {
match self.redis_connection.poll_redis() { match self.redis_connection.poll_redis() {
Ok(Async::NotReady) => (), Ok(Async::NotReady) | Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace
Ok(Async::Ready(Some((timeline, event)))) => { Ok(Async::Ready(Some((timeline, event)))) => {
self.tx.broadcast((timeline, event))? self.tx.broadcast((timeline, event))?
} }
Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace
Err(err) => log::error!("{}", err), // drop msg, log err, and proceed Err(err) => log::error!("{}", err), // drop msg, log err, and proceed
} }
} }

View File

@ -82,10 +82,7 @@ fn utf8_to_redis_data<'a>(s: &'a str) -> Result<(RedisData, &'a str), RedisParse
":" => parse_redis_int(s), ":" => parse_redis_int(s),
"$" => parse_redis_bulk_string(s), "$" => parse_redis_bulk_string(s),
"*" => parse_redis_array(s), "*" => parse_redis_array(s),
e => Err(InvalidLineStart(format!( e => Err(InvalidLineStart(e.to_string())),
"Encountered invalid initial character `{}` in line `{}`",
e, s
))),
} }
} }

View File

@ -6,18 +6,11 @@ use log;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
use warp::reply::Reply; use warp::reply::Reply;
use warp::sse::{ServerSentEvent, Sse as WarpSse}; use warp::sse::Sse as WarpSse;
pub struct Sse; pub struct Sse;
impl Sse { impl Sse {
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
))
}
pub fn send_events( pub fn send_events(
sse: WarpSse, sse: WarpSse,
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>, mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
@ -40,21 +33,20 @@ impl Sse {
TypeSafe(Update { payload, queued_at }) => match timeline { TypeSafe(Update { payload, queued_at }) => match timeline {
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None, Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
_ if payload.involves_any(&blocks) => None, _ if payload.involves_any(&blocks) => None,
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update { _ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at })
payload, .to_warp_reply(),
queued_at,
})),
}, },
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)), TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(),
Dynamic(dyn_event) => { Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind { if let EventKind::Update(s) = dyn_event.kind {
match timeline { match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None, Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
_ if s.involves_any(&blocks) => None, _ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(DynEvent { _ => Dynamic(DynEvent {
kind: EventKind::Update(s), kind: EventKind::Update(s),
..dyn_event ..dyn_event
})), })
.to_warp_reply(),
} }
} else { } else {
None None

View File

@ -39,9 +39,11 @@ impl Ws {
.map_err(|_| -> warp::Error { unreachable!() }) .map_err(|_| -> warp::Error { unreachable!() })
.forward(transmit_to_ws) .forward(transmit_to_ws)
.map(|_r| ()) .map(|_r| ())
.map_err(|e| match e.to_string().as_ref() { .map_err(|e| {
match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket "IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("WebSocket send error: {}", e), _ => log::warn!("WebSocket send error: {}", e),
}
}), }),
); );
@ -50,7 +52,7 @@ impl Ws {
incoming_events.for_each(move |(tl, event)| { incoming_events.for_each(move |(tl, event)| {
if matches!(event, Event::Ping) { if matches!(event, Event::Ping) {
self.send_ping() self.send_msg(&event)
} else if target_timeline == tl { } else if target_timeline == tl {
use crate::event::{CheckedEvent::Update, Event::*, EventKind}; use crate::event::{CheckedEvent::Update, Event::*, EventKind};
use crate::request::Stream::Public; use crate::request::Stream::Public;
@ -61,18 +63,18 @@ impl Ws {
TypeSafe(Update { payload, queued_at }) => match tl { TypeSafe(Update { payload, queued_at }) => match tl {
Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()), Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()),
_ if payload.involves_any(&blocks) => Ok(()), _ if payload.involves_any(&blocks) => Ok(()),
_ => self.send_msg(TypeSafe(Update { payload, queued_at })), _ => self.send_msg(&TypeSafe(Update { payload, queued_at })),
}, },
TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)), TypeSafe(non_update) => self.send_msg(&TypeSafe(non_update)),
Dynamic(dyn_event) => { Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind.clone() { if let EventKind::Update(s) = dyn_event.kind.clone() {
match tl { match tl {
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()), Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
_ if s.involves_any(&blocks) => Ok(()), _ if s.involves_any(&blocks) => Ok(()),
_ => self.send_msg(Dynamic(dyn_event)), _ => self.send_msg(&Dynamic(dyn_event)),
} }
} else { } else {
self.send_msg(Dynamic(dyn_event)) self.send_msg(&Dynamic(dyn_event))
} }
} }
Ping => unreachable!(), // handled pings above Ping => unreachable!(), // handled pings above
@ -83,24 +85,14 @@ impl Ws {
}) })
} }
fn send_ping(&mut self) -> Result<(), ()> { fn send_msg(&mut self, event: &Event) -> Result<(), ()> {
self.send_txt("{}") let txt = &event.to_json_string();
}
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
self.send_txt(&event.to_json_string())
}
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
let tl = self.subscription.timeline; let tl = self.subscription.timeline;
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) { let mut channel = self.ws_tx.clone().ok_or(())?;
Ok(_) => Ok(()), channel.try_send(Message::text(txt)).map_err(|_| {
Err(_) => {
self.unsubscribe_tx self.unsubscribe_tx
.try_send(tl) .try_send(tl)
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e)); .unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
Err(()) })
}
}
} }
} }