mirror of https://github.com/mastodon/flodgatt
Code reorganization (#130)
* Reorganize files * Refactor main() * Code reorganization [WIP] * Reorganize code [WIP] * Refacto RedisConn [WIP] * Complete code reorganization
This commit is contained in:
parent
0eec8f6f7b
commit
45f9d4b9fb
|
@ -57,28 +57,6 @@ name = "autocfg"
|
|||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"backtrace-sys 0.1.28 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rustc-demangle 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace-sys"
|
||||
version = "0.1.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cc 1.0.50 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.10.1"
|
||||
|
@ -386,13 +364,8 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "dotenv"
|
||||
version = "0.14.0"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dtoa"
|
||||
|
@ -416,26 +389,6 @@ dependencies = [
|
|||
"termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "failure"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"backtrace 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "failure_derive"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"synstructure 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fake-simd"
|
||||
version = "0.1.2"
|
||||
|
@ -453,10 +406,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
|
||||
[[package]]
|
||||
name = "flodgatt"
|
||||
version = "0.8.2"
|
||||
version = "0.8.3"
|
||||
dependencies = [
|
||||
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"hashbrown 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
@ -1607,11 +1560,6 @@ name = "rent_to_own"
|
|||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.2.3"
|
||||
|
@ -1830,17 +1778,6 @@ dependencies = [
|
|||
"unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.1.0"
|
||||
|
@ -2399,8 +2336,6 @@ dependencies = [
|
|||
"checksum atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "9a7d5b8723950951411ee34d271d99dddcc2035a16ab25310ea2c8cfd4369652"
|
||||
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
|
||||
"checksum autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d"
|
||||
"checksum backtrace 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)" = "f106c02a3604afcdc0df5d36cc47b44b55917dbaf3d808f71c163a0ddba64637"
|
||||
"checksum backtrace-sys 0.1.28 (registry+https://github.com/rust-lang/crates.io-index)" = "797c830ac25ccc92a7f8a7b9862bde440715531514594a6154e3d4a54dd769b6"
|
||||
"checksum base64 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0b25d992356d2eb0ed82172f5248873db5560c4721f564b13cb5193bda5e668e"
|
||||
"checksum base64 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b41b7ea54a0c9d92199de89e20e58d49f02f8e699814ef3fdf266f6f748d15c7"
|
||||
"checksum bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "228047a76f468627ca71776ecdebd732a3423081fcf5125585bcd7c49886ce12"
|
||||
|
@ -2435,12 +2370,10 @@ dependencies = [
|
|||
"checksum darling_macro 0.8.6 (registry+https://github.com/rust-lang/crates.io-index)" = "244e8987bd4e174385240cde20a3657f607fb0797563c28255c353b5819a07b1"
|
||||
"checksum derive_state_machine_future 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1220ad071cb8996454c20adf547a34ba3ac793759dab793d9dc04996a373ac83"
|
||||
"checksum digest 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05f47366984d3ad862010e22c7ce81a7dbcaebbdfb37241a620f8b6596ee135c"
|
||||
"checksum dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7bdb5b956a911106b6b479cdc6bc1364d359a32299f17b49994f5327132e18d9"
|
||||
"checksum dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)" = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
||||
"checksum dtoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "ea57b42383d091c85abcc2706240b94ab2a8fa1fc81c10ff23c4de06e2a90b5e"
|
||||
"checksum either 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "5527cfe0d098f36e3f8839852688e63c8fff1c90b2b405aef730615f9a7bcf7b"
|
||||
"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3"
|
||||
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
|
||||
"checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1"
|
||||
"checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed"
|
||||
"checksum fallible-iterator 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||
"checksum fixedbitset 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "86d4de0081402f5e88cdac65c8dcdcc73118c1a7a465e2a05f0da05843a8ea33"
|
||||
|
@ -2573,7 +2506,6 @@ dependencies = [
|
|||
"checksum regex-syntax 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e734e891f5b408a29efbf8309e656876276f49ab6a6ac208600b4419bd893d90"
|
||||
"checksum remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "4a83fa3702a688b9359eccba92d153ac33fd2e8462f9e0e3fdf155239ea7792e"
|
||||
"checksum rent_to_own 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05a51ad2b1c5c710fa89e6b1631068dab84ed687bc6a5fe061ad65da3d0c25b2"
|
||||
"checksum rustc-demangle 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "ccc78bfd5acd7bf3e89cffcf899e5cb1a52d6fafa8dec2739ad70c9577a57288"
|
||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
"checksum ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c92464b447c0ee8c4fb3824ecc8383b81717b9f1e74ba2e72540aef7b9f82997"
|
||||
"checksum safemem 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "d2b08423011dae9a5ca23f07cf57dac3857f5c885d352b76f6d95f4aea9434d0"
|
||||
|
@ -2604,7 +2536,6 @@ dependencies = [
|
|||
"checksum subtle 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2d67a5a62ba6e01cb2192ff309324cb4875d0c451d55fe2319433abe7a05a8ee"
|
||||
"checksum syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)" = "a1393e4a97a19c01e900df2aec855a29f71cf02c402e2f443b8d2747c25c5dbe"
|
||||
"checksum syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "66850e97125af79138385e9b88339cbcd037e3f28ceab8c5ad98e64f0f1f80bf"
|
||||
"checksum synstructure 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "73687139bf99285483c96ac0add482c3776528beac1d97d444f6e91f203a2015"
|
||||
"checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
|
||||
"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e"
|
||||
"checksum termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a8fb22f7cde82c8220e5aeacb3258ed7ce996142c77cba193f203515e26c330"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "flodgatt"
|
||||
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
|
||||
version = "0.8.2"
|
||||
version = "0.8.3"
|
||||
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
|
||||
edition = "2018"
|
||||
|
||||
|
@ -15,7 +15,7 @@ serde_json = "1.0.50"
|
|||
serde_derive = "1.0.90"
|
||||
pretty_env_logger = "0.3.0"
|
||||
postgres = "0.17.0"
|
||||
dotenv = "0.14.0"
|
||||
dotenv = "0.15.0"
|
||||
postgres-openssl = { git = "https://github.com/sfackler/rust-postgres.git"}
|
||||
url = "2.1.0"
|
||||
strum = "0.16.0"
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
pub use {deployment_cfg::Deployment, postgres_cfg::Postgres, redis_cfg::Redis};
|
||||
|
||||
use self::environmental_variables::EnvVar;
|
||||
use super::err;
|
||||
use hashbrown::HashMap;
|
||||
use std::env;
|
||||
|
||||
mod deployment_cfg;
|
||||
mod deployment_cfg_types;
|
||||
mod environmental_variables;
|
||||
mod postgres_cfg;
|
||||
mod postgres_cfg_types;
|
||||
mod redis_cfg;
|
||||
mod redis_cfg_types;
|
||||
|
||||
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
|
||||
// TODO -- should this allow the user to run in a dir without a `.env` file?
|
||||
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
|
||||
Some("production") => ".env.production",
|
||||
Some("development") | None => ".env",
|
||||
Some(_unsupported) => Err(err::FatalErr::Unknown)?, // TODO make more specific
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn from_env<'a>(env_vars: HashMap<String, String>) -> (Postgres, Redis, Deployment<'a>) {
|
||||
let env_vars = EnvVar::new(env_vars);
|
||||
log::info!("Environmental variables Flodgatt received: {}", &env_vars);
|
||||
(
|
||||
Postgres::from_env(env_vars.clone()),
|
||||
Redis::from_env(env_vars.clone()),
|
||||
Deployment::from_env(env_vars.clone()),
|
||||
)
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
use super::{deployment_cfg_types::*, EnvVar};
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct DeploymentConfig<'a> {
|
||||
pub struct Deployment<'a> {
|
||||
pub env: Env,
|
||||
pub log_level: LogLevel,
|
||||
pub address: FlodgattAddr,
|
||||
|
@ -13,7 +13,7 @@ pub struct DeploymentConfig<'a> {
|
|||
pub whitelist_mode: WhitelistMode,
|
||||
}
|
||||
|
||||
impl DeploymentConfig<'_> {
|
||||
impl Deployment<'_> {
|
||||
pub fn from_env(env: EnvVar) -> Self {
|
||||
let mut cfg = Self {
|
||||
env: Env::default().maybe_update(env.get("NODE_ENV")),
|
||||
|
|
|
@ -92,7 +92,7 @@ impl fmt::Debug for Cors<'_> {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum LogLevelInner {
|
||||
Trace,
|
||||
|
@ -102,7 +102,7 @@ pub enum LogLevelInner {
|
|||
Error,
|
||||
}
|
||||
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum EnvInner {
|
||||
Production,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use hashbrown::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EnvVar(pub HashMap<String, String>);
|
||||
impl std::ops::Deref for EnvVar {
|
||||
type Target = HashMap<String, String>;
|
||||
|
@ -39,7 +40,7 @@ impl EnvVar {
|
|||
impl fmt::Display for EnvVar {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut result = String::new();
|
||||
for env_var in [
|
||||
for env_var in &[
|
||||
"NODE_ENV",
|
||||
"RUST_LOG",
|
||||
"BIND",
|
||||
|
@ -62,9 +63,7 @@ impl fmt::Display for EnvVar {
|
|||
"REDIS_USER",
|
||||
"REDIS_DB",
|
||||
"REDIS_FREQ",
|
||||
]
|
||||
.iter()
|
||||
{
|
||||
] {
|
||||
if let Some(value) = self.get(&(*env_var).to_string()) {
|
||||
result = format!("{}\n {}: {}", result, env_var, value)
|
||||
}
|
||||
|
@ -96,6 +95,7 @@ macro_rules! from_env_var {
|
|||
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
|
||||
let from_str = |$arg:ident| $body:expr;
|
||||
) => {
|
||||
#[derive(Clone)]
|
||||
pub struct $name(pub $type);
|
||||
impl std::fmt::Debug for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
|
@ -125,14 +125,6 @@ macro_rules! from_env_var {
|
|||
})),
|
||||
None => self,
|
||||
}
|
||||
|
||||
// if let Some(value) = var {
|
||||
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
// crate::err::env_var_fatal($env_var, value, $allowed_values)
|
||||
// }))
|
||||
// } else {
|
||||
// self
|
||||
// }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
mod deployment_cfg;
|
||||
mod deployment_cfg_types;
|
||||
mod postgres_cfg;
|
||||
mod postgres_cfg_types;
|
||||
mod redis_cfg;
|
||||
mod redis_cfg_types;
|
||||
mod environmental_variables;
|
||||
|
||||
pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar};
|
||||
|
|
@ -2,8 +2,8 @@ use super::{postgres_cfg_types::*, EnvVar};
|
|||
use url::Url;
|
||||
use urlencoding;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PostgresConfig {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Postgres {
|
||||
pub user: PgUser,
|
||||
pub host: PgHost,
|
||||
pub password: PgPass,
|
||||
|
@ -46,7 +46,7 @@ impl EnvVar {
|
|||
}
|
||||
}
|
||||
|
||||
impl PostgresConfig {
|
||||
impl Postgres {
|
||||
/// Configure Postgres and return a connection
|
||||
|
||||
pub fn from_env(env: EnvVar) -> Self {
|
||||
|
|
|
@ -49,7 +49,7 @@ from_env_var!(
|
|||
let from_str = |s| PgSslInner::from_str(s).ok();
|
||||
);
|
||||
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum PgSslInner {
|
||||
Prefer,
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::config::EnvVar;
|
|||
use url::Url;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RedisConfig {
|
||||
pub struct Redis {
|
||||
pub user: RedisUser,
|
||||
pub password: RedisPass,
|
||||
pub port: RedisPort,
|
||||
|
@ -40,7 +40,7 @@ impl EnvVar {
|
|||
}
|
||||
}
|
||||
|
||||
impl RedisConfig {
|
||||
impl Redis {
|
||||
const USER_SET_WARNING: &'static str =
|
||||
"Redis user specified, but Redis did not ask for a username. Ignoring it.";
|
||||
const DB_SET_WARNING: &'static str = r"Redis database specified, but PubSub connections do not use databases.
|
||||
|
@ -52,7 +52,7 @@ For similar functionality, you may wish to set a REDIS_NAMESPACE";
|
|||
None => env,
|
||||
};
|
||||
|
||||
let cfg = RedisConfig {
|
||||
let cfg = Redis {
|
||||
user: RedisUser::default().maybe_update(env.get("REDIS_USER")),
|
||||
password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD")),
|
||||
port: RedisPort::default().maybe_update(env.get("REDIS_PORT")),
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
use crate::response::ManagerErr;
|
||||
use std::fmt;
|
||||
|
||||
pub enum FatalErr {
|
||||
Unknown,
|
||||
ReceiverErr(ManagerErr),
|
||||
DotEnv(dotenv::Error),
|
||||
Logger(log::SetLoggerError),
|
||||
}
|
||||
|
||||
impl FatalErr {
|
||||
pub fn exit(msg: impl fmt::Display) {
|
||||
eprintln!("{}", msg);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for FatalErr {}
|
||||
impl fmt::Debug for FatalErr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FatalErr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
use FatalErr::*;
|
||||
write!(
|
||||
f,
|
||||
"{}",
|
||||
match self {
|
||||
Unknown => "Flodgatt encountered an unknown, unrecoverable error".into(),
|
||||
ReceiverErr(e) => format!("{}", e),
|
||||
Logger(e) => format!("{}", e),
|
||||
DotEnv(e) => format!("Could not load specified environmental file: {}", e),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<dotenv::Error> for FatalErr {
|
||||
fn from(e: dotenv::Error) -> Self {
|
||||
Self::DotEnv(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ManagerErr> for FatalErr {
|
||||
fn from(e: ManagerErr) -> Self {
|
||||
Self::ReceiverErr(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<log::SetLoggerError> for FatalErr {
|
||||
fn from(e: log::SetLoggerError) -> Self {
|
||||
Self::Logger(e)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO delete vvvv when postgres_cfg.rs has better error handling
|
||||
pub fn die_with_msg(msg: impl fmt::Display) -> ! {
|
||||
eprintln!("FATAL ERROR: {}", msg);
|
||||
std::process::exit(1);
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
mod timeline;
|
||||
|
||||
pub use timeline::TimelineErr;
|
||||
|
||||
use crate::redis_to_client_stream::ReceiverErr;
|
||||
use std::fmt;
|
||||
|
||||
pub enum FatalErr {
|
||||
Err,
|
||||
ReceiverErr(ReceiverErr),
|
||||
}
|
||||
|
||||
impl FatalErr {
|
||||
pub fn exit(msg: impl fmt::Display) {
|
||||
eprintln!("{}", msg);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for FatalErr {}
|
||||
impl fmt::Debug for FatalErr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FatalErr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
write!(f, "Error message")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ReceiverErr> for FatalErr {
|
||||
fn from(e: ReceiverErr) -> Self {
|
||||
Self::ReceiverErr(e)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO delete vvvv when postgres_cfg.rs has better error handling
|
||||
pub fn die_with_msg(msg: impl fmt::Display) -> ! {
|
||||
eprintln!("FATAL ERROR: {}", msg);
|
||||
std::process::exit(1);
|
||||
}
|
|
@ -12,7 +12,7 @@ mod tag;
|
|||
mod visibility;
|
||||
|
||||
pub use announcement::Announcement;
|
||||
pub(in crate::messages::event) use announcement_reaction::AnnouncementReaction;
|
||||
pub(in crate::event) use announcement_reaction::AnnouncementReaction;
|
||||
pub use conversation::Conversation;
|
||||
pub use id::Id;
|
||||
pub use notification::Notification;
|
|
@ -8,7 +8,7 @@ use super::{
|
|||
};
|
||||
use {application::Application, attachment::Attachment, card::Card, poll::Poll};
|
||||
|
||||
use crate::parse_client_request::Blocks;
|
||||
use crate::request::Blocks;
|
||||
|
||||
use hashbrown::HashSet;
|
||||
use serde::{Deserialize, Serialize};
|
|
@ -1,5 +1,5 @@
|
|||
use super::{EventErr, Id};
|
||||
use crate::parse_client_request::Blocks;
|
||||
use crate::request::Blocks;
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
|
@ -35,10 +35,11 @@
|
|||
//! polls the `Receiver` and the frequency with which the `Receiver` polls Redis.
|
||||
//!
|
||||
|
||||
//#![warn(clippy::pedantic)]
|
||||
#![allow(clippy::try_err, clippy::match_bool)]
|
||||
|
||||
pub mod config;
|
||||
pub mod err;
|
||||
pub mod messages;
|
||||
pub mod parse_client_request;
|
||||
pub mod redis_to_client_stream;
|
||||
pub mod event;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
|
|
195
src/main.rs
195
src/main.rs
|
@ -1,158 +1,107 @@
|
|||
use flodgatt::{
|
||||
config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig},
|
||||
err::FatalErr,
|
||||
messages::Event,
|
||||
parse_client_request::{PgPool, Subscription, Timeline},
|
||||
redis_to_client_stream::{Receiver, SseStream, WsStream},
|
||||
};
|
||||
use std::{env, fs, net::SocketAddr, os::unix::fs::PermissionsExt};
|
||||
use tokio::{
|
||||
net::UnixListener,
|
||||
sync::{mpsc, watch},
|
||||
};
|
||||
use warp::{http::StatusCode, path, ws::Ws2, Filter, Rejection};
|
||||
use flodgatt::config;
|
||||
use flodgatt::err::FatalErr;
|
||||
use flodgatt::event::Event;
|
||||
use flodgatt::request::{Handler, Subscription, Timeline};
|
||||
use flodgatt::response::redis;
|
||||
use flodgatt::response::stream;
|
||||
|
||||
use futures::{future::lazy, stream::Stream as _};
|
||||
use std::fs;
|
||||
use std::net::SocketAddr;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::time::Instant;
|
||||
use tokio::net::UnixListener;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tokio::timer::Interval;
|
||||
use warp::ws::Ws2;
|
||||
use warp::Filter;
|
||||
|
||||
fn main() -> Result<(), FatalErr> {
|
||||
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
|
||||
Some("production") => ".env.production",
|
||||
Some("development") | None => ".env",
|
||||
Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"),
|
||||
})
|
||||
.ok();
|
||||
let env_vars = EnvVar::new(dotenv::vars().collect());
|
||||
pretty_env_logger::init();
|
||||
log::info!("Environmental variables Flodgatt received: {}", &env_vars);
|
||||
config::merge_dotenv()?;
|
||||
pretty_env_logger::try_init()?;
|
||||
let (postgres_cfg, redis_cfg, cfg) = config::from_env(dotenv::vars().collect());
|
||||
|
||||
let postgres_cfg = PostgresConfig::from_env(env_vars.clone());
|
||||
let redis_cfg = RedisConfig::from_env(env_vars.clone());
|
||||
let cfg = DeploymentConfig::from_env(env_vars);
|
||||
|
||||
let pg_pool = PgPool::new(postgres_cfg);
|
||||
// Create channels to communicate between threads
|
||||
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
||||
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
||||
|
||||
let request = Handler::new(postgres_cfg, *cfg.whitelist_mode);
|
||||
let poll_freq = *redis_cfg.polling_interval;
|
||||
let receiver = Receiver::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
||||
log::info!("Streaming server initialized and ready to accept connections");
|
||||
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
||||
|
||||
// Server Sent Events
|
||||
let sse_receiver = receiver.clone();
|
||||
let sse_manager = shared_manager.clone();
|
||||
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
||||
let whitelist_mode = *cfg.whitelist_mode;
|
||||
let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode)
|
||||
|
||||
let sse = request
|
||||
.sse_subscription()
|
||||
.and(warp::sse())
|
||||
.map(
|
||||
move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| {
|
||||
log::info!("Incoming SSE request for {:?}", subscription.timeline);
|
||||
{
|
||||
let mut receiver = sse_receiver.lock().unwrap_or_else(Receiver::recover);
|
||||
receiver.subscribe(&subscription).unwrap_or_else(|e| {
|
||||
log::error!("Could not subscribe to the Redis channel: {}", e)
|
||||
});
|
||||
}
|
||||
let cmd_tx = sse_cmd_tx.clone();
|
||||
let sse_rx = sse_rx.clone();
|
||||
// send the updates through the SSE connection
|
||||
SseStream::send_events(sse_connection_to_client, cmd_tx, subscription, sse_rx)
|
||||
},
|
||||
)
|
||||
.map(move |subscription: Subscription, sse: warp::sse::Sse| {
|
||||
log::info!("Incoming SSE request for {:?}", subscription.timeline);
|
||||
let mut manager = sse_manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||
manager.subscribe(&subscription);
|
||||
|
||||
stream::Sse::send_events(sse, sse_cmd_tx.clone(), subscription, sse_rx.clone())
|
||||
})
|
||||
.with(warp::reply::with::header("Connection", "keep-alive"));
|
||||
|
||||
// WebSocket
|
||||
let ws_receiver = receiver.clone();
|
||||
let whitelist_mode = *cfg.whitelist_mode;
|
||||
let ws_routes = Subscription::from_ws_request(pg_pool, whitelist_mode)
|
||||
let ws_manager = shared_manager.clone();
|
||||
let ws = request
|
||||
.ws_subscription()
|
||||
.and(warp::ws::ws2())
|
||||
.map(move |subscription: Subscription, ws: Ws2| {
|
||||
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
||||
{
|
||||
let mut receiver = ws_receiver.lock().unwrap_or_else(Receiver::recover);
|
||||
let mut manager = ws_manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||
manager.subscribe(&subscription);
|
||||
let token = subscription.access_token.clone().unwrap_or_default(); // token sent for security
|
||||
let ws_stream = stream::Ws::new(cmd_tx.clone(), event_rx.clone(), subscription);
|
||||
|
||||
receiver.subscribe(&subscription).unwrap_or_else(|e| {
|
||||
log::error!("Could not subscribe to the Redis channel: {}", e)
|
||||
});
|
||||
}
|
||||
let cmd_tx = cmd_tx.clone();
|
||||
let ws_rx = event_rx.clone();
|
||||
let token = subscription
|
||||
.clone()
|
||||
.access_token
|
||||
.unwrap_or_else(String::new);
|
||||
|
||||
// send the updates through the WS connection (along with the access_token, for security)
|
||||
(
|
||||
ws.on_upgrade(move |ws| WsStream::new(ws, cmd_tx, subscription).send_events(ws_rx)),
|
||||
token,
|
||||
)
|
||||
(ws.on_upgrade(move |ws| ws_stream.send_to(ws)), token)
|
||||
})
|
||||
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
|
||||
|
||||
#[cfg(feature = "stub_status")]
|
||||
#[rustfmt::skip]
|
||||
let status = {
|
||||
let (r1, r3) = (shared_manager.clone(), shared_manager.clone());
|
||||
request.health().map(|| "OK")
|
||||
.or(request.status()
|
||||
.map(move || r1.lock().unwrap_or_else(redis::Manager::recover).count()))
|
||||
.or(request.status_per_timeline()
|
||||
.map(move || r3.lock().unwrap_or_else(redis::Manager::recover).list()))
|
||||
};
|
||||
#[cfg(not(feature = "stub_status"))]
|
||||
let status = request.health().map(|| "OK");
|
||||
|
||||
let cors = warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_methods(cfg.cors.allowed_methods)
|
||||
.allow_headers(cfg.cors.allowed_headers);
|
||||
|
||||
#[cfg(feature = "stub_status")]
|
||||
let status_endpoints = {
|
||||
let (r1, r3) = (receiver.clone(), receiver.clone());
|
||||
warp::path!("api" / "v1" / "streaming" / "health")
|
||||
.map(|| "OK")
|
||||
.or(warp::path!("api" / "v1" / "streaming" / "status")
|
||||
.and(warp::path::end())
|
||||
.map(move || r1.lock().unwrap_or_else(Receiver::recover).count()))
|
||||
.or(
|
||||
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline")
|
||||
.map(move || r3.lock().unwrap_or_else(Receiver::recover).list()),
|
||||
)
|
||||
let streaming_server = move || {
|
||||
let manager = shared_manager.clone();
|
||||
let stream = Interval::new(Instant::now(), poll_freq)
|
||||
.map_err(|e| log::error!("{}", e))
|
||||
.for_each(move |_| {
|
||||
let mut manager = manager.lock().unwrap_or_else(redis::Manager::recover);
|
||||
manager.poll_broadcast().unwrap_or_else(FatalErr::exit);
|
||||
Ok(())
|
||||
});
|
||||
warp::spawn(lazy(move || stream));
|
||||
warp::serve(ws.or(sse).with(cors).or(status).recover(Handler::err))
|
||||
};
|
||||
#[cfg(not(feature = "stub_status"))]
|
||||
let status_endpoints = warp::path!("api" / "v1" / "streaming" / "health").map(|| "OK");
|
||||
|
||||
if let Some(socket) = &*cfg.unix_socket {
|
||||
log::info!("Using Unix socket {}", socket);
|
||||
fs::remove_file(socket).unwrap_or_default();
|
||||
let incoming = UnixListener::bind(socket).unwrap().incoming();
|
||||
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap();
|
||||
let incoming = UnixListener::bind(socket).expect("TODO").incoming();
|
||||
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).expect("TODO");
|
||||
|
||||
warp::serve(
|
||||
ws_routes
|
||||
.or(sse_routes)
|
||||
.with(cors)
|
||||
.or(status_endpoints)
|
||||
.recover(|r: Rejection| {
|
||||
let json_err = match r.cause() {
|
||||
Some(text)
|
||||
if text.to_string() == "Missing request header 'authorization'" =>
|
||||
{
|
||||
warp::reply::json(&"Error: Missing access token".to_string())
|
||||
}
|
||||
Some(text) => warp::reply::json(&text.to_string()),
|
||||
None => warp::reply::json(&"Error: Nonexistant endpoint".to_string()),
|
||||
};
|
||||
Ok(warp::reply::with_status(json_err, StatusCode::UNAUTHORIZED))
|
||||
}),
|
||||
)
|
||||
.run_incoming(incoming);
|
||||
tokio::run(lazy(|| streaming_server().serve_incoming(incoming)));
|
||||
} else {
|
||||
use futures::{future::lazy, stream::Stream as _Stream};
|
||||
use std::time::Instant;
|
||||
|
||||
let server_addr = SocketAddr::new(*cfg.address, *cfg.port);
|
||||
|
||||
tokio::run(lazy(move || {
|
||||
let receiver = receiver.clone();
|
||||
|
||||
warp::spawn(lazy(move || {
|
||||
tokio::timer::Interval::new(Instant::now(), poll_freq)
|
||||
.map_err(|e| log::error!("{}", e))
|
||||
.for_each(move |_| {
|
||||
let mut receiver = receiver.lock().unwrap_or_else(Receiver::recover);
|
||||
receiver.poll_broadcast().unwrap_or_else(FatalErr::exit);
|
||||
Ok(())
|
||||
})
|
||||
}));
|
||||
|
||||
warp::serve(ws_routes.or(sse_routes).with(cors).or(status_endpoints)).bind(server_addr)
|
||||
}));
|
||||
};
|
||||
tokio::run(lazy(move || streaming_server().bind(server_addr)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
mod event;
|
||||
|
||||
pub use event::{CheckedEvent, DynEvent, Event, EventErr, EventKind, Id};
|
File diff suppressed because one or more lines are too long
|
@ -1,17 +0,0 @@
|
|||
//! Parse the client request and return a Subscription
|
||||
mod postgres;
|
||||
mod query;
|
||||
|
||||
mod subscription;
|
||||
|
||||
pub use self::postgres::PgPool;
|
||||
// TODO consider whether we can remove `Stream` from public API
|
||||
pub use subscription::{Blocks, Stream, Subscription, Timeline};
|
||||
|
||||
//#[cfg(test)]
|
||||
pub use subscription::{Content, Reach};
|
||||
|
||||
#[cfg(test)]
|
||||
mod sse_test;
|
||||
#[cfg(test)]
|
||||
mod ws_test;
|
|
@ -1,188 +0,0 @@
|
|||
//! Postgres queries
|
||||
use crate::{
|
||||
config,
|
||||
messages::Id,
|
||||
parse_client_request::subscription::{Scope, UserData},
|
||||
};
|
||||
use ::postgres;
|
||||
use hashbrown::HashSet;
|
||||
use r2d2_postgres::PostgresConnectionManager;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
|
||||
impl PgPool {
|
||||
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
|
||||
let mut cfg = postgres::Config::new();
|
||||
cfg.user(&pg_cfg.user)
|
||||
.host(&*pg_cfg.host.to_string())
|
||||
.port(*pg_cfg.port)
|
||||
.dbname(&pg_cfg.database);
|
||||
if let Some(password) = &*pg_cfg.password {
|
||||
cfg.password(password);
|
||||
};
|
||||
|
||||
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
||||
let pool = r2d2::Pool::builder()
|
||||
.max_size(10)
|
||||
.build(manager)
|
||||
.expect("Can connect to local postgres");
|
||||
Self(pool)
|
||||
}
|
||||
|
||||
pub fn select_user(self, token: &str) -> Result<UserData, Rejection> {
|
||||
let mut conn = self.0.get().unwrap();
|
||||
let query_rows = conn
|
||||
.query(
|
||||
"
|
||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||
FROM
|
||||
oauth_access_tokens
|
||||
INNER JOIN users ON
|
||||
oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1
|
||||
AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&token.to_owned()],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if let Some(result_columns) = query_rows.get(0) {
|
||||
let id = Id(result_columns.get(1));
|
||||
let allowed_langs = result_columns
|
||||
.try_get::<_, Vec<_>>(2)
|
||||
.unwrap_or_else(|_| Vec::new())
|
||||
.into_iter()
|
||||
.collect();
|
||||
let mut scopes: HashSet<Scope> = result_columns
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.filter_map(|scope| match scope {
|
||||
"read" => Some(Scope::Read),
|
||||
"read:statuses" => Some(Scope::Statuses),
|
||||
"read:notifications" => Some(Scope::Notifications),
|
||||
"read:lists" => Some(Scope::Lists),
|
||||
"write" | "follow" => None, // ignore write scopes
|
||||
unexpected => {
|
||||
log::warn!("Ignoring unknown scope `{}`", unexpected);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// We don't need to separately track read auth - it's just all three others
|
||||
if scopes.remove(&Scope::Read) {
|
||||
scopes.insert(Scope::Statuses);
|
||||
scopes.insert(Scope::Notifications);
|
||||
scopes.insert(Scope::Lists);
|
||||
}
|
||||
|
||||
Ok(UserData {
|
||||
id,
|
||||
allowed_langs,
|
||||
scopes,
|
||||
})
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
|
||||
let mut conn = self.0.get().unwrap();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id
|
||||
FROM tags
|
||||
WHERE name = $1
|
||||
LIMIT 1",
|
||||
&[&tag_name],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
rows.get(0)
|
||||
.map(|row| row.get(0))
|
||||
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
||||
}
|
||||
|
||||
/// Query Postgres for everyone the user has blocked or muted
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
|
||||
self.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"
|
||||
SELECT target_account_id
|
||||
FROM blocks
|
||||
WHERE account_id = $1
|
||||
UNION SELECT target_account_id
|
||||
FROM mutes
|
||||
WHERE account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| Id(row.get(0)))
|
||||
.collect()
|
||||
}
|
||||
/// Query Postgres for everyone who has blocked the user
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
|
||||
self.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"
|
||||
SELECT account_id
|
||||
FROM blocks
|
||||
WHERE target_account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| Id(row.get(0)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Query Postgres for all current domain blocks
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
|
||||
self.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Test whether a user owns a list
|
||||
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
|
||||
let mut conn = self.0.get().unwrap();
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id, account_id
|
||||
FROM lists
|
||||
WHERE id = $1
|
||||
LIMIT 1",
|
||||
&[&list_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
match rows.get(0) {
|
||||
None => false,
|
||||
Some(row) => Id(row.get(1)) == user_id,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,320 +0,0 @@
|
|||
//! `User` struct and related functionality
|
||||
// #[cfg(test)]
|
||||
// mod mock_postgres;
|
||||
// #[cfg(test)]
|
||||
// use mock_postgres as postgres;
|
||||
// #[cfg(not(test))]
|
||||
|
||||
use super::postgres::PgPool;
|
||||
use super::query;
|
||||
use super::query::Query;
|
||||
use crate::err::TimelineErr;
|
||||
|
||||
use crate::messages::Id;
|
||||
|
||||
use hashbrown::HashSet;
|
||||
use lru::LruCache;
|
||||
use warp::{filters::BoxedFilter, path, reject::Rejection, Filter};
|
||||
|
||||
/// Helper macro to match on the first of any of the provided filters
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*.boxed()
|
||||
};
|
||||
}
|
||||
macro_rules! parse_sse_query {
|
||||
(path => $start:tt $(/ $next:tt)*
|
||||
endpoint => $endpoint:expr) => {
|
||||
path!($start $(/ $next)*)
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: $endpoint.to_string(),
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Subscription {
|
||||
pub timeline: Timeline,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub blocks: Blocks,
|
||||
pub hashtag_name: Option<String>,
|
||||
pub access_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq)]
|
||||
pub struct Blocks {
|
||||
pub blocked_domains: HashSet<String>,
|
||||
pub blocked_users: HashSet<Id>,
|
||||
pub blocking_users: HashSet<Id>,
|
||||
}
|
||||
|
||||
impl Default for Subscription {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
|
||||
allowed_langs: HashSet::new(),
|
||||
blocks: Blocks::default(),
|
||||
hashtag_name: None,
|
||||
access_token: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Subscription {
|
||||
pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
|
||||
parse_ws_query()
|
||||
.and(query::OptionalAccessToken::from_ws_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
|
||||
any_of!(
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user" / "notification"
|
||||
endpoint => "user:notification" ),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user"
|
||||
endpoint => "user"),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public" / "local"
|
||||
endpoint => "public:local"),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public"
|
||||
endpoint => "public"),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "direct"
|
||||
endpoint => "direct"),
|
||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
||||
endpoint => "hashtag:local"),
|
||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
|
||||
endpoint => "hashtag"),
|
||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
|
||||
endpoint => "list")
|
||||
)
|
||||
// because SSE requests place their `access_token` in the header instead of in a query
|
||||
// parameter, we need to update our Query if the header has a token
|
||||
.and(query::OptionalAccessToken::from_sse_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
|
||||
let user = match q.access_token.clone() {
|
||||
Some(token) => pool.clone().select_user(&token)?,
|
||||
None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
|
||||
None => UserData::public(),
|
||||
};
|
||||
let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
|
||||
let hashtag_name = match timeline {
|
||||
Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
|
||||
_non_hashtag_timeline => None,
|
||||
};
|
||||
|
||||
Ok(Subscription {
|
||||
timeline,
|
||||
allowed_langs: user.allowed_langs,
|
||||
blocks: Blocks {
|
||||
blocking_users: pool.clone().select_blocking_users(user.id),
|
||||
blocked_users: pool.clone().select_blocked_users(user.id),
|
||||
blocked_domains: pool.select_blocked_domains(user.id),
|
||||
},
|
||||
hashtag_name,
|
||||
access_token: q.access_token,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
||||
path!("api" / "v1" / "streaming")
|
||||
.and(path::end())
|
||||
.and(warp::query())
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|stream: query::Stream,
|
||||
auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: stream.stream,
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub struct Timeline(pub Stream, pub Reach, pub Content);
|
||||
|
||||
impl Timeline {
|
||||
pub fn empty() -> Self {
|
||||
use {Content::*, Reach::*, Stream::*};
|
||||
Self(Unset, Local, Notification)
|
||||
}
|
||||
|
||||
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String, TimelineErr> {
|
||||
use {Content::*, Reach::*, Stream::*};
|
||||
Ok(match self {
|
||||
Timeline(Public, Federated, All) => "timeline:public".into(),
|
||||
Timeline(Public, Local, All) => "timeline:public:local".into(),
|
||||
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
|
||||
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
|
||||
|
||||
Timeline(Hashtag(_id), Federated, All) => format!(
|
||||
"timeline:hashtag:{}",
|
||||
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
|
||||
),
|
||||
Timeline(Hashtag(_id), Local, All) => format!(
|
||||
"timeline:hashtag:{}:local",
|
||||
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
|
||||
),
|
||||
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
|
||||
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
|
||||
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
|
||||
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
|
||||
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_redis_text(
|
||||
timeline: &str,
|
||||
cache: &mut LruCache<String, i64>,
|
||||
) -> Result<Self, TimelineErr> {
|
||||
let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) {
|
||||
Some(id) => Ok(*id),
|
||||
None => Err(TimelineErr::InvalidInput), // TODO more specific
|
||||
};
|
||||
|
||||
use {Content::*, Reach::*, Stream::*};
|
||||
Ok(match &timeline.split(':').collect::<Vec<&str>>()[..] {
|
||||
["public"] => Timeline(Public, Federated, All),
|
||||
["public", "local"] => Timeline(Public, Local, All),
|
||||
["public", "media"] => Timeline(Public, Federated, Media),
|
||||
["public", "local", "media"] => Timeline(Public, Local, Media),
|
||||
["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All),
|
||||
["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All),
|
||||
[id] => Timeline(User(id.parse()?), Federated, All),
|
||||
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
|
||||
["list", id] => Timeline(List(id.parse()?), Federated, All),
|
||||
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
|
||||
// Other endpoints don't exist:
|
||||
[..] => Err(TimelineErr::InvalidInput)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
|
||||
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
|
||||
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
|
||||
let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
|
||||
|
||||
Ok(match q.stream.as_ref() {
|
||||
"public" => match q.media {
|
||||
true => Timeline(Public, Federated, Media),
|
||||
false => Timeline(Public, Federated, All),
|
||||
},
|
||||
"public:local" => match q.media {
|
||||
true => Timeline(Public, Local, Media),
|
||||
false => Timeline(Public, Local, All),
|
||||
},
|
||||
"public:media" => Timeline(Public, Federated, Media),
|
||||
"public:local:media" => Timeline(Public, Local, Media),
|
||||
|
||||
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
|
||||
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
|
||||
"user" => match user.scopes.contains(&Statuses) {
|
||||
true => Timeline(User(user.id), Federated, All),
|
||||
false => Err(custom("Error: Missing access token"))?,
|
||||
},
|
||||
"user:notification" => match user.scopes.contains(&Statuses) {
|
||||
true => Timeline(User(user.id), Federated, Notification),
|
||||
false => Err(custom("Error: Missing access token"))?,
|
||||
},
|
||||
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
|
||||
true => Timeline(List(q.list), Federated, All),
|
||||
false => Err(warp::reject::custom("Error: Missing access token"))?,
|
||||
},
|
||||
"direct" => match user.scopes.contains(&Statuses) {
|
||||
true => Timeline(Direct(*user.id), Federated, All),
|
||||
false => Err(custom("Error: Missing access token"))?,
|
||||
},
|
||||
other => {
|
||||
log::warn!("Request for nonexistent endpoint: `{}`", other);
|
||||
Err(custom("Error: Nonexistent endpoint"))?
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub enum Stream {
|
||||
User(Id),
|
||||
// TODO consider whether List, Direct, and Hashtag should all be `id::Id`s
|
||||
List(i64),
|
||||
Direct(i64),
|
||||
Hashtag(i64),
|
||||
Public,
|
||||
Unset,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub enum Reach {
|
||||
Local,
|
||||
Federated,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub enum Content {
|
||||
All,
|
||||
Media,
|
||||
Notification,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Scope {
|
||||
Read,
|
||||
Statuses,
|
||||
Notifications,
|
||||
Lists,
|
||||
}
|
||||
|
||||
pub struct UserData {
|
||||
pub id: Id,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub scopes: HashSet<Scope>,
|
||||
}
|
||||
|
||||
impl UserData {
|
||||
fn public() -> Self {
|
||||
Self {
|
||||
id: Id(-1),
|
||||
allowed_langs: HashSet::new(),
|
||||
scopes: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,186 +0,0 @@
|
|||
use crate::messages::Event;
|
||||
use crate::parse_client_request::{Subscription, Timeline};
|
||||
|
||||
use futures::{future::Future, stream::Stream};
|
||||
use log;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use warp::{
|
||||
reply::Reply,
|
||||
sse::{ServerSentEvent, Sse},
|
||||
ws::{Message, WebSocket},
|
||||
};
|
||||
|
||||
pub struct WsStream {
|
||||
ws_tx: mpsc::UnboundedSender<Message>,
|
||||
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
subscription: Subscription,
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
pub fn new(
|
||||
ws: WebSocket,
|
||||
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
subscription: Subscription,
|
||||
) -> Self {
|
||||
let (transmit_to_ws, _receive_from_ws) = ws.split();
|
||||
// Create a pipe
|
||||
let (ws_tx, ws_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Send one end of it to a different green thread and tell that end to forward
|
||||
// whatever it gets on to the WebSocket client
|
||||
warp::spawn(
|
||||
ws_rx
|
||||
.map_err(|_| -> warp::Error { unreachable!() })
|
||||
.forward(transmit_to_ws)
|
||||
.map(|_r| ())
|
||||
.map_err(|e| match e.to_string().as_ref() {
|
||||
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
|
||||
_ => log::warn!("WebSocket send error: {}", e),
|
||||
}),
|
||||
);
|
||||
Self {
|
||||
ws_tx,
|
||||
unsubscribe_tx,
|
||||
subscription,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_events(
|
||||
mut self,
|
||||
event_rx: watch::Receiver<(Timeline, Event)>,
|
||||
) -> impl Future<Item = (), Error = ()> {
|
||||
let target_timeline = self.subscription.timeline;
|
||||
|
||||
event_rx.map_err(|_| ()).for_each(move |(tl, event)| {
|
||||
if matches!(event, Event::Ping) {
|
||||
self.send_ping()
|
||||
} else if target_timeline == tl {
|
||||
use crate::messages::{CheckedEvent::Update, Event::*, EventKind};
|
||||
use crate::parse_client_request::Stream::Public;
|
||||
let blocks = &self.subscription.blocks;
|
||||
let allowed_langs = &self.subscription.allowed_langs;
|
||||
|
||||
match event {
|
||||
TypeSafe(Update { payload, queued_at }) => match tl {
|
||||
Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()),
|
||||
_ if payload.involves_any(&blocks) => Ok(()),
|
||||
_ => self.send_msg(TypeSafe(Update { payload, queued_at })),
|
||||
},
|
||||
TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)),
|
||||
Dynamic(dyn_event) => {
|
||||
if let EventKind::Update(s) = dyn_event.kind.clone() {
|
||||
match tl {
|
||||
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
|
||||
_ if s.involves_any(&blocks) => Ok(()),
|
||||
_ => self.send_msg(Dynamic(dyn_event)),
|
||||
}
|
||||
} else {
|
||||
self.send_msg(Dynamic(dyn_event))
|
||||
}
|
||||
}
|
||||
Ping => unreachable!(), // handled pings above
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn send_ping(&mut self) -> Result<(), ()> {
|
||||
self.send_txt("{}")
|
||||
}
|
||||
|
||||
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
|
||||
self.send_txt(&event.to_json_string())
|
||||
}
|
||||
|
||||
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
|
||||
let tl = self.subscription.timeline;
|
||||
match self.ws_tx.try_send(Message::text(txt)) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => {
|
||||
self.unsubscribe_tx
|
||||
.try_send(tl)
|
||||
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SseStream {}
|
||||
|
||||
impl SseStream {
|
||||
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
|
||||
Some((
|
||||
warp::sse::event(event.event_name()),
|
||||
warp::sse::data(event.payload().unwrap_or_else(String::new)),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn send_events(
|
||||
sse: Sse,
|
||||
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
subscription: Subscription,
|
||||
sse_rx: watch::Receiver<(Timeline, Event)>,
|
||||
) -> impl Reply {
|
||||
let target_timeline = subscription.timeline;
|
||||
let allowed_langs = subscription.allowed_langs;
|
||||
let blocks = subscription.blocks;
|
||||
|
||||
let event_stream = sse_rx
|
||||
.filter_map(move |(timeline, event)| {
|
||||
if target_timeline == timeline {
|
||||
use crate::messages::{
|
||||
CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind,
|
||||
};
|
||||
|
||||
use crate::parse_client_request::Stream::Public;
|
||||
match event {
|
||||
TypeSafe(Update { payload, queued_at }) => match timeline {
|
||||
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
||||
_ if payload.involves_any(&blocks) => None,
|
||||
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
|
||||
payload,
|
||||
queued_at,
|
||||
})),
|
||||
},
|
||||
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
|
||||
Dynamic(dyn_event) => {
|
||||
if let EventKind::Update(s) = dyn_event.kind {
|
||||
match timeline {
|
||||
Timeline(Public, _, _) if s.language_not(&allowed_langs) => {
|
||||
None
|
||||
}
|
||||
_ if s.involves_any(&blocks) => None,
|
||||
_ => Self::reply_with(Dynamic(DynEvent {
|
||||
kind: EventKind::Update(s),
|
||||
..dyn_event
|
||||
})),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Ping => None, // pings handled automatically
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.then(move |res| {
|
||||
unsubscribe_tx
|
||||
.try_send(target_timeline)
|
||||
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
|
||||
res
|
||||
});
|
||||
|
||||
sse.reply(
|
||||
warp::sse::keep_alive()
|
||||
.interval(Duration::from_secs(30))
|
||||
.text("thump".to_string())
|
||||
.stream(event_stream),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
|
||||
mod event_stream;
|
||||
mod receiver;
|
||||
mod redis;
|
||||
|
||||
pub use {
|
||||
event_stream::{SseStream, WsStream},
|
||||
receiver::{Receiver, ReceiverErr},
|
||||
};
|
||||
|
||||
#[cfg(feature = "bench")]
|
||||
pub use redis::redis_msg::{RedisMsg, RedisParseOutput};
|
|
@ -1,5 +0,0 @@
|
|||
pub mod redis_connection;
|
||||
pub mod redis_msg;
|
||||
|
||||
pub use redis_connection::{RedisConn, RedisConnErr};
|
||||
pub use redis_msg::RedisParseErr;
|
|
@ -0,0 +1,145 @@
|
|||
//! Parse the client request and return a Subscription
|
||||
mod postgres;
|
||||
mod query;
|
||||
pub mod timeline;
|
||||
|
||||
mod subscription;
|
||||
|
||||
pub use self::postgres::PgPool;
|
||||
// TODO consider whether we can remove `Stream` from public API
|
||||
pub use subscription::{Blocks, Subscription};
|
||||
pub use timeline::{Content, Reach, Stream, Timeline, TimelineErr};
|
||||
|
||||
use self::query::Query;
|
||||
use crate::config;
|
||||
use warp::filters::BoxedFilter;
|
||||
use warp::http::StatusCode;
|
||||
use warp::path;
|
||||
use warp::{Filter, Rejection};
|
||||
|
||||
#[cfg(test)]
|
||||
mod sse_test;
|
||||
#[cfg(test)]
|
||||
mod ws_test;
|
||||
|
||||
/// Helper macro to match on the first of any of the provided filters
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*.boxed()
|
||||
};
|
||||
}
|
||||
macro_rules! parse_sse_query {
|
||||
(path => $start:tt $(/ $next:tt)*
|
||||
endpoint => $endpoint:expr) => {
|
||||
path!($start $(/ $next)*)
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(|auth: query::Auth, media: query::Media, hashtag: query::Hashtag, list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: $endpoint.to_string(),
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Handler {
|
||||
pg_conn: PgPool,
|
||||
}
|
||||
|
||||
impl Handler {
|
||||
pub fn new(postgres_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
||||
Self {
|
||||
pg_conn: PgPool::new(postgres_cfg, whitelist_mode),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sse_subscription(&self) -> BoxedFilter<(Subscription,)> {
|
||||
let pg_conn = self.pg_conn.clone();
|
||||
any_of!(
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "user" / "notification"
|
||||
endpoint => "user:notification" ),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "user"
|
||||
endpoint => "user"),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "public" / "local"
|
||||
endpoint => "public:local"),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "public"
|
||||
endpoint => "public"),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "direct"
|
||||
endpoint => "direct"),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
||||
endpoint => "hashtag:local"),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "hashtag"
|
||||
endpoint => "hashtag"),
|
||||
parse_sse_query!( path => "api" / "v1" / "streaming" / "list"
|
||||
endpoint => "list")
|
||||
)
|
||||
// because SSE requests place their `access_token` in the header instead of in a query
|
||||
// parameter, we need to update our Query if the header has a token
|
||||
.and(query::OptionalAccessToken::from_sse_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::query_postgres(q, pg_conn.clone()))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub fn ws_subscription(&self) -> BoxedFilter<(Subscription,)> {
|
||||
let pg_conn = self.pg_conn.clone();
|
||||
parse_ws_query()
|
||||
.and(query::OptionalAccessToken::from_ws_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::query_postgres(q, pg_conn.clone()))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub fn health(&self) -> BoxedFilter<()> {
|
||||
warp::path!("api" / "v1" / "streaming" / "health").boxed()
|
||||
}
|
||||
|
||||
pub fn status(&self) -> BoxedFilter<()> {
|
||||
warp::path!("api" / "v1" / "streaming" / "status")
|
||||
.and(warp::path::end())
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub fn status_per_timeline(&self) -> BoxedFilter<()> {
|
||||
warp::path!("api" / "v1" / "streaming" / "status" / "per_timeline").boxed()
|
||||
}
|
||||
|
||||
pub fn err(r: Rejection) -> Result<impl warp::Reply, warp::Rejection> {
|
||||
let json_err = match r.cause() {
|
||||
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
||||
warp::reply::json(&"Error: Missing access token".to_string())
|
||||
}
|
||||
Some(text) => warp::reply::json(&text.to_string()),
|
||||
None => warp::reply::json(&"Error: Nonexistant endpoint".to_string()),
|
||||
};
|
||||
Ok(warp::reply::with_status(json_err, StatusCode::UNAUTHORIZED))
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
||||
use query::*;
|
||||
path!("api" / "v1" / "streaming")
|
||||
.and(path::end())
|
||||
.and(warp::query())
|
||||
.and(Auth::to_filter())
|
||||
.and(Media::to_filter())
|
||||
.and(Hashtag::to_filter())
|
||||
.and(List::to_filter())
|
||||
.map(|s: Stream, a: Auth, m: Media, h: Hashtag, l: List| Query {
|
||||
access_token: a.access_token,
|
||||
stream: s.stream,
|
||||
media: m.is_truthy(),
|
||||
hashtag: h.tag,
|
||||
list: l.list,
|
||||
})
|
||||
.boxed()
|
||||
}
|
|
@ -0,0 +1,157 @@
|
|||
//! Postgres queries
|
||||
use crate::config;
|
||||
use crate::event::Id;
|
||||
use crate::request::timeline::{Scope, UserData};
|
||||
|
||||
use ::postgres;
|
||||
use hashbrown::HashSet;
|
||||
use r2d2_postgres::PostgresConnectionManager;
|
||||
use std::convert::TryFrom;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PgPool {
|
||||
pub conn: r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>,
|
||||
whitelist_mode: bool,
|
||||
}
|
||||
|
||||
impl PgPool {
|
||||
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
||||
let mut cfg = postgres::Config::new();
|
||||
cfg.user(&pg_cfg.user)
|
||||
.host(&*pg_cfg.host.to_string())
|
||||
.port(*pg_cfg.port)
|
||||
.dbname(&pg_cfg.database);
|
||||
if let Some(password) = &*pg_cfg.password {
|
||||
cfg.password(password);
|
||||
};
|
||||
|
||||
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
||||
let pool = r2d2::Pool::builder()
|
||||
.max_size(10)
|
||||
.build(manager)
|
||||
.expect("Can connect to local postgres");
|
||||
Self {
|
||||
conn: pool,
|
||||
whitelist_mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_user(self, token: &Option<String>) -> Result<UserData, Rejection> {
|
||||
let mut conn = self.conn.get().unwrap();
|
||||
if let Some(token) = token {
|
||||
let query_rows = conn
|
||||
.query("
|
||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||
FROM oauth_access_tokens
|
||||
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&token.to_owned()],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if let Some(result_columns) = query_rows.get(0) {
|
||||
let id = Id(result_columns.get(1));
|
||||
let allowed_langs = result_columns
|
||||
.try_get::<_, Vec<_>>(2)
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let mut scopes: HashSet<Scope> = result_columns
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.filter_map(|scope| Scope::try_from(scope).ok())
|
||||
.collect();
|
||||
// We don't need to separately track read auth - it's just all three others
|
||||
if scopes.contains(&Scope::Read) {
|
||||
scopes = vec![Scope::Statuses, Scope::Notifications, Scope::Lists]
|
||||
.into_iter()
|
||||
.collect()
|
||||
}
|
||||
|
||||
Ok(UserData {
|
||||
id,
|
||||
allowed_langs,
|
||||
scopes,
|
||||
})
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
} else if self.whitelist_mode {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
} else {
|
||||
Ok(UserData::public())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.get(0)
|
||||
.map(|row| row.get(0))
|
||||
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
||||
}
|
||||
|
||||
/// Query Postgres for everyone the user has blocked or muted
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
conn.query(
|
||||
"SELECT target_account_id FROM blocks WHERE account_id = $1
|
||||
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| Id(row.get(0)))
|
||||
.collect()
|
||||
}
|
||||
/// Query Postgres for everyone who has blocked the user
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
conn.query(
|
||||
"SELECT account_id FROM blocks WHERE target_account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| Id(row.get(0)))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Query Postgres for all current domain blocks
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
conn.query(
|
||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||
&[&*user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Test whether a user owns a list
|
||||
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
|
||||
let mut conn = self.conn.get().expect("TODO");
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||
&[&list_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
//! `User` struct and related functionality
|
||||
// #[cfg(test)]
|
||||
// mod mock_postgres;
|
||||
// #[cfg(test)]
|
||||
// use mock_postgres as postgres;
|
||||
// #[cfg(not(test))]
|
||||
|
||||
use super::postgres::PgPool;
|
||||
use super::query::Query;
|
||||
use super::{Content, Reach, Stream, Timeline};
|
||||
use crate::event::Id;
|
||||
|
||||
use hashbrown::HashSet;
|
||||
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Subscription {
|
||||
pub timeline: Timeline,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub blocks: Blocks,
|
||||
pub hashtag_name: Option<String>,
|
||||
pub access_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq)]
|
||||
pub struct Blocks {
|
||||
pub blocked_domains: HashSet<String>,
|
||||
pub blocked_users: HashSet<Id>,
|
||||
pub blocking_users: HashSet<Id>,
|
||||
}
|
||||
|
||||
impl Default for Subscription {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
|
||||
allowed_langs: HashSet::new(),
|
||||
blocks: Blocks::default(),
|
||||
hashtag_name: None,
|
||||
access_token: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Subscription {
|
||||
pub(super) fn query_postgres(q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||
let user = pool.clone().select_user(&q.access_token)?;
|
||||
let timeline = {
|
||||
let tl = Timeline::from_query_and_user(&q, &user)?;
|
||||
let pool = pool.clone();
|
||||
use Stream::*;
|
||||
match tl {
|
||||
Timeline(Hashtag(_), reach, stream) => {
|
||||
let tag = pool.select_hashtag_id(&q.hashtag)?;
|
||||
Timeline(Hashtag(tag), reach, stream)
|
||||
}
|
||||
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
|
||||
Err(warp::reject::custom("Error: Missing access token"))?
|
||||
}
|
||||
other_tl => other_tl,
|
||||
}
|
||||
};
|
||||
|
||||
let hashtag_name = match timeline {
|
||||
Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
|
||||
_non_hashtag_timeline => None,
|
||||
};
|
||||
|
||||
Ok(Subscription {
|
||||
timeline,
|
||||
allowed_langs: user.allowed_langs,
|
||||
blocks: Blocks {
|
||||
blocking_users: pool.clone().select_blocking_users(user.id),
|
||||
blocked_users: pool.clone().select_blocked_users(user.id),
|
||||
blocked_domains: pool.select_blocked_domains(user.id),
|
||||
},
|
||||
hashtag_name,
|
||||
access_token: q.access_token,
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
pub use self::err::TimelineErr;
|
||||
pub use self::inner::{Content, Reach, Scope, Stream, UserData};
|
||||
use super::query::Query;
|
||||
|
||||
use lru::LruCache;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
mod err;
|
||||
mod inner;
|
||||
|
||||
type Result<T> = std::result::Result<T, TimelineErr>;
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub struct Timeline(pub Stream, pub Reach, pub Content);
|
||||
|
||||
impl Timeline {
|
||||
pub fn empty() -> Self {
|
||||
Self(Stream::Unset, Reach::Local, Content::Notification)
|
||||
}
|
||||
|
||||
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String> {
|
||||
use {Content::*, Reach::*, Stream::*};
|
||||
Ok(match self {
|
||||
Timeline(Public, Federated, All) => "timeline:public".into(),
|
||||
Timeline(Public, Local, All) => "timeline:public:local".into(),
|
||||
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
|
||||
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
|
||||
// TODO -- would `.push_str` be faster here?
|
||||
Timeline(Hashtag(_id), Federated, All) => format!(
|
||||
"timeline:hashtag:{}",
|
||||
hashtag.ok_or(TimelineErr::MissingHashtag)?
|
||||
),
|
||||
Timeline(Hashtag(_id), Local, All) => format!(
|
||||
"timeline:hashtag:{}:local",
|
||||
hashtag.ok_or(TimelineErr::MissingHashtag)?
|
||||
),
|
||||
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
|
||||
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
|
||||
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
|
||||
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
|
||||
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_redis_text(timeline: &str, cache: &mut LruCache<String, i64>) -> Result<Self> {
|
||||
use {Content::*, Reach::*, Stream::*, TimelineErr::*};
|
||||
let mut tag_id = |t: &str| cache.get(&t.to_string()).map_or(Err(BadTag), |id| Ok(*id));
|
||||
|
||||
Ok(match &timeline.split(':').collect::<Vec<&str>>()[..] {
|
||||
["public"] => Timeline(Public, Federated, All),
|
||||
["public", "local"] => Timeline(Public, Local, All),
|
||||
["public", "media"] => Timeline(Public, Federated, Media),
|
||||
["public", "local", "media"] => Timeline(Public, Local, Media),
|
||||
["hashtag", tag] => Timeline(Hashtag(tag_id(tag)?), Federated, All),
|
||||
["hashtag", tag, "local"] => Timeline(Hashtag(tag_id(tag)?), Local, All),
|
||||
[id] => Timeline(User(id.parse()?), Federated, All),
|
||||
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
|
||||
["list", id] => Timeline(List(id.parse()?), Federated, All),
|
||||
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
|
||||
// Other endpoints don't exist:
|
||||
[..] => Err(InvalidInput)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_query_and_user(q: &Query, user: &UserData) -> std::result::Result<Self, Rejection> {
|
||||
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
|
||||
|
||||
Ok(match q.stream.as_ref() {
|
||||
"public" => match q.media {
|
||||
true => Timeline(Public, Federated, Media),
|
||||
false => Timeline(Public, Federated, All),
|
||||
},
|
||||
"public:local" => match q.media {
|
||||
true => Timeline(Public, Local, Media),
|
||||
false => Timeline(Public, Local, All),
|
||||
},
|
||||
"public:media" => Timeline(Public, Federated, Media),
|
||||
"public:local:media" => Timeline(Public, Local, Media),
|
||||
|
||||
"hashtag" => Timeline(Hashtag(0), Federated, All),
|
||||
"hashtag:local" => Timeline(Hashtag(0), Local, All),
|
||||
"user" => match user.scopes.contains(&Statuses) {
|
||||
true => Timeline(User(user.id), Federated, All),
|
||||
false => Err(custom("Error: Missing access token"))?,
|
||||
},
|
||||
"user:notification" => match user.scopes.contains(&Statuses) {
|
||||
true => Timeline(User(user.id), Federated, Notification),
|
||||
false => Err(custom("Error: Missing access token"))?,
|
||||
},
|
||||
"list" => match user.scopes.contains(&Lists) {
|
||||
true => Timeline(List(q.list), Federated, All),
|
||||
false => Err(warp::reject::custom("Error: Missing access token"))?,
|
||||
},
|
||||
"direct" => match user.scopes.contains(&Statuses) {
|
||||
true => Timeline(Direct(*user.id), Federated, All),
|
||||
false => Err(custom("Error: Missing access token"))?,
|
||||
},
|
||||
other => {
|
||||
log::warn!("Request for nonexistent endpoint: `{}`", other);
|
||||
Err(custom("Error: Nonexistent endpoint"))?
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,6 +4,7 @@ use std::fmt;
|
|||
pub enum TimelineErr {
|
||||
MissingHashtag,
|
||||
InvalidInput,
|
||||
BadTag,
|
||||
}
|
||||
|
||||
impl std::error::Error for TimelineErr {}
|
||||
|
@ -20,6 +21,7 @@ impl fmt::Display for TimelineErr {
|
|||
let msg = match self {
|
||||
InvalidInput => "The timeline text from Redis could not be parsed into a supported timeline. TODO: add incoming timeline text",
|
||||
MissingHashtag => "Attempted to send a hashtag timeline without supplying a tag name",
|
||||
BadTag => "No hashtag exists with the specified hashtag ID"
|
||||
};
|
||||
write!(f, "{}", msg)
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
use super::TimelineErr;
|
||||
use crate::event::Id;
|
||||
|
||||
use hashbrown::HashSet;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub enum Stream {
|
||||
User(Id),
|
||||
List(i64),
|
||||
Direct(i64),
|
||||
Hashtag(i64),
|
||||
Public,
|
||||
Unset,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub enum Reach {
|
||||
Local,
|
||||
Federated,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub enum Content {
|
||||
All,
|
||||
Media,
|
||||
Notification,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Scope {
|
||||
Read,
|
||||
Statuses,
|
||||
Notifications,
|
||||
Lists,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Scope {
|
||||
type Error = TimelineErr;
|
||||
|
||||
fn try_from(s: &str) -> Result<Self, TimelineErr> {
|
||||
match s {
|
||||
"read" => Ok(Scope::Read),
|
||||
"read:statuses" => Ok(Scope::Statuses),
|
||||
"read:notifications" => Ok(Scope::Notifications),
|
||||
"read:lists" => Ok(Scope::Lists),
|
||||
"write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes
|
||||
unexpected => {
|
||||
log::warn!("Ignoring unknown scope `{}`", unexpected);
|
||||
Err(TimelineErr::InvalidInput)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UserData {
|
||||
pub id: Id,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub scopes: HashSet<Scope>,
|
||||
}
|
||||
|
||||
impl UserData {
|
||||
pub fn public() -> Self {
|
||||
Self {
|
||||
id: Id(-1),
|
||||
allowed_langs: HashSet::new(),
|
||||
scopes: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
|
||||
|
||||
pub mod redis;
|
||||
pub mod stream;
|
||||
|
||||
pub use redis::{Manager, ManagerErr};
|
||||
|
||||
#[cfg(feature = "bench")]
|
||||
pub use redis::msg::{RedisMsg, RedisParseOutput};
|
|
@ -0,0 +1,27 @@
|
|||
pub mod connection;
|
||||
mod manager;
|
||||
pub mod msg;
|
||||
|
||||
pub use connection::{RedisConn, RedisConnErr};
|
||||
pub use manager::{Manager, ManagerErr};
|
||||
pub use msg::RedisParseErr;
|
||||
|
||||
pub enum RedisCmd {
|
||||
Subscribe,
|
||||
Unsubscribe,
|
||||
}
|
||||
|
||||
impl RedisCmd {
|
||||
pub fn into_sendable(&self, tl: &String) -> (Vec<u8>, Vec<u8>) {
|
||||
match self {
|
||||
RedisCmd::Subscribe => (
|
||||
format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(),
|
||||
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl).into_bytes(),
|
||||
),
|
||||
RedisCmd::Unsubscribe => (
|
||||
format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(),
|
||||
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl).into_bytes(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,24 +1,19 @@
|
|||
mod err;
|
||||
pub use err::RedisConnErr;
|
||||
|
||||
use super::super::receiver::ReceiverErr;
|
||||
use super::redis_msg::{RedisParseErr, RedisParseOutput};
|
||||
use crate::{
|
||||
config::RedisConfig,
|
||||
messages::Event,
|
||||
parse_client_request::{Stream, Timeline},
|
||||
};
|
||||
|
||||
use std::{
|
||||
convert::{TryFrom, TryInto},
|
||||
io::{Read, Write},
|
||||
net::TcpStream,
|
||||
str,
|
||||
time::Duration,
|
||||
};
|
||||
use super::msg::{RedisParseErr, RedisParseOutput};
|
||||
use super::{ManagerErr, RedisCmd};
|
||||
use crate::config::Redis;
|
||||
use crate::event::Event;
|
||||
use crate::request::{Stream, Timeline};
|
||||
|
||||
use futures::{Async, Poll};
|
||||
use lru::LruCache;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
use std::str;
|
||||
use std::time::Duration;
|
||||
|
||||
type Result<T> = std::result::Result<T, RedisConnErr>;
|
||||
|
||||
|
@ -33,7 +28,7 @@ pub struct RedisConn {
|
|||
}
|
||||
|
||||
impl RedisConn {
|
||||
pub fn new(redis_cfg: RedisConfig) -> Result<Self> {
|
||||
pub fn new(redis_cfg: Redis) -> Result<Self> {
|
||||
let addr = format!("{}:{}", *redis_cfg.host, *redis_cfg.port);
|
||||
let conn = Self::new_connection(&addr, redis_cfg.password.as_ref())?;
|
||||
conn.set_nonblocking(true)
|
||||
|
@ -46,29 +41,24 @@ impl RedisConn {
|
|||
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
|
||||
// the tag number instead of the tag name. This would save us from dealing
|
||||
// with a cache here and would be consistent with how lists/users are handled.
|
||||
redis_namespace: redis_cfg.namespace.clone(),
|
||||
redis_namespace: redis_cfg.namespace.clone().0,
|
||||
redis_input: Vec::new(),
|
||||
};
|
||||
Ok(redis_conn)
|
||||
}
|
||||
|
||||
pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ReceiverErr> {
|
||||
pub fn poll_redis(&mut self) -> Poll<Option<(Timeline, Event)>, ManagerErr> {
|
||||
let mut size = 100; // large enough to handle subscribe/unsubscribe notice
|
||||
let (mut buffer, mut first_read) = (vec![0u8; size], true);
|
||||
loop {
|
||||
match self.primary.read(&mut buffer) {
|
||||
Ok(n) if n != size => {
|
||||
self.redis_input.extend_from_slice(&buffer[..n]);
|
||||
break;
|
||||
}
|
||||
Ok(n) => {
|
||||
self.redis_input.extend_from_slice(&buffer[..n]);
|
||||
}
|
||||
Ok(n) if n != size => break self.redis_input.extend_from_slice(&buffer[..n]),
|
||||
Ok(n) => self.redis_input.extend_from_slice(&buffer[..n]),
|
||||
Err(_) => break,
|
||||
};
|
||||
if first_read {
|
||||
size = 2000;
|
||||
buffer = vec![0u8; size];
|
||||
buffer = vec![0_u8; size];
|
||||
first_read = false;
|
||||
}
|
||||
}
|
||||
|
@ -76,6 +66,8 @@ impl RedisConn {
|
|||
if self.redis_input.is_empty() {
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
|
||||
// at this point, we have the raw bytes; now, parse what we can and leave the remainder
|
||||
let input = self.redis_input.clone();
|
||||
self.redis_input.clear();
|
||||
|
||||
|
@ -90,22 +82,22 @@ impl RedisConn {
|
|||
let (res, leftover) = match RedisParseOutput::try_from(input) {
|
||||
Ok(Msg(msg)) => match &self.redis_namespace {
|
||||
Some(ns) if msg.timeline_txt.starts_with(&format!("{}:timeline:", ns)) => {
|
||||
let trimmed_tl_txt = &msg.timeline_txt[ns.len() + ":timeline:".len()..];
|
||||
let tl = Timeline::from_redis_text(trimmed_tl_txt, &mut self.tag_id_cache)?;
|
||||
let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..];
|
||||
let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?;
|
||||
let event = msg.event_txt.try_into()?;
|
||||
(Ok(Ready(Some((tl, event)))), msg.leftover_input)
|
||||
(Ok(Ready(Some((tl, event)))), (msg.leftover_input))
|
||||
}
|
||||
None => {
|
||||
let trimmed_tl_txt = &msg.timeline_txt["timeline:".len()..];
|
||||
let tl = Timeline::from_redis_text(trimmed_tl_txt, &mut self.tag_id_cache)?;
|
||||
let trimmed_tl = &msg.timeline_txt["timeline:".len()..];
|
||||
let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?;
|
||||
let event = msg.event_txt.try_into()?;
|
||||
(Ok(Ready(Some((tl, event)))), msg.leftover_input)
|
||||
(Ok(Ready(Some((tl, event)))), (msg.leftover_input))
|
||||
}
|
||||
Some(_non_matching_namespace) => (Ok(Ready(None)), msg.leftover_input),
|
||||
},
|
||||
Ok(NonMsg(leftover)) => (Ok(Ready(None)), leftover),
|
||||
Err(RedisParseErr::Incomplete) => (Ok(NotReady), input),
|
||||
Err(other_parse_err) => (Err(ReceiverErr::RedisParseErr(other_parse_err)), input),
|
||||
Err(other_parse_err) => (Err(ManagerErr::RedisParseErr(other_parse_err)), input),
|
||||
};
|
||||
self.redis_input.extend_from_slice(leftover.as_bytes());
|
||||
self.redis_input.extend_from_slice(invalid_bytes);
|
||||
|
@ -117,39 +109,47 @@ impl RedisConn {
|
|||
self.tag_name_cache.put(id, hashtag);
|
||||
}
|
||||
|
||||
fn new_connection(addr: &str, pass: Option<&String>) -> Result<TcpStream> {
|
||||
match TcpStream::connect(&addr) {
|
||||
Ok(mut conn) => {
|
||||
if let Some(password) = pass {
|
||||
Self::auth_connection(&mut conn, &addr, password)?;
|
||||
}
|
||||
pub fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> {
|
||||
let hashtag = match timeline {
|
||||
Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id),
|
||||
_non_hashtag_timeline => None,
|
||||
};
|
||||
|
||||
Self::validate_connection(&mut conn, &addr)?;
|
||||
conn.set_read_timeout(Some(Duration::from_millis(10)))
|
||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||
Ok(conn)
|
||||
}
|
||||
Err(e) => Err(RedisConnErr::with_addr(&addr, e)),
|
||||
}
|
||||
let tl = timeline.to_redis_raw_timeline(hashtag)?;
|
||||
let (primary_cmd, secondary_cmd) = cmd.into_sendable(&tl);
|
||||
self.primary.write_all(&primary_cmd)?;
|
||||
self.secondary.write_all(&secondary_cmd)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn new_connection(addr: &str, pass: Option<&String>) -> Result<TcpStream> {
|
||||
let mut conn = TcpStream::connect(&addr)?;
|
||||
if let Some(password) = pass {
|
||||
Self::auth_connection(&mut conn, &addr, password)?;
|
||||
}
|
||||
|
||||
Self::validate_connection(&mut conn, &addr)?;
|
||||
conn.set_read_timeout(Some(Duration::from_millis(10)))
|
||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
|
||||
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
|
||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||
let mut buffer = vec![0u8; 5];
|
||||
let mut buffer = vec![0_u8; 5];
|
||||
conn.read_exact(&mut buffer)
|
||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||
let reply = String::from_utf8_lossy(&buffer);
|
||||
match &*reply {
|
||||
"+OK\r\n" => (),
|
||||
_ => Err(RedisConnErr::IncorrectPassword(pass.to_string()))?,
|
||||
};
|
||||
if String::from_utf8_lossy(&buffer) != "+OK\r\n" {
|
||||
Err(RedisConnErr::IncorrectPassword(pass.to_string()))?
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> {
|
||||
conn.write_all(b"PING\r\n")
|
||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||
let mut buffer = vec![0u8; 7];
|
||||
let mut buffer = vec![0_u8; 7];
|
||||
conn.read_exact(&mut buffer)
|
||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||
let reply = String::from_utf8_lossy(&buffer);
|
||||
|
@ -160,31 +160,4 @@ impl RedisConn {
|
|||
_ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> {
|
||||
let hashtag = match timeline {
|
||||
Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id),
|
||||
_non_hashtag_timeline => None,
|
||||
};
|
||||
|
||||
let tl = timeline.to_redis_raw_timeline(hashtag)?;
|
||||
let (primary_cmd, secondary_cmd) = match cmd {
|
||||
RedisCmd::Subscribe => (
|
||||
format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl),
|
||||
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl),
|
||||
),
|
||||
RedisCmd::Unsubscribe => (
|
||||
format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl),
|
||||
format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl),
|
||||
),
|
||||
};
|
||||
self.primary.write_all(&primary_cmd.as_bytes())?;
|
||||
self.secondary.write_all(&secondary_cmd.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub enum RedisCmd {
|
||||
Subscribe,
|
||||
Unsubscribe,
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::err::TimelineErr;
|
||||
use crate::request::TimelineErr;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
|
@ -2,31 +2,24 @@
|
|||
//! polled by the correct `ClientAgent`. Also manages sububscriptions and
|
||||
//! unsubscriptions to/from Redis.
|
||||
mod err;
|
||||
pub use err::ReceiverErr;
|
||||
pub use err::ManagerErr;
|
||||
|
||||
use super::redis::{redis_connection::RedisCmd, RedisConn};
|
||||
|
||||
use crate::{
|
||||
config,
|
||||
messages::Event,
|
||||
parse_client_request::{Stream, Subscription, Timeline},
|
||||
};
|
||||
use super::{RedisCmd, RedisConn};
|
||||
use crate::config;
|
||||
use crate::event::Event;
|
||||
use crate::request::{Stream, Subscription, Timeline};
|
||||
|
||||
use futures::{Async, Stream as _Stream};
|
||||
use hashbrown::HashMap;
|
||||
use std::sync::{Arc, Mutex, MutexGuard, PoisonError};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
use std::{
|
||||
result,
|
||||
sync::{Arc, Mutex, MutexGuard, PoisonError},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
type Result<T> = result::Result<T, ReceiverErr>;
|
||||
type Result<T> = std::result::Result<T, ManagerErr>;
|
||||
|
||||
/// The item that streams from Redis and is polled by the `ClientAgent`
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver {
|
||||
pub struct Manager {
|
||||
redis_connection: RedisConn,
|
||||
clients_per_timeline: HashMap<Timeline, i32>,
|
||||
tx: watch::Sender<(Timeline, Event)>,
|
||||
|
@ -34,18 +27,16 @@ pub struct Receiver {
|
|||
ping_time: Instant,
|
||||
}
|
||||
|
||||
impl Receiver {
|
||||
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
|
||||
impl Manager {
|
||||
/// Create a new `Manager`, with its own Redis connections (but, as yet, no
|
||||
/// active subscriptions).
|
||||
|
||||
pub fn try_from(
|
||||
redis_cfg: config::RedisConfig,
|
||||
redis_cfg: config::Redis,
|
||||
tx: watch::Sender<(Timeline, Event)>,
|
||||
rx: mpsc::UnboundedReceiver<Timeline>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
redis_connection: RedisConn::new(redis_cfg)?,
|
||||
|
||||
clients_per_timeline: HashMap::new(),
|
||||
tx,
|
||||
rx,
|
||||
|
@ -57,7 +48,7 @@ impl Receiver {
|
|||
Arc::new(Mutex::new(self))
|
||||
}
|
||||
|
||||
pub fn subscribe(&mut self, subscription: &Subscription) -> Result<()> {
|
||||
pub fn subscribe(&mut self, subscription: &Subscription) {
|
||||
let (tag, tl) = (subscription.hashtag_name.clone(), subscription.timeline);
|
||||
if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (tag, tl) {
|
||||
self.redis_connection.update_cache(hashtag, id);
|
||||
|
@ -71,9 +62,10 @@ impl Receiver {
|
|||
|
||||
use RedisCmd::*;
|
||||
if *number_of_subscriptions == 1 {
|
||||
self.redis_connection.send_cmd(Subscribe, &tl)?
|
||||
self.redis_connection
|
||||
.send_cmd(Subscribe, &tl)
|
||||
.unwrap_or_else(|e| log::error!("Could not subscribe to the Redis channel: {}", e));
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn unsubscribe(&mut self, tl: Timeline) -> Result<()> {
|
|
@ -1,11 +1,10 @@
|
|||
use super::super::redis::{RedisConnErr, RedisParseErr};
|
||||
use crate::err::TimelineErr;
|
||||
use crate::messages::{Event, EventErr};
|
||||
use crate::parse_client_request::Timeline;
|
||||
use super::super::{RedisConnErr, RedisParseErr};
|
||||
use crate::event::{Event, EventErr};
|
||||
use crate::request::{Timeline, TimelineErr};
|
||||
|
||||
use std::fmt;
|
||||
#[derive(Debug)]
|
||||
pub enum ReceiverErr {
|
||||
pub enum ManagerErr {
|
||||
InvalidId,
|
||||
TimelineErr(TimelineErr),
|
||||
EventErr(EventErr),
|
||||
|
@ -14,11 +13,11 @@ pub enum ReceiverErr {
|
|||
ChannelSendErr(tokio::sync::watch::error::SendError<(Timeline, Event)>),
|
||||
}
|
||||
|
||||
impl std::error::Error for ReceiverErr {}
|
||||
impl std::error::Error for ManagerErr {}
|
||||
|
||||
impl fmt::Display for ReceiverErr {
|
||||
impl fmt::Display for ManagerErr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
|
||||
use ReceiverErr::*;
|
||||
use ManagerErr::*;
|
||||
match self {
|
||||
InvalidId => write!(
|
||||
f,
|
||||
|
@ -34,31 +33,31 @@ impl fmt::Display for ReceiverErr {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<tokio::sync::watch::error::SendError<(Timeline, Event)>> for ReceiverErr {
|
||||
impl From<tokio::sync::watch::error::SendError<(Timeline, Event)>> for ManagerErr {
|
||||
fn from(error: tokio::sync::watch::error::SendError<(Timeline, Event)>) -> Self {
|
||||
Self::ChannelSendErr(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EventErr> for ReceiverErr {
|
||||
impl From<EventErr> for ManagerErr {
|
||||
fn from(error: EventErr) -> Self {
|
||||
Self::EventErr(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RedisConnErr> for ReceiverErr {
|
||||
impl From<RedisConnErr> for ManagerErr {
|
||||
fn from(e: RedisConnErr) -> Self {
|
||||
Self::RedisConnErr(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TimelineErr> for ReceiverErr {
|
||||
impl From<TimelineErr> for ManagerErr {
|
||||
fn from(e: TimelineErr) -> Self {
|
||||
Self::TimelineErr(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RedisParseErr> for ReceiverErr {
|
||||
impl From<RedisParseErr> for ManagerErr {
|
||||
fn from(e: RedisParseErr) -> Self {
|
||||
Self::RedisParseErr(e)
|
||||
}
|
|
@ -36,6 +36,8 @@ pub enum RedisParseOutput<'a> {
|
|||
NonMsg(&'a str),
|
||||
}
|
||||
|
||||
// TODO -- should this impl Iterator?
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct RedisMsg<'a> {
|
||||
pub timeline_txt: &'a str,
|
|
@ -0,0 +1,5 @@
|
|||
pub use sse::Sse;
|
||||
pub use ws::Ws;
|
||||
|
||||
mod sse;
|
||||
mod ws;
|
|
@ -0,0 +1,80 @@
|
|||
use crate::event::Event;
|
||||
use crate::request::{Subscription, Timeline};
|
||||
|
||||
use futures::stream::Stream;
|
||||
use log;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use warp::reply::Reply;
|
||||
use warp::sse::{ServerSentEvent, Sse as WarpSse};
|
||||
|
||||
pub struct Sse;
|
||||
|
||||
impl Sse {
|
||||
fn reply_with(event: Event) -> Option<(impl ServerSentEvent, impl ServerSentEvent)> {
|
||||
Some((
|
||||
warp::sse::event(event.event_name()),
|
||||
warp::sse::data(event.payload().unwrap_or_else(String::new)),
|
||||
))
|
||||
}
|
||||
|
||||
pub fn send_events(
|
||||
sse: WarpSse,
|
||||
mut unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
subscription: Subscription,
|
||||
sse_rx: watch::Receiver<(Timeline, Event)>,
|
||||
) -> impl Reply {
|
||||
let target_timeline = subscription.timeline;
|
||||
let allowed_langs = subscription.allowed_langs;
|
||||
let blocks = subscription.blocks;
|
||||
|
||||
let event_stream = sse_rx
|
||||
.filter(move |(timeline, _)| target_timeline == *timeline)
|
||||
.filter_map(move |(timeline, event)| {
|
||||
use crate::event::{
|
||||
CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind,
|
||||
};
|
||||
|
||||
use crate::request::Stream::Public;
|
||||
match event {
|
||||
TypeSafe(Update { payload, queued_at }) => match timeline {
|
||||
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
||||
_ if payload.involves_any(&blocks) => None,
|
||||
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
|
||||
payload,
|
||||
queued_at,
|
||||
})),
|
||||
},
|
||||
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
|
||||
Dynamic(dyn_event) => {
|
||||
if let EventKind::Update(s) = dyn_event.kind {
|
||||
match timeline {
|
||||
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
|
||||
_ if s.involves_any(&blocks) => None,
|
||||
_ => Self::reply_with(Dynamic(DynEvent {
|
||||
kind: EventKind::Update(s),
|
||||
..dyn_event
|
||||
})),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Ping => None, // pings handled automatically
|
||||
}
|
||||
})
|
||||
.then(move |res| {
|
||||
unsubscribe_tx
|
||||
.try_send(target_timeline)
|
||||
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
|
||||
res
|
||||
});
|
||||
|
||||
sse.reply(
|
||||
warp::sse::keep_alive()
|
||||
.interval(Duration::from_secs(30))
|
||||
.text("thump".to_string())
|
||||
.stream(event_stream),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
use crate::event::Event;
|
||||
use crate::request::{Subscription, Timeline};
|
||||
|
||||
use futures::{future::Future, stream::Stream};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use warp::ws::{Message, WebSocket};
|
||||
|
||||
pub struct Ws {
|
||||
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
subscription: Subscription,
|
||||
ws_rx: watch::Receiver<(Timeline, Event)>,
|
||||
ws_tx: Option<mpsc::UnboundedSender<Message>>,
|
||||
}
|
||||
|
||||
impl Ws {
|
||||
pub fn new(
|
||||
unsubscribe_tx: mpsc::UnboundedSender<Timeline>,
|
||||
ws_rx: watch::Receiver<(Timeline, Event)>,
|
||||
subscription: Subscription,
|
||||
) -> Self {
|
||||
Self {
|
||||
unsubscribe_tx,
|
||||
subscription,
|
||||
ws_rx,
|
||||
ws_tx: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send_to(mut self, ws: WebSocket) -> impl Future<Item = (), Error = ()> {
|
||||
let (transmit_to_ws, _receive_from_ws) = ws.split();
|
||||
// Create a pipe
|
||||
let (ws_tx, ws_rx) = mpsc::unbounded_channel();
|
||||
self.ws_tx = Some(ws_tx);
|
||||
|
||||
// Send one end of it to a different green thread and tell that end to forward
|
||||
// whatever it gets on to the WebSocket client
|
||||
warp::spawn(
|
||||
ws_rx
|
||||
.map_err(|_| -> warp::Error { unreachable!() })
|
||||
.forward(transmit_to_ws)
|
||||
.map(|_r| ())
|
||||
.map_err(|e| match e.to_string().as_ref() {
|
||||
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
|
||||
_ => log::warn!("WebSocket send error: {}", e),
|
||||
}),
|
||||
);
|
||||
|
||||
let target_timeline = self.subscription.timeline;
|
||||
let incoming_events = self.ws_rx.clone().map_err(|_| ());
|
||||
|
||||
incoming_events.for_each(move |(tl, event)| {
|
||||
if matches!(event, Event::Ping) {
|
||||
self.send_ping()
|
||||
} else if target_timeline == tl {
|
||||
use crate::event::{CheckedEvent::Update, Event::*, EventKind};
|
||||
use crate::request::Stream::Public;
|
||||
let blocks = &self.subscription.blocks;
|
||||
let allowed_langs = &self.subscription.allowed_langs;
|
||||
|
||||
match event {
|
||||
TypeSafe(Update { payload, queued_at }) => match tl {
|
||||
Timeline(Public, _, _) if payload.language_not(allowed_langs) => Ok(()),
|
||||
_ if payload.involves_any(&blocks) => Ok(()),
|
||||
_ => self.send_msg(TypeSafe(Update { payload, queued_at })),
|
||||
},
|
||||
TypeSafe(non_update) => self.send_msg(TypeSafe(non_update)),
|
||||
Dynamic(dyn_event) => {
|
||||
if let EventKind::Update(s) = dyn_event.kind.clone() {
|
||||
match tl {
|
||||
Timeline(Public, _, _) if s.language_not(allowed_langs) => Ok(()),
|
||||
_ if s.involves_any(&blocks) => Ok(()),
|
||||
_ => self.send_msg(Dynamic(dyn_event)),
|
||||
}
|
||||
} else {
|
||||
self.send_msg(Dynamic(dyn_event))
|
||||
}
|
||||
}
|
||||
Ping => unreachable!(), // handled pings above
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn send_ping(&mut self) -> Result<(), ()> {
|
||||
self.send_txt("{}")
|
||||
}
|
||||
|
||||
fn send_msg(&mut self, event: Event) -> Result<(), ()> {
|
||||
self.send_txt(&event.to_json_string())
|
||||
}
|
||||
|
||||
fn send_txt(&mut self, txt: &str) -> Result<(), ()> {
|
||||
let tl = self.subscription.timeline;
|
||||
match self.ws_tx.clone().ok_or(())?.try_send(Message::text(txt)) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => {
|
||||
self.unsubscribe_tx
|
||||
.try_send(tl)
|
||||
.unwrap_or_else(|e| log::error!("could not unsubscribe from channel: {}", e));
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue