Improve module boundary/privacy

This commit is contained in:
Daniel Sockwell 2020-03-26 14:43:14 -04:00
parent a7603739ee
commit 631e818998
24 changed files with 697 additions and 793 deletions

View File

@ -77,7 +77,7 @@ mod parse_inline {
mod flodgatt_parse_event {
use flodgatt::{messages::Event, redis_to_client_stream::receiver::MessageQueues};
use flodgatt::{
parse_client_request::subscription::Timeline,
parse_client_request::Timeline,
redis_to_client_stream::{receiver::MsgQueue, redis::redis_stream},
};
use lru::LruCache;
@ -114,7 +114,7 @@ mod flodgatt_parse_event {
/// the performance we would see if we used serde's built-in method for handling weakly
/// typed JSON instead of our own strongly typed struct.
mod flodgatt_parse_value {
use flodgatt::{log_fatal, parse_client_request::subscription::Timeline};
use flodgatt::{log_fatal, parse_client_request::Timeline};
use lru::LruCache;
use serde_json::Value;
use std::{

View File

@ -11,14 +11,14 @@ from_env_var!(
/// The current environment, which controls what file to read other ENV vars from
let name = Env;
let default: EnvInner = EnvInner::Development;
let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants()));
let (env_var, allowed_values) = ("RUST_ENV", &format!("one of: {:?}", EnvInner::variants()));
let from_str = |s| EnvInner::from_str(s).ok();
);
from_env_var!(
/// The address to run Flodgatt on
let name = FlodgattAddr;
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string());
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
let from_str = |s| match s {
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
_ => s.parse().ok(),
@ -28,35 +28,35 @@ from_env_var!(
/// How verbosely Flodgatt should log messages
let name = LogLevel;
let default: LogLevelInner = LogLevelInner::Warn;
let (env_var, allowed_values) = ("RUST_LOG", format!("one of: {:?}", LogLevelInner::variants()));
let (env_var, allowed_values) = ("RUST_LOG", &format!("one of: {:?}", LogLevelInner::variants()));
let from_str = |s| LogLevelInner::from_str(s).ok();
);
from_env_var!(
/// A Unix Socket to use in place of a local address
let name = Socket;
let default: Option<String> = None;
let (env_var, allowed_values) = ("SOCKET", "any string".to_string());
let (env_var, allowed_values) = ("SOCKET", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// The time between replies sent via WebSocket
let name = WsInterval;
let default: Duration = Duration::from_millis(100);
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The time between replies sent via Server Sent Events
let name = SseInterval;
let default: Duration = Duration::from_millis(100);
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The port to run Flodgatt on
let name = Port;
let default: u16 = 4000;
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string());
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
@ -66,7 +66,7 @@ from_env_var!(
/// (including otherwise public timelines).
let name = WhitelistMode;
let default: bool = false;
let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string());
let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false");
let from_str = |s| s.parse().ok();
);
/// Permissions for Cross Origin Resource Sharing (CORS)

View File

@ -0,0 +1,137 @@
use std::{collections::HashMap, fmt};
pub struct EnvVar(pub HashMap<String, String>);
impl std::ops::Deref for EnvVar {
type Target = HashMap<String, String>;
fn deref(&self) -> &HashMap<String, String> {
&self.0
}
}
impl Clone for EnvVar {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl EnvVar {
pub fn new(vars: HashMap<String, String>) -> Self {
Self(vars)
}
pub fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option<impl ToString>) {
if let Some(value) = maybe_value {
self.0.insert(key.to_string(), value.to_string());
}
}
pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
log::error!(
r"{var} is set to `{value}`, which is invalid.
{var} must be {allowed_vals}.",
var = env_var,
value = supplied_value,
allowed_vals = allowed_values
);
std::process::exit(1);
}
}
impl fmt::Display for EnvVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut result = String::new();
for env_var in [
"NODE_ENV",
"RUST_LOG",
"BIND",
"PORT",
"SOCKET",
"SSE_FREQ",
"WS_FREQ",
"DATABASE_URL",
"DB_USER",
"USER",
"DB_PORT",
"DB_HOST",
"DB_PASS",
"DB_NAME",
"DB_SSLMODE",
"REDIS_HOST",
"REDIS_USER",
"REDIS_PORT",
"REDIS_PASSWORD",
"REDIS_USER",
"REDIS_DB",
]
.iter()
{
if let Some(value) = self.get(&env_var.to_string()) {
result = format!("{}\n {}: {}", result, env_var, value)
}
}
write!(f, "{}", result)
}
}
#[macro_export]
macro_rules! maybe_update {
($name:ident; $item: tt:$type:ty) => (
pub fn $name(self, item: Option<$type>) -> Self {
match item {
Some($item) => Self{ $item, ..self },
None => Self { ..self }
}
});
($name:ident; Some($item: tt: $type:ty)) => (
fn $name(self, item: Option<$type>) -> Self{
match item {
Some($item) => Self{ $item: Some($item), ..self },
None => Self { ..self }
}
})}
#[macro_export]
macro_rules! from_env_var {
($(#[$outer:meta])*
let name = $name:ident;
let default: $type:ty = $inner:expr;
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
let from_str = |$arg:ident| $body:expr;
) => {
pub struct $name(pub $type);
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl std::ops::Deref for $name {
type Target = $type;
fn deref(&self) -> &$type {
&self.0
}
}
impl std::default::Default for $name {
fn default() -> Self {
$name($inner)
}
}
impl $name {
fn inner_from_str($arg: &str) -> Option<$type> {
$body
}
pub fn maybe_update(self, var: Option<&String>) -> Self {
match var {
Some(empty_string) if empty_string.is_empty() => Self::default(),
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
crate::config::EnvVar::err($env_var, value, $allowed_values)
})),
None => self,
}
// if let Some(value) = var {
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
// crate::err::env_var_fatal($env_var, value, $allowed_values)
// }))
// } else {
// self
// }
}
}
};
}

View File

@ -4,135 +4,7 @@ mod postgres_cfg;
mod postgres_cfg_types;
mod redis_cfg;
mod redis_cfg_types;
pub use self::{
deployment_cfg::DeploymentConfig,
postgres_cfg::PostgresConfig,
redis_cfg::RedisConfig,
redis_cfg_types::{RedisInterval, RedisNamespace},
};
use std::{collections::HashMap, fmt};
mod environmental_variables;
pub struct EnvVar(pub HashMap<String, String>);
impl std::ops::Deref for EnvVar {
type Target = HashMap<String, String>;
fn deref(&self) -> &HashMap<String, String> {
&self.0
}
}
pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar};
impl Clone for EnvVar {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl EnvVar {
pub fn new(vars: HashMap<String, String>) -> Self {
Self(vars)
}
fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option<impl ToString>) {
if let Some(value) = maybe_value {
self.0.insert(key.to_string(), value.to_string());
}
}
}
impl fmt::Display for EnvVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut result = String::new();
for env_var in [
"NODE_ENV",
"RUST_LOG",
"BIND",
"PORT",
"SOCKET",
"SSE_FREQ",
"WS_FREQ",
"DATABASE_URL",
"DB_USER",
"USER",
"DB_PORT",
"DB_HOST",
"DB_PASS",
"DB_NAME",
"DB_SSLMODE",
"REDIS_HOST",
"REDIS_USER",
"REDIS_PORT",
"REDIS_PASSWORD",
"REDIS_USER",
"REDIS_DB",
]
.iter()
{
if let Some(value) = self.get(&env_var.to_string()) {
result = format!("{}\n {}: {}", result, env_var, value)
}
}
write!(f, "{}", result)
}
}
#[macro_export]
macro_rules! maybe_update {
($name:ident; $item: tt:$type:ty) => (
pub fn $name(self, item: Option<$type>) -> Self {
match item {
Some($item) => Self{ $item, ..self },
None => Self { ..self }
}
});
($name:ident; Some($item: tt: $type:ty)) => (
fn $name(self, item: Option<$type>) -> Self{
match item {
Some($item) => Self{ $item: Some($item), ..self },
None => Self { ..self }
}
})}
#[macro_export]
macro_rules! from_env_var {
($(#[$outer:meta])*
let name = $name:ident;
let default: $type:ty = $inner:expr;
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
let from_str = |$arg:ident| $body:expr;
) => {
pub struct $name(pub $type);
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl std::ops::Deref for $name {
type Target = $type;
fn deref(&self) -> &$type {
&self.0
}
}
impl std::default::Default for $name {
fn default() -> Self {
$name($inner)
}
}
impl $name {
fn inner_from_str($arg: &str) -> Option<$type> {
$body
}
pub fn maybe_update(self, var: Option<&String>) -> Self {
match var {
Some(empty_string) if empty_string.is_empty() => Self::default(),
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
crate::err::env_var_fatal($env_var, value, $allowed_values)
})),
None => self,
}
// if let Some(value) = var {
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
// crate::err::env_var_fatal($env_var, value, $allowed_values)
// }))
// } else {
// self
// }
}
}
};
}

View File

@ -6,7 +6,7 @@ from_env_var!(
/// The user to use for Postgres
let name = PgUser;
let default: String = "postgres".to_string();
let (env_var, allowed_values) = ("DB_USER", "any string".to_string());
let (env_var, allowed_values) = ("DB_USER", "any string");
let from_str = |s| Some(s.to_string());
);
@ -14,7 +14,7 @@ from_env_var!(
/// The host address where Postgres is running)
let name = PgHost;
let default: String = "localhost".to_string();
let (env_var, allowed_values) = ("DB_HOST", "any string".to_string());
let (env_var, allowed_values) = ("DB_HOST", "any string");
let from_str = |s| Some(s.to_string());
);
@ -22,7 +22,7 @@ from_env_var!(
/// The password to use with Postgress
let name = PgPass;
let default: Option<String> = None;
let (env_var, allowed_values) = ("DB_PASS", "any string".to_string());
let (env_var, allowed_values) = ("DB_PASS", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
@ -30,7 +30,7 @@ from_env_var!(
/// The Postgres database to use
let name = PgDatabase;
let default: String = "mastodon_development".to_string();
let (env_var, allowed_values) = ("DB_NAME", "any string".to_string());
let (env_var, allowed_values) = ("DB_NAME", "any string");
let from_str = |s| Some(s.to_string());
);
@ -38,14 +38,14 @@ from_env_var!(
/// The port Postgres is running on
let name = PgPort;
let default: u16 = 5432;
let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535".to_string());
let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
let name = PgSslMode;
let default: PgSslInner = PgSslInner::Prefer;
let (env_var, allowed_values) = ("DB_SSLMODE", format!("one of: {:?}", PgSslInner::variants()));
let (env_var, allowed_values) = ("DB_SSLMODE", &format!("one of: {:?}", PgSslInner::variants()));
let from_str = |s| PgSslInner::from_str(s).ok();
);

View File

@ -7,48 +7,48 @@ from_env_var!(
/// The host address where Redis is running
let name = RedisHost;
let default: String = "127.0.0.1".to_string();
let (env_var, allowed_values) = ("REDIS_HOST", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_HOST", "any string");
let from_str = |s| Some(s.to_string());
);
from_env_var!(
/// The port Redis is running on
let name = RedisPort;
let default: u16 = 6379;
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string());
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
/// How frequently to poll Redis
let name = RedisInterval;
let default: Duration = Duration::from_millis(100);
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string());
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The password to use for Redis
let name = RedisPass;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// An optional Redis Namespace
let name = RedisNamespace;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// A user for Redis (not supported)
let name = RedisUser;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_USER", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// The database to use with Redis (no current effect for PubSub connections)
let name = RedisDb;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_DB", "any string");
let from_str = |s| Some(Some(s.to_string()));
);

View File

@ -1,4 +1,3 @@
use serde_derive::Serialize;
use std::fmt::Display;
pub fn die_with_msg(msg: impl Display) -> ! {
@ -13,68 +12,3 @@ macro_rules! log_fatal {
panic!();
};};
}
pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
eprintln!(
r"FATAL ERROR: {var} is set to `{value}`, which is invalid.
{var} must be {allowed_vals}.",
var = env_var,
value = supplied_value,
allowed_vals = allowed_values
);
std::process::exit(1);
}
#[macro_export]
macro_rules! dbg_and_die {
($msg:expr) => {
let message = format!("FATAL ERROR: {}", $msg);
dbg!(message);
std::process::exit(1);
};
}
pub fn unwrap_or_die<T>(s: Option<T>, msg: &str) -> T {
s.unwrap_or_else(|| {
eprintln!("FATAL ERROR: {}", msg);
std::process::exit(1)
})
}
#[derive(Serialize)]
pub struct ErrorMessage {
error: String,
}
impl ErrorMessage {
fn new(msg: impl std::fmt::Display) -> Self {
Self {
error: msg.to_string(),
}
}
}
/// Recover from Errors by sending appropriate Warp::Rejections
pub fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {
let err_txt = match rejection.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Error: Nonexistant endpoint".to_string(),
};
let json = warp::reply::json(&ErrorMessage::new(err_txt));
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
}
pub struct CustomError {}
impl CustomError {
pub fn unauthorized_list() -> warp::reject::Rejection {
warp::reject::custom("Error: Access to list not authorized")
}
}

View File

@ -1,32 +1,32 @@
use flodgatt::{
config, err,
parse_client_request::{sse, subscription, ws},
redis_to_client_stream::{self, ClientAgent},
config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig},
parse_client_request::{PgPool, Subscription},
redis_to_client_stream::{ClientAgent, EventStream},
};
use std::{collections::HashMap, env, fs, net, os::unix::fs::PermissionsExt};
use tokio::net::UnixListener;
use warp::{path, ws::Ws2, Filter};
fn main() {
dotenv::from_filename(
match env::var("ENV").ok().as_ref().map(String::as_str) {
dotenv::from_filename(match env::var("ENV").ok().as_ref().map(String::as_str) {
Some("production") => ".env.production",
Some("development") | None => ".env",
Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."),
}).ok();
Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"),
})
.ok();
let env_vars_map: HashMap<_, _> = dotenv::vars().collect();
let env_vars = config::EnvVar::new(env_vars_map);
let env_vars = EnvVar::new(env_vars_map);
pretty_env_logger::init();
log::info!(
"Flodgatt recognized the following environmental variables:{}",
env_vars.clone()
);
let redis_cfg = config::RedisConfig::from_env(env_vars.clone());
let cfg = config::DeploymentConfig::from_env(env_vars.clone());
let redis_cfg = RedisConfig::from_env(env_vars.clone());
let cfg = DeploymentConfig::from_env(env_vars.clone());
let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone());
let pg_pool = subscription::PgPool::new(postgres_cfg);
let postgres_cfg = PostgresConfig::from_env(env_vars.clone());
let pg_pool = PgPool::new(postgres_cfg);
let client_agent_sse = ClientAgent::blank(redis_cfg);
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
@ -34,54 +34,43 @@ fn main() {
log::info!("Streaming server initialized and ready to accept connections");
// Server Sent Events
let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode)
let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode)
.and(warp::sse())
.map(
move |subscription: subscription::Subscription,
sse_connection_to_client: warp::sse::Sse| {
move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| {
log::info!("Incoming SSE request for {:?}", subscription.timeline);
// Create a new ClientAgent
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
// Assign ClientAgent to generate stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the SSE connection
redis_to_client_stream::send_updates_to_sse(
client_agent,
sse_connection_to_client,
sse_update_interval,
)
EventStream::to_sse(client_agent, sse_connection_to_client, sse_interval)
},
)
.with(warp::reply::with::header("Connection", "keep-alive"))
.recover(err::handle_errors);
.with(warp::reply::with::header("Connection", "keep-alive"));
// WebSocket
let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode);
let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode)
let websocket_routes = Subscription::from_ws_request(pg_pool.clone(), whitelist_mode)
.and(warp::ws::ws2())
.map(
move |subscription: subscription::Subscription, token: Option<String>, ws: Ws2| {
log::info!("Incoming websocket request for {:?}", subscription.timeline);
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
.map(move |subscription: Subscription, ws: Ws2| {
log::info!("Incoming websocket request for {:?}", subscription.timeline);
(
ws.on_upgrade(move |socket| {
redis_to_client_stream::send_updates_to_ws(
socket,
client_agent,
ws_update_interval,
)
}),
token.unwrap_or_else(String::new),
)
},
)
let token = subscription.access_token.clone();
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
(
ws.on_upgrade(move |socket| {
EventStream::to_ws(socket, client_agent, ws_update_interval)
}),
token.unwrap_or_else(String::new),
)
})
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
let cors = warp::cors()
@ -98,7 +87,28 @@ fn main() {
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap();
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run_incoming(incoming);
warp::serve(
health.or(websocket_routes.or(sse_routes).with(cors).recover(
|rejection: warp::reject::Rejection| {
let err_txt = match rejection.cause() {
Some(text)
if text.to_string() == "Missing request header 'authorization'" =>
{
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Error: Nonexistant endpoint".to_string(),
};
let json = warp::reply::json(&err_txt);
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
},
)),
)
.run_incoming(incoming);
} else {
let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0);
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run(server_addr);

View File

@ -1,7 +1,9 @@
//! Parse the client request and return a (possibly authenticated) `User`
pub mod query;
pub mod sse;
pub mod subscription;
pub mod ws;
//! Parse the client request and return a Subscription
mod postgres;
mod query;
mod sse;
mod subscription;
mod ws;
pub use subscription::{Stream, Timeline};
pub use self::postgres::PgPool;
pub use subscription::{Stream, Subscription, Timeline};

View File

@ -0,0 +1,204 @@
//! Postgres queries
use crate::{
config,
parse_client_request::subscription::{Scope, UserData},
};
use ::postgres;
use r2d2_postgres::PostgresConnectionManager;
use std::collections::HashSet;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
impl PgPool {
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string())
.port(*pg_cfg.port)
.dbname(&pg_cfg.database);
if let Some(password) = &*pg_cfg.password {
cfg.password(password);
};
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder()
.max_size(10)
.build(manager)
.expect("Can connect to local postgres");
Self(pool)
}
pub fn select_user(self, token: &str) -> Result<UserData, Rejection> {
let mut conn = self.0.get().unwrap();
let query_rows = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let id = result_columns.get(1);
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
.into_iter()
.collect();
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.filter_map(|scope| match scope {
"read" => Some(Scope::Read),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
"write" | "follow" => None, // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
None
}
})
.collect();
// We don't need to separately track read auth - it's just all three others
if scopes.remove(&Scope::Read) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
}
Ok(UserData {
id,
allowed_langs,
scopes,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
pub fn select_hashtag_id(self, tag_name: &String) -> Result<i64, Rejection> {
let mut conn = self.0.get().unwrap();
let rows = &conn
.query(
"
SELECT id
FROM tags
WHERE name = $1
LIMIT 1",
&[&tag_name],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
/// Query Postgres for everyone the user has blocked or muted
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: i64) -> HashSet<i64> {
// "
// SELECT
// 1
// FROM blocks
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
// OR (account_id = $2 AND target_account_id = $1)
// UNION SELECT
// 1
// FROM mutes
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
self
.0
.get()
.unwrap()
.query(
"
SELECT target_account_id
FROM blocks
WHERE account_id = $1
UNION SELECT target_account_id
FROM mutes
WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: i64) -> HashSet<i64> {
self
.0
.get()
.unwrap()
.query(
"
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for all current domain blocks
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: i64) -> HashSet<String> {
self
.0
.get()
.unwrap()
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool {
let mut conn = self.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
None => false,
Some(row) => {
let list_owner_id: i64 = row.get(1);
list_owner_id == user_id
}
}
}
}

View File

@ -28,6 +28,12 @@ impl Query {
}
macro_rules! make_query_type {
(Stream => $parameter:tt:$type:ty) => {
#[derive(Deserialize, Debug, Default)]
pub struct Stream {
pub $parameter: $type,
}
};
($name:tt => $parameter:tt:$type:ty) => {
#[derive(Deserialize, Debug, Default)]
pub struct $name {
@ -59,14 +65,14 @@ impl ToString for Stream {
}
}
pub fn optional_media_query() -> BoxedFilter<(Media,)> {
warp::query()
.or(warp::any().map(|| Media {
only_media: "false".to_owned(),
}))
.unify()
.boxed()
}
// pub fn optional_media_query() -> BoxedFilter<(Media,)> {
// warp::query()
// .or(warp::any().map(|| Media {
// only_media: "false".to_owned(),
// }))
// .unify()
// .boxed()
// }
pub struct OptionalAccessToken;

View File

@ -1,78 +1,4 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use super::{
query::{self, Query},
subscription::{PgPool, Subscription},
};
use warp::{filters::BoxedFilter, path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, Subscription),);
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*.boxed()
};
}
macro_rules! parse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
pub fn extract_user_or_reject(
pg_pool: PgPool,
whitelist_mode: bool,
) -> BoxedFilter<(Subscription,)> {
any_of!(
parse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
// #[cfg(test)]
// mod test {

View File

@ -4,20 +4,55 @@
// #[cfg(test)]
// use mock_postgres as postgres;
// #[cfg(not(test))]
pub mod postgres;
pub use self::postgres::PgPool;
use super::postgres::PgPool;
use super::query::Query;
use crate::log_fatal;
use std::collections::HashSet;
use warp::reject::Rejection;
/// The User (with data read from Postgres)
use super::query;
use warp::{filters::BoxedFilter, path, Filter};
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*.boxed()
};
}
macro_rules! parse_sse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
#[derive(Clone, Debug, PartialEq)]
pub struct Subscription {
pub timeline: Timeline,
pub allowed_langs: HashSet<String>,
pub blocks: Blocks,
pub hashtag_name: Option<String>,
pub access_token: Option<String>,
}
impl Default for Subscription {
@ -27,14 +62,54 @@ impl Default for Subscription {
allowed_langs: HashSet::new(),
blocks: Blocks::default(),
hashtag_name: None,
access_token: None,
}
}
}
impl Subscription {
pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
parse_ws_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
any_of!(
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
let user = match q.access_token.clone() {
Some(token) => postgres::select_user(&token, pool.clone())?,
Some(token) => pool.clone().select_user(&token)?,
None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
None => UserData::public(),
};
@ -48,15 +123,42 @@ impl Subscription {
timeline,
allowed_langs: user.allowed_langs,
blocks: Blocks {
blocking_users: postgres::select_blocking_users(user.id, pool.clone()),
blocked_users: postgres::select_blocked_users(user.id, pool.clone()),
blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()),
blocking_users: pool.clone().select_blocking_users(user.id),
blocked_users: pool.clone().select_blocked_users(user.id),
blocked_domains: pool.clone().select_blocked_domains(user.id),
},
hashtag_name,
access_token: q.access_token,
})
}
}
fn parse_ws_query() -> BoxedFilter<(Query,)> {
path!("api" / "v1" / "streaming")
.and(path::end())
.and(warp::query())
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|stream: query::Stream,
auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub struct Timeline(pub Stream, pub Reach, pub Content);
@ -111,8 +213,8 @@ impl Timeline {
}
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
let id_from_hashtag = || postgres::select_hashtag_id(&q.hashtag, pool.clone());
let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone());
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
Ok(match q.stream.as_ref() {
"public" => match q.media {
@ -189,9 +291,9 @@ pub struct Blocks {
#[derive(Clone, Debug, PartialEq)]
pub struct UserData {
id: i64,
allowed_langs: HashSet<String>,
scopes: HashSet<Scope>,
pub id: i64,
pub allowed_langs: HashSet<String>,
pub scopes: HashSet<Scope>,
}
impl UserData {

View File

@ -1,43 +0,0 @@
//! Mock Postgres connection (for use in unit testing)
use super::{OauthScope, Subscription};
use std::collections::HashSet;
#[derive(Clone)]
pub struct PgPool;
impl PgPool {
pub fn new() -> Self {
Self
}
}
pub fn select_user(
access_token: &str,
_pg_pool: PgPool,
) -> Result<Subscription, warp::reject::Rejection> {
let mut user = Subscription::default();
if access_token == "TEST_USER" {
user.id = 1;
user.logged_in = true;
user.access_token = "TEST_USER".to_string();
user.email = "user@example.com".to_string();
user.scopes = OauthScope::from(vec![
"read".to_string(),
"write".to_string(),
"follow".to_string(),
]);
} else if access_token == "INVALID" {
return Err(warp::reject::custom("Error: Invalid access token"));
}
Ok(user)
}
pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet<i64> {
HashSet::new()
}
pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet<String> {
HashSet::new()
}
pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool {
user_id == list_id
}

View File

@ -1,224 +0,0 @@
//! Postgres queries
use crate::{
config,
parse_client_request::subscription::{Scope, UserData},
};
use ::postgres;
use r2d2_postgres::PostgresConnectionManager;
use std::collections::HashSet;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
impl PgPool {
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string())
.port(*pg_cfg.port)
.dbname(&pg_cfg.database);
if let Some(password) = &*pg_cfg.password {
cfg.password(password);
};
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder()
.max_size(10)
.build(manager)
.expect("Can connect to local postgres");
Self(pool)
}
}
pub fn select_user(token: &str, pool: PgPool) -> Result<UserData, Rejection> {
let mut conn = pool.0.get().unwrap();
let query_rows = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let id = result_columns.get(1);
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
.into_iter()
.collect();
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.filter_map(|scope| match scope {
"read" => Some(Scope::Read),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
"write" | "follow" => None, // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
None
}
})
.collect();
// We don't need to separately track read auth - it's just all three others
if scopes.remove(&Scope::Read) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
}
Ok(UserData {
id,
allowed_langs,
scopes,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
pub fn select_hashtag_id(tag_name: &String, pg_pool: PgPool) -> Result<i64, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
let rows = &conn
.query(
"
SELECT id
FROM tags
WHERE name = $1
LIMIT 1",
&[&tag_name],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result<String, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT name
FROM tags
WHERE id = $1
LIMIT 1",
&[&tag_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
/// Query Postgres for everyone the user has blocked or muted
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
// "
// SELECT
// 1
// FROM blocks
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
// OR (account_id = $2 AND target_account_id = $1)
// UNION SELECT
// 1
// FROM mutes
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
pg_pool
.0
.get()
.unwrap()
.query(
"
SELECT target_account_id
FROM blocks
WHERE account_id = $1
UNION SELECT target_account_id
FROM mutes
WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
pg_pool
.0
.get()
.unwrap()
.query(
"
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for all current domain blocks
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet<String> {
pg_pool
.0
.get()
.unwrap()
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
None => false,
Some(row) => {
let list_owner_id: i64 = row.get(1);
list_owner_id == user_id
}
}
}

View File

@ -1,48 +1,9 @@
//! Filters for the WebSocket endpoint
use super::{
query::{self, Query},
subscription::{PgPool, Subscription},
};
use warp::{filters::BoxedFilter, path, Filter};
/// WebSocket filters
fn parse_query() -> BoxedFilter<(Query,)> {
path!("api" / "v1" / "streaming")
.and(path::end())
.and(warp::query())
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|stream: query::Stream,
auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
}
pub fn extract_user_and_token_or_reject(
pg_pool: PgPool,
whitelist_mode: bool,
) -> BoxedFilter<(Subscription, Option<String>)> {
parse_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.and(query::OptionalAccessToken::from_ws_header())
.boxed()
}
// #[cfg(test)]
// mod test {

View File

@ -19,7 +19,7 @@ use super::receiver::Receiver;
use crate::{
config,
messages::Event,
parse_client_request::subscription::{Stream::Public, Subscription, Timeline},
parse_client_request::{Stream::Public, Subscription, Timeline},
};
use futures::{
Async::{self, NotReady, Ready},

View File

@ -0,0 +1,103 @@
use super::ClientAgent;
use warp::ws::WebSocket;
use futures::{future::Future, stream::Stream, Async};
use log;
use std::time::{Duration, Instant};
pub struct EventStream;
impl EventStream {
/// Send a stream of replies to a WebSocket client.
pub fn to_ws(
socket: WebSocket,
mut client_agent: ClientAgent,
update_interval: Duration,
) -> impl Future<Item = (), Error = ()> {
let (ws_tx, mut ws_rx) = socket.split();
let timeline = client_agent.subscription.timeline;
// Create a pipe
let (tx, rx) = futures::sync::mpsc::unbounded();
// Send one end of it to a different thread and tell that end to forward whatever it gets
// on to the websocket client
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map(|_r| ())
.map_err(|e| match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("websocket send error: {}", e),
}),
);
// Yield new events for as long as the client is still connected
let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
move |_| match ws_rx.poll() {
Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
Ok(Async::Ready(None)) => {
// TODO: consider whether we should manually drop closed connections here
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
// no err, just closed Unix socket
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) => {
log::warn!("Error in {:?}: {}", timeline, e);
futures::future::ok(false)
}
},
);
let mut time = Instant::now();
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_instant| {
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
.expect("No send error");
};
if time.elapsed() > Duration::from_secs(30) {
tx.unbounded_send(warp::ws::Message::text("{}"))
.expect("Can ping");
time = Instant::now();
}
Ok(())
})
.then(move |result| {
// TODO: consider whether we should manually drop closed connections here
log::info!("WebSocket connection for {:?} closed.", timeline);
result
})
.map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
}
pub fn to_sse(
mut client_agent: ClientAgent,
connection: warp::sse::Sse,
update_interval: Duration,
) ->impl warp::reply::Reply {
let event_stream =
tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
match client_agent.poll() {
Ok(Async::Ready(Some(event))) => Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
)),
_ => None,
}
});
connection.reply(
warp::sse::keep_alive()
.interval(Duration::from_secs(30))
.text("thump".to_string())
.stream(event_stream),
)
}
}

View File

@ -1,103 +1,10 @@
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
pub mod client_agent;
pub mod receiver;
pub mod redis;
pub use client_agent::ClientAgent;
use futures::{future::Future, stream::Stream, Async};
use log;
use std::time::{Duration, Instant};
mod client_agent;
mod receiver;
mod redis;
mod event_stream;
/// Send a stream of replies to a Server Sent Events client.
pub fn send_updates_to_sse(
mut client_agent: ClientAgent,
connection: warp::sse::Sse,
update_interval: Duration,
) -> impl warp::reply::Reply {
let event_stream =
tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
match client_agent.poll() {
Ok(Async::Ready(Some(event))) => Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
)),
_ => None,
}
});
pub use {client_agent::ClientAgent, event_stream::EventStream};
connection.reply(
warp::sse::keep_alive()
.interval(Duration::from_secs(30))
.text("thump".to_string())
.stream(event_stream),
)
}
use warp::ws::WebSocket;
/// Send a stream of replies to a WebSocket client.
pub fn send_updates_to_ws(
socket: WebSocket,
mut client_agent: ClientAgent,
update_interval: Duration,
) -> impl Future<Item = (), Error = ()> {
let (ws_tx, mut ws_rx) = socket.split();
let timeline = client_agent.subscription.timeline;
// Create a pipe
let (tx, rx) = futures::sync::mpsc::unbounded();
// Send one end of it to a different thread and tell that end to forward whatever it gets
// on to the websocket client
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map(|_r| ())
.map_err(|e| match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("websocket send error: {}", e),
}),
);
// Yield new events for as long as the client is still connected
let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
move |_| match ws_rx.poll() {
Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
Ok(Async::Ready(None)) => {
// TODO: consider whether we should manually drop closed connections here
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
// no err, just closed Unix socket
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) => {
log::warn!("Error in {:?}: {}", timeline, e);
futures::future::ok(false)
}
},
);
let mut time = Instant::now();
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_instant| {
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
.expect("No send error");
};
if time.elapsed() > Duration::from_secs(30) {
tx.unbounded_send(warp::ws::Message::text("{}"))
.expect("Can ping");
time = Instant::now();
}
Ok(())
})
.then(move |result| {
// TODO: consider whether we should manually drop closed connections here
log::info!("WebSocket connection for {:?} closed.", timeline);
result
})
.map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
}

View File

@ -1,5 +1,5 @@
use crate::messages::Event;
use crate::parse_client_request::subscription::Timeline;
use crate::parse_client_request::Timeline;
use std::{
collections::{HashMap, VecDeque},
fmt,

View File

@ -3,7 +3,7 @@
//! unsubscriptions to/from Redis.
mod message_queues;
use crate::{
config::{self, RedisInterval},
config,
messages::Event,
parse_client_request::{Stream, Timeline},
pubsub_cmd,
@ -12,7 +12,11 @@ use crate::{
use futures::{Async, Poll};
use lru::LruCache;
pub use message_queues::{MessageQueues, MsgQueue};
use std::{collections::HashMap, net, time::Instant};
use std::{
collections::HashMap,
net,
time::{Duration, Instant},
};
use tokio::io::Error;
use uuid::Uuid;
@ -21,7 +25,7 @@ use uuid::Uuid;
pub struct Receiver {
pub pubsub_connection: RedisStream,
secondary_redis_connection: net::TcpStream,
redis_poll_interval: RedisInterval,
redis_poll_interval: Duration,
redis_polled_at: Instant,
timeline: Timeline,
manager_id: Uuid,
@ -29,8 +33,12 @@ pub struct Receiver {
clients_per_timeline: HashMap<Timeline, i32>,
cache: Cache,
}
#[derive(Debug)]
pub struct Cache {
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
// the tag number instead of the tag name. This would save us from dealing
// with a cache here and would be consistent with how lists/users are handled.
id_to_hashtag: LruCache<i64, String>,
pub hashtag_to_id: LruCache<String, i64>,
}
@ -135,7 +143,7 @@ impl futures::stream::Stream for Receiver {
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let (timeline, id) = (self.timeline.clone(), self.manager_id);
if self.redis_polled_at.elapsed() > *self.redis_poll_interval {
if self.redis_polled_at.elapsed() > self.redis_poll_interval {
self.pubsub_connection
.poll_redis(&mut self.cache.hashtag_to_id, &mut self.msg_queues);
self.redis_polled_at = Instant::now();

View File

@ -1,13 +1,13 @@
use super::redis_cmd;
use crate::config::{RedisConfig, RedisInterval, RedisNamespace};
use crate::config::RedisConfig;
use crate::err;
use std::{io::Read, io::Write, net, time};
use std::{io::Read, io::Write, net, time::Duration};
pub struct RedisConn {
pub primary: net::TcpStream,
pub secondary: net::TcpStream,
pub namespace: RedisNamespace,
pub polling_interval: RedisInterval,
pub namespace: Option<String>,
pub polling_interval: Duration,
}
fn send_password(mut conn: net::TcpStream, password: &str) -> net::TcpStream {
@ -68,7 +68,7 @@ impl RedisConn {
conn = send_password(conn, &password);
}
conn = send_test_ping(conn);
conn.set_read_timeout(Some(time::Duration::from_millis(10)))
conn.set_read_timeout(Some(Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
if let Some(db) = &*redis_cfg.db {
conn = set_db(conn, db);
@ -86,8 +86,8 @@ impl RedisConn {
Self {
primary: primary_conn,
secondary: secondary_conn,
namespace: redis_cfg.namespace,
polling_interval: redis_cfg.polling_interval,
namespace: redis_cfg.namespace.clone(),
polling_interval: *redis_cfg.polling_interval,
}
}
}

View File

@ -18,7 +18,7 @@
//! three characters, the second is a bulk string with ten characters, and the third is a
//! bulk string with 1,386 characters.
use crate::{log_fatal, messages::Event, parse_client_request::subscription::Timeline};
use crate::{log_fatal, messages::Event, parse_client_request::Timeline};
use lru::LruCache;
type Parser<'a, Item> = Result<(Item, &'a str), ParseErr>;
#[derive(Debug)]

View File

@ -1,7 +1,6 @@
use super::super::receiver::MessageQueues;
use super::redis_msg::{ParseErr, RedisMsg};
use crate::{config::RedisNamespace, log_fatal};
use crate::log_fatal;
use futures::{Async, Poll};
use lru::LruCache;
use std::{error::Error, io::Read, net};
@ -11,7 +10,7 @@ use tokio::io::AsyncRead;
pub struct RedisStream {
pub inner: net::TcpStream,
incoming_raw_msg: String,
pub namespace: RedisNamespace,
pub namespace: Option<String>,
}
impl RedisStream {
@ -19,10 +18,10 @@ impl RedisStream {
RedisStream {
inner,
incoming_raw_msg: String::new(),
namespace: RedisNamespace(None),
namespace: None,
}
}
pub fn with_namespace(self, namespace: RedisNamespace) -> Self {
pub fn with_namespace(self, namespace: Option<String>) -> Self {
RedisStream { namespace, ..self }
}
// Text comes in from redis as a raw stream, which could be more than one message and
@ -41,7 +40,7 @@ impl RedisStream {
self.incoming_raw_msg.push_str(&raw_utf);
match process_messages(
self.incoming_raw_msg.clone(),
&mut self.namespace.0,
&mut self.namespace,
hashtag_to_id_cache,
queues,
) {