diff --git a/Cargo.lock b/Cargo.lock
index e33dbe8..7cb4f27 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "flodgatt"
-version = "0.6.4"
+version = "0.6.5"
dependencies = [
"criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
diff --git a/benches/parse_redis.rs b/benches/parse_redis.rs
index a683583..c1fd987 100644
--- a/benches/parse_redis.rs
+++ b/benches/parse_redis.rs
@@ -3,95 +3,27 @@ use criterion::criterion_group;
use criterion::criterion_main;
use criterion::Criterion;
-const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"
Trending tags:
#neverforget
#4styles
#newpipe
#uber
#mercredifiction
\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"Hello, I am mastodon.host official semi bot.
Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ).
I also randomly boost one of my followers toot every hour !
If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)
If you want me to follow you, just tell me follow !
If you want automatic follow for new users on your instance and you are an instance admin, contact me !
Other commands are private :)
\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"https://mastodon.host/stats.html\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"https://mastodon.host/about/more\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"@gled\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n";
-
-/// Parses the Redis message using a Regex.
-///
-/// The naive approach from Flodgatt's proof-of-concept stage.
-mod regex_parse {
- use regex::Regex;
- use serde_json::Value;
-
- pub fn to_json_value(input: String) -> Value {
- if input.ends_with("}\r\n") {
- let messages = input.as_str().split("message").skip(1);
- let regex = Regex::new(r"timeline:(?P.*?)\r\n\$\d+\r\n(?P.*?)\r\n")
- .expect("Hard-codded");
- for message in messages {
- let _timeline = regex.captures(message).expect("Hard-coded timeline regex")
- ["timeline"]
- .to_string();
-
- let redis_msg: Value = serde_json::from_str(
- ®ex.captures(message).expect("Hard-coded value regex")["value"],
- )
- .expect("Valid json");
-
- return redis_msg;
- }
- unreachable!()
- } else {
- unreachable!()
- }
- }
-}
-
-/// Parse with a simplified inline iterator.
-///
-/// Essentially shows best-case performance for producing a serde_json::Value.
-mod parse_inline {
- use serde_json::Value;
- pub fn to_json_value(input: String) -> Value {
- fn print_next_str(mut end: usize, input: &str) -> (usize, String) {
- let mut start = end + 3;
- end = start + 1;
-
- let mut iter = input.chars();
- iter.nth(start);
-
- while iter.next().unwrap().is_digit(10) {
- end += 1;
- }
- let length = &input[start..end].parse::().unwrap();
- start = end + 2;
- end = start + length;
-
- let string = &input[start..end];
- (end, string.to_string())
- }
-
- if input.ends_with("}\r\n") {
- let end = 2;
- let (end, _) = print_next_str(end, &input);
- let (end, _timeline) = print_next_str(end, &input);
- let (_, msg) = print_next_str(end, &input);
- let redis_msg: Value = serde_json::from_str(&msg).unwrap();
- redis_msg
- } else {
- unreachable!()
- }
- }
-}
-
/// Parse using Flodgatt's current functions
mod flodgatt_parse_event {
- use flodgatt::{messages::Event, redis_to_client_stream::receiver::MessageQueues};
use flodgatt::{
- parse_client_request::subscription::Timeline,
- redis_to_client_stream::{receiver::MsgQueue, redis::redis_stream},
+ messages::Event,
+ parse_client_request::Timeline,
+ redis_to_client_stream::{process_messages, MessageQueues, MsgQueue},
};
use lru::LruCache;
use std::collections::HashMap;
use uuid::Uuid;
/// One-time setup, not included in testing time.
- pub fn setup() -> MessageQueues {
+ pub fn setup() -> (LruCache, MessageQueues, Uuid, Timeline) {
+ let mut cache: LruCache = LruCache::new(1000);
let mut queues_map = HashMap::new();
let id = Uuid::default();
- let timeline = Timeline::from_redis_raw_timeline("1", None);
+ let timeline =
+ Timeline::from_redis_raw_timeline("timeline:1", &mut cache, &None).expect("In test");
queues_map.insert(id, MsgQueue::new(timeline));
let queues = MessageQueues(queues_map);
- queues
+ (cache, queues, id, timeline)
}
pub fn to_event_struct(
@@ -101,201 +33,7 @@ mod flodgatt_parse_event {
id: Uuid,
timeline: Timeline,
) -> Event {
- redis_stream::process_messages(input, &mut None, &mut cache, &mut queues).unwrap();
- queues
- .oldest_msg_in_target_queue(id, timeline)
- .expect("In test")
- }
-}
-
-/// Parse using modified a modified version of Flodgatt's current function.
-///
-/// This version is modified to return a serde_json::Value instead of an Event to shows
-/// the performance we would see if we used serde's built-in method for handling weakly
-/// typed JSON instead of our own strongly typed struct.
-mod flodgatt_parse_value {
- use flodgatt::{log_fatal, parse_client_request::subscription::Timeline};
- use lru::LruCache;
- use serde_json::Value;
- use std::{
- collections::{HashMap, VecDeque},
- time::Instant,
- };
- use uuid::Uuid;
- #[derive(Debug)]
- pub struct RedisMsg<'a> {
- pub raw: &'a str,
- pub cursor: usize,
- pub prefix_len: usize,
- }
-
- impl<'a> RedisMsg<'a> {
- pub fn from_raw(raw: &'a str, prefix_len: usize) -> Self {
- Self {
- raw,
- cursor: "*3\r\n".len(), //length of intro header
- prefix_len,
- }
- }
-
- /// Move the cursor from the beginning of a number through its end and return the number
- pub fn process_number(&mut self) -> usize {
- let (mut selected_number, selection_start) = (0, self.cursor);
- while let Ok(number) = self.raw[selection_start..=self.cursor].parse::() {
- self.cursor += 1;
- selected_number = number;
- }
- selected_number
- }
-
- /// In a pubsub reply from Redis, an item can be either the name of the subscribed channel
- /// or the msg payload. Either way, it follows the same format:
- /// `$[LENGTH_OF_ITEM_BODY]\r\n[ITEM_BODY]\r\n`
- pub fn next_field(&mut self) -> String {
- self.cursor += "$".len();
-
- let item_len = self.process_number();
- self.cursor += "\r\n".len();
- let item_start_position = self.cursor;
- self.cursor += item_len;
- let item = self.raw[item_start_position..self.cursor].to_string();
- self.cursor += "\r\n".len();
- item
- }
-
- pub fn extract_raw_timeline_and_message(&mut self) -> (String, Value) {
- let timeline = &self.next_field()[self.prefix_len..];
- let msg_txt = self.next_field();
- let msg_value: Value = serde_json::from_str(&msg_txt)
- .unwrap_or_else(|_| log_fatal!("Invalid JSON from Redis: {:?}", &msg_txt));
- (timeline.to_string(), msg_value)
- }
- }
-
- pub struct MsgQueue {
- pub timeline: Timeline,
- pub messages: VecDeque,
- _last_polled_at: Instant,
- }
-
- pub struct MessageQueues(HashMap);
- impl std::ops::Deref for MessageQueues {
- type Target = HashMap;
- fn deref(&self) -> &Self::Target {
- &self.0
- }
- }
-
- impl std::ops::DerefMut for MessageQueues {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.0
- }
- }
-
- impl MessageQueues {
- pub fn oldest_msg_in_target_queue(
- &mut self,
- id: Uuid,
- timeline: Timeline,
- ) -> Option {
- self.entry(id)
- .or_insert_with(|| MsgQueue::new(timeline))
- .messages
- .pop_front()
- }
- }
-
- impl MsgQueue {
- pub fn new(timeline: Timeline) -> Self {
- MsgQueue {
- messages: VecDeque::new(),
- _last_polled_at: Instant::now(),
- timeline,
- }
- }
- }
-
- pub fn process_msg(
- raw_utf: String,
- namespace: &Option,
- hashtag_id_cache: &mut LruCache,
- queues: &mut MessageQueues,
- ) {
- // Only act if we have a full message (end on a msg boundary)
- if !raw_utf.ends_with("}\r\n") {
- return;
- };
- let prefix_to_skip = match namespace {
- Some(namespace) => format!("{}:timeline:", namespace),
- None => "timeline:".to_string(),
- };
-
- let mut msg = RedisMsg::from_raw(&raw_utf, prefix_to_skip.len());
-
- while !msg.raw.is_empty() {
- let command = msg.next_field();
- match command.as_str() {
- "message" => {
- let (raw_timeline, msg_value) = msg.extract_raw_timeline_and_message();
- let hashtag = hashtag_from_timeline(&raw_timeline, hashtag_id_cache);
- let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag);
- for msg_queue in queues.values_mut() {
- if msg_queue.timeline == timeline {
- msg_queue.messages.push_back(msg_value.clone());
- }
- }
- }
-
- "subscribe" | "unsubscribe" => {
- // No msg, so ignore & advance cursor to end
- let _channel = msg.next_field();
- msg.cursor += ":".len();
- let _active_subscriptions = msg.process_number();
- msg.cursor += "\r\n".len();
- }
- cmd => panic!("Invariant violation: {} is unexpected Redis output", cmd),
- };
- msg = RedisMsg::from_raw(&msg.raw[msg.cursor..], msg.prefix_len);
- }
- }
-
- fn hashtag_from_timeline(
- raw_timeline: &str,
- hashtag_id_cache: &mut LruCache,
- ) -> Option {
- if raw_timeline.starts_with("hashtag") {
- let tag_name = raw_timeline
- .split(':')
- .nth(1)
- .unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline))
- .to_string();
-
- let tag_id = *hashtag_id_cache
- .get(&tag_name)
- .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
- Some(tag_id)
- } else {
- None
- }
- }
- pub fn setup() -> (LruCache, MessageQueues, Uuid, Timeline) {
- let cache: LruCache = LruCache::new(1000);
- let mut queues_map = HashMap::new();
- let id = Uuid::default();
- let timeline = Timeline::from_redis_raw_timeline("1", None);
- queues_map.insert(id, MsgQueue::new(timeline));
- let queues = MessageQueues(queues_map);
- (cache, queues, id, timeline)
- }
-
- pub fn to_json_value(
- input: String,
- mut cache: &mut LruCache,
- mut queues: &mut MessageQueues,
- id: Uuid,
- timeline: Timeline,
- ) -> Value {
- process_msg(input, &None, &mut cache, &mut queues);
+ process_messages(&input, &mut cache, &mut None, &mut queues);
queues
.oldest_msg_in_target_queue(id, timeline)
.expect("In test")
@@ -303,28 +41,10 @@ mod flodgatt_parse_value {
}
fn criterion_benchmark(c: &mut Criterion) {
- let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string(); //INPUT.to_string();
+ let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string();
let mut group = c.benchmark_group("Parse redis RESP array");
- // group.bench_function("parse to Value with a regex", |b| {
- // b.iter(|| regex_parse::to_json_value(black_box(input.clone())))
- // });
- group.bench_function("parse to Value inline", |b| {
- b.iter(|| parse_inline::to_json_value(black_box(input.clone())))
- });
- let (mut cache, mut queues, id, timeline) = flodgatt_parse_value::setup();
- group.bench_function("parse to Value using Flodgatt functions", |b| {
- b.iter(|| {
- black_box(flodgatt_parse_value::to_json_value(
- black_box(input.clone()),
- black_box(&mut cache),
- black_box(&mut queues),
- black_box(id),
- black_box(timeline),
- ))
- })
- });
- let mut queues = flodgatt_parse_event::setup();
+ let (mut cache, mut queues, id, timeline) = flodgatt_parse_event::setup();
group.bench_function("parse to Event using Flodgatt functions", |b| {
b.iter(|| {
black_box(flodgatt_parse_event::to_event_struct(
@@ -340,3 +60,5 @@ fn criterion_benchmark(c: &mut Criterion) {
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
+
+const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"Trending tags:
#neverforget
#4styles
#newpipe
#uber
#mercredifiction
\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"Hello, I am mastodon.host official semi bot.
Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ).
I also randomly boost one of my followers toot every hour !
If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)
If you want me to follow you, just tell me follow !
If you want automatic follow for new users on your instance and you are an instance admin, contact me !
Other commands are private :)
\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"https://mastodon.host/stats.html\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"https://mastodon.host/about/more\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"@gled\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n";
diff --git a/src/config/deployment_cfg_types.rs b/src/config/deployment_cfg_types.rs
index b09de80..97d018b 100644
--- a/src/config/deployment_cfg_types.rs
+++ b/src/config/deployment_cfg_types.rs
@@ -11,14 +11,14 @@ from_env_var!(
/// The current environment, which controls what file to read other ENV vars from
let name = Env;
let default: EnvInner = EnvInner::Development;
- let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants()));
+ let (env_var, allowed_values) = ("RUST_ENV", &format!("one of: {:?}", EnvInner::variants()));
let from_str = |s| EnvInner::from_str(s).ok();
);
from_env_var!(
/// The address to run Flodgatt on
let name = FlodgattAddr;
let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded"));
- let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string());
+ let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)");
let from_str = |s| match s {
"localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
_ => s.parse().ok(),
@@ -28,35 +28,35 @@ from_env_var!(
/// How verbosely Flodgatt should log messages
let name = LogLevel;
let default: LogLevelInner = LogLevelInner::Warn;
- let (env_var, allowed_values) = ("RUST_LOG", format!("one of: {:?}", LogLevelInner::variants()));
+ let (env_var, allowed_values) = ("RUST_LOG", &format!("one of: {:?}", LogLevelInner::variants()));
let from_str = |s| LogLevelInner::from_str(s).ok();
);
from_env_var!(
/// A Unix Socket to use in place of a local address
let name = Socket;
let default: Option = None;
- let (env_var, allowed_values) = ("SOCKET", "any string".to_string());
+ let (env_var, allowed_values) = ("SOCKET", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// The time between replies sent via WebSocket
let name = WsInterval;
let default: Duration = Duration::from_millis(100);
- let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
+ let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The time between replies sent via Server Sent Events
let name = SseInterval;
let default: Duration = Duration::from_millis(100);
- let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string());
+ let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The port to run Flodgatt on
let name = Port;
let default: u16 = 4000;
- let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string());
+ let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
@@ -66,7 +66,7 @@ from_env_var!(
/// (including otherwise public timelines).
let name = WhitelistMode;
let default: bool = false;
- let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string());
+ let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false");
let from_str = |s| s.parse().ok();
);
/// Permissions for Cross Origin Resource Sharing (CORS)
diff --git a/src/config/environmental_variables.rs b/src/config/environmental_variables.rs
new file mode 100644
index 0000000..a790c0b
--- /dev/null
+++ b/src/config/environmental_variables.rs
@@ -0,0 +1,137 @@
+use std::{collections::HashMap, fmt};
+
+pub struct EnvVar(pub HashMap);
+impl std::ops::Deref for EnvVar {
+ type Target = HashMap;
+ fn deref(&self) -> &HashMap {
+ &self.0
+ }
+}
+
+impl Clone for EnvVar {
+ fn clone(&self) -> Self {
+ Self(self.0.clone())
+ }
+}
+impl EnvVar {
+ pub fn new(vars: HashMap) -> Self {
+ Self(vars)
+ }
+
+ pub fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) {
+ if let Some(value) = maybe_value {
+ self.0.insert(key.to_string(), value.to_string());
+ }
+ }
+
+ pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! {
+ log::error!(
+ r"{var} is set to `{value}`, which is invalid.
+ {var} must be {allowed_vals}.",
+ var = env_var,
+ value = supplied_value,
+ allowed_vals = allowed_values
+ );
+ std::process::exit(1);
+ }
+}
+impl fmt::Display for EnvVar {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let mut result = String::new();
+ for env_var in [
+ "NODE_ENV",
+ "RUST_LOG",
+ "BIND",
+ "PORT",
+ "SOCKET",
+ "SSE_FREQ",
+ "WS_FREQ",
+ "DATABASE_URL",
+ "DB_USER",
+ "USER",
+ "DB_PORT",
+ "DB_HOST",
+ "DB_PASS",
+ "DB_NAME",
+ "DB_SSLMODE",
+ "REDIS_HOST",
+ "REDIS_USER",
+ "REDIS_PORT",
+ "REDIS_PASSWORD",
+ "REDIS_USER",
+ "REDIS_DB",
+ ]
+ .iter()
+ {
+ if let Some(value) = self.get(&env_var.to_string()) {
+ result = format!("{}\n {}: {}", result, env_var, value)
+ }
+ }
+ write!(f, "{}", result)
+ }
+}
+#[macro_export]
+macro_rules! maybe_update {
+ ($name:ident; $item: tt:$type:ty) => (
+ pub fn $name(self, item: Option<$type>) -> Self {
+ match item {
+ Some($item) => Self{ $item, ..self },
+ None => Self { ..self }
+ }
+ });
+ ($name:ident; Some($item: tt: $type:ty)) => (
+ fn $name(self, item: Option<$type>) -> Self{
+ match item {
+ Some($item) => Self{ $item: Some($item), ..self },
+ None => Self { ..self }
+ }
+ })}
+#[macro_export]
+macro_rules! from_env_var {
+ ($(#[$outer:meta])*
+ let name = $name:ident;
+ let default: $type:ty = $inner:expr;
+ let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
+ let from_str = |$arg:ident| $body:expr;
+ ) => {
+ pub struct $name(pub $type);
+ impl std::fmt::Debug for $name {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{:?}", self.0)
+ }
+ }
+ impl std::ops::Deref for $name {
+ type Target = $type;
+ fn deref(&self) -> &$type {
+ &self.0
+ }
+ }
+ impl std::default::Default for $name {
+ fn default() -> Self {
+ $name($inner)
+ }
+ }
+ impl $name {
+ fn inner_from_str($arg: &str) -> Option<$type> {
+ $body
+ }
+ pub fn maybe_update(self, var: Option<&String>) -> Self {
+ match var {
+ Some(empty_string) if empty_string.is_empty() => Self::default(),
+ Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
+ crate::config::EnvVar::err($env_var, value, $allowed_values)
+ })),
+ None => self,
+ }
+
+ // if let Some(value) = var {
+ // Self(Self::inner_from_str(value).unwrap_or_else(|| {
+ // crate::err::env_var_fatal($env_var, value, $allowed_values)
+ // }))
+ // } else {
+ // self
+ // }
+ }
+ }
+ };
+}
diff --git a/src/config/mod.rs b/src/config/mod.rs
index b2b314d..6d7d8bc 100644
--- a/src/config/mod.rs
+++ b/src/config/mod.rs
@@ -4,135 +4,7 @@ mod postgres_cfg;
mod postgres_cfg_types;
mod redis_cfg;
mod redis_cfg_types;
-pub use self::{
- deployment_cfg::DeploymentConfig,
- postgres_cfg::PostgresConfig,
- redis_cfg::RedisConfig,
- redis_cfg_types::{RedisInterval, RedisNamespace},
-};
-use std::{collections::HashMap, fmt};
+mod environmental_variables;
-pub struct EnvVar(pub HashMap);
-impl std::ops::Deref for EnvVar {
- type Target = HashMap;
- fn deref(&self) -> &HashMap {
- &self.0
- }
-}
+pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar};
-impl Clone for EnvVar {
- fn clone(&self) -> Self {
- Self(self.0.clone())
- }
-}
-impl EnvVar {
- pub fn new(vars: HashMap) -> Self {
- Self(vars)
- }
-
- fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) {
- if let Some(value) = maybe_value {
- self.0.insert(key.to_string(), value.to_string());
- }
- }
-}
-impl fmt::Display for EnvVar {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- let mut result = String::new();
- for env_var in [
- "NODE_ENV",
- "RUST_LOG",
- "BIND",
- "PORT",
- "SOCKET",
- "SSE_FREQ",
- "WS_FREQ",
- "DATABASE_URL",
- "DB_USER",
- "USER",
- "DB_PORT",
- "DB_HOST",
- "DB_PASS",
- "DB_NAME",
- "DB_SSLMODE",
- "REDIS_HOST",
- "REDIS_USER",
- "REDIS_PORT",
- "REDIS_PASSWORD",
- "REDIS_USER",
- "REDIS_DB",
- ]
- .iter()
- {
- if let Some(value) = self.get(&env_var.to_string()) {
- result = format!("{}\n {}: {}", result, env_var, value)
- }
- }
- write!(f, "{}", result)
- }
-}
-#[macro_export]
-macro_rules! maybe_update {
- ($name:ident; $item: tt:$type:ty) => (
- pub fn $name(self, item: Option<$type>) -> Self {
- match item {
- Some($item) => Self{ $item, ..self },
- None => Self { ..self }
- }
- });
- ($name:ident; Some($item: tt: $type:ty)) => (
- fn $name(self, item: Option<$type>) -> Self{
- match item {
- Some($item) => Self{ $item: Some($item), ..self },
- None => Self { ..self }
- }
- })}
-#[macro_export]
-macro_rules! from_env_var {
- ($(#[$outer:meta])*
- let name = $name:ident;
- let default: $type:ty = $inner:expr;
- let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
- let from_str = |$arg:ident| $body:expr;
- ) => {
- pub struct $name(pub $type);
- impl std::fmt::Debug for $name {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "{:?}", self.0)
- }
- }
- impl std::ops::Deref for $name {
- type Target = $type;
- fn deref(&self) -> &$type {
- &self.0
- }
- }
- impl std::default::Default for $name {
- fn default() -> Self {
- $name($inner)
- }
- }
- impl $name {
- fn inner_from_str($arg: &str) -> Option<$type> {
- $body
- }
- pub fn maybe_update(self, var: Option<&String>) -> Self {
- match var {
- Some(empty_string) if empty_string.is_empty() => Self::default(),
- Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| {
- crate::err::env_var_fatal($env_var, value, $allowed_values)
- })),
- None => self,
- }
-
- // if let Some(value) = var {
- // Self(Self::inner_from_str(value).unwrap_or_else(|| {
- // crate::err::env_var_fatal($env_var, value, $allowed_values)
- // }))
- // } else {
- // self
- // }
- }
- }
- };
-}
diff --git a/src/config/postgres_cfg_types.rs b/src/config/postgres_cfg_types.rs
index 066f2fd..7551d1b 100644
--- a/src/config/postgres_cfg_types.rs
+++ b/src/config/postgres_cfg_types.rs
@@ -6,7 +6,7 @@ from_env_var!(
/// The user to use for Postgres
let name = PgUser;
let default: String = "postgres".to_string();
- let (env_var, allowed_values) = ("DB_USER", "any string".to_string());
+ let (env_var, allowed_values) = ("DB_USER", "any string");
let from_str = |s| Some(s.to_string());
);
@@ -14,7 +14,7 @@ from_env_var!(
/// The host address where Postgres is running)
let name = PgHost;
let default: String = "localhost".to_string();
- let (env_var, allowed_values) = ("DB_HOST", "any string".to_string());
+ let (env_var, allowed_values) = ("DB_HOST", "any string");
let from_str = |s| Some(s.to_string());
);
@@ -22,7 +22,7 @@ from_env_var!(
/// The password to use with Postgress
let name = PgPass;
let default: Option = None;
- let (env_var, allowed_values) = ("DB_PASS", "any string".to_string());
+ let (env_var, allowed_values) = ("DB_PASS", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
@@ -30,7 +30,7 @@ from_env_var!(
/// The Postgres database to use
let name = PgDatabase;
let default: String = "mastodon_development".to_string();
- let (env_var, allowed_values) = ("DB_NAME", "any string".to_string());
+ let (env_var, allowed_values) = ("DB_NAME", "any string");
let from_str = |s| Some(s.to_string());
);
@@ -38,14 +38,14 @@ from_env_var!(
/// The port Postgres is running on
let name = PgPort;
let default: u16 = 5432;
- let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535".to_string());
+ let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
let name = PgSslMode;
let default: PgSslInner = PgSslInner::Prefer;
- let (env_var, allowed_values) = ("DB_SSLMODE", format!("one of: {:?}", PgSslInner::variants()));
+ let (env_var, allowed_values) = ("DB_SSLMODE", &format!("one of: {:?}", PgSslInner::variants()));
let from_str = |s| PgSslInner::from_str(s).ok();
);
diff --git a/src/config/redis_cfg_types.rs b/src/config/redis_cfg_types.rs
index d3dbbf0..8d43c76 100644
--- a/src/config/redis_cfg_types.rs
+++ b/src/config/redis_cfg_types.rs
@@ -7,48 +7,48 @@ from_env_var!(
/// The host address where Redis is running
let name = RedisHost;
let default: String = "127.0.0.1".to_string();
- let (env_var, allowed_values) = ("REDIS_HOST", "any string".to_string());
+ let (env_var, allowed_values) = ("REDIS_HOST", "any string");
let from_str = |s| Some(s.to_string());
);
from_env_var!(
/// The port Redis is running on
let name = RedisPort;
let default: u16 = 6379;
- let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string());
+ let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535");
let from_str = |s| s.parse().ok();
);
from_env_var!(
/// How frequently to poll Redis
let name = RedisInterval;
let default: Duration = Duration::from_millis(100);
- let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string());
+ let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds");
let from_str = |s| s.parse().map(Duration::from_millis).ok();
);
from_env_var!(
/// The password to use for Redis
let name = RedisPass;
let default: Option = None;
- let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string());
+ let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// An optional Redis Namespace
let name = RedisNamespace;
let default: Option = None;
- let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string());
+ let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// A user for Redis (not supported)
let name = RedisUser;
let default: Option = None;
- let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string());
+ let (env_var, allowed_values) = ("REDIS_USER", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
from_env_var!(
/// The database to use with Redis (no current effect for PubSub connections)
let name = RedisDb;
let default: Option = None;
- let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string());
+ let (env_var, allowed_values) = ("REDIS_DB", "any string");
let from_str = |s| Some(Some(s.to_string()));
);
diff --git a/src/err.rs b/src/err.rs
index 6ffbaa0..2863c5e 100644
--- a/src/err.rs
+++ b/src/err.rs
@@ -1,4 +1,3 @@
-use serde_derive::Serialize;
use std::fmt::Display;
pub fn die_with_msg(msg: impl Display) -> ! {
@@ -14,67 +13,20 @@ macro_rules! log_fatal {
};};
}
-pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
- eprintln!(
- r"FATAL ERROR: {var} is set to `{value}`, which is invalid.
- {var} must be {allowed_vals}.",
- var = env_var,
- value = supplied_value,
- allowed_vals = allowed_values
- );
- std::process::exit(1);
+#[derive(Debug)]
+pub enum RedisParseErr {
+ Incomplete,
+ Unrecoverable,
}
-#[macro_export]
-macro_rules! dbg_and_die {
- ($msg:expr) => {
- let message = format!("FATAL ERROR: {}", $msg);
- dbg!(message);
- std::process::exit(1);
- };
-}
-pub fn unwrap_or_die(s: Option, msg: &str) -> T {
- s.unwrap_or_else(|| {
- eprintln!("FATAL ERROR: {}", msg);
- std::process::exit(1)
- })
+#[derive(Debug)]
+pub enum TimelineErr {
+ RedisNamespaceMismatch,
+ InvalidInput,
}
-#[derive(Serialize)]
-pub struct ErrorMessage {
- error: String,
-}
-impl ErrorMessage {
- fn new(msg: impl std::fmt::Display) -> Self {
- Self {
- error: msg.to_string(),
- }
- }
-}
-
-/// Recover from Errors by sending appropriate Warp::Rejections
-pub fn handle_errors(
- rejection: warp::reject::Rejection,
-) -> Result {
- let err_txt = match rejection.cause() {
- Some(text) if text.to_string() == "Missing request header 'authorization'" => {
- "Error: Missing access token".to_string()
- }
- Some(text) => text.to_string(),
- None => "Error: Nonexistant endpoint".to_string(),
- };
- let json = warp::reply::json(&ErrorMessage::new(err_txt));
-
- Ok(warp::reply::with_status(
- json,
- warp::http::StatusCode::UNAUTHORIZED,
- ))
-}
-
-pub struct CustomError {}
-
-impl CustomError {
- pub fn unauthorized_list() -> warp::reject::Rejection {
- warp::reject::custom("Error: Access to list not authorized")
+impl From for TimelineErr {
+ fn from(_error: std::num::ParseIntError) -> Self {
+ Self::InvalidInput
}
}
diff --git a/src/main.rs b/src/main.rs
index b964926..2d2aaa4 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,87 +1,76 @@
use flodgatt::{
- config, err,
- parse_client_request::{sse, subscription, ws},
- redis_to_client_stream::{self, ClientAgent},
+ config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig},
+ parse_client_request::{PgPool, Subscription},
+ redis_to_client_stream::{ClientAgent, EventStream},
};
use std::{collections::HashMap, env, fs, net, os::unix::fs::PermissionsExt};
use tokio::net::UnixListener;
use warp::{path, ws::Ws2, Filter};
fn main() {
- dotenv::from_filename(
- match env::var("ENV").ok().as_ref().map(String::as_str) {
+ dotenv::from_filename(match env::var("ENV").ok().as_ref().map(String::as_str) {
Some("production") => ".env.production",
Some("development") | None => ".env",
- Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."),
- }).ok();
+ Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"),
+ })
+ .ok();
let env_vars_map: HashMap<_, _> = dotenv::vars().collect();
- let env_vars = config::EnvVar::new(env_vars_map);
+ let env_vars = EnvVar::new(env_vars_map);
pretty_env_logger::init();
log::info!(
"Flodgatt recognized the following environmental variables:{}",
env_vars.clone()
);
- let redis_cfg = config::RedisConfig::from_env(env_vars.clone());
- let cfg = config::DeploymentConfig::from_env(env_vars.clone());
+ let redis_cfg = RedisConfig::from_env(env_vars.clone());
+ let cfg = DeploymentConfig::from_env(env_vars.clone());
- let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone());
- let pg_pool = subscription::PgPool::new(postgres_cfg);
+ let postgres_cfg = PostgresConfig::from_env(env_vars.clone());
+ let pg_pool = PgPool::new(postgres_cfg);
- let client_agent_sse = ClientAgent::blank(redis_cfg, pg_pool.clone());
+ let client_agent_sse = ClientAgent::blank(redis_cfg);
let client_agent_ws = client_agent_sse.clone_with_shared_receiver();
log::info!("Streaming server initialized and ready to accept connections");
// Server Sent Events
- let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
- let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode)
+ let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode);
+ let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode)
.and(warp::sse())
.map(
- move |subscription: subscription::Subscription,
- sse_connection_to_client: warp::sse::Sse| {
+ move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| {
log::info!("Incoming SSE request for {:?}", subscription.timeline);
// Create a new ClientAgent
let mut client_agent = client_agent_sse.clone_with_shared_receiver();
// Assign ClientAgent to generate stream of updates for the user/timeline pair
client_agent.init_for_user(subscription);
// send the updates through the SSE connection
- redis_to_client_stream::send_updates_to_sse(
- client_agent,
- sse_connection_to_client,
- sse_update_interval,
- )
+ EventStream::to_sse(client_agent, sse_connection_to_client, sse_interval)
},
)
- .with(warp::reply::with::header("Connection", "keep-alive"))
- .recover(err::handle_errors);
+ .with(warp::reply::with::header("Connection", "keep-alive"));
// WebSocket
let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode);
- let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode)
+ let websocket_routes = Subscription::from_ws_request(pg_pool.clone(), whitelist_mode)
.and(warp::ws::ws2())
- .map(
- move |subscription: subscription::Subscription, token: Option, ws: Ws2| {
- log::info!("Incoming websocket request for {:?}", subscription.timeline);
- // Create a new ClientAgent
- let mut client_agent = client_agent_ws.clone_with_shared_receiver();
- // Assign that agent to generate a stream of updates for the user/timeline pair
- client_agent.init_for_user(subscription);
- // send the updates through the WS connection (along with the User's access_token
- // which is sent for security)
+ .map(move |subscription: Subscription, ws: Ws2| {
+ log::info!("Incoming websocket request for {:?}", subscription.timeline);
- (
- ws.on_upgrade(move |socket| {
- redis_to_client_stream::send_updates_to_ws(
- socket,
- client_agent,
- ws_update_interval,
- )
- }),
- token.unwrap_or_else(String::new),
- )
- },
- )
+ let token = subscription.access_token.clone();
+ // Create a new ClientAgent
+ let mut client_agent = client_agent_ws.clone_with_shared_receiver();
+ // Assign that agent to generate a stream of updates for the user/timeline pair
+ client_agent.init_for_user(subscription);
+ // send the updates through the WS connection (along with the User's access_token
+ // which is sent for security)
+ (
+ ws.on_upgrade(move |socket| {
+ EventStream::to_ws(socket, client_agent, ws_update_interval)
+ }),
+ token.unwrap_or_else(String::new),
+ )
+ })
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
let cors = warp::cors()
@@ -98,7 +87,28 @@ fn main() {
fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap();
- warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run_incoming(incoming);
+ warp::serve(
+ health.or(websocket_routes.or(sse_routes).with(cors).recover(
+ |rejection: warp::reject::Rejection| {
+ let err_txt = match rejection.cause() {
+ Some(text)
+ if text.to_string() == "Missing request header 'authorization'" =>
+ {
+ "Error: Missing access token".to_string()
+ }
+ Some(text) => text.to_string(),
+ None => "Error: Nonexistant endpoint".to_string(),
+ };
+ let json = warp::reply::json(&err_txt);
+
+ Ok(warp::reply::with_status(
+ json,
+ warp::http::StatusCode::UNAUTHORIZED,
+ ))
+ },
+ )),
+ )
+ .run_incoming(incoming);
} else {
let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0);
warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run(server_addr);
diff --git a/src/messages.rs b/src/messages.rs
index 5facc8e..a31cade 100644
--- a/src/messages.rs
+++ b/src/messages.rs
@@ -429,27 +429,24 @@ impl Status {
mod test {
use super::*;
use crate::{
- parse_client_request::subscription::{Content::*, Reach::*, Stream::*, Timeline},
- redis_to_client_stream::{
- receiver::{MessageQueues, MsgQueue},
- redis::{
- redis_msg::{ParseErr, RedisMsg},
- redis_stream,
- },
- },
+ err::RedisParseErr,
+ parse_client_request::{Content::*, Reach::*, Stream::*, Timeline},
+ redis_to_client_stream::{MessageQueues, MsgQueue, RedisMsg},
};
use lru::LruCache;
use std::collections::HashMap;
use uuid::Uuid;
- type Err = ParseErr;
+ type Err = RedisParseErr;
/// Set up state shared between multiple tests of Redis parsing
pub fn shared_setup() -> (LruCache, MessageQueues, Uuid, Timeline) {
- let cache: LruCache = LruCache::new(1000);
+ let mut cache: LruCache = LruCache::new(1000);
let mut queues_map = HashMap::new();
- let id = Uuid::default();
+ let id = dbg!(Uuid::default());
- let timeline = Timeline::from_redis_raw_timeline("4", None);
+ let timeline = dbg!(
+ Timeline::from_redis_raw_timeline("timeline:4", &mut cache, &None).expect("In test")
+ );
queues_map.insert(id, MsgQueue::new(timeline));
let queues = MessageQueues(queues_map);
(cache, queues, id, timeline)
@@ -460,8 +457,8 @@ mod test {
let input ="*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:4\r\n$1386\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102866835379605039\",\"created_at\":\"2019-09-27T22:29:02.590Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"http://localhost:3000/users/admin/statuses/102866835379605039\",\"url\":\"http://localhost:3000/@admin/102866835379605039\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"@susan hi
\",\"reblog\":null,\"application\":{\"name\":\"Web\",\"website\":null},\"account\":{\"id\":\"1\",\"username\":\"admin\",\"acct\":\"admin\",\"display_name\":\"\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-07-04T00:21:05.890Z\",\"note\":\"\",\"url\":\"http://localhost:3000/@admin\",\"avatar\":\"http://localhost:3000/avatars/original/missing.png\",\"avatar_static\":\"http://localhost:3000/avatars/original/missing.png\",\"header\":\"http://localhost:3000/headers/original/missing.png\",\"header_static\":\"http://localhost:3000/headers/original/missing.png\",\"followers_count\":3,\"following_count\":3,\"statuses_count\":192,\"emojis\":[],\"fields\":[]},\"media_attachments\":[],\"mentions\":[{\"id\":\"4\",\"username\":\"susan\",\"url\":\"http://localhost:3000/@susan\",\"acct\":\"susan\"}],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1569623342825}\r\n";
let (mut cache, mut queues, id, timeline) = shared_setup();
- redis_stream::process_messages(input.to_string(), &mut None, &mut cache, &mut queues)
- .map_err(|_| ParseErr::Unrecoverable)?;
+ crate::redis_to_client_stream::process_messages(input, &mut cache, &mut None, &mut queues);
+
let parsed_event = queues.oldest_msg_in_target_queue(id, timeline).unwrap();
let test_event = Event::Update{ payload: Status {
id: "102866835379605039".to_string(),
@@ -538,40 +535,40 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(subscription_msg1, RedisMsg::SubscriptionMsg));
- let (subscription_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (subscription_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(subscription_msg2, RedisMsg::SubscriptionMsg));
- let (subscription_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (subscription_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(subscription_msg3, RedisMsg::SubscriptionMsg));
- let (subscription_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (subscription_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(subscription_msg4, RedisMsg::SubscriptionMsg));
- let (subscription_msg5, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (subscription_msg5, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(subscription_msg5, RedisMsg::SubscriptionMsg));
- let (update_msg1, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (update_msg1, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(
update_msg1,
RedisMsg::EventMsg(_, Event::Update { .. })
));
- let (update_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (update_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(
update_msg2,
RedisMsg::EventMsg(_, Event::Update { .. })
));
- let (update_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (update_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(
update_msg3,
RedisMsg::EventMsg(_, Event::Update { .. })
));
- let (update_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (update_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
assert!(matches!(
update_msg4,
RedisMsg::EventMsg(_, Event::Update { .. })
@@ -588,7 +585,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
subscription_msg1,
RedisMsg::EventMsg(Timeline(User(id), Federated, All), Event::Notification { .. }) if id == 55
@@ -603,9 +600,9 @@ mod test {
fn parse_redis_input_delete() -> Result<(), Err> {
let input = "*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$49\r\n{\"event\":\"delete\",\"payload\":\"103864778284581232\"}\r\n";
- let (mut cache, _, _, _) = shared_setup();
+ let (mut cache, _, _, _) = dbg!(shared_setup());
- let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
subscription_msg1,
RedisMsg::EventMsg(
@@ -625,7 +622,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
subscription_msg1,
RedisMsg::EventMsg(Timeline(User(id), Federated, All), Event::FiltersChanged) if id == 56
@@ -642,7 +639,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
msg,
RedisMsg::EventMsg(
@@ -660,7 +657,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
msg,
RedisMsg::EventMsg(
@@ -679,7 +676,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
msg,
RedisMsg::EventMsg(
@@ -701,7 +698,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -721,7 +718,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -741,7 +738,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -761,7 +758,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -781,7 +778,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -797,7 +794,7 @@ mod test {
},
)
));
- let (msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?;
+ let (msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?;
dbg!(&msg2);
assert!(matches!(
msg2,
@@ -814,7 +811,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -840,7 +837,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -863,7 +860,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -886,7 +883,7 @@ mod test {
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
dbg!(&msg);
assert!(matches!(
msg,
@@ -904,7 +901,7 @@ mod test {
fn parse_redis_input_from_live_data_1() -> Result<(), Err> {
let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$2799\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880088450458596\",\"created_at\":\"2020-03-24T21:12:37.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"es\",\"uri\":\"https://mastodon.social/users/durru/statuses/103880088436492032\",\"url\":\"https://mastodon.social/@durru/103880088436492032\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"¡No puedes salir, loca!
\",\"reblog\":null,\"account\":{\"id\":\"2271\",\"username\":\"durru\",\"acct\":\"durru@mastodon.social\",\"display_name\":\"Cloaca Maxima\",\"locked\":false,\"bot\":false,\"discoverable\":true,\"group\":false,\"created_at\":\"2020-03-24T21:27:31.669Z\",\"note\":\"Todo pasa, antes o después, por la Cloaca, diría Vitruvio.
También compongo palíndromos.
\",\"url\":\"https://mastodon.social/@durru\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/002/271/original/d7675a6ff9d9baa7.jpeg?1585085250\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/002/271/original/d7675a6ff9d9baa7.jpeg?1585085250\",\"header\":\"https://instance.codesections.com/system/accounts/headers/000/002/271/original/e3f0a1989b0d8efc.jpeg?1585085250\",\"header_static\":\"https://instance.codesections.com/system/accounts/headers/000/002/271/original/e3f0a1989b0d8efc.jpeg?1585085250\",\"followers_count\":222,\"following_count\":81,\"statuses_count\":5443,\"last_status_at\":\"2020-03-24\",\"emojis\":[],\"fields\":[{\"name\":\"Mis fotos\",\"value\":\"https://pixelfed.de/durru\",\"verified_at\":null},{\"name\":\"diaspora*\",\"value\":\"https://joindiaspora.com/people/75fec0e05114013484870242ac110007\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2864\",\"type\":\"image\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/864/original/3988312d30936494.jpeg?1585085251\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/864/small/3988312d30936494.jpeg?1585085251\",\"remote_url\":\"https://files.mastodon.social/media_attachments/files/026/669/690/original/d8171331f956cf38.jpg\",\"text_url\":null,\"meta\":{\"original\":{\"width\":1001,\"height\":662,\"size\":\"1001x662\",\"aspect\":1.512084592145015},\"small\":{\"width\":491,\"height\":325,\"size\":\"491x325\",\"aspect\":1.5107692307692309}},\"description\":null,\"blurhash\":\"UdLqhI4n4TIUIAt7t7ay~qIojtRj?bM{M{of\"}],\"mentions\":[],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n";
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
msg,
RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. })
@@ -917,7 +914,7 @@ mod test {
fn parse_redis_input_from_live_data_2() -> Result<(), Err> {
let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$3888\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880373579328660\",\"created_at\":\"2020-03-24T22:25:05.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"https://newsbots.eu/users/granma/statuses/103880373417385978\",\"url\":\"https://newsbots.eu/@granma/103880373417385978\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"A total of 11 measures have been established for the pre-epidemic stage of the battle against #Covid-19 in #Cuba
#CubaPorLaSalud
http://en.granma.cu/cuba/2020-03-23/public-health-measures-in-covid-19-pre-epidemic-stage
\",\"reblog\":null,\"account\":{\"id\":\"717\",\"username\":\"granma\",\"acct\":\"granma@newsbots.eu\",\"display_name\":\"Granma (Unofficial)\",\"locked\":false,\"bot\":true,\"discoverable\":false,\"group\":false,\"created_at\":\"2020-03-13T11:08:08.420Z\",\"note\":\"\",\"url\":\"https://newsbots.eu/@granma\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/717/original/4a1f9ed090fc36e9.jpeg?1584097687\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/717/original/4a1f9ed090fc36e9.jpeg?1584097687\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":57,\"following_count\":1,\"statuses_count\":742,\"last_status_at\":\"2020-03-24\",\"emojis\":[],\"fields\":[{\"name\":\"Source\",\"value\":\"https://twitter.com/Granma_English\",\"verified_at\":null},{\"name\":\"Operator\",\"value\":\"@felix\",\"verified_at\":null},{\"name\":\"Code\",\"value\":\"https://yerbamate.dev/nutomic/tootbot\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2881\",\"type\":\"image\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/881/original/a1e97908e84efbcd.jpeg?1585088707\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/881/small/a1e97908e84efbcd.jpeg?1585088707\",\"remote_url\":\"https://newsbots.eu/system/media_attachments/files/000/176/298/original/f30a877d5035f4a6.jpeg\",\"text_url\":null,\"meta\":{\"original\":{\"width\":700,\"height\":795,\"size\":\"700x795\",\"aspect\":0.8805031446540881},\"small\":{\"width\":375,\"height\":426,\"size\":\"375x426\",\"aspect\":0.8802816901408451}},\"description\":null,\"blurhash\":\"UHCY?%sD%1t6}snOxuxu#7rrx]xu$*i_NFNF\"}],\"mentions\":[],\"tags\":[{\"name\":\"covid\",\"url\":\"https://instance.codesections.com/tags/covid\"},{\"name\":\"cuba\",\"url\":\"https://instance.codesections.com/tags/cuba\"},{\"name\":\"CubaPorLaSalud\",\"url\":\"https://instance.codesections.com/tags/CubaPorLaSalud\"}],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n";
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
msg,
RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. })
@@ -930,7 +927,7 @@ mod test {
fn parse_redis_input_from_live_data_3() -> Result<(), Err> {
let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$4803\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880453908763088\",\"created_at\":\"2020-03-24T22:45:33.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"https://mstdn.social/users/stux/statuses/103880453855603541\",\"url\":\"https://mstdn.social/@stux/103880453855603541\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"When they say lockdown. LOCKDOWN.
\",\"reblog\":null,\"account\":{\"id\":\"806\",\"username\":\"stux\",\"acct\":\"stux@mstdn.social\",\"display_name\":\"sтυx⚡\",\"locked\":false,\"bot\":false,\"discoverable\":true,\"group\":false,\"created_at\":\"2020-03-13T23:02:29.970Z\",\"note\":\"Hi, Stux here! I am running the mstdn.social :mastodon: instance!
For questions and help or just for fun you can always send me a toot♥\u{fe0f}
Oh and no, I am not really a cat! Or am I?
\",\"url\":\"https://mstdn.social/@stux\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/806/original/dae8d9d01d57d7f8.gif?1584140547\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/806/static/dae8d9d01d57d7f8.png?1584140547\",\"header\":\"https://instance.codesections.com/system/accounts/headers/000/000/806/original/88c874d69f7d6989.gif?1584140548\",\"header_static\":\"https://instance.codesections.com/system/accounts/headers/000/000/806/static/88c874d69f7d6989.png?1584140548\",\"followers_count\":13954,\"following_count\":7600,\"statuses_count\":10207,\"last_status_at\":\"2020-03-24\",\"emojis\":[{\"shortcode\":\"mastodon\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/418/original/25ccc64333645735.png?1584140550\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/418/static/25ccc64333645735.png?1584140550\",\"visible_in_picker\":true},{\"shortcode\":\"patreon\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/419/original/3cc463d3dfc1e489.png?1584140550\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/419/static/3cc463d3dfc1e489.png?1584140550\",\"visible_in_picker\":true},{\"shortcode\":\"liberapay\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/420/original/893854353dfa9706.png?1584140551\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/420/static/893854353dfa9706.png?1584140551\",\"visible_in_picker\":true},{\"shortcode\":\"team_valor\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/958/original/96aae26b45292a12.png?1584910917\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/958/static/96aae26b45292a12.png?1584910917\",\"visible_in_picker\":true}],\"fields\":[{\"name\":\"Patreon :patreon:\",\"value\":\"https://www.patreon.com/mstdn\",\"verified_at\":null},{\"name\":\"LiberaPay :liberapay:\",\"value\":\"https://liberapay.com/mstdn\",\"verified_at\":null},{\"name\":\"Team :team_valor:\",\"value\":\"https://mstdn.social/team\",\"verified_at\":null},{\"name\":\"Support :mastodon:\",\"value\":\"https://mstdn.social/funding\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2886\",\"type\":\"video\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/886/original/22b3f98a5e8f86d8.mp4?1585090023\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/886/small/22b3f98a5e8f86d8.png?1585090023\",\"remote_url\":\"https://cdn.mstdn.social/mstdn-social/media_attachments/files/003/338/384/original/c146f62ba86fe63e.mp4\",\"text_url\":null,\"meta\":{\"length\":\"0:00:27.03\",\"duration\":27.03,\"fps\":30,\"size\":\"272x480\",\"width\":272,\"height\":480,\"aspect\":0.5666666666666667,\"audio_encode\":\"aac (LC) (mp4a / 0x6134706D)\",\"audio_bitrate\":\"44100 Hz\",\"audio_channels\":\"stereo\",\"original\":{\"width\":272,\"height\":480,\"frame_rate\":\"30/1\",\"duration\":27.029,\"bitrate\":481885},\"small\":{\"width\":227,\"height\":400,\"size\":\"227x400\",\"aspect\":0.5675}},\"description\":null,\"blurhash\":\"UBF~N@OF-:xv4mM|s+ob9FE2t6tQ9Fs:t8oN\"}],\"mentions\":[],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n";
let (mut cache, _, _, _) = shared_setup();
- let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?;
+ let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?;
assert!(matches!(
msg,
RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. })
diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs
index c999c0e..c076eb2 100644
--- a/src/parse_client_request/mod.rs
+++ b/src/parse_client_request/mod.rs
@@ -1,5 +1,13 @@
-//! Parse the client request and return a (possibly authenticated) `User`
-pub mod query;
-pub mod sse;
-pub mod subscription;
-pub mod ws;
+//! Parse the client request and return a Subscription
+mod postgres;
+mod query;
+mod sse;
+mod subscription;
+mod ws;
+
+pub use self::postgres::PgPool;
+// TODO consider whether we can remove `Stream` from public API
+pub use subscription::{Stream, Subscription, Timeline};
+
+#[cfg(test)]
+pub use subscription::{Content, Reach};
diff --git a/src/parse_client_request/postgres.rs b/src/parse_client_request/postgres.rs
new file mode 100644
index 0000000..606f583
--- /dev/null
+++ b/src/parse_client_request/postgres.rs
@@ -0,0 +1,204 @@
+//! Postgres queries
+use crate::{
+ config,
+ parse_client_request::subscription::{Scope, UserData},
+};
+use ::postgres;
+use r2d2_postgres::PostgresConnectionManager;
+use std::collections::HashSet;
+use warp::reject::Rejection;
+
+#[derive(Clone, Debug)]
+pub struct PgPool(pub r2d2::Pool>);
+impl PgPool {
+ pub fn new(pg_cfg: config::PostgresConfig) -> Self {
+ let mut cfg = postgres::Config::new();
+ cfg.user(&pg_cfg.user)
+ .host(&*pg_cfg.host.to_string())
+ .port(*pg_cfg.port)
+ .dbname(&pg_cfg.database);
+ if let Some(password) = &*pg_cfg.password {
+ cfg.password(password);
+ };
+
+ let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
+ let pool = r2d2::Pool::builder()
+ .max_size(10)
+ .build(manager)
+ .expect("Can connect to local postgres");
+ Self(pool)
+ }
+ pub fn select_user(self, token: &str) -> Result {
+ let mut conn = self.0.get().unwrap();
+ let query_rows = conn
+ .query(
+ "
+SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
+FROM
+oauth_access_tokens
+INNER JOIN users ON
+oauth_access_tokens.resource_owner_id = users.id
+WHERE oauth_access_tokens.token = $1
+AND oauth_access_tokens.revoked_at IS NULL
+LIMIT 1",
+ &[&token.to_owned()],
+ )
+ .expect("Hard-coded query will return Some([0 or more rows])");
+ if let Some(result_columns) = query_rows.get(0) {
+ let id = result_columns.get(1);
+ let allowed_langs = result_columns
+ .try_get::<_, Vec<_>>(2)
+ .unwrap_or_else(|_| Vec::new())
+ .into_iter()
+ .collect();
+ let mut scopes: HashSet = result_columns
+ .get::<_, String>(3)
+ .split(' ')
+ .filter_map(|scope| match scope {
+ "read" => Some(Scope::Read),
+ "read:statuses" => Some(Scope::Statuses),
+ "read:notifications" => Some(Scope::Notifications),
+ "read:lists" => Some(Scope::Lists),
+ "write" | "follow" => None, // ignore write scopes
+ unexpected => {
+ log::warn!("Ignoring unknown scope `{}`", unexpected);
+ None
+ }
+ })
+ .collect();
+ // We don't need to separately track read auth - it's just all three others
+ if scopes.remove(&Scope::Read) {
+ scopes.insert(Scope::Statuses);
+ scopes.insert(Scope::Notifications);
+ scopes.insert(Scope::Lists);
+ }
+
+ Ok(UserData {
+ id,
+ allowed_langs,
+ scopes,
+ })
+ } else {
+ Err(warp::reject::custom("Error: Invalid access token"))
+ }
+ }
+
+ pub fn select_hashtag_id(self, tag_name: &String) -> Result {
+ let mut conn = self.0.get().unwrap();
+ let rows = &conn
+ .query(
+ "
+SELECT id
+FROM tags
+WHERE name = $1
+LIMIT 1",
+ &[&tag_name],
+ )
+ .expect("Hard-coded query will return Some([0 or more rows])");
+
+ match rows.get(0) {
+ Some(row) => Ok(row.get(0)),
+ None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
+ }
+ }
+
+ /// Query Postgres for everyone the user has blocked or muted
+ ///
+ /// **NOTE**: because we check this when the user connects, it will not include any blocks
+ /// the user adds until they refresh/reconnect.
+ pub fn select_blocked_users(self, user_id: i64) -> HashSet {
+ // "
+ // SELECT
+ // 1
+ // FROM blocks
+ // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
+ // OR (account_id = $2 AND target_account_id = $1)
+ // UNION SELECT
+ // 1
+ // FROM mutes
+ // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
+ // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
+ self
+ .0
+ .get()
+ .unwrap()
+ .query(
+ "
+SELECT target_account_id
+ FROM blocks
+ WHERE account_id = $1
+UNION SELECT target_account_id
+ FROM mutes
+ WHERE account_id = $1",
+ &[&user_id],
+ )
+ .expect("Hard-coded query will return Some([0 or more rows])")
+ .iter()
+ .map(|row| row.get(0))
+ .collect()
+ }
+ /// Query Postgres for everyone who has blocked the user
+ ///
+ /// **NOTE**: because we check this when the user connects, it will not include any blocks
+ /// the user adds until they refresh/reconnect.
+ pub fn select_blocking_users(self, user_id: i64) -> HashSet {
+ self
+ .0
+ .get()
+ .unwrap()
+ .query(
+ "
+SELECT account_id
+ FROM blocks
+ WHERE target_account_id = $1",
+ &[&user_id],
+ )
+ .expect("Hard-coded query will return Some([0 or more rows])")
+ .iter()
+ .map(|row| row.get(0))
+ .collect()
+ }
+
+ /// Query Postgres for all current domain blocks
+ ///
+ /// **NOTE**: because we check this when the user connects, it will not include any blocks
+ /// the user adds until they refresh/reconnect.
+ pub fn select_blocked_domains(self, user_id: i64) -> HashSet {
+ self
+ .0
+ .get()
+ .unwrap()
+ .query(
+ "SELECT domain FROM account_domain_blocks WHERE account_id = $1",
+ &[&user_id],
+ )
+ .expect("Hard-coded query will return Some([0 or more rows])")
+ .iter()
+ .map(|row| row.get(0))
+ .collect()
+ }
+
+ /// Test whether a user owns a list
+ pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool {
+ let mut conn = self.0.get().unwrap();
+ // For the Postgres query, `id` = list number; `account_id` = user.id
+ let rows = &conn
+ .query(
+ "
+SELECT id, account_id
+FROM lists
+WHERE id = $1
+LIMIT 1",
+ &[&list_id],
+ )
+ .expect("Hard-coded query will return Some([0 or more rows])");
+
+ match rows.get(0) {
+ None => false,
+ Some(row) => {
+ let list_owner_id: i64 = row.get(1);
+ list_owner_id == user_id
+ }
+ }
+ }
+}
diff --git a/src/parse_client_request/query.rs b/src/parse_client_request/query.rs
index 3c92c40..c4445a9 100644
--- a/src/parse_client_request/query.rs
+++ b/src/parse_client_request/query.rs
@@ -28,6 +28,12 @@ impl Query {
}
macro_rules! make_query_type {
+ (Stream => $parameter:tt:$type:ty) => {
+ #[derive(Deserialize, Debug, Default)]
+ pub struct Stream {
+ pub $parameter: $type,
+ }
+ };
($name:tt => $parameter:tt:$type:ty) => {
#[derive(Deserialize, Debug, Default)]
pub struct $name {
@@ -59,14 +65,14 @@ impl ToString for Stream {
}
}
-pub fn optional_media_query() -> BoxedFilter<(Media,)> {
- warp::query()
- .or(warp::any().map(|| Media {
- only_media: "false".to_owned(),
- }))
- .unify()
- .boxed()
-}
+// pub fn optional_media_query() -> BoxedFilter<(Media,)> {
+// warp::query()
+// .or(warp::any().map(|| Media {
+// only_media: "false".to_owned(),
+// }))
+// .unify()
+// .boxed()
+// }
pub struct OptionalAccessToken;
diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs
index b62461c..287cd9b 100644
--- a/src/parse_client_request/sse.rs
+++ b/src/parse_client_request/sse.rs
@@ -1,78 +1,4 @@
//! Filters for all the endpoints accessible for Server Sent Event updates
-use super::{
- query::{self, Query},
- subscription::{PgPool, Subscription},
-};
-use warp::{filters::BoxedFilter, path, Filter};
-#[allow(dead_code)]
-type TimelineUser = ((String, Subscription),);
-
-/// Helper macro to match on the first of any of the provided filters
-macro_rules! any_of {
- ($filter:expr, $($other_filter:expr),*) => {
- $filter$(.or($other_filter).unify())*.boxed()
- };
-}
-
-macro_rules! parse_query {
- (path => $start:tt $(/ $next:tt)*
- endpoint => $endpoint:expr) => {
- path!($start $(/ $next)*)
- .and(query::Auth::to_filter())
- .and(query::Media::to_filter())
- .and(query::Hashtag::to_filter())
- .and(query::List::to_filter())
- .map(
- |auth: query::Auth,
- media: query::Media,
- hashtag: query::Hashtag,
- list: query::List| {
- Query {
- access_token: auth.access_token,
- stream: $endpoint.to_string(),
- media: media.is_truthy(),
- hashtag: hashtag.tag,
- list: list.list,
- }
- },
- )
- .boxed()
- };
-}
-pub fn extract_user_or_reject(
- pg_pool: PgPool,
- whitelist_mode: bool,
-) -> BoxedFilter<(Subscription,)> {
- any_of!(
- parse_query!(
- path => "api" / "v1" / "streaming" / "user" / "notification"
- endpoint => "user:notification" ),
- parse_query!(
- path => "api" / "v1" / "streaming" / "user"
- endpoint => "user"),
- parse_query!(
- path => "api" / "v1" / "streaming" / "public" / "local"
- endpoint => "public:local"),
- parse_query!(
- path => "api" / "v1" / "streaming" / "public"
- endpoint => "public"),
- parse_query!(
- path => "api" / "v1" / "streaming" / "direct"
- endpoint => "direct"),
- parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
- endpoint => "hashtag:local"),
- parse_query!(path => "api" / "v1" / "streaming" / "hashtag"
- endpoint => "hashtag"),
- parse_query!(path => "api" / "v1" / "streaming" / "list"
- endpoint => "list")
- )
- // because SSE requests place their `access_token` in the header instead of in a query
- // parameter, we need to update our Query if the header has a token
- .and(query::OptionalAccessToken::from_sse_header())
- .and_then(Query::update_access_token)
- .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
- .boxed()
-}
// #[cfg(test)]
// mod test {
diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs
new file mode 100644
index 0000000..d8e874a
--- /dev/null
+++ b/src/parse_client_request/subscription.rs
@@ -0,0 +1,396 @@
+//! `User` struct and related functionality
+// #[cfg(test)]
+// mod mock_postgres;
+// #[cfg(test)]
+// use mock_postgres as postgres;
+// #[cfg(not(test))]
+
+use super::postgres::PgPool;
+use super::query::Query;
+use crate::err::TimelineErr;
+use crate::log_fatal;
+use lru::LruCache;
+use std::collections::HashSet;
+use warp::reject::Rejection;
+
+use super::query;
+use warp::{filters::BoxedFilter, path, Filter};
+
+/// Helper macro to match on the first of any of the provided filters
+macro_rules! any_of {
+ ($filter:expr, $($other_filter:expr),*) => {
+ $filter$(.or($other_filter).unify())*.boxed()
+ };
+}
+macro_rules! parse_sse_query {
+ (path => $start:tt $(/ $next:tt)*
+ endpoint => $endpoint:expr) => {
+ path!($start $(/ $next)*)
+ .and(query::Auth::to_filter())
+ .and(query::Media::to_filter())
+ .and(query::Hashtag::to_filter())
+ .and(query::List::to_filter())
+ .map(
+ |auth: query::Auth,
+ media: query::Media,
+ hashtag: query::Hashtag,
+ list: query::List| {
+ Query {
+ access_token: auth.access_token,
+ stream: $endpoint.to_string(),
+ media: media.is_truthy(),
+ hashtag: hashtag.tag,
+ list: list.list,
+ }
+ },
+ )
+ .boxed()
+ };
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct Subscription {
+ pub timeline: Timeline,
+ pub allowed_langs: HashSet,
+ pub blocks: Blocks,
+ pub hashtag_name: Option,
+ pub access_token: Option,
+}
+
+impl Default for Subscription {
+ fn default() -> Self {
+ Self {
+ timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
+ allowed_langs: HashSet::new(),
+ blocks: Blocks::default(),
+ hashtag_name: None,
+ access_token: None,
+ }
+ }
+}
+
+impl Subscription {
+ pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
+ parse_ws_query()
+ .and(query::OptionalAccessToken::from_ws_header())
+ .and_then(Query::update_access_token)
+ .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
+ .boxed()
+ }
+
+ pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> {
+ any_of!(
+ parse_sse_query!(
+ path => "api" / "v1" / "streaming" / "user" / "notification"
+ endpoint => "user:notification" ),
+ parse_sse_query!(
+ path => "api" / "v1" / "streaming" / "user"
+ endpoint => "user"),
+ parse_sse_query!(
+ path => "api" / "v1" / "streaming" / "public" / "local"
+ endpoint => "public:local"),
+ parse_sse_query!(
+ path => "api" / "v1" / "streaming" / "public"
+ endpoint => "public"),
+ parse_sse_query!(
+ path => "api" / "v1" / "streaming" / "direct"
+ endpoint => "direct"),
+ parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
+ endpoint => "hashtag:local"),
+ parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
+ endpoint => "hashtag"),
+ parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
+ endpoint => "list")
+ )
+ // because SSE requests place their `access_token` in the header instead of in a query
+ // parameter, we need to update our Query if the header has a token
+ .and(query::OptionalAccessToken::from_sse_header())
+ .and_then(Query::update_access_token)
+ .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
+ .boxed()
+ }
+ fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result {
+ let user = match q.access_token.clone() {
+ Some(token) => pool.clone().select_user(&token)?,
+ None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
+ None => UserData::public(),
+ };
+ let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
+ let hashtag_name = match timeline {
+ Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
+ _non_hashtag_timeline => None,
+ };
+
+ Ok(Subscription {
+ timeline,
+ allowed_langs: user.allowed_langs,
+ blocks: Blocks {
+ blocking_users: pool.clone().select_blocking_users(user.id),
+ blocked_users: pool.clone().select_blocked_users(user.id),
+ blocked_domains: pool.clone().select_blocked_domains(user.id),
+ },
+ hashtag_name,
+ access_token: q.access_token,
+ })
+ }
+}
+
+fn parse_ws_query() -> BoxedFilter<(Query,)> {
+ path!("api" / "v1" / "streaming")
+ .and(path::end())
+ .and(warp::query())
+ .and(query::Auth::to_filter())
+ .and(query::Media::to_filter())
+ .and(query::Hashtag::to_filter())
+ .and(query::List::to_filter())
+ .map(
+ |stream: query::Stream,
+ auth: query::Auth,
+ media: query::Media,
+ hashtag: query::Hashtag,
+ list: query::List| {
+ Query {
+ access_token: auth.access_token,
+ stream: stream.stream,
+ media: media.is_truthy(),
+ hashtag: hashtag.tag,
+ list: list.list,
+ }
+ },
+ )
+ .boxed()
+}
+
+#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
+pub struct Timeline(pub Stream, pub Reach, pub Content);
+
+impl Timeline {
+ pub fn empty() -> Self {
+ use {Content::*, Reach::*, Stream::*};
+ Self(Unset, Local, Notification)
+ }
+
+ pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String {
+ use {Content::*, Reach::*, Stream::*};
+ match self {
+ Timeline(Public, Federated, All) => "timeline:public".into(),
+ Timeline(Public, Local, All) => "timeline:public:local".into(),
+ Timeline(Public, Federated, Media) => "timeline:public:media".into(),
+ Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
+
+ Timeline(Hashtag(id), Federated, All) => format!(
+ "timeline:hashtag:{}",
+ hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
+ ),
+ Timeline(Hashtag(id), Local, All) => format!(
+ "timeline:hashtag:{}:local",
+ hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
+ ),
+ Timeline(User(id), Federated, All) => format!("timeline:{}", id),
+ Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
+ Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
+ Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
+ Timeline(one, _two, _three) => {
+ log_fatal!("Supposedly impossible timeline reached: {:?}", one)
+ }
+ }
+ }
+
+ pub fn from_redis_raw_timeline(
+ timeline: &str,
+ cache: &mut LruCache,
+ namespace: &Option,
+ ) -> Result {
+ use crate::err::TimelineErr::RedisNamespaceMismatch;
+ use {Content::*, Reach::*, Stream::*};
+ let timeline_slice = &timeline.split(":").collect::>()[..];
+
+ #[rustfmt::skip]
+ let (stream, reach, content) = if let Some(ns) = namespace {
+ match timeline_slice {
+ [n, "timeline", "public"] if n == ns => (Public, Federated, All),
+ [_, "timeline", "public"]
+ | ["timeline", "public"] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", "public", "local"] if ns == n => (Public, Local, All),
+ [_, "timeline", "public", "local"]
+ | ["timeline", "public", "local"] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", "public", "media"] if ns == n => (Public, Federated, Media),
+ [_, "timeline", "public", "media"]
+ | ["timeline", "public", "media"] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", "public", "local", "media"] if ns == n => (Public, Local, Media),
+ [_, "timeline", "public", "local", "media"]
+ | ["timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", "hashtag", tag_name] if ns == n => {
+ let tag_id = *cache
+ .get(&tag_name.to_string())
+ .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name));
+ (Hashtag(tag_id), Federated, All)
+ }
+ [_, "timeline", "hashtag", _tag]
+ | ["timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", "hashtag", _tag, "local"] if ns == n => (Hashtag(0), Local, All),
+ [_, "timeline", "hashtag", _tag, "local"]
+ | ["timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", id] if ns == n => (User(id.parse().unwrap()), Federated, All),
+ [_, "timeline", _id]
+ | ["timeline", _id] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", id, "notification"] if ns == n =>
+ (User(id.parse()?), Federated, Notification),
+
+ [_, "timeline", _id, "notification"]
+ | ["timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?,
+
+
+ [n, "timeline", "list", id] if ns == n => (List(id.parse()?), Federated, All),
+ [_, "timeline", "list", _id]
+ | ["timeline", "list", _id] => Err(RedisNamespaceMismatch)?,
+
+ [n, "timeline", "direct", id] if ns == n => (Direct(id.parse()?), Federated, All),
+ [_, "timeline", "direct", _id]
+ | ["timeline", "direct", _id] => Err(RedisNamespaceMismatch)?,
+
+ [..] => log_fatal!("Unexpected channel from Redis: {:?}", timeline_slice),
+ }
+ } else {
+ match timeline_slice {
+ ["timeline", "public"] => (Public, Federated, All),
+ [_, "timeline", "public"] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "public", "local"] => (Public, Local, All),
+ [_, "timeline", "public", "local"] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "public", "media"] => (Public, Federated, Media),
+
+ [_, "timeline", "public", "media"] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "public", "local", "media"] => (Public, Local, Media),
+ [_, "timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "hashtag", _tag] => (Hashtag(0), Federated, All),
+ [_, "timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "hashtag", _tag, "local"] => (Hashtag(0), Local, All),
+ [_, "timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", id] => (User(id.parse().unwrap()), Federated, All),
+ [_, "timeline", _id] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", id, "notification"] => {
+ (User(id.parse().unwrap()), Federated, Notification)
+ }
+ [_, "timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "list", id] => (List(id.parse().unwrap()), Federated, All),
+ [_, "timeline", "list", _id] => Err(RedisNamespaceMismatch)?,
+
+ ["timeline", "direct", id] => (Direct(id.parse().unwrap()), Federated, All),
+ [_, "timeline", "direct", _id] => Err(RedisNamespaceMismatch)?,
+
+ // Other endpoints don't exist:
+ [..] => Err(TimelineErr::InvalidInput)?,
+ }
+ };
+
+ Ok(Timeline(stream, reach, content))
+ }
+ fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result {
+ use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
+ let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
+ let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
+
+ Ok(match q.stream.as_ref() {
+ "public" => match q.media {
+ true => Timeline(Public, Federated, Media),
+ false => Timeline(Public, Federated, All),
+ },
+ "public:local" => match q.media {
+ true => Timeline(Public, Local, Media),
+ false => Timeline(Public, Local, All),
+ },
+ "public:media" => Timeline(Public, Federated, Media),
+ "public:local:media" => Timeline(Public, Local, Media),
+
+ "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
+ "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
+ "user" => match user.scopes.contains(&Statuses) {
+ true => Timeline(User(user.id), Federated, All),
+ false => Err(custom("Error: Missing access token"))?,
+ },
+ "user:notification" => match user.scopes.contains(&Statuses) {
+ true => Timeline(User(user.id), Federated, Notification),
+ false => Err(custom("Error: Missing access token"))?,
+ },
+ "list" => match user.scopes.contains(&Lists) && user_owns_list() {
+ true => Timeline(List(q.list), Federated, All),
+ false => Err(warp::reject::custom("Error: Missing access token"))?,
+ },
+ "direct" => match user.scopes.contains(&Statuses) {
+ true => Timeline(Direct(user.id), Federated, All),
+ false => Err(custom("Error: Missing access token"))?,
+ },
+ other => {
+ log::warn!("Request for nonexistent endpoint: `{}`", other);
+ Err(custom("Error: Nonexistent endpoint"))?
+ }
+ })
+ }
+}
+#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
+pub enum Stream {
+ User(i64),
+ List(i64),
+ Direct(i64),
+ Hashtag(i64),
+ Public,
+ Unset,
+}
+#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
+pub enum Reach {
+ Local,
+ Federated,
+}
+#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
+pub enum Content {
+ All,
+ Media,
+ Notification,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+pub enum Scope {
+ Read,
+ Statuses,
+ Notifications,
+ Lists,
+}
+
+#[derive(Clone, Default, Debug, PartialEq)]
+pub struct Blocks {
+ pub blocked_domains: HashSet,
+ pub blocked_users: HashSet,
+ pub blocking_users: HashSet,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct UserData {
+ pub id: i64,
+ pub allowed_langs: HashSet,
+ pub scopes: HashSet,
+}
+
+impl UserData {
+ fn public() -> Self {
+ Self {
+ id: -1,
+ allowed_langs: HashSet::new(),
+ scopes: HashSet::new(),
+ }
+ }
+}
diff --git a/src/parse_client_request/subscription/mock_postgres.rs b/src/parse_client_request/subscription/mock_postgres.rs
deleted file mode 100644
index d5a9612..0000000
--- a/src/parse_client_request/subscription/mock_postgres.rs
+++ /dev/null
@@ -1,43 +0,0 @@
-//! Mock Postgres connection (for use in unit testing)
-use super::{OauthScope, Subscription};
-use std::collections::HashSet;
-
-#[derive(Clone)]
-pub struct PgPool;
-impl PgPool {
- pub fn new() -> Self {
- Self
- }
-}
-
-pub fn select_user(
- access_token: &str,
- _pg_pool: PgPool,
-) -> Result {
- let mut user = Subscription::default();
- if access_token == "TEST_USER" {
- user.id = 1;
- user.logged_in = true;
- user.access_token = "TEST_USER".to_string();
- user.email = "user@example.com".to_string();
- user.scopes = OauthScope::from(vec![
- "read".to_string(),
- "write".to_string(),
- "follow".to_string(),
- ]);
- } else if access_token == "INVALID" {
- return Err(warp::reject::custom("Error: Invalid access token"));
- }
- Ok(user)
-}
-
-pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet {
- HashSet::new()
-}
-pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet {
- HashSet::new()
-}
-
-pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool {
- user_id == list_id
-}
diff --git a/src/parse_client_request/subscription/mod.rs b/src/parse_client_request/subscription/mod.rs
deleted file mode 100644
index 3fdb060..0000000
--- a/src/parse_client_request/subscription/mod.rs
+++ /dev/null
@@ -1,196 +0,0 @@
-//! `User` struct and related functionality
-// #[cfg(test)]
-// mod mock_postgres;
-// #[cfg(test)]
-// use mock_postgres as postgres;
-// #[cfg(not(test))]
-pub mod postgres;
-pub use self::postgres::PgPool;
-use super::query::Query;
-use crate::log_fatal;
-use std::collections::HashSet;
-use warp::reject::Rejection;
-
-/// The User (with data read from Postgres)
-#[derive(Clone, Debug, PartialEq)]
-pub struct Subscription {
- pub timeline: Timeline,
- pub allowed_langs: HashSet,
- pub blocks: Blocks,
-}
-
-impl Default for Subscription {
- fn default() -> Self {
- Self {
- timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification),
- allowed_langs: HashSet::new(),
- blocks: Blocks::default(),
- }
- }
-}
-
-impl Subscription {
- pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result {
- let user = match q.access_token.clone() {
- Some(token) => postgres::select_user(&token, pool.clone())?,
- None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?,
- None => UserData::public(),
- };
- Ok(Subscription {
- timeline: Timeline::from_query_and_user(&q, &user, pool.clone())?,
- allowed_langs: user.allowed_langs,
- blocks: Blocks {
- blocking_users: postgres::select_blocking_users(user.id, pool.clone()),
- blocked_users: postgres::select_blocked_users(user.id, pool.clone()),
- blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()),
- },
- })
- }
-}
-
-#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
-pub struct Timeline(pub Stream, pub Reach, pub Content);
-
-impl Timeline {
- pub fn empty() -> Self {
- use {Content::*, Reach::*, Stream::*};
- Self(Unset, Local, Notification)
- }
-
- pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String {
- use {Content::*, Reach::*, Stream::*};
- match self {
- Timeline(Public, Federated, All) => "timeline:public".into(),
- Timeline(Public, Local, All) => "timeline:public:local".into(),
- Timeline(Public, Federated, Media) => "timeline:public:media".into(),
- Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
-
- Timeline(Hashtag(id), Federated, All) => format!(
- "timeline:hashtag:{}",
- hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
- ),
- Timeline(Hashtag(id), Local, All) => format!(
- "timeline:hashtag:{}:local",
- hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id))
- ),
- Timeline(User(id), Federated, All) => format!("timeline:{}", id),
- Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
- Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
- Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
- Timeline(one, _two, _three) => {
- log_fatal!("Supposedly impossible timeline reached: {:?}", one)
- }
- }
- }
- pub fn from_redis_raw_timeline(raw_timeline: &str, hashtag: Option) -> Self {
- use {Content::*, Reach::*, Stream::*};
- match raw_timeline.split(':').collect::>()[..] {
- ["public"] => Timeline(Public, Federated, All),
- ["public", "local"] => Timeline(Public, Local, All),
- ["public", "media"] => Timeline(Public, Federated, Media),
- ["public", "local", "media"] => Timeline(Public, Local, Media),
-
- ["hashtag", _tag] => Timeline(Hashtag(hashtag.unwrap()), Federated, All),
- ["hashtag", _tag, "local"] => Timeline(Hashtag(hashtag.unwrap()), Local, All),
- [id] => Timeline(User(id.parse().unwrap()), Federated, All),
- [id, "notification"] => Timeline(User(id.parse().unwrap()), Federated, Notification),
- ["list", id] => Timeline(List(id.parse().unwrap()), Federated, All),
- ["direct", id] => Timeline(Direct(id.parse().unwrap()), Federated, All),
- // Other endpoints don't exist:
- [..] => log_fatal!("Unexpected channel from Redis: {}", raw_timeline),
- }
- }
- fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result {
- use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
- let id_from_hashtag = || postgres::select_list_id(&q.hashtag, pool.clone());
- let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone());
-
- Ok(match q.stream.as_ref() {
- "public" => match q.media {
- true => Timeline(Public, Federated, Media),
- false => Timeline(Public, Federated, All),
- },
- "public:local" => match q.media {
- true => Timeline(Public, Local, Media),
- false => Timeline(Public, Local, All),
- },
- "public:media" => Timeline(Public, Federated, Media),
- "public:local:media" => Timeline(Public, Local, Media),
-
- "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
- "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
- "user" => match user.scopes.contains(&Statuses) {
- true => Timeline(User(user.id), Federated, All),
- false => Err(custom("Error: Missing access token"))?,
- },
- "user:notification" => match user.scopes.contains(&Statuses) {
- true => Timeline(User(user.id), Federated, Notification),
- false => Err(custom("Error: Missing access token"))?,
- },
- "list" => match user.scopes.contains(&Lists) && user_owns_list() {
- true => Timeline(List(q.list), Federated, All),
- false => Err(warp::reject::custom("Error: Missing access token"))?,
- },
- "direct" => match user.scopes.contains(&Statuses) {
- true => Timeline(Direct(user.id), Federated, All),
- false => Err(custom("Error: Missing access token"))?,
- },
- other => {
- log::warn!("Request for nonexistent endpoint: `{}`", other);
- Err(custom("Error: Nonexistent endpoint"))?
- }
- })
- }
-}
-#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
-pub enum Stream {
- User(i64),
- List(i64),
- Direct(i64),
- Hashtag(i64),
- Public,
- Unset,
-}
-#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
-pub enum Reach {
- Local,
- Federated,
-}
-#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
-pub enum Content {
- All,
- Media,
- Notification,
-}
-
-#[derive(Clone, Debug, PartialEq, Eq, Hash)]
-pub enum Scope {
- Read,
- Statuses,
- Notifications,
- Lists,
-}
-
-#[derive(Clone, Default, Debug, PartialEq)]
-pub struct Blocks {
- pub blocked_domains: HashSet,
- pub blocked_users: HashSet,
- pub blocking_users: HashSet,
-}
-
-#[derive(Clone, Debug, PartialEq)]
-pub struct UserData {
- id: i64,
- allowed_langs: HashSet,
- scopes: HashSet,
-}
-
-impl UserData {
- fn public() -> Self {
- Self {
- id: -1,
- allowed_langs: HashSet::new(),
- scopes: HashSet::new(),
- }
- }
-}
diff --git a/src/parse_client_request/subscription/postgres.rs b/src/parse_client_request/subscription/postgres.rs
deleted file mode 100644
index 9353efa..0000000
--- a/src/parse_client_request/subscription/postgres.rs
+++ /dev/null
@@ -1,225 +0,0 @@
-//! Postgres queries
-use crate::{
- config,
- parse_client_request::subscription::{Scope, UserData},
-};
-use ::postgres;
-use r2d2_postgres::PostgresConnectionManager;
-use std::collections::HashSet;
-use warp::reject::Rejection;
-
-#[derive(Clone, Debug)]
-pub struct PgPool(pub r2d2::Pool>);
-impl PgPool {
- pub fn new(pg_cfg: config::PostgresConfig) -> Self {
- let mut cfg = postgres::Config::new();
- cfg.user(&pg_cfg.user)
- .host(&*pg_cfg.host.to_string())
- .port(*pg_cfg.port)
- .dbname(&pg_cfg.database);
- if let Some(password) = &*pg_cfg.password {
- cfg.password(password);
- };
-
- let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
- let pool = r2d2::Pool::builder()
- .max_size(10)
- .build(manager)
- .expect("Can connect to local postgres");
- Self(pool)
- }
-}
-
-pub fn select_user(token: &str, pool: PgPool) -> Result {
- let mut conn = pool.0.get().unwrap();
- let query_rows = conn
- .query(
- "
-SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
-FROM
-oauth_access_tokens
-INNER JOIN users ON
-oauth_access_tokens.resource_owner_id = users.id
-WHERE oauth_access_tokens.token = $1
-AND oauth_access_tokens.revoked_at IS NULL
-LIMIT 1",
- &[&token.to_owned()],
- )
- .expect("Hard-coded query will return Some([0 or more rows])");
- if let Some(result_columns) = query_rows.get(0) {
- let id = result_columns.get(1);
- let allowed_langs = result_columns
- .try_get::<_, Vec<_>>(2)
- .unwrap_or_else(|_| Vec::new())
- .into_iter()
- .collect();
- let mut scopes: HashSet = result_columns
- .get::<_, String>(3)
- .split(' ')
- .filter_map(|scope| match scope {
- "read" => Some(Scope::Read),
- "read:statuses" => Some(Scope::Statuses),
- "read:notifications" => Some(Scope::Notifications),
- "read:lists" => Some(Scope::Lists),
- "write" | "follow" => None, // ignore write scopes
- unexpected => {
- log::warn!("Ignoring unknown scope `{}`", unexpected);
- None
- }
- })
- .collect();
- // We don't need to separately track read auth - it's just all three others
- if scopes.remove(&Scope::Read) {
- scopes.insert(Scope::Statuses);
- scopes.insert(Scope::Notifications);
- scopes.insert(Scope::Lists);
- }
-
- Ok(UserData {
- id,
- allowed_langs,
- scopes,
- })
- } else {
- Err(warp::reject::custom("Error: Invalid access token"))
- }
-}
-
-pub fn select_list_id(tag_name: &String, pg_pool: PgPool) -> Result {
- let mut conn = pg_pool.0.get().unwrap();
- // For the Postgres query, `id` = list number; `account_id` = user.id
- let rows = &conn
- .query(
- "
-SELECT id
-FROM tags
-WHERE name = $1
-LIMIT 1",
- &[&tag_name],
- )
- .expect("Hard-coded query will return Some([0 or more rows])");
-
- match rows.get(0) {
- Some(row) => Ok(row.get(0)),
- None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
- }
-}
-pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result {
- let mut conn = pg_pool.0.get().unwrap();
- // For the Postgres query, `id` = list number; `account_id` = user.id
- let rows = &conn
- .query(
- "
-SELECT name
-FROM tags
-WHERE id = $1
-LIMIT 1",
- &[&tag_id],
- )
- .expect("Hard-coded query will return Some([0 or more rows])");
-
- match rows.get(0) {
- Some(row) => Ok(row.get(0)),
- None => Err(warp::reject::custom("Error: Hashtag does not exist.")),
- }
-}
-
-/// Query Postgres for everyone the user has blocked or muted
-///
-/// **NOTE**: because we check this when the user connects, it will not include any blocks
-/// the user adds until they refresh/reconnect.
-pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet {
- // "
- // SELECT
- // 1
- // FROM blocks
- // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)}))
- // OR (account_id = $2 AND target_account_id = $1)
- // UNION SELECT
- // 1
- // FROM mutes
- // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})`
- // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`"
- pg_pool
- .0
- .get()
- .unwrap()
- .query(
- "
-SELECT target_account_id
- FROM blocks
- WHERE account_id = $1
-UNION SELECT target_account_id
- FROM mutes
- WHERE account_id = $1",
- &[&user_id],
- )
- .expect("Hard-coded query will return Some([0 or more rows])")
- .iter()
- .map(|row| row.get(0))
- .collect()
-}
-/// Query Postgres for everyone who has blocked the user
-///
-/// **NOTE**: because we check this when the user connects, it will not include any blocks
-/// the user adds until they refresh/reconnect.
-pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet {
- pg_pool
- .0
- .get()
- .unwrap()
- .query(
- "
-SELECT account_id
- FROM blocks
- WHERE target_account_id = $1",
- &[&user_id],
- )
- .expect("Hard-coded query will return Some([0 or more rows])")
- .iter()
- .map(|row| row.get(0))
- .collect()
-}
-
-/// Query Postgres for all current domain blocks
-///
-/// **NOTE**: because we check this when the user connects, it will not include any blocks
-/// the user adds until they refresh/reconnect.
-pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet {
- pg_pool
- .0
- .get()
- .unwrap()
- .query(
- "SELECT domain FROM account_domain_blocks WHERE account_id = $1",
- &[&user_id],
- )
- .expect("Hard-coded query will return Some([0 or more rows])")
- .iter()
- .map(|row| row.get(0))
- .collect()
-}
-
-/// Test whether a user owns a list
-pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool {
- let mut conn = pg_pool.0.get().unwrap();
- // For the Postgres query, `id` = list number; `account_id` = user.id
- let rows = &conn
- .query(
- "
-SELECT id, account_id
-FROM lists
-WHERE id = $1
-LIMIT 1",
- &[&list_id],
- )
- .expect("Hard-coded query will return Some([0 or more rows])");
-
- match rows.get(0) {
- None => false,
- Some(row) => {
- let list_owner_id: i64 = row.get(1);
- list_owner_id == user_id
- }
- }
-}
diff --git a/src/parse_client_request/subscription/stdin b/src/parse_client_request/subscription/stdin
deleted file mode 100644
index e69de29..0000000
diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs
index aa74c60..e101bed 100644
--- a/src/parse_client_request/ws.rs
+++ b/src/parse_client_request/ws.rs
@@ -1,48 +1,9 @@
//! Filters for the WebSocket endpoint
-use super::{
- query::{self, Query},
- subscription::{PgPool, Subscription},
-};
-use warp::{filters::BoxedFilter, path, Filter};
-/// WebSocket filters
-fn parse_query() -> BoxedFilter<(Query,)> {
- path!("api" / "v1" / "streaming")
- .and(path::end())
- .and(warp::query())
- .and(query::Auth::to_filter())
- .and(query::Media::to_filter())
- .and(query::Hashtag::to_filter())
- .and(query::List::to_filter())
- .map(
- |stream: query::Stream,
- auth: query::Auth,
- media: query::Media,
- hashtag: query::Hashtag,
- list: query::List| {
- Query {
- access_token: auth.access_token,
- stream: stream.stream,
- media: media.is_truthy(),
- hashtag: hashtag.tag,
- list: list.list,
- }
- },
- )
- .boxed()
-}
-pub fn extract_user_and_token_or_reject(
- pg_pool: PgPool,
- whitelist_mode: bool,
-) -> BoxedFilter<(Subscription, Option)> {
- parse_query()
- .and(query::OptionalAccessToken::from_ws_header())
- .and_then(Query::update_access_token)
- .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode))
- .and(query::OptionalAccessToken::from_ws_header())
- .boxed()
-}
+
+
+
// #[cfg(test)]
// mod test {
diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs
index 0162598..18c79d3 100644
--- a/src/redis_to_client_stream/client_agent.rs
+++ b/src/redis_to_client_stream/client_agent.rs
@@ -19,29 +19,29 @@ use super::receiver::Receiver;
use crate::{
config,
messages::Event,
- parse_client_request::subscription::{PgPool, Stream::Public, Subscription, Timeline},
+ parse_client_request::{Stream::Public, Subscription, Timeline},
};
use futures::{
Async::{self, NotReady, Ready},
Poll,
};
-use std::sync;
+use std::sync::{Arc, Mutex};
use tokio::io::Error;
use uuid::Uuid;
/// Struct for managing all Redis streams.
#[derive(Clone, Debug)]
pub struct ClientAgent {
- receiver: sync::Arc>,
- id: uuid::Uuid,
+ receiver: Arc>,
+ id: Uuid,
pub subscription: Subscription,
}
impl ClientAgent {
/// Create a new `ClientAgent` with no shared data.
- pub fn blank(redis_cfg: config::RedisConfig, pg_pool: PgPool) -> Self {
+ pub fn blank(redis_cfg: config::RedisConfig) -> Self {
ClientAgent {
- receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg, pg_pool))),
+ receiver: Arc::new(Mutex::new(Receiver::new(redis_cfg))),
id: Uuid::default(),
subscription: Subscription::default(),
}
@@ -70,7 +70,11 @@ impl ClientAgent {
self.subscription = subscription;
let start_time = Instant::now();
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
- receiver.manage_new_timeline(self.id, self.subscription.timeline);
+ receiver.manage_new_timeline(
+ self.id,
+ self.subscription.timeline,
+ self.subscription.hashtag_name.clone(),
+ );
log::info!("init_for_user had lock for: {:?}", start_time.elapsed());
}
}
diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs
new file mode 100644
index 0000000..5157120
--- /dev/null
+++ b/src/redis_to_client_stream/event_stream.rs
@@ -0,0 +1,103 @@
+use super::ClientAgent;
+
+use warp::ws::WebSocket;
+use futures::{future::Future, stream::Stream, Async};
+use log;
+use std::time::{Duration, Instant};
+
+pub struct EventStream;
+
+impl EventStream {
+
+
+/// Send a stream of replies to a WebSocket client.
+ pub fn to_ws(
+ socket: WebSocket,
+ mut client_agent: ClientAgent,
+ update_interval: Duration,
+) -> impl Future- {
+ let (ws_tx, mut ws_rx) = socket.split();
+ let timeline = client_agent.subscription.timeline;
+
+ // Create a pipe
+ let (tx, rx) = futures::sync::mpsc::unbounded();
+
+ // Send one end of it to a different thread and tell that end to forward whatever it gets
+ // on to the websocket client
+ warp::spawn(
+ rx.map_err(|()| -> warp::Error { unreachable!() })
+ .forward(ws_tx)
+ .map(|_r| ())
+ .map_err(|e| match e.to_string().as_ref() {
+ "IO error: Broken pipe (os error 32)" => (), // just closed unix socket
+ _ => log::warn!("websocket send error: {}", e),
+ }),
+ );
+
+ // Yield new events for as long as the client is still connected
+ let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
+ move |_| match ws_rx.poll() {
+ Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
+ Ok(Async::Ready(None)) => {
+ // TODO: consider whether we should manually drop closed connections here
+ log::info!("Client closed WebSocket connection for {:?}", timeline);
+ futures::future::ok(false)
+ }
+ Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
+ // no err, just closed Unix socket
+ log::info!("Client closed WebSocket connection for {:?}", timeline);
+ futures::future::ok(false)
+ }
+ Err(e) => {
+ log::warn!("Error in {:?}: {}", timeline, e);
+ futures::future::ok(false)
+ }
+ },
+ );
+
+ let mut time = Instant::now();
+ // Every time you get an event from that stream, send it through the pipe
+ event_stream
+ .for_each(move |_instant| {
+ if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
+ tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
+ .expect("No send error");
+ };
+ if time.elapsed() > Duration::from_secs(30) {
+ tx.unbounded_send(warp::ws::Message::text("{}"))
+ .expect("Can ping");
+ time = Instant::now();
+ }
+ Ok(())
+ })
+ .then(move |result| {
+ // TODO: consider whether we should manually drop closed connections here
+ log::info!("WebSocket connection for {:?} closed.", timeline);
+ result
+ })
+ .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
+}
+ pub fn to_sse(
+ mut client_agent: ClientAgent,
+ connection: warp::sse::Sse,
+ update_interval: Duration,
+) ->impl warp::reply::Reply {
+ let event_stream =
+ tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
+ match client_agent.poll() {
+ Ok(Async::Ready(Some(event))) => Some((
+ warp::sse::event(event.event_name()),
+ warp::sse::data(event.payload().unwrap_or_else(String::new)),
+ )),
+ _ => None,
+ }
+ });
+
+ connection.reply(
+ warp::sse::keep_alive()
+ .interval(Duration::from_secs(30))
+ .text("thump".to_string())
+ .stream(event_stream),
+ )
+}
+}
diff --git a/src/redis_to_client_stream/message.rs b/src/redis_to_client_stream/message.rs
deleted file mode 100644
index 10cd1d6..0000000
--- a/src/redis_to_client_stream/message.rs
+++ /dev/null
@@ -1,87 +0,0 @@
-use crate::log_fatal;
-use crate::messages::Event;
-use serde_json::Value;
-use std::{collections::HashSet, string::String};
-use strum_macros::Display;
-
-#[derive(Debug, Display, Clone)]
-pub enum Message {
- Update(Status),
- Conversation(Value),
- Notification(Value),
- Delete(String),
- FiltersChanged,
- Announcement(AnnouncementType),
- UnknownEvent(String, Value),
-}
-
-#[derive(Debug, Clone)]
-pub struct Status(Value);
-
-#[derive(Debug, Clone)]
-pub enum AnnouncementType {
- New(Value),
- Delete(String),
- Reaction(Value),
-}
-
-impl Message {
- // pub fn from_json(event: Event) -> Self {
- // use AnnouncementType::*;
-
- // match event.event.as_ref() {
- // "update" => Self::Update(Status(event.payload)),
- // "conversation" => Self::Conversation(event.payload),
- // "notification" => Self::Notification(event.payload),
- // "delete" => Self::Delete(
- // event
- // .payload
- // .as_str()
- // .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event))
- // .to_string(),
- // ),
- // "filters_changed" => Self::FiltersChanged,
- // "announcement" => Self::Announcement(New(event.payload)),
- // "announcement.reaction" => Self::Announcement(Reaction(event.payload)),
- // "announcement.delete" => Self::Announcement(Delete(
- // event
- // .payload
- // .as_str()
- // .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event))
- // .to_string(),
- // )),
- // other => {
- // log::warn!("Received unexpected `event` from Redis: {}", other);
- // Self::UnknownEvent(event.event.to_string(), event.payload)
- // }
- // }
- // }
- pub fn event(&self) -> String {
- use AnnouncementType::*;
- match self {
- Self::Update(_) => "update",
- Self::Conversation(_) => "conversation",
- Self::Notification(_) => "notification",
- Self::Announcement(New(_)) => "announcement",
- Self::Announcement(Reaction(_)) => "announcement.reaction",
- Self::UnknownEvent(event, _) => &event,
- Self::Delete(_) => "delete",
- Self::Announcement(Delete(_)) => "announcement.delete",
- Self::FiltersChanged => "filters_changed",
- }
- .to_string()
- }
- pub fn payload(&self) -> String {
- use AnnouncementType::*;
- match self {
- Self::Update(status) => status.0.to_string(),
- Self::Conversation(value)
- | Self::Notification(value)
- | Self::Announcement(New(value))
- | Self::Announcement(Reaction(value))
- | Self::UnknownEvent(_, value) => value.to_string(),
- Self::Delete(id) | Self::Announcement(Delete(id)) => id.clone(),
- Self::FiltersChanged => "".to_string(),
- }
- }
-}
diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs
index 5295073..e4cdfb5 100644
--- a/src/redis_to_client_stream/mod.rs
+++ b/src/redis_to_client_stream/mod.rs
@@ -1,103 +1,14 @@
//! Stream the updates appropriate for a given `User`/`timeline` pair from Redis.
-pub mod client_agent;
-pub mod receiver;
-pub mod redis;
-pub use client_agent::ClientAgent;
-use futures::{future::Future, stream::Stream, Async};
-use log;
-use std::time::{Duration, Instant};
+mod client_agent;
+mod event_stream;
+mod receiver;
+mod redis;
-/// Send a stream of replies to a Server Sent Events client.
-pub fn send_updates_to_sse(
- mut client_agent: ClientAgent,
- connection: warp::sse::Sse,
- update_interval: Duration,
-) -> impl warp::reply::Reply {
- let event_stream =
- tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| {
- match client_agent.poll() {
- Ok(Async::Ready(Some(event))) => Some((
- warp::sse::event(event.event_name()),
- warp::sse::data(event.payload().unwrap_or_else(String::new)),
- )),
- _ => None,
- }
- });
+pub use {client_agent::ClientAgent, event_stream::EventStream};
- connection.reply(
- warp::sse::keep_alive()
- .interval(Duration::from_secs(30))
- .text("thump".to_string())
- .stream(event_stream),
- )
-}
-
-use warp::ws::WebSocket;
-
-/// Send a stream of replies to a WebSocket client.
-pub fn send_updates_to_ws(
- socket: WebSocket,
- mut client_agent: ClientAgent,
- update_interval: Duration,
-) -> impl Future
- {
- let (ws_tx, mut ws_rx) = socket.split();
- let timeline = client_agent.subscription.timeline;
-
- // Create a pipe
- let (tx, rx) = futures::sync::mpsc::unbounded();
-
- // Send one end of it to a different thread and tell that end to forward whatever it gets
- // on to the websocket client
- warp::spawn(
- rx.map_err(|()| -> warp::Error { unreachable!() })
- .forward(ws_tx)
- .map(|_r| ())
- .map_err(|e| match e.to_string().as_ref() {
- "IO error: Broken pipe (os error 32)" => (), // just closed unix socket
- _ => log::warn!("websocket send error: {}", e),
- }),
- );
-
- // Yield new events for as long as the client is still connected
- let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while(
- move |_| match ws_rx.poll() {
- Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true),
- Ok(Async::Ready(None)) => {
- // TODO: consider whether we should manually drop closed connections here
- log::info!("Client closed WebSocket connection for {:?}", timeline);
- futures::future::ok(false)
- }
- Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => {
- // no err, just closed Unix socket
- log::info!("Client closed WebSocket connection for {:?}", timeline);
- futures::future::ok(false)
- }
- Err(e) => {
- log::warn!("Error in {:?}: {}", timeline, e);
- futures::future::ok(false)
- }
- },
- );
-
- let mut time = Instant::now();
- // Every time you get an event from that stream, send it through the pipe
- event_stream
- .for_each(move |_instant| {
- if let Ok(Async::Ready(Some(msg))) = client_agent.poll() {
- tx.unbounded_send(warp::ws::Message::text(msg.to_json_string()))
- .expect("No send error");
- };
- if time.elapsed() > Duration::from_secs(30) {
- tx.unbounded_send(warp::ws::Message::text("{}"))
- .expect("Can ping");
- time = Instant::now();
- }
- Ok(())
- })
- .then(move |result| {
- // TODO: consider whether we should manually drop closed connections here
- log::info!("WebSocket connection for {:?} closed.", timeline);
- result
- })
- .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e))
-}
+#[cfg(test)]
+pub use receiver::process_messages;
+#[cfg(test)]
+pub use receiver::{MessageQueues, MsgQueue};
+#[cfg(test)]
+pub use redis::redis_msg::RedisMsg;
diff --git a/src/redis_to_client_stream/receiver/message_queues.rs b/src/redis_to_client_stream/receiver/message_queues.rs
index ffcd31e..eb45b30 100644
--- a/src/redis_to_client_stream/receiver/message_queues.rs
+++ b/src/redis_to_client_stream/receiver/message_queues.rs
@@ -1,5 +1,5 @@
use crate::messages::Event;
-use crate::parse_client_request::subscription::Timeline;
+use crate::parse_client_request::Timeline;
use std::{
collections::{HashMap, VecDeque},
fmt,
diff --git a/src/redis_to_client_stream/receiver/mod.rs b/src/redis_to_client_stream/receiver/mod.rs
index 364d1fd..87b35ad 100644
--- a/src/redis_to_client_stream/receiver/mod.rs
+++ b/src/redis_to_client_stream/receiver/mod.rs
@@ -2,62 +2,70 @@
//! polled by the correct `ClientAgent`. Also manages sububscriptions and
//! unsubscriptions to/from Redis.
mod message_queues;
+
+pub use message_queues::{MessageQueues, MsgQueue};
+
use crate::{
- config::{self, RedisInterval},
- log_fatal,
+ config,
+ err::RedisParseErr,
messages::Event,
- parse_client_request::subscription::{self, postgres, PgPool, Timeline},
+ parse_client_request::{Stream, Timeline},
pubsub_cmd,
- redis_to_client_stream::redis::{redis_cmd, RedisConn, RedisStream},
+ redis_to_client_stream::redis::redis_msg::RedisMsg,
+ redis_to_client_stream::redis::{redis_cmd, RedisConn},
};
use futures::{Async, Poll};
use lru::LruCache;
-pub use message_queues::{MessageQueues, MsgQueue};
-use std::{collections::HashMap, net, time::Instant};
+use tokio::io::AsyncRead;
+
+use std::{
+ collections::HashMap,
+ io::Read,
+ net, str,
+ time::{Duration, Instant},
+};
use tokio::io::Error;
use uuid::Uuid;
/// The item that streams from Redis and is polled by the `ClientAgent`
#[derive(Debug)]
pub struct Receiver {
- pub pubsub_connection: RedisStream,
+ pub pubsub_connection: net::TcpStream,
secondary_redis_connection: net::TcpStream,
- redis_poll_interval: RedisInterval,
+ redis_poll_interval: Duration,
redis_polled_at: Instant,
timeline: Timeline,
manager_id: Uuid,
pub msg_queues: MessageQueues,
clients_per_timeline: HashMap,
cache: Cache,
- pool: PgPool,
+ redis_input: Vec,
+ redis_namespace: Option,
}
+
#[derive(Debug)]
pub struct Cache {
+ // TODO: eventually, it might make sense to have Mastodon publish to timelines with
+ // the tag number instead of the tag name. This would save us from dealing
+ // with a cache here and would be consistent with how lists/users are handled.
id_to_hashtag: LruCache,
pub hashtag_to_id: LruCache,
}
-impl Cache {
- fn new(size: usize) -> Self {
- Self {
- id_to_hashtag: LruCache::new(size),
- hashtag_to_id: LruCache::new(size),
- }
- }
-}
+
impl Receiver {
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
/// active subscriptions).
- pub fn new(redis_cfg: config::RedisConfig, pool: PgPool) -> Self {
+ pub fn new(redis_cfg: config::RedisConfig) -> Self {
+ let redis_namespace = redis_cfg.namespace.clone();
+
let RedisConn {
primary: pubsub_connection,
secondary: secondary_redis_connection,
- namespace: redis_namespace,
polling_interval: redis_poll_interval,
} = RedisConn::new(redis_cfg);
Self {
- pubsub_connection: RedisStream::from_stream(pubsub_connection)
- .with_namespace(redis_namespace),
+ pubsub_connection,
secondary_redis_connection,
redis_poll_interval,
redis_polled_at: Instant::now(),
@@ -65,8 +73,12 @@ impl Receiver {
manager_id: Uuid::default(),
msg_queues: MessageQueues(HashMap::new()),
clients_per_timeline: HashMap::new(),
- cache: Cache::new(1000), // should this be a run-time option?
- pool,
+ cache: Cache {
+ id_to_hashtag: LruCache::new(1000),
+ hashtag_to_id: LruCache::new(1000),
+ }, // should these be run-time options?
+ redis_input: Vec::new(),
+ redis_namespace,
}
}
@@ -76,12 +88,15 @@ impl Receiver {
/// Note: this method calls `subscribe_or_unsubscribe_as_needed`,
/// so Redis PubSub subscriptions are only updated when a new timeline
/// comes under management for the first time.
- pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: Timeline) {
- self.manager_id = manager_id;
- self.timeline = timeline;
- self.msg_queues
- .insert(self.manager_id, MsgQueue::new(timeline));
- self.subscribe_or_unsubscribe_as_needed(timeline);
+ pub fn manage_new_timeline(&mut self, id: Uuid, tl: Timeline, hashtag: Option) {
+ self.timeline = tl;
+ if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (hashtag, tl) {
+ self.cache.id_to_hashtag.put(id, hashtag.clone());
+ self.cache.hashtag_to_id.put(hashtag, id);
+ };
+
+ self.msg_queues.insert(id, MsgQueue::new(tl));
+ self.subscribe_or_unsubscribe_as_needed(tl);
}
/// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate
@@ -91,26 +106,6 @@ impl Receiver {
self.timeline = timeline;
}
- fn if_hashtag_timeline_get_hashtag_name(&mut self, timeline: Timeline) -> Option {
- use subscription::Stream::*;
- if let Timeline(Hashtag(id), _, _) = timeline {
- let cached_tag = self.cache.id_to_hashtag.get(&id).map(String::from);
- let tag = match cached_tag {
- Some(tag) => tag,
- None => {
- let new_tag = postgres::select_hashtag_name(&id, self.pool.clone())
- .unwrap_or_else(|_| log_fatal!("No hashtag associated with tag #{}", &id));
- self.cache.hashtag_to_id.put(new_tag.clone(), id);
- self.cache.id_to_hashtag.put(id, new_tag.clone());
- new_tag.to_string()
- }
- };
- Some(tag)
- } else {
- None
- }
- }
-
/// Drop any PubSub subscriptions that don't have active clients and check
/// that there's a subscription to the current one. If there isn't, then
/// subscribe to it.
@@ -121,8 +116,10 @@ impl Receiver {
// Record the lower number of clients subscribed to that channel
for change in timelines_to_modify {
let timeline = change.timeline;
- let hashtag = self.if_hashtag_timeline_get_hashtag_name(timeline);
- let hashtag = hashtag.as_ref();
+ let hashtag = match timeline {
+ Timeline(Stream::Hashtag(id), _, _) => self.cache.id_to_hashtag.get(&id),
+ _non_hashtag_timeline => None,
+ };
let count_of_subscribed_clients = self
.clients_per_timeline
@@ -157,10 +154,27 @@ impl futures::stream::Stream for Receiver {
fn poll(&mut self) -> Poll