Reorganize code, pt1 (#110)

* Prevent Reciever from querying postgres

Before this commit, the Receiver would query Postgres for the name
associated with a hashtag when it encountered one not in its cache.
This ensured that the Receiver never encountered a (valid) hashtag id
that it couldn't handle, but caused a extra DB query and made
independent sections of the code more entangled than they need to be.

Now, we pass the relevant tag name to the Receiver when it first
starts managing a new subscription and it adds the tag name to its
cache then.

* Improve module boundary/privacy

* Reorganize Receiver to cut RedisStream

* Fix tests for code reorganization

Note that this change includes testing some private functionality by
exposing it publicly in tests via conditional compilation.  This
doesn't expose that functionality for the benchmarks, so the benchmark
tests do not currently pass without adding a few `pub use`
statements.  This might be worth changing later, but benchmark tests
aren't part of our CI and it's not hard to change when we want to test
performance.

This change also cuts the benchmark tests that were benchmarking old
ways Flodgatt functioned.  Those were useful for comparison purposes,
but have served their purpose – we've firmly moved away from the
older/slower approach.

* Fix Receiver for tests
This commit is contained in:
Daniel Sockwell 2020-03-27 12:00:48 -04:00 committed by GitHub
parent 2dd9ccbf91
commit 0acbde3eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 1181 additions and 1604 deletions

2
Cargo.lock generated
View File

@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
version = "0.6.4"
version = "0.6.5"
dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@ -3,95 +3,27 @@ use criterion::criterion_group;
use criterion::criterion_main;
use criterion::Criterion;
const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"<p>Trending tags:<br><a href=\\\"https://mastodon.host/tags/neverforget\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>neverforget</span></a><br><a href=\\\"https://mastodon.host/tags/4styles\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>4styles</span></a><br><a href=\\\"https://mastodon.host/tags/newpipe\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>newpipe</span></a><br><a href=\\\"https://mastodon.host/tags/uber\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>uber</span></a><br><a href=\\\"https://mastodon.host/tags/mercredifiction\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>mercredifiction</span></a></p>\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"<p>Hello, I am mastodon.host official semi bot.</p><p>Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ). </p><p>I also randomly boost one of my followers toot every hour !</p><p>If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)</p><p>If you want me to follow you, just tell me follow ! </p><p>If you want automatic follow for new users on your instance and you are an instance admin, contact me !</p><p>Other commands are private :)</p>\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"<a href=\\\"https://mastodon.host/stats.html\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\"><span class=\\\"invisible\\\">https://</span><span class=\\\"\\\">mastodon.host/stats.html</span><span class=\\\"invisible\\\"></span></a>\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"<a href=\\\"https://mastodon.host/about/more\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\"><span class=\\\"invisible\\\">https://</span><span class=\\\"\\\">mastodon.host/about/more</span><span class=\\\"invisible\\\"></span></a>\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"<span class=\\\"h-card\\\"><a href=\\\"https://mastodon.host/@gled\\\" class=\\\"u-url mention\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">@<span>gled</span></a></span>\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n";
/// Parses the Redis message using a Regex.
///
/// The naive approach from Flodgatt's proof-of-concept stage.
mod regex_parse {
use regex::Regex;
use serde_json::Value;
pub fn to_json_value(input: String) -> Value {
if input.ends_with("}\r\n") {
let messages = input.as_str().split("message").skip(1);
let regex = Regex::new(r"timeline:(?P<timeline>.*?)\r\n\$\d+\r\n(?P<value>.*?)\r\n")
.expect("Hard-codded");
for message in messages {
let _timeline = regex.captures(message).expect("Hard-coded timeline regex")
["timeline"]
.to_string();
let redis_msg: Value = serde_json::from_str(
&regex.captures(message).expect("Hard-coded value regex")["value"],
)
.expect("Valid json");
return redis_msg;
}
unreachable!()
} else {
unreachable!()
}
}
}
/// Parse with a simplified inline iterator.
///
/// Essentially shows best-case performance for producing a serde_json::Value.
mod parse_inline {
use serde_json::Value;
pub fn to_json_value(input: String) -> Value {
fn print_next_str(mut end: usize, input: &str) -> (usize, String) {
let mut start = end + 3;
end = start + 1;
let mut iter = input.chars();
iter.nth(start);
while iter.next().unwrap().is_digit(10) {
end += 1;
}
let length = &input[start..end].parse::<usize>().unwrap();
start = end + 2;
end = start + length;
let string = &input[start..end];
(end, string.to_string())
}
if input.ends_with("}\r\n") {
let end = 2;
let (end, _) = print_next_str(end, &input);
let (end, _timeline) = print_next_str(end, &input);
let (_, msg) = print_next_str(end, &input);
let redis_msg: Value = serde_json::from_str(&msg).unwrap();
redis_msg
} else {
unreachable!()
}
}
}
/// Parse using Flodgatt's current functions
mod flodgatt_parse_event {
use flodgatt::{messages::Event, redis_to_client_stream::receiver::MessageQueues};
use flodgatt::{
parse_client_request::subscription::Timeline,
redis_to_client_stream::{receiver::MsgQueue, redis::redis_stream},
messages::Event,
parse_client_request::Timeline,
redis_to_client_stream::{process_messages, MessageQueues, MsgQueue},
};
use lru::LruCache;
use std::collections::HashMap;
use uuid::Uuid;
/// One-time setup, not included in testing time.
pub fn setup() -> MessageQueues {
pub fn setup() -> (LruCache<String, i64>, MessageQueues, Uuid, Timeline) {
let mut cache: LruCache<String, i64> = LruCache::new(1000);
let mut queues_map = HashMap::new();
let id = Uuid::default();
let timeline = Timeline::from_redis_raw_timeline("1", None);
let timeline =
Timeline::from_redis_raw_timeline("timeline:1", &mut cache, &None).expect("In test");
queues_map.insert(id, MsgQueue::new(timeline));
let queues = MessageQueues(queues_map);
queues
(cache, queues, id, timeline)
}
pub fn to_event_struct(
@ -101,201 +33,7 @@ mod flodgatt_parse_event {
id: Uuid,
timeline: Timeline,
) -> Event {
redis_stream::process_messages(input, &mut None, &mut cache, &mut queues).unwrap();
queues
.oldest_msg_in_target_queue(id, timeline)
.expect("In test")
}
}
/// Parse using modified a modified version of Flodgatt's current function.
///
/// This version is modified to return a serde_json::Value instead of an Event to shows
/// the performance we would see if we used serde's built-in method for handling weakly
/// typed JSON instead of our own strongly typed struct.
mod flodgatt_parse_value {
use flodgatt::{log_fatal, parse_client_request::subscription::Timeline};
use lru::LruCache;
use serde_json::Value;
use std::{
collections::{HashMap, VecDeque},
time::Instant,
};
use uuid::Uuid;
#[derive(Debug)]
pub struct RedisMsg<'a> {
pub raw: &'a str,
pub cursor: usize,
pub prefix_len: usize,
}
impl<'a> RedisMsg<'a> {
pub fn from_raw(raw: &'a str, prefix_len: usize) -> Self {
Self {
raw,
cursor: "*3\r\n".len(), //length of intro header
prefix_len,
}
}
/// Move the cursor from the beginning of a number through its end and return the number
pub fn process_number(&mut self) -> usize {
let (mut selected_number, selection_start) = (0, self.cursor);
while let Ok(number) = self.raw[selection_start..=self.cursor].parse::<usize>() {
self.cursor += 1;
selected_number = number;
}
selected_number
}
/// In a pubsub reply from Redis, an item can be either the name of the subscribed channel
/// or the msg payload. Either way, it follows the same format:
/// `$[LENGTH_OF_ITEM_BODY]\r\n[ITEM_BODY]\r\n`
pub fn next_field(&mut self) -> String {
self.cursor += "$".len();
let item_len = self.process_number();
self.cursor += "\r\n".len();
let item_start_position = self.cursor;
self.cursor += item_len;
let item = self.raw[item_start_position..self.cursor].to_string();
self.cursor += "\r\n".len();
item
}
pub fn extract_raw_timeline_and_message(&mut self) -> (String, Value) {
let timeline = &self.next_field()[self.prefix_len..];
let msg_txt = self.next_field();
let msg_value: Value = serde_json::from_str(&msg_txt)
.unwrap_or_else(|_| log_fatal!("Invalid JSON from Redis: {:?}", &msg_txt));
(timeline.to_string(), msg_value)
}
}
pub struct MsgQueue {
pub timeline: Timeline,
pub messages: VecDeque<Value>,
_last_polled_at: Instant,
}
pub struct MessageQueues(HashMap<Uuid, MsgQueue>);
impl std::ops::Deref for MessageQueues {
type Target = HashMap<Uuid, MsgQueue>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for MessageQueues {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl MessageQueues {
pub fn oldest_msg_in_target_queue(
&mut self,
id: Uuid,
timeline: Timeline,
) -> Option<Value> {
self.entry(id)
.or_insert_with(|| MsgQueue::new(timeline))
.messages
.pop_front()
}
}
impl MsgQueue {
pub fn new(timeline: Timeline) -> Self {
MsgQueue {
messages: VecDeque::new(),
_last_polled_at: Instant::now(),
timeline,
}
}
}
pub fn process_msg(
raw_utf: String,
namespace: &Option<String>,
hashtag_id_cache: &mut LruCache<String, i64>,
queues: &mut MessageQueues,
) {
// Only act if we have a full message (end on a msg boundary)
if !raw_utf.ends_with("}\r\n") {
return;
};
let prefix_to_skip = match namespace {
Some(namespace) => format!("{}:timeline:", namespace),
None => "timeline:".to_string(),
};
let mut msg = RedisMsg::from_raw(&raw_utf, prefix_to_skip.len());
while !msg.raw.is_empty() {
let command = msg.next_field();
match command.as_str() {
"message" => {
let (raw_timeline, msg_value) = msg.extract_raw_timeline_and_message();
let hashtag = hashtag_from_timeline(&raw_timeline, hashtag_id_cache);
let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag);
for msg_queue in queues.values_mut() {
if msg_queue.timeline == timeline {
msg_queue.messages.push_back(msg_value.clone());
}
}
}
"subscribe" | "unsubscribe" => {
// No msg, so ignore & advance cursor to end
let _channel = msg.next_field();
msg.cursor += ":".len();
let _active_subscriptions = msg.process_number();
msg.cursor += "\r\n".len();
}
cmd => panic!("Invariant violation: {} is unexpected Redis output", cmd),
};
msg = RedisMsg::from_raw(&msg.raw[msg.cursor..], msg.prefix_len);
}
}
fn hashtag_from_timeline(
raw_timeline: &str,
hashtag_id_cache: &mut LruCache<String, i64>,
) -> Option<i64> {
if raw_timeline.starts_with("hashtag") {
let tag_name = raw_timeline
.split(':')
.nth(1)
.unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline))
.to_string();
let tag_id = *hashtag_id_cache
.get(&tag_name)
.unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
Some(tag_id)
} else {
None
}
}
pub fn setup() -> (LruCache<String, i64>, MessageQueues, Uuid, Timeline) {
let cache: LruCache<String, i64> = LruCache::new(1000);
let mut queues_map = HashMap::new();
let id = Uuid::default();
let timeline = Timeline::from_redis_raw_timeline("1", None);
queues_map.insert(id, MsgQueue::new(timeline));
let queues = MessageQueues(queues_map);
(cache, queues, id, timeline)
}
pub fn to_json_value(
input: String,
mut cache: &mut LruCache<String, i64>,
mut queues: &mut MessageQueues,
id: Uuid,
timeline: Timeline,
) -> Value {
process_msg(input, &None, &mut cache, &mut queues);
process_messages(&input, &mut cache, &mut None, &mut queues);
queues
.oldest_msg_in_target_queue(id, timeline)
.expect("In test")
@ -303,28 +41,10 @@ mod flodgatt_parse_value {
}
fn criterion_benchmark(c: &mut Criterion) {
let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string(); //INPUT.to_string();
let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string();
let mut group = c.benchmark_group("Parse redis RESP array");
// group.bench_function("parse to Value with a regex", |b| {
// b.iter(|| regex_parse::to_json_value(black_box(input.clone())))
// });
group.bench_function("parse to Value inline", |b| {
b.iter(|| parse_inline::to_json_value(black_box(input.clone())))
});
let (mut cache, mut queues, id, timeline) = flodgatt_parse_value::setup();
group.bench_function("parse to Value using Flodgatt functions", |b| {
b.iter(|| {
black_box(flodgatt_parse_value::to_json_value(
black_box(input.clone()),
black_box(&mut cache),
black_box(&mut queues),
black_box(id),
black_box(timeline),
))
})
});
let mut queues = flodgatt_parse_event::setup();
let (mut cache, mut queues, id, timeline) = flodgatt_parse_event::setup();
group.bench_function("parse to Event using Flodgatt functions", |b| {
b.iter(|| {
black_box(flodgatt_parse_event::to_event_struct(
@ -340,3 +60,5 @@ fn criterion_benchmark(c: &mut Criterion) {
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"<p>Trending tags:<br><a href=\\\"https://mastodon.host/tags/neverforget\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>neverforget</span></a><br><a href=\\\"https://mastodon.host/tags/4styles\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>4styles</span></a><br><a href=\\\"https://mastodon.host/tags/newpipe\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>newpipe</span></a><br><a href=\\\"https://mastodon.host/tags/uber\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>uber</span></a><br><a href=\\\"https://mastodon.host/tags/mercredifiction\\\" class=\\\"mention hashtag\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">#<span>mercredifiction</span></a></p>\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"<p>Hello, I am mastodon.host official semi bot.</p><p>Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ). </p><p>I also randomly boost one of my followers toot every hour !</p><p>If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)</p><p>If you want me to follow you, just tell me follow ! </p><p>If you want automatic follow for new users on your instance and you are an instance admin, contact me !</p><p>Other commands are private :)</p>\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"<a href=\\\"https://mastodon.host/stats.html\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\"><span class=\\\"invisible\\\">https://</span><span class=\\\"\\\">mastodon.host/stats.html</span><span class=\\\"invisible\\\"></span></a>\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"<a href=\\\"https://mastodon.host/about/more\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\"><span class=\\\"invisible\\\">https://</span><span class=\\\"\\\">mastodon.host/about/more</span><span class=\\\"invisible\\\"></span></a>\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"<span class=\\\"h-card\\\"><a href=\\\"https://mastodon.host/@gled\\\" class=\\\"u-url mention\\\" rel=\\\"nofollow noopener\\\" target=\\\"_blank\\\">@<span>gled</span></a></span>\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n";

View File

@ -11,14 +11,14 @@ from_env_var!(
/// The current environment, which controls what file to read other ENV vars from
let name = Env;
let default: EnvInner = EnvInner::Development;
let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants()));
let (env_var, allowed_values) = ("RUST_ENV", &format!("one of: {:?}", EnvInner::variants()));
let from_str = |s| EnvInner::from_str(s).ok();
);
from_env_var!(
/// The address to run Flodgatt on
let name = FlodgattAddr;
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string());
let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
let from_str = |s| match s {
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
_ => s.parse().ok(),
@ -28,35 +28,35 @@ from_env_var!(
/// How verbosely Flodgatt should log messages
let name = LogLevel;
let default: LogLevelInner = LogLevelInner::Warn;
let (env_var, allowed_values) = ("RUST_LOG", format!("one of: {:?}", LogLevelInner::variants()));
let (env_var, allowed_values) = ("RUST_LOG", &format!("one of: {:?}", LogLevelInner::variants()));
let from_str = |s| LogLevelInner::from_str(s).ok();
);
from_env_var!(
/// A Unix Socket to use in place of a local address
let name = Socket;
let default: Option<String> = None;
let (env_var, allowed_values) = ("SOCKET", "any string".to_string());
let (env_var, allowed_values) = ("SOCKET", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// The time between replies sent via WebSocket
let name = WsInterval;
let default: Duration = Duration::from_millis(100);
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The time between replies sent via Server Sent Events
let name = SseInterval;
let default: Duration = Duration::from_millis(100);
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The port to run Flodgatt on
let name = Port;
let default: u16 = 4000;
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string());
let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
@ -66,7 +66,7 @@ from_env_var!(
/// (including otherwise public timelines).
let name = WhitelistMode;
let default: bool = false;
let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string());
let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false");
let from_str = |s| s.parse().ok();
);
/// Permissions for Cross Origin Resource Sharing (CORS)

View File

@ -0,0 +1,137 @@
use std::{collections::HashMap, fmt};
pub struct EnvVar(pub HashMap<String, String>);
impl std::ops::Deref for EnvVar {
type Target = HashMap<String, String>;
fn deref(&self) -> &HashMap<String, String> {
&self.0
}
}
impl Clone for EnvVar {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl EnvVar {
pub fn new(vars: HashMap<String, String>) -> Self {
Self(vars)
}
pub fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option<impl ToString>) {
if let Some(value) = maybe_value {
self.0.insert(key.to_string(), value.to_string());
}
}
pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
log::error!(
r"{var} is set to `{value}`, which is invalid.
{var} must be {allowed_vals}.",
var = env_var,
value = supplied_value,
allowed_vals = allowed_values
);
std::process::exit(1);
}
}
impl fmt::Display for EnvVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut result = String::new();
for env_var in [
"NODE_ENV",
"RUST_LOG",
"BIND",
"PORT",
"SOCKET",
"SSE_FREQ",
"WS_FREQ",
"DATABASE_URL",
"DB_USER",
"USER",
"DB_PORT",
"DB_HOST",
"DB_PASS",
"DB_NAME",
"DB_SSLMODE",
"REDIS_HOST",
"REDIS_USER",
"REDIS_PORT",
"REDIS_PASSWORD",
"REDIS_USER",
"REDIS_DB",
]
.iter()
{
if let Some(value) = self.get(&env_var.to_string()) {
result = format!("{}\n {}: {}", result, env_var, value)
}
}
write!(f, "{}", result)
}
}
#[macro_export]
macro_rules! maybe_update {
($name:ident; $item: tt:$type:ty) => (
pub fn $name(self, item: Option<$type>) -> Self {
match item {
Some($item) => Self{ $item, ..self },
None => Self { ..self }
}
});
($name:ident; Some($item: tt: $type:ty)) => (
fn $name(self, item: Option<$type>) -> Self{
match item {
Some($item) => Self{ $item: Some($item), ..self },
None => Self { ..self }
}
})}
#[macro_export]
macro_rules! from_env_var {
($(#[$outer:meta])*
let name = $name:ident;
let default: $type:ty = $inner:expr;
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
let from_str = |$arg:ident| $body:expr;
) => {
pub struct $name(pub $type);
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl std::ops::Deref for $name {
type Target = $type;
fn deref(&self) -> &$type {
&self.0
}
}
impl std::default::Default for $name {
fn default() -> Self {
$name($inner)
}
}
impl $name {
fn inner_from_str($arg: &str) -> Option<$type> {
$body
}
pub fn maybe_update(self, var: Option<&String>) -> Self {
match var {
Some(empty_string) if empty_string.is_empty() => Self::default(),
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
crate::config::EnvVar::err($env_var, value, $allowed_values)
})),
None => self,
}
// if let Some(value) = var {
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
// crate::err::env_var_fatal($env_var, value, $allowed_values)
// }))
// } else {
// self
// }
}
}
};
}

View File

@ -4,135 +4,7 @@ mod postgres_cfg;
mod postgres_cfg_types;
mod redis_cfg;
mod redis_cfg_types;
pub use self::{
deployment_cfg::DeploymentConfig,
postgres_cfg::PostgresConfig,
redis_cfg::RedisConfig,
redis_cfg_types::{RedisInterval, RedisNamespace},
};
use std::{collections::HashMap, fmt};
mod environmental_variables;
pub struct EnvVar(pub HashMap<String, String>);
impl std::ops::Deref for EnvVar {
type Target = HashMap<String, String>;
fn deref(&self) -> &HashMap<String, String> {
&self.0
}
}
pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar};
impl Clone for EnvVar {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl EnvVar {
pub fn new(vars: HashMap<String, String>) -> Self {
Self(vars)
}
fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option<impl ToString>) {
if let Some(value) = maybe_value {
self.0.insert(key.to_string(), value.to_string());
}
}
}
impl fmt::Display for EnvVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut result = String::new();
for env_var in [
"NODE_ENV",
"RUST_LOG",
"BIND",
"PORT",
"SOCKET",
"SSE_FREQ",
"WS_FREQ",
"DATABASE_URL",
"DB_USER",
"USER",
"DB_PORT",
"DB_HOST",
"DB_PASS",
"DB_NAME",
"DB_SSLMODE",
"REDIS_HOST",
"REDIS_USER",
"REDIS_PORT",
"REDIS_PASSWORD",
"REDIS_USER",
"REDIS_DB",
]
.iter()
{
if let Some(value) = self.get(&env_var.to_string()) {
result = format!("{}\n {}: {}", result, env_var, value)
}
}
write!(f, "{}", result)
}
}
#[macro_export]
macro_rules! maybe_update {
($name:ident; $item: tt:$type:ty) => (
pub fn $name(self, item: Option<$type>) -> Self {
match item {
Some($item) => Self{ $item, ..self },
None => Self { ..self }
}
});
($name:ident; Some($item: tt: $type:ty)) => (
fn $name(self, item: Option<$type>) -> Self{
match item {
Some($item) => Self{ $item: Some($item), ..self },
None => Self { ..self }
}
})}
#[macro_export]
macro_rules! from_env_var {
($(#[$outer:meta])*
let name = $name:ident;
let default: $type:ty = $inner:expr;
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
let from_str = |$arg:ident| $body:expr;
) => {
pub struct $name(pub $type);
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl std::ops::Deref for $name {
type Target = $type;
fn deref(&self) -> &$type {
&self.0
}
}
impl std::default::Default for $name {
fn default() -> Self {
$name($inner)
}
}
impl $name {
fn inner_from_str($arg: &str) -> Option<$type> {
$body
}
pub fn maybe_update(self, var: Option<&String>) -> Self {
match var {
Some(empty_string) if empty_string.is_empty() => Self::default(),
Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
crate::err::env_var_fatal($env_var, value, $allowed_values)
})),
None => self,
}
// if let Some(value) = var {
// Self(Self::inner_from_str(value).unwrap_or_else(|| {
// crate::err::env_var_fatal($env_var, value, $allowed_values)
// }))
// } else {
// self
// }
}
}
};
}

View File

@ -6,7 +6,7 @@ from_env_var!(
/// The user to use for Postgres
let name = PgUser;
let default: String = "postgres".to_string();
let (env_var, allowed_values) = ("DB_USER", "any string".to_string());
let (env_var, allowed_values) = ("DB_USER", "any string");
let from_str = |s| Some(s.to_string());
);
@ -14,7 +14,7 @@ from_env_var!(
/// The host address where Postgres is running)
let name = PgHost;
let default: String = "localhost".to_string();
let (env_var, allowed_values) = ("DB_HOST", "any string".to_string());
let (env_var, allowed_values) = ("DB_HOST", "any string");
let from_str = |s| Some(s.to_string());
);
@ -22,7 +22,7 @@ from_env_var!(
/// The password to use with Postgress
let name = PgPass;
let default: Option<String> = None;
let (env_var, allowed_values) = ("DB_PASS", "any string".to_string());
let (env_var, allowed_values) = ("DB_PASS", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
@ -30,7 +30,7 @@ from_env_var!(
/// The Postgres database to use
let name = PgDatabase;
let default: String = "mastodon_development".to_string();
let (env_var, allowed_values) = ("DB_NAME", "any string".to_string());
let (env_var, allowed_values) = ("DB_NAME", "any string");
let from_str = |s| Some(s.to_string());
);
@ -38,14 +38,14 @@ from_env_var!(
/// The port Postgres is running on
let name = PgPort;
let default: u16 = 5432;
let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535".to_string());
let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
let name = PgSslMode;
let default: PgSslInner = PgSslInner::Prefer;
let (env_var, allowed_values) = ("DB_SSLMODE", format!("one of: {:?}", PgSslInner::variants()));
let (env_var, allowed_values) = ("DB_SSLMODE", &format!("one of: {:?}", PgSslInner::variants()));
let from_str = |s| PgSslInner::from_str(s).ok();
);

View File

@ -7,48 +7,48 @@ from_env_var!(
/// The host address where Redis is running
let name = RedisHost;
let default: String = "127.0.0.1".to_string();
let (env_var, allowed_values) = ("REDIS_HOST", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_HOST", "any string");
let from_str = |s| Some(s.to_string());
);
from_env_var!(
/// The port Redis is running on
let name = RedisPort;
let default: u16 = 6379;
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string());
let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
/// How frequently to poll Redis
let name = RedisInterval;
let default: Duration = Duration::from_millis(100);
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string());
let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The password to use for Redis
let name = RedisPass;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// An optional Redis Namespace
let name = RedisNamespace;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// A user for Redis (not supported)
let name = RedisUser;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_USER", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// The database to use with Redis (no current effect for PubSub connections)
let name = RedisDb;
let default: Option<String> = None;
let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string());
let (env_var, allowed_values) = ("REDIS_DB", "any string");
let from_str = |s| Some(Some(s.to_string()));
);

View File

@ -1,4 +1,3 @@
use serde_derive::Serialize;
use std::fmt::Display;
pub fn die_with_msg(msg: impl Display) -> ! {
@ -14,67 +13,20 @@ macro_rules! log_fatal {
};};
}
pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
eprintln!(
r"FATAL ERROR: {var} is set to `{value}`, which is invalid.
{var} must be {allowed_vals}.",
var = env_var,
value = supplied_value,
allowed_vals = allowed_values
);
std::process::exit(1);
#[derive(Debug)]
pub enum RedisParseErr {
Incomplete,
Unrecoverable,
}
#[macro_export]
macro_rules! dbg_and_die {
($msg:expr) => {
let message = format!("FATAL ERROR: {}", $msg);
dbg!(message);
std::process::exit(1);
};
}
pub fn unwrap_or_die<T>(s: Option<T>, msg: &str) -> T {
s.unwrap_or_else(|| {
eprintln!("FATAL ERROR: {}", msg);
std::process::exit(1)
})
#[derive(Debug)]
pub enum TimelineErr {
RedisNamespaceMismatch,
InvalidInput,
}
#[derive(Serialize)]
pub struct ErrorMessage {
error: String,
}
impl ErrorMessage {
fn new(msg: impl std::fmt::Display) -> Self {
Self {
error: msg.to_string(),
}
}
}
/// Recover from Errors by sending appropriate Warp::Rejections
pub fn handle_errors(
rejection: warp::reject::Rejection,
) -> Result<impl warp::Reply, warp::reject::Rejection> {
let err_txt = match rejection.cause() {
Some(text) if text.to_string() == "Missing request header 'authorization'" => {
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Error: Nonexistant endpoint".to_string(),
};
let json = warp::reply::json(&ErrorMessage::new(err_txt));
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
}
pub struct CustomError {}
impl CustomError {
pub fn unauthorized_list() -> warp::reject::Rejection {
warp::reject::custom("Error: Access to list not authorized")
impl From<std::num::ParseIntError> for TimelineErr {
fn from(_error: std::num::ParseIntError) -> Self {
Self::InvalidInput
}
}

View File

@ -1,87 +1,76 @@
use flodgatt::{
config, err,
parse_client_request::{sse, subscription, ws},
redis_to_client_stream::{self, ClientAgent},
config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig},
parse_client_request::{PgPool, Subscription},
redis_to_client_stream::{ClientAgent, EventStream},
};
use std::{collections::HashMap, env, fs, net, os::unix::fs::PermissionsExt};
use tokio::net::UnixListener;
use warp::{path, ws::Ws2, Filter};
fn main() {
dotenv::from_filename(
match env::var("ENV").ok().as_ref().map(String::as_str) {
dotenv::from_filename(match env::var("ENV").ok().as_ref().map(String::as_str) {
Some("production") => ".env.production",
Some("development") | None => ".env",
Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."),
}).ok();
Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"),
})
.ok();
let env_vars_map: HashMap<_, _> = dotenv::vars().collect();
let env_vars = config::EnvVar::new(env_vars_map);
let env_vars = EnvVar::new(env_vars_map);
pretty_env_logger::init();
log::info!(
"Flodgatt recognized the following environmental variables:{}",
env_vars.clone()
);
let redis_cfg = config::RedisConfig::from_env(env_vars.clone());
let cfg = config::DeploymentConfig::from_env(env_vars.clone());
let redis_cfg = RedisConfig::from_env(env_vars.clone());
let cfg = DeploymentConfig::from_env(env_vars.clone());
let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone());
let pg_pool = subscription::PgPool::new(postgres_cfg);
let postgres_cfg = PostgresConfig::from_env(env_vars.clone());
let pg_pool = PgPool::new(postgres_cfg);
let client_agent_sse = ClientAgent::blank(redis_cfg, pg_pool.clone());
let client_agent_sse = ClientAgent::blank(redis_cfg);
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
log::info!("Streaming server initialized and ready to accept connections");
// Server Sent Events
let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode)
let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode)
.and(warp::sse())
.map(
move |subscription: subscription::Subscription,
sse_connection_to_client: warp::sse::Sse| {
move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| {
log::info!("Incoming SSE request for {:?}", subscription.timeline);
// Create a new ClientAgent
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
// Assign ClientAgent to generate stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the SSE connection
redis_to_client_stream::send_updates_to_sse(
client_agent,
sse_connection_to_client,
sse_update_interval,
)
EventStream::to_sse(client_agent, sse_connection_to_client, sse_interval)
},
)
.with(warp::reply::with::header("Connection", "keep-alive"))
.recover(err::handle_errors);
.with(warp::reply::with::header("Connection", "keep-alive"));
// WebSocket
let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode);
let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode)
let websocket_routes = Subscription::from_ws_request(pg_pool.clone(), whitelist_mode)
.and(warp::ws::ws2())
.map(
move |subscription: subscription::Subscription, token: Option<String>, ws: Ws2| {
log::info!("Incoming websocket request for {:?}", subscription.timeline);
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
.map(move |subscription: Subscription, ws: Ws2| {
log::info!("Incoming websocket request for {:?}", subscription.timeline);
(
ws.on_upgrade(move |socket| {
redis_to_client_stream::send_updates_to_ws(
socket,
client_agent,
ws_update_interval,
)
}),
token.unwrap_or_else(String::new),
)
},
)
let token = subscription.access_token.clone();
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the WS connection (along with the User's access_token
// which is sent for security)
(
ws.on_upgrade(move |socket| {
EventStream::to_ws(socket, client_agent, ws_update_interval)
}),
token.unwrap_or_else(String::new),
)
})
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
let cors = warp::cors()
@ -98,7 +87,28 @@ fn main() {
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap();
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run_incoming(incoming);
warp::serve(
health.or(websocket_routes.or(sse_routes).with(cors).recover(
|rejection: warp::reject::Rejection| {
let err_txt = match rejection.cause() {
Some(text)
if text.to_string() == "Missing request header 'authorization'" =>
{
"Error: Missing access token".to_string()
}
Some(text) => text.to_string(),
None => "Error: Nonexistant endpoint".to_string(),
};
let json = warp::reply::json(&err_txt);
Ok(warp::reply::with_status(
json,
warp::http::StatusCode::UNAUTHORIZED,
))
},
)),
)
.run_incoming(incoming);
} else {
let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0);
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run(server_addr);

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,13 @@
//! Parse the client request and return a (possibly authenticated) `User`
pub mod query;
pub mod sse;
pub mod subscription;
pub mod ws;
//! Parse the client request and return a Subscription
mod postgres;
mod query;
mod sse;
mod subscription;
mod ws;
pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API
pub use subscription::{Stream, Subscription, Timeline};
#[cfg(test)]
pub use subscription::{Content, Reach};

View File

@ -0,0 +1,204 @@
//! Postgres queries
use crate::{
config,
parse_client_request::subscription::{Scope, UserData},
};
use ::postgres;
use r2d2_postgres::PostgresConnectionManager;
use std::collections::HashSet;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
impl PgPool {
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string())
.port(*pg_cfg.port)
.dbname(&pg_cfg.database);
if let Some(password) = &*pg_cfg.password {
cfg.password(password);
};
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder()
.max_size(10)
.build(manager)
.expect("Can connect to local postgres");
Self(pool)
}
pub fn select_user(self, token: &str) -> Result<UserData, Rejection> {
let mut conn = self.0.get().unwrap();
let query_rows = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let id = result_columns.get(1);
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
.into_iter()
.collect();
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.filter_map(|scope| match scope {
"read" => Some(Scope::Read),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
"write" | "follow" => None, // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
None
}
})
.collect();
// We don't need to separately track read auth - it's just all three others
if scopes.remove(&Scope::Read) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
}
Ok(UserData {
id,
allowed_langs,
scopes,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
pub fn select_hashtag_id(self, tag_name: &String) -> Result<i64, Rejection> {
let mut conn = self.0.get().unwrap();
let rows = &conn
.query(
"
SELECT id
FROM tags
WHERE name = $1
LIMIT 1",
&[&tag_name],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
/// Query Postgres for everyone the user has blocked or muted
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: i64) -> HashSet<i64> {
// "
// SELECT
// 1
// FROM blocks
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
// OR (account_id = $2 AND target_account_id = $1)
// UNION SELECT
// 1
// FROM mutes
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
self
.0
.get()
.unwrap()
.query(
"
SELECT target_account_id
FROM blocks
WHERE account_id = $1
UNION SELECT target_account_id
FROM mutes
WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: i64) -> HashSet<i64> {
self
.0
.get()
.unwrap()
.query(
"
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for all current domain blocks
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: i64) -> HashSet<String> {
self
.0
.get()
.unwrap()
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool {
let mut conn = self.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
None => false,
Some(row) => {
let list_owner_id: i64 = row.get(1);
list_owner_id == user_id
}
}
}
}

View File

@ -28,6 +28,12 @@ impl Query {
}
macro_rules! make_query_type {
(Stream => $parameter:tt:$type:ty) => {
#[derive(Deserialize, Debug, Default)]
pub struct Stream {
pub $parameter: $type,
}
};
($name:tt => $parameter:tt:$type:ty) => {
#[derive(Deserialize, Debug, Default)]
pub struct $name {
@ -59,14 +65,14 @@ impl ToString for Stream {
}
}
pub fn optional_media_query() -> BoxedFilter<(Media,)> {
warp::query()
.or(warp::any().map(|| Media {
only_media: "false".to_owned(),
}))
.unify()
.boxed()
}
// pub fn optional_media_query() -> BoxedFilter<(Media,)> {
// warp::query()
// .or(warp::any().map(|| Media {
// only_media: "false".to_owned(),
// }))
// .unify()
// .boxed()
// }
pub struct OptionalAccessToken;

View File

@ -1,78 +1,4 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
use super::{
query::{self, Query},
subscription::{PgPool, Subscription},
};
use warp::{filters::BoxedFilter, path, Filter};
#[allow(dead_code)]
type TimelineUser = ((String, Subscription),);
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*.boxed()
};
}
macro_rules! parse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
pub fn extract_user_or_reject(
pg_pool: PgPool,
whitelist_mode: bool,
) -> BoxedFilter<(Subscription,)> {
any_of!(
parse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
// #[cfg(test)]
// mod test {

View File

@ -0,0 +1,396 @@
//! `User` struct and related functionality
// #[cfg(test)]
// mod mock_postgres;
// #[cfg(test)]
// use mock_postgres as postgres;
// #[cfg(not(test))]
use super::postgres::PgPool;
use super::query::Query;
use crate::err::TimelineErr;
use crate::log_fatal;
use lru::LruCache;
use std::collections::HashSet;
use warp::reject::Rejection;
use super::query;
use warp::{filters::BoxedFilter, path, Filter};
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*.boxed()
};
}
macro_rules! parse_sse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
#[derive(Clone, Debug, PartialEq)]
pub struct Subscription {
pub timeline: Timeline,
pub allowed_langs: HashSet<String>,
pub blocks: Blocks,
pub hashtag_name: Option<String>,
pub access_token: Option<String>,
}
impl Default for Subscription {
fn default() -> Self {
Self {
timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
allowed_langs: HashSet::new(),
blocks: Blocks::default(),
hashtag_name: None,
access_token: None,
}
}
}
impl Subscription {
pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
parse_ws_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
any_of!(
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.boxed()
}
fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
let user = match q.access_token.clone() {
Some(token) => pool.clone().select_user(&token)?,
None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
None => UserData::public(),
};
let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
let hashtag_name = match timeline {
Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
_non_hashtag_timeline => None,
};
Ok(Subscription {
timeline,
allowed_langs: user.allowed_langs,
blocks: Blocks {
blocking_users: pool.clone().select_blocking_users(user.id),
blocked_users: pool.clone().select_blocked_users(user.id),
blocked_domains: pool.clone().select_blocked_domains(user.id),
},
hashtag_name,
access_token: q.access_token,
})
}
}
fn parse_ws_query() -> BoxedFilter<(Query,)> {
path!("api" / "v1" / "streaming")
.and(path::end())
.and(warp::query())
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|stream: query::Stream,
auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub struct Timeline(pub Stream, pub Reach, pub Content);
impl Timeline {
pub fn empty() -> Self {
use {Content::*, Reach::*, Stream::*};
Self(Unset, Local, Notification)
}
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String {
use {Content::*, Reach::*, Stream::*};
match self {
Timeline(Public, Federated, All) => "timeline:public".into(),
Timeline(Public, Local, All) => "timeline:public:local".into(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
Timeline(Hashtag(id), Federated, All) => format!(
"timeline:hashtag:{}",
hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
),
Timeline(Hashtag(id), Local, All) => format!(
"timeline:hashtag:{}:local",
hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
),
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
Timeline(one, _two, _three) => {
log_fatal!("Supposedly impossible timeline reached: {:?}", one)
}
}
}
pub fn from_redis_raw_timeline(
timeline: &str,
cache: &mut LruCache<String, i64>,
namespace: &Option<String>,
) -> Result<Self, TimelineErr> {
use crate::err::TimelineErr::RedisNamespaceMismatch;
use {Content::*, Reach::*, Stream::*};
let timeline_slice = &timeline.split(":").collect::<Vec<&str>>()[..];
#[rustfmt::skip]
let (stream, reach, content) = if let Some(ns) = namespace {
match timeline_slice {
[n, "timeline", "public"] if n == ns => (Public, Federated, All),
[_, "timeline", "public"]
| ["timeline", "public"] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "public", "local"] if ns == n => (Public, Local, All),
[_, "timeline", "public", "local"]
| ["timeline", "public", "local"] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "public", "media"] if ns == n => (Public, Federated, Media),
[_, "timeline", "public", "media"]
| ["timeline", "public", "media"] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "public", "local", "media"] if ns == n => (Public, Local, Media),
[_, "timeline", "public", "local", "media"]
| ["timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "hashtag", tag_name] if ns == n => {
let tag_id = *cache
.get(&tag_name.to_string())
.unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
(Hashtag(tag_id), Federated, All)
}
[_, "timeline", "hashtag", _tag]
| ["timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "hashtag", _tag, "local"] if ns == n => (Hashtag(0), Local, All),
[_, "timeline", "hashtag", _tag, "local"]
| ["timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?,
[n, "timeline", id] if ns == n => (User(id.parse().unwrap()), Federated, All),
[_, "timeline", _id]
| ["timeline", _id] => Err(RedisNamespaceMismatch)?,
[n, "timeline", id, "notification"] if ns == n =>
(User(id.parse()?), Federated, Notification),
[_, "timeline", _id, "notification"]
| ["timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "list", id] if ns == n => (List(id.parse()?), Federated, All),
[_, "timeline", "list", _id]
| ["timeline", "list", _id] => Err(RedisNamespaceMismatch)?,
[n, "timeline", "direct", id] if ns == n => (Direct(id.parse()?), Federated, All),
[_, "timeline", "direct", _id]
| ["timeline", "direct", _id] => Err(RedisNamespaceMismatch)?,
[..] => log_fatal!("Unexpected channel from Redis: {:?}", timeline_slice),
}
} else {
match timeline_slice {
["timeline", "public"] => (Public, Federated, All),
[_, "timeline", "public"] => Err(RedisNamespaceMismatch)?,
["timeline", "public", "local"] => (Public, Local, All),
[_, "timeline", "public", "local"] => Err(RedisNamespaceMismatch)?,
["timeline", "public", "media"] => (Public, Federated, Media),
[_, "timeline", "public", "media"] => Err(RedisNamespaceMismatch)?,
["timeline", "public", "local", "media"] => (Public, Local, Media),
[_, "timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?,
["timeline", "hashtag", _tag] => (Hashtag(0), Federated, All),
[_, "timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?,
["timeline", "hashtag", _tag, "local"] => (Hashtag(0), Local, All),
[_, "timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?,
["timeline", id] => (User(id.parse().unwrap()), Federated, All),
[_, "timeline", _id] => Err(RedisNamespaceMismatch)?,
["timeline", id, "notification"] => {
(User(id.parse().unwrap()), Federated, Notification)
}
[_, "timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?,
["timeline", "list", id] => (List(id.parse().unwrap()), Federated, All),
[_, "timeline", "list", _id] => Err(RedisNamespaceMismatch)?,
["timeline", "direct", id] => (Direct(id.parse().unwrap()), Federated, All),
[_, "timeline", "direct", _id] => Err(RedisNamespaceMismatch)?,
// Other endpoints don't exist:
[..] => Err(TimelineErr::InvalidInput)?,
}
};
Ok(Timeline(stream, reach, content))
}
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
Ok(match q.stream.as_ref() {
"public" => match q.media {
true => Timeline(Public, Federated, Media),
false => Timeline(Public, Federated, All),
},
"public:local" => match q.media {
true => Timeline(Public, Local, Media),
false => Timeline(Public, Local, All),
},
"public:media" => Timeline(Public, Federated, Media),
"public:local:media" => Timeline(Public, Local, Media),
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
"user" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
"user:notification" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, Notification),
false => Err(custom("Error: Missing access token"))?,
},
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
true => Timeline(List(q.list), Federated, All),
false => Err(warp::reject::custom("Error: Missing access token"))?,
},
"direct" => match user.scopes.contains(&Statuses) {
true => Timeline(Direct(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
other => {
log::warn!("Request for nonexistent endpoint: `{}`", other);
Err(custom("Error: Nonexistent endpoint"))?
}
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Stream {
User(i64),
List(i64),
Direct(i64),
Hashtag(i64),
Public,
Unset,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Reach {
Local,
Federated,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Content {
All,
Media,
Notification,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Scope {
Read,
Statuses,
Notifications,
Lists,
}
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Blocks {
pub blocked_domains: HashSet<String>,
pub blocked_users: HashSet<i64>,
pub blocking_users: HashSet<i64>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct UserData {
pub id: i64,
pub allowed_langs: HashSet<String>,
pub scopes: HashSet<Scope>,
}
impl UserData {
fn public() -> Self {
Self {
id: -1,
allowed_langs: HashSet::new(),
scopes: HashSet::new(),
}
}
}

View File

@ -1,43 +0,0 @@
//! Mock Postgres connection (for use in unit testing)
use super::{OauthScope, Subscription};
use std::collections::HashSet;
#[derive(Clone)]
pub struct PgPool;
impl PgPool {
pub fn new() -> Self {
Self
}
}
pub fn select_user(
access_token: &str,
_pg_pool: PgPool,
) -> Result<Subscription, warp::reject::Rejection> {
let mut user = Subscription::default();
if access_token == "TEST_USER" {
user.id = 1;
user.logged_in = true;
user.access_token = "TEST_USER".to_string();
user.email = "user@example.com".to_string();
user.scopes = OauthScope::from(vec![
"read".to_string(),
"write".to_string(),
"follow".to_string(),
]);
} else if access_token == "INVALID" {
return Err(warp::reject::custom("Error: Invalid access token"));
}
Ok(user)
}
pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet<i64> {
HashSet::new()
}
pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet<String> {
HashSet::new()
}
pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool {
user_id == list_id
}

View File

@ -1,196 +0,0 @@
//! `User` struct and related functionality
// #[cfg(test)]
// mod mock_postgres;
// #[cfg(test)]
// use mock_postgres as postgres;
// #[cfg(not(test))]
pub mod postgres;
pub use self::postgres::PgPool;
use super::query::Query;
use crate::log_fatal;
use std::collections::HashSet;
use warp::reject::Rejection;
/// The User (with data read from Postgres)
#[derive(Clone, Debug, PartialEq)]
pub struct Subscription {
pub timeline: Timeline,
pub allowed_langs: HashSet<String>,
pub blocks: Blocks,
}
impl Default for Subscription {
fn default() -> Self {
Self {
timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
allowed_langs: HashSet::new(),
blocks: Blocks::default(),
}
}
}
impl Subscription {
pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result<Self, Rejection> {
let user = match q.access_token.clone() {
Some(token) => postgres::select_user(&token, pool.clone())?,
None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
None => UserData::public(),
};
Ok(Subscription {
timeline: Timeline::from_query_and_user(&q, &user, pool.clone())?,
allowed_langs: user.allowed_langs,
blocks: Blocks {
blocking_users: postgres::select_blocking_users(user.id, pool.clone()),
blocked_users: postgres::select_blocked_users(user.id, pool.clone()),
blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()),
},
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub struct Timeline(pub Stream, pub Reach, pub Content);
impl Timeline {
pub fn empty() -> Self {
use {Content::*, Reach::*, Stream::*};
Self(Unset, Local, Notification)
}
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String {
use {Content::*, Reach::*, Stream::*};
match self {
Timeline(Public, Federated, All) => "timeline:public".into(),
Timeline(Public, Local, All) => "timeline:public:local".into(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
Timeline(Hashtag(id), Federated, All) => format!(
"timeline:hashtag:{}",
hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
),
Timeline(Hashtag(id), Local, All) => format!(
"timeline:hashtag:{}:local",
hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
),
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
Timeline(one, _two, _three) => {
log_fatal!("Supposedly impossible timeline reached: {:?}", one)
}
}
}
pub fn from_redis_raw_timeline(raw_timeline: &str, hashtag: Option<i64>) -> Self {
use {Content::*, Reach::*, Stream::*};
match raw_timeline.split(':').collect::<Vec<&str>>()[..] {
["public"] => Timeline(Public, Federated, All),
["public", "local"] => Timeline(Public, Local, All),
["public", "media"] => Timeline(Public, Federated, Media),
["public", "local", "media"] => Timeline(Public, Local, Media),
["hashtag", _tag] => Timeline(Hashtag(hashtag.unwrap()), Federated, All),
["hashtag", _tag, "local"] => Timeline(Hashtag(hashtag.unwrap()), Local, All),
[id] => Timeline(User(id.parse().unwrap()), Federated, All),
[id, "notification"] => Timeline(User(id.parse().unwrap()), Federated, Notification),
["list", id] => Timeline(List(id.parse().unwrap()), Federated, All),
["direct", id] => Timeline(Direct(id.parse().unwrap()), Federated, All),
// Other endpoints don't exist:
[..] => log_fatal!("Unexpected channel from Redis: {}", raw_timeline),
}
}
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
let id_from_hashtag = || postgres::select_list_id(&q.hashtag, pool.clone());
let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone());
Ok(match q.stream.as_ref() {
"public" => match q.media {
true => Timeline(Public, Federated, Media),
false => Timeline(Public, Federated, All),
},
"public:local" => match q.media {
true => Timeline(Public, Local, Media),
false => Timeline(Public, Local, All),
},
"public:media" => Timeline(Public, Federated, Media),
"public:local:media" => Timeline(Public, Local, Media),
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
"user" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
"user:notification" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, Notification),
false => Err(custom("Error: Missing access token"))?,
},
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
true => Timeline(List(q.list), Federated, All),
false => Err(warp::reject::custom("Error: Missing access token"))?,
},
"direct" => match user.scopes.contains(&Statuses) {
true => Timeline(Direct(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
other => {
log::warn!("Request for nonexistent endpoint: `{}`", other);
Err(custom("Error: Nonexistent endpoint"))?
}
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Stream {
User(i64),
List(i64),
Direct(i64),
Hashtag(i64),
Public,
Unset,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Reach {
Local,
Federated,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Content {
All,
Media,
Notification,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Scope {
Read,
Statuses,
Notifications,
Lists,
}
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Blocks {
pub blocked_domains: HashSet<String>,
pub blocked_users: HashSet<i64>,
pub blocking_users: HashSet<i64>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct UserData {
id: i64,
allowed_langs: HashSet<String>,
scopes: HashSet<Scope>,
}
impl UserData {
fn public() -> Self {
Self {
id: -1,
allowed_langs: HashSet::new(),
scopes: HashSet::new(),
}
}
}

View File

@ -1,225 +0,0 @@
//! Postgres queries
use crate::{
config,
parse_client_request::subscription::{Scope, UserData},
};
use ::postgres;
use r2d2_postgres::PostgresConnectionManager;
use std::collections::HashSet;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
pub struct PgPool(pub r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>);
impl PgPool {
pub fn new(pg_cfg: config::PostgresConfig) -> Self {
let mut cfg = postgres::Config::new();
cfg.user(&pg_cfg.user)
.host(&*pg_cfg.host.to_string())
.port(*pg_cfg.port)
.dbname(&pg_cfg.database);
if let Some(password) = &*pg_cfg.password {
cfg.password(password);
};
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder()
.max_size(10)
.build(manager)
.expect("Can connect to local postgres");
Self(pool)
}
}
pub fn select_user(token: &str, pool: PgPool) -> Result<UserData, Rejection> {
let mut conn = pool.0.get().unwrap();
let query_rows = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let id = result_columns.get(1);
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
.into_iter()
.collect();
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.filter_map(|scope| match scope {
"read" => Some(Scope::Read),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
"write" | "follow" => None, // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
None
}
})
.collect();
// We don't need to separately track read auth - it's just all three others
if scopes.remove(&Scope::Read) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
}
Ok(UserData {
id,
allowed_langs,
scopes,
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
pub fn select_list_id(tag_name: &String, pg_pool: PgPool) -> Result<i64, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id
FROM tags
WHERE name = $1
LIMIT 1",
&[&tag_name],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result<String, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT name
FROM tags
WHERE id = $1
LIMIT 1",
&[&tag_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
Some(row) => Ok(row.get(0)),
None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
}
}
/// Query Postgres for everyone the user has blocked or muted
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
// "
// SELECT
// 1
// FROM blocks
// WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
// OR (account_id = $2 AND target_account_id = $1)
// UNION SELECT
// 1
// FROM mutes
// WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
// , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
pg_pool
.0
.get()
.unwrap()
.query(
"
SELECT target_account_id
FROM blocks
WHERE account_id = $1
UNION SELECT target_account_id
FROM mutes
WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet<i64> {
pg_pool
.0
.get()
.unwrap()
.query(
"
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Query Postgres for all current domain blocks
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet<String> {
pg_pool
.0
.get()
.unwrap()
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool {
let mut conn = pg_pool.0.get().unwrap();
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
None => false,
Some(row) => {
let list_owner_id: i64 = row.get(1);
list_owner_id == user_id
}
}
}

View File

@ -1,48 +1,9 @@
//! Filters for the WebSocket endpoint
use super::{
query::{self, Query},
subscription::{PgPool, Subscription},
};
use warp::{filters::BoxedFilter, path, Filter};
/// WebSocket filters
fn parse_query() -> BoxedFilter<(Query,)> {
path!("api" / "v1" / "streaming")
.and(path::end())
.and(warp::query())
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|stream: query::Stream,
auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
}
pub fn extract_user_and_token_or_reject(
pg_pool: PgPool,
whitelist_mode: bool,
) -> BoxedFilter<(Subscription, Option<String>)> {
parse_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
.and(query::OptionalAccessToken::from_ws_header())
.boxed()
}
// #[cfg(test)]
// mod test {

View File

@ -19,29 +19,29 @@ use super::receiver::Receiver;
use crate::{
config,
messages::Event,
parse_client_request::subscription::{PgPool, Stream::Public, Subscription, Timeline},
parse_client_request::{Stream::Public, Subscription, Timeline},
};
use futures::{
Async::{self, NotReady, Ready},
Poll,
};
use std::sync;
use std::sync::{Arc, Mutex};
use tokio::io::Error;
use uuid::Uuid;
/// Struct for managing all Redis streams.
#[derive(Clone, Debug)]
pub struct ClientAgent {
receiver: sync::Arc<sync::Mutex<Receiver>>,
id: uuid::Uuid,
receiver: Arc<Mutex<Receiver>>,
id: Uuid,
pub subscription: Subscription,
}
impl ClientAgent {
/// Create a new `ClientAgent` with no shared data.
pub fn blank(redis_cfg: config::RedisConfig, pg_pool: PgPool) -> Self {
pub fn blank(redis_cfg: config::RedisConfig) -> Self {
ClientAgent {
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg, pg_pool))),
receiver: Arc::new(Mutex::new(Receiver::new(redis_cfg))),
id: Uuid::default(),
subscription: Subscription::default(),
}
@ -70,7 +70,11 @@ impl ClientAgent {
self.subscription = subscription;
let start_time = Instant::now();
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
receiver.manage_new_timeline(self.id, self.subscription.timeline);
receiver.manage_new_timeline(
self.id,
self.subscription.timeline,
self.subscription.hashtag_name.clone(),
);
log::info!("init_for_user had lock for: {:?}", start_time.elapsed());
}
}

View File

@ -0,0 +1,103 @@
use super::ClientAgent;
use warp::ws::WebSocket;
use futures::{future::Future, stream::Stream, Async};
use log;
use std::time::{Duration, Instant};
pub struct EventStream;
impl EventStream {
/// Send a stream of replies to a WebSocket client.
pub fn to_ws(
socket: WebSocket,
mut client_agent: ClientAgent,
update_interval: Duration,
) -> impl Future<Item = (), Error = ()> {
let (ws_tx, mut ws_rx) = socket.split();
let timeline = client_agent.subscription.timeline;
// Create a pipe
let (tx, rx) = futures::sync::mpsc::unbounded();
// Send one end of it to a different thread and tell that end to forward whatever it gets
// on to the websocket client
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map(|_r| ())
.map_err(|e| match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("websocket send error: {}", e),
}),
);
// Yield new events for as long as the client is still connected
let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
move |_| match ws_rx.poll() {
Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
Ok(Async::Ready(None)) => {
// TODO: consider whether we should manually drop closed connections here
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
// no err, just closed Unix socket
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) => {
log::warn!("Error in {:?}: {}", timeline, e);
futures::future::ok(false)
}
},
);
let mut time = Instant::now();
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_instant| {
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
.expect("No send error");
};
if time.elapsed() > Duration::from_secs(30) {
tx.unbounded_send(warp::ws::Message::text("{}"))
.expect("Can ping");
time = Instant::now();
}
Ok(())
})
.then(move |result| {
// TODO: consider whether we should manually drop closed connections here
log::info!("WebSocket connection for {:?} closed.", timeline);
result
})
.map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
}
pub fn to_sse(
mut client_agent: ClientAgent,
connection: warp::sse::Sse,
update_interval: Duration,
) ->impl warp::reply::Reply {
let event_stream =
tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
match client_agent.poll() {
Ok(Async::Ready(Some(event))) => Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
)),
_ => None,
}
});
connection.reply(
warp::sse::keep_alive()
.interval(Duration::from_secs(30))
.text("thump".to_string())
.stream(event_stream),
)
}
}

View File

@ -1,87 +0,0 @@
use crate::log_fatal;
use crate::messages::Event;
use serde_json::Value;
use std::{collections::HashSet, string::String};
use strum_macros::Display;
#[derive(Debug, Display, Clone)]
pub enum Message {
Update(Status),
Conversation(Value),
Notification(Value),
Delete(String),
FiltersChanged,
Announcement(AnnouncementType),
UnknownEvent(String, Value),
}
#[derive(Debug, Clone)]
pub struct Status(Value);
#[derive(Debug, Clone)]
pub enum AnnouncementType {
New(Value),
Delete(String),
Reaction(Value),
}
impl Message {
// pub fn from_json(event: Event) -> Self {
// use AnnouncementType::*;
// match event.event.as_ref() {
// "update" => Self::Update(Status(event.payload)),
// "conversation" => Self::Conversation(event.payload),
// "notification" => Self::Notification(event.payload),
// "delete" => Self::Delete(
// event
// .payload
// .as_str()
// .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event))
// .to_string(),
// ),
// "filters_changed" => Self::FiltersChanged,
// "announcement" => Self::Announcement(New(event.payload)),
// "announcement.reaction" => Self::Announcement(Reaction(event.payload)),
// "announcement.delete" => Self::Announcement(Delete(
// event
// .payload
// .as_str()
// .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event))
// .to_string(),
// )),
// other => {
// log::warn!("Received unexpected `event` from Redis: {}", other);
// Self::UnknownEvent(event.event.to_string(), event.payload)
// }
// }
// }
pub fn event(&self) -> String {
use AnnouncementType::*;
match self {
Self::Update(_) => "update",
Self::Conversation(_) => "conversation",
Self::Notification(_) => "notification",
Self::Announcement(New(_)) => "announcement",
Self::Announcement(Reaction(_)) => "announcement.reaction",
Self::UnknownEvent(event, _) => &event,
Self::Delete(_) => "delete",
Self::Announcement(Delete(_)) => "announcement.delete",
Self::FiltersChanged => "filters_changed",
}
.to_string()
}
pub fn payload(&self) -> String {
use AnnouncementType::*;
match self {
Self::Update(status) => status.0.to_string(),
Self::Conversation(value)
| Self::Notification(value)
| Self::Announcement(New(value))
| Self::Announcement(Reaction(value))
| Self::UnknownEvent(_, value) => value.to_string(),
Self::Delete(id) | Self::Announcement(Delete(id)) => id.clone(),
Self::FiltersChanged => "".to_string(),
}
}
}

View File

@ -1,103 +1,14 @@
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
pub mod client_agent;
pub mod receiver;
pub mod redis;
pub use client_agent::ClientAgent;
use futures::{future::Future, stream::Stream, Async};
use log;
use std::time::{Duration, Instant};
mod client_agent;
mod event_stream;
mod receiver;
mod redis;
/// Send a stream of replies to a Server Sent Events client.
pub fn send_updates_to_sse(
mut client_agent: ClientAgent,
connection: warp::sse::Sse,
update_interval: Duration,
) -> impl warp::reply::Reply {
let event_stream =
tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
match client_agent.poll() {
Ok(Async::Ready(Some(event))) => Some((
warp::sse::event(event.event_name()),
warp::sse::data(event.payload().unwrap_or_else(String::new)),
)),
_ => None,
}
});
pub use {client_agent::ClientAgent, event_stream::EventStream};
connection.reply(
warp::sse::keep_alive()
.interval(Duration::from_secs(30))
.text("thump".to_string())
.stream(event_stream),
)
}
use warp::ws::WebSocket;
/// Send a stream of replies to a WebSocket client.
pub fn send_updates_to_ws(
socket: WebSocket,
mut client_agent: ClientAgent,
update_interval: Duration,
) -> impl Future<Item = (), Error = ()> {
let (ws_tx, mut ws_rx) = socket.split();
let timeline = client_agent.subscription.timeline;
// Create a pipe
let (tx, rx) = futures::sync::mpsc::unbounded();
// Send one end of it to a different thread and tell that end to forward whatever it gets
// on to the websocket client
warp::spawn(
rx.map_err(|()| -> warp::Error { unreachable!() })
.forward(ws_tx)
.map(|_r| ())
.map_err(|e| match e.to_string().as_ref() {
"IO error: Broken pipe (os error 32)" => (), // just closed unix socket
_ => log::warn!("websocket send error: {}", e),
}),
);
// Yield new events for as long as the client is still connected
let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
move |_| match ws_rx.poll() {
Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
Ok(Async::Ready(None)) => {
// TODO: consider whether we should manually drop closed connections here
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
// no err, just closed Unix socket
log::info!("Client closed WebSocket connection for {:?}", timeline);
futures::future::ok(false)
}
Err(e) => {
log::warn!("Error in {:?}: {}", timeline, e);
futures::future::ok(false)
}
},
);
let mut time = Instant::now();
// Every time you get an event from that stream, send it through the pipe
event_stream
.for_each(move |_instant| {
if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
.expect("No send error");
};
if time.elapsed() > Duration::from_secs(30) {
tx.unbounded_send(warp::ws::Message::text("{}"))
.expect("Can ping");
time = Instant::now();
}
Ok(())
})
.then(move |result| {
// TODO: consider whether we should manually drop closed connections here
log::info!("WebSocket connection for {:?} closed.", timeline);
result
})
.map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
}
#[cfg(test)]
pub use receiver::process_messages;
#[cfg(test)]
pub use receiver::{MessageQueues, MsgQueue};
#[cfg(test)]
pub use redis::redis_msg::RedisMsg;

View File

@ -1,5 +1,5 @@
use crate::messages::Event;
use crate::parse_client_request::subscription::Timeline;
use crate::parse_client_request::Timeline;
use std::{
collections::{HashMap, VecDeque},
fmt,

View File

@ -2,62 +2,70 @@
//! polled by the correct `ClientAgent`. Also manages sububscriptions and
//! unsubscriptions to/from Redis.
mod message_queues;
pub use message_queues::{MessageQueues, MsgQueue};
use crate::{
config::{self, RedisInterval},
log_fatal,
config,
err::RedisParseErr,
messages::Event,
parse_client_request::subscription::{self, postgres, PgPool, Timeline},
parse_client_request::{Stream, Timeline},
pubsub_cmd,
redis_to_client_stream::redis::{redis_cmd, RedisConn, RedisStream},
redis_to_client_stream::redis::redis_msg::RedisMsg,
redis_to_client_stream::redis::{redis_cmd, RedisConn},
};
use futures::{Async, Poll};
use lru::LruCache;
pub use message_queues::{MessageQueues, MsgQueue};
use std::{collections::HashMap, net, time::Instant};
use tokio::io::AsyncRead;
use std::{
collections::HashMap,
io::Read,
net, str,
time::{Duration, Instant},
};
use tokio::io::Error;
use uuid::Uuid;
/// The item that streams from Redis and is polled by the `ClientAgent`
#[derive(Debug)]
pub struct Receiver {
pub pubsub_connection: RedisStream,
pub pubsub_connection: net::TcpStream,
secondary_redis_connection: net::TcpStream,
redis_poll_interval: RedisInterval,
redis_poll_interval: Duration,
redis_polled_at: Instant,
timeline: Timeline,
manager_id: Uuid,
pub msg_queues: MessageQueues,
clients_per_timeline: HashMap<Timeline, i32>,
cache: Cache,
pool: PgPool,
redis_input: Vec<u8>,
redis_namespace: Option<String>,
}
#[derive(Debug)]
pub struct Cache {
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
// the tag number instead of the tag name. This would save us from dealing
// with a cache here and would be consistent with how lists/users are handled.
id_to_hashtag: LruCache<i64, String>,
pub hashtag_to_id: LruCache<String, i64>,
}
impl Cache {
fn new(size: usize) -> Self {
Self {
id_to_hashtag: LruCache::new(size),
hashtag_to_id: LruCache::new(size),
}
}
}
impl Receiver {
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
/// active subscriptions).
pub fn new(redis_cfg: config::RedisConfig, pool: PgPool) -> Self {
pub fn new(redis_cfg: config::RedisConfig) -> Self {
let redis_namespace = redis_cfg.namespace.clone();
let RedisConn {
primary: pubsub_connection,
secondary: secondary_redis_connection,
namespace: redis_namespace,
polling_interval: redis_poll_interval,
} = RedisConn::new(redis_cfg);
Self {
pubsub_connection: RedisStream::from_stream(pubsub_connection)
.with_namespace(redis_namespace),
pubsub_connection,
secondary_redis_connection,
redis_poll_interval,
redis_polled_at: Instant::now(),
@ -65,8 +73,12 @@ impl Receiver {
manager_id: Uuid::default(),
msg_queues: MessageQueues(HashMap::new()),
clients_per_timeline: HashMap::new(),
cache: Cache::new(1000), // should this be a run-time option?
pool,
cache: Cache {
id_to_hashtag: LruCache::new(1000),
hashtag_to_id: LruCache::new(1000),
}, // should these be run-time options?
redis_input: Vec::new(),
redis_namespace,
}
}
@ -76,12 +88,15 @@ impl Receiver {
/// Note: this method calls `subscribe_or_unsubscribe_as_needed`,
/// so Redis PubSub subscriptions are only updated when a new timeline
/// comes under management for the first time.
pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: Timeline) {
self.manager_id = manager_id;
self.timeline = timeline;
self.msg_queues
.insert(self.manager_id, MsgQueue::new(timeline));
self.subscribe_or_unsubscribe_as_needed(timeline);
pub fn manage_new_timeline(&mut self, id: Uuid, tl: Timeline, hashtag: Option<String>) {
self.timeline = tl;
if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (hashtag, tl) {
self.cache.id_to_hashtag.put(id, hashtag.clone());
self.cache.hashtag_to_id.put(hashtag, id);
};
self.msg_queues.insert(id, MsgQueue::new(tl));
self.subscribe_or_unsubscribe_as_needed(tl);
}
/// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate
@ -91,26 +106,6 @@ impl Receiver {
self.timeline = timeline;
}
fn if_hashtag_timeline_get_hashtag_name(&mut self, timeline: Timeline) -> Option<String> {
use subscription::Stream::*;
if let Timeline(Hashtag(id), _, _) = timeline {
let cached_tag = self.cache.id_to_hashtag.get(&id).map(String::from);
let tag = match cached_tag {
Some(tag) => tag,
None => {
let new_tag = postgres::select_hashtag_name(&id, self.pool.clone())
.unwrap_or_else(|_| log_fatal!("No hashtag associated with tag #{}", &id));
self.cache.hashtag_to_id.put(new_tag.clone(), id);
self.cache.id_to_hashtag.put(id, new_tag.clone());
new_tag.to_string()
}
};
Some(tag)
} else {
None
}
}
/// Drop any PubSub subscriptions that don't have active clients and check
/// that there's a subscription to the current one. If there isn't, then
/// subscribe to it.
@ -121,8 +116,10 @@ impl Receiver {
// Record the lower number of clients subscribed to that channel
for change in timelines_to_modify {
let timeline = change.timeline;
let hashtag = self.if_hashtag_timeline_get_hashtag_name(timeline);
let hashtag = hashtag.as_ref();
let hashtag = match timeline {
Timeline(Stream::Hashtag(id), _, _) => self.cache.id_to_hashtag.get(&id),
_non_hashtag_timeline => None,
};
let count_of_subscribed_clients = self
.clients_per_timeline
@ -157,10 +154,27 @@ impl futures::stream::Stream for Receiver {
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let (timeline, id) = (self.timeline.clone(), self.manager_id);
if self.redis_polled_at.elapsed() > *self.redis_poll_interval {
self.pubsub_connection
.poll_redis(&mut self.cache.hashtag_to_id, &mut self.msg_queues);
self.redis_polled_at = Instant::now();
if self.redis_polled_at.elapsed() > self.redis_poll_interval {
let mut buffer = vec![0u8; 6000];
if let Ok(Async::Ready(bytes_read)) = self.poll_read(&mut buffer) {
let binary_input = buffer[..bytes_read].to_vec();
let (input, extra_bytes) = match str::from_utf8(&binary_input) {
Ok(input) => (input, "".as_bytes()),
Err(e) => {
let (valid, after_valid) = binary_input.split_at(e.valid_up_to());
let input = str::from_utf8(valid).expect("Guaranteed by `.valid_up_to`");
(input, after_valid)
}
};
let (cache, namespace) = (&mut self.cache.hashtag_to_id, &self.redis_namespace);
let remaining_input =
process_messages(input, cache, namespace, &mut self.msg_queues);
self.redis_input.extend_from_slice(remaining_input);
self.redis_input.extend_from_slice(extra_bytes);
}
}
// Record current time as last polled time
@ -173,3 +187,49 @@ impl futures::stream::Stream for Receiver {
}
}
}
impl Read for Receiver {
fn read(&mut self, buffer: &mut [u8]) -> Result<usize, std::io::Error> {
self.pubsub_connection.read(buffer)
}
}
impl AsyncRead for Receiver {
fn poll_read(&mut self, buf: &mut [u8]) -> Poll<usize, std::io::Error> {
match self.read(buf) {
Ok(t) => Ok(Async::Ready(t)),
Err(_) => Ok(Async::NotReady),
}
}
}
#[must_use]
pub fn process_messages<'a>(
input: &'a str,
mut cache: &mut LruCache<String, i64>,
namespace: &Option<String>,
msg_queues: &mut MessageQueues,
) -> &'a [u8] {
let mut remaining_input = input;
use RedisMsg::*;
loop {
match RedisMsg::from_raw(&mut remaining_input, &mut cache, namespace) {
Ok((EventMsg(timeline, event), rest)) => {
for msg_queue in msg_queues.values_mut() {
if msg_queue.timeline == timeline {
msg_queue.messages.push_back(event.clone());
}
}
remaining_input = rest;
}
Ok((SubscriptionMsg, rest)) | Ok((MsgForDifferentNamespace, rest)) => {
remaining_input = rest;
}
Err(RedisParseErr::Incomplete) => break,
Err(RedisParseErr::Unrecoverable) => {
panic!("Failed parsing Redis msg: {}", &remaining_input)
}
};
}
remaining_input.as_bytes()
}

View File

@ -1,9 +1,5 @@
pub mod redis_cmd;
pub mod redis_connection;
pub mod redis_msg;
pub mod redis_stream;
pub use redis_connection::RedisConn;
pub use redis_stream::RedisStream;

View File

@ -7,7 +7,7 @@ macro_rules! pubsub_cmd {
($cmd:expr, $self:expr, $tl:expr) => {{
use std::io::Write;
log::info!("Sending {} command to {}", $cmd, $tl);
let namespace = $self.pubsub_connection.namespace.clone();
let namespace = $self.redis_namespace.clone();
$self
.pubsub_connection

View File

@ -1,13 +1,12 @@
use super::redis_cmd;
use crate::config::{RedisConfig, RedisInterval, RedisNamespace};
use crate::config::RedisConfig;
use crate::err;
use std::{io::Read, io::Write, net, time};
use std::{io::Read, io::Write, net, time::Duration};
pub struct RedisConn {
pub primary: net::TcpStream,
pub secondary: net::TcpStream,
pub namespace: RedisNamespace,
pub polling_interval: RedisInterval,
pub polling_interval: Duration,
}
fn send_password(mut conn: net::TcpStream, password: &str) -> net::TcpStream {
@ -68,7 +67,7 @@ impl RedisConn {
conn = send_password(conn, &password);
}
conn = send_test_ping(conn);
conn.set_read_timeout(Some(time::Duration::from_millis(10)))
conn.set_read_timeout(Some(Duration::from_millis(10)))
.expect("Can set read timeout for Redis connection");
if let Some(db) = &*redis_cfg.db {
conn = set_db(conn, db);
@ -86,8 +85,7 @@ impl RedisConn {
Self {
primary: primary_conn,
secondary: secondary_conn,
namespace: redis_cfg.namespace,
polling_interval: redis_cfg.polling_interval,
polling_interval: *redis_cfg.polling_interval,
}
}
}

View File

@ -18,26 +18,31 @@
//! three characters, the second is a bulk string with ten characters, and the third is a
//! bulk string with 1,386 characters.
use crate::{log_fatal, messages::Event, parse_client_request::subscription::Timeline};
use crate::{
err::{RedisParseErr, TimelineErr},
messages::Event,
parse_client_request::Timeline,
};
use lru::LruCache;
type Parser<'a, Item> = Result<(Item, &'a str), ParseErr>;
#[derive(Debug)]
pub enum ParseErr {
Incomplete,
Unrecoverable,
}
use ParseErr::*;
type Parser<'a, Item> = Result<(Item, &'a str), RedisParseErr>;
/// A message that has been parsed from an incoming raw message from Redis.
#[derive(Debug, Clone)]
pub enum RedisMsg {
EventMsg(Timeline, Event),
SubscriptionMsg,
MsgForDifferentNamespace,
}
use RedisParseErr::*;
type Hashtags = LruCache<String, i64>;
impl RedisMsg {
pub fn from_raw<'a>(input: &'a str, cache: &mut Hashtags, prefix: usize) -> Parser<'a, Self> {
pub fn from_raw<'a>(
input: &'a str,
cache: &mut Hashtags,
namespace: &Option<String>,
) -> Parser<'a, Self> {
// No need to parse the Redis Array header, just skip it
let input = input.get("*3\r\n".len()..).ok_or(Incomplete)?;
let (command, rest) = parse_redis_bulk_string(&input)?;
@ -46,14 +51,16 @@ impl RedisMsg {
// Messages look like;
// $10\r\ntimeline:4\r\n
// $1386\r\n{\"event\":\"update\",\"payload\"...\"queued_at\":1569623342825}\r\n
let (raw_timeline, rest) = parse_redis_bulk_string(&rest)?;
let (timeline, rest) = parse_redis_bulk_string(&rest)?;
let (msg_txt, rest) = parse_redis_bulk_string(&rest)?;
let event: Event = serde_json::from_str(&msg_txt).map_err(|_| Unrecoverable)?;
let raw_timeline = &raw_timeline.get(prefix..).ok_or(Unrecoverable)?;
let event: Event = serde_json::from_str(&msg_txt).unwrap();
let hashtag = hashtag_from_timeline(&raw_timeline, cache);
let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag);
Ok((Self::EventMsg(timeline, event), rest))
use TimelineErr::*;
match Timeline::from_redis_raw_timeline(timeline, cache, namespace) {
Ok(timeline) => Ok((Self::EventMsg(timeline, event), rest)),
Err(RedisNamespaceMismatch) => Ok((Self::MsgForDifferentNamespace, rest)),
Err(InvalidInput) => Err(RedisParseErr::Unrecoverable),
}
}
"subscribe" | "unsubscribe" => {
// subscription statuses look like:
@ -101,18 +108,3 @@ fn parse_number_at(input: &str) -> Parser<usize> {
let rest = &input.get(number_len..).ok_or(Incomplete)?;
Ok((number, rest))
}
fn hashtag_from_timeline(raw_timeline: &str, hashtag_id_cache: &mut Hashtags) -> Option<i64> {
if raw_timeline.starts_with("hashtag") {
let tag_name = raw_timeline
.split(':')
.nth(1)
.unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline))
.to_string();
let tag_id = *hashtag_id_cache
.get(&tag_name)
.unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
Some(tag_id)
} else {
None
}
}

View File

@ -1,127 +0,0 @@
use super::redis_msg::{ParseErr, RedisMsg};
use crate::config::RedisNamespace;
use crate::log_fatal;
use crate::redis_to_client_stream::receiver::MessageQueues;
use futures::{Async, Poll};
use lru::LruCache;
use std::{error::Error, io::Read, net};
use tokio::io::AsyncRead;
#[derive(Debug)]
pub struct RedisStream {
pub inner: net::TcpStream,
incoming_raw_msg: String,
pub namespace: RedisNamespace,
}
impl RedisStream {
pub fn from_stream(inner: net::TcpStream) -> Self {
RedisStream {
inner,
incoming_raw_msg: String::new(),
namespace: RedisNamespace(None),
}
}
pub fn with_namespace(self, namespace: RedisNamespace) -> Self {
RedisStream { namespace, ..self }
}
// Text comes in from redis as a raw stream, which could be more than one message and
// is not guaranteed to end on a message boundary. We need to break it down into
// messages. Incoming messages *are* guaranteed to be RESP arrays (though still not
// guaranteed to end at an array boundary). See https://redis.io/topics/protocol
/// Adds any new Redis messages to the `MsgQueue` for the appropriate `ClientAgent`.
pub fn poll_redis(
&mut self,
hashtag_to_id_cache: &mut LruCache<String, i64>,
queues: &mut MessageQueues,
) {
let mut buffer = vec![0u8; 6000];
if let Ok(Async::Ready(num_bytes_read)) = self.poll_read(&mut buffer) {
let raw_utf = self.as_utf8(buffer, num_bytes_read);
self.incoming_raw_msg.push_str(&raw_utf);
match process_messages(
self.incoming_raw_msg.clone(),
&mut self.namespace.0,
hashtag_to_id_cache,
queues,
) {
Ok(None) => self.incoming_raw_msg.clear(),
Ok(Some(msg_fragment)) => self.incoming_raw_msg = msg_fragment,
Err(e) => {
log::error!("{}", e);
log_fatal!("Could not process RedisStream: {:?}", &self);
}
}
}
}
fn as_utf8(&mut self, cur_buffer: Vec<u8>, size: usize) -> String {
String::from_utf8(cur_buffer[..size].to_vec()).unwrap_or_else(|_| {
let mut new_buffer = vec![0u8; 1];
self.poll_read(&mut new_buffer).unwrap();
let buffer = ([cur_buffer, new_buffer]).concat();
self.as_utf8(buffer, size + 1)
})
}
}
type HashtagCache = LruCache<String, i64>;
pub fn process_messages(
raw_msg: String,
namespace: &mut Option<String>,
cache: &mut HashtagCache,
queues: &mut MessageQueues,
) -> Result<Option<String>, Box<dyn Error>> {
let prefix_len = match namespace {
Some(namespace) => format!("{}:timeline:", namespace).len(),
None => "timeline:".len(),
};
let mut input = raw_msg.as_str();
loop {
let rest = match RedisMsg::from_raw(&input, cache, prefix_len) {
Ok((RedisMsg::EventMsg(timeline, event), rest)) => {
for msg_queue in queues.values_mut() {
if msg_queue.timeline == timeline {
msg_queue.messages.push_back(event.clone());
}
}
rest
}
Ok((RedisMsg::SubscriptionMsg, rest)) => rest,
Err(ParseErr::Incomplete) => break,
Err(ParseErr::Unrecoverable) => log_fatal!("Failed parsing Redis msg: {}", &input),
};
input = rest
}
Ok(Some(input.to_string()))
}
impl std::ops::Deref for RedisStream {
type Target = net::TcpStream;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl std::ops::DerefMut for RedisStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl Read for RedisStream {
fn read(&mut self, buffer: &mut [u8]) -> Result<usize, std::io::Error> {
self.inner.read(buffer)
}
}
impl AsyncRead for RedisStream {
fn poll_read(&mut self, buf: &mut [u8]) -> Poll<usize, std::io::Error> {
match self.read(buf) {
Ok(t) => Ok(Async::Ready(t)),
Err(_) => Ok(Async::NotReady),
}
}
}