Config refactor (#57)

* Refactor configuration

* Fix bug with incorrect Host env variable

* Improve logging of REDIS_NAMESPACE

* Update test for Postgres configuration

* Conform Redis config to Postgres changes
This commit is contained in:
Daniel Sockwell 2019-10-03 00:34:41 -04:00 committed by GitHub
parent 11661d2fdc
commit e8145275b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 412 additions and 286 deletions

2
Cargo.lock generated
View File

@ -386,7 +386,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
version = "0.3.4"
version = "0.3.5"
dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@ -1,7 +1,7 @@
[package]
name = "flodgatt"
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
version = "0.3.4"
version = "0.3.5"
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
edition = "2018"

View File

@ -1,233 +0,0 @@
//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden
//! by an environmental variable of the same name without that prefix (either by setting
//! the variable at runtime or in the `.env` file)
use dotenv::dotenv;
use lazy_static::lazy_static;
use log::warn;
use std::{env, io::Write, net, time};
use url::Url;
use crate::{err, redis_to_client_stream::redis_cmd};
const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"];
const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"];
// Postgres
const DEFAULT_DB_HOST: &str = "localhost";
const DEFAULT_DB_USER: &str = "postgres";
const DEFAULT_DB_NAME: &str = "mastodon_development";
const DEFAULT_DB_PORT: &str = "5432";
const DEFAULT_DB_SSLMODE: &str = "prefer";
// Redis
const DEFAULT_REDIS_HOST: &str = "127.0.0.1";
const DEFAULT_REDIS_PORT: &str = "6379";
const _DEFAULT_REDIS_NAMESPACE: &str = "";
// Deployment
const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000";
const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100;
const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100;
/// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver`
/// (on the order of 10ms rather than 50μs). Thus, changing this setting
/// would be a good place to start for performance improvements at the cost
/// of delaying all updates.
const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100;
fn default(var: &str, default_var: &str) -> String {
env::var(var)
.unwrap_or_else(|_| {
warn!(
"No {} env variable set. Using default value: {}",
var, default_var
);
default_var.to_string()
})
.to_string()
}
lazy_static! {
static ref POSTGRES_ADDR: String = match &env::var("DATABASE_URL") {
Ok(url) => {
warn!("DATABASE_URL env variable set. Connecting to Postgres with that URL and ignoring any values set in DB_HOST, DB_USER, DB_NAME, DB_PASS, or DB_PORT.");
url.to_string()
}
Err(_) => {
let user = &env::var("DB_USER").unwrap_or_else(|_| {
match &env::var("USER") {
Err(_) => default("DB_USER", DEFAULT_DB_USER),
Ok(user) => default("DB_USER", user)
}
});
let host = &env::var("DB_HOST")
.unwrap_or_else(|_| default("DB_HOST", DEFAULT_DB_HOST));
let db_name = &env::var("DB_NAME")
.unwrap_or_else(|_| default("DB_NAME", DEFAULT_DB_NAME));
let port = &env::var("DB_PORT")
.unwrap_or_else(|_| default("DB_PORT", DEFAULT_DB_PORT));
let ssl_mode = &env::var("DB_SSLMODE")
.unwrap_or_else(|_| default("DB_SSLMODE", DEFAULT_DB_SSLMODE));
match &env::var("DB_PASS") {
Ok(password) => {
format!("postgres://{}:{}@{}:{}/{}?sslmode={}",
user, password, host, port, db_name, ssl_mode)},
Err(_) => {
warn!("No DB_PASSWORD set. Attempting to connect to Postgres without a password. (This is correct if you are using the `ident` method.)");
format!("postgres://{}@{}:{}/{}?sslmode={}",
user, host, port, db_name, ssl_mode)
},
}
}
};
static ref REDIS_ADDR: RedisConfig = match &env::var("REDIS_URL") {
Ok(url) => {
warn!(r"REDIS_URL env variable set.
Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT.");
let url = Url::parse(url).unwrap();
fn none_if_empty(item: &str) -> Option<String> {
if item.is_empty() { None } else { Some(item.to_string()) }
};
let user = none_if_empty(url.username());
let mut password = url.password().as_ref().map(|str| str.to_string());
let host = err::unwrap_or_die(url.host_str(),"Missing/invalid host in REDIS_URL");
let port = err::unwrap_or_die(url.port(), "Missing/invalid port in REDIS_URL");
let mut db = none_if_empty(url.path());
let query_pairs = url.query_pairs();
for (key, value) in query_pairs {
match key.to_string().as_str() {
"password" => { password = Some(value.to_string());},
"db" => { db = Some(value.to_string())}
_ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", key))}
}
}
RedisConfig {
user,
password,
host,
port,
db
}
}
Err(_) => {
let host = env::var("REDIS_HOST")
.unwrap_or_else(|_| default("REDIS_HOST", DEFAULT_REDIS_HOST));
let port = env::var("REDIS_PORT")
.unwrap_or_else(|_| default("REDIS_PORT", DEFAULT_REDIS_PORT));
RedisConfig {
user: None,
password: None,
host,
port,
db: None,
}
}
};
pub static ref REDIS_NAMESPACE: Option<String> = match env::var("REDIS_NAMESPACE") {
Ok(ns) => {
log::warn!("Using `{}:` as a Redis namespace.", ns);
Some(ns)
},
_ => None
};
pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR")
.unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned())
.parse()
.expect("static string");
/// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via SSE.
pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL);
/// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via WS.
pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_WS_UPDATE_INTERVAL);
/// Interval, in ms, at which the `Receiver` polls Redis.
pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_REDIS_POLL_INTERVAL);
}
/// Configure CORS for the API server
pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors {
warp::cors()
.allow_any_origin()
.allow_methods(CORS_ALLOWED_METHODS.to_vec())
.allow_headers(CORS_ALLOWED_HEADERS.to_vec())
}
/// Initialize logging and read values from `src/.env`
pub fn logging_and_env() {
dotenv().ok();
pretty_env_logger::init();
POSTGRES_ADDR.to_string();
}
/// Configure Postgres and return a connection
pub fn postgres() -> postgres::Client {
// use openssl::ssl::{SslConnector, SslMethod};
// use postgres_openssl::MakeTlsConnector;
// let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
// builder.set_ca_file("/etc/ssl/cert.pem").unwrap();
// let connector = MakeTlsConnector::new(builder.build());
// TODO: add TLS support, remove `NoTls`
postgres::Client::connect(&POSTGRES_ADDR.to_string(), postgres::NoTls)
.expect("Can connect to local Postgres")
}
#[derive(Default)]
struct RedisConfig {
user: Option<String>,
password: Option<String>,
port: String,
host: String,
db: Option<String>,
}
/// Configure Redis
pub fn redis_addr() -> (net::TcpStream, net::TcpStream) {
let redis = &REDIS_ADDR;
let addr = format!("{}:{}", redis.host, redis.port);
if let Some(user) = &redis.user {
log::error!(
"Username {} provided, but Redis does not need a username. Ignoring it",
user
);
};
let mut pubsub_connection =
net::TcpStream::connect(addr.clone()).expect("Can connect to Redis");
pubsub_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
pubsub_connection
.set_nonblocking(true)
.expect("set_nonblocking call failed");
let mut secondary_redis_connection =
net::TcpStream::connect(addr).expect("Can connect to Redis");
secondary_redis_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
if let Some(password) = &REDIS_ADDR.password {
pubsub_connection
.write_all(&redis_cmd::cmd("auth", &password))
.unwrap();
secondary_redis_connection
.write_all(&redis_cmd::cmd("auth", password))
.unwrap();
} else {
warn!("No REDIS_PASSWORD set. Attempting to connect to Redis without a password. (This is correct if you are following the default setup.)");
}
if let Some(db) = &REDIS_ADDR.db {
pubsub_connection
.write_all(&redis_cmd::cmd("SELECT", &db))
.unwrap();
secondary_redis_connection
.write_all(&redis_cmd::cmd("SELECT", &db))
.unwrap();
}
(pubsub_connection, secondary_redis_connection)
}

128
src/config/mod.rs Normal file
View File

@ -0,0 +1,128 @@
//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden
//! by an environmental variable of the same name without that prefix (either by setting
//! the variable at runtime or in the `.env` file)
mod postgres_cfg;
mod redis_cfg;
pub use self::{postgres_cfg::PostgresConfig, redis_cfg::RedisConfig};
use dotenv::dotenv;
use lazy_static::lazy_static;
use log::warn;
use std::{env, net};
use url::Url;
const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"];
const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"];
// Postgres
// Deployment
const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000";
const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100;
const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100;
/// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver`
/// (on the order of 10ms rather than 50μs). Thus, changing this setting
/// would be a good place to start for performance improvements at the cost
/// of delaying all updates.
const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100;
lazy_static! {
pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR")
.unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned())
.parse()
.expect("static string");
/// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via SSE.
pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL);
/// Interval, in ms, at which `ClientAgent` polls `Receiver` for updates to send via WS.
pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_WS_UPDATE_INTERVAL);
/// Interval, in ms, at which the `Receiver` polls Redis.
pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_REDIS_POLL_INTERVAL);
}
/// Configure CORS for the API server
pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors {
warp::cors()
.allow_any_origin()
.allow_methods(CORS_ALLOWED_METHODS.to_vec())
.allow_headers(CORS_ALLOWED_HEADERS.to_vec())
}
/// Initialize logging and read values from `src/.env`
pub fn logging_and_env() {
dotenv().ok();
pretty_env_logger::init();
}
/// Configure Postgres and return a connection
pub fn postgres() -> PostgresConfig {
// use openssl::ssl::{SslConnector, SslMethod};
// use postgres_openssl::MakeTlsConnector;
// let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
// builder.set_ca_file("/etc/ssl/cert.pem").unwrap();
// let connector = MakeTlsConnector::new(builder.build());
// TODO: add TLS support, remove `NoTls`
let pg_cfg = match &env::var("DATABASE_URL").ok() {
Some(url) => {
warn!("DATABASE_URL env variable set. Connecting to Postgres with that URL and ignoring any values set in DB_HOST, DB_USER, DB_NAME, DB_PASS, or DB_PORT.");
PostgresConfig::from_url(Url::parse(url).unwrap())
}
None => PostgresConfig::default()
.maybe_update_user(env::var("USER").ok())
.maybe_update_user(env::var("DB_USER").ok())
.maybe_update_host(env::var("DB_HOST").ok())
.maybe_update_password(env::var("DB_PASS").ok())
.maybe_update_db(env::var("DB_NAME").ok())
.maybe_update_sslmode(env::var("DB_SSLMODE").ok()),
};
log::warn!(
"Connecting to Postgres with the following configuration:\n{:#?}",
&pg_cfg
);
pg_cfg
}
/// Configure Redis and return a pair of connections
pub fn redis() -> RedisConfig {
let redis_cfg = match &env::var("REDIS_URL") {
Ok(url) => {
warn!("REDIS_URL env variable set. Connecting to Redis with that URL and ignoring any values set in REDIS_HOST or DB_PORT.");
RedisConfig::from_url(Url::parse(url).unwrap())
}
Err(_) => RedisConfig::default()
.maybe_update_host(env::var("REDIS_HOST").ok())
.maybe_update_port(env::var("REDIS_PORT").ok()),
}.maybe_update_namespace(env::var("REDIS_NAMESPACE").ok());
if let Some(user) = &redis_cfg.user {
log::error!(
"Username {} provided, but Redis does not need a username. Ignoring it",
user
);
};
log::warn!(
"Connecting to Redis with the following configuration:\n{:#?}",
&redis_cfg
);
redis_cfg
}
#[macro_export]
macro_rules! maybe_update {
($name:ident; $item: tt) => (
pub fn $name(self, item: Option<String>) -> Self{
match item {
Some($item) => Self{ $item, ..self },
None => Self { ..self }
}
});
($name:ident; Some($item: tt)) => (
pub fn $name(self, item: Option<String>) -> Self{
match item {
Some($item) => Self{ $item: Some($item), ..self },
None => Self { ..self }
}
})}

View File

@ -0,0 +1,63 @@
use crate::{err, maybe_update};
use url::Url;
#[derive(Debug)]
pub struct PostgresConfig {
pub user: String,
pub host: String,
pub password: Option<String>,
pub database: String,
pub port: String,
pub ssl_mode: String,
}
impl Default for PostgresConfig {
fn default() -> Self {
Self {
user: "postgres".to_string(),
host: "localhost".to_string(),
password: None,
database: "mastodon_development".to_string(),
port: "5432".to_string(),
ssl_mode: "prefer".to_string(),
}
}
}
fn none_if_empty(item: &str) -> Option<String> {
Some(item).filter(|i| !i.is_empty()).map(String::from)
}
impl PostgresConfig {
maybe_update!(maybe_update_user; user);
maybe_update!(maybe_update_host; host);
maybe_update!(maybe_update_db; database);
maybe_update!(maybe_update_port; port);
maybe_update!(maybe_update_sslmode; ssl_mode);
maybe_update!(maybe_update_password; Some(password));
pub fn from_url(url: Url) -> Self {
let (mut user, mut host, mut sslmode, mut password) = (None, None, None, None);
for (k, v) in url.query_pairs() {
match k.to_string().as_str() {
"user" => { user = Some(v.to_string());},
"password" => { password = Some(v.to_string());},
"host" => { host = Some(v.to_string());},
"sslmode" => { sslmode = Some(v.to_string());},
_ => { err::die_with_msg(format!("Unsupported parameter {} in DATABASE_URL.\n Flodgatt supports only `user`, `password`, `host`, and `sslmode` parameters.", k))}
}
}
Self::default()
// Values from query parameter
.maybe_update_user(user)
.maybe_update_password(password)
.maybe_update_host(host)
.maybe_update_sslmode(sslmode)
// Values from URL (which override query values if both are present)
.maybe_update_user(none_if_empty(url.username()))
.maybe_update_host(url.host_str().filter(|h| !h.is_empty()).map(String::from))
.maybe_update_password(url.password().map(String::from))
.maybe_update_port(url.port().map(|port_num| port_num.to_string()))
.maybe_update_db(none_if_empty(&url.path()[1..]))
}
}

56
src/config/redis_cfg.rs Normal file
View File

@ -0,0 +1,56 @@
use crate::{err, maybe_update};
use url::Url;
fn none_if_empty(item: &str) -> Option<String> {
if item.is_empty() {
None
} else {
Some(item.to_string())
}
}
#[derive(Debug)]
pub struct RedisConfig {
pub user: Option<String>,
pub password: Option<String>,
pub port: String,
pub host: String,
pub db: Option<String>,
pub namespace: Option<String>,
}
impl Default for RedisConfig {
fn default() -> Self {
Self {
user: None,
password: None,
db: None,
port: "6379".to_string(),
host: "127.0.0.1".to_string(),
namespace: None,
}
}
}
impl RedisConfig {
pub fn from_url(url: Url) -> Self {
let mut password = url.password().as_ref().map(|str| str.to_string());
let mut db = none_if_empty(&url.path()[1..]);
for (k, v) in url.query_pairs() {
match k.to_string().as_str() {
"password" => { password = Some(v.to_string());},
"db" => { db = Some(v.to_string())},
_ => { err::die_with_msg(format!("Unsupported parameter {} in REDIS_URL.\n Flodgatt supports only `password` and `db` parameters.", k))}
}
}
RedisConfig {
user: none_if_empty(url.username()),
host: err::unwrap_or_die(url.host_str(), "Missing or invalid host in REDIS_URL"),
port: err::unwrap_or_die(url.port(), "Missing or invalid port in REDIS_URL"),
namespace: None,
password,
db,
}
}
maybe_update!(maybe_update_host; host);
maybe_update!(maybe_update_port; port);
maybe_update!(maybe_update_namespace; Some(namespace));
}

View File

@ -1,8 +1,7 @@
use flodgatt::{
config, err,
parse_client_request::{sse, user, ws},
redis_to_client_stream,
redis_to_client_stream::ClientAgent,
redis_to_client_stream::{self, ClientAgent},
};
use log::warn;
use warp::{ws::Ws2, Filter as WarpFilter};
@ -11,11 +10,12 @@ fn main() {
config::logging_and_env();
let client_agent_sse = ClientAgent::blank();
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
let pg_conn = user::PostgresConn::new();
warn!("Streaming server initialized and ready to accept connections");
// Server Sent Events
let sse_routes = sse::extract_user_or_reject()
let sse_routes = sse::extract_user_or_reject(pg_conn.clone())
.and(warp::sse())
.map(
move |user: user::User, sse_connection_to_client: warp::sse::Sse| {
@ -31,7 +31,7 @@ fn main() {
.recover(err::handle_errors);
// WebSocket
let websocket_routes = ws::extract_user_or_reject()
let websocket_routes = ws::extract_user_or_reject(pg_conn.clone())
.and(warp::ws::ws2())
.map(move |user: user::User, ws: Ws2| {
let token = user.access_token.clone();

View File

@ -1,7 +1,9 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use super::{query, query::Query, user::User};
use super::{
query::{self, Query},
user::{PostgresConn, User},
};
use warp::{filters::BoxedFilter, path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, User),);
@ -37,7 +39,7 @@ macro_rules! parse_query {
.boxed()
};
}
pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
pub fn extract_user_or_reject(pg_conn: PostgresConn) -> BoxedFilter<(User,)> {
any_of!(
parse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
@ -65,14 +67,14 @@ pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(User::from_query)
.and_then(move |q| User::from_query(q, pg_conn.clone()))
.boxed()
}
#[cfg(test)]
mod test {
use super::*;
use crate::parse_client_request::user::{Filter, OauthScope};
use crate::parse_client_request::user::{Filter, OauthScope, PostgresConn};
macro_rules! test_public_endpoint {
($name:ident {
@ -81,9 +83,10 @@ mod test {
}) => {
#[test]
fn $name() {
let pg_conn = PostgresConn::new();
let user = warp::test::request()
.path($path)
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
assert_eq!(user, $user);
}
@ -98,16 +101,17 @@ mod test {
#[test]
fn $name() {
let path = format!("{}?access_token=TEST_USER", $path);
let pg_conn = PostgresConn::new();
$(let path = format!("{}&{}", path, $query);)*
let user = warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn.clone()))
.expect("in test");
assert_eq!(user, $user);
let user = warp::test::request()
.path(&path)
.header("Authorization", "Bearer: TEST_USER")
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
assert_eq!(user, $user);
}
@ -123,9 +127,10 @@ mod test {
fn $name() {
let path = format!("{}?access_token=INVALID", $path);
$(let path = format!("{}&{}", path, $query);)*
let pg_conn = PostgresConn::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
};
@ -140,10 +145,12 @@ mod test {
fn $name() {
let path = $path;
$(let path = format!("{}?{}", path, $query);)*
let pg_conn = PostgresConn::new();
warp::test::request()
.path(&path)
.header("Authorization", "Bearer: INVALID")
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
};
@ -158,9 +165,10 @@ mod test {
fn $name() {
let path = $path;
$(let path = format!("{}?{}", path, $query);)*
let pg_conn = PostgresConn::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
};
@ -429,10 +437,10 @@ mod test {
#[test]
#[should_panic(expected = "NotFound")]
fn nonexistant_endpoint() {
let pg_conn = PostgresConn::new();
warp::test::request()
.path("/api/v1/streaming/DOES_NOT_EXIST")
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
}

View File

@ -1,6 +1,17 @@
//! Mock Postgres connection (for use in unit testing)
use std::sync::{Arc, Mutex};
pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec<String>) {
#[derive(Clone)]
pub struct PostgresConn(Arc<Mutex<String>>);
impl PostgresConn {
pub fn new() -> Self {
Self(Arc::new(Mutex::new("MOCK".to_string())))
}
}
pub fn query_for_user_data(
access_token: &str,
_pg_conn: PostgresConn,
) -> (i64, Option<Vec<String>>, Vec<String>) {
let (user_id, lang, scopes) = if access_token == "TEST_USER" {
(
1,
@ -17,7 +28,7 @@ pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec
(user_id, lang, scopes)
}
pub fn query_list_owner(list_id: i64) -> Option<i64> {
pub fn query_list_owner(list_id: i64, _pg_conn: PostgresConn) -> Option<i64> {
match list_id {
1 => Some(1),
_ => None,

View File

@ -5,6 +5,7 @@ mod mock_postgres;
use mock_postgres as postgres;
#[cfg(not(test))]
mod postgres;
pub use self::postgres::PostgresConn;
use super::query::Query;
use warp::reject::Rejection;
@ -57,7 +58,7 @@ impl From<Vec<String>> for OauthScope {
}
impl User {
pub fn from_query(q: Query) -> Result<Self, Rejection> {
pub fn from_query(q: Query, pg_conn: PostgresConn) -> Result<Self, Rejection> {
let (id, access_token, scopes, langs, logged_in) = match q.access_token.clone() {
None => (
-1,
@ -67,7 +68,8 @@ impl User {
false,
),
Some(token) => {
let (id, langs, scope_list) = postgres::query_for_user_data(&token);
let (id, langs, scope_list) =
postgres::query_for_user_data(&token, pg_conn.clone());
if id == -1 {
return Err(warp::reject::custom("Error: Invalid access token"));
}
@ -85,12 +87,16 @@ impl User {
filter: Filter::Language,
};
user = user.update_timeline_and_filter(q)?;
user = user.update_timeline_and_filter(q, pg_conn.clone())?;
Ok(user)
}
fn update_timeline_and_filter(mut self, q: Query) -> Result<Self, Rejection> {
fn update_timeline_and_filter(
mut self,
q: Query,
pg_conn: PostgresConn,
) -> Result<Self, Rejection> {
let read_scope = self.scopes.clone();
let timeline = match q.stream.as_ref() {
@ -110,7 +116,7 @@ impl User {
format!("{}", self.id)
}
// List endpoint:
"list" if self.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
"list" if self.owns_list(q.list, pg_conn) && (read_scope.all || read_scope.lists) => {
self.filter = Filter::NoFilter;
format!("list:{}", q.list)
}
@ -133,8 +139,8 @@ impl User {
}
/// Determine whether the User is authorised for a specified list
pub fn owns_list(&self, list: i64) -> bool {
match postgres::query_list_owner(list) {
pub fn owns_list(&self, list: i64, pg_conn: PostgresConn) -> bool {
match postgres::query_list_owner(list, pg_conn) {
Some(i) if i == self.id => true,
_ => false,
}

View File

@ -1,9 +1,34 @@
//! Postgres queries
use crate::config;
use ::postgres;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct PostgresConn(pub Arc<Mutex<postgres::Client>>);
impl PostgresConn {
pub fn new() -> Self {
let pg_cfg = config::postgres();
let mut con = postgres::Client::configure();
con.user(&pg_cfg.user)
.host(&pg_cfg.host)
.port(pg_cfg.port.parse::<u16>().unwrap())
.dbname(&pg_cfg.database);
if let Some(password) = &pg_cfg.password {
con.password(password);
};
Self(Arc::new(Mutex::new(
con.connect(postgres::NoTls)
.expect("Can connect to local Postgres"),
)))
}
}
#[cfg(not(test))]
pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec<String>) {
let mut conn = config::postgres();
pub fn query_for_user_data(
access_token: &str,
pg_conn: PostgresConn,
) -> (i64, Option<Vec<String>>, Vec<String>) {
let mut conn = pg_conn.0.lock().unwrap();
let query_result = conn
.query(
@ -53,8 +78,8 @@ pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec
}
#[cfg(not(test))]
pub fn query_list_owner(list_id: i64) -> Option<i64> {
let mut conn = config::postgres();
pub fn query_list_owner(list_id: i64, pg_conn: PostgresConn) -> Option<i64> {
let mut conn = pg_conn.0.lock().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(

View File

@ -1,5 +1,8 @@
//! Filters for the WebSocket endpoint
use super::{query, query::Query, user::User};
use super::{
query::{self, Query},
user::{PostgresConn, User},
};
use warp::{filters::BoxedFilter, path, Filter};
/// WebSocket filters
@ -29,11 +32,11 @@ fn parse_query() -> BoxedFilter<(Query,)> {
.boxed()
}
pub fn extract_user_or_reject() -> BoxedFilter<(User,)> {
pub fn extract_user_or_reject(pg_conn: PostgresConn) -> BoxedFilter<(User,)> {
parse_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(User::from_query)
.and_then(move |q| User::from_query(q, pg_conn.clone()))
.boxed()
}
#[cfg(test)]
@ -48,13 +51,14 @@ mod test {
}) => {
#[test]
fn $name() {
let pg_conn = PostgresConn::new();
let user = warp::test::request()
.path($path)
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
assert_eq!(user, $user);
}
@ -67,6 +71,7 @@ mod test {
}) => {
#[test]
fn $name() {
let pg_conn = PostgresConn::new();
let path = format!("{}&access_token=TEST_USER", $path);
let user = warp::test::request()
.path(&path)
@ -74,7 +79,7 @@ mod test {
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
assert_eq!(user, $user);
}
@ -90,9 +95,10 @@ mod test {
fn $name() {
let path = format!("{}&access_token=INVALID", $path);
let pg_conn = PostgresConn::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
};
@ -105,9 +111,10 @@ mod test {
#[should_panic(expected = "Error: Missing access token")]
fn $name() {
let path = $path;
let pg_conn = PostgresConn::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
};
@ -308,13 +315,14 @@ mod test {
#[test]
#[should_panic(expected = "NotFound")]
fn nonexistant_endpoint() {
let pg_conn = PostgresConn::new();
warp::test::request()
.path("/api/v1/streaming/DOES_NOT_EXIST")
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject())
.filter(&extract_user_or_reject(pg_conn))
.expect("in test");
}
}

View File

@ -1,7 +1,7 @@
//! Receives data from Redis, sorts it by `ClientAgent`, and stores it until
//! polled by the correct `ClientAgent`. Also manages sububscriptions and
//! unsubscriptions to/from Redis.
use super::{redis_cmd, redis_stream};
use super::{redis_cmd, redis_stream, redis_stream::RedisConn};
use crate::{config, pubsub_cmd};
use futures::{Async, Poll};
use serde_json::Value;
@ -14,6 +14,7 @@ use uuid::Uuid;
pub struct Receiver {
pub pubsub_connection: net::TcpStream,
secondary_redis_connection: net::TcpStream,
pub redis_namespace: Option<String>,
redis_polled_at: time::Instant,
timeline: String,
manager_id: Uuid,
@ -26,10 +27,16 @@ impl Receiver {
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
/// active subscriptions).
pub fn new() -> Self {
let (pubsub_connection, secondary_redis_connection) = config::redis_addr();
let RedisConn {
primary: pubsub_connection,
secondary: secondary_redis_connection,
namespace: redis_namespace,
} = RedisConn::new();
Self {
pubsub_connection,
secondary_redis_connection,
redis_namespace,
redis_polled_at: time::Instant::now(),
timeline: String::new(),
manager_id: Uuid::default(),

View File

@ -1,5 +1,4 @@
//! Send raw TCP commands to the Redis server
use crate::config;
use std::fmt::Display;
/// Send a subscribe or unsubscribe to the Redis PubSub channel
@ -10,7 +9,7 @@ macro_rules! pubsub_cmd {
log::info!("Sending {} command to {}", $cmd, $tl);
$self
.pubsub_connection
.write_all(&redis_cmd::pubsub($cmd, $tl))
.write_all(&redis_cmd::pubsub($cmd, $tl, $self.redis_namespace.clone()))
.expect("Can send command to Redis");
// Because we keep track of the number of clients subscribed to a channel on our end,
// we need to manually tell Redis when we have subscribed or unsubscribed
@ -24,6 +23,7 @@ macro_rules! pubsub_cmd {
.write_all(&redis_cmd::set(
format!("subscribed:timeline:{}", $tl),
subscription_new_number,
$self.redis_namespace.clone(),
))
.expect("Can set Redis");
@ -31,8 +31,8 @@ macro_rules! pubsub_cmd {
}};
}
/// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline
pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec<u8> {
let arg = match &*config::REDIS_NAMESPACE {
pub fn pubsub(command: impl Display, timeline: impl Display, ns: Option<String>) -> Vec<u8> {
let arg = match ns {
Some(namespace) => format!("{}:timeline:{}", namespace, timeline),
None => format!("timeline:{}", timeline),
};
@ -55,8 +55,8 @@ pub fn cmd(command: impl Display, arg: impl Display) -> Vec<u8> {
}
/// Send a `SET` command (used to manually unsubscribe from Redis)
pub fn set(key: impl Display, value: impl Display) -> Vec<u8> {
let key = match &*config::REDIS_NAMESPACE {
pub fn set(key: impl Display, value: impl Display, ns: Option<String>) -> Vec<u8> {
let key = match ns {
Some(namespace) => format!("{}:{}", namespace, key),
None => key.to_string(),
};

View File

@ -1,11 +1,58 @@
use super::receiver::Receiver;
use crate::config;
use crate::{config, redis_to_client_stream::redis_cmd};
use futures::{Async, Poll};
use serde_json::Value;
use std::io::Read;
use std::net;
use std::{io::Read, io::Write, net, time};
use tokio::io::AsyncRead;
pub struct RedisConn {
pub primary: net::TcpStream,
pub secondary: net::TcpStream,
pub namespace: Option<String>,
}
impl RedisConn {
pub fn new() -> Self {
let redis_cfg = config::redis();
let addr = format!("{}:{}", redis_cfg.host, redis_cfg.port);
let mut pubsub_connection =
net::TcpStream::connect(addr.clone()).expect("Can connect to Redis");
pubsub_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
pubsub_connection
.set_nonblocking(true)
.expect("set_nonblocking call failed");
let mut secondary_redis_connection =
net::TcpStream::connect(addr).expect("Can connect to Redis");
secondary_redis_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
if let Some(password) = redis_cfg.password {
pubsub_connection
.write_all(&redis_cmd::cmd("auth", &password))
.unwrap();
secondary_redis_connection
.write_all(&redis_cmd::cmd("auth", password))
.unwrap();
}
if let Some(db) = redis_cfg.db {
pubsub_connection
.write_all(&redis_cmd::cmd("SELECT", &db))
.unwrap();
secondary_redis_connection
.write_all(&redis_cmd::cmd("SELECT", &db))
.unwrap();
}
Self {
primary: pubsub_connection,
secondary: secondary_redis_connection,
namespace: redis_cfg.namespace,
}
}
}
pub struct AsyncReadableStream<'a>(&'a mut net::TcpStream);
impl<'a> AsyncReadableStream<'a> {
@ -41,7 +88,7 @@ If so, set it with the REDIS_PASSWORD environmental variable"
};
let mut msg = RedisMsg::from_raw(&receiver.incoming_raw_msg);
let prefix_to_skip = match &*config::REDIS_NAMESPACE {
let prefix_to_skip = match &receiver.redis_namespace {
Some(namespace) => format!("{}:timeline:", namespace),
None => "timeline:".to_string(),
};
@ -56,7 +103,6 @@ If so, set it with the REDIS_PASSWORD environmental variable"
Ok(v) => v,
Err(e) => panic!("Unparseable json {}\n\n{}", msg_txt, e),
};
dbg!(&timeline);
for msg_queue in receiver.msg_queues.values_mut() {
if msg_queue.redis_channel == timeline {
msg_queue.messages.push_back(msg_value.clone());

1
src/rustfmt.toml Normal file
View File

@ -0,0 +1 @@
edition = "2018"