mirror of https://github.com/mastodon/flodgatt
Error handling, pt3 (#131)
* Improve handling of Postgres errors * Finish error handling improvements * Remove `format!` calls from hot path
This commit is contained in:
parent
45f9d4b9fb
commit
37b652ad79
|
@ -406,7 +406,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "flodgatt"
|
name = "flodgatt"
|
||||||
version = "0.8.3"
|
version = "0.8.4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
[package]
|
[package]
|
||||||
name = "flodgatt"
|
name = "flodgatt"
|
||||||
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
|
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
|
||||||
version = "0.8.3"
|
version = "0.8.4"
|
||||||
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
|
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
pub use {deployment_cfg::Deployment, postgres_cfg::Postgres, redis_cfg::Redis};
|
pub use {deployment_cfg::Deployment, postgres_cfg::Postgres, redis_cfg::Redis};
|
||||||
|
|
||||||
use self::environmental_variables::EnvVar;
|
use self::environmental_variables::EnvVar;
|
||||||
use super::err;
|
use super::err::FatalErr;
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
|
@ -13,22 +13,45 @@ mod postgres_cfg_types;
|
||||||
mod redis_cfg;
|
mod redis_cfg;
|
||||||
mod redis_cfg_types;
|
mod redis_cfg_types;
|
||||||
|
|
||||||
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
|
type Result<T> = std::result::Result<T, FatalErr>;
|
||||||
// TODO -- should this allow the user to run in a dir without a `.env` file?
|
|
||||||
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
|
pub fn merge_dotenv() -> Result<()> {
|
||||||
|
let env_file = match env::var("ENV").ok().as_deref() {
|
||||||
Some("production") => ".env.production",
|
Some("production") => ".env.production",
|
||||||
Some("development") | None => ".env",
|
Some("development") | None => ".env",
|
||||||
Some(_unsupported) => Err(err::FatalErr::Unknown)?, // TODO make more specific
|
Some(v) => Err(FatalErr::config("ENV", v, "`production` or `development`"))?,
|
||||||
})?;
|
};
|
||||||
|
let res = dotenv::from_filename(env_file);
|
||||||
|
|
||||||
|
if let Ok(log_level) = env::var("RUST_LOG") {
|
||||||
|
if res.is_err() && ["warn", "info", "trace", "debug"].contains(&log_level.as_str()) {
|
||||||
|
eprintln!(
|
||||||
|
" WARN: could not load environmental variables from {:?}\n\
|
||||||
|
{:8}Are you in the right directory? Proceeding with variables from the environment.",
|
||||||
|
env::current_dir().unwrap_or_else(|_|"./".into()).join(env_file), ""
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_env<'a>(env_vars: HashMap<String, String>) -> (Postgres, Redis, Deployment<'a>) {
|
#[allow(clippy::implicit_hasher)]
|
||||||
|
pub fn from_env<'a>(
|
||||||
|
env_vars: HashMap<String, String>,
|
||||||
|
) -> Result<(Postgres, Redis, Deployment<'a>)> {
|
||||||
let env_vars = EnvVar::new(env_vars);
|
let env_vars = EnvVar::new(env_vars);
|
||||||
log::info!("Environmental variables Flodgatt received: {}", &env_vars);
|
log::info!(
|
||||||
(
|
"Flodgatt received the following environmental variables:{}",
|
||||||
Postgres::from_env(env_vars.clone()),
|
&env_vars
|
||||||
Redis::from_env(env_vars.clone()),
|
);
|
||||||
Deployment::from_env(env_vars.clone()),
|
|
||||||
)
|
let pg_cfg = Postgres::from_env(env_vars.clone())?;
|
||||||
|
log::info!("Configuration for {:#?}", &pg_cfg);
|
||||||
|
let redis_cfg = Redis::from_env(env_vars.clone())?;
|
||||||
|
log::info!("Configuration for {:#?},", &redis_cfg);
|
||||||
|
let deployment_cfg = Deployment::from_env(&env_vars)?;
|
||||||
|
log::info!("Configuration for {:#?}", &deployment_cfg);
|
||||||
|
|
||||||
|
Ok((pg_cfg, redis_cfg, deployment_cfg))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use super::{deployment_cfg_types::*, EnvVar};
|
use super::{deployment_cfg_types::*, EnvVar};
|
||||||
|
use crate::err::FatalErr;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct Deployment<'a> {
|
pub struct Deployment<'a> {
|
||||||
|
@ -14,20 +15,19 @@ pub struct Deployment<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deployment<'_> {
|
impl Deployment<'_> {
|
||||||
pub fn from_env(env: EnvVar) -> Self {
|
pub fn from_env(env: &EnvVar) -> Result<Self, FatalErr> {
|
||||||
let mut cfg = Self {
|
let mut cfg = Self {
|
||||||
env: Env::default().maybe_update(env.get("NODE_ENV")),
|
env: Env::default().maybe_update(env.get("NODE_ENV"))?,
|
||||||
log_level: LogLevel::default().maybe_update(env.get("RUST_LOG")),
|
log_level: LogLevel::default().maybe_update(env.get("RUST_LOG"))?,
|
||||||
address: FlodgattAddr::default().maybe_update(env.get("BIND")),
|
address: FlodgattAddr::default().maybe_update(env.get("BIND"))?,
|
||||||
port: Port::default().maybe_update(env.get("PORT")),
|
port: Port::default().maybe_update(env.get("PORT"))?,
|
||||||
unix_socket: Socket::default().maybe_update(env.get("SOCKET")),
|
unix_socket: Socket::default().maybe_update(env.get("SOCKET"))?,
|
||||||
sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ")),
|
sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ"))?,
|
||||||
ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ")),
|
ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ"))?,
|
||||||
whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE")),
|
whitelist_mode: WhitelistMode::default().maybe_update(env.get("WHITELIST_MODE"))?,
|
||||||
cors: Cors::default(),
|
cors: Cors::default(),
|
||||||
};
|
};
|
||||||
cfg.env = cfg.env.maybe_update(env.get("RUST_ENV"));
|
cfg.env = cfg.env.maybe_update(env.get("RUST_ENV"))?;
|
||||||
log::info!("Using deployment configuration:\n {:#?}", &cfg);
|
Ok(cfg)
|
||||||
cfg
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@ from_env_var!(
|
||||||
from_env_var!(
|
from_env_var!(
|
||||||
/// The address to run Flodgatt on
|
/// The address to run Flodgatt on
|
||||||
let name = FlodgattAddr;
|
let name = FlodgattAddr;
|
||||||
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
|
let default: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
|
||||||
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
|
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
|
||||||
let from_str = |s| match s {
|
let from_str = |s| match s {
|
||||||
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||||
|
|
|
@ -25,17 +25,6 @@ impl EnvVar {
|
||||||
self.0.insert(key.to_string(), value.to_string());
|
self.0.insert(key.to_string(), value.to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
|
|
||||||
log::error!(
|
|
||||||
r"{var} is set to `{value}`, which is invalid.
|
|
||||||
{var} must be {allowed_vals}.",
|
|
||||||
var = env_var,
|
|
||||||
value = supplied_value,
|
|
||||||
allowed_vals = allowed_values
|
|
||||||
);
|
|
||||||
std::process::exit(1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
impl fmt::Display for EnvVar {
|
impl fmt::Display for EnvVar {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
@ -98,7 +87,7 @@ macro_rules! from_env_var {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct $name(pub $type);
|
pub struct $name(pub $type);
|
||||||
impl std::fmt::Debug for $name {
|
impl std::fmt::Debug for $name {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
write!(f, "{:?}", self.0)
|
write!(f, "{:?}", self.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,14 +106,14 @@ macro_rules! from_env_var {
|
||||||
fn inner_from_str($arg: &str) -> Option<$type> {
|
fn inner_from_str($arg: &str) -> Option<$type> {
|
||||||
$body
|
$body
|
||||||
}
|
}
|
||||||
pub fn maybe_update(self, var: Option<&String>) -> Self {
|
pub fn maybe_update(self, var: Option<&String>) -> Result<Self, crate::err::FatalErr> {
|
||||||
match var {
|
Ok(match var {
|
||||||
Some(empty_string) if empty_string.is_empty() => Self::default(),
|
Some(empty_string) if empty_string.is_empty() => Self::default(),
|
||||||
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
Some(value) => Self(Self::inner_from_str(value).ok_or_else(|| {
|
||||||
crate::config::EnvVar::err($env_var, value, $allowed_values)
|
crate::err::FatalErr::config($env_var, value, $allowed_values)
|
||||||
})),
|
})?),
|
||||||
None => self,
|
None => self,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
use super::{postgres_cfg_types::*, EnvVar};
|
use super::{postgres_cfg_types::*, EnvVar};
|
||||||
|
use crate::err::FatalErr;
|
||||||
|
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use urlencoding;
|
use urlencoding;
|
||||||
|
|
||||||
|
type Result<T> = std::result::Result<T, FatalErr>;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Postgres {
|
pub struct Postgres {
|
||||||
pub user: PgUser,
|
pub user: PgUser,
|
||||||
|
@ -13,8 +17,8 @@ pub struct Postgres {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EnvVar {
|
impl EnvVar {
|
||||||
fn update_with_postgres_url(mut self, url_str: &str) -> Self {
|
fn update_with_postgres_url(mut self, url_str: &str) -> Result<Self> {
|
||||||
let url = Url::parse(url_str).unwrap();
|
let url = Url::parse(url_str)?;
|
||||||
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
|
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
|
||||||
|
|
||||||
for (k, v) in url.query_pairs().into_owned() {
|
for (k, v) in url.query_pairs().into_owned() {
|
||||||
|
@ -23,11 +27,11 @@ impl EnvVar {
|
||||||
"password" => self.maybe_add_env_var("DB_PASS", Some(v.to_string())),
|
"password" => self.maybe_add_env_var("DB_PASS", Some(v.to_string())),
|
||||||
"host" => self.maybe_add_env_var("DB_HOST", Some(v.to_string())),
|
"host" => self.maybe_add_env_var("DB_HOST", Some(v.to_string())),
|
||||||
"sslmode" => self.maybe_add_env_var("DB_SSLMODE", Some(v.to_string())),
|
"sslmode" => self.maybe_add_env_var("DB_SSLMODE", Some(v.to_string())),
|
||||||
_ => crate::err::die_with_msg(format!(
|
_ => Err(FatalErr::config(
|
||||||
r"Unsupported parameter {} in POSTGRES_URL
|
"POSTGRES_URL",
|
||||||
Flodgatt supports only `password`, `user`, `host`, and `sslmode` parameters",
|
&k,
|
||||||
k
|
"a URL with parameters `password`, `user`, `host`, and `sslmode` only",
|
||||||
)),
|
))?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,42 +39,39 @@ impl EnvVar {
|
||||||
self.maybe_add_env_var("DB_PASS", url.password());
|
self.maybe_add_env_var("DB_PASS", url.password());
|
||||||
self.maybe_add_env_var(
|
self.maybe_add_env_var(
|
||||||
"DB_HOST",
|
"DB_HOST",
|
||||||
url.host().map(|h| {
|
url.host()
|
||||||
urlencoding::decode(&h.to_string()).expect("Non-Unicode text in hostname")
|
.map(|host| urlencoding::decode(&host.to_string()))
|
||||||
}),
|
.transpose()?,
|
||||||
);
|
);
|
||||||
self.maybe_add_env_var("DB_USER", none_if_empty(url.username().to_string()));
|
self.maybe_add_env_var("DB_USER", none_if_empty(url.username().to_string()));
|
||||||
self.maybe_add_env_var("DB_NAME", none_if_empty(url.path()[1..].to_string()));
|
self.maybe_add_env_var("DB_NAME", none_if_empty(url.path()[1..].to_string()));
|
||||||
|
Ok(self)
|
||||||
self
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Postgres {
|
impl Postgres {
|
||||||
/// Configure Postgres and return a connection
|
/// Configure Postgres and return a connection
|
||||||
|
pub fn from_env(env: EnvVar) -> Result<Self> {
|
||||||
pub fn from_env(env: EnvVar) -> Self {
|
|
||||||
let env = match env.get("DATABASE_URL").cloned() {
|
let env = match env.get("DATABASE_URL").cloned() {
|
||||||
Some(url_str) => env.update_with_postgres_url(&url_str),
|
Some(url_str) => env.update_with_postgres_url(&url_str)?,
|
||||||
None => env,
|
None => env,
|
||||||
};
|
};
|
||||||
|
|
||||||
let cfg = Self {
|
let cfg = Self {
|
||||||
user: PgUser::default().maybe_update(env.get("DB_USER")),
|
user: PgUser::default().maybe_update(env.get("DB_USER"))?,
|
||||||
host: PgHost::default().maybe_update(env.get("DB_HOST")),
|
host: PgHost::default().maybe_update(env.get("DB_HOST"))?,
|
||||||
password: PgPass::default().maybe_update(env.get("DB_PASS")),
|
password: PgPass::default().maybe_update(env.get("DB_PASS"))?,
|
||||||
database: PgDatabase::default().maybe_update(env.get("DB_NAME")),
|
database: PgDatabase::default().maybe_update(env.get("DB_NAME"))?,
|
||||||
port: PgPort::default().maybe_update(env.get("DB_PORT")),
|
port: PgPort::default().maybe_update(env.get("DB_PORT"))?,
|
||||||
ssl_mode: PgSslMode::default().maybe_update(env.get("DB_SSLMODE")),
|
ssl_mode: PgSslMode::default().maybe_update(env.get("DB_SSLMODE"))?,
|
||||||
};
|
};
|
||||||
log::info!("Postgres configuration:\n{:#?}", &cfg);
|
Ok(cfg)
|
||||||
cfg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// // use openssl::ssl::{SslConnector, SslMethod};
|
// // use openssl::ssl::{SslConnector, SslMethod};
|
||||||
// // use postgres_openssl::MakeTlsConnector;
|
// // use postgres_openssl::MakeTlsConnector;
|
||||||
// // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
// // let mut builder = SslConnector::builder(SslMethod::tls())?;
|
||||||
// // builder.set_ca_file("/etc/ssl/cert.pem").unwrap();
|
// // builder.set_ca_file("/etc/ssl/cert.pem")?;
|
||||||
// // let connector = MakeTlsConnector::new(builder.build());
|
// // let connector = MakeTlsConnector::new(builder.build());
|
||||||
// // TODO: add TLS support, remove `NoTls`
|
// // TODO: add TLS support, remove `NoTls`
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
use super::redis_cfg_types::*;
|
use super::redis_cfg_types::*;
|
||||||
use crate::config::EnvVar;
|
use super::EnvVar;
|
||||||
|
use crate::err::FatalErr;
|
||||||
|
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
type Result<T> = std::result::Result<T, FatalErr>;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct Redis {
|
pub struct Redis {
|
||||||
pub user: RedisUser,
|
pub user: RedisUser,
|
||||||
|
@ -17,8 +21,8 @@ pub struct Redis {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EnvVar {
|
impl EnvVar {
|
||||||
fn update_with_redis_url(mut self, url_str: &str) -> Self {
|
fn update_with_redis_url(mut self, url_str: &str) -> Result<Self> {
|
||||||
let url = Url::parse(url_str).unwrap();
|
let url = Url::parse(url_str)?;
|
||||||
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
|
let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) };
|
||||||
|
|
||||||
self.maybe_add_env_var("REDIS_PORT", url.port());
|
self.maybe_add_env_var("REDIS_PORT", url.port());
|
||||||
|
@ -29,14 +33,14 @@ impl EnvVar {
|
||||||
match k.to_string().as_str() {
|
match k.to_string().as_str() {
|
||||||
"password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())),
|
"password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())),
|
||||||
"db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())),
|
"db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())),
|
||||||
_ => crate::err::die_with_msg(format!(
|
_ => Err(FatalErr::config(
|
||||||
r"Unsupported parameter {} in REDIS_URL.
|
"REDIS_URL",
|
||||||
Flodgatt supports only `password` and `db` parameters.",
|
&k,
|
||||||
k
|
"a URL with parameters `password`, `db`, only",
|
||||||
)),
|
))?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self
|
Ok(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,20 +50,20 @@ impl Redis {
|
||||||
const DB_SET_WARNING: &'static str = r"Redis database specified, but PubSub connections do not use databases.
|
const DB_SET_WARNING: &'static str = r"Redis database specified, but PubSub connections do not use databases.
|
||||||
For similar functionality, you may wish to set a REDIS_NAMESPACE";
|
For similar functionality, you may wish to set a REDIS_NAMESPACE";
|
||||||
|
|
||||||
pub fn from_env(env: EnvVar) -> Self {
|
pub fn from_env(env: EnvVar) -> Result<Self> {
|
||||||
let env = match env.get("REDIS_URL").cloned() {
|
let env = match env.get("REDIS_URL").cloned() {
|
||||||
Some(url_str) => env.update_with_redis_url(&url_str),
|
Some(url_str) => env.update_with_redis_url(&url_str)?,
|
||||||
None => env,
|
None => env,
|
||||||
};
|
};
|
||||||
|
|
||||||
let cfg = Redis {
|
let cfg = Redis {
|
||||||
user: RedisUser::default().maybe_update(env.get("REDIS_USER")),
|
user: RedisUser::default().maybe_update(env.get("REDIS_USER"))?,
|
||||||
password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD")),
|
password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD"))?,
|
||||||
port: RedisPort::default().maybe_update(env.get("REDIS_PORT")),
|
port: RedisPort::default().maybe_update(env.get("REDIS_PORT"))?,
|
||||||
host: RedisHost::default().maybe_update(env.get("REDIS_HOST")),
|
host: RedisHost::default().maybe_update(env.get("REDIS_HOST"))?,
|
||||||
db: RedisDb::default().maybe_update(env.get("REDIS_DB")),
|
db: RedisDb::default().maybe_update(env.get("REDIS_DB"))?,
|
||||||
namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE")),
|
namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE"))?,
|
||||||
polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_FREQ")),
|
polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_FREQ"))?,
|
||||||
};
|
};
|
||||||
|
|
||||||
if cfg.db.is_some() {
|
if cfg.db.is_some() {
|
||||||
|
@ -68,7 +72,6 @@ For similar functionality, you may wish to set a REDIS_NAMESPACE";
|
||||||
if cfg.user.is_some() {
|
if cfg.user.is_some() {
|
||||||
log::warn!("{}", Self::USER_SET_WARNING);
|
log::warn!("{}", Self::USER_SET_WARNING);
|
||||||
}
|
}
|
||||||
log::info!("Redis configuration:\n{:#?},", &cfg);
|
Ok(cfg)
|
||||||
cfg
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
56
src/err.rs
56
src/err.rs
|
@ -1,17 +1,29 @@
|
||||||
|
use crate::request::RequestErr;
|
||||||
use crate::response::ManagerErr;
|
use crate::response::ManagerErr;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
pub enum FatalErr {
|
pub enum FatalErr {
|
||||||
Unknown,
|
|
||||||
ReceiverErr(ManagerErr),
|
ReceiverErr(ManagerErr),
|
||||||
DotEnv(dotenv::Error),
|
|
||||||
Logger(log::SetLoggerError),
|
Logger(log::SetLoggerError),
|
||||||
|
Postgres(RequestErr),
|
||||||
|
Unrecoverable,
|
||||||
|
StdIo(std::io::Error),
|
||||||
|
// config errs
|
||||||
|
UrlParse(url::ParseError),
|
||||||
|
UrlEncoding(urlencoding::FromUrlEncodingError),
|
||||||
|
ConfigErr(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FatalErr {
|
impl FatalErr {
|
||||||
pub fn exit(msg: impl fmt::Display) {
|
pub fn log(msg: impl fmt::Display) {
|
||||||
eprintln!("{}", msg);
|
eprintln!("{}", msg);
|
||||||
std::process::exit(1);
|
}
|
||||||
|
|
||||||
|
pub fn config<T: fmt::Display>(var: T, value: T, allowed_vals: T) -> Self {
|
||||||
|
Self::ConfigErr(format!(
|
||||||
|
"{0} is set to `{1}`, which is invalid.\n{3:7}{0} must be {2}.",
|
||||||
|
var, value, allowed_vals, ""
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,18 +41,22 @@ impl fmt::Display for FatalErr {
|
||||||
f,
|
f,
|
||||||
"{}",
|
"{}",
|
||||||
match self {
|
match self {
|
||||||
Unknown => "Flodgatt encountered an unknown, unrecoverable error".into(),
|
|
||||||
ReceiverErr(e) => format!("{}", e),
|
ReceiverErr(e) => format!("{}", e),
|
||||||
Logger(e) => format!("{}", e),
|
Logger(e) => format!("{}", e),
|
||||||
DotEnv(e) => format!("Could not load specified environmental file: {}", e),
|
StdIo(e) => format!("{}", e),
|
||||||
|
Postgres(e) => format!("could not connect to Postgres.\n{:7}{}", "", e),
|
||||||
|
ConfigErr(e) => e.to_string(),
|
||||||
|
UrlParse(e) => format!("could parse Postgres URL.\n{:7}{}", "", e),
|
||||||
|
UrlEncoding(e) => format!("could not parse POSTGRES_URL.\n{:7}{:?}", "", e),
|
||||||
|
Unrecoverable => "Flodgatt will now shut down.".into(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<dotenv::Error> for FatalErr {
|
impl From<RequestErr> for FatalErr {
|
||||||
fn from(e: dotenv::Error) -> Self {
|
fn from(e: RequestErr) -> Self {
|
||||||
Self::DotEnv(e)
|
Self::Postgres(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,15 +65,23 @@ impl From<ManagerErr> for FatalErr {
|
||||||
Self::ReceiverErr(e)
|
Self::ReceiverErr(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl From<urlencoding::FromUrlEncodingError> for FatalErr {
|
||||||
|
fn from(e: urlencoding::FromUrlEncodingError) -> Self {
|
||||||
|
Self::UrlEncoding(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<url::ParseError> for FatalErr {
|
||||||
|
fn from(e: url::ParseError) -> Self {
|
||||||
|
Self::UrlParse(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<std::io::Error> for FatalErr {
|
||||||
|
fn from(e: std::io::Error) -> Self {
|
||||||
|
Self::StdIo(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
impl From<log::SetLoggerError> for FatalErr {
|
impl From<log::SetLoggerError> for FatalErr {
|
||||||
fn from(e: log::SetLoggerError) -> Self {
|
fn from(e: log::SetLoggerError) -> Self {
|
||||||
Self::Logger(e)
|
Self::Logger(e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO delete vvvv when postgres_cfg.rs has better error handling
|
|
||||||
pub fn die_with_msg(msg: impl fmt::Display) -> ! {
|
|
||||||
eprintln!("FATAL ERROR: {}", msg);
|
|
||||||
std::process::exit(1);
|
|
||||||
}
|
|
||||||
|
|
47
src/event.rs
47
src/event.rs
|
@ -2,14 +2,14 @@ mod checked_event;
|
||||||
mod dynamic_event;
|
mod dynamic_event;
|
||||||
mod err;
|
mod err;
|
||||||
|
|
||||||
pub use {
|
pub use checked_event::{CheckedEvent, Id};
|
||||||
checked_event::{CheckedEvent, Id},
|
pub use dynamic_event::{DynEvent, DynStatus, EventKind};
|
||||||
dynamic_event::{DynEvent, DynStatus, EventKind},
|
pub use err::EventErr;
|
||||||
err::EventErr,
|
|
||||||
};
|
|
||||||
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use std::{convert::TryFrom, string::String};
|
use std::convert::TryFrom;
|
||||||
|
use std::string::String;
|
||||||
|
use warp::sse::ServerSentEvent;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Event {
|
pub enum Event {
|
||||||
|
@ -20,15 +20,30 @@ pub enum Event {
|
||||||
|
|
||||||
impl Event {
|
impl Event {
|
||||||
pub fn to_json_string(&self) -> String {
|
pub fn to_json_string(&self) -> String {
|
||||||
let event = &self.event_name();
|
if let Event::Ping = self {
|
||||||
let sendable_event = match self.payload() {
|
"{}".to_string()
|
||||||
Some(payload) => SendableEvent::WithPayload { event, payload },
|
} else {
|
||||||
None => SendableEvent::NoPayload { event },
|
let event = &self.event_name();
|
||||||
};
|
let sendable_event = match self.payload() {
|
||||||
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
|
Some(payload) => SendableEvent::WithPayload { event, payload },
|
||||||
|
None => SendableEvent::NoPayload { event },
|
||||||
|
};
|
||||||
|
serde_json::to_string(&sendable_event).expect("Guaranteed: SendableEvent is Serialize")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn event_name(&self) -> String {
|
pub fn to_warp_reply(&self) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
|
||||||
|
if let Event::Ping = self {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some((
|
||||||
|
warp::sse::event(self.event_name()),
|
||||||
|
warp::sse::data(self.payload().unwrap_or_else(String::new)),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn event_name(&self) -> String {
|
||||||
String::from(match self {
|
String::from(match self {
|
||||||
Self::TypeSafe(checked) => match checked {
|
Self::TypeSafe(checked) => match checked {
|
||||||
CheckedEvent::Update { .. } => "update",
|
CheckedEvent::Update { .. } => "update",
|
||||||
|
@ -45,11 +60,11 @@ impl Event {
|
||||||
..
|
..
|
||||||
}) => "update",
|
}) => "update",
|
||||||
Self::Dynamic(DynEvent { event, .. }) => event,
|
Self::Dynamic(DynEvent { event, .. }) => event,
|
||||||
Self::Ping => panic!("event_name() called on Ping"),
|
Self::Ping => unreachable!(), // private method only called above
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn payload(&self) -> Option<String> {
|
fn payload(&self) -> Option<String> {
|
||||||
use CheckedEvent::*;
|
use CheckedEvent::*;
|
||||||
match self {
|
match self {
|
||||||
Self::TypeSafe(checked) => match checked {
|
Self::TypeSafe(checked) => match checked {
|
||||||
|
@ -63,7 +78,7 @@ impl Event {
|
||||||
FiltersChanged => None,
|
FiltersChanged => None,
|
||||||
},
|
},
|
||||||
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
|
Self::Dynamic(DynEvent { payload, .. }) => Some(payload.to_string()),
|
||||||
Self::Ping => panic!("payload() called on Ping"),
|
Self::Ping => unreachable!(), // private method only called above
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ type Result<T> = std::result::Result<T, EventErr>;
|
||||||
impl DynEvent {
|
impl DynEvent {
|
||||||
pub fn set_update(self) -> Result<Self> {
|
pub fn set_update(self) -> Result<Self> {
|
||||||
if self.event == "update" {
|
if self.event == "update" {
|
||||||
let kind = EventKind::Update(DynStatus::new(self.payload.clone())?);
|
let kind = EventKind::Update(DynStatus::new(&self.payload.clone())?);
|
||||||
Ok(Self { kind, ..self })
|
Ok(Self { kind, ..self })
|
||||||
} else {
|
} else {
|
||||||
Ok(self)
|
Ok(self)
|
||||||
|
@ -52,7 +52,7 @@ impl DynEvent {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DynStatus {
|
impl DynStatus {
|
||||||
pub fn new(payload: Value) -> Result<Self> {
|
pub fn new(payload: &Value) -> Result<Self> {
|
||||||
use EventErr::*;
|
use EventErr::*;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -61,7 +61,7 @@ impl DynStatus {
|
||||||
.as_str()
|
.as_str()
|
||||||
.ok_or(DynParse)?
|
.ok_or(DynParse)?
|
||||||
.to_string(),
|
.to_string(),
|
||||||
language: payload["language"].as_str().map(|s| s.to_string()),
|
language: payload["language"].as_str().map(String::from),
|
||||||
mentioned_users: HashSet::new(),
|
mentioned_users: HashSet::new(),
|
||||||
replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(),
|
replied_to_user: Id::try_from(&payload["in_reply_to_account_id"]).ok(),
|
||||||
boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(),
|
boosted_user: Id::try_from(&payload["reblog"]["account"]["id"]).ok(),
|
||||||
|
|
|
@ -37,6 +37,7 @@
|
||||||
|
|
||||||
//#![warn(clippy::pedantic)]
|
//#![warn(clippy::pedantic)]
|
||||||
#![allow(clippy::try_err, clippy::match_bool)]
|
#![allow(clippy::try_err, clippy::match_bool)]
|
||||||
|
//#![allow(clippy::large_enum_variant)]
|
||||||
|
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod err;
|
pub mod err;
|
||||||
|
|
18
src/main.rs
18
src/main.rs
|
@ -19,20 +19,19 @@ use warp::Filter;
|
||||||
fn main() -> Result<(), FatalErr> {
|
fn main() -> Result<(), FatalErr> {
|
||||||
config::merge_dotenv()?;
|
config::merge_dotenv()?;
|
||||||
pretty_env_logger::try_init()?;
|
pretty_env_logger::try_init()?;
|
||||||
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect());
|
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect())?;
|
||||||
|
let poll_freq = *redis_cfg.polling_interval;
|
||||||
|
|
||||||
// Create channels to communicate between threads
|
// Create channels to communicate between threads
|
||||||
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
||||||
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode);
|
let request = Handler::new(&postgres_cfg, *cfg.whitelist_mode)?;
|
||||||
let poll_freq = *redis_cfg.polling_interval;
|
let shared_manager = redis::Manager::try_from(&redis_cfg, event_tx, cmd_rx)?.into_arc();
|
||||||
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
|
||||||
|
|
||||||
// Server Sent Events
|
// Server Sent Events
|
||||||
let sse_manager = shared_manager.clone();
|
let sse_manager = shared_manager.clone();
|
||||||
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
||||||
|
|
||||||
let sse = request
|
let sse = request
|
||||||
.sse_subscription()
|
.sse_subscription()
|
||||||
.and(warp::sse())
|
.and(warp::sse())
|
||||||
|
@ -85,8 +84,7 @@ fn main() -> Result<(), FatalErr> {
|
||||||
.map_err(|e| log::error!("{}", e))
|
.map_err(|e| log::error!("{}", e))
|
||||||
.for_each(move |_| {
|
.for_each(move |_| {
|
||||||
let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover);
|
let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||||
manager.poll_broadcast().unwrap_or_else(FatalErr::exit);
|
manager.poll_broadcast().map_err(FatalErr::log)
|
||||||
Ok(())
|
|
||||||
});
|
});
|
||||||
warp::spawn(lazy(move || stream));
|
warp::spawn(lazy(move || stream));
|
||||||
warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err))
|
warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err))
|
||||||
|
@ -95,13 +93,13 @@ fn main() -> Result<(), FatalErr> {
|
||||||
if let Some(socket) = &*cfg.unix_socket {
|
if let Some(socket) = &*cfg.unix_socket {
|
||||||
log::info!("Using Unix socket {}", socket);
|
log::info!("Using Unix socket {}", socket);
|
||||||
fs::remove_file(socket).unwrap_or_default();
|
fs::remove_file(socket).unwrap_or_default();
|
||||||
let incoming = UnixListener::bind(socket).expect("TODO").incoming();
|
let incoming = UnixListener::bind(socket)?.incoming();
|
||||||
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).expect("TODO");
|
fs::set_permissions(socket, PermissionsExt::from_mode(0o666))?;
|
||||||
|
|
||||||
tokio::run(lazy(|| streaming_server().serve_incoming(incoming)));
|
tokio::run(lazy(|| streaming_server().serve_incoming(incoming)));
|
||||||
} else {
|
} else {
|
||||||
let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
|
let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
|
||||||
tokio::run(lazy(move || streaming_server().bind(server_addr)));
|
tokio::run(lazy(move || streaming_server().bind(server_addr)));
|
||||||
}
|
}
|
||||||
Ok(())
|
Err(FatalErr::Unrecoverable) // on get here if there's an unrecoverable error in poll_broadcast.
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,8 +3,10 @@ mod postgres;
|
||||||
mod query;
|
mod query;
|
||||||
pub mod timeline;
|
pub mod timeline;
|
||||||
|
|
||||||
|
mod err;
|
||||||
mod subscription;
|
mod subscription;
|
||||||
|
|
||||||
|
pub use self::err::RequestErr;
|
||||||
pub use self::postgres::PgPool;
|
pub use self::postgres::PgPool;
|
||||||
// TODO consider whether we can remove `Stream` from public API
|
// TODO consider whether we can remove `Stream` from public API
|
||||||
pub use subscription::{Blocks, Subscription};
|
pub use subscription::{Blocks, Subscription};
|
||||||
|
@ -22,6 +24,8 @@ mod sse_test;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod ws_test;
|
mod ws_test;
|
||||||
|
|
||||||
|
type Result<T> = std::result::Result<T, err::RequestErr>;
|
||||||
|
|
||||||
/// Helper macro to match on the first of any of the provided filters
|
/// Helper macro to match on the first of any of the provided filters
|
||||||
macro_rules! any_of {
|
macro_rules! any_of {
|
||||||
($filter:expr, $($other_filter:expr),*) => {
|
($filter:expr, $($other_filter:expr),*) => {
|
||||||
|
@ -56,10 +60,10 @@ pub struct Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Handler {
|
impl Handler {
|
||||||
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
pub fn new(postgres_cfg: &config::Postgres, whitelist_mode: bool) -> Result<Self> {
|
||||||
Self {
|
Ok(Self {
|
||||||
pg_conn: PgPool::new(postgres_cfg, whitelist_mode),
|
pg_conn: PgPool::new(postgres_cfg, whitelist_mode)?,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
|
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
|
||||||
|
@ -113,7 +117,7 @@ impl Handler {
|
||||||
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
|
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> {
|
pub fn err(r: Rejection) -> std::result::Result<impl warp::Reply, warp::Rejection> {
|
||||||
let json_err = match r.cause() {
|
let json_err = match r.cause() {
|
||||||
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
||||||
warp::reply::json(&"Error: Missing access token".to_string())
|
warp::reply::json(&"Error: Missing access token".to_string())
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
use std::fmt;
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum RequestErr {
|
||||||
|
Unknown,
|
||||||
|
PgPool(r2d2::Error),
|
||||||
|
Pg(postgres::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for RequestErr {}
|
||||||
|
|
||||||
|
impl fmt::Display for RequestErr {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||||
|
use RequestErr::*;
|
||||||
|
let msg = match self {
|
||||||
|
Unknown => "Encountered an unrecoverable error related to handling a request".into(),
|
||||||
|
PgPool(e) => format!("{}", e),
|
||||||
|
Pg(e) => format!("{}", e),
|
||||||
|
};
|
||||||
|
write!(f, "{}", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<r2d2::Error> for RequestErr {
|
||||||
|
fn from(e: r2d2::Error) -> Self {
|
||||||
|
Self::PgPool(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<postgres::Error> for RequestErr {
|
||||||
|
fn from(e: postgres::Error) -> Self {
|
||||||
|
Self::Pg(e)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,13 +1,13 @@
|
||||||
//! Postgres queries
|
//! Postgres queries
|
||||||
|
use super::err;
|
||||||
|
use super::timeline::{Scope, UserData};
|
||||||
use crate::config;
|
use crate::config;
|
||||||
use crate::event::Id;
|
use crate::event::Id;
|
||||||
use crate::request::timeline::{Scope, UserData};
|
|
||||||
|
|
||||||
use ::postgres;
|
use ::postgres;
|
||||||
use hashbrown::HashSet;
|
use hashbrown::HashSet;
|
||||||
use r2d2_postgres::PostgresConnectionManager;
|
use r2d2_postgres::PostgresConnectionManager;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use warp::reject::Rejection;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct PgPool {
|
pub struct PgPool {
|
||||||
|
@ -15,8 +15,11 @@ pub struct PgPool {
|
||||||
whitelist_mode: bool,
|
whitelist_mode: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Result<T> = std::result::Result<T, err::RequestErr>;
|
||||||
|
type Rejectable<T> = std::result::Result<T, warp::Rejection>;
|
||||||
|
|
||||||
impl PgPool {
|
impl PgPool {
|
||||||
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
pub fn new(pg_cfg: &config::Postgres, whitelist_mode: bool) -> Result<Self> {
|
||||||
let mut cfg = postgres::Config::new();
|
let mut cfg = postgres::Config::new();
|
||||||
cfg.user(&pg_cfg.user)
|
cfg.user(&pg_cfg.user)
|
||||||
.host(&*pg_cfg.host.to_string())
|
.host(&*pg_cfg.host.to_string())
|
||||||
|
@ -26,19 +29,20 @@ impl PgPool {
|
||||||
cfg.password(password);
|
cfg.password(password);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
cfg.connect(postgres::NoTls)?; // Test connection, letting us immediately exit with an error
|
||||||
|
// when Postgres isn't running instead of timing out below
|
||||||
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
||||||
let pool = r2d2::Pool::builder()
|
let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
|
||||||
.max_size(10)
|
|
||||||
.build(manager)
|
Ok(Self {
|
||||||
.expect("Can connect to local postgres");
|
|
||||||
Self {
|
|
||||||
conn: pool,
|
conn: pool,
|
||||||
whitelist_mode,
|
whitelist_mode,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> {
|
pub fn select_user(self, token: &Option<String>) -> Rejectable<UserData> {
|
||||||
let mut conn = self.conn.get().unwrap();
|
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||||
|
|
||||||
if let Some(token) = token {
|
if let Some(token) = token {
|
||||||
let query_rows = conn
|
let query_rows = conn
|
||||||
.query("
|
.query("
|
||||||
|
@ -47,9 +51,9 @@ SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_lan
|
||||||
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
|
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
|
||||||
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
|
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
|
||||||
LIMIT 1",
|
LIMIT 1",
|
||||||
&[&token.to_owned()],
|
&[&token.to_owned()],
|
||||||
)
|
).map_err(warp::reject::custom)?;
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
|
||||||
if let Some(result_columns) = query_rows.get(0) {
|
if let Some(result_columns) = query_rows.get(0) {
|
||||||
let id = Id(result_columns.get(1));
|
let id = Id(result_columns.get(1));
|
||||||
let allowed_langs = result_columns
|
let allowed_langs = result_columns
|
||||||
|
@ -85,10 +89,10 @@ LIMIT 1",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
|
pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable<i64> {
|
||||||
let mut conn = self.conn.get().expect("TODO");
|
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||||
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
|
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
.map_err(warp::reject::custom)?
|
||||||
.get(0)
|
.get(0)
|
||||||
.map(|row| row.get(0))
|
.map(|row| row.get(0))
|
||||||
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
||||||
|
@ -98,31 +102,31 @@ LIMIT 1",
|
||||||
///
|
///
|
||||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||||
/// the user adds until they refresh/reconnect.
|
/// the user adds until they refresh/reconnect.
|
||||||
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
|
pub fn select_blocked_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
|
||||||
let mut conn = self.conn.get().expect("TODO");
|
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||||
conn.query(
|
conn.query(
|
||||||
"SELECT target_account_id FROM blocks WHERE account_id = $1
|
"SELECT target_account_id FROM blocks WHERE account_id = $1
|
||||||
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
|
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
|
||||||
&[&*user_id],
|
&[&*user_id],
|
||||||
)
|
)
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
.map_err(warp::reject::custom)?
|
||||||
.iter()
|
.iter()
|
||||||
.map(|row| Id(row.get(0)))
|
.map(|row| Ok(Id(row.get(0))))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
/// Query Postgres for everyone who has blocked the user
|
/// Query Postgres for everyone who has blocked the user
|
||||||
///
|
///
|
||||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||||
/// the user adds until they refresh/reconnect.
|
/// the user adds until they refresh/reconnect.
|
||||||
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
|
pub fn select_blocking_users(self, user_id: Id) -> Rejectable<HashSet<Id>> {
|
||||||
let mut conn = self.conn.get().expect("TODO");
|
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||||
conn.query(
|
conn.query(
|
||||||
"SELECT account_id FROM blocks WHERE target_account_id = $1",
|
"SELECT account_id FROM blocks WHERE target_account_id = $1",
|
||||||
&[&*user_id],
|
&[&*user_id],
|
||||||
)
|
)
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
.map_err(warp::reject::custom)?
|
||||||
.iter()
|
.iter()
|
||||||
.map(|row| Id(row.get(0)))
|
.map(|row| Ok(Id(row.get(0))))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,28 +134,28 @@ LIMIT 1",
|
||||||
///
|
///
|
||||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||||
/// the user adds until they refresh/reconnect.
|
/// the user adds until they refresh/reconnect.
|
||||||
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
|
pub fn select_blocked_domains(self, user_id: Id) -> Rejectable<HashSet<String>> {
|
||||||
let mut conn = self.conn.get().expect("TODO");
|
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||||
conn.query(
|
conn.query(
|
||||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||||
&[&*user_id],
|
&[&*user_id],
|
||||||
)
|
)
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
.map_err(warp::reject::custom)?
|
||||||
.iter()
|
.iter()
|
||||||
.map(|row| row.get(0))
|
.map(|row| Ok(row.get(0)))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test whether a user owns a list
|
/// Test whether a user owns a list
|
||||||
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
|
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable<bool> {
|
||||||
let mut conn = self.conn.get().expect("TODO");
|
let mut conn = self.conn.get().map_err(warp::reject::custom)?;
|
||||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||||
let rows = &conn
|
let rows = &conn
|
||||||
.query(
|
.query(
|
||||||
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||||
&[&list_id],
|
&[&list_id],
|
||||||
)
|
)
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
.map_err(warp::reject::custom)?;
|
||||||
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
|
Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ impl Subscription {
|
||||||
let tag = pool.select_hashtag_id(&q.hashtag)?;
|
let tag = pool.select_hashtag_id(&q.hashtag)?;
|
||||||
Timeline(Hashtag(tag), reach, stream)
|
Timeline(Hashtag(tag), reach, stream)
|
||||||
}
|
}
|
||||||
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
|
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id)? => {
|
||||||
Err(warp::reject::custom("Error: Missing access token"))?
|
Err(warp::reject::custom("Error: Missing access token"))?
|
||||||
}
|
}
|
||||||
other_tl => other_tl,
|
other_tl => other_tl,
|
||||||
|
@ -70,9 +70,9 @@ impl Subscription {
|
||||||
timeline,
|
timeline,
|
||||||
allowed_langs: user.allowed_langs,
|
allowed_langs: user.allowed_langs,
|
||||||
blocks: Blocks {
|
blocks: Blocks {
|
||||||
blocking_users: pool.clone().select_blocking_users(user.id),
|
blocking_users: pool.clone().select_blocking_users(user.id)?,
|
||||||
blocked_users: pool.clone().select_blocked_users(user.id),
|
blocked_users: pool.clone().select_blocked_users(user.id)?,
|
||||||
blocked_domains: pool.select_blocked_domains(user.id),
|
blocked_domains: pool.select_blocked_domains(user.id)?,
|
||||||
},
|
},
|
||||||
hashtag_name,
|
hashtag_name,
|
||||||
access_token: q.access_token,
|
access_token: q.access_token,
|
||||||
|
|
|
@ -19,25 +19,29 @@ impl Timeline {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String> {
|
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String> {
|
||||||
use {Content::*, Reach::*, Stream::*};
|
// TODO -- does this need to account for namespaces?
|
||||||
|
use {Content::*, Reach::*, Stream::*, TimelineErr::*};
|
||||||
|
|
||||||
Ok(match self {
|
Ok(match self {
|
||||||
Timeline(Public, Federated, All) => "timeline:public".into(),
|
Timeline(Public, Federated, All) => "timeline:public".to_string(),
|
||||||
Timeline(Public, Local, All) => "timeline:public:local".into(),
|
Timeline(Public, Local, All) => "timeline:public:local".to_string(),
|
||||||
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
|
Timeline(Public, Federated, Media) => "timeline:public:media".to_string(),
|
||||||
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
|
Timeline(Public, Local, Media) => "timeline:public:local:media".to_string(),
|
||||||
// TODO -- would `.push_str` be faster here?
|
Timeline(Hashtag(_id), Federated, All) => {
|
||||||
Timeline(Hashtag(_id), Federated, All) => format!(
|
["timeline:hashtag:", hashtag.ok_or(MissingHashtag)?].concat()
|
||||||
"timeline:hashtag:{}",
|
}
|
||||||
hashtag.ok_or(TimelineErr::MissingHashtag)?
|
Timeline(Hashtag(_id), Local, All) => [
|
||||||
),
|
"timeline:hashtag:",
|
||||||
Timeline(Hashtag(_id), Local, All) => format!(
|
hashtag.ok_or(MissingHashtag)?,
|
||||||
"timeline:hashtag:{}:local",
|
":local",
|
||||||
hashtag.ok_or(TimelineErr::MissingHashtag)?
|
]
|
||||||
),
|
.concat(),
|
||||||
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
|
Timeline(User(id), Federated, All) => ["timeline:", &id.to_string()].concat(),
|
||||||
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
|
Timeline(User(id), Federated, Notification) => {
|
||||||
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
|
["timeline:", &id.to_string(), ":notification"].concat()
|
||||||
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
|
}
|
||||||
|
Timeline(List(id), Federated, All) => ["timeline:list:", &id.to_string()].concat(),
|
||||||
|
Timeline(Direct(id), Federated, All) => ["timeline:direct:", &id.to_string()].concat(),
|
||||||
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
|
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -57,8 +61,7 @@ impl Timeline {
|
||||||
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
|
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
|
||||||
["list", id] => Timeline(List(id.parse()?), Federated, All),
|
["list", id] => Timeline(List(id.parse()?), Federated, All),
|
||||||
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
|
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
|
||||||
// Other endpoints don't exist:
|
[..] => Err(InvalidInput)?, // Other endpoints don't exist
|
||||||
[..] => Err(InvalidInput)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,43 @@ pub enum RedisCmd {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RedisCmd {
|
impl RedisCmd {
|
||||||
pub fn into_sendable(&self, tl: &String) -> (Vec<u8>, Vec<u8>) {
|
pub fn into_sendable(self, tl: &str) -> (Vec<u8>, Vec<u8>) {
|
||||||
match self {
|
match self {
|
||||||
RedisCmd::Subscribe => (
|
RedisCmd::Subscribe => (
|
||||||
format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(),
|
[
|
||||||
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl).into_bytes(),
|
b"*2\r\n$9\r\nsubscribe\r\n$",
|
||||||
|
tl.len().to_string().as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
tl.as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
]
|
||||||
|
.concat(),
|
||||||
|
[
|
||||||
|
b"*3\r\n$3\r\nSET\r\n$",
|
||||||
|
tl.len().to_string().as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
tl.as_bytes(),
|
||||||
|
b"\r\n$1\r\n1\r\n",
|
||||||
|
]
|
||||||
|
.concat(),
|
||||||
),
|
),
|
||||||
RedisCmd::Unsubscribe => (
|
RedisCmd::Unsubscribe => (
|
||||||
format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(),
|
[
|
||||||
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl).into_bytes(),
|
b"*2\r\n$11\r\nunsubscribe\r\n$",
|
||||||
|
tl.len().to_string().as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
tl.as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
]
|
||||||
|
.concat(),
|
||||||
|
[
|
||||||
|
b"*3\r\n$3\r\nSET\r\n$",
|
||||||
|
tl.len().to_string().as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
tl.as_bytes(),
|
||||||
|
b"\r\n$1\r\n0\r\n",
|
||||||
|
]
|
||||||
|
.concat(),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,9 @@ pub struct RedisConn {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RedisConn {
|
impl RedisConn {
|
||||||
pub fn new(redis_cfg: Redis) -> Result<Self> {
|
pub fn new(redis_cfg: &Redis) -> Result<Self> {
|
||||||
let addr = format!("{}:{}", *redis_cfg.host, *redis_cfg.port);
|
let addr = [&*redis_cfg.host, ":", &*redis_cfg.port.to_string()].concat();
|
||||||
|
|
||||||
let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?;
|
let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?;
|
||||||
conn.set_nonblocking(true)
|
conn.set_nonblocking(true)
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
@ -49,7 +50,7 @@ impl RedisConn {
|
||||||
|
|
||||||
pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ManagerErr> {
|
pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ManagerErr> {
|
||||||
let mut size = 100; // large enough to handle subscribe/unsubscribe notice
|
let mut size = 100; // large enough to handle subscribe/unsubscribe notice
|
||||||
let (mut buffer, mut first_read) = (vec![0u8; size], true);
|
let (mut buffer, mut first_read) = (vec![0_u8; size], true);
|
||||||
loop {
|
loop {
|
||||||
match self.primary.read(&mut buffer) {
|
match self.primary.read(&mut buffer) {
|
||||||
Ok(n) if n != size => break self.redis_input.extend_from_slice(&buffer[..n]),
|
Ok(n) if n != size => break self.redis_input.extend_from_slice(&buffer[..n]),
|
||||||
|
@ -81,7 +82,7 @@ impl RedisConn {
|
||||||
use {Async::*, RedisParseOutput::*};
|
use {Async::*, RedisParseOutput::*};
|
||||||
let (res, leftover) = match RedisParseOutput::try_from(input) {
|
let (res, leftover) = match RedisParseOutput::try_from(input) {
|
||||||
Ok(Msg(msg)) => match &self.redis_namespace {
|
Ok(Msg(msg)) => match &self.redis_namespace {
|
||||||
Some(ns) if msg.timeline_txt.starts_with(&format!("{}:timeline:", ns)) => {
|
Some(ns) if msg.timeline_txt.starts_with(&[ns, ":timeline:"].concat()) => {
|
||||||
let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..];
|
let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..];
|
||||||
let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?;
|
let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?;
|
||||||
let event = msg.event_txt.try_into()?;
|
let event = msg.event_txt.try_into()?;
|
||||||
|
@ -135,8 +136,17 @@ impl RedisConn {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
|
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
|
||||||
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
|
conn.write_all(
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
&[
|
||||||
|
b"*2\r\n$4\r\nauth\r\n$",
|
||||||
|
pass.len().to_string().as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
pass.as_bytes(),
|
||||||
|
b"\r\n",
|
||||||
|
]
|
||||||
|
.concat(),
|
||||||
|
)
|
||||||
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
let mut buffer = vec![0_u8; 5];
|
let mut buffer = vec![0_u8; 5];
|
||||||
conn.read_exact(&mut buffer)
|
conn.read_exact(&mut buffer)
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
|
|
@ -31,7 +31,7 @@ impl Manager {
|
||||||
/// Create a new `Manager`, with its own Redis connections (but, as yet, no
|
/// Create a new `Manager`, with its own Redis connections (but, as yet, no
|
||||||
/// active subscriptions).
|
/// active subscriptions).
|
||||||
pub fn try_from(
|
pub fn try_from(
|
||||||
redis_cfg: config::Redis,
|
redis_cfg: &config::Redis,
|
||||||
tx: watch::Sender<(Timeline, Event)>,
|
tx: watch::Sender<(Timeline, Event)>,
|
||||||
rx: mpsc::UnboundedReceiver<Timeline>,
|
rx: mpsc::UnboundedReceiver<Timeline>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
@ -99,11 +99,10 @@ impl Manager {
|
||||||
self.tx.broadcast((Timeline::empty(), Event::Ping))?
|
self.tx.broadcast((Timeline::empty(), Event::Ping))?
|
||||||
} else {
|
} else {
|
||||||
match self.redis_connection.poll_redis() {
|
match self.redis_connection.poll_redis() {
|
||||||
Ok(Async::NotReady) => (),
|
Ok(Async::NotReady) | Ok(Async::Ready(None)) => (), // None = cmd or msg for other namespace
|
||||||
Ok(Async::Ready(Some((timeline, event)))) => {
|
Ok(Async::Ready(Some((timeline, event)))) => {
|
||||||
self.tx.broadcast((timeline, event))?
|
self.tx.broadcast((timeline, event))?
|
||||||
}
|
}
|
||||||
Ok(Async::Ready(None)) => (), // subscription cmd or msg for other namespace
|
|
||||||
Err(err) => log::error!("{}", err), // drop msg, log err, and proceed
|
Err(err) => log::error!("{}", err), // drop msg, log err, and proceed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,10 +82,7 @@ fn utf8_to_redis_data<'a>(s: &'a str) -> Result<(RedisData, &'a str), RedisParse
|
||||||
":" => parse_redis_int(s),
|
":" => parse_redis_int(s),
|
||||||
"$" => parse_redis_bulk_string(s),
|
"$" => parse_redis_bulk_string(s),
|
||||||
"*" => parse_redis_array(s),
|
"*" => parse_redis_array(s),
|
||||||
e => Err(InvalidLineStart(format!(
|
e => Err(InvalidLineStart(e.to_string())),
|
||||||
"Encountered invalid initial character `{}` in line `{}`",
|
|
||||||
e, s
|
|
||||||
))),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,18 +6,11 @@ use log;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
use warp::reply::Reply;
|
use warp::reply::Reply;
|
||||||
use warp::sse::{ServerSentEvent, Sse as WarpSse};
|
use warp::sse::Sse as WarpSse;
|
||||||
|
|
||||||
pub struct Sse;
|
pub struct Sse;
|
||||||
|
|
||||||
impl Sse {
|
impl Sse {
|
||||||
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
|
|
||||||
Some((
|
|
||||||
warp::sse::event(event.event_name()),
|
|
||||||
warp::sse::data(event.payload().unwrap_or_else(String::new)),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_events(
|
pub fn send_events(
|
||||||
sse: WarpSse,
|
sse: WarpSse,
|
||||||
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||||
|
@ -40,21 +33,20 @@ impl Sse {
|
||||||
TypeSafe(Update { payload, queued_at }) => match timeline {
|
TypeSafe(Update { payload, queued_at }) => match timeline {
|
||||||
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
||||||
_ if payload.involves_any(&blocks) => None,
|
_ if payload.involves_any(&blocks) => None,
|
||||||
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
|
_ => Event::TypeSafe(CheckedEvent::Update { payload, queued_at })
|
||||||
payload,
|
.to_warp_reply(),
|
||||||
queued_at,
|
|
||||||
})),
|
|
||||||
},
|
},
|
||||||
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
|
TypeSafe(non_update) => Event::TypeSafe(non_update).to_warp_reply(),
|
||||||
Dynamic(dyn_event) => {
|
Dynamic(dyn_event) => {
|
||||||
if let EventKind::Update(s) = dyn_event.kind {
|
if let EventKind::Update(s) = dyn_event.kind {
|
||||||
match timeline {
|
match timeline {
|
||||||
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
|
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
|
||||||
_ if s.involves_any(&blocks) => None,
|
_ if s.involves_any(&blocks) => None,
|
||||||
_ => Self::reply_with(Dynamic(DynEvent {
|
_ => Dynamic(DynEvent {
|
||||||
kind: EventKind::Update(s),
|
kind: EventKind::Update(s),
|
||||||
..dyn_event
|
..dyn_event
|
||||||
})),
|
})
|
||||||
|
.to_warp_reply(),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
|
@ -39,9 +39,11 @@ impl Ws {
|
||||||
.map_err(|_| -> warp::Error { unreachable!() })
|
.map_err(|_| -> warp::Error { unreachable!() })
|
||||||
.forward(transmit_to_ws)
|
.forward(transmit_to_ws)
|
||||||
.map(|_r| ())
|
.map(|_r| ())
|
||||||
.map_err(|e| match e.to_string().as_ref() {
|
.map_err(|e| {
|
||||||
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
|
match e.to_string().as_ref() {
|
||||||
_ => log::warn!("WebSocket send error: {}", e),
|
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
|
||||||
|
_ => log::warn!("WebSocket send error: {}", e),
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -50,7 +52,7 @@ impl Ws {
|
||||||
|
|
||||||
incoming_events.for_each(move |(tl, event)| {
|
incoming_events.for_each(move |(tl, event)| {
|
||||||
if matches!(event, Event::Ping) {
|
if matches!(event, Event::Ping) {
|
||||||
self.send_ping()
|
self.send_msg(&event)
|
||||||
} else if target_timeline == tl {
|
} else if target_timeline == tl {
|
||||||
use crate::event::{CheckedEvent::Update, Event::*, EventKind};
|
use crate::event::{CheckedEvent::Update, Event::*, EventKind};
|
||||||
use crate::request::Stream::Public;
|
use crate::request::Stream::Public;
|
||||||
|
@ -61,18 +63,18 @@ impl Ws {
|
||||||
TypeSafe(Update { payload, queued_at }) => match tl {
|
TypeSafe(Update { payload, queued_at }) => match tl {
|
||||||
Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()),
|
Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()),
|
||||||
_ if payload.involves_any(&blocks) => Ok(()),
|
_ if payload.involves_any(&blocks) => Ok(()),
|
||||||
_ => self.send_msg(TypeSafe(Update { payload, queued_at })),
|
_ => self.send_msg(&TypeSafe(Update { payload, queued_at })),
|
||||||
},
|
},
|
||||||
TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)),
|
TypeSafe(non_update) => self.send_msg(&TypeSafe(non_update)),
|
||||||
Dynamic(dyn_event) => {
|
Dynamic(dyn_event) => {
|
||||||
if let EventKind::Update(s) = dyn_event.kind.clone() {
|
if let EventKind::Update(s) = dyn_event.kind.clone() {
|
||||||
match tl {
|
match tl {
|
||||||
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
|
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
|
||||||
_ if s.involves_any(&blocks) => Ok(()),
|
_ if s.involves_any(&blocks) => Ok(()),
|
||||||
_ => self.send_msg(Dynamic(dyn_event)),
|
_ => self.send_msg(&Dynamic(dyn_event)),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.send_msg(Dynamic(dyn_event))
|
self.send_msg(&Dynamic(dyn_event))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ping => unreachable!(), // handled pings above
|
Ping => unreachable!(), // handled pings above
|
||||||
|
@ -83,24 +85,14 @@ impl Ws {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn send_ping(&mut self) -> Result<(), ()> {
|
fn send_msg(&mut self, event: &Event) -> Result<(), ()> {
|
||||||
self.send_txt("{}")
|
let txt = &event.to_json_string();
|
||||||
}
|
|
||||||
|
|
||||||
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
|
|
||||||
self.send_txt(&event.to_json_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
|
|
||||||
let tl = self.subscription.timeline;
|
let tl = self.subscription.timeline;
|
||||||
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
|
let mut channel = self.ws_tx.clone().ok_or(())?;
|
||||||
Ok(_) => Ok(()),
|
channel.try_send(Message::text(txt)).map_err(|_| {
|
||||||
Err(_) => {
|
self.unsubscribe_tx
|
||||||
self.unsubscribe_tx
|
.try_send(tl)
|
||||||
.try_send(tl)
|
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
|
||||||
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
|
})
|
||||||
Err(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue