mirror of https://github.com/mastodon/flodgatt
Significant progress on type safety
This commit is contained in:
parent
503ddfd510
commit
526e9d99cb
|
@ -1,4 +1,4 @@
|
|||
use crate::config::deployment_cfg_types::*;
|
||||
use super::deployment_cfg_types::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
@ -6,7 +6,7 @@ pub struct DeploymentConfig<'a> {
|
|||
pub env: Env,
|
||||
pub log_level: LogLevel,
|
||||
pub address: FlodgattAddr,
|
||||
pub port: Port2,
|
||||
pub port: Port,
|
||||
pub unix_socket: Socket,
|
||||
pub cors: Cors<'a>,
|
||||
pub sse_interval: SseInterval,
|
||||
|
@ -14,22 +14,19 @@ pub struct DeploymentConfig<'a> {
|
|||
}
|
||||
|
||||
impl DeploymentConfig<'_> {
|
||||
pub fn from_env(env_vars: HashMap<String, String>) -> Self {
|
||||
let mut res = Self::default();
|
||||
res.env = Env::from_env_var_or_die(env_vars.get("NODE_ENV"));
|
||||
res.env = Env::from_env_var_or_die(env_vars.get("RUST_ENV"));
|
||||
res.log_level = LogLevel::from_env_var_or_die(env_vars.get("RUST_LOG"));
|
||||
res.address = FlodgattAddr::from_env_var_or_die(env_vars.get("BIND"));
|
||||
res.port = Port2::from_env_var_or_die(env_vars.get("PORT"));
|
||||
res.unix_socket = Socket::from_env_var_or_die(env_vars.get("SOCKET"));
|
||||
res.sse_interval = SseInterval::from_env_var_or_die(env_vars.get("SSE_FREQ"));
|
||||
res.ws_interval = WsInterval::from_env_var_or_die(env_vars.get("WS_FREQ"));
|
||||
|
||||
res.log()
|
||||
}
|
||||
|
||||
fn log(self) -> Self {
|
||||
log::warn!("Using deployment configuration:\n {:#?}", &self);
|
||||
self
|
||||
pub fn from_env(env: HashMap<String, String>) -> Self {
|
||||
let mut cfg = Self {
|
||||
env: Env::default().maybe_update(env.get("NODE_ENV")),
|
||||
log_level: LogLevel::default().maybe_update(env.get("RUST_LOG")),
|
||||
address: FlodgattAddr::default().maybe_update(env.get("BIND")),
|
||||
port: Port::default().maybe_update(env.get("PORT")),
|
||||
unix_socket: Socket::default().maybe_update(env.get("SOCKET")),
|
||||
sse_interval: SseInterval::default().maybe_update(env.get("SSE_FREQ")),
|
||||
ws_interval: WsInterval::default().maybe_update(env.get("WS_FREQ")),
|
||||
cors: Cors::default(),
|
||||
};
|
||||
cfg.env = cfg.env.maybe_update(env.get("RUST_ENV"));
|
||||
log::info!("Using deployment configuration:\n {:#?}", &cfg);
|
||||
cfg
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,82 +1,81 @@
|
|||
use crate::from_env_var;
|
||||
use std::{fmt, net::IpAddr, os::unix::net::UnixListener, str::FromStr, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
os::unix::net::UnixListener,
|
||||
str::FromStr,
|
||||
time::Duration,
|
||||
};
|
||||
use strum_macros::{EnumString, EnumVariantNames};
|
||||
|
||||
from_env_var!(/// The current environment, which controls what file to read other ENV vars from
|
||||
Env {
|
||||
inner: EnvInner::Development; EnvInner,
|
||||
env_var: "RUST_ENV",
|
||||
allowed_values: format!("one of: {:?}", EnvInner::variants()),
|
||||
}
|
||||
inner_from_str(|s| EnvInner::from_str(s).ok())
|
||||
from_env_var!(
|
||||
/// The current environment, which controls what file to read other ENV vars from
|
||||
let name = Env;
|
||||
let default: EnvInner = EnvInner::Development;
|
||||
let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants()));
|
||||
let from_str = |s| EnvInner::from_str(s).ok();
|
||||
);
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum EnvInner {
|
||||
Production,
|
||||
Development,
|
||||
}
|
||||
|
||||
from_env_var!(/// The address to run Flodgatt on
|
||||
FlodgattAddr {
|
||||
inner: IpAddr::V4("127.0.0.1".parse().expect("hardcoded")); IpAddr,
|
||||
env_var: "BIND",
|
||||
allowed_values: "a valid address (e.g., 127.0.0.1)".to_string(),
|
||||
}
|
||||
inner_from_str(|s| s.parse().ok()));
|
||||
from_env_var!(/// How verbosely Flodgatt should log messages
|
||||
LogLevel {
|
||||
inner: LogLevelInner::Warn; LogLevelInner,
|
||||
env_var: "RUST_LOG",
|
||||
allowed_values: format!("one of {:?}", LogLevelInner::variants()),
|
||||
}
|
||||
inner_from_str(|s| LogLevelInner::from_str(s).ok()));
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum LogLevelInner {
|
||||
Trace,
|
||||
Debug,
|
||||
Info,
|
||||
Warn,
|
||||
Error,
|
||||
}
|
||||
from_env_var!(/// A Unix Socket to use in place of a local address
|
||||
Socket{
|
||||
inner: None; Option<UnixListener>,
|
||||
env_var: "SOCKET",
|
||||
allowed_values: "a valid Unix Socket".to_string(),
|
||||
}
|
||||
inner_from_str(|s| match UnixListener::bind(s).ok() {
|
||||
from_env_var!(
|
||||
/// The address to run Flodgatt on
|
||||
let name = FlodgattAddr;
|
||||
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
|
||||
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string());
|
||||
let from_str = |s| match s {
|
||||
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
_ => s.parse().ok(),
|
||||
};
|
||||
);
|
||||
from_env_var!(
|
||||
/// How verbosely Flodgatt should log messages
|
||||
let name = LogLevel;
|
||||
let default: LogLevelInner = LogLevelInner::Warn;
|
||||
let (env_var, allowed_values) = ("RUST_LOG", "a valid address (e.g., 127.0.0.1)".to_string());
|
||||
let from_str = |s| LogLevelInner::from_str(s).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// A Unix Socket to use in place of a local address
|
||||
let name = Socket;
|
||||
let default: Option<UnixListener> = None;
|
||||
let (env_var, allowed_values) = ("SOCKET", "a valid Unix Socket".to_string());
|
||||
let from_str = |s| match UnixListener::bind(s).ok() {
|
||||
Some(socket) => Some(Some(socket)),
|
||||
None => None,
|
||||
}));
|
||||
from_env_var!(/// The time between replies sent via WebSocket
|
||||
WsInterval {
|
||||
inner: Duration::from_millis(100); Duration,
|
||||
env_var: "WS_FREQ",
|
||||
allowed_values: "a number of milliseconds".to_string(),
|
||||
}
|
||||
inner_from_str(|s| s.parse().map(|num| Duration::from_millis(num)).ok()));
|
||||
from_env_var!(/// The time between replies sent via Server Sent Events
|
||||
SseInterval {
|
||||
inner: Duration::from_millis(100); Duration,
|
||||
env_var: "SSE_FREQ",
|
||||
allowed_values: "a number of milliseconds".to_string(),
|
||||
}
|
||||
inner_from_str(|s| s.parse().map(|num| Duration::from_millis(num)).ok()));
|
||||
from_env_var!(/// The port to run Flodgatt on
|
||||
Port2 {
|
||||
inner: 4000; u16,
|
||||
env_var: "PORT",
|
||||
allowed_values: "a number".to_string(),
|
||||
}
|
||||
inner_from_str(|s| s.parse().ok()));
|
||||
|
||||
};
|
||||
);
|
||||
from_env_var!(
|
||||
/// The time between replies sent via WebSocket
|
||||
let name = WsInterval;
|
||||
let default: Duration = Duration::from_millis(100);
|
||||
let (env_var, allowed_values) = ("WS_FREQ", "a valid Unix Socket".to_string());
|
||||
let from_str = |s| s.parse().map(|num| Duration::from_millis(num)).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The time between replies sent via Server Sent Events
|
||||
let name = SseInterval;
|
||||
let default: Duration = Duration::from_millis(100);
|
||||
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
|
||||
let from_str = |s| s.parse().map(|num| Duration::from_millis(num)).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The port to run Flodgatt on
|
||||
let name = Port;
|
||||
let default: u16 = 4000;
|
||||
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string());
|
||||
let from_str = |s| s.parse().ok();
|
||||
);
|
||||
/// Permissions for Cross Origin Resource Sharing (CORS)
|
||||
pub struct Cors<'a> {
|
||||
pub allowed_headers: Vec<&'a str>,
|
||||
pub allowed_methods: Vec<&'a str>,
|
||||
}
|
||||
impl std::default::Default for Cors<'_> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_methods: vec!["GET", "OPTIONS"],
|
||||
allowed_headers: vec!["Authorization", "Accept", "Cache-Control"],
|
||||
}
|
||||
}
|
||||
}
|
||||
impl fmt::Debug for Cors<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
|
@ -86,11 +85,20 @@ impl fmt::Debug for Cors<'_> {
|
|||
)
|
||||
}
|
||||
}
|
||||
impl std::default::Default for Cors<'_> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_methods: vec!["GET", "OPTIONS"],
|
||||
allowed_headers: vec!["Authorization", "Accept", "Cache-Control"],
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum LogLevelInner {
|
||||
Trace,
|
||||
Debug,
|
||||
Info,
|
||||
Warn,
|
||||
Error,
|
||||
}
|
||||
|
||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
||||
#[strum(serialize_all = "snake_case")]
|
||||
pub enum EnvInner {
|
||||
Production,
|
||||
Development,
|
||||
}
|
||||
|
|
|
@ -2,8 +2,10 @@ mod deployment_cfg;
|
|||
mod deployment_cfg_types;
|
||||
mod postgres_cfg;
|
||||
mod redis_cfg;
|
||||
mod redis_cfg_types;
|
||||
pub use self::{
|
||||
deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig,
|
||||
redis_cfg_types::RedisInterval,
|
||||
};
|
||||
|
||||
#[macro_export]
|
||||
|
@ -25,59 +27,39 @@ macro_rules! maybe_update {
|
|||
#[macro_export]
|
||||
macro_rules! from_env_var {
|
||||
($(#[$outer:meta])*
|
||||
$name:ident {
|
||||
inner: $inner:expr; $type:ty,
|
||||
env_var: $env_var:tt,
|
||||
allowed_values: $allowed_values:expr,
|
||||
}
|
||||
inner_from_str(|$arg:ident| $body:expr)
|
||||
let name = $name:ident;
|
||||
let default: $type:ty = $inner:expr;
|
||||
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
|
||||
let from_str = |$arg:ident| $body:expr;
|
||||
) => {
|
||||
pub struct $name {
|
||||
pub inner: $type,
|
||||
pub env_var: String,
|
||||
pub allowed_values: String,
|
||||
}
|
||||
pub struct $name(pub $type);
|
||||
impl std::fmt::Debug for $name {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:?}", self.inner)
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self.0)
|
||||
}
|
||||
}
|
||||
impl std::ops::Deref for $name {
|
||||
type Target = $type;
|
||||
fn deref(&self) -> &$type {
|
||||
&self.inner
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl std::default::Default for $name {
|
||||
fn default() -> Self {
|
||||
$name {
|
||||
inner: $inner,
|
||||
env_var: $env_var.to_string(),
|
||||
allowed_values: $allowed_values,
|
||||
}
|
||||
$name($inner)
|
||||
}
|
||||
}
|
||||
impl $name {
|
||||
fn inner_from_str($arg: &str) -> Option<$type> {
|
||||
$body
|
||||
}
|
||||
fn update_inner(&mut self, inner: $type) -> &Self {
|
||||
self.inner = inner;
|
||||
self
|
||||
}
|
||||
pub fn from_env_var_or_die(env: Option<&String>) -> Self {
|
||||
let mut res = Self::default();
|
||||
if let Some(value) = env {
|
||||
res.update_inner(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
eprintln!(
|
||||
"\"{}\" is not a valid value for {}. {} must be {}",
|
||||
value, res.env_var, res.env_var, res.allowed_values
|
||||
);
|
||||
std::process::exit(1);
|
||||
}));
|
||||
res
|
||||
pub fn maybe_update(self, var: Option<&String>) -> Self {
|
||||
if let Some(value) = var {
|
||||
Self(Self::inner_from_str(value).unwrap_or_else(|| {
|
||||
crate::err::env_var_fatal($env_var, value, $allowed_values)
|
||||
}))
|
||||
} else {
|
||||
res
|
||||
self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::{err, maybe_update};
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
use url::Url;
|
||||
use super::redis_cfg_types::*;
|
||||
//use crate::{err, maybe_update};
|
||||
use crate::maybe_update;
|
||||
use std::collections::HashMap;
|
||||
//use url::Url;
|
||||
|
||||
fn none_if_empty(item: &str) -> Option<String> {
|
||||
Some(item).filter(|i| !i.is_empty()).map(String::from)
|
||||
|
@ -9,80 +11,51 @@ fn none_if_empty(item: &str) -> Option<String> {
|
|||
#[derive(Debug)]
|
||||
pub struct RedisConfig {
|
||||
pub user: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub port: u16,
|
||||
pub host: String,
|
||||
pub password: RedisPass,
|
||||
pub port: RedisPort,
|
||||
pub host: RedisHost,
|
||||
pub db: Option<String>,
|
||||
pub namespace: Option<String>,
|
||||
// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver`
|
||||
// (on the order of 1ms rather than 50μs). Thus, changing this setting
|
||||
// would be a good place to start for performance improvements at the cost
|
||||
// of delaying all updates.
|
||||
pub polling_interval: Duration,
|
||||
pub polling_interval: RedisInterval,
|
||||
}
|
||||
impl Default for RedisConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
user: None,
|
||||
password: None,
|
||||
password: RedisPass::default(),
|
||||
db: None,
|
||||
port: 6379,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: RedisPort::default(),
|
||||
host: RedisHost::default(),
|
||||
namespace: None,
|
||||
polling_interval: Duration::from_millis(100),
|
||||
polling_interval: RedisInterval::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RedisConfig {
|
||||
pub fn from_env(env_vars: HashMap<String, String>) -> Self {
|
||||
match env_vars.get("REDIS_URL") {
|
||||
Some(url) => {
|
||||
log::warn!("REDIS_URL env variable set. Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT.");
|
||||
Self::from_url(Url::parse(url).unwrap())
|
||||
}
|
||||
None => RedisConfig::default()
|
||||
.maybe_update_host(env_vars.get("REDIS_HOST").map(String::from))
|
||||
.maybe_update_port(env_vars.get("REDIS_PORT").map(|p| err::unwrap_or_die(
|
||||
p.parse().ok(),"REDIS_PORT must be a number."))),
|
||||
}
|
||||
.maybe_update_namespace(env_vars.get("REDIS_NAMESPACE").map(String::from))
|
||||
.maybe_update_polling_interval(env_vars.get("REDIS_POLL_INTERVAL")
|
||||
.map(|str| Duration::from_millis(str.parse().unwrap()))).log()
|
||||
// TODO handle REDIS_URL
|
||||
|
||||
let mut cfg = RedisConfig::default();
|
||||
cfg.host = RedisHost::default().maybe_update(env_vars.get("REDIS_HOST"));
|
||||
cfg = cfg.maybe_update_namespace(env_vars.get("REDIS_NAMESPACE").map(String::from));
|
||||
|
||||
cfg.port = RedisPort::default().maybe_update(env_vars.get("REDIS_PORT"));
|
||||
cfg.polling_interval =
|
||||
RedisInterval::default().maybe_update(env_vars.get("REDIS_POLL_INTERVAL"));
|
||||
cfg.password = RedisPass::default().maybe_update(env_vars.get("REDIS_PASSWORD"));
|
||||
|
||||
cfg.log()
|
||||
}
|
||||
|
||||
fn from_url(url: Url) -> Self {
|
||||
let mut password = url.password().as_ref().map(|str| str.to_string());
|
||||
let mut db = none_if_empty(&url.path()[1..]);
|
||||
for (k, v) in url.query_pairs() {
|
||||
match k.to_string().as_str() {
|
||||
"password" => { password = Some(v.to_string());},
|
||||
"db" => { db = Some(v.to_string())},
|
||||
_ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", k))}
|
||||
}
|
||||
}
|
||||
let user = none_if_empty(url.username());
|
||||
if let Some(user) = &user {
|
||||
log::error!(
|
||||
"Username {} provided, but Redis does not need a username. Ignoring it",
|
||||
user
|
||||
);
|
||||
}
|
||||
RedisConfig {
|
||||
user,
|
||||
host: err::unwrap_or_die(url.host_str(), "Missing or invalid host in REDIS_URL")
|
||||
.to_string(),
|
||||
port: err::unwrap_or_die(url.port(), "Missing or invalid port in REDIS_URL"),
|
||||
namespace: None,
|
||||
password,
|
||||
db,
|
||||
polling_interval: Duration::from_millis(100),
|
||||
}
|
||||
}
|
||||
|
||||
maybe_update!(maybe_update_host; host: String);
|
||||
maybe_update!(maybe_update_port; port: u16);
|
||||
// maybe_update!(maybe_update_host; host: String);
|
||||
// maybe_update!(maybe_update_port; port: u16);
|
||||
maybe_update!(maybe_update_namespace; Some(namespace: String));
|
||||
maybe_update!(maybe_update_polling_interval; polling_interval: Duration);
|
||||
// maybe_update!(maybe_update_polling_interval; polling_interval: Duration);
|
||||
|
||||
fn log(self) -> Self {
|
||||
log::warn!("Redis configuration:\n{:#?},", &self);
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
use crate::from_env_var;
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
time::Duration,
|
||||
};
|
||||
//use std::{fmt, net::IpAddr, os::unix::net::UnixListener, str::FromStr, time::Duration};
|
||||
//use strum_macros::{EnumString, EnumVariantNames};
|
||||
|
||||
from_env_var!(
|
||||
/// The host address where Redis is running
|
||||
let name = RedisHost;
|
||||
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
|
||||
let (env_var, allowed_values) = ("REDIS_HOST", "a valid address (e.g., 127.0.0.1)".to_string());
|
||||
let from_str = |s| match s {
|
||||
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
|
||||
_ => s.parse().ok(),
|
||||
};
|
||||
);
|
||||
|
||||
from_env_var!(
|
||||
/// The port Redis is running on
|
||||
let name = RedisPort;
|
||||
let default: u16 = 6379;
|
||||
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string());
|
||||
let from_str = |s| s.parse().ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// How frequently to poll Redis
|
||||
let name = RedisInterval;
|
||||
let default: Duration = Duration::from_millis(100);
|
||||
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string());
|
||||
let from_str = |s| s.parse().map(|num| Duration::from_millis(num)).ok();
|
||||
);
|
||||
from_env_var!(
|
||||
/// The password to use for Redis
|
||||
let name = RedisPass;
|
||||
let default: Option<String> = None;
|
||||
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string());
|
||||
let from_str = |s| Some(Some(s.to_string()));
|
||||
);
|
11
src/err.rs
11
src/err.rs
|
@ -6,6 +6,17 @@ pub fn die_with_msg(msg: impl Display) -> ! {
|
|||
std::process::exit(1);
|
||||
}
|
||||
|
||||
pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
|
||||
eprintln!(
|
||||
r"FATAL ERROR: {var} is set to `{value}`, which is invalid.
|
||||
{var} must be {allowed_vals}.",
|
||||
var = env_var,
|
||||
value = supplied_value,
|
||||
allowed_vals = allowed_values
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! dbg_and_die {
|
||||
($msg:expr) => {
|
||||
|
|
|
@ -79,9 +79,9 @@ fn main() {
|
|||
.allow_methods(cfg.cors.allowed_methods)
|
||||
.allow_headers(cfg.cors.allowed_headers);
|
||||
|
||||
let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.inner);
|
||||
let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0);
|
||||
|
||||
if let Some(_socket) = cfg.unix_socket.inner.as_ref() {
|
||||
if let Some(_socket) = cfg.unix_socket.0.as_ref() {
|
||||
dbg_and_die!("Unix socket support not yet implemented");
|
||||
} else {
|
||||
warp::serve(websocket_routes.or(sse_routes).with(cors)).run(server_addr);
|
||||
|
|
|
@ -39,6 +39,7 @@ pub fn extract_user_or_reject(pg_conn: PostgresConn) -> BoxedFilter<(User,)> {
|
|||
.and_then(move |q| User::from_query(q, pg_conn.clone()))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! Receives data from Redis, sorts it by `ClientAgent`, and stores it until
|
||||
//! polled by the correct `ClientAgent`. Also manages sububscriptions and
|
||||
//! unsubscriptions to/from Redis.
|
||||
use super::{config, redis_cmd, redis_stream, redis_stream::RedisConn};
|
||||
use super::{config, config::RedisInterval, redis_cmd, redis_stream, redis_stream::RedisConn};
|
||||
use crate::pubsub_cmd;
|
||||
use futures::{Async, Poll};
|
||||
use serde_json::Value;
|
||||
|
@ -15,7 +15,7 @@ pub struct Receiver {
|
|||
pub pubsub_connection: net::TcpStream,
|
||||
secondary_redis_connection: net::TcpStream,
|
||||
pub redis_namespace: Option<String>,
|
||||
redis_poll_interval: time::Duration,
|
||||
redis_poll_interval: RedisInterval,
|
||||
redis_polled_at: time::Instant,
|
||||
timeline: String,
|
||||
manager_id: Uuid,
|
||||
|
@ -139,7 +139,7 @@ impl futures::stream::Stream for Receiver {
|
|||
/// been polled lately.
|
||||
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
|
||||
let timeline = self.timeline.clone();
|
||||
if self.redis_polled_at.elapsed() > self.redis_poll_interval {
|
||||
if self.redis_polled_at.elapsed() > *self.redis_poll_interval {
|
||||
redis_stream::AsyncReadableStream::poll_redis(self);
|
||||
self.redis_polled_at = time::Instant::now();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::receiver::Receiver;
|
||||
use crate::{config, redis_to_client_stream::redis_cmd};
|
||||
use crate::{config, config::RedisInterval, err, redis_to_client_stream::redis_cmd};
|
||||
use futures::{Async, Poll};
|
||||
use serde_json::Value;
|
||||
use std::{io::Read, io::Write, net, time};
|
||||
|
@ -9,32 +9,77 @@ pub struct RedisConn {
|
|||
pub primary: net::TcpStream,
|
||||
pub secondary: net::TcpStream,
|
||||
pub namespace: Option<String>,
|
||||
pub polling_interval: time::Duration,
|
||||
pub polling_interval: RedisInterval,
|
||||
}
|
||||
|
||||
fn send_password(mut conn: net::TcpStream, password: &String) -> net::TcpStream {
|
||||
conn.write_all(&redis_cmd::cmd("auth", &password)).unwrap();
|
||||
let mut buffer = vec![0u8; 5];
|
||||
conn.read_exact(&mut buffer).unwrap();
|
||||
let reply = String::from_utf8(buffer.to_vec()).unwrap();
|
||||
if reply != "+OK\r\n".to_string() {
|
||||
err::die_with_msg(format!(
|
||||
r"Incorrect Redis password. You supplied `{}`.
|
||||
Please supply correct password with REDIS_PASSWORD environmental variable.",
|
||||
password,
|
||||
))
|
||||
};
|
||||
conn
|
||||
}
|
||||
|
||||
fn send_test_ping(mut conn: net::TcpStream) -> net::TcpStream {
|
||||
conn.write_all(b"PING\r\n").unwrap();
|
||||
let mut buffer = vec![0u8; 7];
|
||||
conn.read_exact(&mut buffer).unwrap();
|
||||
let reply = String::from_utf8(buffer.to_vec()).unwrap();
|
||||
match reply.as_str() {
|
||||
"+PONG\r\n" => (),
|
||||
"-NOAUTH" => err::die_with_msg(
|
||||
r"Invalid authentication for Redis.
|
||||
Redis reports that it needs a password, but you did not provide one.
|
||||
You can set a password with the REDIS_PASSWORD environmental variable.",
|
||||
),
|
||||
"HTTP/1." => err::die_with_msg(
|
||||
r"The server at REDIS_HOST and REDIS_PORT is not a Redis server.
|
||||
Please update the REDIS_HOST and/or REDIS_PORT environmental variables.",
|
||||
),
|
||||
_ => err::die_with_msg(format!(
|
||||
"Could not connect to Redis for unknown reason. Expected `+PONG` reply but got {}",
|
||||
reply
|
||||
)),
|
||||
};
|
||||
conn
|
||||
}
|
||||
|
||||
impl RedisConn {
|
||||
pub fn new(redis_cfg: config::RedisConfig) -> Self {
|
||||
let addr = format!("{}:{}", redis_cfg.host, redis_cfg.port);
|
||||
let mut pubsub_connection =
|
||||
net::TcpStream::connect(addr.clone()).expect("Can connect to Redis");
|
||||
let addr = net::SocketAddr::from((*redis_cfg.host, *redis_cfg.port));
|
||||
let conn_err = |e| {
|
||||
err::die_with_msg(format!(
|
||||
"Could not connect to Redis at {}:{}.\n Error detail: {}",
|
||||
*redis_cfg.host, *redis_cfg.port, e,
|
||||
))
|
||||
};
|
||||
let mut pubsub_connection = net::TcpStream::connect(addr).unwrap_or_else(conn_err);
|
||||
let mut secondary_redis_connection = net::TcpStream::connect(addr).unwrap_or_else(conn_err);
|
||||
|
||||
if let Some(password) = redis_cfg.password.clone() {
|
||||
pubsub_connection = send_password(pubsub_connection, &password);
|
||||
secondary_redis_connection = send_password(secondary_redis_connection, &password)
|
||||
}
|
||||
pubsub_connection = send_test_ping(pubsub_connection);
|
||||
secondary_redis_connection = send_test_ping(secondary_redis_connection);
|
||||
|
||||
pubsub_connection
|
||||
.set_read_timeout(Some(time::Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
pubsub_connection
|
||||
.set_nonblocking(true)
|
||||
.expect("set_nonblocking call failed");
|
||||
let mut secondary_redis_connection =
|
||||
net::TcpStream::connect(addr).expect("Can connect to Redis");
|
||||
|
||||
secondary_redis_connection
|
||||
.set_read_timeout(Some(time::Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
if let Some(password) = redis_cfg.password {
|
||||
pubsub_connection
|
||||
.write_all(&redis_cmd::cmd("auth", &password))
|
||||
.unwrap();
|
||||
secondary_redis_connection
|
||||
.write_all(&redis_cmd::cmd("auth", password))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
if let Some(db) = redis_cfg.db {
|
||||
pubsub_connection
|
||||
|
@ -72,13 +117,18 @@ impl<'a> AsyncReadableStream<'a> {
|
|||
|
||||
if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() {
|
||||
let raw_redis_response = async_stream.as_utf8(buffer, num_bytes_read);
|
||||
dbg!(&raw_redis_response);
|
||||
if raw_redis_response.starts_with("-NOAUTH") {
|
||||
eprintln!(
|
||||
err::die_with_msg(
|
||||
r"Invalid authentication for Redis.
|
||||
Do you need a password?
|
||||
If so, set it with the REDIS_PASSWORD environmental variable"
|
||||
If so, set it with the REDIS_PASSWORD environmental variable.",
|
||||
);
|
||||
} else if raw_redis_response.starts_with("HTTP") {
|
||||
err::die_with_msg(
|
||||
r"The server at REDIS_HOST and REDIS_PORT is not a Redis server.
|
||||
Please update the REDIS_HOST and/or REDIS_PORT environmental variables with the correct values.",
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
receiver.incoming_raw_msg.push_str(&raw_redis_response);
|
||||
|
|
Loading…
Reference in New Issue