mirror of https://github.com/mastodon/flodgatt
Improve module boundary/privacy
This commit is contained in:
parent
a7603739ee
commit
631e818998
|
@ -77,7 +77,7 @@ mod parse_inline {
|
|||
mod flodgatt_parse_event {
|
||||
use flodgatt::{messages::Event, redis_to_client_stream::receiver::MessageQueues};
|
||||
use flodgatt::{
|
||||
parse_client_request::subscription::Timeline,
|
||||
parse_client_request::Timeline,
|
||||
redis_to_client_stream::{receiver::MsgQueue, redis::redis_stream},
|
||||
};
|
||||
use lru::LruCache;
|
||||
|
@ -114,7 +114,7 @@ mod flodgatt_parse_event {
|
|||
/// the performance we would see if we used serde's built-in method for handling weakly
|
||||
/// typed JSON instead of our own strongly typed struct.
|
||||
mod flodgatt_parse_value {
|
||||
use flodgatt::{log_fatal, parse_client_request::subscription::Timeline};
|
||||
use flodgatt::{log_fatal, parse_client_request::Timeline};
|
||||
use lru::LruCache;
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
|
|
|
@ -11,14 +11,14 @@ from_env_var!(
|
|||
/// The current environment, which controls what file to read other ENV vars from
|
||||
let name = Env;
|
||||
let default: EnvInner = EnvInner::Development;
|
||||
let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants()));
|
||||
let (env_var, allowed_values) = ("RUST_ENV", &format!("one of: {:?}", EnvInner::variants()));
|
||||
let from_str = |s| EnvInner::from_str(s).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The address to run Flodgatt on
|
||||
let name = FlodgattAddr;
|
||||
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
|
||||
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string());
|
||||
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
|
||||
let from_str = |s| match s {
|
||||
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
_ => s.parse().ok(),
|
||||
|
@ -28,35 +28,35 @@ from_env_var!(
|
|||
/// How verbosely Flodgatt should log messages
|
||||
let name = LogLevel;
|
||||
let default: LogLevelInner = LogLevelInner::Warn;
|
||||
let (env_var, allowed_values) = ("RUST_LOG", format!("one of: {:?}", LogLevelInner::variants()));
|
||||
let (env_var, allowed_values) = ("RUST_LOG", &format!("one of: {:?}", LogLevelInner::variants()));
|
||||
let from_str = |s| LogLevelInner::from_str(s).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// A Unix Socket to use in place of a local address
|
||||
let name = Socket;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("SOCKET", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("SOCKET", "any string");
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
||||
from_env_var!(
|
||||
/// The time between replies sent via WebSocket
|
||||
let name = WsInterval;
|
||||
let default: Duration = Duration::from_millis(100);
|
||||
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
|
||||
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
|
||||
let from_str = |s| s.parse().map(Duration::from_millis).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The time between replies sent via Server Sent Events
|
||||
let name = SseInterval;
|
||||
let default: Duration = Duration::from_millis(100);
|
||||
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
|
||||
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
|
||||
let from_str = |s| s.parse().map(Duration::from_millis).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The port to run Flodgatt on
|
||||
let name = Port;
|
||||
let default: u16 = 4000;
|
||||
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string());
|
||||
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535");
|
||||
let from_str = |s| s.parse().ok();
|
||||
);
|
||||
from_env_var!(
|
||||
|
@ -66,7 +66,7 @@ from_env_var!(
|
|||
/// (including otherwise public timelines).
|
||||
let name = WhitelistMode;
|
||||
let default: bool = false;
|
||||
let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string());
|
||||
let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false");
|
||||
let from_str = |s| s.parse().ok();
|
||||
);
|
||||
/// Permissions for Cross Origin Resource Sharing (CORS)
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
use std::{collections::HashMap, fmt};
|
||||
|
||||
pub struct EnvVar(pub HashMap<String, String>);
|
||||
impl std::ops::Deref for EnvVar {
|
||||
type Target = HashMap<String, String>;
|
||||
fn deref(&self) -> &HashMap<String, String> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for EnvVar {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
||||
impl EnvVar {
|
||||
pub fn new(vars: HashMap<String, String>) -> Self {
|
||||
Self(vars)
|
||||
}
|
||||
|
||||
pub fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option<impl ToString>) {
|
||||
if let Some(value) = maybe_value {
|
||||
self.0.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
|
||||
log::error!(
|
||||
r"{var} is set to `{value}`, which is invalid.
|
||||
{var} must be {allowed_vals}.",
|
||||
var = env_var,
|
||||
value = supplied_value,
|
||||
allowed_vals = allowed_values
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
impl fmt::Display for EnvVar {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut result = String::new();
|
||||
for env_var in [
|
||||
"NODE_ENV",
|
||||
"RUST_LOG",
|
||||
"BIND",
|
||||
"PORT",
|
||||
"SOCKET",
|
||||
"SSE_FREQ",
|
||||
"WS_FREQ",
|
||||
"DATABASE_URL",
|
||||
"DB_USER",
|
||||
"USER",
|
||||
"DB_PORT",
|
||||
"DB_HOST",
|
||||
"DB_PASS",
|
||||
"DB_NAME",
|
||||
"DB_SSLMODE",
|
||||
"REDIS_HOST",
|
||||
"REDIS_USER",
|
||||
"REDIS_PORT",
|
||||
"REDIS_PASSWORD",
|
||||
"REDIS_USER",
|
||||
"REDIS_DB",
|
||||
]
|
||||
.iter()
|
||||
{
|
||||
if let Some(value) = self.get(&env_var.to_string()) {
|
||||
result = format!("{}\n {}: {}", result, env_var, value)
|
||||
}
|
||||
}
|
||||
write!(f, "{}", result)
|
||||
}
|
||||
}
|
||||
#[macro_export]
|
||||
macro_rules! maybe_update {
|
||||
($name:ident; $item: tt:$type:ty) => (
|
||||
pub fn $name(self, item: Option<$type>) -> Self {
|
||||
match item {
|
||||
Some($item) => Self{ $item, ..self },
|
||||
None => Self { ..self }
|
||||
}
|
||||
});
|
||||
($name:ident; Some($item: tt: $type:ty)) => (
|
||||
fn $name(self, item: Option<$type>) -> Self{
|
||||
match item {
|
||||
Some($item) => Self{ $item: Some($item), ..self },
|
||||
None => Self { ..self }
|
||||
}
|
||||
})}
|
||||
#[macro_export]
|
||||
macro_rules! from_env_var {
|
||||
($(#[$outer:meta])*
|
||||
let name = $name:ident;
|
||||
let default: $type:ty = $inner:expr;
|
||||
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
|
||||
let from_str = |$arg:ident| $body:expr;
|
||||
) => {
|
||||
pub struct $name(pub $type);
|
||||
impl std::fmt::Debug for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self.0)
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = $type;
|
||||
fn deref(&self) -> &$type {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::default::Default for $name {
|
||||
fn default() -> Self {
|
||||
$name($inner)
|
||||
}
|
||||
}
|
||||
impl $name {
|
||||
fn inner_from_str($arg: &str) -> Option<$type> {
|
||||
$body
|
||||
}
|
||||
pub fn maybe_update(self, var: Option<&String>) -> Self {
|
||||
match var {
|
||||
Some(empty_string) if empty_string.is_empty() => Self::default(),
|
||||
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
crate::config::EnvVar::err($env_var, value, $allowed_values)
|
||||
})),
|
||||
None => self,
|
||||
}
|
||||
|
||||
// if let Some(value) = var {
|
||||
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
// crate::err::env_var_fatal($env_var, value, $allowed_values)
|
||||
// }))
|
||||
// } else {
|
||||
// self
|
||||
// }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
|
@ -4,135 +4,7 @@ mod postgres_cfg;
|
|||
mod postgres_cfg_types;
|
||||
mod redis_cfg;
|
||||
mod redis_cfg_types;
|
||||
pub use self::{
|
||||
deployment_cfg::DeploymentConfig,
|
||||
postgres_cfg::PostgresConfig,
|
||||
redis_cfg::RedisConfig,
|
||||
redis_cfg_types::{RedisInterval, RedisNamespace},
|
||||
};
|
||||
use std::{collections::HashMap, fmt};
|
||||
mod environmental_variables;
|
||||
|
||||
pub struct EnvVar(pub HashMap<String, String>);
|
||||
impl std::ops::Deref for EnvVar {
|
||||
type Target = HashMap<String, String>;
|
||||
fn deref(&self) -> &HashMap<String, String> {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar};
|
||||
|
||||
impl Clone for EnvVar {
|
||||
fn clone(&self) -> Self {
|
||||
Self(self.0.clone())
|
||||
}
|
||||
}
|
||||
impl EnvVar {
|
||||
pub fn new(vars: HashMap<String, String>) -> Self {
|
||||
Self(vars)
|
||||
}
|
||||
|
||||
fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option<impl ToString>) {
|
||||
if let Some(value) = maybe_value {
|
||||
self.0.insert(key.to_string(), value.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
impl fmt::Display for EnvVar {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let mut result = String::new();
|
||||
for env_var in [
|
||||
"NODE_ENV",
|
||||
"RUST_LOG",
|
||||
"BIND",
|
||||
"PORT",
|
||||
"SOCKET",
|
||||
"SSE_FREQ",
|
||||
"WS_FREQ",
|
||||
"DATABASE_URL",
|
||||
"DB_USER",
|
||||
"USER",
|
||||
"DB_PORT",
|
||||
"DB_HOST",
|
||||
"DB_PASS",
|
||||
"DB_NAME",
|
||||
"DB_SSLMODE",
|
||||
"REDIS_HOST",
|
||||
"REDIS_USER",
|
||||
"REDIS_PORT",
|
||||
"REDIS_PASSWORD",
|
||||
"REDIS_USER",
|
||||
"REDIS_DB",
|
||||
]
|
||||
.iter()
|
||||
{
|
||||
if let Some(value) = self.get(&env_var.to_string()) {
|
||||
result = format!("{}\n {}: {}", result, env_var, value)
|
||||
}
|
||||
}
|
||||
write!(f, "{}", result)
|
||||
}
|
||||
}
|
||||
#[macro_export]
|
||||
macro_rules! maybe_update {
|
||||
($name:ident; $item: tt:$type:ty) => (
|
||||
pub fn $name(self, item: Option<$type>) -> Self {
|
||||
match item {
|
||||
Some($item) => Self{ $item, ..self },
|
||||
None => Self { ..self }
|
||||
}
|
||||
});
|
||||
($name:ident; Some($item: tt: $type:ty)) => (
|
||||
fn $name(self, item: Option<$type>) -> Self{
|
||||
match item {
|
||||
Some($item) => Self{ $item: Some($item), ..self },
|
||||
None => Self { ..self }
|
||||
}
|
||||
})}
|
||||
#[macro_export]
|
||||
macro_rules! from_env_var {
|
||||
($(#[$outer:meta])*
|
||||
let name = $name:ident;
|
||||
let default: $type:ty = $inner:expr;
|
||||
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
|
||||
let from_str = |$arg:ident| $body:expr;
|
||||
) => {
|
||||
pub struct $name(pub $type);
|
||||
impl std::fmt::Debug for $name {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self.0)
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = $type;
|
||||
fn deref(&self) -> &$type {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::default::Default for $name {
|
||||
fn default() -> Self {
|
||||
$name($inner)
|
||||
}
|
||||
}
|
||||
impl $name {
|
||||
fn inner_from_str($arg: &str) -> Option<$type> {
|
||||
$body
|
||||
}
|
||||
pub fn maybe_update(self, var: Option<&String>) -> Self {
|
||||
match var {
|
||||
Some(empty_string) if empty_string.is_empty() => Self::default(),
|
||||
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
crate::err::env_var_fatal($env_var, value, $allowed_values)
|
||||
})),
|
||||
None => self,
|
||||
}
|
||||
|
||||
// if let Some(value) = var {
|
||||
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
// crate::err::env_var_fatal($env_var, value, $allowed_values)
|
||||
// }))
|
||||
// } else {
|
||||
// self
|
||||
// }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ from_env_var!(
|
|||
/// The user to use for Postgres
|
||||
let name = PgUser;
|
||||
let default: String = "postgres".to_string();
|
||||
let (env_var, allowed_values) = ("DB_USER", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("DB_USER", "any string");
|
||||
let from_str = |s| Some(s.to_string());
|
||||
);
|
||||
|
||||
|
@ -14,7 +14,7 @@ from_env_var!(
|
|||
/// The host address where Postgres is running)
|
||||
let name = PgHost;
|
||||
let default: String = "localhost".to_string();
|
||||
let (env_var, allowed_values) = ("DB_HOST", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("DB_HOST", "any string");
|
||||
let from_str = |s| Some(s.to_string());
|
||||
);
|
||||
|
||||
|
@ -22,7 +22,7 @@ from_env_var!(
|
|||
/// The password to use with Postgress
|
||||
let name = PgPass;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("DB_PASS", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("DB_PASS", "any string");
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
||||
|
||||
|
@ -30,7 +30,7 @@ from_env_var!(
|
|||
/// The Postgres database to use
|
||||
let name = PgDatabase;
|
||||
let default: String = "mastodon_development".to_string();
|
||||
let (env_var, allowed_values) = ("DB_NAME", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("DB_NAME", "any string");
|
||||
let from_str = |s| Some(s.to_string());
|
||||
);
|
||||
|
||||
|
@ -38,14 +38,14 @@ from_env_var!(
|
|||
/// The port Postgres is running on
|
||||
let name = PgPort;
|
||||
let default: u16 = 5432;
|
||||
let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535".to_string());
|
||||
let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535");
|
||||
let from_str = |s| s.parse().ok();
|
||||
);
|
||||
|
||||
from_env_var!(
|
||||
let name = PgSslMode;
|
||||
let default: PgSslInner = PgSslInner::Prefer;
|
||||
let (env_var, allowed_values) = ("DB_SSLMODE", format!("one of: {:?}", PgSslInner::variants()));
|
||||
let (env_var, allowed_values) = ("DB_SSLMODE", &format!("one of: {:?}", PgSslInner::variants()));
|
||||
let from_str = |s| PgSslInner::from_str(s).ok();
|
||||
);
|
||||
|
||||
|
|
|
@ -7,48 +7,48 @@ from_env_var!(
|
|||
/// The host address where Redis is running
|
||||
let name = RedisHost;
|
||||
let default: String = "127.0.0.1".to_string();
|
||||
let (env_var, allowed_values) = ("REDIS_HOST", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_HOST", "any string");
|
||||
let from_str = |s| Some(s.to_string());
|
||||
);
|
||||
from_env_var!(
|
||||
/// The port Redis is running on
|
||||
let name = RedisPort;
|
||||
let default: u16 = 6379;
|
||||
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535");
|
||||
let from_str = |s| s.parse().ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// How frequently to poll Redis
|
||||
let name = RedisInterval;
|
||||
let default: Duration = Duration::from_millis(100);
|
||||
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds");
|
||||
let from_str = |s| s.parse().map(Duration::from_millis).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The password to use for Redis
|
||||
let name = RedisPass;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string");
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
||||
from_env_var!(
|
||||
/// An optional Redis Namespace
|
||||
let name = RedisNamespace;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string");
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
||||
from_env_var!(
|
||||
/// A user for Redis (not supported)
|
||||
let name = RedisUser;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_USER", "any string");
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
||||
from_env_var!(
|
||||
/// The database to use with Redis (no current effect for PubSub connections)
|
||||
let name = RedisDb;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string());
|
||||
let (env_var, allowed_values) = ("REDIS_DB", "any string");
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
||||
|
|
66
src/err.rs
66
src/err.rs
|
@ -1,4 +1,3 @@
|
|||
use serde_derive::Serialize;
|
||||
use std::fmt::Display;
|
||||
|
||||
pub fn die_with_msg(msg: impl Display) -> ! {
|
||||
|
@ -13,68 +12,3 @@ macro_rules! log_fatal {
|
|||
panic!();
|
||||
};};
|
||||
}
|
||||
|
||||
pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
|
||||
eprintln!(
|
||||
r"FATAL ERROR: {var} is set to `{value}`, which is invalid.
|
||||
{var} must be {allowed_vals}.",
|
||||
var = env_var,
|
||||
value = supplied_value,
|
||||
allowed_vals = allowed_values
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! dbg_and_die {
|
||||
($msg:expr) => {
|
||||
let message = format!("FATAL ERROR: {}", $msg);
|
||||
dbg!(message);
|
||||
std::process::exit(1);
|
||||
};
|
||||
}
|
||||
pub fn unwrap_or_die<T>(s: Option<T>, msg: &str) -> T {
|
||||
s.unwrap_or_else(|| {
|
||||
eprintln!("FATAL ERROR: {}", msg);
|
||||
std::process::exit(1)
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ErrorMessage {
|
||||
error: String,
|
||||
}
|
||||
impl ErrorMessage {
|
||||
fn new(msg: impl std::fmt::Display) -> Self {
|
||||
Self {
|
||||
error: msg.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Recover from Errors by sending appropriate Warp::Rejections
|
||||
pub fn handle_errors(
|
||||
rejection: warp::reject::Rejection,
|
||||
) -> Result<impl warp::Reply, warp::reject::Rejection> {
|
||||
let err_txt = match rejection.cause() {
|
||||
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
|
||||
"Error: Missing access token".to_string()
|
||||
}
|
||||
Some(text) => text.to_string(),
|
||||
None => "Error: Nonexistant endpoint".to_string(),
|
||||
};
|
||||
let json = warp::reply::json(&ErrorMessage::new(err_txt));
|
||||
|
||||
Ok(warp::reply::with_status(
|
||||
json,
|
||||
warp::http::StatusCode::UNAUTHORIZED,
|
||||
))
|
||||
}
|
||||
|
||||
pub struct CustomError {}
|
||||
|
||||
impl CustomError {
|
||||
pub fn unauthorized_list() -> warp::reject::Rejection {
|
||||
warp::reject::custom("Error: Access to list not authorized")
|
||||
}
|
||||
}
|
||||
|
|
102
src/main.rs
102
src/main.rs
|
@ -1,32 +1,32 @@
|
|||
use flodgatt::{
|
||||
config, err,
|
||||
parse_client_request::{sse, subscription, ws},
|
||||
redis_to_client_stream::{self, ClientAgent},
|
||||
config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig},
|
||||
parse_client_request::{PgPool, Subscription},
|
||||
redis_to_client_stream::{ClientAgent, EventStream},
|
||||
};
|
||||
use std::{collections::HashMap, env, fs, net, os::unix::fs::PermissionsExt};
|
||||
use tokio::net::UnixListener;
|
||||
use warp::{path, ws::Ws2, Filter};
|
||||
|
||||
fn main() {
|
||||
dotenv::from_filename(
|
||||
match env::var("ENV").ok().as_ref().map(String::as_str) {
|
||||
dotenv::from_filename(match env::var("ENV").ok().as_ref().map(String::as_str) {
|
||||
Some("production") => ".env.production",
|
||||
Some("development") | None => ".env",
|
||||
Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."),
|
||||
}).ok();
|
||||
Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"),
|
||||
})
|
||||
.ok();
|
||||
let env_vars_map: HashMap<_, _> = dotenv::vars().collect();
|
||||
let env_vars = config::EnvVar::new(env_vars_map);
|
||||
let env_vars = EnvVar::new(env_vars_map);
|
||||
pretty_env_logger::init();
|
||||
|
||||
log::info!(
|
||||
"Flodgatt recognized the following environmental variables:{}",
|
||||
env_vars.clone()
|
||||
);
|
||||
let redis_cfg = config::RedisConfig::from_env(env_vars.clone());
|
||||
let cfg = config::DeploymentConfig::from_env(env_vars.clone());
|
||||
let redis_cfg = RedisConfig::from_env(env_vars.clone());
|
||||
let cfg = DeploymentConfig::from_env(env_vars.clone());
|
||||
|
||||
let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone());
|
||||
let pg_pool = subscription::PgPool::new(postgres_cfg);
|
||||
let postgres_cfg = PostgresConfig::from_env(env_vars.clone());
|
||||
let pg_pool = PgPool::new(postgres_cfg);
|
||||
|
||||
let client_agent_sse = ClientAgent::blank(redis_cfg);
|
||||
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
|
||||
|
@ -34,54 +34,43 @@ fn main() {
|
|||
log::info!("Streaming server initialized and ready to accept connections");
|
||||
|
||||
// Server Sent Events
|
||||
let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
|
||||
let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode)
|
||||
let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
|
||||
let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode)
|
||||
.and(warp::sse())
|
||||
.map(
|
||||
move |subscription: subscription::Subscription,
|
||||
sse_connection_to_client: warp::sse::Sse| {
|
||||
move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| {
|
||||
log::info!("Incoming SSE request for {:?}", subscription.timeline);
|
||||
// Create a new ClientAgent
|
||||
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
|
||||
// Assign ClientAgent to generate stream of updates for the user/timeline pair
|
||||
client_agent.init_for_user(subscription);
|
||||
// send the updates through the SSE connection
|
||||
redis_to_client_stream::send_updates_to_sse(
|
||||
client_agent,
|
||||
sse_connection_to_client,
|
||||
sse_update_interval,
|
||||
)
|
||||
EventStream::to_sse(client_agent, sse_connection_to_client, sse_interval)
|
||||
},
|
||||
)
|
||||
.with(warp::reply::with::header("Connection", "keep-alive"))
|
||||
.recover(err::handle_errors);
|
||||
.with(warp::reply::with::header("Connection", "keep-alive"));
|
||||
|
||||
// WebSocket
|
||||
let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode);
|
||||
let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode)
|
||||
let websocket_routes = Subscription::from_ws_request(pg_pool.clone(), whitelist_mode)
|
||||
.and(warp::ws::ws2())
|
||||
.map(
|
||||
move |subscription: subscription::Subscription, token: Option<String>, ws: Ws2| {
|
||||
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
||||
// Create a new ClientAgent
|
||||
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
|
||||
// Assign that agent to generate a stream of updates for the user/timeline pair
|
||||
client_agent.init_for_user(subscription);
|
||||
// send the updates through the WS connection (along with the User's access_token
|
||||
// which is sent for security)
|
||||
.map(move |subscription: Subscription, ws: Ws2| {
|
||||
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
||||
|
||||
(
|
||||
ws.on_upgrade(move |socket| {
|
||||
redis_to_client_stream::send_updates_to_ws(
|
||||
socket,
|
||||
client_agent,
|
||||
ws_update_interval,
|
||||
)
|
||||
}),
|
||||
token.unwrap_or_else(String::new),
|
||||
)
|
||||
},
|
||||
)
|
||||
let token = subscription.access_token.clone();
|
||||
// Create a new ClientAgent
|
||||
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
|
||||
// Assign that agent to generate a stream of updates for the user/timeline pair
|
||||
client_agent.init_for_user(subscription);
|
||||
// send the updates through the WS connection (along with the User's access_token
|
||||
// which is sent for security)
|
||||
(
|
||||
ws.on_upgrade(move |socket| {
|
||||
EventStream::to_ws(socket, client_agent, ws_update_interval)
|
||||
}),
|
||||
token.unwrap_or_else(String::new),
|
||||
)
|
||||
})
|
||||
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
|
||||
|
||||
let cors = warp::cors()
|
||||
|
@ -98,7 +87,28 @@ fn main() {
|
|||
|
||||
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap();
|
||||
|
||||
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run_incoming(incoming);
|
||||
warp::serve(
|
||||
health.or(websocket_routes.or(sse_routes).with(cors).recover(
|
||||
|rejection: warp::reject::Rejection| {
|
||||
let err_txt = match rejection.cause() {
|
||||
Some(text)
|
||||
if text.to_string() == "Missing request header 'authorization'" =>
|
||||
{
|
||||
"Error: Missing access token".to_string()
|
||||
}
|
||||
Some(text) => text.to_string(),
|
||||
None => "Error: Nonexistant endpoint".to_string(),
|
||||
};
|
||||
let json = warp::reply::json(&err_txt);
|
||||
|
||||
Ok(warp::reply::with_status(
|
||||
json,
|
||||
warp::http::StatusCode::UNAUTHORIZED,
|
||||
))
|
||||
},
|
||||
)),
|
||||
)
|
||||
.run_incoming(incoming);
|
||||
} else {
|
||||
let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0);
|
||||
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run(server_addr);
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
//! Parse the client request and return a (possibly authenticated) `User`
|
||||
pub mod query;
|
||||
pub mod sse;
|
||||
pub mod subscription;
|
||||
pub mod ws;
|
||||
//! Parse the client request and return a Subscription
|
||||
mod postgres;
|
||||
mod query;
|
||||
mod sse;
|
||||
mod subscription;
|
||||
mod ws;
|
||||
|
||||
pub use subscription::{Stream, Timeline};
|
||||
pub use self::postgres::PgPool;
|
||||
pub use subscription::{Stream, Subscription, Timeline};
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
//! Postgres queries
|
||||
use crate::{
|
||||
config,
|
||||
parse_client_request::subscription::{Scope, UserData},
|
||||
};
|
||||
use ::postgres;
|
||||
use r2d2_postgres::PostgresConnectionManager;
|
||||
use std::collections::HashSet;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
|
||||
impl PgPool {
|
||||
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
|
||||
let mut cfg = postgres::Config::new();
|
||||
cfg.user(&pg_cfg.user)
|
||||
.host(&*pg_cfg.host.to_string())
|
||||
.port(*pg_cfg.port)
|
||||
.dbname(&pg_cfg.database);
|
||||
if let Some(password) = &*pg_cfg.password {
|
||||
cfg.password(password);
|
||||
};
|
||||
|
||||
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
||||
let pool = r2d2::Pool::builder()
|
||||
.max_size(10)
|
||||
.build(manager)
|
||||
.expect("Can connect to local postgres");
|
||||
Self(pool)
|
||||
}
|
||||
pub fn select_user(self, token: &str) -> Result<UserData, Rejection> {
|
||||
let mut conn = self.0.get().unwrap();
|
||||
let query_rows = conn
|
||||
.query(
|
||||
"
|
||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||
FROM
|
||||
oauth_access_tokens
|
||||
INNER JOIN users ON
|
||||
oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1
|
||||
AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&token.to_owned()],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if let Some(result_columns) = query_rows.get(0) {
|
||||
let id = result_columns.get(1);
|
||||
let allowed_langs = result_columns
|
||||
.try_get::<_, Vec<_>>(2)
|
||||
.unwrap_or_else(|_| Vec::new())
|
||||
.into_iter()
|
||||
.collect();
|
||||
let mut scopes: HashSet<Scope> = result_columns
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.filter_map(|scope| match scope {
|
||||
"read" => Some(Scope::Read),
|
||||
"read:statuses" => Some(Scope::Statuses),
|
||||
"read:notifications" => Some(Scope::Notifications),
|
||||
"read:lists" => Some(Scope::Lists),
|
||||
"write" | "follow" => None, // ignore write scopes
|
||||
unexpected => {
|
||||
log::warn!("Ignoring unknown scope `{}`", unexpected);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// We don't need to separately track read auth - it's just all three others
|
||||
if scopes.remove(&Scope::Read) {
|
||||
scopes.insert(Scope::Statuses);
|
||||
scopes.insert(Scope::Notifications);
|
||||
scopes.insert(Scope::Lists);
|
||||
}
|
||||
|
||||
Ok(UserData {
|
||||
id,
|
||||
allowed_langs,
|
||||
scopes,
|
||||
})
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_hashtag_id(self, tag_name: &String) -> Result<i64, Rejection> {
|
||||
let mut conn = self.0.get().unwrap();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id
|
||||
FROM tags
|
||||
WHERE name = $1
|
||||
LIMIT 1",
|
||||
&[&tag_name],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
match rows.get(0) {
|
||||
Some(row) => Ok(row.get(0)),
|
||||
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query Postgres for everyone the user has blocked or muted
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_users(self, user_id: i64) -> HashSet<i64> {
|
||||
// "
|
||||
// SELECT
|
||||
// 1
|
||||
// FROM blocks
|
||||
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
|
||||
// OR (account_id = $2 AND target_account_id = $1)
|
||||
// UNION SELECT
|
||||
// 1
|
||||
// FROM mutes
|
||||
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
|
||||
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
|
||||
self
|
||||
.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"
|
||||
SELECT target_account_id
|
||||
FROM blocks
|
||||
WHERE account_id = $1
|
||||
UNION SELECT target_account_id
|
||||
FROM mutes
|
||||
WHERE account_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
/// Query Postgres for everyone who has blocked the user
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocking_users(self, user_id: i64) -> HashSet<i64> {
|
||||
self
|
||||
.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"
|
||||
SELECT account_id
|
||||
FROM blocks
|
||||
WHERE target_account_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Query Postgres for all current domain blocks
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_domains(self, user_id: i64) -> HashSet<String> {
|
||||
self
|
||||
.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Test whether a user owns a list
|
||||
pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool {
|
||||
let mut conn = self.0.get().unwrap();
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id, account_id
|
||||
FROM lists
|
||||
WHERE id = $1
|
||||
LIMIT 1",
|
||||
&[&list_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
match rows.get(0) {
|
||||
None => false,
|
||||
Some(row) => {
|
||||
let list_owner_id: i64 = row.get(1);
|
||||
list_owner_id == user_id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -28,6 +28,12 @@ impl Query {
|
|||
}
|
||||
|
||||
macro_rules! make_query_type {
|
||||
(Stream => $parameter:tt:$type:ty) => {
|
||||
#[derive(Deserialize, Debug, Default)]
|
||||
pub struct Stream {
|
||||
pub $parameter: $type,
|
||||
}
|
||||
};
|
||||
($name:tt => $parameter:tt:$type:ty) => {
|
||||
#[derive(Deserialize, Debug, Default)]
|
||||
pub struct $name {
|
||||
|
@ -59,14 +65,14 @@ impl ToString for Stream {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn optional_media_query() -> BoxedFilter<(Media,)> {
|
||||
warp::query()
|
||||
.or(warp::any().map(|| Media {
|
||||
only_media: "false".to_owned(),
|
||||
}))
|
||||
.unify()
|
||||
.boxed()
|
||||
}
|
||||
// pub fn optional_media_query() -> BoxedFilter<(Media,)> {
|
||||
// warp::query()
|
||||
// .or(warp::any().map(|| Media {
|
||||
// only_media: "false".to_owned(),
|
||||
// }))
|
||||
// .unify()
|
||||
// .boxed()
|
||||
// }
|
||||
|
||||
pub struct OptionalAccessToken;
|
||||
|
||||
|
|
|
@ -1,78 +1,4 @@
|
|||
//! Filters for all the endpoints accessible for Server Sent Event updates
|
||||
use super::{
|
||||
query::{self, Query},
|
||||
subscription::{PgPool, Subscription},
|
||||
};
|
||||
use warp::{filters::BoxedFilter, path, Filter};
|
||||
#[allow(dead_code)]
|
||||
type TimelineUser = ((String, Subscription),);
|
||||
|
||||
/// Helper macro to match on the first of any of the provided filters
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*.boxed()
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! parse_query {
|
||||
(path => $start:tt $(/ $next:tt)*
|
||||
endpoint => $endpoint:expr) => {
|
||||
path!($start $(/ $next)*)
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: $endpoint.to_string(),
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
};
|
||||
}
|
||||
pub fn extract_user_or_reject(
|
||||
pg_pool: PgPool,
|
||||
whitelist_mode: bool,
|
||||
) -> BoxedFilter<(Subscription,)> {
|
||||
any_of!(
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user" / "notification"
|
||||
endpoint => "user:notification" ),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user"
|
||||
endpoint => "user"),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public" / "local"
|
||||
endpoint => "public:local"),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public"
|
||||
endpoint => "public"),
|
||||
parse_query!(
|
||||
path => "api" / "v1" / "streaming" / "direct"
|
||||
endpoint => "direct"),
|
||||
parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
||||
endpoint => "hashtag:local"),
|
||||
parse_query!(path => "api" / "v1" / "streaming" / "hashtag"
|
||||
endpoint => "hashtag"),
|
||||
parse_query!(path => "api" / "v1" / "streaming" / "list"
|
||||
endpoint => "list")
|
||||
)
|
||||
// because SSE requests place their `access_token` in the header instead of in a query
|
||||
// parameter, we need to update our Query if the header has a token
|
||||
.and(query::OptionalAccessToken::from_sse_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod test {
|
||||
|
|
|
@ -4,20 +4,55 @@
|
|||
// #[cfg(test)]
|
||||
// use mock_postgres as postgres;
|
||||
// #[cfg(not(test))]
|
||||
pub mod postgres;
|
||||
pub use self::postgres::PgPool;
|
||||
|
||||
use super::postgres::PgPool;
|
||||
use super::query::Query;
|
||||
use crate::log_fatal;
|
||||
use std::collections::HashSet;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
/// The User (with data read from Postgres)
|
||||
use super::query;
|
||||
use warp::{filters::BoxedFilter, path, Filter};
|
||||
|
||||
/// Helper macro to match on the first of any of the provided filters
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*.boxed()
|
||||
};
|
||||
}
|
||||
macro_rules! parse_sse_query {
|
||||
(path => $start:tt $(/ $next:tt)*
|
||||
endpoint => $endpoint:expr) => {
|
||||
path!($start $(/ $next)*)
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: $endpoint.to_string(),
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Subscription {
|
||||
pub timeline: Timeline,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub blocks: Blocks,
|
||||
pub hashtag_name: Option<String>,
|
||||
pub access_token: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Subscription {
|
||||
|
@ -27,14 +62,54 @@ impl Default for Subscription {
|
|||
allowed_langs: HashSet::new(),
|
||||
blocks: Blocks::default(),
|
||||
hashtag_name: None,
|
||||
access_token: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Subscription {
|
||||
pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
|
||||
pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
|
||||
parse_ws_query()
|
||||
.and(query::OptionalAccessToken::from_ws_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
|
||||
any_of!(
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user" / "notification"
|
||||
endpoint => "user:notification" ),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "user"
|
||||
endpoint => "user"),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public" / "local"
|
||||
endpoint => "public:local"),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "public"
|
||||
endpoint => "public"),
|
||||
parse_sse_query!(
|
||||
path => "api" / "v1" / "streaming" / "direct"
|
||||
endpoint => "direct"),
|
||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
||||
endpoint => "hashtag:local"),
|
||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
|
||||
endpoint => "hashtag"),
|
||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
|
||||
endpoint => "list")
|
||||
)
|
||||
// because SSE requests place their `access_token` in the header instead of in a query
|
||||
// parameter, we need to update our Query if the header has a token
|
||||
.and(query::OptionalAccessToken::from_sse_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
|
||||
.boxed()
|
||||
}
|
||||
fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
|
||||
let user = match q.access_token.clone() {
|
||||
Some(token) => postgres::select_user(&token, pool.clone())?,
|
||||
Some(token) => pool.clone().select_user(&token)?,
|
||||
None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
|
||||
None => UserData::public(),
|
||||
};
|
||||
|
@ -48,15 +123,42 @@ impl Subscription {
|
|||
timeline,
|
||||
allowed_langs: user.allowed_langs,
|
||||
blocks: Blocks {
|
||||
blocking_users: postgres::select_blocking_users(user.id, pool.clone()),
|
||||
blocked_users: postgres::select_blocked_users(user.id, pool.clone()),
|
||||
blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()),
|
||||
blocking_users: pool.clone().select_blocking_users(user.id),
|
||||
blocked_users: pool.clone().select_blocked_users(user.id),
|
||||
blocked_domains: pool.clone().select_blocked_domains(user.id),
|
||||
},
|
||||
hashtag_name,
|
||||
access_token: q.access_token,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
||||
path!("api" / "v1" / "streaming")
|
||||
.and(path::end())
|
||||
.and(warp::query())
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|stream: query::Stream,
|
||||
auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: stream.stream,
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||
pub struct Timeline(pub Stream, pub Reach, pub Content);
|
||||
|
||||
|
@ -111,8 +213,8 @@ impl Timeline {
|
|||
}
|
||||
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
|
||||
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
|
||||
let id_from_hashtag = || postgres::select_hashtag_id(&q.hashtag, pool.clone());
|
||||
let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone());
|
||||
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
|
||||
let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
|
||||
|
||||
Ok(match q.stream.as_ref() {
|
||||
"public" => match q.media {
|
||||
|
@ -189,9 +291,9 @@ pub struct Blocks {
|
|||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct UserData {
|
||||
id: i64,
|
||||
allowed_langs: HashSet<String>,
|
||||
scopes: HashSet<Scope>,
|
||||
pub id: i64,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub scopes: HashSet<Scope>,
|
||||
}
|
||||
|
||||
impl UserData {
|
|
@ -1,43 +0,0 @@
|
|||
//! Mock Postgres connection (for use in unit testing)
|
||||
use super::{OauthScope, Subscription};
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PgPool;
|
||||
impl PgPool {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_user(
|
||||
access_token: &str,
|
||||
_pg_pool: PgPool,
|
||||
) -> Result<Subscription, warp::reject::Rejection> {
|
||||
let mut user = Subscription::default();
|
||||
if access_token == "TEST_USER" {
|
||||
user.id = 1;
|
||||
user.logged_in = true;
|
||||
user.access_token = "TEST_USER".to_string();
|
||||
user.email = "user@example.com".to_string();
|
||||
user.scopes = OauthScope::from(vec![
|
||||
"read".to_string(),
|
||||
"write".to_string(),
|
||||
"follow".to_string(),
|
||||
]);
|
||||
} else if access_token == "INVALID" {
|
||||
return Err(warp::reject::custom("Error: Invalid access token"));
|
||||
}
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet<i64> {
|
||||
HashSet::new()
|
||||
}
|
||||
pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet<String> {
|
||||
HashSet::new()
|
||||
}
|
||||
|
||||
pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool {
|
||||
user_id == list_id
|
||||
}
|
|
@ -1,224 +0,0 @@
|
|||
//! Postgres queries
|
||||
use crate::{
|
||||
config,
|
||||
parse_client_request::subscription::{Scope, UserData},
|
||||
};
|
||||
use ::postgres;
|
||||
use r2d2_postgres::PostgresConnectionManager;
|
||||
use std::collections::HashSet;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
|
||||
impl PgPool {
|
||||
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
|
||||
let mut cfg = postgres::Config::new();
|
||||
cfg.user(&pg_cfg.user)
|
||||
.host(&*pg_cfg.host.to_string())
|
||||
.port(*pg_cfg.port)
|
||||
.dbname(&pg_cfg.database);
|
||||
if let Some(password) = &*pg_cfg.password {
|
||||
cfg.password(password);
|
||||
};
|
||||
|
||||
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
|
||||
let pool = r2d2::Pool::builder()
|
||||
.max_size(10)
|
||||
.build(manager)
|
||||
.expect("Can connect to local postgres");
|
||||
Self(pool)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_user(token: &str, pool: PgPool) -> Result<UserData, Rejection> {
|
||||
let mut conn = pool.0.get().unwrap();
|
||||
let query_rows = conn
|
||||
.query(
|
||||
"
|
||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||
FROM
|
||||
oauth_access_tokens
|
||||
INNER JOIN users ON
|
||||
oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1
|
||||
AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&token.to_owned()],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if let Some(result_columns) = query_rows.get(0) {
|
||||
let id = result_columns.get(1);
|
||||
let allowed_langs = result_columns
|
||||
.try_get::<_, Vec<_>>(2)
|
||||
.unwrap_or_else(|_| Vec::new())
|
||||
.into_iter()
|
||||
.collect();
|
||||
let mut scopes: HashSet<Scope> = result_columns
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.filter_map(|scope| match scope {
|
||||
"read" => Some(Scope::Read),
|
||||
"read:statuses" => Some(Scope::Statuses),
|
||||
"read:notifications" => Some(Scope::Notifications),
|
||||
"read:lists" => Some(Scope::Lists),
|
||||
"write" | "follow" => None, // ignore write scopes
|
||||
unexpected => {
|
||||
log::warn!("Ignoring unknown scope `{}`", unexpected);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// We don't need to separately track read auth - it's just all three others
|
||||
if scopes.remove(&Scope::Read) {
|
||||
scopes.insert(Scope::Statuses);
|
||||
scopes.insert(Scope::Notifications);
|
||||
scopes.insert(Scope::Lists);
|
||||
}
|
||||
|
||||
Ok(UserData {
|
||||
id,
|
||||
allowed_langs,
|
||||
scopes,
|
||||
})
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_hashtag_id(tag_name: &String, pg_pool: PgPool) -> Result<i64, Rejection> {
|
||||
let mut conn = pg_pool.0.get().unwrap();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id
|
||||
FROM tags
|
||||
WHERE name = $1
|
||||
LIMIT 1",
|
||||
&[&tag_name],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
match rows.get(0) {
|
||||
Some(row) => Ok(row.get(0)),
|
||||
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
|
||||
}
|
||||
}
|
||||
pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result<String, Rejection> {
|
||||
let mut conn = pg_pool.0.get().unwrap();
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT name
|
||||
FROM tags
|
||||
WHERE id = $1
|
||||
LIMIT 1",
|
||||
&[&tag_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
match rows.get(0) {
|
||||
Some(row) => Ok(row.get(0)),
|
||||
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query Postgres for everyone the user has blocked or muted
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
|
||||
// "
|
||||
// SELECT
|
||||
// 1
|
||||
// FROM blocks
|
||||
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
|
||||
// OR (account_id = $2 AND target_account_id = $1)
|
||||
// UNION SELECT
|
||||
// 1
|
||||
// FROM mutes
|
||||
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
|
||||
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
|
||||
pg_pool
|
||||
.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"
|
||||
SELECT target_account_id
|
||||
FROM blocks
|
||||
WHERE account_id = $1
|
||||
UNION SELECT target_account_id
|
||||
FROM mutes
|
||||
WHERE account_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
/// Query Postgres for everyone who has blocked the user
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
|
||||
pg_pool
|
||||
.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"
|
||||
SELECT account_id
|
||||
FROM blocks
|
||||
WHERE target_account_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Query Postgres for all current domain blocks
|
||||
///
|
||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||
/// the user adds until they refresh/reconnect.
|
||||
pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet<String> {
|
||||
pg_pool
|
||||
.0
|
||||
.get()
|
||||
.unwrap()
|
||||
.query(
|
||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||
.iter()
|
||||
.map(|row| row.get(0))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Test whether a user owns a list
|
||||
pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool {
|
||||
let mut conn = pg_pool.0.get().unwrap();
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id, account_id
|
||||
FROM lists
|
||||
WHERE id = $1
|
||||
LIMIT 1",
|
||||
&[&list_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
|
||||
match rows.get(0) {
|
||||
None => false,
|
||||
Some(row) => {
|
||||
let list_owner_id: i64 = row.get(1);
|
||||
list_owner_id == user_id
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,48 +1,9 @@
|
|||
//! Filters for the WebSocket endpoint
|
||||
use super::{
|
||||
query::{self, Query},
|
||||
subscription::{PgPool, Subscription},
|
||||
};
|
||||
use warp::{filters::BoxedFilter, path, Filter};
|
||||
|
||||
/// WebSocket filters
|
||||
fn parse_query() -> BoxedFilter<(Query,)> {
|
||||
path!("api" / "v1" / "streaming")
|
||||
.and(path::end())
|
||||
.and(warp::query())
|
||||
.and(query::Auth::to_filter())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.map(
|
||||
|stream: query::Stream,
|
||||
auth: query::Auth,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List| {
|
||||
Query {
|
||||
access_token: auth.access_token,
|
||||
stream: stream.stream,
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
}
|
||||
},
|
||||
)
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub fn extract_user_and_token_or_reject(
|
||||
pg_pool: PgPool,
|
||||
whitelist_mode: bool,
|
||||
) -> BoxedFilter<(Subscription, Option<String>)> {
|
||||
parse_query()
|
||||
.and(query::OptionalAccessToken::from_ws_header())
|
||||
.and_then(Query::update_access_token)
|
||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
|
||||
.and(query::OptionalAccessToken::from_ws_header())
|
||||
.boxed()
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// #[cfg(test)]
|
||||
// mod test {
|
||||
|
|
|
@ -19,7 +19,7 @@ use super::receiver::Receiver;
|
|||
use crate::{
|
||||
config,
|
||||
messages::Event,
|
||||
parse_client_request::subscription::{Stream::Public, Subscription, Timeline},
|
||||
parse_client_request::{Stream::Public, Subscription, Timeline},
|
||||
};
|
||||
use futures::{
|
||||
Async::{self, NotReady, Ready},
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
use super::ClientAgent;
|
||||
|
||||
use warp::ws::WebSocket;
|
||||
use futures::{future::Future, stream::Stream, Async};
|
||||
use log;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub struct EventStream;
|
||||
|
||||
impl EventStream {
|
||||
|
||||
|
||||
/// Send a stream of replies to a WebSocket client.
|
||||
pub fn to_ws(
|
||||
socket: WebSocket,
|
||||
mut client_agent: ClientAgent,
|
||||
update_interval: Duration,
|
||||
) -> impl Future<Item = (), Error = ()> {
|
||||
let (ws_tx, mut ws_rx) = socket.split();
|
||||
let timeline = client_agent.subscription.timeline;
|
||||
|
||||
// Create a pipe
|
||||
let (tx, rx) = futures::sync::mpsc::unbounded();
|
||||
|
||||
// Send one end of it to a different thread and tell that end to forward whatever it gets
|
||||
// on to the websocket client
|
||||
warp::spawn(
|
||||
rx.map_err(|()| -> warp::Error { unreachable!() })
|
||||
.forward(ws_tx)
|
||||
.map(|_r| ())
|
||||
.map_err(|e| match e.to_string().as_ref() {
|
||||
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
|
||||
_ => log::warn!("websocket send error: {}", e),
|
||||
}),
|
||||
);
|
||||
|
||||
// Yield new events for as long as the client is still connected
|
||||
let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
|
||||
move |_| match ws_rx.poll() {
|
||||
Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
|
||||
Ok(Async::Ready(None)) => {
|
||||
// TODO: consider whether we should manually drop closed connections here
|
||||
log::info!("Client closed WebSocket connection for {:?}", timeline);
|
||||
futures::future::ok(false)
|
||||
}
|
||||
Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
|
||||
// no err, just closed Unix socket
|
||||
log::info!("Client closed WebSocket connection for {:?}", timeline);
|
||||
futures::future::ok(false)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Error in {:?}: {}", timeline, e);
|
||||
futures::future::ok(false)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let mut time = Instant::now();
|
||||
// Every time you get an event from that stream, send it through the pipe
|
||||
event_stream
|
||||
.for_each(move |_instant| {
|
||||
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
|
||||
tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
|
||||
.expect("No send error");
|
||||
};
|
||||
if time.elapsed() > Duration::from_secs(30) {
|
||||
tx.unbounded_send(warp::ws::Message::text("{}"))
|
||||
.expect("Can ping");
|
||||
time = Instant::now();
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.then(move |result| {
|
||||
// TODO: consider whether we should manually drop closed connections here
|
||||
log::info!("WebSocket connection for {:?} closed.", timeline);
|
||||
result
|
||||
})
|
||||
.map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
|
||||
}
|
||||
pub fn to_sse(
|
||||
mut client_agent: ClientAgent,
|
||||
connection: warp::sse::Sse,
|
||||
update_interval: Duration,
|
||||
) ->impl warp::reply::Reply {
|
||||
let event_stream =
|
||||
tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
|
||||
match client_agent.poll() {
|
||||
Ok(Async::Ready(Some(event))) => Some((
|
||||
warp::sse::event(event.event_name()),
|
||||
warp::sse::data(event.payload().unwrap_or_else(String::new)),
|
||||
)),
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
|
||||
connection.reply(
|
||||
warp::sse::keep_alive()
|
||||
.interval(Duration::from_secs(30))
|
||||
.text("thump".to_string())
|
||||
.stream(event_stream),
|
||||
)
|
||||
}
|
||||
}
|
|
@ -1,103 +1,10 @@
|
|||
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
|
||||
pub mod client_agent;
|
||||
pub mod receiver;
|
||||
pub mod redis;
|
||||
pub use client_agent::ClientAgent;
|
||||
use futures::{future::Future, stream::Stream, Async};
|
||||
use log;
|
||||
use std::time::{Duration, Instant};
|
||||
mod client_agent;
|
||||
mod receiver;
|
||||
mod redis;
|
||||
mod event_stream;
|
||||
|
||||
/// Send a stream of replies to a Server Sent Events client.
|
||||
pub fn send_updates_to_sse(
|
||||
mut client_agent: ClientAgent,
|
||||
connection: warp::sse::Sse,
|
||||
update_interval: Duration,
|
||||
) -> impl warp::reply::Reply {
|
||||
let event_stream =
|
||||
tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
|
||||
match client_agent.poll() {
|
||||
Ok(Async::Ready(Some(event))) => Some((
|
||||
warp::sse::event(event.event_name()),
|
||||
warp::sse::data(event.payload().unwrap_or_else(String::new)),
|
||||
)),
|
||||
_ => None,
|
||||
}
|
||||
});
|
||||
pub use {client_agent::ClientAgent, event_stream::EventStream};
|
||||
|
||||
connection.reply(
|
||||
warp::sse::keep_alive()
|
||||
.interval(Duration::from_secs(30))
|
||||
.text("thump".to_string())
|
||||
.stream(event_stream),
|
||||
)
|
||||
}
|
||||
|
||||
use warp::ws::WebSocket;
|
||||
|
||||
/// Send a stream of replies to a WebSocket client.
|
||||
pub fn send_updates_to_ws(
|
||||
socket: WebSocket,
|
||||
mut client_agent: ClientAgent,
|
||||
update_interval: Duration,
|
||||
) -> impl Future<Item = (), Error = ()> {
|
||||
let (ws_tx, mut ws_rx) = socket.split();
|
||||
let timeline = client_agent.subscription.timeline;
|
||||
|
||||
// Create a pipe
|
||||
let (tx, rx) = futures::sync::mpsc::unbounded();
|
||||
|
||||
// Send one end of it to a different thread and tell that end to forward whatever it gets
|
||||
// on to the websocket client
|
||||
warp::spawn(
|
||||
rx.map_err(|()| -> warp::Error { unreachable!() })
|
||||
.forward(ws_tx)
|
||||
.map(|_r| ())
|
||||
.map_err(|e| match e.to_string().as_ref() {
|
||||
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
|
||||
_ => log::warn!("websocket send error: {}", e),
|
||||
}),
|
||||
);
|
||||
|
||||
// Yield new events for as long as the client is still connected
|
||||
let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
|
||||
move |_| match ws_rx.poll() {
|
||||
Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
|
||||
Ok(Async::Ready(None)) => {
|
||||
// TODO: consider whether we should manually drop closed connections here
|
||||
log::info!("Client closed WebSocket connection for {:?}", timeline);
|
||||
futures::future::ok(false)
|
||||
}
|
||||
Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
|
||||
// no err, just closed Unix socket
|
||||
log::info!("Client closed WebSocket connection for {:?}", timeline);
|
||||
futures::future::ok(false)
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Error in {:?}: {}", timeline, e);
|
||||
futures::future::ok(false)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let mut time = Instant::now();
|
||||
// Every time you get an event from that stream, send it through the pipe
|
||||
event_stream
|
||||
.for_each(move |_instant| {
|
||||
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
|
||||
tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
|
||||
.expect("No send error");
|
||||
};
|
||||
if time.elapsed() > Duration::from_secs(30) {
|
||||
tx.unbounded_send(warp::ws::Message::text("{}"))
|
||||
.expect("Can ping");
|
||||
time = Instant::now();
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.then(move |result| {
|
||||
// TODO: consider whether we should manually drop closed connections here
|
||||
log::info!("WebSocket connection for {:?} closed.", timeline);
|
||||
result
|
||||
})
|
||||
.map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::messages::Event;
|
||||
use crate::parse_client_request::subscription::Timeline;
|
||||
use crate::parse_client_request::Timeline;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt,
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
//! unsubscriptions to/from Redis.
|
||||
mod message_queues;
|
||||
use crate::{
|
||||
config::{self, RedisInterval},
|
||||
config,
|
||||
messages::Event,
|
||||
parse_client_request::{Stream, Timeline},
|
||||
pubsub_cmd,
|
||||
|
@ -12,7 +12,11 @@ use crate::{
|
|||
use futures::{Async, Poll};
|
||||
use lru::LruCache;
|
||||
pub use message_queues::{MessageQueues, MsgQueue};
|
||||
use std::{collections::HashMap, net, time::Instant};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tokio::io::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -21,7 +25,7 @@ use uuid::Uuid;
|
|||
pub struct Receiver {
|
||||
pub pubsub_connection: RedisStream,
|
||||
secondary_redis_connection: net::TcpStream,
|
||||
redis_poll_interval: RedisInterval,
|
||||
redis_poll_interval: Duration,
|
||||
redis_polled_at: Instant,
|
||||
timeline: Timeline,
|
||||
manager_id: Uuid,
|
||||
|
@ -29,8 +33,12 @@ pub struct Receiver {
|
|||
clients_per_timeline: HashMap<Timeline, i32>,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Cache {
|
||||
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
|
||||
// the tag number instead of the tag name. This would save us from dealing
|
||||
// with a cache here and would be consistent with how lists/users are handled.
|
||||
id_to_hashtag: LruCache<i64, String>,
|
||||
pub hashtag_to_id: LruCache<String, i64>,
|
||||
}
|
||||
|
@ -135,7 +143,7 @@ impl futures::stream::Stream for Receiver {
|
|||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
let (timeline, id) = (self.timeline.clone(), self.manager_id);
|
||||
|
||||
if self.redis_polled_at.elapsed() > *self.redis_poll_interval {
|
||||
if self.redis_polled_at.elapsed() > self.redis_poll_interval {
|
||||
self.pubsub_connection
|
||||
.poll_redis(&mut self.cache.hashtag_to_id, &mut self.msg_queues);
|
||||
self.redis_polled_at = Instant::now();
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use super::redis_cmd;
|
||||
use crate::config::{RedisConfig, RedisInterval, RedisNamespace};
|
||||
use crate::config::RedisConfig;
|
||||
use crate::err;
|
||||
use std::{io::Read, io::Write, net, time};
|
||||
use std::{io::Read, io::Write, net, time::Duration};
|
||||
|
||||
pub struct RedisConn {
|
||||
pub primary: net::TcpStream,
|
||||
pub secondary: net::TcpStream,
|
||||
pub namespace: RedisNamespace,
|
||||
pub polling_interval: RedisInterval,
|
||||
pub namespace: Option<String>,
|
||||
pub polling_interval: Duration,
|
||||
}
|
||||
|
||||
fn send_password(mut conn: net::TcpStream, password: &str) -> net::TcpStream {
|
||||
|
@ -68,7 +68,7 @@ impl RedisConn {
|
|||
conn = send_password(conn, &password);
|
||||
}
|
||||
conn = send_test_ping(conn);
|
||||
conn.set_read_timeout(Some(time::Duration::from_millis(10)))
|
||||
conn.set_read_timeout(Some(Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
if let Some(db) = &*redis_cfg.db {
|
||||
conn = set_db(conn, db);
|
||||
|
@ -86,8 +86,8 @@ impl RedisConn {
|
|||
Self {
|
||||
primary: primary_conn,
|
||||
secondary: secondary_conn,
|
||||
namespace: redis_cfg.namespace,
|
||||
polling_interval: redis_cfg.polling_interval,
|
||||
namespace: redis_cfg.namespace.clone(),
|
||||
polling_interval: *redis_cfg.polling_interval,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
//! three characters, the second is a bulk string with ten characters, and the third is a
|
||||
//! bulk string with 1,386 characters.
|
||||
|
||||
use crate::{log_fatal, messages::Event, parse_client_request::subscription::Timeline};
|
||||
use crate::{log_fatal, messages::Event, parse_client_request::Timeline};
|
||||
use lru::LruCache;
|
||||
type Parser<'a, Item> = Result<(Item, &'a str), ParseErr>;
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use super::super::receiver::MessageQueues;
|
||||
use super::redis_msg::{ParseErr, RedisMsg};
|
||||
use crate::{config::RedisNamespace, log_fatal};
|
||||
|
||||
use crate::log_fatal;
|
||||
use futures::{Async, Poll};
|
||||
use lru::LruCache;
|
||||
use std::{error::Error, io::Read, net};
|
||||
|
@ -11,7 +10,7 @@ use tokio::io::AsyncRead;
|
|||
pub struct RedisStream {
|
||||
pub inner: net::TcpStream,
|
||||
incoming_raw_msg: String,
|
||||
pub namespace: RedisNamespace,
|
||||
pub namespace: Option<String>,
|
||||
}
|
||||
|
||||
impl RedisStream {
|
||||
|
@ -19,10 +18,10 @@ impl RedisStream {
|
|||
RedisStream {
|
||||
inner,
|
||||
incoming_raw_msg: String::new(),
|
||||
namespace: RedisNamespace(None),
|
||||
namespace: None,
|
||||
}
|
||||
}
|
||||
pub fn with_namespace(self, namespace: RedisNamespace) -> Self {
|
||||
pub fn with_namespace(self, namespace: Option<String>) -> Self {
|
||||
RedisStream { namespace, ..self }
|
||||
}
|
||||
// Text comes in from redis as a raw stream, which could be more than one message and
|
||||
|
@ -41,7 +40,7 @@ impl RedisStream {
|
|||
self.incoming_raw_msg.push_str(&raw_utf);
|
||||
match process_messages(
|
||||
self.incoming_raw_msg.clone(),
|
||||
&mut self.namespace.0,
|
||||
&mut self.namespace,
|
||||
hashtag_to_id_cache,
|
||||
queues,
|
||||
) {
|
||||
|
|
Loading…
Reference in New Issue