Fix valid language (#93)

* Fix panic on delete events

Previously, the code attempted to check the toot's language regardless
of event types.  That caused a panic for `delete` events, which lack a
language.

* WIP implementation of Message refactor

* Major refactor

* Refactor scope managment to use enum

* Use Timeline type instead of String

* Clean up Receiver's use of Timeline

* Make debug output more readable

* Block statuses from blocking users

This commit fixes an issue where a status from A would be displayed on
B's public timelines even when A had B blocked (i.e., it would treat B
as though they were muted rather than blocked for the purpose of
public timelines).

* Fix bug with incorrect parsing of incomming timeline

* Disable outdated tests

* Bump version
This commit is contained in:
Daniel Sockwell 2020-03-18 20:37:10 -04:00 committed by GitHub
parent 440d691b0f
commit 8843f18f5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1436 additions and 1100 deletions

69
Cargo.lock generated
View File

@ -1,5 +1,13 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
[[package]]
name = "ahash"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"const-random 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "aho-corasick"
version = "0.7.6"
@ -33,7 +41,7 @@ dependencies = [
[[package]]
name = "autocfg"
version = "0.1.2"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
@ -41,7 +49,7 @@ name = "backtrace"
version = "0.3.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"backtrace-sys 0.1.28 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
@ -192,6 +200,24 @@ dependencies = [
"bitflags 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "const-random"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"const-random-macro 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-hack 0.5.11 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "const-random-macro"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-hack 0.5.11 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "criterion"
version = "0.3.0"
@ -414,12 +440,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
version = "0.4.8"
version = "0.5.0"
dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"lru 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)",
"postgres 0.17.0 (registry+https://github.com/rust-lang/crates.io-index)",
"postgres-openssl 0.2.0-rc.1 (git+https://github.com/sfackler/rust-postgres.git)",
"pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
@ -613,6 +640,15 @@ dependencies = [
"tokio-io 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "hashbrown"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"ahash 0.2.18 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "headers"
version = "0.2.1"
@ -817,6 +853,14 @@ dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "lru"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"hashbrown 0.6.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "matches"
version = "0.1.8"
@ -957,7 +1001,7 @@ name = "num-integer"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -966,7 +1010,7 @@ name = "num-traits"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -1005,7 +1049,7 @@ name = "openssl-sys"
version = "0.9.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
"pkg-config 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1315,7 +1359,7 @@ name = "rand"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1345,7 +1389,7 @@ name = "rand_chacha"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -1440,7 +1484,7 @@ name = "rand_pcg"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -2334,11 +2378,12 @@ dependencies = [
]
[metadata]
"checksum ahash 0.2.18 (registry+https://github.com/rust-lang/crates.io-index)" = "6f33b5018f120946c1dcf279194f238a9f146725593ead1c08fa47ff22b0b5d3"
"checksum aho-corasick 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)" = "58fb5e95d83b38284460a5fda7d6470aa0b8844d283a0b614b8535e880800d2d"
"checksum antidote 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "34fde25430d87a9388dadbe6e34d7f72a462c8b43ac8d309b42b0a8505d7e2a5"
"checksum arrayvec 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)" = "92c7fb76bc8826a8b33b4ee5bb07a247a81e76764ab4d55e8f73e3a4d8808c71"
"checksum atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "9a7d5b8723950951411ee34d271d99dddcc2035a16ab25310ea2c8cfd4369652"
"checksum autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a6d640bee2da49f60a4068a7fae53acde8982514ab7bae8b8cea9e88cbcfd799"
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
"checksum backtrace 0.3.15 (registry+https://github.com/rust-lang/crates.io-index)" = "f106c02a3604afcdc0df5d36cc47b44b55917dbaf3d808f71c163a0ddba64637"
"checksum backtrace-sys 0.1.28 (registry+https://github.com/rust-lang/crates.io-index)" = "797c830ac25ccc92a7f8a7b9862bde440715531514594a6154e3d4a54dd769b6"
"checksum base64 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0b25d992356d2eb0ed82172f5248873db5560c4721f564b13cb5193bda5e668e"
@ -2359,6 +2404,8 @@ dependencies = [
"checksum chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "77d81f58b7301084de3b958691458a53c3f7e0b1d702f77e550b6a88e3a88abe"
"checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9"
"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f"
"checksum const-random 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "2f1af9ac737b2dd2d577701e59fd09ba34822f6f2ebdb30a7647405d9e55e16a"
"checksum const-random-macro 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "25e4c606eb459dd29f7c57b2e0879f2b6f14ee130918c2b78ccb58a9624e6c7a"
"checksum criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "938703e165481c8d612ea3479ac8342e5615185db37765162e762ec3523e2fc6"
"checksum criterion-plot 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "eccdc6ce8bbe352ca89025bee672aa6d24f4eb8c53e3a8b5d1bc58011da072a2"
"checksum crossbeam-deque 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b18cd2e169ad86297e6bc0ad9aa679aee9daa4f19e8163860faf7c164e4f5a71"
@ -2403,6 +2450,7 @@ dependencies = [
"checksum generic-array 0.13.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0ed1e761351b56f54eb9dcd0cfaca9fd0daecf93918e1cfc01c8a3d26ee7adcd"
"checksum getrandom 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)" = "473a1265acc8ff1e808cd0a1af8cee3c2ee5200916058a2ca113c29f2d903571"
"checksum h2 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)" = "85ab6286db06040ddefb71641b50017c06874614001a134b423783e2db2920bd"
"checksum hashbrown 0.6.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8e6073d0ca812575946eb5f35ff68dbe519907b25c42530389ff946dc84c6ead"
"checksum headers 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6e2e51d356081258ef05ff4c648138b5d3fe64b7300aaad3b820554a2b7fb6"
"checksum headers-core 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "51ae5b0b5417559ee1d2733b21d33b0868ae9e406bd32eb1a51d613f66ed472a"
"checksum headers-derive 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "97c462e8066bca4f0968ddf8d12de64c40f2c2187b3b9a2fa994d06e8ad444a9"
@ -2426,6 +2474,7 @@ dependencies = [
"checksum lock_api 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "79b2de95ecb4691949fea4716ca53cdbcfccb2c612e19644a8bad05edcf9f47b"
"checksum log 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "e19e8d5c34a3e0e2223db8e060f9e8264aeeb5c5fc64a4ee9965c062211c024b"
"checksum log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c84ec4b527950aa83a329754b01dbe3f58361d1c5efacd1f6d68c494d08a17c6"
"checksum lru 0.4.3 (registry+https://github.com/rust-lang/crates.io-index)" = "0609345ddee5badacf857d4f547e0e5a2e987db77085c24cd887f73573a04237"
"checksum matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "7ffc5c5338469d4d3ea17d269fa8ea3512ad247247c30bd2df69e68309ed0a08"
"checksum md5 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7e6bcd6433cff03a4bfc3d9834d504467db1f1cf6d0ea765d37d330249ed629d"
"checksum md5 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"

View File

@ -1,7 +1,7 @@
[package]
name = "flodgatt"
description = "A blazingly fast drop-in replacement for the Mastodon streaming api server"
version = "0.5.0"
version = "0.6.0"
authors = ["Daniel Long Sockwell <daniel@codesections.com", "Julian Laubstein <contact@julianlaubstein.de>"]
edition = "2018"
@ -23,6 +23,7 @@ strum = "0.16.0"
strum_macros = "0.16.0"
r2d2_postgres = "0.16.0"
r2d2 = "0.8.8"
lru = "0.4.3"
[dev-dependencies]
criterion = "0.3"

View File

@ -6,6 +6,14 @@ pub fn die_with_msg(msg: impl Display) -> ! {
std::process::exit(1);
}
#[macro_export]
macro_rules! log_fatal {
($str:expr, $var:expr) => {{
log::error!($str, $var);
panic!();
};};
}
pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
eprintln!(
r"FATAL ERROR: {var} is set to `{value}`, which is invalid.

View File

@ -26,11 +26,11 @@ fn main() {
let cfg = config::DeploymentConfig::from_env(env_vars.clone());
let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone());
let client_agent_sse = ClientAgent::blank(redis_cfg);
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
let pg_pool = user::PgPool::new(postgres_cfg);
let client_agent_sse = ClientAgent::blank(redis_cfg, pg_pool.clone());
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
log::warn!("Streaming server initialized and ready to accept connections");
// Server Sent Events
@ -38,7 +38,7 @@ fn main() {
let sse_routes = sse::extract_user_or_reject(pg_pool.clone())
.and(warp::sse())
.map(
move |user: user::User, sse_connection_to_client: warp::sse::Sse| {
move |user: user::Subscription, sse_connection_to_client: warp::sse::Sse| {
log::info!("Incoming SSE request");
// Create a new ClientAgent
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
@ -57,29 +57,30 @@ fn main() {
// WebSocket
let ws_update_interval = *cfg.ws_interval;
let websocket_routes = ws::extract_user_or_reject(pg_pool.clone())
let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone())
.and(warp::ws::ws2())
.map(move |user: user::User, ws: Ws2| {
log::info!("Incoming websocket request");
let token = user.access_token.clone();
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(user);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
.map(
move |user: user::Subscription, token: Option<String>, ws: Ws2| {
log::info!("Incoming websocket request");
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(user);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
(
ws.on_upgrade(move |socket| {
redis_to_client_stream::send_updates_to_ws(
socket,
client_agent,
ws_update_interval,
)
}),
token,
)
})
(
ws.on_upgrade(move |socket| {
redis_to_client_stream::send_updates_to_ws(
socket,
client_agent,
ws_update_interval,
)
}),
token.unwrap_or_else(String::new),
)
},
)
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
let cors = warp::cors()

View File

@ -1,11 +1,11 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use super::{
query::{self, Query},
user::{PgPool, User},
user::{PgPool, Subscription},
};
use warp::{filters::BoxedFilter, path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, User),);
type TimelineUser = ((String, Subscription),);
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
@ -39,7 +39,7 @@ macro_rules! parse_query {
.boxed()
};
}
pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(User,)> {
pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> {
any_of!(
parse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
@ -67,402 +67,402 @@ pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(User,)> {
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| User::from_query(q, pg_pool.clone()))
.and_then(move |q| Subscription::from_query(q, pg_pool.clone()))
.boxed()
}
#[cfg(test)]
mod test {
use super::*;
use crate::parse_client_request::user::{Blocks, Filter, OauthScope, PgPool};
// #[cfg(test)]
// mod test {
// use super::*;
// use crate::parse_client_request::user::{Blocks, Filter, OauthScope, PgPool};
macro_rules! test_public_endpoint {
($name:ident {
endpoint: $path:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let mock_pg_pool = PgPool::new();
let user = warp::test::request()
.path($path)
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
assert_eq!(user, $user);
}
};
}
macro_rules! test_private_endpoint {
($name:ident {
endpoint: $path:expr,
$(query: $query:expr,)*
user: $user:expr,
}) => {
#[test]
fn $name() {
let path = format!("{}?access_token=TEST_USER", $path);
let mock_pg_pool = PgPool::new();
$(let path = format!("{}&{}", path, $query);)*
let user = warp::test::request()
.path(&path)
.filter(&extract_user_or_reject(mock_pg_pool.clone()))
.expect("in test");
assert_eq!(user, $user);
let user = warp::test::request()
.path(&path)
.header("Authorization", "Bearer: TEST_USER")
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
assert_eq!(user, $user);
}
};
}
macro_rules! test_bad_auth_token_in_query {
($name: ident {
endpoint: $path:expr,
$(query: $query:expr,)*
}) => {
#[test]
#[should_panic(expected = "Error: Invalid access token")]
fn $name() {
let path = format!("{}?access_token=INVALID", $path);
$(let path = format!("{}&{}", path, $query);)*
let mock_pg_pool = PgPool::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
};
}
macro_rules! test_bad_auth_token_in_header {
($name: ident {
endpoint: $path:expr,
$(query: $query:expr,)*
}) => {
#[test]
#[should_panic(expected = "Error: Invalid access token")]
fn $name() {
let path = $path;
$(let path = format!("{}?{}", path, $query);)*
// macro_rules! test_public_endpoint {
// ($name:ident {
// endpoint: $path:expr,
// user: $user:expr,
// }) => {
// #[test]
// fn $name() {
// let mock_pg_pool = PgPool::new();
// let user = warp::test::request()
// .path($path)
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// assert_eq!(user, $user);
// }
// };
// }
// macro_rules! test_private_endpoint {
// ($name:ident {
// endpoint: $path:expr,
// $(query: $query:expr,)*
// user: $user:expr,
// }) => {
// #[test]
// fn $name() {
// let path = format!("{}?access_token=TEST_USER", $path);
// let mock_pg_pool = PgPool::new();
// $(let path = format!("{}&{}", path, $query);)*
// let user = warp::test::request()
// .path(&path)
// .filter(&extract_user_or_reject(mock_pg_pool.clone()))
// .expect("in test");
// assert_eq!(user, $user);
// let user = warp::test::request()
// .path(&path)
// .header("Authorization", "Bearer: TEST_USER")
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// assert_eq!(user, $user);
// }
// };
// }
// macro_rules! test_bad_auth_token_in_query {
// ($name: ident {
// endpoint: $path:expr,
// $(query: $query:expr,)*
// }) => {
// #[test]
// #[should_panic(expected = "Error: Invalid access token")]
// fn $name() {
// let path = format!("{}?access_token=INVALID", $path);
// $(let path = format!("{}&{}", path, $query);)*
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path(&path)
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// };
// }
// macro_rules! test_bad_auth_token_in_header {
// ($name: ident {
// endpoint: $path:expr,
// $(query: $query:expr,)*
// }) => {
// #[test]
// #[should_panic(expected = "Error: Invalid access token")]
// fn $name() {
// let path = $path;
// $(let path = format!("{}?{}", path, $query);)*
let mock_pg_pool = PgPool::new();
warp::test::request()
.path(&path)
.header("Authorization", "Bearer: INVALID")
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
};
}
macro_rules! test_missing_auth {
($name: ident {
endpoint: $path:expr,
$(query: $query:expr,)*
}) => {
#[test]
#[should_panic(expected = "Error: Missing access token")]
fn $name() {
let path = $path;
$(let path = format!("{}?{}", path, $query);)*
let mock_pg_pool = PgPool::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
};
}
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path(&path)
// .header("Authorization", "Bearer: INVALID")
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// };
// }
// macro_rules! test_missing_auth {
// ($name: ident {
// endpoint: $path:expr,
// $(query: $query:expr,)*
// }) => {
// #[test]
// #[should_panic(expected = "Error: Missing access token")]
// fn $name() {
// let path = $path;
// $(let path = format!("{}?{}", path, $query);)*
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path(&path)
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// };
// }
test_public_endpoint!(public_media_true {
endpoint: "/api/v1/streaming/public?only_media=true",
user: User {
target_timeline: "public:media".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(public_media_1 {
endpoint: "/api/v1/streaming/public?only_media=1",
user: User {
target_timeline: "public:media".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(public_local {
endpoint: "/api/v1/streaming/public/local",
user: User {
target_timeline: "public:local".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(public_local_media_true {
endpoint: "/api/v1/streaming/public/local?only_media=true",
user: User {
target_timeline: "public:local:media".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(public_local_media_1 {
endpoint: "/api/v1/streaming/public/local?only_media=1",
user: User {
target_timeline: "public:local:media".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag {
endpoint: "/api/v1/streaming/hashtag?tag=a",
user: User {
target_timeline: "hashtag:a".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag_local {
endpoint: "/api/v1/streaming/hashtag/local?tag=a",
user: User {
target_timeline: "hashtag:local:a".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
// test_public_endpoint!(public_media_true {
// endpoint: "/api/v1/streaming/public?only_media=true",
// user: Subscription {
// timeline: "public:media".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(public_media_1 {
// endpoint: "/api/v1/streaming/public?only_media=1",
// user: Subscription {
// timeline: "public:media".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(public_local {
// endpoint: "/api/v1/streaming/public/local",
// user: Subscription {
// timeline: "public:local".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(public_local_media_true {
// endpoint: "/api/v1/streaming/public/local?only_media=true",
// user: Subscription {
// timeline: "public:local:media".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(public_local_media_1 {
// endpoint: "/api/v1/streaming/public/local?only_media=1",
// user: Subscription {
// timeline: "public:local:media".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(hashtag {
// endpoint: "/api/v1/streaming/hashtag?tag=a",
// user: Subscription {
// timeline: "hashtag:a".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(hashtag_local {
// endpoint: "/api/v1/streaming/hashtag/local?tag=a",
// user: Subscription {
// timeline: "hashtag:local:a".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
test_private_endpoint!(user {
endpoint: "/api/v1/streaming/user",
user: User {
target_timeline: "1".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
},
});
test_private_endpoint!(user_notification {
endpoint: "/api/v1/streaming/user/notification",
user: User {
target_timeline: "1".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::Notification,
},
});
test_private_endpoint!(direct {
endpoint: "/api/v1/streaming/direct",
user: User {
target_timeline: "direct".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
},
});
// test_private_endpoint!(user {
// endpoint: "/api/v1/streaming/user",
// user: Subscription {
// timeline: "1".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::NoFilter,
// },
// });
// test_private_endpoint!(user_notification {
// endpoint: "/api/v1/streaming/user/notification",
// user: Subscription {
// timeline: "1".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::Notification,
// },
// });
// test_private_endpoint!(direct {
// endpoint: "/api/v1/streaming/direct",
// user: Subscription {
// timeline: "direct".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::NoFilter,
// },
// });
test_private_endpoint!(list_valid_list {
endpoint: "/api/v1/streaming/list",
query: "list=1",
user: User {
target_timeline: "list:1".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
},
});
test_bad_auth_token_in_query!(public_media_true_bad_auth {
endpoint: "/api/v1/streaming/public",
query: "only_media=true",
});
test_bad_auth_token_in_header!(public_media_1_bad_auth {
endpoint: "/api/v1/streaming/public",
query: "only_media=1",
});
test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
endpoint: "/api/v1/streaming/public/local",
});
test_bad_auth_token_in_header!(public_local_bad_auth_in_header {
endpoint: "/api/v1/streaming/public/local",
});
test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
endpoint: "/api/v1/streaming/public/local",
query: "only_media=1",
});
test_bad_auth_token_in_header!(public_local_media_timeline_bad_token_in_header {
endpoint: "/api/v1/streaming/public/local",
query: "only_media=true",
});
test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
endpoint: "/api/v1/streaming/hashtag",
query: "tag=a",
});
test_bad_auth_token_in_header!(hashtag_bad_auth_in_header {
endpoint: "/api/v1/streaming/hashtag",
query: "tag=a",
});
test_bad_auth_token_in_query!(user_bad_auth_in_query {
endpoint: "/api/v1/streaming/user",
});
test_bad_auth_token_in_header!(user_bad_auth_in_header {
endpoint: "/api/v1/streaming/user",
});
test_missing_auth!(user_missing_auth_token {
endpoint: "/api/v1/streaming/user",
});
test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
endpoint: "/api/v1/streaming/user/notification",
});
test_bad_auth_token_in_header!(user_notification_bad_auth_in_header {
endpoint: "/api/v1/streaming/user/notification",
});
test_missing_auth!(user_notification_missing_auth_token {
endpoint: "/api/v1/streaming/user/notification",
});
test_bad_auth_token_in_query!(direct_bad_auth_in_query {
endpoint: "/api/v1/streaming/direct",
});
test_bad_auth_token_in_header!(direct_bad_auth_in_header {
endpoint: "/api/v1/streaming/direct",
});
test_missing_auth!(direct_missing_auth_token {
endpoint: "/api/v1/streaming/direct",
});
test_bad_auth_token_in_query!(list_bad_auth_in_query {
endpoint: "/api/v1/streaming/list",
query: "list=1",
});
test_bad_auth_token_in_header!(list_bad_auth_in_header {
endpoint: "/api/v1/streaming/list",
query: "list=1",
});
test_missing_auth!(list_missing_auth_token {
endpoint: "/api/v1/streaming/list",
query: "list=1",
});
// test_private_endpoint!(list_valid_list {
// endpoint: "/api/v1/streaming/list",
// query: "list=1",
// user: Subscription {
// timeline: "list:1".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::NoFilter,
// },
// });
// test_bad_auth_token_in_query!(public_media_true_bad_auth {
// endpoint: "/api/v1/streaming/public",
// query: "only_media=true",
// });
// test_bad_auth_token_in_header!(public_media_1_bad_auth {
// endpoint: "/api/v1/streaming/public",
// query: "only_media=1",
// });
// test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
// endpoint: "/api/v1/streaming/public/local",
// });
// test_bad_auth_token_in_header!(public_local_bad_auth_in_header {
// endpoint: "/api/v1/streaming/public/local",
// });
// test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
// endpoint: "/api/v1/streaming/public/local",
// query: "only_media=1",
// });
// test_bad_auth_token_in_header!(public_local_media_timeline_bad_token_in_header {
// endpoint: "/api/v1/streaming/public/local",
// query: "only_media=true",
// });
// test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
// endpoint: "/api/v1/streaming/hashtag",
// query: "tag=a",
// });
// test_bad_auth_token_in_header!(hashtag_bad_auth_in_header {
// endpoint: "/api/v1/streaming/hashtag",
// query: "tag=a",
// });
// test_bad_auth_token_in_query!(user_bad_auth_in_query {
// endpoint: "/api/v1/streaming/user",
// });
// test_bad_auth_token_in_header!(user_bad_auth_in_header {
// endpoint: "/api/v1/streaming/user",
// });
// test_missing_auth!(user_missing_auth_token {
// endpoint: "/api/v1/streaming/user",
// });
// test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
// endpoint: "/api/v1/streaming/user/notification",
// });
// test_bad_auth_token_in_header!(user_notification_bad_auth_in_header {
// endpoint: "/api/v1/streaming/user/notification",
// });
// test_missing_auth!(user_notification_missing_auth_token {
// endpoint: "/api/v1/streaming/user/notification",
// });
// test_bad_auth_token_in_query!(direct_bad_auth_in_query {
// endpoint: "/api/v1/streaming/direct",
// });
// test_bad_auth_token_in_header!(direct_bad_auth_in_header {
// endpoint: "/api/v1/streaming/direct",
// });
// test_missing_auth!(direct_missing_auth_token {
// endpoint: "/api/v1/streaming/direct",
// });
// test_bad_auth_token_in_query!(list_bad_auth_in_query {
// endpoint: "/api/v1/streaming/list",
// query: "list=1",
// });
// test_bad_auth_token_in_header!(list_bad_auth_in_header {
// endpoint: "/api/v1/streaming/list",
// query: "list=1",
// });
// test_missing_auth!(list_missing_auth_token {
// endpoint: "/api/v1/streaming/list",
// query: "list=1",
// });
#[test]
#[should_panic(expected = "NotFound")]
fn nonexistant_endpoint() {
let mock_pg_pool = PgPool::new();
warp::test::request()
.path("/api/v1/streaming/DOES_NOT_EXIST")
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
}
// #[test]
// #[should_panic(expected = "NotFound")]
// fn nonexistant_endpoint() {
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path("/api/v1/streaming/DOES_NOT_EXIST")
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// }

View File

@ -1,5 +1,5 @@
//! Mock Postgres connection (for use in unit testing)
use super::{OauthScope, User};
use super::{OauthScope, Subscription};
use std::collections::HashSet;
#[derive(Clone)]
@ -10,8 +10,11 @@ impl PgPool {
}
}
pub fn select_user(access_token: &str, _pg_pool: PgPool) -> Result<User, warp::reject::Rejection> {
let mut user = User::default();
pub fn select_user(
access_token: &str,
_pg_pool: PgPool,
) -> Result<Subscription, warp::reject::Rejection> {
let mut user = Subscription::default();
if access_token == "TEST_USER" {
user.id = 1;
user.logged_in = true;

View File

@ -1,144 +1,195 @@
//! `User` struct and related functionality
#[cfg(test)]
mod mock_postgres;
#[cfg(test)]
use mock_postgres as postgres;
#[cfg(not(test))]
mod postgres;
// #[cfg(test)]
// mod mock_postgres;
// #[cfg(test)]
// use mock_postgres as postgres;
// #[cfg(not(test))]
pub mod postgres;
pub use self::postgres::PgPool;
use super::query::Query;
use crate::log_fatal;
use std::collections::HashSet;
use warp::reject::Rejection;
/// The filters that can be applied to toots after they come from Redis
#[derive(Clone, Debug, PartialEq)]
pub enum Filter {
NoFilter,
Language,
Notification,
}
impl Default for Filter {
fn default() -> Self {
Filter::Language
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct OauthScope {
pub all: bool,
pub statuses: bool,
pub notify: bool,
pub lists: bool,
}
impl From<Vec<String>> for OauthScope {
fn from(scope_list: Vec<String>) -> Self {
let mut oauth_scope = OauthScope::default();
for scope in scope_list {
match scope.as_str() {
"read" => oauth_scope.all = true,
"read:statuses" => oauth_scope.statuses = true,
"read:notifications" => oauth_scope.notify = true,
"read:lists" => oauth_scope.lists = true,
_ => (),
}
}
oauth_scope
}
}
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Blocks {
pub domain_blocks: HashSet<String>,
pub user_blocks: HashSet<i64>,
}
/// The User (with data read from Postgres)
#[derive(Clone, Debug, PartialEq)]
pub struct User {
pub target_timeline: String,
pub email: String, // We only use email for logging; we could cut it for performance
pub access_token: String, // We only need this once (to send back with the WS reply). Cut?
pub id: i64,
pub scopes: OauthScope,
pub langs: Option<Vec<String>>,
pub logged_in: bool,
pub filter: Filter,
pub struct Subscription {
pub timeline: Timeline,
pub allowed_langs: HashSet<String>,
pub blocks: Blocks,
}
impl Default for User {
impl Default for Subscription {
fn default() -> Self {
Self {
id: -1,
email: "".to_string(),
access_token: "".to_string(),
scopes: OauthScope::default(),
langs: None,
logged_in: false,
target_timeline: String::new(),
filter: Filter::default(),
timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
allowed_langs: HashSet::new(),
blocks: Blocks::default(),
}
}
}
impl User {
impl Subscription {
pub fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
println!("Creating user...");
let mut user: User = match q.access_token.clone() {
None => User::default(),
let user = match q.access_token.clone() {
Some(token) => postgres::select_user(&token, pool.clone())?,
None => UserData::public(),
};
user = user.set_timeline_and_filter(q, pool.clone())?;
user.blocks.user_blocks = postgres::select_user_blocks(user.id, pool.clone());
user.blocks.domain_blocks = postgres::select_domain_blocks(pool.clone());
dbg!(&user);
Ok(user)
}
fn set_timeline_and_filter(mut self, q: Query, pool: PgPool) -> Result<Self, Rejection> {
let read_scope = self.scopes.clone();
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
// Private endpoints: User:
"user" if self.logged_in && (read_scope.all || read_scope.statuses) => {
self.filter = Filter::NoFilter;
format!("{}", self.id)
}
"user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => {
self.filter = Filter::Notification;
format!("{}", self.id)
}
// List endpoint:
"list" if self.owns_list(q.list, pool) && (read_scope.all || read_scope.lists) => {
self.filter = Filter::NoFilter;
format!("list:{}", q.list)
}
// Direct endpoint:
"direct" if self.logged_in && (read_scope.all || read_scope.statuses) => {
self.filter = Filter::NoFilter;
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Missing access token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent endpoint")),
};
Ok(Self {
target_timeline: timeline,
..self
Ok(Subscription {
timeline: Timeline::from_query_and_user(&q, &user, pool.clone())?,
allowed_langs: user.allowed_langs,
blocks: Blocks {
blocking_users: postgres::select_blocking_users(user.id, pool.clone()),
blocked_users: postgres::select_blocked_users(user.id, pool.clone()),
blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()),
},
})
}
}
fn owns_list(&self, list: i64, pool: PgPool) -> bool {
postgres::user_owns_list(self.id, list, pool)
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub struct Timeline(pub Stream, pub Reach, pub Content);
impl Timeline {
pub fn empty() -> Self {
use {Content::*, Reach::*, Stream::*};
Self(Unset, Local, Notification)
}
pub fn to_redis_str(&self, hashtag: Option<&String>) -> String {
use {Content::*, Reach::*, Stream::*};
match self {
Timeline(Public, Federated, All) => "timeline:public".into(),
Timeline(Public, Local, All) => "timeline:public:local".into(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
Timeline(Hashtag(id), Federated, All) => format!(
"timeline:hashtag:{}",
hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
),
Timeline(Hashtag(id), Local, All) => format!(
"timeline:hashtag:{}:local",
hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
),
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
Timeline(one, _two, _three) => {
log_fatal!("Supposedly impossible timeline reached: {:?}", one)
}
}
}
pub fn from_redis_str(raw_timeline: &str, hashtag: Option<i64>) -> Self {
use {Content::*, Reach::*, Stream::*};
match raw_timeline.split(':').collect::<Vec<&str>>()[..] {
["public"] => Timeline(Public, Federated, All),
["public", "local"] => Timeline(Public, Local, All),
["public", "media"] => Timeline(Public, Federated, Media),
["public", "local", "media"] => Timeline(Public, Local, Media),
["hashtag", _tag] => Timeline(Hashtag(hashtag.unwrap()), Federated, All),
["hashtag", _tag, "local"] => Timeline(Hashtag(hashtag.unwrap()), Local, All),
[id] => Timeline(User(id.parse().unwrap()), Federated, All),
[id, "notification"] => Timeline(User(id.parse().unwrap()), Federated, Notification),
["list", id] => Timeline(List(id.parse().unwrap()), Federated, All),
["direct", id] => Timeline(Direct(id.parse().unwrap()), Federated, All),
// Other endpoints don't exist:
[..] => log_fatal!("Unexpected channel from Redis: {}", raw_timeline),
}
}
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
let id_from_hashtag = || postgres::select_list_id(&q.hashtag, pool.clone());
let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone());
Ok(match q.stream.as_ref() {
"public" => match q.media {
true => Timeline(Public, Federated, Media),
false => Timeline(Public, Federated, All),
},
"public:local" => match q.media {
true => Timeline(Public, Local, Media),
false => Timeline(Public, Local, All),
},
"public:media" => Timeline(Public, Federated, Media),
"public:local:media" => Timeline(Public, Local, Media),
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
"user" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
"user:notification" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, Notification),
false => Err(custom("Error: Missing access token"))?,
},
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
true => Timeline(List(q.list), Federated, All),
false => Err(warp::reject::custom("Error: Missing access token"))?,
},
"direct" => match user.scopes.contains(&Statuses) {
true => Timeline(Direct(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
other => {
log::warn!("Client attempted to subscribe to: `{}`", other);
Err(custom("Error: Nonexistent endpoint"))?
}
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Stream {
User(i64),
List(i64),
Direct(i64),
Hashtag(i64),
Public,
Unset,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Reach {
Local,
Federated,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Content {
All,
Media,
Notification,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Scope {
Read,
Statuses,
Notifications,
Lists,
}
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Blocks {
pub blocked_domains: HashSet<String>,
pub blocked_users: HashSet<i64>,
pub blocking_users: HashSet<i64>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct UserData {
id: i64,
allowed_langs: HashSet<String>,
scopes: HashSet<Scope>,
}
impl UserData {
fn public() -> Self {
Self {
id: -1,
allowed_langs: HashSet::new(),
scopes: HashSet::new(),
}
}
}

View File

@ -1,14 +1,14 @@
//! Postgres queries
use crate::{
config,
parse_client_request::user::{OauthScope, User},
parse_client_request::user::{Scope, UserData},
};
use ::postgres;
use r2d2_postgres::PostgresConnectionManager;
use std::collections::HashSet;
use warp::reject::Rejection;
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
impl PgPool {
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
@ -30,16 +30,12 @@ impl PgPool {
}
}
/// Build a user based on the result of querying Postgres with the access token
///
/// This does _not_ set the timeline, filter, or blocks fields. Use the various `User`
/// methods to do so. In general, this function shouldn't be needed outside `User`.
pub fn select_user(access_token: &str, pg_pool: PgPool) -> Result<User, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
let query_result = conn
pub fn select_user(token: &str, pool: PgPool) -> Result<UserData, Rejection> {
let mut conn = pool.0.get().unwrap();
let query_rows = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.email, users.chosen_languages, oauth_access_tokens.scopes
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
@ -47,27 +43,84 @@ oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&access_token.to_owned()],
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if query_result.is_empty() {
Err(warp::reject::custom("Error: Invalid access token"))
} else {
let only_row: &postgres::Row = query_result.get(0).unwrap();
let scope_vec: Vec<String> = only_row
.get::<_, String>(4)
.split(' ')
.map(|s| s.to_owned())
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let id = result_columns.get(1);
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
.into_iter()
.collect();
Ok(User {
id: only_row.get(1),
access_token: access_token.to_string(),
email: only_row.get(2),
logged_in: true,
scopes: OauthScope::from(scope_vec),
langs: only_row.get(3),
..User::default()
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.filter_map(|scope| match scope {
"read" => Some(Scope::Read),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
"write" | "follow" => None, // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
None
}
})
.collect();
// We don't need to separately track read auth - it's just all three others
if scopes.remove(&Scope::Read) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
}
Ok(UserData {
id,
allowed_langs,
scopes,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
pub fn select_list_id(tag_name: &String, pg_pool: PgPool) -> Result<i64, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id
FROM tags
WHERE name = $1
LIMIT 1",
&[&tag_name],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result<String, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT name
FROM tags
WHERE id = $1
LIMIT 1",
&[&tag_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
@ -75,7 +128,18 @@ LIMIT 1",
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_user_blocks(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
// "
// SELECT
// 1
// FROM blocks
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
// OR (account_id = $2 AND target_account_id = $1)
// UNION SELECT
// 1
// FROM mutes
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
pg_pool
.0
.get()
@ -95,17 +159,41 @@ UNION SELECT target_account_id
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
pg_pool
.0
.get()
.unwrap()
.query(
"
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for all current domain blocks
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_domain_blocks(pg_pool: PgPool) -> HashSet<String> {
pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet<String> {
pg_pool
.0
.get()
.unwrap()
.query("SELECT domain FROM account_domain_blocks", &[])
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))

View File

View File

@ -1,7 +1,7 @@
//! Filters for the WebSocket endpoint
use super::{
query::{self, Query},
user::{PgPool, User},
user::{PgPool, Subscription},
};
use warp::{filters::BoxedFilter, path, Filter};
@ -32,316 +32,319 @@ fn parse_query() -> BoxedFilter<(Query,)> {
.boxed()
}
pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(User,)> {
pub fn extract_user_and_token_or_reject(
pg_pool: PgPool,
) -> BoxedFilter<(Subscription, Option<String>)> {
parse_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| User::from_query(q, pg_pool.clone()))
.and_then(move |q| Subscription::from_query(q, pg_pool.clone()))
.and(query::OptionalAccessToken::from_ws_header())
.boxed()
}
#[cfg(test)]
mod test {
use super::*;
use crate::parse_client_request::user::{Blocks, Filter, OauthScope};
// #[cfg(test)]
// mod test {
// use super::*;
// use crate::parse_client_request::user::{Blocks, Filter, OauthScope};
macro_rules! test_public_endpoint {
($name:ident {
endpoint: $path:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let mock_pg_pool = PgPool::new();
let user = warp::test::request()
.path($path)
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
assert_eq!(user, $user);
}
};
}
macro_rules! test_private_endpoint {
($name:ident {
endpoint: $path:expr,
user: $user:expr,
}) => {
#[test]
fn $name() {
let mock_pg_pool = PgPool::new();
let path = format!("{}&access_token=TEST_USER", $path);
let user = warp::test::request()
.path(&path)
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
assert_eq!(user, $user);
}
};
}
macro_rules! test_bad_auth_token_in_query {
($name: ident {
endpoint: $path:expr,
// macro_rules! test_public_endpoint {
// ($name:ident {
// endpoint: $path:expr,
// user: $user:expr,
// }) => {
// #[test]
// fn $name() {
// let mock_pg_pool = PgPool::new();
// let user = warp::test::request()
// .path($path)
// .header("connection", "upgrade")
// .header("upgrade", "websocket")
// .header("sec-websocket-version", "13")
// .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// assert_eq!(user, $user);
// }
// };
// }
// macro_rules! test_private_endpoint {
// ($name:ident {
// endpoint: $path:expr,
// user: $user:expr,
// }) => {
// #[test]
// fn $name() {
// let mock_pg_pool = PgPool::new();
// let path = format!("{}&access_token=TEST_USER", $path);
// let user = warp::test::request()
// .path(&path)
// .header("connection", "upgrade")
// .header("upgrade", "websocket")
// .header("sec-websocket-version", "13")
// .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// assert_eq!(user, $user);
// }
// };
// }
// macro_rules! test_bad_auth_token_in_query {
// ($name: ident {
// endpoint: $path:expr,
}) => {
#[test]
#[should_panic(expected = "Error: Invalid access token")]
// }) => {
// #[test]
// #[should_panic(expected = "Error: Invalid access token")]
fn $name() {
let path = format!("{}&access_token=INVALID", $path);
let mock_pg_pool = PgPool::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
};
}
macro_rules! test_missing_auth {
($name: ident {
endpoint: $path:expr,
}) => {
#[test]
#[should_panic(expected = "Error: Missing access token")]
fn $name() {
let path = $path;
let mock_pg_pool = PgPool::new();
warp::test::request()
.path(&path)
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
};
}
// fn $name() {
// let path = format!("{}&access_token=INVALID", $path);
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path(&path)
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// };
// }
// macro_rules! test_missing_auth {
// ($name: ident {
// endpoint: $path:expr,
// }) => {
// #[test]
// #[should_panic(expected = "Error: Missing access token")]
// fn $name() {
// let path = $path;
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path(&path)
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// };
// }
test_public_endpoint!(public_media {
endpoint: "/api/v1/streaming?stream=public:media",
user: User {
target_timeline: "public:media".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(public_local {
endpoint: "/api/v1/streaming?stream=public:local",
user: User {
target_timeline: "public:local".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(public_local_media {
endpoint: "/api/v1/streaming?stream=public:local:media",
user: User {
target_timeline: "public:local:media".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag {
endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
user: User {
target_timeline: "hashtag:a".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
test_public_endpoint!(hashtag_local {
endpoint: "/api/v1/streaming?stream=hashtag:local&tag=a",
user: User {
target_timeline: "hashtag:local:a".to_string(),
id: -1,
email: "".to_string(),
access_token: "".to_string(),
langs: None,
scopes: OauthScope {
all: false,
statuses: false,
notify: false,
lists: false,
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
},
});
// test_public_endpoint!(public_media {
// endpoint: "/api/v1/streaming?stream=public:media",
// user: Subscription {
// timeline: "public:media".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(public_local {
// endpoint: "/api/v1/streaming?stream=public:local",
// user: Subscription {
// timeline: "public:local".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(public_local_media {
// endpoint: "/api/v1/streaming?stream=public:local:media",
// user: Subscription {
// timeline: "public:local:media".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(hashtag {
// endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
// user: Subscription {
// timeline: "hashtag:a".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
// test_public_endpoint!(hashtag_local {
// endpoint: "/api/v1/streaming?stream=hashtag:local&tag=a",
// user: Subscription {
// timeline: "hashtag:local:a".to_string(),
// id: -1,
// email: "".to_string(),
// access_token: "".to_string(),
// langs: None,
// scopes: OauthScope {
// all: false,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: false,
// blocks: Blocks::default(),
// allowed_langs: Filter::Language,
// },
// });
test_private_endpoint!(user {
endpoint: "/api/v1/streaming?stream=user",
user: User {
target_timeline: "1".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
},
});
test_private_endpoint!(user_notification {
endpoint: "/api/v1/streaming?stream=user:notification",
user: User {
target_timeline: "1".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::Notification,
},
});
test_private_endpoint!(direct {
endpoint: "/api/v1/streaming?stream=direct",
user: User {
target_timeline: "direct".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
},
});
test_private_endpoint!(list_valid_list {
endpoint: "/api/v1/streaming?stream=list&list=1",
user: User {
target_timeline: "list:1".to_string(),
id: 1,
email: "user@example.com".to_string(),
access_token: "TEST_USER".to_string(),
langs: None,
scopes: OauthScope {
all: true,
statuses: false,
notify: false,
lists: false,
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
},
});
// test_private_endpoint!(user {
// endpoint: "/api/v1/streaming?stream=user",
// user: Subscription {
// timeline: "1".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::NoFilter,
// },
// });
// test_private_endpoint!(user_notification {
// endpoint: "/api/v1/streaming?stream=user:notification",
// user: Subscription {
// timeline: "1".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::Notification,
// },
// });
// test_private_endpoint!(direct {
// endpoint: "/api/v1/streaming?stream=direct",
// user: Subscription {
// timeline: "direct".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::NoFilter,
// },
// });
// test_private_endpoint!(list_valid_list {
// endpoint: "/api/v1/streaming?stream=list&list=1",
// user: Subscription {
// timeline: "list:1".to_string(),
// id: 1,
// email: "user@example.com".to_string(),
// access_token: "TEST_USER".to_string(),
// langs: None,
// scopes: OauthScope {
// all: true,
// statuses: false,
// notify: false,
// lists: false,
// },
// logged_in: true,
// blocks: Blocks::default(),
// allowed_langs: Filter::NoFilter,
// },
// });
test_bad_auth_token_in_query!(public_media_true_bad_auth {
endpoint: "/api/v1/streaming?stream=public:media",
});
test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=public:local",
});
test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=public:local:media",
});
test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
});
test_bad_auth_token_in_query!(user_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=user",
});
test_missing_auth!(user_missing_auth_token {
endpoint: "/api/v1/streaming?stream=user",
});
test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=user:notification",
});
test_missing_auth!(user_notification_missing_auth_token {
endpoint: "/api/v1/streaming?stream=user:notification",
});
test_bad_auth_token_in_query!(direct_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=direct",
});
test_missing_auth!(direct_missing_auth_token {
endpoint: "/api/v1/streaming?stream=direct",
});
test_bad_auth_token_in_query!(list_bad_auth_in_query {
endpoint: "/api/v1/streaming?stream=list&list=1",
});
test_missing_auth!(list_missing_auth_token {
endpoint: "/api/v1/streaming?stream=list&list=1",
});
// test_bad_auth_token_in_query!(public_media_true_bad_auth {
// endpoint: "/api/v1/streaming?stream=public:media",
// });
// test_bad_auth_token_in_query!(public_local_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=public:local",
// });
// test_bad_auth_token_in_query!(public_local_media_timeline_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=public:local:media",
// });
// test_bad_auth_token_in_query!(hashtag_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=hashtag&tag=a",
// });
// test_bad_auth_token_in_query!(user_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=user",
// });
// test_missing_auth!(user_missing_auth_token {
// endpoint: "/api/v1/streaming?stream=user",
// });
// test_bad_auth_token_in_query!(user_notification_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=user:notification",
// });
// test_missing_auth!(user_notification_missing_auth_token {
// endpoint: "/api/v1/streaming?stream=user:notification",
// });
// test_bad_auth_token_in_query!(direct_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=direct",
// });
// test_missing_auth!(direct_missing_auth_token {
// endpoint: "/api/v1/streaming?stream=direct",
// });
// test_bad_auth_token_in_query!(list_bad_auth_in_query {
// endpoint: "/api/v1/streaming?stream=list&list=1",
// });
// test_missing_auth!(list_missing_auth_token {
// endpoint: "/api/v1/streaming?stream=list&list=1",
// });
#[test]
#[should_panic(expected = "NotFound")]
fn nonexistant_endpoint() {
let mock_pg_pool = PgPool::new();
warp::test::request()
.path("/api/v1/streaming/DOES_NOT_EXIST")
.header("connection", "upgrade")
.header("upgrade", "websocket")
.header("sec-websocket-version", "13")
.header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
.filter(&extract_user_or_reject(mock_pg_pool))
.expect("in test");
}
}
// #[test]
// #[should_panic(expected = "NotFound")]
// fn nonexistant_endpoint() {
// let mock_pg_pool = PgPool::new();
// warp::test::request()
// .path("/api/v1/streaming/DOES_NOT_EXIST")
// .header("connection", "upgrade")
// .header("upgrade", "websocket")
// .header("sec-websocket-version", "13")
// .header("sec-websocket-key", "dGhlIHNhbXBsZSBub25jZQ==")
// .filter(&extract_user_or_reject(mock_pg_pool))
// .expect("in test");
// }
// }

View File

@ -14,12 +14,18 @@
//!
//! Because `StreamManagers` are lightweight data structures that do not directly
//! communicate with Redis, it we create a new `ClientAgent` for
//! each new client connection (each in its own thread).
use super::receiver::Receiver;
use crate::{config, parse_client_request::user::User};
use futures::{Async, Poll};
use serde_json::Value;
use std::{collections::HashSet, sync};
//! each new client connection (each in its own thread).use super::{message::Message, receiver::Receiver}
use super::{message::Message, receiver::Receiver};
use crate::{
config,
parse_client_request::user::{PgPool, Subscription},
};
use futures::{
Async::{self, NotReady, Ready},
Poll,
};
use std::sync;
use tokio::io::Error;
use uuid::Uuid;
@ -28,18 +34,17 @@ use uuid::Uuid;
pub struct ClientAgent {
receiver: sync::Arc<sync::Mutex<Receiver>>,
id: uuid::Uuid,
pub target_timeline: String,
pub current_user: User,
// pub current_timeline: String,
subscription: Subscription,
}
impl ClientAgent {
/// Create a new `ClientAgent` with no shared data.
pub fn blank(redis_cfg: config::RedisConfig) -> Self {
pub fn blank(redis_cfg: config::RedisConfig, pg_pool: PgPool) -> Self {
ClientAgent {
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg))),
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg, pg_pool))),
id: Uuid::default(),
target_timeline: String::new(),
current_user: User::default(),
subscription: Subscription::default(),
}
}
@ -48,30 +53,29 @@ impl ClientAgent {
Self {
receiver: self.receiver.clone(),
id: self.id,
target_timeline: self.target_timeline.clone(),
current_user: self.current_user.clone(),
subscription: self.subscription.clone(),
}
}
/// Initializes the `ClientAgent` with a unique ID, a `User`, and the target timeline.
/// Also passes values to the `Receiver` for it's initialization.
/// Initializes the `ClientAgent` with a unique ID associated with a specific user's
/// subscription. Also passes values to the `Receiver` for it's initialization.
///
/// Note that this *may or may not* result in a new Redis connection.
/// If the server has already subscribed to the timeline on behalf of
/// a different user, the `Receiver` is responsible for figuring
/// that out and avoiding duplicated connections. Thus, it is safe to
/// use this method for each new client connection.
pub fn init_for_user(&mut self, user: User) {
pub fn init_for_user(&mut self, subscription: Subscription) {
self.id = Uuid::new_v4();
self.target_timeline = user.target_timeline.to_owned();
self.current_user = user;
self.subscription = subscription;
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
receiver.manage_new_timeline(self.id, &self.target_timeline);
receiver.manage_new_timeline(self.id, self.subscription.timeline);
}
}
/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented.
impl futures::stream::Stream for ClientAgent {
type Item = Toot;
type Item = Message;
type Error = Error;
/// Checks for any new messages that should be sent to the client.
@ -89,126 +93,34 @@ impl futures::stream::Stream for ClientAgent {
.receiver
.lock()
.expect("ClientAgent: No other thread panic");
receiver.configure_for_polling(self.id, &self.target_timeline.clone());
receiver.configure_for_polling(self.id, self.subscription.timeline);
receiver.poll()
};
if start_time.elapsed().as_millis() > 1 {
log::warn!("Polling the Receiver took: {:?}", start_time.elapsed());
};
let allowed_langs = &self.subscription.allowed_langs;
let blocked_users = &self.subscription.blocks.blocked_users;
let blocking_users = &self.subscription.blocks.blocking_users;
let blocked_domains = &self.subscription.blocks.blocked_domains;
let (send, block) = (|msg| Ok(Ready(Some(msg))), Ok(NotReady));
use Message::*;
match result {
Ok(Async::Ready(Some(value))) => {
let user = &self.current_user;
let toot = Toot::from_json(value);
toot.filter(&user)
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(Some(json))) => match Message::from_json(json) {
Update(status) if status.language_not_allowed(allowed_langs) => block,
Update(status) if status.involves_blocked_user(blocked_users) => block,
Update(status) if status.from_blocked_domain(blocked_domains) => block,
Update(status) if status.from_blocking_user(blocking_users) => block,
Update(status) => send(Update(status)),
Notification(notification) => send(Notification(notification)),
Conversation(notification) => send(Conversation(notification)),
Delete(status_id) => send(Delete(status_id)),
FiltersChanged => send(FiltersChanged),
},
Ok(Ready(None)) => Ok(Ready(None)),
Ok(NotReady) => Ok(NotReady),
Err(e) => Err(e),
}
}
}
/// The message to send to the client (which might not literally be a toot in some cases).
#[derive(Debug, Clone)]
pub struct Toot {
pub category: String,
pub payload: Value,
pub language: Option<String>,
}
impl Toot {
/// Construct a `Toot` from well-formed JSON.
pub fn from_json(value: Value) -> Self {
let category = value["event"].as_str().expect("Redis string").to_owned();
let language = if category == "update" {
Some(value["payload"]["language"].to_string())
} else {
None
};
Self {
category,
payload: value["payload"].clone(),
language,
}
}
pub fn get_originating_domain(&self) -> HashSet<String> {
let api = "originating Invariant Violation: JSON value does not conform to Mastdon API";
let mut originating_domain = HashSet::new();
originating_domain.insert(
self.payload["account"]["acct"]
.as_str()
.expect(&api)
.split("@")
.nth(1)
.expect(&api)
.to_string(),
);
originating_domain
}
pub fn get_involved_users(&self) -> HashSet<i64> {
let mut involved_users: HashSet<i64> = HashSet::new();
let msg = self.payload.clone();
let api = "Invariant Violation: JSON value does not conform to Mastdon API";
involved_users.insert(msg["account"]["id"].str_to_i64().expect(&api));
if let Some(mentions) = msg["mentions"].as_array() {
for mention in mentions {
involved_users.insert(mention["id"].str_to_i64().expect(&api));
}
}
if let Some(replied_to_account) = msg["in_reply_to_account_id"].as_str() {
involved_users.insert(replied_to_account.parse().expect(&api));
}
if let Some(reblog) = msg["reblog"].as_object() {
involved_users.insert(reblog["account"]["id"].str_to_i64().expect(&api));
}
involved_users
}
/// Filter out any `Toot`'s that fail the provided filter.
pub fn filter(self, user: &User) -> Result<Async<Option<Self>>, Error> {
let toot = self;
let category = toot.category.clone();
let toot_language = &toot.language.clone().expect("Valid lanugage");
let (send_msg, skip_msg) = (Ok(Async::Ready(Some(toot))), Ok(Async::NotReady));
if category == "update" {
use crate::parse_client_request::user::Filter;
match &user.filter {
Filter::NoFilter => send_msg,
Filter::Notification if category == "notification" => send_msg,
// If not, skip it
Filter::Notification => skip_msg,
Filter::Language if user.langs.is_none() => send_msg,
Filter::Language if user.langs.clone().expect("").contains(toot_language) => {
send_msg
}
// If not, skip it
Filter::Language => skip_msg,
}
} else {
send_msg
}
}
}
trait ConvertValue {
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>>;
}
impl ConvertValue for Value {
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>> {
Ok(self
.as_str()
.ok_or(format!("{} is not a string", &self))?
.parse()
.map_err(|_| "Could not parse str")?)
}
}

View File

@ -0,0 +1,167 @@
use crate::log_fatal;
use log::{log_enabled, Level};
use serde_json::Value;
use std::{collections::HashSet, string::String};
use strum_macros::Display;
#[derive(Debug, Display, Clone)]
pub enum Message {
Update(Status),
Conversation(Value),
Notification(Value),
Delete(String),
FiltersChanged,
}
#[derive(Debug, Clone)]
pub struct Status(Value);
impl Message {
pub fn from_json(json: Value) -> Self {
let event = json["event"]
.as_str()
.unwrap_or_else(|| log_fatal!("Could not process `event` in {:?}", json));
match event {
"update" => Self::Update(Status(json["payload"].clone())),
"conversation" => Self::Conversation(json["payload"].clone()),
"notification" => Self::Notification(json["payload"].clone()),
"delete" => Self::Delete(json["payload"].to_string()),
"filters_changed" => Self::FiltersChanged,
unsupported_event => log_fatal!(
"Received an unsupported `event` type from Redis: {}",
unsupported_event
),
}
}
pub fn event(&self) -> String {
format!("{}", self).to_lowercase()
}
pub fn payload(&self) -> String {
match self {
Self::Delete(id) => id.clone(),
Self::Update(status) => status.0.to_string(),
Self::Conversation(value) | Self::Notification(value) => value.to_string(),
Self::FiltersChanged => "".to_string(),
}
}
}
impl Status {
/// Returns `true` if the status is filtered out based on its language
pub fn language_not_allowed(&self, allowed_langs: &HashSet<String>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
let reject_and_maybe_log = |toot_language| {
if log_enabled!(Level::Info) {
log::info!(
"Language `{toot_language}` is not in list `{allowed_langs:?}`",
toot_language = toot_language,
allowed_langs = allowed_langs
);
log::info!("Filtering out toot from `{}`", &self.0["account"]["acct"],);
}
REJECT
};
if allowed_langs.is_empty() {
return ALLOW; // listing no allowed_langs results in allowing all languages
}
match self.0["language"].as_str() {
Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW,
Some(toot_language) => reject_and_maybe_log(toot_language),
None => ALLOW, // If toot language is null, toot is always allowed
}
}
/// Returns `true` if this toot originated from a domain the User has blocked.
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
let full_username = self.0["account"]["acct"]
.as_str()
.unwrap_or_else(|| log_fatal!("Could not process `account.acct` in {:?}", self.0));
match full_username.split('@').nth(1) {
Some(originating_domain) => blocked_domains.contains(originating_domain),
None => false, // None means the user is on the local instance, which can't be blocked
}
}
/// Returns `true` if the Status is from an account that has blocked the current user.
pub fn from_blocking_user(&self, blocking_users: &HashSet<i64>) -> bool {
let toot = self.0.clone();
const ALLOW: bool = false;
const REJECT: bool = true;
let author = toot["account"]["id"]
.str_to_i64()
.unwrap_or_else(|_| log_fatal!("Could not process `account.id` in {:?}", toot));
if blocking_users.contains(&author) {
REJECT
} else {
ALLOW
}
}
/// Returns `true` if the User's list of blocked and muted users includes a user
/// involved in this toot.
///
/// A user is involved if they:
/// * Wrote this toot
/// * Are mentioned in this toot
/// * Wrote a toot that this toot is replying to (if any)
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
let toot = self.0.clone();
const ALLOW: bool = false;
const REJECT: bool = true;
let author_user = match toot["account"]["id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
Err(_) => log_fatal!("Could not process `account.id` in {:?}", toot),
};
let mentioned_users = (match &toot["mentions"] {
Value::Array(inner) => inner,
_ => log_fatal!("Could not process `mentions` in {:?}", toot),
})
.into_iter()
.map(|mention| match mention["id"].str_to_i64() {
Ok(user_id) => user_id,
Err(_) => log_fatal!("Could not process `id` field of mention in {:?}", toot),
});
let replied_to_user = match toot["in_reply_to_account_id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
Err(_) => vec![].into_iter(), // no error; just no replied_to_user
};
let boosted_user = match toot["reblog"].as_object() {
Some(boosted_user) => match boosted_user["account"]["id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
Err(_) => log_fatal!("Could not process `reblog.account.id` in {:?}", toot),
},
None => vec![].into_iter(), // no error; just no boosted_user
};
let involved_users = author_user
.chain(mentioned_users)
.chain(replied_to_user)
.chain(boosted_user)
.collect::<HashSet<i64>>();
if involved_users.is_disjoint(blocked_users) {
ALLOW
} else {
REJECT
}
}
}
trait ConvertValue {
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>>;
}
impl ConvertValue for Value {
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>> {
Ok(self.as_str().ok_or("none_err")?.parse()?)
}
}

View File

@ -1,5 +1,6 @@
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
pub mod client_agent;
pub mod message;
pub mod receiver;
pub mod redis;
@ -17,9 +18,9 @@ pub fn send_updates_to_sse(
) -> impl warp::reply::Reply {
let event_stream = tokio::timer::Interval::new(time::Instant::now(), update_interval)
.filter_map(move |_| match client_agent.poll() {
Ok(Async::Ready(Some(toot))) => Some((
warp::sse::event(toot.category),
warp::sse::data(toot.payload),
Ok(Async::Ready(Some(msg))) => Some((
warp::sse::event(msg.event()),
warp::sse::data(msg.payload()),
)),
_ => None,
});
@ -55,11 +56,6 @@ pub fn send_updates_to_ws(
}),
);
let (tl, email, id) = (
client_agent.current_user.target_timeline.clone(),
client_agent.current_user.email.clone(),
client_agent.current_user.id,
);
// Yield new events for as long as the client is still connected
let event_stream = tokio::timer::Interval::new(time::Instant::now(), update_interval)
.take_while(move |_| match ws_rx.poll() {
@ -75,39 +71,23 @@ pub fn send_updates_to_ws(
futures::future::ok(false)
}
Err(e) => {
log::warn!("Error in TL {}\nfor user: {}({})\n{}", tl, email, id, e);
log::warn!("Error in TL {}", e);
futures::future::ok(false)
}
});
let mut time = time::Instant::now();
let (tl, email, id, blocked_users, blocked_domains) = (
client_agent.current_user.target_timeline.clone(),
client_agent.current_user.email.clone(),
client_agent.current_user.id,
client_agent.current_user.blocks.user_blocks.clone(),
client_agent.current_user.blocks.domain_blocks.clone(),
);
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_instant| {
if let Ok(Async::Ready(Some(toot))) = client_agent.poll() {
if blocked_domains.is_disjoint(&toot.get_originating_domain())
&& blocked_users.is_disjoint(&toot.get_involved_users())
{
let txt = &toot.payload["content"];
log::warn!("toot: {}\nTL: {}\nUser: {}({})", txt, tl, email, id);
tx.unbounded_send(warp::ws::Message::text(
json!({ "event": toot.category,
"payload": &toot.payload.to_string() })
.to_string(),
))
.expect("No send error");
} else {
log::info!("Blocked a message to {}", email);
}
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
tx.unbounded_send(warp::ws::Message::text(
json!({ "event": msg.event(),
"payload": msg.payload() })
.to_string(),
))
.expect("No send error");
};
if time.elapsed() > time::Duration::from_secs(30) {
tx.unbounded_send(warp::ws::Message::text("{}"))
@ -121,5 +101,5 @@ pub fn send_updates_to_ws(
log::info!("WebSocket connection closed.");
result
})
.map_err(move |e| log::warn!("Error sending to user: {}\n{}", id, e))
.map_err(move |e| log::warn!("Error sending to user: {}", e))
}

View File

@ -1,21 +1,37 @@
use crate::parse_client_request::user::Timeline;
use serde_json::Value;
use std::{collections, time};
use std::{collections, fmt, time};
use uuid::Uuid;
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct MsgQueue {
pub timeline: Timeline,
pub messages: collections::VecDeque<Value>,
last_polled_at: time::Instant,
pub redis_channel: String,
}
impl fmt::Debug for MsgQueue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"\
MsgQueue {{
timeline: {:?},
messages: {:?},
last_polled_at: {:?},
}}",
self.timeline,
self.messages,
self.last_polled_at.elapsed(),
)
}
}
impl MsgQueue {
pub fn new(redis_channel: impl std::fmt::Display) -> Self {
let redis_channel = redis_channel.to_string();
pub fn new(timeline: Timeline) -> Self {
MsgQueue {
messages: collections::VecDeque::new(),
last_polled_at: time::Instant::now(),
redis_channel,
timeline,
}
}
}
@ -29,26 +45,26 @@ impl MessageQueues {
.and_modify(|queue| queue.last_polled_at = time::Instant::now());
}
pub fn oldest_msg_in_target_queue(&mut self, id: Uuid, timeline: String) -> Option<Value> {
pub fn oldest_msg_in_target_queue(&mut self, id: Uuid, timeline: Timeline) -> Option<Value> {
self.entry(id)
.or_insert_with(|| MsgQueue::new(timeline))
.messages
.pop_front()
}
pub fn calculate_timelines_to_add_or_drop(&mut self, timeline: String) -> Vec<Change> {
pub fn calculate_timelines_to_add_or_drop(&mut self, timeline: Timeline) -> Vec<Change> {
let mut timelines_to_modify = Vec::new();
timelines_to_modify.push(Change {
timeline: timeline.to_owned(),
timeline,
in_subscriber_number: 1,
});
self.retain(|_id, msg_queue| {
if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) {
true
} else {
let timeline = &msg_queue.redis_channel;
let timeline = &msg_queue.timeline;
timelines_to_modify.push(Change {
timeline: timeline.to_owned(),
timeline: *timeline,
in_subscriber_number: -1,
});
false
@ -58,7 +74,7 @@ impl MessageQueues {
}
}
pub struct Change {
pub timeline: String,
pub timeline: Timeline,
pub in_subscriber_number: i32,
}

View File

@ -4,13 +4,16 @@
mod message_queues;
use crate::{
config::{self, RedisInterval},
log_fatal,
parse_client_request::user::{self, postgres, PgPool, Timeline},
pubsub_cmd,
redis_to_client_stream::redis::{redis_cmd, RedisConn, RedisStream},
};
use futures::{Async, Poll};
use lru::LruCache;
pub use message_queues::{MessageQueues, MsgQueue};
use serde_json::Value;
use std::{collections, net, time};
use std::{collections::HashMap, net, time};
use tokio::io::Error;
use uuid::Uuid;
@ -21,16 +24,30 @@ pub struct Receiver {
secondary_redis_connection: net::TcpStream,
redis_poll_interval: RedisInterval,
redis_polled_at: time::Instant,
timeline: String,
timeline: Timeline,
manager_id: Uuid,
pub msg_queues: MessageQueues,
clients_per_timeline: collections::HashMap<String, i32>,
clients_per_timeline: HashMap<Timeline, i32>,
cache: Cache,
pool: PgPool,
}
#[derive(Debug)]
struct Cache {
id_to_hashtag: LruCache<i64, String>,
hashtag_to_id: LruCache<String, i64>,
}
impl Cache {
fn new(size: usize) -> Self {
Self {
id_to_hashtag: LruCache::new(size),
hashtag_to_id: LruCache::new(size),
}
}
}
impl Receiver {
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
/// active subscriptions).
pub fn new(redis_cfg: config::RedisConfig) -> Self {
pub fn new(redis_cfg: config::RedisConfig, pool: PgPool) -> Self {
let RedisConn {
primary: pubsub_connection,
secondary: secondary_redis_connection,
@ -44,10 +61,12 @@ impl Receiver {
secondary_redis_connection,
redis_poll_interval,
redis_polled_at: time::Instant::now(),
timeline: String::new(),
timeline: Timeline::empty(),
manager_id: Uuid::default(),
msg_queues: MessageQueues(collections::HashMap::new()),
clients_per_timeline: collections::HashMap::new(),
msg_queues: MessageQueues(HashMap::new()),
clients_per_timeline: HashMap::new(),
cache: Cache::new(1000), // should this be a run-time option?
pool,
}
}
@ -57,9 +76,9 @@ impl Receiver {
/// Note: this method calls `subscribe_or_unsubscribe_as_needed`,
/// so Redis PubSub subscriptions are only updated when a new timeline
/// comes under management for the first time.
pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) {
pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: Timeline) {
self.manager_id = manager_id;
self.timeline = timeline.to_string();
self.timeline = timeline;
self.msg_queues
.insert(self.manager_id, MsgQueue::new(timeline));
self.subscribe_or_unsubscribe_as_needed(timeline);
@ -67,32 +86,55 @@ impl Receiver {
/// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate
/// value to be polled by the current `StreamManager`.
pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) {
pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: Timeline) {
self.manager_id = manager_id;
self.timeline = timeline.to_string();
self.timeline = timeline;
}
fn if_hashtag_timeline_get_hashtag_name(&mut self, timeline: Timeline) -> Option<String> {
use user::Stream::*;
if let Timeline(Hashtag(id), _, _) = timeline {
let cached_tag = self.cache.id_to_hashtag.get(&id).map(String::from);
let tag = match cached_tag {
Some(tag) => tag,
None => {
let new_tag = postgres::select_hashtag_name(&id, self.pool.clone())
.unwrap_or_else(|_| log_fatal!("No hashtag associated with tag #{}", &id));
self.cache.hashtag_to_id.put(new_tag.clone(), id);
self.cache.id_to_hashtag.put(id, new_tag.clone());
new_tag.to_string()
}
};
Some(tag)
} else {
None
}
}
/// Drop any PubSub subscriptions that don't have active clients and check
/// that there's a subscription to the current one. If there isn't, then
/// subscribe to it.
fn subscribe_or_unsubscribe_as_needed(&mut self, timeline: &str) {
fn subscribe_or_unsubscribe_as_needed(&mut self, timeline: Timeline) {
let start_time = std::time::Instant::now();
let timelines_to_modify = self
.msg_queues
.calculate_timelines_to_add_or_drop(timeline.to_string());
let timelines_to_modify = self.msg_queues.calculate_timelines_to_add_or_drop(timeline);
// Record the lower number of clients subscribed to that channel
for change in timelines_to_modify {
let timeline = change.timeline;
let opt_hashtag = self.if_hashtag_timeline_get_hashtag_name(timeline);
let opt_hashtag = opt_hashtag.as_ref();
let count_of_subscribed_clients = self
.clients_per_timeline
.entry(change.timeline.clone())
.entry(timeline)
.and_modify(|n| *n += change.in_subscriber_number)
.or_insert_with(|| 1);
// If no clients, unsubscribe from the channel
if *count_of_subscribed_clients <= 0 {
pubsub_cmd!("unsubscribe", self, change.timeline.clone());
pubsub_cmd!("unsubscribe", self, timeline.to_redis_str(opt_hashtag));
} else if *count_of_subscribed_clients == 1 && change.in_subscriber_number == 1 {
pubsub_cmd!("subscribe", self, change.timeline.clone());
pubsub_cmd!("subscribe", self, timeline.to_redis_str(opt_hashtag));
}
}
if start_time.elapsed().as_millis() > 1 {
@ -115,7 +157,29 @@ impl futures::stream::Stream for Receiver {
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
let (timeline, id) = (self.timeline.clone(), self.manager_id);
if self.redis_polled_at.elapsed() > *self.redis_poll_interval {
self.pubsub_connection.poll_redis(&mut self.msg_queues);
for (raw_timeline, msg_value) in self.pubsub_connection.poll_redis() {
let hashtag = if raw_timeline.starts_with("hashtag") {
let tag_name = raw_timeline
.split(':')
.nth(1)
.unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline))
.to_string();
let tag_id = *self
.cache
.hashtag_to_id
.get(&tag_name)
.unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
Some(tag_id)
} else {
None
};
let timeline = Timeline::from_redis_str(&raw_timeline, hashtag);
for msg_queue in self.msg_queues.values_mut() {
if msg_queue.timeline == timeline {
msg_queue.messages.push_back(msg_value.clone());
}
}
}
self.redis_polled_at = time::Instant::now();
}
@ -129,9 +193,3 @@ impl futures::stream::Stream for Receiver {
}
}
}
impl Drop for Receiver {
fn drop(&mut self) {
pubsub_cmd!("unsubscribe", self, self.timeline.clone());
}
}

View File

@ -23,7 +23,7 @@ macro_rules! pubsub_cmd {
$self
.secondary_redis_connection
.write_all(&redis_cmd::set(
format!("subscribed:timeline:{}", $tl),
format!("subscribed:{}", $tl),
subscription_new_number,
namespace.clone(),
))
@ -35,8 +35,8 @@ macro_rules! pubsub_cmd {
/// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline
pub fn pubsub(command: impl Display, timeline: impl Display, ns: Option<String>) -> Vec<u8> {
let arg = match ns {
Some(namespace) => format!("{}:timeline:{}", namespace, timeline),
None => format!("timeline:{}", timeline),
Some(namespace) => format!("{}:{}", namespace, timeline),
None => format!("{}", timeline),
};
cmd(command, arg)
}

View File

@ -39,7 +39,7 @@ impl<'a> RedisMsg<'a> {
item
}
pub fn extract_timeline_and_message(&mut self) -> (String, Value) {
pub fn extract_raw_timeline_and_message(&mut self) -> (String, Value) {
let timeline = &self.next_field()[self.prefix_len..];
let msg_txt = self.next_field();
let msg_value: Value =

View File

@ -1,6 +1,7 @@
use super::redis_msg::RedisMsg;
use crate::{config::RedisNamespace, redis_to_client_stream::receiver::MessageQueues};
use crate::config::RedisNamespace;
use futures::{Async, Poll};
use serde_json::Value;
use std::{io::Read, net};
use tokio::io::AsyncRead;
@ -27,8 +28,9 @@ impl RedisStream {
// into messages. Incoming messages *are* guaranteed to be RESP arrays,
// https://redis.io/topics/protocol
/// Adds any new Redis messages to the `MsgQueue` for the appropriate `ClientAgent`.
pub fn poll_redis(&mut self, msg_queues: &mut MessageQueues) {
pub fn poll_redis(&mut self) -> Vec<(String, Value)> {
let mut buffer = vec![0u8; 6000];
let mut messages = Vec::new();
if let Async::Ready(num_bytes_read) = self.poll_read(&mut buffer).unwrap() {
let raw_utf = self.as_utf8(buffer, num_bytes_read);
@ -36,7 +38,7 @@ impl RedisStream {
// Only act if we have a full message (end on a msg boundary)
if !self.incoming_raw_msg.ends_with("}\r\n") {
return;
return messages;
};
let prefix_to_skip = match &*self.namespace {
Some(namespace) => format!("{}:timeline:", namespace),
@ -49,12 +51,8 @@ impl RedisStream {
let command = msg.next_field();
match command.as_str() {
"message" => {
let (timeline, msg_value) = msg.extract_timeline_and_message();
for msg_queue in msg_queues.values_mut() {
if msg_queue.redis_channel == timeline {
msg_queue.messages.push_back(msg_value.clone());
}
}
let (raw_timeline, msg_value) = msg.extract_raw_timeline_and_message();
messages.push((raw_timeline, msg_value));
}
"subscribe" | "unsubscribe" => {
@ -64,12 +62,13 @@ impl RedisStream {
let _active_subscriptions = msg.process_number();
msg.cursor += "\r\n".len();
}
cmd => panic!("Invariant violation: {} is invalid Redis input", cmd),
cmd => panic!("Invariant violation: {} is unexpected Redis output", cmd),
};
msg = RedisMsg::from_raw(&msg.raw[msg.cursor..], msg.prefix_len);
}
self.incoming_raw_msg.clear();
}
messages
}
fn as_utf8(&mut self, cur_buffer: Vec<u8>, size: usize) -> String {