From e8145275b58f03a55d57bbb5945393927719a3f3 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Thu, 3 Oct 2019 00:34:41 -0400 Subject: [PATCH] Config refactor (#57) * Refactor configuration * Fix bug with incorrect Host env variable * Improve logging of REDIS_NAMESPACE * Update test for Postgres configuration * Conform Redis config to Postgres changes --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/config.rs | 233 ------------------ src/config/mod.rs | 128 ++++++++++ src/config/postgres_cfg.rs | 63 +++++ src/config/redis_cfg.rs | 56 +++++ src/main.rs | 8 +- src/parse_client_request/sse.rs | 34 ++- .../user/mock_postgres.rs | 15 +- src/parse_client_request/user/mod.rs | 20 +- src/parse_client_request/user/postgres.rs | 33 ++- src/parse_client_request/ws.rs | 24 +- src/redis_to_client_stream/receiver.rs | 11 +- src/redis_to_client_stream/redis_cmd.rs | 12 +- src/redis_to_client_stream/redis_stream.rs | 56 ++++- src/rustfmt.toml | 1 + 16 files changed, 412 insertions(+), 286 deletions(-) delete mode 100644 src/config.rs create mode 100644 src/config/mod.rs create mode 100644 src/config/postgres_cfg.rs create mode 100644 src/config/redis_cfg.rs create mode 100644 src/rustfmt.toml diff --git a/Cargo.lock b/Cargo.lock index 7cb3f5e..9f61f78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -386,7 +386,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.3.4" +version = "0.3.5" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index fc0dd21..679c4e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.3.4" +version = "0.3.5" authors = ["Daniel Long Sockwell "] edition = "2018" diff --git a/src/config.rs b/src/config.rs deleted file mode 100644 index afff20c..0000000 --- a/src/config.rs +++ /dev/null @@ -1,233 +0,0 @@ -//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden -//! by an environmental variable of the same name without that prefix (either by setting -//! the variable at runtime or in the `.env` file) -use dotenv::dotenv; -use lazy_static::lazy_static; -use log::warn; -use std::{env, io::Write, net, time}; -use url::Url; - -use crate::{err, redis_to_client_stream::redis_cmd}; - -const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"]; -const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"]; -// Postgres -const DEFAULT_DB_HOST: &str = "localhost"; -const DEFAULT_DB_USER: &str = "postgres"; -const DEFAULT_DB_NAME: &str = "mastodon_development"; -const DEFAULT_DB_PORT: &str = "5432"; -const DEFAULT_DB_SSLMODE: &str = "prefer"; -// Redis -const DEFAULT_REDIS_HOST: &str = "127.0.0.1"; -const DEFAULT_REDIS_PORT: &str = "6379"; - -const _DEFAULT_REDIS_NAMESPACE: &str = ""; -// Deployment -const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000"; - -const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; -const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; -/// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` -/// (on the order of 10ms rather than 50μs). Thus, changing this setting -/// would be a good place to start for performance improvements at the cost -/// of delaying all updates. -const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; - -fn default(var: &str, default_var: &str) -> String { - env::var(var) - .unwrap_or_else(|_| { - warn!( - "No {} env variable set. Using default value: {}", - var, default_var - ); - default_var.to_string() - }) - .to_string() -} - -lazy_static! { - static ref POSTGRES_ADDR: String = match &env::var("DATABASE_URL") { - Ok(url) => { - warn!("DATABASE_URL env variable set. Connecting to Postgres with that URL and ignoring any values set in DB_HOST, DB_USER, DB_NAME, DB_PASS, or DB_PORT."); - url.to_string() - } - Err(_) => { - let user = &env::var("DB_USER").unwrap_or_else(|_| { - match &env::var("USER") { - Err(_) => default("DB_USER", DEFAULT_DB_USER), - Ok(user) => default("DB_USER", user) - } - }); - let host = &env::var("DB_HOST") - .unwrap_or_else(|_| default("DB_HOST", DEFAULT_DB_HOST)); - let db_name = &env::var("DB_NAME") - .unwrap_or_else(|_| default("DB_NAME", DEFAULT_DB_NAME)); - let port = &env::var("DB_PORT") - .unwrap_or_else(|_| default("DB_PORT", DEFAULT_DB_PORT)); - let ssl_mode = &env::var("DB_SSLMODE") - .unwrap_or_else(|_| default("DB_SSLMODE", DEFAULT_DB_SSLMODE)); - - - match &env::var("DB_PASS") { - Ok(password) => { - format!("postgres://{}:{}@{}:{}/{}?sslmode={}", - user, password, host, port, db_name, ssl_mode)}, - Err(_) => { - warn!("No DB_PASSWORD set. Attempting to connect to Postgres without a password. (This is correct if you are using the `ident` method.)"); - format!("postgres://{}@{}:{}/{}?sslmode={}", - user, host, port, db_name, ssl_mode) - }, - } - } - }; - static ref REDIS_ADDR: RedisConfig = match &env::var("REDIS_URL") { - Ok(url) => { - warn!(r"REDIS_URL env variable set. - Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT."); - let url = Url::parse(url).unwrap(); - fn none_if_empty(item: &str) -> Option { - if item.is_empty() { None } else { Some(item.to_string()) } - }; - - - let user = none_if_empty(url.username()); - let mut password = url.password().as_ref().map(|str| str.to_string()); - let host = err::unwrap_or_die(url.host_str(),"Missing/invalid host in REDIS_URL"); - let port = err::unwrap_or_die(url.port(), "Missing/invalid port in REDIS_URL"); - let mut db = none_if_empty(url.path()); - let query_pairs = url.query_pairs(); - - for (key, value) in query_pairs { - match key.to_string().as_str() { - "password" => { password = Some(value.to_string());}, - "db" => { db = Some(value.to_string())} - _ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", key))} - } - } - RedisConfig { - user, - password, - host, - port, - db - } - } - Err(_) => { - let host = env::var("REDIS_HOST") - .unwrap_or_else(|_| default("REDIS_HOST", DEFAULT_REDIS_HOST)); - let port = env::var("REDIS_PORT") - .unwrap_or_else(|_| default("REDIS_PORT", DEFAULT_REDIS_PORT)); - RedisConfig { - user: None, - password: None, - host, - port, - db: None, - } - } - }; - pub static ref REDIS_NAMESPACE: Option = match env::var("REDIS_NAMESPACE") { - Ok(ns) => { - log::warn!("Using `{}:` as a Redis namespace.", ns); - Some(ns) - }, - _ => None - }; - - - pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR") - .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) - .parse() - .expect("static string"); - - /// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via SSE. - pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL); - /// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via WS. - pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(DEFAULT_WS_UPDATE_INTERVAL); - /// Interval, in ms, at which the `Receiver` polls Redis. - pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL") - .map(|s| s.parse().expect("Valid config")) - .unwrap_or(DEFAULT_REDIS_POLL_INTERVAL); -} - -/// Configure CORS for the API server -pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { - warp::cors() - .allow_any_origin() - .allow_methods(CORS_ALLOWED_METHODS.to_vec()) - .allow_headers(CORS_ALLOWED_HEADERS.to_vec()) -} - -/// Initialize logging and read values from `src/.env` -pub fn logging_and_env() { - dotenv().ok(); - pretty_env_logger::init(); - POSTGRES_ADDR.to_string(); -} - -/// Configure Postgres and return a connection -pub fn postgres() -> postgres::Client { - // use openssl::ssl::{SslConnector, SslMethod}; - // use postgres_openssl::MakeTlsConnector; - // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - // builder.set_ca_file("/etc/ssl/cert.pem").unwrap(); - // let connector = MakeTlsConnector::new(builder.build()); - // TODO: add TLS support, remove `NoTls` - postgres::Client::connect(&POSTGRES_ADDR.to_string(), postgres::NoTls) - .expect("Can connect to local Postgres") -} -#[derive(Default)] -struct RedisConfig { - user: Option, - password: Option, - port: String, - host: String, - db: Option, -} -/// Configure Redis -pub fn redis_addr() -> (net::TcpStream, net::TcpStream) { - let redis = &REDIS_ADDR; - let addr = format!("{}:{}", redis.host, redis.port); - if let Some(user) = &redis.user { - log::error!( - "Username {} provided, but Redis does not need a username. Ignoring it", - user - ); - }; - let mut pubsub_connection = - net::TcpStream::connect(addr.clone()).expect("Can connect to Redis"); - pubsub_connection - .set_read_timeout(Some(time::Duration::from_millis(10))) - .expect("Can set read timeout for Redis connection"); - pubsub_connection - .set_nonblocking(true) - .expect("set_nonblocking call failed"); - let mut secondary_redis_connection = - net::TcpStream::connect(addr).expect("Can connect to Redis"); - secondary_redis_connection - .set_read_timeout(Some(time::Duration::from_millis(10))) - .expect("Can set read timeout for Redis connection"); - if let Some(password) = &REDIS_ADDR.password { - pubsub_connection - .write_all(&redis_cmd::cmd("auth", &password)) - .unwrap(); - secondary_redis_connection - .write_all(&redis_cmd::cmd("auth", password)) - .unwrap(); - } else { - warn!("No REDIS_PASSWORD set. Attempting to connect to Redis without a password. (This is correct if you are following the default setup.)"); - } - if let Some(db) = &REDIS_ADDR.db { - pubsub_connection - .write_all(&redis_cmd::cmd("SELECT", &db)) - .unwrap(); - secondary_redis_connection - .write_all(&redis_cmd::cmd("SELECT", &db)) - .unwrap(); - } - (pubsub_connection, secondary_redis_connection) -} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..0fba3a9 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,128 @@ +//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden +//! by an environmental variable of the same name without that prefix (either by setting +//! the variable at runtime or in the `.env` file) +mod postgres_cfg; +mod redis_cfg; +pub use self::{postgres_cfg::PostgresConfig, redis_cfg::RedisConfig}; +use dotenv::dotenv; +use lazy_static::lazy_static; +use log::warn; +use std::{env, net}; +use url::Url; + +const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"]; +const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"]; +// Postgres +// Deployment +const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000"; + +const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100; +const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100; +/// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver` +/// (on the order of 10ms rather than 50μs). Thus, changing this setting +/// would be a good place to start for performance improvements at the cost +/// of delaying all updates. +const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100; + +lazy_static! { + pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR") + .unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned()) + .parse() + .expect("static string"); + + /// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via SSE. + pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL); + /// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via WS. + pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(DEFAULT_WS_UPDATE_INTERVAL); + /// Interval, in ms, at which the `Receiver` polls Redis. + pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL") + .map(|s| s.parse().expect("Valid config")) + .unwrap_or(DEFAULT_REDIS_POLL_INTERVAL); +} + +/// Configure CORS for the API server +pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors { + warp::cors() + .allow_any_origin() + .allow_methods(CORS_ALLOWED_METHODS.to_vec()) + .allow_headers(CORS_ALLOWED_HEADERS.to_vec()) +} + +/// Initialize logging and read values from `src/.env` +pub fn logging_and_env() { + dotenv().ok(); + pretty_env_logger::init(); +} + +/// Configure Postgres and return a connection +pub fn postgres() -> PostgresConfig { + // use openssl::ssl::{SslConnector, SslMethod}; + // use postgres_openssl::MakeTlsConnector; + // let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + // builder.set_ca_file("/etc/ssl/cert.pem").unwrap(); + // let connector = MakeTlsConnector::new(builder.build()); + // TODO: add TLS support, remove `NoTls` + let pg_cfg = match &env::var("DATABASE_URL").ok() { + Some(url) => { + warn!("DATABASE_URL env variable set. Connecting to Postgres with that URL and ignoring any values set in DB_HOST, DB_USER, DB_NAME, DB_PASS, or DB_PORT."); + PostgresConfig::from_url(Url::parse(url).unwrap()) + } + None => PostgresConfig::default() + .maybe_update_user(env::var("USER").ok()) + .maybe_update_user(env::var("DB_USER").ok()) + .maybe_update_host(env::var("DB_HOST").ok()) + .maybe_update_password(env::var("DB_PASS").ok()) + .maybe_update_db(env::var("DB_NAME").ok()) + .maybe_update_sslmode(env::var("DB_SSLMODE").ok()), + }; + log::warn!( + "Connecting to Postgres with the following configuration:\n{:#?}", + &pg_cfg + ); + pg_cfg +} + +/// Configure Redis and return a pair of connections +pub fn redis() -> RedisConfig { + let redis_cfg = match &env::var("REDIS_URL") { + Ok(url) => { + warn!("REDIS_URL env variable set. Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT."); + RedisConfig::from_url(Url::parse(url).unwrap()) + } + Err(_) => RedisConfig::default() + .maybe_update_host(env::var("REDIS_HOST").ok()) + .maybe_update_port(env::var("REDIS_PORT").ok()), + }.maybe_update_namespace(env::var("REDIS_NAMESPACE").ok()); + if let Some(user) = &redis_cfg.user { + log::error!( + "Username {} provided, but Redis does not need a username. Ignoring it", + user + ); + }; + log::warn!( + "Connecting to Redis with the following configuration:\n{:#?}", + &redis_cfg + ); + redis_cfg +} + +#[macro_export] +macro_rules! maybe_update { + ($name:ident; $item: tt) => ( + pub fn $name(self, item: Option) -> Self{ + match item { + Some($item) => Self{ $item, ..self }, + None => Self { ..self } + } + }); + ($name:ident; Some($item: tt)) => ( + pub fn $name(self, item: Option) -> Self{ + match item { + Some($item) => Self{ $item: Some($item), ..self }, + None => Self { ..self } + } + })} diff --git a/src/config/postgres_cfg.rs b/src/config/postgres_cfg.rs new file mode 100644 index 0000000..613635a --- /dev/null +++ b/src/config/postgres_cfg.rs @@ -0,0 +1,63 @@ +use crate::{err, maybe_update}; +use url::Url; + +#[derive(Debug)] +pub struct PostgresConfig { + pub user: String, + pub host: String, + pub password: Option, + pub database: String, + pub port: String, + pub ssl_mode: String, +} + +impl Default for PostgresConfig { + fn default() -> Self { + Self { + user: "postgres".to_string(), + host: "localhost".to_string(), + password: None, + database: "mastodon_development".to_string(), + port: "5432".to_string(), + ssl_mode: "prefer".to_string(), + } + } +} +fn none_if_empty(item: &str) -> Option { + Some(item).filter(|i| !i.is_empty()).map(String::from) +} + +impl PostgresConfig { + maybe_update!(maybe_update_user; user); + maybe_update!(maybe_update_host; host); + maybe_update!(maybe_update_db; database); + maybe_update!(maybe_update_port; port); + maybe_update!(maybe_update_sslmode; ssl_mode); + maybe_update!(maybe_update_password; Some(password)); + + pub fn from_url(url: Url) -> Self { + let (mut user, mut host, mut sslmode, mut password) = (None, None, None, None); + for (k, v) in url.query_pairs() { + match k.to_string().as_str() { + "user" => { user = Some(v.to_string());}, + "password" => { password = Some(v.to_string());}, + "host" => { host = Some(v.to_string());}, + "sslmode" => { sslmode = Some(v.to_string());}, + _ => { err::die_with_msg(format!("Unsupported parameter {} in DATABASE_URL.\n Flodgatt supports only `user`, `password`, `host`, and `sslmode` parameters.", k))} + } + } + + Self::default() + // Values from query parameter + .maybe_update_user(user) + .maybe_update_password(password) + .maybe_update_host(host) + .maybe_update_sslmode(sslmode) + // Values from URL (which override query values if both are present) + .maybe_update_user(none_if_empty(url.username())) + .maybe_update_host(url.host_str().filter(|h| !h.is_empty()).map(String::from)) + .maybe_update_password(url.password().map(String::from)) + .maybe_update_port(url.port().map(|port_num| port_num.to_string())) + .maybe_update_db(none_if_empty(&url.path()[1..])) + } +} diff --git a/src/config/redis_cfg.rs b/src/config/redis_cfg.rs new file mode 100644 index 0000000..019e92a --- /dev/null +++ b/src/config/redis_cfg.rs @@ -0,0 +1,56 @@ +use crate::{err, maybe_update}; +use url::Url; + +fn none_if_empty(item: &str) -> Option { + if item.is_empty() { + None + } else { + Some(item.to_string()) + } +} + +#[derive(Debug)] +pub struct RedisConfig { + pub user: Option, + pub password: Option, + pub port: String, + pub host: String, + pub db: Option, + pub namespace: Option, +} +impl Default for RedisConfig { + fn default() -> Self { + Self { + user: None, + password: None, + db: None, + port: "6379".to_string(), + host: "127.0.0.1".to_string(), + namespace: None, + } + } +} +impl RedisConfig { + pub fn from_url(url: Url) -> Self { + let mut password = url.password().as_ref().map(|str| str.to_string()); + let mut db = none_if_empty(&url.path()[1..]); + for (k, v) in url.query_pairs() { + match k.to_string().as_str() { + "password" => { password = Some(v.to_string());}, + "db" => { db = Some(v.to_string())}, + _ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", k))} + } + } + RedisConfig { + user: none_if_empty(url.username()), + host: err::unwrap_or_die(url.host_str(), "Missing or invalid host in REDIS_URL"), + port: err::unwrap_or_die(url.port(), "Missing or invalid port in REDIS_URL"), + namespace: None, + password, + db, + } + } + maybe_update!(maybe_update_host; host); + maybe_update!(maybe_update_port; port); + maybe_update!(maybe_update_namespace; Some(namespace)); +} diff --git a/src/main.rs b/src/main.rs index 0c6d242..ea1bb94 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,7 @@ use flodgatt::{ config, err, parse_client_request::{sse, user, ws}, - redis_to_client_stream, - redis_to_client_stream::ClientAgent, + redis_to_client_stream::{self, ClientAgent}, }; use log::warn; use warp::{ws::Ws2, Filter as WarpFilter}; @@ -11,11 +10,12 @@ fn main() { config::logging_and_env(); let client_agent_sse = ClientAgent::blank(); let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); + let pg_conn = user::PostgresConn::new(); warn!("Streaming server initialized and ready to accept connections"); // Server Sent Events - let sse_routes = sse::extract_user_or_reject() + let sse_routes = sse::extract_user_or_reject(pg_conn.clone()) .and(warp::sse()) .map( move |user: user::User, sse_connection_to_client: warp::sse::Sse| { @@ -31,7 +31,7 @@ fn main() { .recover(err::handle_errors); // WebSocket - let websocket_routes = ws::extract_user_or_reject() + let websocket_routes = ws::extract_user_or_reject(pg_conn.clone()) .and(warp::ws::ws2()) .map(move |user: user::User, ws: Ws2| { let token = user.access_token.clone(); diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs index f4fef54..3f47187 100644 --- a/src/parse_client_request/sse.rs +++ b/src/parse_client_request/sse.rs @@ -1,7 +1,9 @@ //! Filters for all the endpoints accessible for Server Sent Event updates -use super::{query, query::Query, user::User}; +use super::{ + query::{self, Query}, + user::{PostgresConn, User}, +}; use warp::{filters::BoxedFilter, path, Filter}; - #[allow(dead_code)] type TimelineUser = ((String, User),); @@ -37,7 +39,7 @@ macro_rules! parse_query { .boxed() }; } -pub fn extract_user_or_reject() -> BoxedFilter<(User,)> { +pub fn extract_user_or_reject(pg_conn: PostgresConn) -> BoxedFilter<(User,)> { any_of!( parse_query!( path => "api" / "v1" / "streaming" / "user" / "notification" @@ -65,14 +67,14 @@ pub fn extract_user_or_reject() -> BoxedFilter<(User,)> { // parameter, we need to update our Query if the header has a token .and(query::OptionalAccessToken::from_sse_header()) .and_then(Query::update_access_token) - .and_then(User::from_query) + .and_then(move |q| User::from_query(q, pg_conn.clone())) .boxed() } #[cfg(test)] mod test { use super::*; - use crate::parse_client_request::user::{Filter, OauthScope}; + use crate::parse_client_request::user::{Filter, OauthScope, PostgresConn}; macro_rules! test_public_endpoint { ($name:ident { @@ -81,9 +83,10 @@ mod test { }) => { #[test] fn $name() { + let pg_conn = PostgresConn::new(); let user = warp::test::request() .path($path) - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); assert_eq!(user, $user); } @@ -98,16 +101,17 @@ mod test { #[test] fn $name() { let path = format!("{}?access_token=TEST_USER", $path); + let pg_conn = PostgresConn::new(); $(let path = format!("{}&{}", path, $query);)* let user = warp::test::request() .path(&path) - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn.clone())) .expect("in test"); assert_eq!(user, $user); let user = warp::test::request() .path(&path) .header("Authorization", "Bearer: TEST_USER") - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); assert_eq!(user, $user); } @@ -123,9 +127,10 @@ mod test { fn $name() { let path = format!("{}?access_token=INVALID", $path); $(let path = format!("{}&{}", path, $query);)* + let pg_conn = PostgresConn::new(); warp::test::request() .path(&path) - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } }; @@ -140,10 +145,12 @@ mod test { fn $name() { let path = $path; $(let path = format!("{}?{}", path, $query);)* + + let pg_conn = PostgresConn::new(); warp::test::request() .path(&path) .header("Authorization", "Bearer: INVALID") - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } }; @@ -158,9 +165,10 @@ mod test { fn $name() { let path = $path; $(let path = format!("{}?{}", path, $query);)* + let pg_conn = PostgresConn::new(); warp::test::request() .path(&path) - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } }; @@ -429,10 +437,10 @@ mod test { #[test] #[should_panic(expected = "NotFound")] fn nonexistant_endpoint() { + let pg_conn = PostgresConn::new(); warp::test::request() .path("/api/v1/streaming/DOES_NOT_EXIST") - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } - } diff --git a/src/parse_client_request/user/mock_postgres.rs b/src/parse_client_request/user/mock_postgres.rs index bf60d0e..df778d6 100644 --- a/src/parse_client_request/user/mock_postgres.rs +++ b/src/parse_client_request/user/mock_postgres.rs @@ -1,6 +1,17 @@ //! Mock Postgres connection (for use in unit testing) +use std::sync::{Arc, Mutex}; -pub fn query_for_user_data(access_token: &str) -> (i64, Option>, Vec) { +#[derive(Clone)] +pub struct PostgresConn(Arc>); +impl PostgresConn { + pub fn new() -> Self { + Self(Arc::new(Mutex::new("MOCK".to_string()))) + } +} +pub fn query_for_user_data( + access_token: &str, + _pg_conn: PostgresConn, +) -> (i64, Option>, Vec) { let (user_id, lang, scopes) = if access_token == "TEST_USER" { ( 1, @@ -17,7 +28,7 @@ pub fn query_for_user_data(access_token: &str) -> (i64, Option>, Vec (user_id, lang, scopes) } -pub fn query_list_owner(list_id: i64) -> Option { +pub fn query_list_owner(list_id: i64, _pg_conn: PostgresConn) -> Option { match list_id { 1 => Some(1), _ => None, diff --git a/src/parse_client_request/user/mod.rs b/src/parse_client_request/user/mod.rs index 53d3c80..5fd62a3 100644 --- a/src/parse_client_request/user/mod.rs +++ b/src/parse_client_request/user/mod.rs @@ -5,6 +5,7 @@ mod mock_postgres; use mock_postgres as postgres; #[cfg(not(test))] mod postgres; +pub use self::postgres::PostgresConn; use super::query::Query; use warp::reject::Rejection; @@ -57,7 +58,7 @@ impl From> for OauthScope { } impl User { - pub fn from_query(q: Query) -> Result { + pub fn from_query(q: Query, pg_conn: PostgresConn) -> Result { let (id, access_token, scopes, langs, logged_in) = match q.access_token.clone() { None => ( -1, @@ -67,7 +68,8 @@ impl User { false, ), Some(token) => { - let (id, langs, scope_list) = postgres::query_for_user_data(&token); + let (id, langs, scope_list) = + postgres::query_for_user_data(&token, pg_conn.clone()); if id == -1 { return Err(warp::reject::custom("Error: Invalid access token")); } @@ -85,12 +87,16 @@ impl User { filter: Filter::Language, }; - user = user.update_timeline_and_filter(q)?; + user = user.update_timeline_and_filter(q, pg_conn.clone())?; Ok(user) } - fn update_timeline_and_filter(mut self, q: Query) -> Result { + fn update_timeline_and_filter( + mut self, + q: Query, + pg_conn: PostgresConn, + ) -> Result { let read_scope = self.scopes.clone(); let timeline = match q.stream.as_ref() { @@ -110,7 +116,7 @@ impl User { format!("{}", self.id) } // List endpoint: - "list" if self.owns_list(q.list) && (read_scope.all || read_scope.lists) => { + "list" if self.owns_list(q.list, pg_conn) && (read_scope.all || read_scope.lists) => { self.filter = Filter::NoFilter; format!("list:{}", q.list) } @@ -133,8 +139,8 @@ impl User { } /// Determine whether the User is authorised for a specified list - pub fn owns_list(&self, list: i64) -> bool { - match postgres::query_list_owner(list) { + pub fn owns_list(&self, list: i64, pg_conn: PostgresConn) -> bool { + match postgres::query_list_owner(list, pg_conn) { Some(i) if i == self.id => true, _ => false, } diff --git a/src/parse_client_request/user/postgres.rs b/src/parse_client_request/user/postgres.rs index 1d8b90e..ca7935c 100644 --- a/src/parse_client_request/user/postgres.rs +++ b/src/parse_client_request/user/postgres.rs @@ -1,9 +1,34 @@ //! Postgres queries use crate::config; +use ::postgres; +use std::sync::{Arc, Mutex}; + +#[derive(Clone)] +pub struct PostgresConn(pub Arc>); +impl PostgresConn { + pub fn new() -> Self { + let pg_cfg = config::postgres(); + let mut con = postgres::Client::configure(); + con.user(&pg_cfg.user) + .host(&pg_cfg.host) + .port(pg_cfg.port.parse::().unwrap()) + .dbname(&pg_cfg.database); + if let Some(password) = &pg_cfg.password { + con.password(password); + }; + Self(Arc::new(Mutex::new( + con.connect(postgres::NoTls) + .expect("Can connect to local Postgres"), + ))) + } +} #[cfg(not(test))] -pub fn query_for_user_data(access_token: &str) -> (i64, Option>, Vec) { - let mut conn = config::postgres(); +pub fn query_for_user_data( + access_token: &str, + pg_conn: PostgresConn, +) -> (i64, Option>, Vec) { + let mut conn = pg_conn.0.lock().unwrap(); let query_result = conn .query( @@ -53,8 +78,8 @@ pub fn query_for_user_data(access_token: &str) -> (i64, Option>, Vec } #[cfg(not(test))] -pub fn query_list_owner(list_id: i64) -> Option { - let mut conn = config::postgres(); +pub fn query_list_owner(list_id: i64, pg_conn: PostgresConn) -> Option { + let mut conn = pg_conn.0.lock().unwrap(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn .query( diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index c50e833..c399e32 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -1,5 +1,8 @@ //! Filters for the WebSocket endpoint -use super::{query, query::Query, user::User}; +use super::{ + query::{self, Query}, + user::{PostgresConn, User}, +}; use warp::{filters::BoxedFilter, path, Filter}; /// WebSocket filters @@ -29,11 +32,11 @@ fn parse_query() -> BoxedFilter<(Query,)> { .boxed() } -pub fn extract_user_or_reject() -> BoxedFilter<(User,)> { +pub fn extract_user_or_reject(pg_conn: PostgresConn) -> BoxedFilter<(User,)> { parse_query() .and(query::OptionalAccessToken::from_ws_header()) .and_then(Query::update_access_token) - .and_then(User::from_query) + .and_then(move |q| User::from_query(q, pg_conn.clone())) .boxed() } #[cfg(test)] @@ -48,13 +51,14 @@ mod test { }) => { #[test] fn $name() { + let pg_conn = PostgresConn::new(); let user = warp::test::request() .path($path) .header("connection", "upgrade") .header("upgrade", "websocket") .header("sec-websocket-version", "13") .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); assert_eq!(user, $user); } @@ -67,6 +71,7 @@ mod test { }) => { #[test] fn $name() { + let pg_conn = PostgresConn::new(); let path = format!("{}&access_token=TEST_USER", $path); let user = warp::test::request() .path(&path) @@ -74,7 +79,7 @@ mod test { .header("upgrade", "websocket") .header("sec-websocket-version", "13") .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); assert_eq!(user, $user); } @@ -90,9 +95,10 @@ mod test { fn $name() { let path = format!("{}&access_token=INVALID", $path); + let pg_conn = PostgresConn::new(); warp::test::request() .path(&path) - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } }; @@ -105,9 +111,10 @@ mod test { #[should_panic(expected = "Error: Missing access token")] fn $name() { let path = $path; + let pg_conn = PostgresConn::new(); warp::test::request() .path(&path) - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } }; @@ -308,13 +315,14 @@ mod test { #[test] #[should_panic(expected = "NotFound")] fn nonexistant_endpoint() { + let pg_conn = PostgresConn::new(); warp::test::request() .path("/api/v1/streaming/DOES_NOT_EXIST") .header("connection", "upgrade") .header("upgrade", "websocket") .header("sec-websocket-version", "13") .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==") - .filter(&extract_user_or_reject()) + .filter(&extract_user_or_reject(pg_conn)) .expect("in test"); } } diff --git a/src/redis_to_client_stream/receiver.rs b/src/redis_to_client_stream/receiver.rs index 3513948..f308cf7 100644 --- a/src/redis_to_client_stream/receiver.rs +++ b/src/redis_to_client_stream/receiver.rs @@ -1,7 +1,7 @@ //! Receives data from Redis, sorts it by `ClientAgent`, and stores it until //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. -use super::{redis_cmd, redis_stream}; +use super::{redis_cmd, redis_stream, redis_stream::RedisConn}; use crate::{config, pubsub_cmd}; use futures::{Async, Poll}; use serde_json::Value; @@ -14,6 +14,7 @@ use uuid::Uuid; pub struct Receiver { pub pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, + pub redis_namespace: Option, redis_polled_at: time::Instant, timeline: String, manager_id: Uuid, @@ -26,10 +27,16 @@ impl Receiver { /// Create a new `Receiver`, with its own Redis connections (but, as yet, no /// active subscriptions). pub fn new() -> Self { - let (pubsub_connection, secondary_redis_connection) = config::redis_addr(); + let RedisConn { + primary: pubsub_connection, + secondary: secondary_redis_connection, + namespace: redis_namespace, + } = RedisConn::new(); + Self { pubsub_connection, secondary_redis_connection, + redis_namespace, redis_polled_at: time::Instant::now(), timeline: String::new(), manager_id: Uuid::default(), diff --git a/src/redis_to_client_stream/redis_cmd.rs b/src/redis_to_client_stream/redis_cmd.rs index 8b9eb0c..6a9c07a 100644 --- a/src/redis_to_client_stream/redis_cmd.rs +++ b/src/redis_to_client_stream/redis_cmd.rs @@ -1,5 +1,4 @@ //! Send raw TCP commands to the Redis server -use crate::config; use std::fmt::Display; /// Send a subscribe or unsubscribe to the Redis PubSub channel @@ -10,7 +9,7 @@ macro_rules! pubsub_cmd { log::info!("Sending {} command to {}", $cmd, $tl); $self .pubsub_connection - .write_all(&redis_cmd::pubsub($cmd, $tl)) + .write_all(&redis_cmd::pubsub($cmd, $tl, $self.redis_namespace.clone())) .expect("Can send command to Redis"); // Because we keep track of the number of clients subscribed to a channel on our end, // we need to manually tell Redis when we have subscribed or unsubscribed @@ -24,6 +23,7 @@ macro_rules! pubsub_cmd { .write_all(&redis_cmd::set( format!("subscribed:timeline:{}", $tl), subscription_new_number, + $self.redis_namespace.clone(), )) .expect("Can set Redis"); @@ -31,8 +31,8 @@ macro_rules! pubsub_cmd { }}; } /// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline -pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec { - let arg = match &*config::REDIS_NAMESPACE { +pub fn pubsub(command: impl Display, timeline: impl Display, ns: Option) -> Vec { + let arg = match ns { Some(namespace) => format!("{}:timeline:{}", namespace, timeline), None => format!("timeline:{}", timeline), }; @@ -55,8 +55,8 @@ pub fn cmd(command: impl Display, arg: impl Display) -> Vec { } /// Send a `SET` command (used to manually unsubscribe from Redis) -pub fn set(key: impl Display, value: impl Display) -> Vec { - let key = match &*config::REDIS_NAMESPACE { +pub fn set(key: impl Display, value: impl Display, ns: Option) -> Vec { + let key = match ns { Some(namespace) => format!("{}:{}", namespace, key), None => key.to_string(), }; diff --git a/src/redis_to_client_stream/redis_stream.rs b/src/redis_to_client_stream/redis_stream.rs index 3e4cf33..1173b14 100644 --- a/src/redis_to_client_stream/redis_stream.rs +++ b/src/redis_to_client_stream/redis_stream.rs @@ -1,11 +1,58 @@ use super::receiver::Receiver; -use crate::config; +use crate::{config, redis_to_client_stream::redis_cmd}; use futures::{Async, Poll}; use serde_json::Value; -use std::io::Read; -use std::net; +use std::{io::Read, io::Write, net, time}; use tokio::io::AsyncRead; +pub struct RedisConn { + pub primary: net::TcpStream, + pub secondary: net::TcpStream, + pub namespace: Option, +} +impl RedisConn { + pub fn new() -> Self { + let redis_cfg = config::redis(); + let addr = format!("{}:{}", redis_cfg.host, redis_cfg.port); + let mut pubsub_connection = + net::TcpStream::connect(addr.clone()).expect("Can connect to Redis"); + pubsub_connection + .set_read_timeout(Some(time::Duration::from_millis(10))) + .expect("Can set read timeout for Redis connection"); + pubsub_connection + .set_nonblocking(true) + .expect("set_nonblocking call failed"); + let mut secondary_redis_connection = + net::TcpStream::connect(addr).expect("Can connect to Redis"); + secondary_redis_connection + .set_read_timeout(Some(time::Duration::from_millis(10))) + .expect("Can set read timeout for Redis connection"); + if let Some(password) = redis_cfg.password { + pubsub_connection + .write_all(&redis_cmd::cmd("auth", &password)) + .unwrap(); + secondary_redis_connection + .write_all(&redis_cmd::cmd("auth", password)) + .unwrap(); + } + + if let Some(db) = redis_cfg.db { + pubsub_connection + .write_all(&redis_cmd::cmd("SELECT", &db)) + .unwrap(); + secondary_redis_connection + .write_all(&redis_cmd::cmd("SELECT", &db)) + .unwrap(); + } + + Self { + primary: pubsub_connection, + secondary: secondary_redis_connection, + namespace: redis_cfg.namespace, + } + } +} + pub struct AsyncReadableStream<'a>(&'a mut net::TcpStream); impl<'a> AsyncReadableStream<'a> { @@ -41,7 +88,7 @@ If so, set it with the REDIS_PASSWORD environmental variable" }; let mut msg = RedisMsg::from_raw(&receiver.incoming_raw_msg); - let prefix_to_skip = match &*config::REDIS_NAMESPACE { + let prefix_to_skip = match &receiver.redis_namespace { Some(namespace) => format!("{}:timeline:", namespace), None => "timeline:".to_string(), }; @@ -56,7 +103,6 @@ If so, set it with the REDIS_PASSWORD environmental variable" Ok(v) => v, Err(e) => panic!("Unparseable json {}\n\n{}", msg_txt, e), }; - dbg!(&timeline); for msg_queue in receiver.msg_queues.values_mut() { if msg_queue.redis_channel == timeline { msg_queue.messages.push_back(msg_value.clone()); diff --git a/src/rustfmt.toml b/src/rustfmt.toml new file mode 100644 index 0000000..c51666e --- /dev/null +++ b/src/rustfmt.toml @@ -0,0 +1 @@ +edition = "2018" \ No newline at end of file