diff --git a/Cargo.lock b/Cargo.lock index e4525e7..0ca68d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,6 +416,8 @@ dependencies = [ "serde 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", "serde_derive 1.0.91 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.39 (registry+https://github.com/rust-lang/crates.io-index)", + "strum 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)", + "strum_macros 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)", "tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)", "url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "uuid 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1059,6 +1061,14 @@ dependencies = [ "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "proc-macro2" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "quick-error" version = "1.2.2" @@ -1072,6 +1082,14 @@ dependencies = [ "proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "quote" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rand" version = "0.6.5" @@ -1467,6 +1485,22 @@ dependencies = [ "unicode-normalization 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "strum" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "strum_macros" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "heck 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "subtle" version = "1.0.0" @@ -1482,6 +1516,16 @@ dependencies = [ "unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "syn" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "synstructure" version = "0.10.1" @@ -1859,6 +1903,11 @@ name = "unicode-xid" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "unicode-xid" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "url" version = "2.1.0" @@ -2127,8 +2176,10 @@ dependencies = [ "checksum ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e3cbf9f658cdb5000fcf6f362b8ea2ba154b9f146a61c7a20d647034c6b6561b" "checksum pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df8b3f4e0475def7d9c2e5de8e5a1306949849761e107b360d03e98eafaffd61" "checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759" +"checksum proc-macro2 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "90cf5f418035b98e655e9cdb225047638296b862b42411c4e45bb88d700f7fc0" "checksum quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "9274b940887ce9addde99c4eee6b5c44cc494b182b97e73dc8ffdcb3397fd3f0" "checksum quote 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "faf4799c5d274f3868a4aae320a0a182cbd2baee377b378f080e16a23e9d80db" +"checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe" "checksum rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)" = "6d71dacdc3c88c1fde3885a3be3fbab9f35724e6ce99467f7d9c5026132184ca" "checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" "checksum rand_chacha 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "556d3a1ca6600bfcbab7c7c91ccb085ac7fbbcd70e008a98742e7847f4f7bcef" @@ -2178,8 +2229,11 @@ dependencies = [ "checksum state_machine_future 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "530e1d624baae485bce12e6647acb76aafa253346ee8a16751974eed5a24b13d" "checksum string 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "b639411d0b9c738748b5397d5ceba08e648f4f1992231aa859af1a017f31f60b" "checksum stringprep 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" +"checksum strum 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)" = "6138f8f88a16d90134763314e3fc76fa3ed6a7db4725d6acf9a3ef95a3188d22" +"checksum strum_macros 0.16.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0054a7df764039a6cd8592b9de84be4bec368ff081d203a7d5371cbfa8e65c81" "checksum subtle 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2d67a5a62ba6e01cb2192ff309324cb4875d0c451d55fe2319433abe7a05a8ee" "checksum syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)" = "a1393e4a97a19c01e900df2aec855a29f71cf02c402e2f443b8d2747c25c5dbe" +"checksum syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "66850e97125af79138385e9b88339cbcd037e3f28ceab8c5ad98e64f0f1f80bf" "checksum synstructure 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "73687139bf99285483c96ac0add482c3776528beac1d97d444f6e91f203a2015" "checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9" "checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e" @@ -2216,6 +2270,7 @@ dependencies = [ "checksum unicode-segmentation 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1967f4cdfc355b37fd76d2a954fb2ed3871034eb4f26d60537d88795cfc332a9" "checksum unicode-width 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7007dbd421b92cc6e28410fe7362e2e0a2503394908f417b68ec8d1c364c4e20" "checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" +"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" "checksum url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "75b414f6c464c879d7f9babf951f23bc3743fb7313c081b2e6ca719067ea9d61" "checksum urlencoding 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3df3561629a8bb4c57e5a2e4c43348d9e29c7c29d9b1c4c1f47166deca8f37ed" "checksum utf-8 0.7.5 (registry+https://github.com/rust-lang/crates.io-index)" = "05e42f7c18b8f902290b009cde6d651262f956c98bc51bca4cd1d511c9cd85c7" diff --git a/Cargo.toml b/Cargo.toml index 2fa483c..a463b2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ uuid = { version = "0.7", features = ["v4"] } dotenv = "0.14.0" postgres-openssl = { git = "https://github.com/sfackler/rust-postgres.git"} url = "2.1.0" +strum = "0.16.0" +strum_macros = "0.16.0" [dev-dependencies] criterion = "0.3" diff --git a/src/config/deployment_cfg.rs b/src/config/deployment_cfg.rs index 12ba5c5..d94a79e 100644 --- a/src/config/deployment_cfg.rs +++ b/src/config/deployment_cfg.rs @@ -1,102 +1,31 @@ -use crate::{err, maybe_update}; -use std::{ - collections::HashMap, - fmt, - net::{IpAddr, Ipv4Addr}, - os::unix::net::UnixListener, - time::Duration, -}; +use super::{deployment_cfg_types::*, EnvVar}; -#[derive(Debug)] +#[derive(Debug, Default)] pub struct DeploymentConfig<'a> { - pub env: String, - pub log_level: String, - pub address: IpAddr, - pub port: u16, - pub unix_socket: Option, + pub env: Env, + pub log_level: LogLevel, + pub address: FlodgattAddr, + pub port: Port, + pub unix_socket: Socket, pub cors: Cors<'a>, - pub sse_interval: Duration, - pub ws_interval: Duration, + pub sse_interval: SseInterval, + pub ws_interval: WsInterval, } -pub struct Cors<'a> { - pub allowed_headers: Vec<&'a str>, - pub allowed_methods: Vec<&'a str>, -} -impl fmt::Debug for Cors<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "allowed headers: {:?}\n allowed methods: {:?}", - self.allowed_headers, self.allowed_methods - ) - } -} - -impl Default for DeploymentConfig<'_> { - fn default() -> Self { - Self { - env: "development".to_string(), - log_level: "error".to_string(), - address: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), - port: 4000, - unix_socket: None, - cors: Cors { - allowed_methods: vec!["GET", "OPTIONS"], - allowed_headers: vec!["Authorization", "Accept", "Cache-Control"], - }, - sse_interval: Duration::from_millis(100), - ws_interval: Duration::from_millis(100), - } - } -} impl DeploymentConfig<'_> { - pub fn from_env(env_vars: HashMap) -> Self { - Self::default() - .maybe_update_env(env_vars.get("NODE_ENV").map(String::from)) - .maybe_update_env(env_vars.get("RUST_ENV").map(String::from)) - .maybe_update_address( - env_vars - .get("BIND") - .map(|a| err::unwrap_or_die(a.parse().ok(), "BIND must be a valid address")), - ) - .maybe_update_port( - env_vars - .get("PORT") - .map(|port| err::unwrap_or_die(port.parse().ok(), "PORT must be a number")), - ) - .maybe_update_unix_socket( - env_vars - .get("SOCKET") - .map(|s| UnixListener::bind(s).unwrap()), - ) - .maybe_update_log_level(env_vars.get("RUST_LOG").map(|level| match level.as_ref() { - l @ "trace" | l @ "debug" | l @ "info" | l @ "warn" | l @ "error" => l.to_string(), - _ => err::die_with_msg("Invalid log level specified"), - })) - .maybe_update_sse_interval( - env_vars - .get("SSE_UPDATE_INTERVAL") - .map(|str| Duration::from_millis(str.parse().unwrap())), - ) - .maybe_update_ws_interval( - env_vars - .get("WS_UPDATE_INTERVAL") - .map(|str| Duration::from_millis(str.parse().unwrap())), - ) - .log() - } - - maybe_update!(maybe_update_env; env: String); - maybe_update!(maybe_update_port; port: u16); - maybe_update!(maybe_update_address; address: IpAddr); - maybe_update!(maybe_update_unix_socket; Some(unix_socket: UnixListener)); - maybe_update!(maybe_update_log_level; log_level: String); - maybe_update!(maybe_update_sse_interval; sse_interval: Duration); - maybe_update!(maybe_update_ws_interval; ws_interval: Duration); - - fn log(self) -> Self { - log::warn!("Using deployment configuration:\n {:#?}", &self); - self + pub fn from_env(env: EnvVar) -> Self { + let mut cfg = Self { + env: Env::default().maybe_update(env.get("NODE_ENV")), + log_level: LogLevel::default().maybe_update(env.get("RUST_LOG")), + address: FlodgattAddr::default().maybe_update(env.get("BIND")), + port: Port::default().maybe_update(env.get("PORT")), + unix_socket: Socket::default().maybe_update(env.get("SOCKET")), + sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ")), + ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ")), + cors: Cors::default(), + }; + cfg.env = cfg.env.maybe_update(env.get("RUST_ENV")); + log::info!("Using deployment configuration:\n {:#?}", &cfg); + cfg } } diff --git a/src/config/deployment_cfg_types.rs b/src/config/deployment_cfg_types.rs new file mode 100644 index 0000000..31a7c03 --- /dev/null +++ b/src/config/deployment_cfg_types.rs @@ -0,0 +1,104 @@ +use crate::from_env_var; +use std::{ + fmt, + net::{IpAddr, Ipv4Addr}, + os::unix::net::UnixListener, + str::FromStr, + time::Duration, +}; +use strum_macros::{EnumString, EnumVariantNames}; + +from_env_var!( + /// The current environment, which controls what file to read other ENV vars from + let name = Env; + let default: EnvInner = EnvInner::Development; + let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants())); + let from_str = |s| EnvInner::from_str(s).ok(); +); +from_env_var!( + /// The address to run Flodgatt on + let name = FlodgattAddr; + let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded")); + let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string()); + let from_str = |s| match s { + "localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), + _ => s.parse().ok(), + }; +); +from_env_var!( + /// How verbosely Flodgatt should log messages + let name = LogLevel; + let default: LogLevelInner = LogLevelInner::Warn; + let (env_var, allowed_values) = ("RUST_LOG", "a valid address (e.g., 127.0.0.1)".to_string()); + let from_str = |s| LogLevelInner::from_str(s).ok(); +); +from_env_var!( + /// A Unix Socket to use in place of a local address + let name = Socket; + let default: Option = None; + let (env_var, allowed_values) = ("SOCKET", "a valid Unix Socket".to_string()); + let from_str = |s| match UnixListener::bind(s).ok() { + Some(socket) => Some(Some(socket)), + None => None, + }; +); +from_env_var!( + /// The time between replies sent via WebSocket + let name = WsInterval; + let default: Duration = Duration::from_millis(100); + let (env_var, allowed_values) = ("WS_FREQ", "a valid Unix Socket".to_string()); + let from_str = |s| s.parse().map(Duration::from_millis).ok(); +); +from_env_var!( + /// The time between replies sent via Server Sent Events + let name = SseInterval; + let default: Duration = Duration::from_millis(100); + let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string()); + let from_str = |s| s.parse().map(Duration::from_millis).ok(); +); +from_env_var!( + /// The port to run Flodgatt on + let name = Port; + let default: u16 = 4000; + let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string()); + let from_str = |s| s.parse().ok(); +); +/// Permissions for Cross Origin Resource Sharing (CORS) +pub struct Cors<'a> { + pub allowed_headers: Vec<&'a str>, + pub allowed_methods: Vec<&'a str>, +} +impl std::default::Default for Cors<'_> { + fn default() -> Self { + Self { + allowed_methods: vec!["GET", "OPTIONS"], + allowed_headers: vec!["Authorization", "Accept", "Cache-Control"], + } + } +} +impl fmt::Debug for Cors<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "allowed headers: {:?}\n allowed methods: {:?}", + self.allowed_headers, self.allowed_methods + ) + } +} + +#[derive(EnumString, EnumVariantNames, Debug)] +#[strum(serialize_all = "snake_case")] +pub enum LogLevelInner { + Trace, + Debug, + Info, + Warn, + Error, +} + +#[derive(EnumString, EnumVariantNames, Debug)] +#[strum(serialize_all = "snake_case")] +pub enum EnvInner { + Production, + Development, +} diff --git a/src/config/mod.rs b/src/config/mod.rs index f040b5b..cf2dfde 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,17 +1,58 @@ -//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden -//! by an environmental variable of the same name without that prefix (either by setting -//! the variable at runtime or in the `.env` file) mod deployment_cfg; +mod deployment_cfg_types; mod postgres_cfg; mod redis_cfg; +mod redis_cfg_types; pub use self::{ - deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, + deployment_cfg::DeploymentConfig, + postgres_cfg::PostgresConfig, + redis_cfg::RedisConfig, + redis_cfg_types::{RedisInterval, RedisNamespace}, }; +use std::collections::HashMap; +use url::Url; -// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` -// (on the order of 10ms rather than 50μs). Thus, changing this setting -// would be a good place to start for performance improvements at the cost -// of delaying all updates. +pub struct EnvVar(pub HashMap); +impl std::ops::Deref for EnvVar { + type Target = HashMap; + fn deref(&self) -> &HashMap { + &self.0 + } +} +impl Clone for EnvVar { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} +impl EnvVar { + fn update_with_url(mut self, url_str: &str) -> Self { + let url = Url::parse(url_str).unwrap(); + let none_if_empty = |s: String| if s.is_empty() { None } else { Some(s) }; + + self.maybe_add_env_var("REDIS_PORT", url.port()); + self.maybe_add_env_var("REDIS_PASSWORD", url.password()); + self.maybe_add_env_var("REDIS_USERNAME", none_if_empty(url.username().to_string())); + self.maybe_add_env_var("REDIS_DB", none_if_empty(url.path()[1..].to_string())); + for (k, v) in url.query_pairs().into_owned() { + match k.to_string().as_str() { + "password" => self.maybe_add_env_var("REDIS_PASSWORD", Some(v.to_string())), + "db" => self.maybe_add_env_var("REDIS_DB", Some(v.to_string())), + _ => crate::err::die_with_msg(format!( + r"Unsupported parameter {} in REDIS_URL. + Flodgatt supports only `password` and `db` parameters.", + k + )), + } + } + + self + } + fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) { + if let Some(value) = maybe_value { + self.0.insert(key.to_string(), value.to_string()); + } + } +} #[macro_export] macro_rules! maybe_update { @@ -29,3 +70,44 @@ macro_rules! maybe_update { None => Self { ..self } } })} +#[macro_export] +macro_rules! from_env_var { + ($(#[$outer:meta])* + let name = $name:ident; + let default: $type:ty = $inner:expr; + let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr); + let from_str = |$arg:ident| $body:expr; + ) => { + pub struct $name(pub $type); + impl std::fmt::Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl std::ops::Deref for $name { + type Target = $type; + fn deref(&self) -> &$type { + &self.0 + } + } + impl std::default::Default for $name { + fn default() -> Self { + $name($inner) + } + } + impl $name { + fn inner_from_str($arg: &str) -> Option<$type> { + $body + } + pub fn maybe_update(self, var: Option<&String>) -> Self { + if let Some(value) = var { + Self(Self::inner_from_str(value).unwrap_or_else(|| { + crate::err::env_var_fatal($env_var, value, $allowed_values) + })) + } else { + self + } + } + } + }; +} diff --git a/src/config/postgres_cfg.rs b/src/config/postgres_cfg.rs index ab2733c..532c249 100644 --- a/src/config/postgres_cfg.rs +++ b/src/config/postgres_cfg.rs @@ -1,5 +1,5 @@ +use super::EnvVar; use crate::{err, maybe_update}; -use std::collections::HashMap; use url::Url; #[derive(Debug)] @@ -30,7 +30,7 @@ fn none_if_empty(item: &str) -> Option { impl PostgresConfig { /// Configure Postgres and return a connection - pub fn from_env(env_vars: HashMap) -> Self { + pub fn from_env(env_vars: EnvVar) -> Self { // use openssl::ssl::{SslConnector, SslMethod}; // use postgres_openssl::MakeTlsConnector; // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); diff --git a/src/config/redis_cfg.rs b/src/config/redis_cfg.rs index 1373d31..f73717f 100644 --- a/src/config/redis_cfg.rs +++ b/src/config/redis_cfg.rs @@ -1,87 +1,50 @@ -use crate::{err, maybe_update}; -use std::{collections::HashMap, time::Duration}; -use url::Url; +use super::redis_cfg_types::*; +use crate::config::EnvVar; -fn none_if_empty(item: &str) -> Option { - Some(item).filter(|i| !i.is_empty()).map(String::from) -} - -#[derive(Debug)] +#[derive(Debug, Default)] pub struct RedisConfig { - pub user: Option, - pub password: Option, - pub port: u16, - pub host: String, - pub db: Option, - pub namespace: Option, - pub polling_interval: Duration, -} -impl Default for RedisConfig { - fn default() -> Self { - Self { - user: None, - password: None, - db: None, - port: 6379, - host: "127.0.0.1".to_string(), - namespace: None, - polling_interval: Duration::from_millis(100), - } - } + pub user: RedisUser, + pub password: RedisPass, + pub port: RedisPort, + pub host: RedisHost, + pub db: RedisDb, + pub namespace: RedisNamespace, + // **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` (~1ms + // compared to ~50μs). Thus, changing this setting with REDIS_POLL_INTERVAL may be a good + // place to start for performance improvements at the cost of delaying all updates. + pub polling_interval: RedisInterval, } + impl RedisConfig { - pub fn from_env(env_vars: HashMap) -> Self { - match env_vars.get("REDIS_URL") { - Some(url) => { - log::warn!("REDIS_URL env variable set. Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT."); - Self::from_url(Url::parse(url).unwrap()) - } - None => RedisConfig::default() - .maybe_update_host(env_vars.get("REDIS_HOST").map(String::from)) - .maybe_update_port(env_vars.get("REDIS_PORT").map(|p| err::unwrap_or_die( - p.parse().ok(),"REDIS_PORT must be a number."))), - } - .maybe_update_namespace(env_vars.get("REDIS_NAMESPACE").map(String::from)) - .maybe_update_polling_interval(env_vars.get("REDIS_POLL_INTERVAL") - .map(|str| Duration::from_millis(str.parse().unwrap()))).log() - } + const USER_SET_WARNING: &'static str = + "Redis user specified, but Redis did not ask for a username. Ignoring it."; + const DB_SET_WARNING: &'static str = + r"Redis database specified, but PubSub connections do not use databases. +For similar functionality, you may wish to set a REDIS_NAMESPACE"; - fn from_url(url: Url) -> Self { - let mut password = url.password().as_ref().map(|str| str.to_string()); - let mut db = none_if_empty(&url.path()[1..]); - for (k, v) in url.query_pairs() { - match k.to_string().as_str() { - "password" => { password = Some(v.to_string());}, - "db" => { db = Some(v.to_string())}, - _ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", k))} - } - } - let user = none_if_empty(url.username()); - if let Some(user) = &user { - log::error!( - "Username {} provided, but Redis does not need a username. Ignoring it", - user - ); - } - RedisConfig { - user, - host: err::unwrap_or_die(url.host_str(), "Missing or invalid host in REDIS_URL") - .to_string(), - port: err::unwrap_or_die(url.port(), "Missing or invalid port in REDIS_URL"), - namespace: None, - password, - db, - polling_interval: Duration::from_millis(100), - } - } + pub fn from_env(env: EnvVar) -> Self { + let env = match env.get("REDIS_URL").map(|s| s.clone()) { + Some(url_str) => env.update_with_url(&url_str), + None => env, + }; - maybe_update!(maybe_update_host; host: String); - maybe_update!(maybe_update_port; port: u16); - maybe_update!(maybe_update_namespace; Some(namespace: String)); - maybe_update!(maybe_update_polling_interval; polling_interval: Duration); + let cfg = RedisConfig { + user: RedisUser::default().maybe_update(env.get("REDIS_USER")), + password: RedisPass::default().maybe_update(env.get("REDIS_PASSWORD")), + port: RedisPort::default().maybe_update(env.get("REDIS_PORT")), + host: RedisHost::default().maybe_update(env.get("REDIS_HOST")), + db: RedisDb::default().maybe_update(env.get("REDIS_DB")), + namespace: RedisNamespace::default().maybe_update(env.get("REDIS_NAMESPACE")), + polling_interval: RedisInterval::default().maybe_update(env.get("REDIS_POLL_INTERVAL")), + }; - fn log(self) -> Self { - log::warn!("Redis configuration:\n{:#?},", &self); - self + if cfg.db.is_some() { + log::warn!("{}", Self::DB_SET_WARNING); + } + if cfg.user.is_some() { + log::warn!("{}", Self::USER_SET_WARNING); + } + log::info!("Redis configuration:\n{:#?},", &cfg); + cfg } } diff --git a/src/config/redis_cfg_types.rs b/src/config/redis_cfg_types.rs new file mode 100644 index 0000000..0a2f3b5 --- /dev/null +++ b/src/config/redis_cfg_types.rs @@ -0,0 +1,60 @@ +use crate::from_env_var; +use std::{ + net::{IpAddr, Ipv4Addr}, + time::Duration, +}; +//use std::{fmt, net::IpAddr, os::unix::net::UnixListener, str::FromStr, time::Duration}; +//use strum_macros::{EnumString, EnumVariantNames}; + +from_env_var!( + /// The host address where Redis is running + let name = RedisHost; + let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded")); + let (env_var, allowed_values) = ("REDIS_HOST", "a valid address (e.g., 127.0.0.1)".to_string()); + let from_str = |s| match s { + "localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), + _ => s.parse().ok(), + }; +); +from_env_var!( + /// The port Redis is running on + let name = RedisPort; + let default: u16 = 6379; + let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string()); + let from_str = |s| s.parse().ok(); +); +from_env_var!( + /// How frequently to poll Redis + let name = RedisInterval; + let default: Duration = Duration::from_millis(100); + let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string()); + let from_str = |s| s.parse().map(Duration::from_millis).ok(); +); +from_env_var!( + /// The password to use for Redis + let name = RedisPass; + let default: Option = None; + let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string()); + let from_str = |s| Some(Some(s.to_string())); +); +from_env_var!( + /// An optional Redis Namespace + let name = RedisNamespace; + let default: Option = None; + let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string()); + let from_str = |s| Some(Some(s.to_string())); +); +from_env_var!( + /// A user for Redis (not supported) + let name = RedisUser; + let default: Option = None; + let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string()); + let from_str = |s| Some(Some(s.to_string())); +); +from_env_var!( + /// The database to use with Redis (no current effect for PubSub connections) + let name = RedisDb; + let default: Option = None; + let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string()); + let from_str = |s| Some(Some(s.to_string())); +); diff --git a/src/err.rs b/src/err.rs index 8f4e422..c9b143d 100644 --- a/src/err.rs +++ b/src/err.rs @@ -6,6 +6,17 @@ pub fn die_with_msg(msg: impl Display) -> ! { std::process::exit(1); } +pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! { + eprintln!( + r"FATAL ERROR: {var} is set to `{value}`, which is invalid. + {var} must be {allowed_vals}.", + var = env_var, + value = supplied_value, + allowed_vals = allowed_values + ); + std::process::exit(1); +} + #[macro_export] macro_rules! dbg_and_die { ($msg:expr) => { @@ -14,7 +25,7 @@ macro_rules! dbg_and_die { std::process::exit(1); }; } -pub fn unwrap_or_die(s: Option, msg: &str) -> T { +pub fn unwrap_or_die(s: Option, msg: &str) -> T { s.unwrap_or_else(|| { eprintln!("FATAL ERROR: {}", msg); std::process::exit(1) diff --git a/src/main.rs b/src/main.rs index 9cd9e4b..80d3739 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,11 +14,12 @@ fn main() { Some("development") | None => ".env", Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."), }).ok(); - let env_vars: HashMap<_, _> = dotenv::vars().collect(); + let env_vars_map: HashMap<_, _> = dotenv::vars().collect(); + let env_vars = config::EnvVar(env_vars_map); pretty_env_logger::init(); - - let cfg = config::DeploymentConfig::from_env(env_vars.clone()); let redis_cfg = config::RedisConfig::from_env(env_vars.clone()); + let cfg = config::DeploymentConfig::from_env(env_vars.clone()); + let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone()); let client_agent_sse = ClientAgent::blank(redis_cfg); @@ -28,7 +29,7 @@ fn main() { warn!("Streaming server initialized and ready to accept connections"); // Server Sent Events - let sse_update_interval = cfg.ws_interval; + let sse_update_interval = *cfg.ws_interval; let sse_routes = sse::extract_user_or_reject(pg_conn.clone()) .and(warp::sse()) .map( @@ -49,7 +50,7 @@ fn main() { .recover(err::handle_errors); // WebSocket - let ws_update_interval = cfg.ws_interval; + let ws_update_interval = *cfg.ws_interval; let websocket_routes = ws::extract_user_or_reject(pg_conn.clone()) .and(warp::ws::ws2()) .map(move |user: user::User, ws: Ws2| { @@ -79,9 +80,9 @@ fn main() { .allow_methods(cfg.cors.allowed_methods) .allow_headers(cfg.cors.allowed_headers); - let server_addr = net::SocketAddr::new(cfg.address, cfg.port); + let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0); - if let Some(_socket) = cfg.unix_socket { + if let Some(_socket) = cfg.unix_socket.0.as_ref() { dbg_and_die!("Unix socket support not yet implemented"); } else { warp::serve(websocket_routes.or(sse_routes).with(cors)).run(server_addr); diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index c399e32..7be71c2 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -39,6 +39,7 @@ pub fn extract_user_or_reject(pg_conn: PostgresConn) -> BoxedFilter<(User,)> { .and_then(move |q| User::from_query(q, pg_conn.clone())) .boxed() } + #[cfg(test)] mod test { use super::*; diff --git a/src/redis_to_client_stream/receiver.rs b/src/redis_to_client_stream/receiver.rs index 02b7851..25997d8 100644 --- a/src/redis_to_client_stream/receiver.rs +++ b/src/redis_to_client_stream/receiver.rs @@ -1,7 +1,11 @@ //! Receives data from Redis, sorts it by `ClientAgent`, and stores it until //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. -use super::{config, redis_cmd, redis_stream, redis_stream::RedisConn}; +use super::{ + config::{self, RedisInterval, RedisNamespace}, + redis_cmd, redis_stream, + redis_stream::RedisConn, +}; use crate::pubsub_cmd; use futures::{Async, Poll}; use serde_json::Value; @@ -14,8 +18,8 @@ use uuid::Uuid; pub struct Receiver { pub pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, - pub redis_namespace: Option, - redis_poll_interval: time::Duration, + pub redis_namespace: RedisNamespace, + redis_poll_interval: RedisInterval, redis_polled_at: time::Instant, timeline: String, manager_id: Uuid, @@ -139,7 +143,7 @@ impl futures::stream::Stream for Receiver { /// been polled lately. fn poll(&mut self) -> Poll, Self::Error> { let timeline = self.timeline.clone(); - if self.redis_polled_at.elapsed() > self.redis_poll_interval { + if self.redis_polled_at.elapsed() > *self.redis_poll_interval { redis_stream::AsyncReadableStream::poll_redis(self); self.redis_polled_at = time::Instant::now(); } diff --git a/src/redis_to_client_stream/redis_stream.rs b/src/redis_to_client_stream/redis_stream.rs index 4295d0c..f0132d8 100644 --- a/src/redis_to_client_stream/redis_stream.rs +++ b/src/redis_to_client_stream/redis_stream.rs @@ -1,5 +1,9 @@ use super::receiver::Receiver; -use crate::{config, redis_to_client_stream::redis_cmd}; +use crate::{ + config::{self, RedisInterval, RedisNamespace}, + err, + redis_to_client_stream::redis_cmd, +}; use futures::{Async, Poll}; use serde_json::Value; use std::{io::Read, io::Write, net, time}; @@ -8,46 +12,86 @@ use tokio::io::AsyncRead; pub struct RedisConn { pub primary: net::TcpStream, pub secondary: net::TcpStream, - pub namespace: Option, - pub polling_interval: time::Duration, + pub namespace: RedisNamespace, + pub polling_interval: RedisInterval, } + +fn send_password(mut conn: net::TcpStream, password: &str) -> net::TcpStream { + conn.write_all(&redis_cmd::cmd("auth", &password)).unwrap(); + let mut buffer = vec![0u8; 5]; + conn.read_exact(&mut buffer).unwrap(); + let reply = String::from_utf8(buffer.to_vec()).unwrap(); + if reply != "+OK\r\n" { + err::die_with_msg(format!( + r"Incorrect Redis password. You supplied `{}`. + Please supply correct password with REDIS_PASSWORD environmental variable.", + password, + )) + }; + conn +} + +fn set_db(mut conn: net::TcpStream, db: &str) -> net::TcpStream { + conn.write_all(&redis_cmd::cmd("SELECT", &db)).unwrap(); + conn +} + +fn send_test_ping(mut conn: net::TcpStream) -> net::TcpStream { + conn.write_all(b"PING\r\n").unwrap(); + let mut buffer = vec![0u8; 7]; + conn.read_exact(&mut buffer).unwrap(); + let reply = String::from_utf8(buffer.to_vec()).unwrap(); + match reply.as_str() { + "+PONG\r\n" => (), + "-NOAUTH" => err::die_with_msg( + r"Invalid authentication for Redis. + Redis reports that it needs a password, but you did not provide one. + You can set a password with the REDIS_PASSWORD environmental variable.", + ), + "HTTP/1." => err::die_with_msg( + r"The server at REDIS_HOST and REDIS_PORT is not a Redis server. + Please update the REDIS_HOST and/or REDIS_PORT environmental variables.", + ), + _ => err::die_with_msg(format!( + "Could not connect to Redis for unknown reason. Expected `+PONG` reply but got {}", + reply + )), + }; + conn +} + impl RedisConn { pub fn new(redis_cfg: config::RedisConfig) -> Self { - let addr = format!("{}:{}", redis_cfg.host, redis_cfg.port); - let mut pubsub_connection = - net::TcpStream::connect(addr.clone()).expect("Can connect to Redis"); - pubsub_connection - .set_read_timeout(Some(time::Duration::from_millis(10))) - .expect("Can set read timeout for Redis connection"); - pubsub_connection + let addr = net::SocketAddr::from((*redis_cfg.host, *redis_cfg.port)); + let conn_err = |e| { + err::die_with_msg(format!( + "Could not connect to Redis at {}:{}.\n Error detail: {}", + *redis_cfg.host, *redis_cfg.port, e, + )) + }; + let update_conn = |mut conn| { + if let Some(password) = redis_cfg.password.clone() { + conn = send_password(conn, &password); + } + conn = send_test_ping(conn); + conn.set_read_timeout(Some(time::Duration::from_millis(10))) + .expect("Can set read timeout for Redis connection"); + if let Some(db) = &*redis_cfg.db { + conn = set_db(conn, db); + } + conn + }; + let (primary_conn, secondary_conn) = ( + update_conn(net::TcpStream::connect(addr).unwrap_or_else(conn_err)), + update_conn(net::TcpStream::connect(addr).unwrap_or_else(conn_err)), + ); + primary_conn .set_nonblocking(true) .expect("set_nonblocking call failed"); - let mut secondary_redis_connection = - net::TcpStream::connect(addr).expect("Can connect to Redis"); - secondary_redis_connection - .set_read_timeout(Some(time::Duration::from_millis(10))) - .expect("Can set read timeout for Redis connection"); - if let Some(password) = redis_cfg.password { - pubsub_connection - .write_all(&redis_cmd::cmd("auth", &password)) - .unwrap(); - secondary_redis_connection - .write_all(&redis_cmd::cmd("auth", password)) - .unwrap(); - } - - if let Some(db) = redis_cfg.db { - pubsub_connection - .write_all(&redis_cmd::cmd("SELECT", &db)) - .unwrap(); - secondary_redis_connection - .write_all(&redis_cmd::cmd("SELECT", &db)) - .unwrap(); - } Self { - primary: pubsub_connection, - secondary: secondary_redis_connection, + primary: primary_conn, + secondary: secondary_conn, namespace: redis_cfg.namespace, polling_interval: redis_cfg.polling_interval, } @@ -72,13 +116,18 @@ impl<'a> AsyncReadableStream<'a> { if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() { let raw_redis_response = async_stream.as_utf8(buffer, num_bytes_read); + dbg!(&raw_redis_response); if raw_redis_response.starts_with("-NOAUTH") { - eprintln!( + err::die_with_msg( r"Invalid authentication for Redis. Do you need a password? -If so, set it with the REDIS_PASSWORD environmental variable" +If so, set it with the REDIS_PASSWORD environmental variable.", + ); + } else if raw_redis_response.starts_with("HTTP") { + err::die_with_msg( + r"The server at REDIS_HOST and REDIS_PORT is not a Redis server. +Please update the REDIS_HOST and/or REDIS_PORT environmental variables with the correct values.", ); - std::process::exit(1); } receiver.incoming_raw_msg.push_str(&raw_redis_response); @@ -89,7 +138,7 @@ If so, set it with the REDIS_PASSWORD environmental variable" }; let mut msg = RedisMsg::from_raw(&receiver.incoming_raw_msg); - let prefix_to_skip = match &receiver.redis_namespace { + let prefix_to_skip = match &*receiver.redis_namespace { Some(namespace) => format!("{}:timeline:", namespace), None => "timeline:".to_string(), };