Merge pull request #22 from tootsuite/cleanup_and_document

Refactor, cleanup, and document
This commit is contained in:
Daniel Sockwell 2019-07-09 13:19:50 -04:00 committed by GitHub
commit a67317b0a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1578 additions and 1324 deletions

44
Cargo.lock generated
View File

@ -27,7 +27,7 @@ version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)",
"termion 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
"termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -144,11 +144,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "chrono"
version = "0.4.6"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)",
"num-integer 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)",
"time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -246,14 +247,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "env_logger"
version = "0.6.1"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"atty 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)",
"humantime 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.1.6 (registry+https://github.com/rust-lang/crates.io-index)",
"termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)",
"termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@ -637,16 +638,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "num-integer"
version = "0.1.39"
version = "0.1.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-traits"
version = "0.2.6"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num_cpus"
@ -782,8 +787,8 @@ name = "pretty_env_logger"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
"chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)",
"env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -814,6 +819,7 @@ version = "0.1.0"
dependencies = [
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
"futures 0.1.26 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
"postgres 0.15.2 (registry+https://github.com/rust-lang/crates.io-index)",
"pretty_env_logger 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1168,7 +1174,7 @@ dependencies = [
[[package]]
name = "termcolor"
version = "1.0.4"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"wincolor 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1176,7 +1182,7 @@ dependencies = [
[[package]]
name = "termion"
version = "1.5.2"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"libc 0.2.54 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1597,7 +1603,7 @@ dependencies = [
"checksum bytes 0.4.12 (registry+https://github.com/rust-lang/crates.io-index)" = "206fdffcfa2df7cbe15601ef46c813fce0965eb3286db6b56c583b814b51c81c"
"checksum cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)" = "a0c56216487bb80eec9c4516337b2588a4f2a2290d72a1416d930e4dcdb0c90d"
"checksum cfg-if 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "11d43355396e872eefb45ce6342e4374ed7bc2b3a502d1b28e36d6e23c05d1f4"
"checksum chrono 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "45912881121cb26fad7c38c17ba7daa18764771836b34fab7d3fbd93ed633878"
"checksum chrono 0.4.7 (registry+https://github.com/rust-lang/crates.io-index)" = "77d81f58b7301084de3b958691458a53c3f7e0b1d702f77e550b6a88e3a88abe"
"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f"
"checksum constant_time_eq 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8ff012e225ce166d4422e0e78419d901719760f62ae2b7969ca6b564d1b54a9e"
"checksum crossbeam-deque 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b18cd2e169ad86297e6bc0ad9aa679aee9daa4f19e8163860faf7c164e4f5a71"
@ -1609,7 +1615,7 @@ dependencies = [
"checksum digest 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "05f47366984d3ad862010e22c7ce81a7dbcaebbdfb37241a620f8b6596ee135c"
"checksum dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7bdb5b956a911106b6b479cdc6bc1364d359a32299f17b49994f5327132e18d9"
"checksum dtoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "ea57b42383d091c85abcc2706240b94ab2a8fa1fc81c10ff23c4de06e2a90b5e"
"checksum env_logger 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b61fa891024a945da30a9581546e8cfaf5602c7b3f4c137a2805cf388f92075a"
"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3"
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
"checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1"
"checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed"
@ -1655,8 +1661,8 @@ dependencies = [
"checksum miow 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919"
"checksum net2 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88"
"checksum nodrop 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "2f9667ddcc6cc8a43afc9b7917599d7216aa09c463919ea32c59ed6cac8bc945"
"checksum num-integer 0.1.39 (registry+https://github.com/rust-lang/crates.io-index)" = "e83d528d2677f0518c570baf2b7abdcf0cd2d248860b68507bdcb3e91d4c0cea"
"checksum num-traits 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "0b3a5d7cc97d6d30d8b9bc8fa19bf45349ffe46241e8816f50f62f6d6aaabee1"
"checksum num-integer 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "8b8af8caa3184078cd419b430ff93684cb13937970fcb7639f728992f33ce674"
"checksum num-traits 0.2.7 (registry+https://github.com/rust-lang/crates.io-index)" = "d9c79c952a4a139f44a0fe205c4ee66ce239c0e6ce72cd935f5f7e2f717549dd"
"checksum num_cpus 1.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1a23f0ed30a54abaa0c7e83b1d2d87ada7c3c23078d1d87815af3e3b6385fbba"
"checksum numtoa 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef"
"checksum opaque-debug 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "93f5bb2e8e8dec81642920ccff6b61f1eb94fa3020c5a325c9851ff604152409"
@ -1716,8 +1722,8 @@ dependencies = [
"checksum stringprep 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1"
"checksum syn 0.15.34 (registry+https://github.com/rust-lang/crates.io-index)" = "a1393e4a97a19c01e900df2aec855a29f71cf02c402e2f443b8d2747c25c5dbe"
"checksum synstructure 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "73687139bf99285483c96ac0add482c3776528beac1d97d444f6e91f203a2015"
"checksum termcolor 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4096add70612622289f2fdcdbd5086dc81c1e2675e6ae58d6c4f62a16c6d7f2f"
"checksum termion 1.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dde0593aeb8d47accea5392b39350015b5eccb12c0d98044d856983d89548dea"
"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e"
"checksum termion 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a8fb22f7cde82c8220e5aeacb3258ed7ce996142c77cba193f203515e26c330"
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
"checksum time 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "db8dcfca086c1143c9270ac42a2bbd8a7ee477b78ac8e45b19abfb0cbede4b6f"
"checksum tokio 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)" = "cec6c34409089be085de9403ba2010b80e36938c9ca992c4f67f407bb13db0b1"

View File

@ -18,6 +18,7 @@ pretty_env_logger = "0.3.0"
postgres = "0.15.2"
uuid = { version = "0.7", features = ["v4"] }
dotenv = "0.14.0"
lazy_static = "1.3.0"
[features]
default = [ "production" ]

View File

@ -2,31 +2,63 @@
A WIP blazingly fast drop-in replacement for the Mastodon streaming api server.
## Current status
The streaming server is very much a work in progress. It is currently missing essential features including support for SSL, CORS, and separate development/production environments. However, it has reached the point where it is usable/testable in a localhost development environment and I would greatly appreciate any testing, bug reports, or other feedback you could provide.
The streaming server is currently a work in progress. However, it is now testable and, if
configured properly, would theoretically be usable in production—though production use is
not advisable until we have completed further testing. I would greatly appreciate any testing,
bug reports, or other feedback you could provide.
## Installation
Installing the WIP version requires the Rust toolchain (the released version will be available as a pre-compiled binary). To install, clone this repository and run `cargo build` (to build the server) `cargo run` (to both build and run the server), or `cargo build --release` (to build the server with release optimizations).
Installing the WIP version requires the Rust toolchain (the released version will be available
as a pre-compiled binary). To install, clone this repository and run `cargo build` (to build
the server) `cargo run` (to both build and run the server), or `cargo build --release` (to
build the server with release optimizations).
## Connection to Mastodon
The streaming server expects to connect to a running development version of Mastodon built off of the `master` branch. Specifically, it needs to connect to both the Postgres database (to authenticate users) and to the Redis database. You should run Mastodon in whatever way you normally do and configure the streaming server to connect to the appropriate databases.
The streaming server expects to connect to a running development version of Mastodon built off of
the `master` branch. Specifically, it needs to connect to both the Postgres database (to
authenticate users) and to the Redis database. You should run Mastodon in whatever way you
normally do and configure the streaming server to connect to the appropriate databases.
## Configuring
You may edit the (currently limited) configuration variables in the `.env` file. Note that, by default, this server is configured to run on port 4000. This allows for easy testing with the development version of Mastodon (which, by default, is configured to communicate with a streaming server running on `localhost:4000`). However, it also conflicts with the current/Node.js version of Mastodon's streaming server, which runs on the same port. Thus, to test this server, you should disable the other streaming server or move it to a non-conflicting port.
You may edit the configuration variables in the `config.rs` module. You can also overwrite the
default config variables in the `.env` file. Note that, by default, this server is configured
to run on port 4000. This allows for easy testing with the development version of Mastodon
(which, by default, is configured to communicate with a streaming server running on
`localhost:4000`). However, it also conflicts with the current/Node.js version of Mastodon's
streaming server, which runs on the same port. Thus, to test this server, you should disable
the Node streaming server or move it to a non-conflicting port.
## Documentation
Build documentation with `cargo doc --open`, which will build the Markdown docs and open them in your browser. Please consult those docs for a description of the code structure/organization.
Build documentation with `cargo doc --open`, which will build the Markdown docs and open them
in your browser. Please consult those docs for a detailed description of the code
structure/organization. The documentation also contains additional notes about data flow and
options for configuration.
## Running
As noted above, you can run the server with `cargo run`. Alternatively, if you built the sever using `cargo build` or `cargo build --release`, you can run the executable produced in the `target/build/debug` folder or the `target/build/release` folder.
As noted above, you can run the server with `cargo run`. Alternatively, if you built the sever
using `cargo build` or `cargo build --release`, you can run the executable produced in the
`target/build/debug` folder or the `target/build/release` folder.
## Unit and (limited) integration tests
You can run basic unit test of the public Server Sent Event endpoints with `cargo test`. You can run integration tests of the authenticated SSE endpoints (which require a Postgres connection) with `cargo test -- --ignored`.
You can run basic unit test of the public Server Sent Event endpoints with `cargo test`. You can
run integration tests of the authenticated SSE endpoints (which require a Postgres connection)
with `cargo test -- --ignored`.
## Manual testing
Once the streaming server is running, you can also test it manually. You can test it using a browser connected to the relevant Mastodon development server. Or you can test the SSE endpoints with `curl`, PostMan, or any other HTTP client. Similarly, you can test the WebSocket endpoints with `websocat` or any other WebSocket client.
Once the streaming server is running, you can also test it manually. You can test it using a
browser connected to the relevant Mastodon development server. Or you can test the SSE endpoints
with `curl`, PostMan, or any other HTTP client. Similarly, you can test the WebSocket endpoints
with `websocat` or any other WebSocket client.
## Memory/CPU usage
Note that memory usage is higher when running the development version of the streaming server (the one generated with `cargo run` or `cargo build`). If you are interested in measuring RAM or CPU usage, you should likely run `cargo build --release` and test the release version of the executable.
Note that memory usage is higher when running the development version of the streaming server (the
one generated with `cargo run` or `cargo build`). If you are interested in measuring RAM or CPU
usage, you should likely run `cargo build --release` and test the release version of the executable.
## Load testing
I have not yet found a good way to test the streaming server under load. I have experimented with using `artillery` or other load-testing utilities. However, every utility I am familiar with or have found is built around either HTTP requests or WebSocket connections in which the client sends messages. I have not found a good solution to test receiving SSEs or WebSocket connections where the client does not transmit data after establishing the connection. If you are aware of a good way to do load testing, please let me know.
I have not yet found a good way to test the streaming server under load. I have experimented with
using `artillery` or other load-testing utilities. However, every utility I am familiar with or
have found is built around either HTTP requests or WebSocket connections in which the client sends
messages. I have not found a good solution to test receiving SSEs or WebSocket connections where
the client does not transmit data after establishing the connection. If you are aware of a good
way to do load testing, please let me know.

View File

@ -3,3 +3,4 @@
#SERVER_ADDR=
#REDIS_ADDR=
#POSTGRES_ADDR=
CORS_ALLOWED_METHODS="GET OPTIONS"

128
src/config.rs Normal file
View File

@ -0,0 +1,128 @@
//! Configuration defaults. All settings with the prefix of `DEFAULT_` can be overridden
//! by an environmental variable of the same name without that prefix (either by setting
//! the variable at runtime or in the `.env` file)
use dotenv::dotenv;
use lazy_static::lazy_static;
use log::warn;
use serde_derive::Serialize;
use std::{env, net, time};
const CORS_ALLOWED_METHODS: [&str; 2] = ["GET", "OPTIONS"];
const CORS_ALLOWED_HEADERS: [&str; 3] = ["Authorization", "Accept", "Cache-Control"];
const DEFAULT_POSTGRES_ADDR: &str = "postgres://@localhost/mastodon_development";
const DEFAULT_REDIS_ADDR: &str = "127.0.0.1:6379";
const DEFAULT_SERVER_ADDR: &str = "127.0.0.1:4000";
const DEFAULT_SSE_UPDATE_INTERVAL: u64 = 100;
const DEFAULT_WS_UPDATE_INTERVAL: u64 = 100;
const DEFAULT_REDIS_POLL_INTERVAL: u64 = 100;
lazy_static! {
static ref POSTGRES_ADDR: String = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| {
let mut postgres_addr = DEFAULT_POSTGRES_ADDR.to_string();
postgres_addr.insert_str(11,
&env::var("USER").unwrap_or_else(|_| {
warn!("No USER env variable set. Connecting to Postgress with default `postgres` user");
"postgres".to_string()
}).as_str()
);
postgres_addr
});
static ref REDIS_ADDR: String = env::var("REDIS_ADDR").unwrap_or_else(|_| DEFAULT_REDIS_ADDR.to_owned());
pub static ref SERVER_ADDR: net::SocketAddr = env::var("SERVER_ADDR")
.unwrap_or_else(|_| DEFAULT_SERVER_ADDR.to_owned())
.parse()
.expect("static string");
/// Interval, in ms, at which the `ClientAgent` polls the `Receiver` for updates to send via SSE.
pub static ref SSE_UPDATE_INTERVAL: u64 = env::var("SSE_UPDATE_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_SSE_UPDATE_INTERVAL);
/// Interval, in ms, at which the `ClientAgent` polls the `Receiver` for updates to send via WS.
pub static ref WS_UPDATE_INTERVAL: u64 = env::var("WS_UPDATE_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_WS_UPDATE_INTERVAL);
/// Interval, in ms, at which the `Receiver` polls Redis.
/// **NOTE**: Polling Redis is much more time consuming than polling the `Receiver`
/// (on the order of 10ms rather than 50μs). Thus, changing this setting
/// would be a good place to start for performance improvements at the cost
/// of delaying all updates.
pub static ref REDIS_POLL_INTERVAL: u64 = env::var("REDIS_POLL_INTERVAL")
.map(|s| s.parse().expect("Valid config"))
.unwrap_or(DEFAULT_REDIS_POLL_INTERVAL);
}
/// Configure CORS for the API server
pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors {
warp::cors()
.allow_any_origin()
.allow_methods(CORS_ALLOWED_METHODS.to_vec())
.allow_headers(CORS_ALLOWED_HEADERS.to_vec())
}
/// Initialize logging and read values from `src/.env`
pub fn logging_and_env() {
pretty_env_logger::init();
dotenv().ok();
}
/// Configure Postgres and return a connection
pub fn postgres() -> postgres::Connection {
postgres::Connection::connect(POSTGRES_ADDR.to_string(), postgres::TlsMode::None)
.expect("Can connect to local Postgres")
}
/// Configure Redis
pub fn redis_addr() -> (net::TcpStream, net::TcpStream) {
let pubsub_connection =
net::TcpStream::connect(&REDIS_ADDR.to_string()).expect("Can connect to Redis");
pubsub_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
let secondary_redis_connection =
net::TcpStream::connect(&REDIS_ADDR.to_string()).expect("Can connect to Redis");
secondary_redis_connection
.set_read_timeout(Some(time::Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
(pubsub_connection, secondary_redis_connection)
}
#[derive(Serialize)]
pub struct ErrorMessage {
error: String,
}
impl ErrorMessage {
fn new(msg: impl std::fmt::Display) -> Self {
Self {
error: msg.to_string(),
}
}
}
/// Recover from Errors by sending appropriate Warp::Rejections
pub fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {
let err_txt = match rejection.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Error: Nonexistant endpoint".to_string(),
};
let json = warp::reply::json(&ErrorMessage::new(err_txt));
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
}
pub struct CustomError {}
impl CustomError {
pub fn unauthorized_list() -> warp::reject::Rejection {
warp::reject::custom("Error: Access to list not authorized")
}
}

View File

@ -1,32 +0,0 @@
//! Custom Errors and Warp::Rejections
use serde_derive::Serialize;
#[derive(Serialize)]
pub struct ErrorMessage {
error: String,
}
impl ErrorMessage {
fn new(msg: impl std::fmt::Display) -> Self {
Self {
error: msg.to_string(),
}
}
}
/// Recover from Errors by sending appropriate Warp::Rejections
pub fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {
let err_txt = match rejection.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Error: Nonexistant endpoint".to_string(),
};
let json = warp::reply::json(&ErrorMessage::new(err_txt));
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
}

42
src/lib.rs Normal file
View File

@ -0,0 +1,42 @@
//! Streaming server for Mastodon
//!
//!
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
//! connect to the server with the API described [in Mastodon's public API
//! documentation](https://docs.joinmastodon.org/api/streaming/).
//!
//! # Data Flow
//! * **Parsing the client request**
//! When the client request first comes in, it is parsed based on the endpoint it targets (for
//! server sent events), its query parameters, and its headers (for WebSocket). Based on this
//! data, we authenticate the user, retrieve relevant user data from Postgres, and determine the
//! timeline targeted by the request. Successfully parsing the client request results in generating
//! a `User` and target `timeline` for the request. If any requests are invalid/not authorized, we
//! reject them in this stage.
//! * **Streaming update from Redis to the client**:
//! After the user request is parsed, we pass the `User` and `timeline` data on to the
//! `ClientAgent`. The `ClientAgent` is responsible for communicating the user's request to the
//! `Receiver`, polling the `Receiver` for any updates, and then for wording those updates on to the
//! client. The `Receiver`, in tern, is responsible for managing the Redis subscriptions,
//! periodically polling Redis, and sorting the replies from Redis into queues for when it is polled
//! by the `ClientAgent`.
//!
//! # Concurrency
//! The `Receiver` is created when the server is first initialized, and there is only one
//! `Receiver`. Thus, the `Receiver` is a potential bottleneck. On the other hand, each
//! client request results in a new green thread, which spawns its own `ClientAgent`. Thus,
//! their will be many `ClientAgent`s polling a single `Receiver`. Accordingly, it is very
//! important that polling the `Receiver` remain as fast as possible. It is less important
//! that the `Receiver`'s poll of Redis be fast, since there will only ever be one
//! `Receiver`.
//!
//! # Configuration
//! By default, the server uses config values from the `config.rs` module; these values can be
//! overwritten with environmental variables or in the `.env` file. The most important settings
//! for performance control the frequency with which the `ClientAgent` polls the `Receiver` and
//! the frequency with which the `Receiver` polls Redis.
//!
pub mod config;
pub mod parse_client_request;
pub mod redis_to_client_stream;

View File

@ -1,175 +1,115 @@
//! Streaming server for Mastodon
//!
//!
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
//! connect to the server with the API described [in the public API
//! documentation](https://docs.joinmastodon.org/api/streaming/)
//!
//! # Notes on data flow
//! * **Client Request → Warp**:
//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User`
//! representing the client that made the request. The `User` is authenticated, if appropriate. Warp
//! repeatedly polls the StreamManager for information relevant to the User.
//!
//! * **Warp → StreamManager**:
//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent
//! access to the (single) `Receiver`, which it can access behind an `Arc<Mutex>`. The `StreamManager`
//! polles the `Receiver` for any updates relvant to the current client. If there are updates, the
//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp.
//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the
//! `Receiver`) when necessary.
//!
//! * **StreamManger → Receiver**:
//! The Receiver receives data from Redis and stores it in a series of queues (one for each
//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages
//! relevant to that StreamManager and removes them from the queue.
use ragequit::{
any_of, config,
parse_client_request::{sse, user, ws},
redis_to_client_stream,
redis_to_client_stream::ClientAgent,
};
pub mod error;
pub mod query;
pub mod receiver;
pub mod redis_cmd;
pub mod stream;
pub mod timeline;
pub mod user;
pub mod ws;
use dotenv::dotenv;
use futures::stream::Stream;
use futures::Async;
use receiver::Receiver;
use std::env;
use std::net::SocketAddr;
use stream::StreamManager;
use user::{OauthScope::*, Scope, User};
use warp::path;
use warp::Filter as WarpFilter;
use warp::{ws::Ws2, Filter as WarpFilter};
fn main() {
pretty_env_logger::init();
dotenv().ok();
config::logging_and_env();
let client_agent_sse = ClientAgent::blank();
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
let redis_updates = StreamManager::new(Receiver::new());
let redis_updates_sse = redis_updates.blank_copy();
let redis_updates_ws = redis_updates.blank_copy();
let routes = any_of!(
// Server Sent Events
//
// For SSE, the API requires users to use different endpoints, so we first filter based on
// the endpoint. Using that endpoint determine the `timeline` the user is requesting,
// the scope for that `timeline`, and authenticate the `User` if they provided a token.
let sse_routes = any_of!(
// GET /api/v1/streaming/user/notification [private; notification filter]
timeline::user_notifications(),
sse::Request::user_notifications(),
// GET /api/v1/streaming/user [private; language filter]
timeline::user(),
sse::Request::user(),
// GET /api/v1/streaming/public/local?only_media=true [public; language filter]
timeline::public_local_media(),
sse::Request::public_local_media(),
// GET /api/v1/streaming/public?only_media=true [public; language filter]
timeline::public_media(),
sse::Request::public_media(),
// GET /api/v1/streaming/public/local [public; language filter]
timeline::public_local(),
sse::Request::public_local(),
// GET /api/v1/streaming/public [public; language filter]
timeline::public(),
sse::Request::public(),
// GET /api/v1/streaming/direct [private; *no* filter]
timeline::direct(),
sse::Request::direct(),
// GET /api/v1/streaming/hashtag?tag=:hashtag [public; no filter]
timeline::hashtag(),
sse::Request::hashtag(),
// GET /api/v1/streaming/hashtag/local?tag=:hashtag [public; no filter]
timeline::hashtag_local(),
sse::Request::hashtag_local(),
// GET /api/v1/streaming/list?list=:list_id [private; no filter]
timeline::list()
sse::Request::list()
)
.untuple_one()
.and(warp::sse())
.map(move |timeline: String, user: User, sse: warp::sse::Sse| {
let mut redis_stream = redis_updates_sse.configure_copy(&timeline, user);
let event_stream = tokio::timer::Interval::new(
std::time::Instant::now(),
std::time::Duration::from_millis(100),
)
.filter_map(move |_| match redis_stream.poll() {
Ok(Async::Ready(Some(json_value))) => Some((
warp::sse::event(json_value["event"].clone().to_string()),
warp::sse::data(json_value["payload"].clone()),
)),
_ => None,
});
sse.reply(warp::sse::keep(event_stream, None))
})
.map(
move |timeline: String, user: user::User, sse_connection_to_client: warp::sse::Sse| {
// Create a new ClientAgent
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(&timeline, user);
// send the updates through the SSE connection
redis_to_client_stream::send_updates_to_sse(client_agent, sse_connection_to_client)
},
)
.with(warp::reply::with::header("Connection", "keep-alive"))
.recover(error::handle_errors);
.recover(config::handle_errors);
//let redis_updates_ws = StreamManager::new(Receiver::new());
let websocket = path!("api" / "v1" / "streaming")
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.and(warp::query())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.and(warp::ws2())
.and_then(
move |mut user: User,
q: query::Stream,
m: query::Media,
h: query::Hashtag,
l: query::List,
ws: warp::ws::Ws2| {
let scopes = user.scopes.clone();
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if m.is_truthy() => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
// TODO: handle missing query
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, h.tag),
// Private endpoints: User
"user"
if user.id > 0
&& (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) =>
{
format!("{}", user.id)
}
"user:notification"
if user.id > 0
&& (scopes.contains(&Read) || scopes.contains(&ReadNotifications)) =>
{
user = user.with_notification_filter();
format!("{}", user.id)
}
// List endpoint:
// TODO: handle missing query
"list"
if user.authorized_for_list(l.list).is_ok()
&& (scopes.contains(&Read) || scopes.contains(&ReadList)) =>
{
format!("list:{}", l.list)
}
// WebSocket
//
// For WS, the API specifies a single endpoint, so we extract the User/timeline pair
// directy from the query
let websocket_routes = ws::extract_user_and_query()
.and_then(move |mut user: user::User, q: ws::Query, ws: Ws2| {
let token = user.access_token.clone();
let read_scope = user.scopes.clone();
// Direct endpoint:
"direct"
if user.id > 0
&& (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) =>
{
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Invalid Access Token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
};
let token = user.access_token.clone();
let stream = redis_updates_ws.configure_copy(&timeline, user);
let timeline = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
tl @ "public" | tl @ "public:local" => tl.to_string(),
// Hashtag endpoints:
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
// Private endpoints: User
"user" if user.logged_in && (read_scope.all || read_scope.statuses) => {
format!("{}", user.id)
}
"user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => {
user = user.set_filter(user::Filter::Notification);
format!("{}", user.id)
}
// List endpoint:
"list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
format!("list:{}", q.list)
}
// Direct endpoint:
"direct" if user.logged_in && (read_scope.all || read_scope.statuses) => {
"direct".to_string()
}
// Reject unathorized access attempts for private endpoints
"user" | "user:notification" | "direct" | "list" => {
return Err(warp::reject::custom("Error: Invalid Access Token"))
}
// Other endpoints don't exist:
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
};
Ok((
ws.on_upgrade(move |socket| ws::send_replies(socket, stream)),
token,
))
},
)
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(&timeline, user);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
Ok((
ws.on_upgrade(move |socket| {
redis_to_client_stream::send_updates_to_ws(socket, client_agent)
}),
token,
))
})
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
let address: SocketAddr = env::var("SERVER_ADDR")
.unwrap_or("127.0.0.1:4000".to_owned())
.parse()
.expect("static string");
warp::serve(websocket.or(routes)).run(address);
let cors = config::cross_origin_resource_sharing();
warp::serve(websocket_routes.or(sse_routes).with(cors)).run(*config::SERVER_ADDR);
}

View File

@ -0,0 +1,5 @@
//! Parse the client request and return a 'timeline' and a (maybe authenticated) `User`
pub mod query;
pub mod sse;
pub mod user;
pub mod ws;

View File

@ -0,0 +1,45 @@
//! Validate query prarams with type checking
use serde_derive::Deserialize;
use warp::filters::BoxedFilter;
use warp::Filter as WarpFilter;
macro_rules! query {
($name:tt => $parameter:tt:$type:tt) => {
#[derive(Deserialize, Debug, Default)]
pub struct $name {
pub $parameter: $type,
}
impl $name {
pub fn to_filter() -> BoxedFilter<(Self,)> {
warp::query()
.or(warp::any().map(Self::default))
.unify()
.boxed()
}
}
};
}
query!(Media => only_media:String);
impl Media {
pub fn is_truthy(&self) -> bool {
self.only_media == "true" || self.only_media == "1"
}
}
query!(Hashtag => tag: String);
query!(List => list: i64);
query!(Auth => access_token: String);
query!(Stream => stream: String);
impl ToString for Stream {
fn to_string(&self) -> String {
format!("{:?}", self)
}
}
pub fn optional_media_query() -> BoxedFilter<(Media,)> {
warp::query()
.or(warp::any().map(|| Media {
only_media: "false".to_owned(),
}))
.unify()
.boxed()
}

View File

@ -0,0 +1,106 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use super::{
query,
user::{Filter::*, Scope, User},
};
use crate::{config::CustomError, user_from_path};
use warp::{filters::BoxedFilter, path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, User),);
pub enum Request {}
impl Request {
/// GET /api/v1/streaming/user
pub fn user() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "user", Scope::Private)
.map(|user: User| (user.id.to_string(), user))
.boxed()
}
/// GET /api/v1/streaming/user/notification
///
///
/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented?
pub fn user_notifications() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "user" / "notification", Scope::Private)
.map(|user: User| (user.id.to_string(), user.set_filter(Notification)))
.boxed()
}
/// GET /api/v1/streaming/public
pub fn public() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "public", Scope::Public)
.map(|user: User| ("public".to_owned(), user.set_filter(Language)))
.boxed()
}
/// GET /api/v1/streaming/public?only_media=true
pub fn public_media() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "public", Scope::Public)
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:media".to_owned(), user.set_filter(Language)),
_ => ("public".to_owned(), user.set_filter(Language)),
})
.boxed()
}
/// GET /api/v1/streaming/public/local
pub fn public_local() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "public" / "local", Scope::Public)
.map(|user: User| ("public:local".to_owned(), user.set_filter(Language)))
.boxed()
}
/// GET /api/v1/streaming/public/local?only_media=true
pub fn public_local_media() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "public" / "local", Scope::Public)
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)),
_ => ("public:local".to_owned(), user.set_filter(Language)),
})
.boxed()
}
/// GET /api/v1/streaming/direct
pub fn direct() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "direct", Scope::Private)
.map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter)))
.boxed()
}
/// GET /api/v1/streaming/hashtag?tag=:hashtag
pub fn hashtag() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag")
.and(warp::query())
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/hashtag/local?tag=:hashtag
pub fn hashtag_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(warp::query())
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/list?list=:list_id
pub fn list() -> BoxedFilter<TimelineUser> {
user_from_path!("streaming" / "list", Scope::Private)
.and(warp::query())
.and_then(|user: User, q: query::List| {
if user.owns_list(q.list) {
(Ok(q.list), Ok(user))
} else {
(Err(CustomError::unauthorized_list()), Ok(user))
}
})
.untuple_one()
.map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter)))
.boxed()
}
}

View File

@ -0,0 +1,149 @@
//! `User` struct and related functionality
mod postgres;
use crate::parse_client_request::query;
use log::info;
use warp::Filter as WarpFilter;
/// Combine multiple routes with the same return type together with
/// `or()` and `unify()`
#[macro_export]
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}
/// The filters that can be applied to toots after they come from Redis
#[derive(Clone, Debug, PartialEq)]
pub enum Filter {
NoFilter,
Language,
Notification,
}
/// The User (with data read from Postgres)
#[derive(Clone, Debug, PartialEq)]
pub struct User {
pub id: i64,
pub access_token: String,
pub scopes: OauthScope,
pub langs: Option<Vec<String>>,
pub logged_in: bool,
pub filter: Filter,
}
impl Default for User {
fn default() -> Self {
User::public()
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub struct OauthScope {
pub all: bool,
pub statuses: bool,
pub notify: bool,
pub lists: bool,
}
impl From<Vec<String>> for OauthScope {
fn from(scope_list: Vec<String>) -> Self {
let mut oauth_scope = OauthScope::default();
for scope in scope_list {
match scope.as_str() {
"read" => oauth_scope.all = true,
"read:statuses" => oauth_scope.statuses = true,
"read:notifications" => oauth_scope.notify = true,
"read:lists" => oauth_scope.lists = true,
_ => (),
}
}
oauth_scope
}
}
/// Create a user based on the supplied path and access scope for the resource
#[macro_export]
macro_rules! user_from_path {
($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +)
.and($scope.get_access_token())
.and_then(|token| User::from_access_token(token, $scope)))
}
impl User {
/// Create a user from the access token supplied in the header or query paramaters
pub fn from_access_token(
access_token: String,
scope: Scope,
) -> Result<Self, warp::reject::Rejection> {
let (id, langs, scope_list) = postgres::query_for_user_data(&access_token);
let scopes = OauthScope::from(scope_list);
if id != -1 || scope == Scope::Public {
let (logged_in, log_msg) = match id {
-1 => (false, "Public access to non-authenticated endpoints"),
_ => (true, "Granting logged-in access"),
};
info!("{}", log_msg);
Ok(User {
id,
access_token,
scopes,
langs,
logged_in,
filter: Filter::NoFilter,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
/// Set the Notification/Language filter
pub fn set_filter(self, filter: Filter) -> Self {
Self { filter, ..self }
}
/// Determine whether the User is authorised for a specified list
pub fn owns_list(&self, list: i64) -> bool {
match postgres::query_list_owner(list) {
Some(i) if i == self.id => true,
_ => false,
}
}
/// A public (non-authenticated) User
pub fn public() -> Self {
User {
id: -1,
access_token: String::from("no access token"),
scopes: OauthScope::default(),
langs: None,
logged_in: false,
filter: Filter::NoFilter,
}
}
}
/// Whether the endpoint requires authentication or not
#[derive(PartialEq)]
pub enum Scope {
Public,
Private,
}
impl Scope {
pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> {
let token_from_header_http_push = warp::header::header::<String>("authorization")
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
let token_from_header_ws =
warp::header::header::<String>("Sec-WebSocket-Protocol").map(|auth: String| auth);
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
let private_scopes = any_of!(
token_from_header_http_push,
token_from_header_ws,
token_from_query
);
let public = warp::any().map(|| "no access token".to_string());
match self {
// if they're trying to access a private scope without an access token, reject the request
Scope::Private => private_scopes.boxed(),
// if they're trying to access a public scope without an access token, proceed
Scope::Public => any_of!(private_scopes, public).boxed(),
}
}
}

View File

@ -0,0 +1,53 @@
//! Postgres queries
use crate::config;
pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec<String>) {
let conn = config::postgres();
let query_result = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&access_token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !query_result.is_empty() {
let only_row = query_result.get(0);
let id: i64 = only_row.get(1);
let scopes = only_row
.get::<_, String>(3)
.split(' ')
.map(|s| s.to_owned())
.collect();
let langs: Option<Vec<String>> = only_row.get(2);
(id, langs, scopes)
} else {
(-1, None, Vec::new())
}
}
pub fn query_list_owner(list_id: i64) -> Option<i64> {
let conn = config::postgres();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if rows.is_empty() {
None
} else {
Some(rows.get(0).get(1))
}
}

View File

@ -0,0 +1,43 @@
//! Filters for the WebSocket endpoint
use super::{
query,
user::{Scope, User},
};
use crate::user_from_path;
use warp::{filters::BoxedFilter, path, Filter};
/// WebSocket filters
pub fn extract_user_and_query() -> BoxedFilter<(User, Query, warp::ws::Ws2)> {
user_from_path!("streaming", Scope::Public)
.and(warp::query())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.and(warp::ws2())
.map(
|user: User,
stream: query::Stream,
media: query::Media,
hashtag: query::Hashtag,
list: query::List,
ws: warp::ws::Ws2| {
let query = Query {
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
};
(user, query, ws)
},
)
.untuple_one()
.boxed()
}
#[derive(Debug)]
pub struct Query {
pub stream: String,
pub media: bool,
pub hashtag: String,
pub list: i64,
}

View File

@ -1,66 +0,0 @@
//! Validate query prarams with type checking
use serde_derive::Deserialize;
use warp::filters::BoxedFilter;
use warp::Filter as WarpFilter;
#[derive(Deserialize, Debug, Default)]
pub struct Media {
pub only_media: String,
}
impl Media {
pub fn to_filter() -> BoxedFilter<(Self,)> {
warp::query()
.or(warp::any().map(Self::default))
.unify()
.boxed()
}
pub fn is_truthy(&self) -> bool {
self.only_media == "true" || self.only_media == "1"
}
}
#[derive(Deserialize, Debug, Default)]
pub struct Hashtag {
pub tag: String,
}
impl Hashtag {
pub fn to_filter() -> BoxedFilter<(Self,)> {
warp::query()
.or(warp::any().map(Self::default))
.unify()
.boxed()
}
}
#[derive(Deserialize, Debug, Default)]
pub struct List {
pub list: i64,
}
impl List {
pub fn to_filter() -> BoxedFilter<(Self,)> {
warp::query()
.or(warp::any().map(Self::default))
.unify()
.boxed()
}
}
#[derive(Deserialize, Debug)]
pub struct Auth {
pub access_token: String,
}
#[derive(Deserialize, Debug)]
pub struct Stream {
pub stream: String,
}
impl ToString for Stream {
fn to_string(&self) -> String {
format!("{:?}", self)
}
}
pub fn optional_media_query() -> BoxedFilter<(Media,)> {
warp::query()
.or(warp::any().map(|| Media {
only_media: "false".to_owned(),
}))
.unify()
.boxed()
}

View File

@ -1,218 +0,0 @@
//! Interfacing with Redis and stream the results on to the `StreamManager`
use crate::redis_cmd;
use crate::user::User;
use futures::stream::Stream;
use futures::{Async, Poll};
use log::info;
use regex::Regex;
use serde_json::Value;
use std::collections::{HashMap, VecDeque};
use std::env;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, Error};
use uuid::Uuid;
#[derive(Debug)]
struct MsgQueue {
messages: VecDeque<Value>,
last_polled_at: Instant,
redis_channel: String,
}
impl MsgQueue {
fn new(redis_channel: impl std::fmt::Display) -> Self {
let redis_channel = redis_channel.to_string();
MsgQueue {
messages: VecDeque::new(),
last_polled_at: Instant::now(),
redis_channel,
}
}
}
/// The item that streams from Redis and is polled by the `StreamManger`
#[derive(Debug)]
pub struct Receiver {
pubsub_connection: TcpStream,
secondary_redis_connection: TcpStream,
tl: String,
pub user: User,
manager_id: Uuid,
msg_queues: HashMap<Uuid, MsgQueue>,
clients_per_timeline: HashMap<String, i32>,
}
impl Default for Receiver {
fn default() -> Self {
Self::new()
}
}
impl Receiver {
pub fn new() -> Self {
let redis_addr = env::var("REDIS_ADDR").unwrap_or("127.0.0.1:6379".to_string());
let pubsub_connection = TcpStream::connect(&redis_addr).expect("Can connect to Redis");
pubsub_connection
.set_read_timeout(Some(Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
let secondary_redis_connection =
TcpStream::connect(&redis_addr).expect("Can connect to Redis");
secondary_redis_connection
.set_read_timeout(Some(Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
Self {
pubsub_connection,
secondary_redis_connection,
tl: String::new(),
user: User::public(),
manager_id: Uuid::new_v4(),
msg_queues: HashMap::new(),
clients_per_timeline: HashMap::new(),
}
}
/// Update the `StreamManager` that is currently polling the `Receiver`
pub fn update(&mut self, id: Uuid, timeline: impl std::fmt::Display) {
self.manager_id = id;
self.tl = timeline.to_string();
}
/// Send a subscribe command to the Redis PubSub (if needed)
pub fn maybe_subscribe(&mut self, tl: &str) {
info!("Subscribing to {}", &tl);
let manager_id = self.manager_id;
self.msg_queues.insert(manager_id, MsgQueue::new(tl));
let current_clients = self
.clients_per_timeline
.entry(tl.to_string())
.and_modify(|n| *n += 1)
.or_insert(1);
if *current_clients == 1 {
let subscribe_cmd = redis_cmd::pubsub("subscribe", tl);
self.pubsub_connection
.write_all(&subscribe_cmd)
.expect("Can subscribe to Redis");
let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "1");
self.secondary_redis_connection
.write_all(&set_subscribed_cmd)
.expect("Can set Redis");
info!("Now subscribed to: {:#?}", &self.msg_queues);
}
}
/// Drop any PubSub subscriptions that don't have active clients
pub fn unsubscribe_from_empty_channels(&mut self) {
let mut timelines_with_fewer_clients = Vec::new();
// Keep only message queues that have been polled recently
self.msg_queues.retain(|_id, msg_queue| {
if msg_queue.last_polled_at.elapsed() < Duration::from_secs(30) {
true
} else {
timelines_with_fewer_clients.push(msg_queue.redis_channel.clone());
false
}
});
// Record the lower number of clients subscribed to that channel
for timeline in timelines_with_fewer_clients {
let count_of_subscribed_clients = self
.clients_per_timeline
.entry(timeline.clone())
.and_modify(|n| *n -= 1)
.or_insert(0);
// If no clients, unsubscribe from the channel
if *count_of_subscribed_clients <= 0 {
self.unsubscribe(&timeline);
}
}
}
/// Send an unsubscribe command to the Redis PubSub
pub fn unsubscribe(&mut self, tl: &str) {
let unsubscribe_cmd = redis_cmd::pubsub("unsubscribe", tl);
info!("Unsubscribing from {}", &tl);
self.pubsub_connection
.write_all(&unsubscribe_cmd)
.expect("Can unsubscribe from Redis");
let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "0");
self.secondary_redis_connection
.write_all(&set_subscribed_cmd)
.expect("Can set Redis");
info!("Now subscribed only to: {:#?}", &self.msg_queues);
}
}
impl Stream for Receiver {
type Item = Value;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
let mut buffer = vec![0u8; 3000];
info!("Being polled by: {}", self.manager_id);
let timeline = self.tl.clone();
// Record current time as last polled time
self.msg_queues
.entry(self.manager_id)
.and_modify(|msg_queue| msg_queue.last_polled_at = Instant::now());
// Add any incomming messages to the back of the relevant `msg_queues`
// NOTE: This could be more/other than the `msg_queue` currently being polled
let mut async_stream = AsyncReadableStream(&mut self.pubsub_connection);
if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer)? {
let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]);
// capture everything between `{` and `}` as potential JSON
let json_regex = Regex::new(r"(?P<json>\{.*\})").expect("Hard-coded");
// capture the timeline so we know which queues to add it to
let timeline_regex = Regex::new(r"timeline:(?P<timeline>.*?)\r").expect("Hard-codded");
if let Some(result) = json_regex.captures(raw_redis_response) {
let timeline =
timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string();
let msg: Value = serde_json::from_str(&result["json"].to_string().clone())?;
for msg_queue in self.msg_queues.values_mut() {
if msg_queue.redis_channel == timeline {
msg_queue.messages.push_back(msg.clone());
}
}
}
}
// If the `msg_queue` being polled has any new messages, return the first (oldest) one
match self
.msg_queues
.entry(self.manager_id)
.or_insert_with(|| MsgQueue::new(timeline))
.messages
.pop_front()
{
Some(value) => Ok(Async::Ready(Some(value))),
_ => Ok(Async::NotReady),
}
}
}
impl Drop for Receiver {
fn drop(&mut self) {
let timeline = self.tl.clone();
self.unsubscribe(&timeline);
}
}
struct AsyncReadableStream<'a>(&'a mut TcpStream);
impl<'a> Read for AsyncReadableStream<'a> {
fn read(&mut self, buffer: &mut [u8]) -> Result<usize, std::io::Error> {
self.0.read(buffer)
}
}
impl<'a> AsyncRead for AsyncReadableStream<'a> {
fn poll_read(&mut self, buf: &mut [u8]) -> Poll<usize, std::io::Error> {
match self.read(buf) {
Ok(t) => Ok(Async::Ready(t)),
Err(_) => Ok(Async::NotReady),
}
}
}

View File

@ -0,0 +1,161 @@
//! Provides an interface between the `Warp` filters and the underlying
//! mechanics of talking with Redis/managing multiple threads.
//!
//! The `ClientAgent`'s interface is very simple. All you can do with it is:
//! * Create a totally new `ClientAgent` with no shared data;
//! * Clone an existing `ClientAgent`, sharing the `Receiver`;
//! * Manage an new timeline/user pair; or
//! * Poll an existing `ClientAgent` to see if there are any new messages
//! for clients
//!
//! When you poll the `ClientAgent`, it is responsible for polling internal data
//! structures, getting any updates from Redis, and then filtering out any updates
//! that should be excluded by relevant filters.
//!
//! Because `StreamManagers` are lightweight data structures that do not directly
//! communicate with Redis, it we create a new `ClientAgent` for
//! each new client connection (each in its own thread).
use super::receiver::Receiver;
use crate::parse_client_request::user::User;
use futures::{Async, Poll};
use log;
use serde_json::{json, Value};
use std::{sync, time};
use tokio::io::Error;
use uuid::Uuid;
/// Struct for managing all Redis streams.
#[derive(Clone, Default, Debug)]
pub struct ClientAgent {
receiver: sync::Arc<sync::Mutex<Receiver>>,
id: uuid::Uuid,
target_timeline: String,
current_user: User,
}
impl ClientAgent {
/// Create a new `ClientAgent` with no shared data.
pub fn blank() -> Self {
ClientAgent {
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())),
id: Uuid::default(),
target_timeline: String::new(),
current_user: User::public(),
}
}
/// Clones the `ClientAgent`, sharing the `Receiver`.
pub fn clone_with_shared_receiver(&self) -> Self {
Self {
receiver: self.receiver.clone(),
id: self.id,
target_timeline: self.target_timeline.clone(),
current_user: self.current_user.clone(),
}
}
/// Initializes the `ClientAgent` with a unique ID, a `User`, and the target timeline.
/// Also passes values to the `Receiver` for it's initialization.
///
/// Note that this *may or may not* result in a new Redis connection.
/// If the server has already subscribed to the timeline on behalf of
/// a different user, the `Receiver` is responsible for figuring
/// that out and avoiding duplicated connections. Thus, it is safe to
/// use this method for each new client connection.
pub fn init_for_user(&mut self, target_timeline: &str, user: User) {
self.id = Uuid::new_v4();
self.target_timeline = target_timeline.to_owned();
self.current_user = user;
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
receiver.manage_new_timeline(self.id, target_timeline);
}
}
/// The stream that the `ClientAgent` manages. `Poll` is the only method implemented.
impl futures::stream::Stream for ClientAgent {
type Item = Value;
type Error = Error;
/// Checks for any new messages that should be sent to the client.
///
/// The `ClientAgent` polls the `Receiver` and replies
/// with `Ok(Ready(Some(Value)))` if there is a new message to send to
/// the client. If there is no new message or if the new message should be
/// filtered out based on one of the user's filters, then the `ClientAgent`
/// replies with `Ok(NotReady)`. The `ClientAgent` bubles up any
/// errors from the underlying data structures.
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let start_time = time::Instant::now();
let result = {
let mut receiver = self
.receiver
.lock()
.expect("ClientAgent: No other thread panic");
receiver.configure_for_polling(self.id, &self.target_timeline.clone());
receiver.poll()
};
if start_time.elapsed() > time::Duration::from_millis(20) {
log::warn!("Polling took: {:?}", start_time.elapsed());
}
match result {
Ok(Async::Ready(Some(value))) => {
let user = &self.current_user;
let toot = Toot::from_json(value);
toot.filter(&user)
}
Ok(inner_value) => Ok(inner_value),
Err(e) => Err(e),
}
}
}
/// The message to send to the client (which might not literally be a toot in some cases).
struct Toot {
category: String,
payload: String,
language: String,
}
impl Toot {
/// Construct a `Toot` from well-formed JSON.
fn from_json(value: Value) -> Self {
Self {
category: value["event"].as_str().expect("Redis string").to_owned(),
payload: value["payload"].to_string(),
language: value["payload"]["language"]
.as_str()
.expect("Redis str")
.to_string(),
}
}
/// Convert a `Toot` to JSON inside an Option.
fn to_optional_json(&self) -> Option<Value> {
Some(json!(
{"event": self.category,
"payload": self.payload,}
))
}
/// Filter out any `Toot`'s that fail the provided filter.
fn filter(&self, user: &User) -> Result<Async<Option<Value>>, Error> {
let toot = self;
let (send_msg, skip_msg) = (
Ok(Async::Ready(toot.to_optional_json())),
Ok(Async::NotReady),
);
use crate::parse_client_request::user::Filter;
match &user.filter {
Filter::NoFilter => send_msg,
Filter::Notification if toot.category == "notification" => send_msg,
// If not, skip it
Filter::Notification => skip_msg,
Filter::Language if user.langs.is_none() => send_msg,
Filter::Language if user.langs.clone().expect("").contains(&toot.language) => send_msg,
// If not, skip it
Filter::Language => skip_msg,
}
}
}

View File

@ -0,0 +1,72 @@
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
pub mod client_agent;
pub mod receiver;
pub mod redis_cmd;
use crate::config;
pub use client_agent::ClientAgent;
use futures::{future::Future, stream::Stream, Async};
use log;
use std::time;
/// Send a stream of replies to a Server Sent Events client.
pub fn send_updates_to_sse(
mut client_agent: ClientAgent,
connection: warp::sse::Sse,
) -> impl warp::reply::Reply {
let event_stream = tokio::timer::Interval::new(
time::Instant::now(),
time::Duration::from_millis(*config::SSE_UPDATE_INTERVAL),
)
.filter_map(move |_| match client_agent.poll() {
Ok(Async::Ready(Some(json_value))) => Some((
warp::sse::event(json_value["event"].clone().to_string()),
warp::sse::data(json_value["payload"].clone()),
)),
_ => None,
});
connection.reply(warp::sse::keep(event_stream, None))
}
/// Send a stream of replies to a WebSocket client.
pub fn send_updates_to_ws(
socket: warp::ws::WebSocket,
mut stream: ClientAgent,
) -> impl futures::future::Future<Item = (), Error = ()> {
let (ws_tx, mut ws_rx) = socket.split();
// Create a pipe
let (tx, rx) = futures::sync::mpsc::unbounded();
// Send one end of it to a different thread and tell that end to forward whatever it gets
// on to the websocket client
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map_err(|_| ())
.map(|_r| ()),
);
// For as long as the client is still connected, yeild a new event every 100 ms
let event_stream = tokio::timer::Interval::new(
time::Instant::now(),
time::Duration::from_millis(*config::WS_UPDATE_INTERVAL),
)
.take_while(move |_| match ws_rx.poll() {
Ok(Async::Ready(None)) => futures::future::ok(false),
_ => futures::future::ok(true),
});
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_json_value| {
if let Ok(Async::Ready(Some(json_value))) = stream.poll() {
let msg = warp::ws::Message::text(json_value.to_string());
tx.unbounded_send(msg).expect("No send error");
};
Ok(())
})
.then(|msg| msg)
.map_err(|e| log::error!("{}", e))
}

View File

@ -0,0 +1,251 @@
//! Receives data from Redis, sorts it by `ClientAgent`, and stores it until
//! polled by the correct `ClientAgent`. Also manages sububscriptions and
//! unsubscriptions to/from Redis.
use super::redis_cmd;
use crate::{config, pubsub_cmd};
use futures::{Async, Poll};
use log::info;
use regex::Regex;
use serde_json::Value;
use std::{collections, io::Read, io::Write, net, time};
use tokio::io::{AsyncRead, Error};
use uuid::Uuid;
/// The item that streams from Redis and is polled by the `ClientAgent`
#[derive(Debug)]
pub struct Receiver {
pubsub_connection: net::TcpStream,
secondary_redis_connection: net::TcpStream,
redis_polled_at: time::Instant,
timeline: String,
manager_id: Uuid,
msg_queues: collections::HashMap<Uuid, MsgQueue>,
clients_per_timeline: collections::HashMap<String, i32>,
}
impl Receiver {
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
/// active subscriptions).
pub fn new() -> Self {
let (pubsub_connection, secondary_redis_connection) = config::redis_addr();
Self {
pubsub_connection,
secondary_redis_connection,
redis_polled_at: time::Instant::now(),
timeline: String::new(),
manager_id: Uuid::default(),
msg_queues: collections::HashMap::new(),
clients_per_timeline: collections::HashMap::new(),
}
}
/// Assigns the `Receiver` a new timeline to monitor and runs other
/// first-time setup.
///
/// Note: this method calls `subscribe_or_unsubscribe_as_needed`,
/// so Redis PubSub subscriptions are only updated when a new timeline
/// comes under management for the first time.
pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) {
self.manager_id = manager_id;
self.timeline = timeline.to_string();
self.msg_queues
.insert(self.manager_id, MsgQueue::new(timeline));
self.subscribe_or_unsubscribe_as_needed(timeline);
}
/// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate
/// value to be polled by the current `StreamManager`.
pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) {
self.manager_id = manager_id;
self.timeline = timeline.to_string();
}
/// Drop any PubSub subscriptions that don't have active clients and check
/// that there's a subscription to the current one. If there isn't, then
/// subscribe to it.
fn subscribe_or_unsubscribe_as_needed(&mut self, timeline: &str) {
let mut timelines_to_modify = Vec::new();
struct Change {
timeline: String,
change_in_subscriber_number: i32,
}
timelines_to_modify.push(Change {
timeline: timeline.to_owned(),
change_in_subscriber_number: 1,
});
// Keep only message queues that have been polled recently
self.msg_queues.retain(|_id, msg_queue| {
if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) {
true
} else {
let timeline = &msg_queue.redis_channel;
timelines_to_modify.push(Change {
timeline: timeline.to_owned(),
change_in_subscriber_number: -1,
});
false
}
});
// Record the lower number of clients subscribed to that channel
for change in timelines_to_modify {
let mut need_to_subscribe = false;
let count_of_subscribed_clients = self
.clients_per_timeline
.entry(change.timeline.clone())
.and_modify(|n| *n += change.change_in_subscriber_number)
.or_insert_with(|| {
need_to_subscribe = true;
1
});
// If no clients, unsubscribe from the channel
if *count_of_subscribed_clients <= 0 {
pubsub_cmd!("unsubscribe", self, change.timeline.clone());
}
if need_to_subscribe {
pubsub_cmd!("subscribe", self, change.timeline.clone());
}
}
}
fn log_number_of_msgs_in_queue(&self) {
let messages_waiting = self
.msg_queues
.get(&self.manager_id)
.expect("Guaranteed by match block")
.messages
.len();
match messages_waiting {
number if number > 10 => {
log::error!("{} messages waiting in the queue", messages_waiting)
}
_ => log::info!("{} messages waiting in the queue", messages_waiting),
}
}
fn get_target_msg_queue(&mut self) -> collections::hash_map::Entry<Uuid, MsgQueue> {
self.msg_queues.entry(self.manager_id)
}
}
impl Default for Receiver {
fn default() -> Self {
Receiver::new()
}
}
/// The stream that the ClientAgent polls to learn about new messages.
impl futures::stream::Stream for Receiver {
type Item = Value;
type Error = Error;
/// Returns the oldest message in the `ClientAgent`'s queue (if any).
///
/// Note: This method does **not** poll Redis every time, because polling
/// Redis is signifiantly more time consuming that simply returning the
/// message already in a queue. Thus, we only poll Redis if it has not
/// been polled lately.
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
let timeline = self.timeline.clone();
if self.redis_polled_at.elapsed()
> time::Duration::from_millis(*config::REDIS_POLL_INTERVAL)
{
AsyncReadableStream::poll_redis(self);
self.redis_polled_at = time::Instant::now();
}
// Record current time as last polled time
self.get_target_msg_queue()
.and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now());
// If the `msg_queue` being polled has any new messages, return the first (oldest) one
match self
.get_target_msg_queue()
.or_insert_with(|| MsgQueue::new(timeline.clone()))
.messages
.pop_front()
{
Some(value) => {
self.log_number_of_msgs_in_queue();
Ok(Async::Ready(Some(value)))
}
_ => Ok(Async::NotReady),
}
}
}
impl Drop for Receiver {
fn drop(&mut self) {
pubsub_cmd!("unsubscribe", self, self.timeline.clone());
}
}
#[derive(Debug, Clone)]
struct MsgQueue {
messages: collections::VecDeque<Value>,
last_polled_at: time::Instant,
redis_channel: String,
}
impl MsgQueue {
fn new(redis_channel: impl std::fmt::Display) -> Self {
let redis_channel = redis_channel.to_string();
MsgQueue {
messages: collections::VecDeque::new(),
last_polled_at: time::Instant::now(),
redis_channel,
}
}
}
struct AsyncReadableStream<'a>(&'a mut net::TcpStream);
impl<'a> AsyncReadableStream<'a> {
fn new(stream: &'a mut net::TcpStream) -> Self {
AsyncReadableStream(stream)
}
/// Polls Redis for any new messages and adds them to the `MsgQueue` for
/// the appropriate `ClientAgent`.
fn poll_redis(receiver: &mut Receiver) {
let mut buffer = vec![0u8; 3000];
// Add any incoming messages to the back of the relevant `msg_queues`
// NOTE: This could be more/other than the `msg_queue` currently being polled
let mut async_stream = AsyncReadableStream::new(&mut receiver.pubsub_connection);
if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer).unwrap() {
let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]);
// capture everything between `{` and `}` as potential JSON
let json_regex = Regex::new(r"(?P<json>\{.*\})").expect("Hard-coded");
// capture the timeline so we know which queues to add it to
let timeline_regex = Regex::new(r"timeline:(?P<timeline>.*?)\r").expect("Hard-codded");
if let Some(result) = json_regex.captures(raw_redis_response) {
let timeline =
timeline_regex.captures(raw_redis_response).unwrap()["timeline"].to_string();
let msg: Value = serde_json::from_str(&result["json"].to_string().clone()).unwrap();
for msg_queue in receiver.msg_queues.values_mut() {
if msg_queue.redis_channel == timeline {
msg_queue.messages.push_back(msg.clone());
}
}
}
}
}
}
impl<'a> Read for AsyncReadableStream<'a> {
fn read(&mut self, buffer: &mut [u8]) -> Result<usize, std::io::Error> {
self.0.read(buffer)
}
}
impl<'a> AsyncRead for AsyncReadableStream<'a> {
fn poll_read(&mut self, buf: &mut [u8]) -> Poll<usize, std::io::Error> {
match self.read(buf) {
Ok(t) => Ok(Async::Ready(t)),
Err(_) => Ok(Async::NotReady),
}
}
}

View File

@ -1,10 +1,32 @@
//! Send raw TCP commands to the Redis server
use log::info;
use std::fmt::Display;
/// Send a subscribe or unsubscribe to the Redis PubSub channel
#[macro_export]
macro_rules! pubsub_cmd {
($cmd:expr, $self:expr, $tl:expr) => {{
info!("Sending {} command to {}", $cmd, $tl);
$self
.pubsub_connection
.write_all(&redis_cmd::pubsub($cmd, $tl))
.expect("Can send command to Redis");
let new_value = if $cmd == "subscribe" { "1" } else { "0" };
$self
.secondary_redis_connection
.write_all(&redis_cmd::set(
format!("subscribed:timeline:{}", $tl),
new_value,
))
.expect("Can set Redis");
info!("Now subscribed to: {:#?}", $self.msg_queues);
}};
}
/// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline
pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec<u8> {
let arg = format!("timeline:{}", timeline);
let command = command.to_string();
info!("Sent {} command", &command);
format!(
"*2\r\n${cmd_length}\r\n{cmd}\r\n${arg_length}\r\n{arg}\r\n",
cmd_length = command.len(),

View File

@ -1,93 +0,0 @@
//! Manage all existing Redis PubSub connection
use crate::receiver::Receiver;
use crate::user::{Filter, User};
use futures::stream::Stream;
use futures::{Async, Poll};
use serde_json::json;
use serde_json::Value;
use std::sync::{Arc, Mutex};
use tokio::io::Error;
use uuid::Uuid;
/// Struct for manageing all Redis streams
#[derive(Clone, Debug)]
pub struct StreamManager {
receiver: Arc<Mutex<Receiver>>,
id: uuid::Uuid,
target_timeline: String,
current_user: Option<User>,
}
impl StreamManager {
pub fn new(reciever: Receiver) -> Self {
StreamManager {
receiver: Arc::new(Mutex::new(reciever)),
id: Uuid::default(),
target_timeline: String::new(),
current_user: None,
}
}
/// Create a blank StreamManager copy
pub fn blank_copy(&self) -> Self {
StreamManager { ..self.clone() }
}
/// Create a StreamManager copy with a new unique id manage subscriptions
pub fn configure_copy(&self, timeline: &String, user: User) -> Self {
let id = Uuid::new_v4();
let mut receiver = self.receiver.lock().expect("No panic in other threads");
receiver.update(id, timeline);
receiver.maybe_subscribe(timeline);
StreamManager {
id,
current_user: Some(user),
target_timeline: timeline.clone(),
..self.clone()
}
}
}
impl Stream for StreamManager {
type Item = Value;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let mut receiver = self
.receiver
.lock()
.expect("StreamManager: No other thread panic");
receiver.update(self.id, &self.target_timeline.clone());
match receiver.poll() {
Ok(Async::Ready(Some(value))) => {
let user = self
.clone()
.current_user
.expect("Previously set current user");
let user_langs = user.langs.clone();
let event = value["event"].as_str().expect("Redis string");
let payload = value["payload"].to_string();
match (&user.filter, user_langs) {
(Filter::Notification, _) if event != "notification" => Ok(Async::NotReady),
(Filter::Language, Some(ref user_langs))
if !user_langs.contains(
&value["payload"]["language"]
.as_str()
.expect("Redis str")
.to_string(),
) =>
{
Ok(Async::NotReady)
}
_ => Ok(Async::Ready(Some(json!(
{"event": event,
"payload": payload,}
)))),
}
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e),
}
}
}

View File

@ -1,499 +0,0 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use crate::query;
use crate::user::{Scope, User};
use warp::filters::BoxedFilter;
use warp::{path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, User),);
/// GET /api/v1/streaming/user
///
///
/// **private**. Filter: `Language`
pub fn user() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "user")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user))
.boxed()
}
/// GET /api/v1/streaming/user/notification
///
///
/// **private**. Filter: `Notification`
///
///
/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented?
pub fn user_notifications() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "user" / "notification")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (user.id.to_string(), user.with_notification_filter()))
.boxed()
}
/// GET /api/v1/streaming/public
///
///
/// **public**. Filter: `Language`
pub fn public() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.map(|user: User| ("public".to_owned(), user.with_language_filter()))
.boxed()
}
/// GET /api/v1/streaming/public?only_media=true
///
///
/// **public**. Filter: `Language`
pub fn public_media() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.and(warp::query())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:media".to_owned(), user.with_language_filter()),
_ => ("public".to_owned(), user.with_language_filter()),
})
.boxed()
}
/// GET /api/v1/streaming/public/local
///
///
/// **public**. Filter: `Language`
pub fn public_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public" / "local")
.and(path::end())
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.map(|user: User| ("public:local".to_owned(), user.with_language_filter()))
.boxed()
}
/// GET /api/v1/streaming/public/local?only_media=true
///
///
/// **public**. Filter: `Language`
pub fn public_local_media() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "public" / "local")
.and(Scope::Public.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Public))
.and(warp::query())
.and(path::end())
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
"1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()),
_ => ("public:local".to_owned(), user.with_language_filter()),
})
.boxed()
}
/// GET /api/v1/streaming/direct
///
///
/// **private**. Filter: `None`
pub fn direct() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "direct")
.and(path::end())
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.map(|user: User| (format!("direct:{}", user.id), user.with_no_filter()))
.boxed()
}
/// GET /api/v1/streaming/hashtag?tag=:hashtag
///
///
/// **public**. Filter: `None`
pub fn hashtag() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/hashtag/local?tag=:hashtag
///
///
/// **public**. Filter: `None`
pub fn hashtag_local() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "hashtag" / "local")
.and(warp::query())
.and(path::end())
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()))
.boxed()
}
/// GET /api/v1/streaming/list?list=:list_id
///
///
/// **private**. Filter: `None`
pub fn list() -> BoxedFilter<TimelineUser> {
path!("api" / "v1" / "streaming" / "list")
.and(Scope::Private.get_access_token())
.and_then(|token| User::from_access_token(token, Scope::Private))
.and(warp::query())
.and_then(|user: User, q: query::List| (user.authorized_for_list(q.list), Ok(user)))
.untuple_one()
.and(path::end())
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()))
.boxed()
}
/// Combines multiple routes with the same return type together with
/// `or()` and `unify()`
#[macro_export]
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::user;
#[test]
fn user_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&user());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user",))
.filter(&user());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token={}",
access_token
))
.filter(&user())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user())
.expect("in test");
let expected_user =
User::from_access_token(access_token, user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn user_notifications_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
))
.filter(&user_notifications());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user/notification",))
.filter(&user_notifications());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_notifications_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token={}",
access_token
))
.filter(&user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token.clone(), user::Scope::Private)
.expect("in test")
.with_notification_filter();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user/notification")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token, user::Scope::Private)
.expect("in test")
.with_notification_filter();
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn public_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public")
.filter(&public())
.expect("in test");
assert_eq!(value.0, "public".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=true")
.filter(&public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=1")
.filter(&public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_local_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local")
.filter(&public_local())
.expect("in test");
assert_eq!(value.0, "public:local".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn public_local_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=true")
.filter(&public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=1")
.filter(&public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().with_language_filter());
}
#[test]
fn direct_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
))
.filter(&direct());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/direct",))
.filter(&direct());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn direct_timeline_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token={}",
access_token
))
.filter(&direct())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/direct")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&direct())
.expect("in test");
let expected_user =
User::from_access_token(access_token, user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn hashtag_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag?tag=a")
.filter(&hashtag())
.expect("in test");
assert_eq!(value.0, "hashtag:a".to_string());
assert_eq!(value.1, User::public());
}
#[test]
fn hashtag_timeline_local() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag/local?tag=a")
.filter(&hashtag_local())
.expect("in test");
assert_eq!(value.0, "hashtag:a:local".to_string());
assert_eq!(value.1, User::public());
}
#[test]
#[ignore]
fn list_timeline_auth() {
let list_id = 1;
let list_owner_id = get_list_owner(list_id);
let access_token = get_access_token(list_owner_id);
// Query Auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token={}&list={}",
access_token, list_id,
))
.filter(&list())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
// Header Auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/list?list=1")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&list())
.expect("in test");
let expected_user =
User::from_access_token(access_token, user::Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn list_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&list());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/list?list=1",))
.filter(&list());
assert!(no_access_token(value));
}
fn get_list_owner(list_number: i32) -> i64 {
let list_number: i64 = list_number.into();
let conn = user::connect_to_postgres();
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_number],
)
.expect("in test");
assert_eq!(
rows.len(),
1,
"Test database must contain at least one user with a list to run this test."
);
rows.get(0).get(1)
}
fn get_access_token(user_id: i64) -> String {
let conn = user::connect_to_postgres();
let rows = &conn
.query(
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
&[&user_id],
)
.expect("Can get access token from id");
rows.get(0).get(0)
}
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => {
true
}
_ => false,
},
_ => false,
}
}
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
_ => false,
},
_ => false,
}
}
}

View File

@ -1,192 +0,0 @@
//! Create a User by querying the Postgres database with the user's access_token
use crate::{any_of, query};
use log::info;
use postgres;
use std::env;
use warp::Filter as WarpFilter;
/// (currently hardcoded to localhost)
pub fn connect_to_postgres() -> postgres::Connection {
let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or(format!(
"postgres://{}@localhost/mastodon_development",
env::var("USER").expect("User env var should exist")
));
postgres::Connection::connect(postgres_addr, postgres::TlsMode::None)
.expect("Can connect to local Postgres")
}
/// The filters that can be applied to toots after they come from Redis
#[derive(Clone, Debug, PartialEq)]
pub enum Filter {
None,
Language,
Notification,
}
/// The User (with data read from Postgres)
#[derive(Clone, Debug, PartialEq)]
pub struct User {
pub id: i64,
pub access_token: String,
pub scopes: Vec<OauthScope>,
pub langs: Option<Vec<String>>,
pub logged_in: bool,
pub filter: Filter,
}
#[derive(Clone, Debug, PartialEq)]
pub enum OauthScope {
Read,
ReadStatuses,
ReadNotifications,
ReadList,
Other,
}
impl From<&str> for OauthScope {
fn from(scope: &str) -> Self {
use OauthScope::*;
match scope {
"read" => Read,
"read:statuses" => ReadStatuses,
"read:notifications" => ReadNotifications,
"read:lists" => ReadList,
_ => Other,
}
}
}
impl User {
/// Create a user from the access token supplied in the header or query paramaters
pub fn from_access_token(
access_token: String,
scope: Scope,
) -> Result<Self, warp::reject::Rejection> {
let conn = connect_to_postgres();
let result = &conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&access_token],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !result.is_empty() {
let only_row = result.get(0);
let id: i64 = only_row.get(1);
let scopes = only_row
.get::<_, String>(3)
.split(' ')
.map(|scope: &str| scope.into())
.filter(|scope| scope != &OauthScope::Other)
.collect();
dbg!(&scopes);
let langs: Option<Vec<String>> = only_row.get(2);
info!("Granting logged-in access");
Ok(User {
id,
access_token,
scopes,
langs,
logged_in: true,
filter: Filter::None,
})
} else if let Scope::Public = scope {
info!("Granting public access to non-authenticated client");
Ok(User {
id: -1,
access_token,
scopes: Vec::new(),
langs: None,
logged_in: false,
filter: Filter::None,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
/// Add a Notification filter
pub fn with_notification_filter(self) -> Self {
Self {
filter: Filter::Notification,
..self
}
}
/// Add a Language filter
pub fn with_language_filter(self) -> Self {
Self {
filter: Filter::Language,
..self
}
}
/// Remove all filters
pub fn with_no_filter(self) -> Self {
Self {
filter: Filter::None,
..self
}
}
/// Determine whether the User is authorised for a specified list
pub fn authorized_for_list(&self, list: i64) -> Result<i64, warp::reject::Rejection> {
let conn = connect_to_postgres();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
" SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if !rows.is_empty() {
let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1);
if id_of_account_that_owns_the_list == self.id {
return Ok(list);
}
};
Err(warp::reject::custom("Error: Invalid access token"))
}
/// A public (non-authenticated) User
pub fn public() -> Self {
User {
id: -1,
access_token: String::new(),
scopes: Vec::new(),
langs: None,
logged_in: false,
filter: Filter::None,
}
}
}
/// Whether the endpoint requires authentication or not
pub enum Scope {
Public,
Private,
}
impl Scope {
pub fn get_access_token(self) -> warp::filters::BoxedFilter<(String,)> {
let token_from_header_http_push = warp::header::header::<String>("authorization")
.map(|auth: String| auth.split(' ').nth(1).unwrap_or("invalid").to_string());
let token_from_header_ws =
warp::header::header::<String>("Sec-WebSocket-Protocol").map(|auth: String| auth);
let token_from_query = warp::query().map(|q: query::Auth| q.access_token);
let private_scopes = any_of!(
token_from_header_http_push,
token_from_header_ws,
token_from_query
);
let public = warp::any().map(|| "no access token".to_string());
match self {
// if they're trying to access a private scope without an access token, reject the request
Scope::Private => private_scopes.boxed(),
// if they're trying to access a public scope without an access token, proceed
Scope::Public => any_of!(private_scopes, public).boxed(),
}
}
}

View File

@ -1,44 +0,0 @@
//! WebSocket-specific functionality
use crate::stream::StreamManager;
use futures::future::Future;
use futures::stream::Stream;
use futures::Async;
/// Send a stream of replies to a WebSocket client
pub fn send_replies(
socket: warp::ws::WebSocket,
mut stream: StreamManager,
) -> impl futures::future::Future<Item = (), Error = ()> {
let (tx, rx) = futures::sync::mpsc::unbounded();
let (ws_tx, mut ws_rx) = socket.split();
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map_err(|_| ())
.map(|_r| ()),
);
let event_stream = tokio::timer::Interval::new(
std::time::Instant::now(),
std::time::Duration::from_millis(100),
)
.take_while(move |_| {
if ws_rx.poll().is_err() {
futures::future::ok(false)
} else {
futures::future::ok(true)
}
});
event_stream
.for_each(move |_json_value| {
if let Ok(Async::Ready(Some(json_value))) = stream.poll() {
let msg = warp::ws::Message::text(json_value.to_string());
if !tx.is_closed() {
tx.unbounded_send(msg).expect("No send error");
}
};
Ok(())
})
.then(|msg| msg)
.map_err(|e| println!("{}", e))
}

341
tests/test.rs Normal file
View File

@ -0,0 +1,341 @@
use ragequit::{
config,
parse_client_request::sse::Request,
parse_client_request::user::{Filter::*, Scope, User},
};
#[test]
fn user_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&Request::user());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user",))
.filter(&Request::user());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user?access_token={}",
access_token
))
.filter(&Request::user())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&Request::user())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn user_notifications_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
))
.filter(&Request::user_notifications());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/user/notification",))
.filter(&Request::user_notifications());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn user_notifications_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/user/notification?access_token={}",
access_token
))
.filter(&Request::user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token.clone(), Scope::Private)
.expect("in test")
.set_filter(Notification);
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/user/notification")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&Request::user_notifications())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private)
.expect("in test")
.set_filter(Notification);
assert_eq!(actual_timeline, "1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn public_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public")
.filter(&Request::public())
.expect("in test");
assert_eq!(value.0, "public".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn public_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=true")
.filter(&Request::public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
let value = warp::test::request()
.path("/api/v1/streaming/public?only_media=1")
.filter(&Request::public_media())
.expect("in test");
assert_eq!(value.0, "public:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn public_local_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local")
.filter(&Request::public_local())
.expect("in test");
assert_eq!(value.0, "public:local".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn public_local_media_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=true")
.filter(&Request::public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
let value = warp::test::request()
.path("/api/v1/streaming/public/local?only_media=1")
.filter(&Request::public_local_media())
.expect("in test");
assert_eq!(value.0, "public:local:media".to_string());
assert_eq!(value.1, User::public().set_filter(Language));
}
#[test]
fn direct_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
))
.filter(&Request::direct());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/direct",))
.filter(&Request::direct());
assert!(no_access_token(value));
}
#[test]
#[ignore]
fn direct_timeline_auth() {
let user_id: i64 = 1;
let access_token = get_access_token(user_id);
// Query auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/direct?access_token={}",
access_token
))
.filter(&Request::direct())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
// Header auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/direct")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&Request::direct())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
assert_eq!(actual_timeline, "direct:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn hashtag_timeline() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag?tag=a")
.filter(&Request::hashtag())
.expect("in test");
assert_eq!(value.0, "hashtag:a".to_string());
assert_eq!(value.1, User::public());
}
#[test]
fn hashtag_timeline_local() {
let value = warp::test::request()
.path("/api/v1/streaming/hashtag/local?tag=a")
.filter(&Request::hashtag_local())
.expect("in test");
assert_eq!(value.0, "hashtag:a:local".to_string());
assert_eq!(value.1, User::public());
}
#[test]
#[ignore]
fn list_timeline_auth() {
let list_id = 1;
let list_owner_id = get_list_owner(list_id);
let access_token = get_access_token(list_owner_id);
// Query Auth
let (actual_timeline, actual_user) = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token={}&list={}",
access_token, list_id,
))
.filter(&Request::list())
.expect("in test");
let expected_user =
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
// Header Auth
let (actual_timeline, actual_user) = warp::test::request()
.path("/api/v1/streaming/list?list=1")
.header("Authorization", format!("Bearer: {}", access_token.clone()))
.filter(&Request::list())
.expect("in test");
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
assert_eq!(actual_timeline, "list:1");
assert_eq!(actual_user, expected_user);
}
#[test]
fn list_timeline_unauthorized() {
let value = warp::test::request()
.path(&format!(
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
))
.filter(&Request::list());
assert!(invalid_access_token(value));
let value = warp::test::request()
.path(&format!("/api/v1/streaming/list?list=1",))
.filter(&Request::list());
assert!(no_access_token(value));
}
// Helper functions for tests
fn get_list_owner(list_number: i32) -> i64 {
let list_number: i64 = list_number.into();
let conn = config::postgres();
let rows = &conn
.query(
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_number],
)
.expect("in test");
assert_eq!(
rows.len(),
1,
"Test database must contain at least one user with a list to run this test."
);
rows.get(0).get(1)
}
fn get_access_token(user_id: i64) -> String {
let conn = config::postgres();
let rows = &conn
.query(
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
&[&user_id],
)
.expect("Can get access token from id");
rows.get(0).get(0)
}
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => true,
_ => false,
},
_ => false,
}
}
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
match value {
Err(error) => match error.cause() {
// The cause could validly be any of these, depending on the order they're checked
// (It would pass with just one, so the last one it doesn't have is "the" cause)
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
Some(c) if format!("{:?}", c) == "InvalidQuery" => true,
Some(c) if format!("{:?}", c) == "MissingHeader(\"Sec-WebSocket-Protocol\")" => true,
_ => false,
},
_ => false,
}
}