diff --git a/Cargo.lock b/Cargo.lock index 9f61f78..7ed6a10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,17 +386,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.3.5" +version = "0.3.6" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)", - "lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "postgres 0.16.0-rc.2 (git+https://github.com/sfackler/rust-postgres.git)", "postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)", "pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", - "regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", "serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 679c4e0..796dbad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.3.5" +version = "0.3.6" authors = ["Daniel Long Sockwell "] edition = "2018" @@ -10,7 +10,6 @@ log = "0.4.6" futures = "0.1.26" tokio = "0.1.19" warp = "0.1.15" -regex = "1.1.5" serde_json = "1.0.39" serde_derive = "1.0.90" serde = "1.0.90" @@ -18,7 +17,6 @@ pretty_env_logger = "0.3.0" postgres = { git = "https://github.com/sfackler/rust-postgres.git" } uuid = { version = "0.7", features = ["v4"] } dotenv = "0.14.0" -lazy_static = "1.3.0" postgres-openssl = { git = "https://github.com/sfackler/rust-postgres.git"} url = "2.1.0" diff --git a/src/config/deployment_cfg.rs b/src/config/deployment_cfg.rs new file mode 100644 index 0000000..12ba5c5 --- /dev/null +++ b/src/config/deployment_cfg.rs @@ -0,0 +1,102 @@ +use crate::{err, maybe_update}; +use std::{ + collections::HashMap, + fmt, + net::{IpAddr, Ipv4Addr}, + os::unix::net::UnixListener, + time::Duration, +}; + +#[derive(Debug)] +pub struct DeploymentConfig<'a> { + pub env: String, + pub log_level: String, + pub address: IpAddr, + pub port: u16, + pub unix_socket: Option, + pub cors: Cors<'a>, + pub sse_interval: Duration, + pub ws_interval: Duration, +} + +pub struct Cors<'a> { + pub allowed_headers: Vec<&'a str>, + pub allowed_methods: Vec<&'a str>, +} +impl fmt::Debug for Cors<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "allowed headers: {:?}\n allowed methods: {:?}", + self.allowed_headers, self.allowed_methods + ) + } +} + +impl Default for DeploymentConfig<'_> { + fn default() -> Self { + Self { + env: "development".to_string(), + log_level: "error".to_string(), + address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + port: 4000, + unix_socket: None, + cors: Cors { + allowed_methods: vec!["GET", "OPTIONS"], + allowed_headers: vec!["Authorization", "Accept", "Cache-Control"], + }, + sse_interval: Duration::from_millis(100), + ws_interval: Duration::from_millis(100), + } + } +} +impl DeploymentConfig<'_> { + pub fn from_env(env_vars: HashMap) -> Self { + Self::default() + .maybe_update_env(env_vars.get("NODE_ENV").map(String::from)) + .maybe_update_env(env_vars.get("RUST_ENV").map(String::from)) + .maybe_update_address( + env_vars + .get("BIND") + .map(|a| err::unwrap_or_die(a.parse().ok(), "BIND must be a valid address")), + ) + .maybe_update_port( + env_vars + .get("PORT") + .map(|port| err::unwrap_or_die(port.parse().ok(), "PORT must be a number")), + ) + .maybe_update_unix_socket( + env_vars + .get("SOCKET") + .map(|s| UnixListener::bind(s).unwrap()), + ) + .maybe_update_log_level(env_vars.get("RUST_LOG").map(|level| match level.as_ref() { + l @ "trace" | l @ "debug" | l @ "info" | l @ "warn" | l @ "error" => l.to_string(), + _ => err::die_with_msg("Invalid log level specified"), + })) + .maybe_update_sse_interval( + env_vars + .get("SSE_UPDATE_INTERVAL") + .map(|str| Duration::from_millis(str.parse().unwrap())), + ) + .maybe_update_ws_interval( + env_vars + .get("WS_UPDATE_INTERVAL") + .map(|str| Duration::from_millis(str.parse().unwrap())), + ) + .log() + } + + maybe_update!(maybe_update_env; env: String); + maybe_update!(maybe_update_port; port: u16); + maybe_update!(maybe_update_address; address: IpAddr); + maybe_update!(maybe_update_unix_socket; Some(unix_socket: UnixListener)); + maybe_update!(maybe_update_log_level; log_level: String); + maybe_update!(maybe_update_sse_interval; sse_interval: Duration); + maybe_update!(maybe_update_ws_interval; ws_interval: Duration); + + fn log(self) -> Self { + log::warn!("Using deployment configuration:\n {:#?}", &self); + self + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 0fba3a9..f040b5b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,126 +1,29 @@ //! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden //! by an environmental variable of the same name without that prefix (either by setting //! the variable at runtime or in the `.env` file) +mod deployment_cfg; mod postgres_cfg; mod redis_cfg; -pub use self::{postgres_cfg::PostgresConfig, redis_cfg::RedisConfig}; -use dotenv::dotenv; -use lazy_static::lazy_static; -use log::warn; -use std::{env, net}; -use url::Url; +pub use self::{ + deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, +}; -const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"]; -const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"]; -// Postgres -// Deployment -const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000"; - -const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; -const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; -/// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` -/// (on the order of 10ms rather than 50μs). Thus, changing this setting -/// would be a good place to start for performance improvements at the cost -/// of delaying all updates. -const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; - -lazy_static! { - pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR") - .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) - .parse() - .expect("static string"); - - /// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via SSE. - pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL); - /// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via WS. - pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(DEFAULT_WS_UPDATE_INTERVAL); - /// Interval, in ms, at which the `Receiver` polls Redis. - pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(DEFAULT_REDIS_POLL_INTERVAL); -} - -/// Configure CORS for the API server -pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { - warp::cors() - .allow_any_origin() - .allow_methods(CORS_ALLOWED_METHODS.to_vec()) - .allow_headers(CORS_ALLOWED_HEADERS.to_vec()) -} - -/// Initialize logging and read values from `src/.env` -pub fn logging_and_env() { - dotenv().ok(); - pretty_env_logger::init(); -} - -/// Configure Postgres and return a connection -pub fn postgres() -> PostgresConfig { - // use openssl::ssl::{SslConnector, SslMethod}; - // use postgres_openssl::MakeTlsConnector; - // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - // builder.set_ca_file("/etc/ssl/cert.pem").unwrap(); - // let connector = MakeTlsConnector::new(builder.build()); - // TODO: add TLS support, remove `NoTls` - let pg_cfg = match &env::var("DATABASE_URL").ok() { - Some(url) => { - warn!("DATABASE_URL env variable set. Connecting to Postgres with that URL and ignoring any values set in DB_HOST, DB_USER, DB_NAME, DB_PASS, or DB_PORT."); - PostgresConfig::from_url(Url::parse(url).unwrap()) - } - None => PostgresConfig::default() - .maybe_update_user(env::var("USER").ok()) - .maybe_update_user(env::var("DB_USER").ok()) - .maybe_update_host(env::var("DB_HOST").ok()) - .maybe_update_password(env::var("DB_PASS").ok()) - .maybe_update_db(env::var("DB_NAME").ok()) - .maybe_update_sslmode(env::var("DB_SSLMODE").ok()), - }; - log::warn!( - "Connecting to Postgres with the following configuration:\n{:#?}", - &pg_cfg - ); - pg_cfg -} - -/// Configure Redis and return a pair of connections -pub fn redis() -> RedisConfig { - let redis_cfg = match &env::var("REDIS_URL") { - Ok(url) => { - warn!("REDIS_URL env variable set. Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT."); - RedisConfig::from_url(Url::parse(url).unwrap()) - } - Err(_) => RedisConfig::default() - .maybe_update_host(env::var("REDIS_HOST").ok()) - .maybe_update_port(env::var("REDIS_PORT").ok()), - }.maybe_update_namespace(env::var("REDIS_NAMESPACE").ok()); - if let Some(user) = &redis_cfg.user { - log::error!( - "Username {} provided, but Redis does not need a username. Ignoring it", - user - ); - }; - log::warn!( - "Connecting to Redis with the following configuration:\n{:#?}", - &redis_cfg - ); - redis_cfg -} +// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` +// (on the order of 10ms rather than 50μs). Thus, changing this setting +// would be a good place to start for performance improvements at the cost +// of delaying all updates. #[macro_export] macro_rules! maybe_update { - ($name:ident; $item: tt) => ( - pub fn $name(self, item: Option) -> Self{ + ($name:ident; $item: tt:$type:ty) => ( + pub fn $name(self, item: Option<$type>) -> Self { match item { Some($item) => Self{ $item, ..self }, None => Self { ..self } } }); - ($name:ident; Some($item: tt)) => ( - pub fn $name(self, item: Option) -> Self{ + ($name:ident; Some($item: tt: $type:ty)) => ( + fn $name(self, item: Option<$type>) -> Self{ match item { Some($item) => Self{ $item: Some($item), ..self }, None => Self { ..self } diff --git a/src/config/postgres_cfg.rs b/src/config/postgres_cfg.rs index 613635a..ab2733c 100644 --- a/src/config/postgres_cfg.rs +++ b/src/config/postgres_cfg.rs @@ -1,4 +1,5 @@ use crate::{err, maybe_update}; +use std::collections::HashMap; use url::Url; #[derive(Debug)] @@ -7,7 +8,7 @@ pub struct PostgresConfig { pub host: String, pub password: Option, pub database: String, - pub port: String, + pub port: u16, pub ssl_mode: String, } @@ -18,7 +19,7 @@ impl Default for PostgresConfig { host: "localhost".to_string(), password: None, database: "mastodon_development".to_string(), - port: "5432".to_string(), + port: 5432, ssl_mode: "prefer".to_string(), } } @@ -28,14 +29,36 @@ fn none_if_empty(item: &str) -> Option { } impl PostgresConfig { - maybe_update!(maybe_update_user; user); - maybe_update!(maybe_update_host; host); - maybe_update!(maybe_update_db; database); - maybe_update!(maybe_update_port; port); - maybe_update!(maybe_update_sslmode; ssl_mode); - maybe_update!(maybe_update_password; Some(password)); + /// Configure Postgres and return a connection + pub fn from_env(env_vars: HashMap) -> Self { + // use openssl::ssl::{SslConnector, SslMethod}; + // use postgres_openssl::MakeTlsConnector; + // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + // builder.set_ca_file("/etc/ssl/cert.pem").unwrap(); + // let connector = MakeTlsConnector::new(builder.build()); + // TODO: add TLS support, remove `NoTls` + match env_vars.get("DATABASE_URL") { + Some(url) => { + log::warn!("DATABASE_URL env variable set. Connecting to Postgres with that URL and ignoring any values set in DB_HOST, DB_USER, DB_NAME, DB_PASS, or DB_PORT."); + PostgresConfig::from_url(Url::parse(url).unwrap()) + } + None => Self::default() + .maybe_update_user(env_vars.get("USER").map(String::from)) + .maybe_update_user(env_vars.get("DB_USER").map(String::from)) + .maybe_update_host(env_vars.get("DB_HOST").map(String::from)) + .maybe_update_password(env_vars.get("DB_PASS").map(String::from)) + .maybe_update_db(env_vars.get("DB_NAME").map(String::from)) + .maybe_update_sslmode(env_vars.get("DB_SSLMODE").map(String::from))} + .log() + } + maybe_update!(maybe_update_user; user: String); + maybe_update!(maybe_update_host; host: String); + maybe_update!(maybe_update_db; database: String); + maybe_update!(maybe_update_port; port: u16); + maybe_update!(maybe_update_sslmode; ssl_mode: String); + maybe_update!(maybe_update_password; Some(password: String)); - pub fn from_url(url: Url) -> Self { + fn from_url(url: Url) -> Self { let (mut user, mut host, mut sslmode, mut password) = (None, None, None, None); for (k, v) in url.query_pairs() { match k.to_string().as_str() { @@ -57,7 +80,11 @@ impl PostgresConfig { .maybe_update_user(none_if_empty(url.username())) .maybe_update_host(url.host_str().filter(|h| !h.is_empty()).map(String::from)) .maybe_update_password(url.password().map(String::from)) - .maybe_update_port(url.port().map(|port_num| port_num.to_string())) + .maybe_update_port(url.port()) .maybe_update_db(none_if_empty(&url.path()[1..])) } + fn log(self) -> Self { + log::warn!("Postgres configuration:\n{:#?}", &self); + self + } } diff --git a/src/config/redis_cfg.rs b/src/config/redis_cfg.rs index 019e92a..1373d31 100644 --- a/src/config/redis_cfg.rs +++ b/src/config/redis_cfg.rs @@ -1,22 +1,20 @@ use crate::{err, maybe_update}; +use std::{collections::HashMap, time::Duration}; use url::Url; fn none_if_empty(item: &str) -> Option { - if item.is_empty() { - None - } else { - Some(item.to_string()) - } + Some(item).filter(|i| !i.is_empty()).map(String::from) } #[derive(Debug)] pub struct RedisConfig { pub user: Option, pub password: Option, - pub port: String, + pub port: u16, pub host: String, pub db: Option, pub namespace: Option, + pub polling_interval: Duration, } impl Default for RedisConfig { fn default() -> Self { @@ -24,14 +22,31 @@ impl Default for RedisConfig { user: None, password: None, db: None, - port: "6379".to_string(), + port: 6379, host: "127.0.0.1".to_string(), namespace: None, + polling_interval: Duration::from_millis(100), } } } impl RedisConfig { - pub fn from_url(url: Url) -> Self { + pub fn from_env(env_vars: HashMap) -> Self { + match env_vars.get("REDIS_URL") { + Some(url) => { + log::warn!("REDIS_URL env variable set. Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT."); + Self::from_url(Url::parse(url).unwrap()) + } + None => RedisConfig::default() + .maybe_update_host(env_vars.get("REDIS_HOST").map(String::from)) + .maybe_update_port(env_vars.get("REDIS_PORT").map(|p| err::unwrap_or_die( + p.parse().ok(),"REDIS_PORT must be a number."))), + } + .maybe_update_namespace(env_vars.get("REDIS_NAMESPACE").map(String::from)) + .maybe_update_polling_interval(env_vars.get("REDIS_POLL_INTERVAL") + .map(|str| Duration::from_millis(str.parse().unwrap()))).log() + } + + fn from_url(url: Url) -> Self { let mut password = url.password().as_ref().map(|str| str.to_string()); let mut db = none_if_empty(&url.path()[1..]); for (k, v) in url.query_pairs() { @@ -41,16 +56,32 @@ impl RedisConfig { _ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", k))} } } + let user = none_if_empty(url.username()); + if let Some(user) = &user { + log::error!( + "Username {} provided, but Redis does not need a username. Ignoring it", + user + ); + } RedisConfig { - user: none_if_empty(url.username()), - host: err::unwrap_or_die(url.host_str(), "Missing or invalid host in REDIS_URL"), + user, + host: err::unwrap_or_die(url.host_str(), "Missing or invalid host in REDIS_URL") + .to_string(), port: err::unwrap_or_die(url.port(), "Missing or invalid port in REDIS_URL"), namespace: None, password, db, + polling_interval: Duration::from_millis(100), } } - maybe_update!(maybe_update_host; host); - maybe_update!(maybe_update_port; port); - maybe_update!(maybe_update_namespace; Some(namespace)); + + maybe_update!(maybe_update_host; host: String); + maybe_update!(maybe_update_port; port: u16); + maybe_update!(maybe_update_namespace; Some(namespace: String)); + maybe_update!(maybe_update_polling_interval; polling_interval: Duration); + + fn log(self) -> Self { + log::warn!("Redis configuration:\n{:#?},", &self); + self + } } diff --git a/src/err.rs b/src/err.rs index a9eff65..8f4e422 100644 --- a/src/err.rs +++ b/src/err.rs @@ -5,12 +5,20 @@ pub fn die_with_msg(msg: impl Display) -> ! { eprintln!("FATAL ERROR: {}", msg); std::process::exit(1); } -pub fn unwrap_or_die(s: Option, msg: &str) -> String { + +#[macro_export] +macro_rules! dbg_and_die { + ($msg:expr) => { + let message = format!("FATAL ERROR: {}", $msg); + dbg!(message); + std::process::exit(1); + }; +} +pub fn unwrap_or_die(s: Option, msg: &str) -> T { s.unwrap_or_else(|| { eprintln!("FATAL ERROR: {}", msg); std::process::exit(1) }) - .to_string() } #[derive(Serialize)] diff --git a/src/main.rs b/src/main.rs index ea1bb94..9cd9e4b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,34 @@ use flodgatt::{ - config, err, + config, dbg_and_die, err, parse_client_request::{sse, user, ws}, redis_to_client_stream::{self, ClientAgent}, }; use log::warn; +use std::{collections::HashMap, env, net}; use warp::{ws::Ws2, Filter as WarpFilter}; fn main() { - config::logging_and_env(); - let client_agent_sse = ClientAgent::blank(); + dotenv::from_filename( + match env::var("ENV").ok().as_ref().map(String::as_str) { + Some("production") => ".env.production", + Some("development") | None => ".env", + Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."), + }).ok(); + let env_vars: HashMap<_, _> = dotenv::vars().collect(); + pretty_env_logger::init(); + + let cfg = config::DeploymentConfig::from_env(env_vars.clone()); + let redis_cfg = config::RedisConfig::from_env(env_vars.clone()); + let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone()); + + let client_agent_sse = ClientAgent::blank(redis_cfg); let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); - let pg_conn = user::PostgresConn::new(); + let pg_conn = user::PostgresConn::new(postgres_cfg); warn!("Streaming server initialized and ready to accept connections"); // Server Sent Events + let sse_update_interval = cfg.ws_interval; let sse_routes = sse::extract_user_or_reject(pg_conn.clone()) .and(warp::sse()) .map( @@ -24,13 +38,18 @@ fn main() { // Assign ClientAgent to generate stream of updates for the user/timeline pair client_agent.init_for_user(user); // send the updates through the SSE connection - redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client) + redis_to_client_stream::send_updates_to_sse( + client_agent, + sse_connection_to_client, + sse_update_interval, + ) }, ) .with(warp::reply::with::header("Connection", "keep-alive")) .recover(err::handle_errors); // WebSocket + let ws_update_interval = cfg.ws_interval; let websocket_routes = ws::extract_user_or_reject(pg_conn.clone()) .and(warp::ws::ws2()) .map(move |user: user::User, ws: Ws2| { @@ -44,14 +63,27 @@ fn main() { ( ws.on_upgrade(move |socket| { - redis_to_client_stream::send_updates_to_ws(socket, client_agent) + redis_to_client_stream::send_updates_to_ws( + socket, + client_agent, + ws_update_interval, + ) }), token, ) }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); - let cors = config::cross_origin_resource_sharing(); + let cors = warp::cors() + .allow_any_origin() + .allow_methods(cfg.cors.allowed_methods) + .allow_headers(cfg.cors.allowed_headers); - warp::serve(websocket_routes.or(sse_routes).with(cors)).run(*config::SERVER_ADDR); + let server_addr = net::SocketAddr::new(cfg.address, cfg.port); + + if let Some(_socket) = cfg.unix_socket { + dbg_and_die!("Unix socket support not yet implemented"); + } else { + warp::serve(websocket_routes.or(sse_routes).with(cors)).run(server_addr); + } } diff --git a/src/parse_client_request/user/postgres.rs b/src/parse_client_request/user/postgres.rs index ca7935c..f93121b 100644 --- a/src/parse_client_request/user/postgres.rs +++ b/src/parse_client_request/user/postgres.rs @@ -6,12 +6,11 @@ use std::sync::{Arc, Mutex}; #[derive(Clone)] pub struct PostgresConn(pub Arc>); impl PostgresConn { - pub fn new() -> Self { - let pg_cfg = config::postgres(); + pub fn new(pg_cfg: config::PostgresConfig) -> Self { let mut con = postgres::Client::configure(); con.user(&pg_cfg.user) .host(&pg_cfg.host) - .port(pg_cfg.port.parse::().unwrap()) + .port(pg_cfg.port) .dbname(&pg_cfg.database); if let Some(password) = &pg_cfg.password { con.password(password); diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs index fdf0404..e4e8451 100644 --- a/src/redis_to_client_stream/client_agent.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -16,7 +16,7 @@ //! communicate with Redis, it we create a new `ClientAgent` for //! each new client connection (each in its own thread). use super::receiver::Receiver; -use crate::parse_client_request::user::User; +use crate::{config, parse_client_request::user::User}; use futures::{Async, Poll}; use serde_json::{json, Value}; use std::sync; @@ -24,7 +24,7 @@ use tokio::io::Error; use uuid::Uuid; /// Struct for managing all Redis streams. -#[derive(Clone, Default, Debug)] +#[derive(Clone, Debug)] pub struct ClientAgent { receiver: sync::Arc>, id: uuid::Uuid, @@ -34,9 +34,9 @@ pub struct ClientAgent { impl ClientAgent { /// Create a new `ClientAgent` with no shared data. - pub fn blank() -> Self { + pub fn blank(redis_cfg: config::RedisConfig) -> Self { ClientAgent { - receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())), + receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg))), id: Uuid::default(), target_timeline: String::new(), current_user: User::default(), diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index 3040bf5..54c318a 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -14,18 +14,16 @@ use std::time; pub fn send_updates_to_sse( mut client_agent: ClientAgent, connection: warp::sse::Sse, + update_interval: time::Duration, ) -> impl warp::reply::Reply { - let event_stream = tokio::timer::Interval::new( - time::Instant::now(), - time::Duration::from_millis(*config::SSE_UPDATE_INTERVAL), - ) - .filter_map(move |_| match client_agent.poll() { - Ok(Async::Ready(Some(json_value))) => Some(( - warp::sse::event(json_value["event"].clone().to_string()), - warp::sse::data(json_value["payload"].clone()), - )), - _ => None, - }); + let event_stream = tokio::timer::Interval::new(time::Instant::now(), update_interval) + .filter_map(move |_| match client_agent.poll() { + Ok(Async::Ready(Some(json_value))) => Some(( + warp::sse::event(json_value["event"].clone().to_string()), + warp::sse::data(json_value["payload"].clone()), + )), + _ => None, + }); connection.reply(warp::sse::keep(event_stream, None)) } @@ -34,6 +32,7 @@ pub fn send_updates_to_sse( pub fn send_updates_to_ws( socket: warp::ws::WebSocket, mut stream: ClientAgent, + update_interval: time::Duration, ) -> impl futures::future::Future { let (ws_tx, mut ws_rx) = socket.split(); @@ -50,22 +49,19 @@ pub fn send_updates_to_ws( ); // Yield new events for as long as the client is still connected - let event_stream = tokio::timer::Interval::new( - time::Instant::now(), - time::Duration::from_millis(*config::WS_UPDATE_INTERVAL), - ) - .take_while(move |_| match ws_rx.poll() { - Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true), - Ok(Async::Ready(None)) => { - // TODO: consider whether we should manually drop closed connections here - log::info!("Client closed WebSocket connection"); - futures::future::ok(false) - } - Err(e) => { - log::warn!("{}", e); - futures::future::ok(false) - } - }); + let event_stream = tokio::timer::Interval::new(time::Instant::now(), update_interval) + .take_while(move |_| match ws_rx.poll() { + Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true), + Ok(Async::Ready(None)) => { + // TODO: consider whether we should manually drop closed connections here + log::info!("Client closed WebSocket connection"); + futures::future::ok(false) + } + Err(e) => { + log::warn!("{}", e); + futures::future::ok(false) + } + }); // Every time you get an event from that stream, send it through the pipe event_stream diff --git a/src/redis_to_client_stream/receiver.rs b/src/redis_to_client_stream/receiver.rs index f308cf7..02b7851 100644 --- a/src/redis_to_client_stream/receiver.rs +++ b/src/redis_to_client_stream/receiver.rs @@ -1,8 +1,8 @@ //! Receives data from Redis, sorts it by `ClientAgent`, and stores it until //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. -use super::{redis_cmd, redis_stream, redis_stream::RedisConn}; -use crate::{config, pubsub_cmd}; +use super::{config, redis_cmd, redis_stream, redis_stream::RedisConn}; +use crate::pubsub_cmd; use futures::{Async, Poll}; use serde_json::Value; use std::{collections, net, time}; @@ -15,6 +15,7 @@ pub struct Receiver { pub pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, pub redis_namespace: Option, + redis_poll_interval: time::Duration, redis_polled_at: time::Instant, timeline: String, manager_id: Uuid, @@ -26,17 +27,19 @@ pub struct Receiver { impl Receiver { /// Create a new `Receiver`, with its own Redis connections (but, as yet, no /// active subscriptions). - pub fn new() -> Self { + pub fn new(redis_cfg: config::RedisConfig) -> Self { let RedisConn { primary: pubsub_connection, secondary: secondary_redis_connection, namespace: redis_namespace, - } = RedisConn::new(); + polling_interval: redis_poll_interval, + } = RedisConn::new(redis_cfg); Self { pubsub_connection, secondary_redis_connection, redis_namespace, + redis_poll_interval, redis_polled_at: time::Instant::now(), timeline: String::new(), manager_id: Uuid::default(), @@ -123,12 +126,6 @@ impl Receiver { } } -impl Default for Receiver { - fn default() -> Self { - Receiver::new() - } -} - /// The stream that the ClientAgent polls to learn about new messages. impl futures::stream::Stream for Receiver { type Item = Value; @@ -142,9 +139,7 @@ impl futures::stream::Stream for Receiver { /// been polled lately. fn poll(&mut self) -> Poll, Self::Error> { let timeline = self.timeline.clone(); - if self.redis_polled_at.elapsed() - > time::Duration::from_millis(*config::REDIS_POLL_INTERVAL) - { + if self.redis_polled_at.elapsed() > self.redis_poll_interval { redis_stream::AsyncReadableStream::poll_redis(self); self.redis_polled_at = time::Instant::now(); } diff --git a/src/redis_to_client_stream/redis_stream.rs b/src/redis_to_client_stream/redis_stream.rs index 1173b14..4295d0c 100644 --- a/src/redis_to_client_stream/redis_stream.rs +++ b/src/redis_to_client_stream/redis_stream.rs @@ -9,10 +9,10 @@ pub struct RedisConn { pub primary: net::TcpStream, pub secondary: net::TcpStream, pub namespace: Option, + pub polling_interval: time::Duration, } impl RedisConn { - pub fn new() -> Self { - let redis_cfg = config::redis(); + pub fn new(redis_cfg: config::RedisConfig) -> Self { let addr = format!("{}:{}", redis_cfg.host, redis_cfg.port); let mut pubsub_connection = net::TcpStream::connect(addr.clone()).expect("Can connect to Redis"); @@ -49,6 +49,7 @@ impl RedisConn { primary: pubsub_connection, secondary: secondary_redis_connection, namespace: redis_cfg.namespace, + polling_interval: redis_cfg.polling_interval, } } }