From 0acbde3eee2a8d536ec769699e5a15a34966a254 Mon Sep 17 00:00:00 2001 From: Daniel Sockwell Date: Fri, 27 Mar 2020 12:00:48 -0400 Subject: [PATCH] Reorganize code, pt1 (#110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Prevent Reciever from querying postgres Before this commit, the Receiver would query Postgres for the name associated with a hashtag when it encountered one not in its cache. This ensured that the Receiver never encountered a (valid) hashtag id that it couldn't handle, but caused a extra DB query and made independent sections of the code more entangled than they need to be. Now, we pass the relevant tag name to the Receiver when it first starts managing a new subscription and it adds the tag name to its cache then. * Improve module boundary/privacy * Reorganize Receiver to cut RedisStream * Fix tests for code reorganization Note that this change includes testing some private functionality by exposing it publicly in tests via conditional compilation. This doesn't expose that functionality for the benchmarks, so the benchmark tests do not currently pass without adding a few `pub use` statements. This might be worth changing later, but benchmark tests aren't part of our CI and it's not hard to change when we want to test performance. This change also cuts the benchmark tests that were benchmarking old ways Flodgatt functioned. Those were useful for comparison purposes, but have served their purpose – we've firmly moved away from the older/slower approach. * Fix Receiver for tests --- Cargo.lock | 2 +- benches/parse_redis.rs | 304 +------------- src/config/deployment_cfg_types.rs | 16 +- src/config/environmental_variables.rs | 137 ++++++ src/config/mod.rs | 132 +----- src/config/postgres_cfg_types.rs | 12 +- src/config/redis_cfg_types.rs | 14 +- src/err.rs | 70 +--- src/main.rs | 104 ++--- src/messages.rs | 83 ++-- src/parse_client_request/mod.rs | 18 +- src/parse_client_request/postgres.rs | 204 +++++++++ src/parse_client_request/query.rs | 22 +- src/parse_client_request/sse.rs | 74 ---- src/parse_client_request/subscription.rs | 396 ++++++++++++++++++ .../subscription/mock_postgres.rs | 43 -- src/parse_client_request/subscription/mod.rs | 196 --------- .../subscription/postgres.rs | 225 ---------- src/parse_client_request/subscription/stdin | 0 src/parse_client_request/ws.rs | 45 +- src/redis_to_client_stream/client_agent.rs | 18 +- src/redis_to_client_stream/event_stream.rs | 103 +++++ src/redis_to_client_stream/message.rs | 87 ---- src/redis_to_client_stream/mod.rs | 111 +---- .../receiver/message_queues.rs | 2 +- src/redis_to_client_stream/receiver/mod.rs | 170 +++++--- src/redis_to_client_stream/redis/mod.rs | 4 - src/redis_to_client_stream/redis/redis_cmd.rs | 2 +- .../redis/redis_connection.rs | 12 +- src/redis_to_client_stream/redis/redis_msg.rs | 52 +-- .../redis/redis_stream.rs | 127 ------ 31 files changed, 1181 insertions(+), 1604 deletions(-) create mode 100644 src/config/environmental_variables.rs create mode 100644 src/parse_client_request/postgres.rs create mode 100644 src/parse_client_request/subscription.rs delete mode 100644 src/parse_client_request/subscription/mock_postgres.rs delete mode 100644 src/parse_client_request/subscription/mod.rs delete mode 100644 src/parse_client_request/subscription/postgres.rs delete mode 100644 src/parse_client_request/subscription/stdin create mode 100644 src/redis_to_client_stream/event_stream.rs delete mode 100644 src/redis_to_client_stream/message.rs delete mode 100644 src/redis_to_client_stream/redis/redis_stream.rs diff --git a/Cargo.lock b/Cargo.lock index e33dbe8..7cb4f27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.6.4" +version = "0.6.5" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/benches/parse_redis.rs b/benches/parse_redis.rs index a683583..c1fd987 100644 --- a/benches/parse_redis.rs +++ b/benches/parse_redis.rs @@ -3,95 +3,27 @@ use criterion::criterion_group; use criterion::criterion_main; use criterion::Criterion; -const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"

Trending tags:
#neverforget
#4styles
#newpipe
#uber
#mercredifiction

\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"

Hello, I am mastodon.host official semi bot.

Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ).

I also randomly boost one of my followers toot every hour !

If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)

If you want me to follow you, just tell me follow !

If you want automatic follow for new users on your instance and you are an instance admin, contact me !

Other commands are private :)

\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"https://mastodon.host/stats.html\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"https://mastodon.host/about/more\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"@gled\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n"; - -/// Parses the Redis message using a Regex. -/// -/// The naive approach from Flodgatt's proof-of-concept stage. -mod regex_parse { - use regex::Regex; - use serde_json::Value; - - pub fn to_json_value(input: String) -> Value { - if input.ends_with("}\r\n") { - let messages = input.as_str().split("message").skip(1); - let regex = Regex::new(r"timeline:(?P.*?)\r\n\$\d+\r\n(?P.*?)\r\n") - .expect("Hard-codded"); - for message in messages { - let _timeline = regex.captures(message).expect("Hard-coded timeline regex") - ["timeline"] - .to_string(); - - let redis_msg: Value = serde_json::from_str( - ®ex.captures(message).expect("Hard-coded value regex")["value"], - ) - .expect("Valid json"); - - return redis_msg; - } - unreachable!() - } else { - unreachable!() - } - } -} - -/// Parse with a simplified inline iterator. -/// -/// Essentially shows best-case performance for producing a serde_json::Value. -mod parse_inline { - use serde_json::Value; - pub fn to_json_value(input: String) -> Value { - fn print_next_str(mut end: usize, input: &str) -> (usize, String) { - let mut start = end + 3; - end = start + 1; - - let mut iter = input.chars(); - iter.nth(start); - - while iter.next().unwrap().is_digit(10) { - end += 1; - } - let length = &input[start..end].parse::().unwrap(); - start = end + 2; - end = start + length; - - let string = &input[start..end]; - (end, string.to_string()) - } - - if input.ends_with("}\r\n") { - let end = 2; - let (end, _) = print_next_str(end, &input); - let (end, _timeline) = print_next_str(end, &input); - let (_, msg) = print_next_str(end, &input); - let redis_msg: Value = serde_json::from_str(&msg).unwrap(); - redis_msg - } else { - unreachable!() - } - } -} - /// Parse using Flodgatt's current functions mod flodgatt_parse_event { - use flodgatt::{messages::Event, redis_to_client_stream::receiver::MessageQueues}; use flodgatt::{ - parse_client_request::subscription::Timeline, - redis_to_client_stream::{receiver::MsgQueue, redis::redis_stream}, + messages::Event, + parse_client_request::Timeline, + redis_to_client_stream::{process_messages, MessageQueues, MsgQueue}, }; use lru::LruCache; use std::collections::HashMap; use uuid::Uuid; /// One-time setup, not included in testing time. - pub fn setup() -> MessageQueues { + pub fn setup() -> (LruCache, MessageQueues, Uuid, Timeline) { + let mut cache: LruCache = LruCache::new(1000); let mut queues_map = HashMap::new(); let id = Uuid::default(); - let timeline = Timeline::from_redis_raw_timeline("1", None); + let timeline = + Timeline::from_redis_raw_timeline("timeline:1", &mut cache, &None).expect("In test"); queues_map.insert(id, MsgQueue::new(timeline)); let queues = MessageQueues(queues_map); - queues + (cache, queues, id, timeline) } pub fn to_event_struct( @@ -101,201 +33,7 @@ mod flodgatt_parse_event { id: Uuid, timeline: Timeline, ) -> Event { - redis_stream::process_messages(input, &mut None, &mut cache, &mut queues).unwrap(); - queues - .oldest_msg_in_target_queue(id, timeline) - .expect("In test") - } -} - -/// Parse using modified a modified version of Flodgatt's current function. -/// -/// This version is modified to return a serde_json::Value instead of an Event to shows -/// the performance we would see if we used serde's built-in method for handling weakly -/// typed JSON instead of our own strongly typed struct. -mod flodgatt_parse_value { - use flodgatt::{log_fatal, parse_client_request::subscription::Timeline}; - use lru::LruCache; - use serde_json::Value; - use std::{ - collections::{HashMap, VecDeque}, - time::Instant, - }; - use uuid::Uuid; - #[derive(Debug)] - pub struct RedisMsg<'a> { - pub raw: &'a str, - pub cursor: usize, - pub prefix_len: usize, - } - - impl<'a> RedisMsg<'a> { - pub fn from_raw(raw: &'a str, prefix_len: usize) -> Self { - Self { - raw, - cursor: "*3\r\n".len(), //length of intro header - prefix_len, - } - } - - /// Move the cursor from the beginning of a number through its end and return the number - pub fn process_number(&mut self) -> usize { - let (mut selected_number, selection_start) = (0, self.cursor); - while let Ok(number) = self.raw[selection_start..=self.cursor].parse::() { - self.cursor += 1; - selected_number = number; - } - selected_number - } - - /// In a pubsub reply from Redis, an item can be either the name of the subscribed channel - /// or the msg payload. Either way, it follows the same format: - /// `$[LENGTH_OF_ITEM_BODY]\r\n[ITEM_BODY]\r\n` - pub fn next_field(&mut self) -> String { - self.cursor += "$".len(); - - let item_len = self.process_number(); - self.cursor += "\r\n".len(); - let item_start_position = self.cursor; - self.cursor += item_len; - let item = self.raw[item_start_position..self.cursor].to_string(); - self.cursor += "\r\n".len(); - item - } - - pub fn extract_raw_timeline_and_message(&mut self) -> (String, Value) { - let timeline = &self.next_field()[self.prefix_len..]; - let msg_txt = self.next_field(); - let msg_value: Value = serde_json::from_str(&msg_txt) - .unwrap_or_else(|_| log_fatal!("Invalid JSON from Redis: {:?}", &msg_txt)); - (timeline.to_string(), msg_value) - } - } - - pub struct MsgQueue { - pub timeline: Timeline, - pub messages: VecDeque, - _last_polled_at: Instant, - } - - pub struct MessageQueues(HashMap); - impl std::ops::Deref for MessageQueues { - type Target = HashMap; - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl std::ops::DerefMut for MessageQueues { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - - impl MessageQueues { - pub fn oldest_msg_in_target_queue( - &mut self, - id: Uuid, - timeline: Timeline, - ) -> Option { - self.entry(id) - .or_insert_with(|| MsgQueue::new(timeline)) - .messages - .pop_front() - } - } - - impl MsgQueue { - pub fn new(timeline: Timeline) -> Self { - MsgQueue { - messages: VecDeque::new(), - _last_polled_at: Instant::now(), - timeline, - } - } - } - - pub fn process_msg( - raw_utf: String, - namespace: &Option, - hashtag_id_cache: &mut LruCache, - queues: &mut MessageQueues, - ) { - // Only act if we have a full message (end on a msg boundary) - if !raw_utf.ends_with("}\r\n") { - return; - }; - let prefix_to_skip = match namespace { - Some(namespace) => format!("{}:timeline:", namespace), - None => "timeline:".to_string(), - }; - - let mut msg = RedisMsg::from_raw(&raw_utf, prefix_to_skip.len()); - - while !msg.raw.is_empty() { - let command = msg.next_field(); - match command.as_str() { - "message" => { - let (raw_timeline, msg_value) = msg.extract_raw_timeline_and_message(); - let hashtag = hashtag_from_timeline(&raw_timeline, hashtag_id_cache); - let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag); - for msg_queue in queues.values_mut() { - if msg_queue.timeline == timeline { - msg_queue.messages.push_back(msg_value.clone()); - } - } - } - - "subscribe" | "unsubscribe" => { - // No msg, so ignore & advance cursor to end - let _channel = msg.next_field(); - msg.cursor += ":".len(); - let _active_subscriptions = msg.process_number(); - msg.cursor += "\r\n".len(); - } - cmd => panic!("Invariant violation: {} is unexpected Redis output", cmd), - }; - msg = RedisMsg::from_raw(&msg.raw[msg.cursor..], msg.prefix_len); - } - } - - fn hashtag_from_timeline( - raw_timeline: &str, - hashtag_id_cache: &mut LruCache, - ) -> Option { - if raw_timeline.starts_with("hashtag") { - let tag_name = raw_timeline - .split(':') - .nth(1) - .unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline)) - .to_string(); - - let tag_id = *hashtag_id_cache - .get(&tag_name) - .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); - Some(tag_id) - } else { - None - } - } - pub fn setup() -> (LruCache, MessageQueues, Uuid, Timeline) { - let cache: LruCache = LruCache::new(1000); - let mut queues_map = HashMap::new(); - let id = Uuid::default(); - let timeline = Timeline::from_redis_raw_timeline("1", None); - queues_map.insert(id, MsgQueue::new(timeline)); - let queues = MessageQueues(queues_map); - (cache, queues, id, timeline) - } - - pub fn to_json_value( - input: String, - mut cache: &mut LruCache, - mut queues: &mut MessageQueues, - id: Uuid, - timeline: Timeline, - ) -> Value { - process_msg(input, &None, &mut cache, &mut queues); + process_messages(&input, &mut cache, &mut None, &mut queues); queues .oldest_msg_in_target_queue(id, timeline) .expect("In test") @@ -303,28 +41,10 @@ mod flodgatt_parse_value { } fn criterion_benchmark(c: &mut Criterion) { - let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string(); //INPUT.to_string(); + let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string(); let mut group = c.benchmark_group("Parse redis RESP array"); - // group.bench_function("parse to Value with a regex", |b| { - // b.iter(|| regex_parse::to_json_value(black_box(input.clone()))) - // }); - group.bench_function("parse to Value inline", |b| { - b.iter(|| parse_inline::to_json_value(black_box(input.clone()))) - }); - let (mut cache, mut queues, id, timeline) = flodgatt_parse_value::setup(); - group.bench_function("parse to Value using Flodgatt functions", |b| { - b.iter(|| { - black_box(flodgatt_parse_value::to_json_value( - black_box(input.clone()), - black_box(&mut cache), - black_box(&mut queues), - black_box(id), - black_box(timeline), - )) - }) - }); - let mut queues = flodgatt_parse_event::setup(); + let (mut cache, mut queues, id, timeline) = flodgatt_parse_event::setup(); group.bench_function("parse to Event using Flodgatt functions", |b| { b.iter(|| { black_box(flodgatt_parse_event::to_event_struct( @@ -340,3 +60,5 @@ fn criterion_benchmark(c: &mut Criterion) { criterion_group!(benches, criterion_benchmark); criterion_main!(benches); + +const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"

Trending tags:
#neverforget
#4styles
#newpipe
#uber
#mercredifiction

\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"

Hello, I am mastodon.host official semi bot.

Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ).

I also randomly boost one of my followers toot every hour !

If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)

If you want me to follow you, just tell me follow !

If you want automatic follow for new users on your instance and you are an instance admin, contact me !

Other commands are private :)

\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"https://mastodon.host/stats.html\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"https://mastodon.host/about/more\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"@gled\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n"; diff --git a/src/config/deployment_cfg_types.rs b/src/config/deployment_cfg_types.rs index b09de80..97d018b 100644 --- a/src/config/deployment_cfg_types.rs +++ b/src/config/deployment_cfg_types.rs @@ -11,14 +11,14 @@ from_env_var!( /// The current environment, which controls what file to read other ENV vars from let name = Env; let default: EnvInner = EnvInner::Development; - let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants())); + let (env_var, allowed_values) = ("RUST_ENV", &format!("one of: {:?}", EnvInner::variants())); let from_str = |s| EnvInner::from_str(s).ok(); ); from_env_var!( /// The address to run Flodgatt on let name = FlodgattAddr; let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded")); - let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string()); + let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)"); let from_str = |s| match s { "localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), _ => s.parse().ok(), @@ -28,35 +28,35 @@ from_env_var!( /// How verbosely Flodgatt should log messages let name = LogLevel; let default: LogLevelInner = LogLevelInner::Warn; - let (env_var, allowed_values) = ("RUST_LOG", format!("one of: {:?}", LogLevelInner::variants())); + let (env_var, allowed_values) = ("RUST_LOG", &format!("one of: {:?}", LogLevelInner::variants())); let from_str = |s| LogLevelInner::from_str(s).ok(); ); from_env_var!( /// A Unix Socket to use in place of a local address let name = Socket; let default: Option = None; - let (env_var, allowed_values) = ("SOCKET", "any string".to_string()); + let (env_var, allowed_values) = ("SOCKET", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// The time between replies sent via WebSocket let name = WsInterval; let default: Duration = Duration::from_millis(100); - let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string()); + let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds"); let from_str = |s| s.parse().map(Duration::from_millis).ok(); ); from_env_var!( /// The time between replies sent via Server Sent Events let name = SseInterval; let default: Duration = Duration::from_millis(100); - let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string()); + let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds"); let from_str = |s| s.parse().map(Duration::from_millis).ok(); ); from_env_var!( /// The port to run Flodgatt on let name = Port; let default: u16 = 4000; - let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string()); + let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535"); let from_str = |s| s.parse().ok(); ); from_env_var!( @@ -66,7 +66,7 @@ from_env_var!( /// (including otherwise public timelines). let name = WhitelistMode; let default: bool = false; - let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string()); + let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false"); let from_str = |s| s.parse().ok(); ); /// Permissions for Cross Origin Resource Sharing (CORS) diff --git a/src/config/environmental_variables.rs b/src/config/environmental_variables.rs new file mode 100644 index 0000000..a790c0b --- /dev/null +++ b/src/config/environmental_variables.rs @@ -0,0 +1,137 @@ +use std::{collections::HashMap, fmt}; + +pub struct EnvVar(pub HashMap); +impl std::ops::Deref for EnvVar { + type Target = HashMap; + fn deref(&self) -> &HashMap { + &self.0 + } +} + +impl Clone for EnvVar { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} +impl EnvVar { + pub fn new(vars: HashMap) -> Self { + Self(vars) + } + + pub fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) { + if let Some(value) = maybe_value { + self.0.insert(key.to_string(), value.to_string()); + } + } + + pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! { + log::error!( + r"{var} is set to `{value}`, which is invalid. + {var} must be {allowed_vals}.", + var = env_var, + value = supplied_value, + allowed_vals = allowed_values + ); + std::process::exit(1); + } +} +impl fmt::Display for EnvVar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut result = String::new(); + for env_var in [ + "NODE_ENV", + "RUST_LOG", + "BIND", + "PORT", + "SOCKET", + "SSE_FREQ", + "WS_FREQ", + "DATABASE_URL", + "DB_USER", + "USER", + "DB_PORT", + "DB_HOST", + "DB_PASS", + "DB_NAME", + "DB_SSLMODE", + "REDIS_HOST", + "REDIS_USER", + "REDIS_PORT", + "REDIS_PASSWORD", + "REDIS_USER", + "REDIS_DB", + ] + .iter() + { + if let Some(value) = self.get(&env_var.to_string()) { + result = format!("{}\n {}: {}", result, env_var, value) + } + } + write!(f, "{}", result) + } +} +#[macro_export] +macro_rules! maybe_update { + ($name:ident; $item: tt:$type:ty) => ( + pub fn $name(self, item: Option<$type>) -> Self { + match item { + Some($item) => Self{ $item, ..self }, + None => Self { ..self } + } + }); + ($name:ident; Some($item: tt: $type:ty)) => ( + fn $name(self, item: Option<$type>) -> Self{ + match item { + Some($item) => Self{ $item: Some($item), ..self }, + None => Self { ..self } + } + })} +#[macro_export] +macro_rules! from_env_var { + ($(#[$outer:meta])* + let name = $name:ident; + let default: $type:ty = $inner:expr; + let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr); + let from_str = |$arg:ident| $body:expr; + ) => { + pub struct $name(pub $type); + impl std::fmt::Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl std::ops::Deref for $name { + type Target = $type; + fn deref(&self) -> &$type { + &self.0 + } + } + impl std::default::Default for $name { + fn default() -> Self { + $name($inner) + } + } + impl $name { + fn inner_from_str($arg: &str) -> Option<$type> { + $body + } + pub fn maybe_update(self, var: Option<&String>) -> Self { + match var { + Some(empty_string) if empty_string.is_empty() => Self::default(), + Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| { + crate::config::EnvVar::err($env_var, value, $allowed_values) + })), + None => self, + } + + // if let Some(value) = var { + // Self(Self::inner_from_str(value).unwrap_or_else(|| { + // crate::err::env_var_fatal($env_var, value, $allowed_values) + // })) + // } else { + // self + // } + } + } + }; +} diff --git a/src/config/mod.rs b/src/config/mod.rs index b2b314d..6d7d8bc 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,135 +4,7 @@ mod postgres_cfg; mod postgres_cfg_types; mod redis_cfg; mod redis_cfg_types; -pub use self::{ - deployment_cfg::DeploymentConfig, - postgres_cfg::PostgresConfig, - redis_cfg::RedisConfig, - redis_cfg_types::{RedisInterval, RedisNamespace}, -}; -use std::{collections::HashMap, fmt}; +mod environmental_variables; -pub struct EnvVar(pub HashMap); -impl std::ops::Deref for EnvVar { - type Target = HashMap; - fn deref(&self) -> &HashMap { - &self.0 - } -} +pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar}; -impl Clone for EnvVar { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} -impl EnvVar { - pub fn new(vars: HashMap) -> Self { - Self(vars) - } - - fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) { - if let Some(value) = maybe_value { - self.0.insert(key.to_string(), value.to_string()); - } - } -} -impl fmt::Display for EnvVar { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut result = String::new(); - for env_var in [ - "NODE_ENV", - "RUST_LOG", - "BIND", - "PORT", - "SOCKET", - "SSE_FREQ", - "WS_FREQ", - "DATABASE_URL", - "DB_USER", - "USER", - "DB_PORT", - "DB_HOST", - "DB_PASS", - "DB_NAME", - "DB_SSLMODE", - "REDIS_HOST", - "REDIS_USER", - "REDIS_PORT", - "REDIS_PASSWORD", - "REDIS_USER", - "REDIS_DB", - ] - .iter() - { - if let Some(value) = self.get(&env_var.to_string()) { - result = format!("{}\n {}: {}", result, env_var, value) - } - } - write!(f, "{}", result) - } -} -#[macro_export] -macro_rules! maybe_update { - ($name:ident; $item: tt:$type:ty) => ( - pub fn $name(self, item: Option<$type>) -> Self { - match item { - Some($item) => Self{ $item, ..self }, - None => Self { ..self } - } - }); - ($name:ident; Some($item: tt: $type:ty)) => ( - fn $name(self, item: Option<$type>) -> Self{ - match item { - Some($item) => Self{ $item: Some($item), ..self }, - None => Self { ..self } - } - })} -#[macro_export] -macro_rules! from_env_var { - ($(#[$outer:meta])* - let name = $name:ident; - let default: $type:ty = $inner:expr; - let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr); - let from_str = |$arg:ident| $body:expr; - ) => { - pub struct $name(pub $type); - impl std::fmt::Debug for $name { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } - } - impl std::ops::Deref for $name { - type Target = $type; - fn deref(&self) -> &$type { - &self.0 - } - } - impl std::default::Default for $name { - fn default() -> Self { - $name($inner) - } - } - impl $name { - fn inner_from_str($arg: &str) -> Option<$type> { - $body - } - pub fn maybe_update(self, var: Option<&String>) -> Self { - match var { - Some(empty_string) if empty_string.is_empty() => Self::default(), - Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| { - crate::err::env_var_fatal($env_var, value, $allowed_values) - })), - None => self, - } - - // if let Some(value) = var { - // Self(Self::inner_from_str(value).unwrap_or_else(|| { - // crate::err::env_var_fatal($env_var, value, $allowed_values) - // })) - // } else { - // self - // } - } - } - }; -} diff --git a/src/config/postgres_cfg_types.rs b/src/config/postgres_cfg_types.rs index 066f2fd..7551d1b 100644 --- a/src/config/postgres_cfg_types.rs +++ b/src/config/postgres_cfg_types.rs @@ -6,7 +6,7 @@ from_env_var!( /// The user to use for Postgres let name = PgUser; let default: String = "postgres".to_string(); - let (env_var, allowed_values) = ("DB_USER", "any string".to_string()); + let (env_var, allowed_values) = ("DB_USER", "any string"); let from_str = |s| Some(s.to_string()); ); @@ -14,7 +14,7 @@ from_env_var!( /// The host address where Postgres is running) let name = PgHost; let default: String = "localhost".to_string(); - let (env_var, allowed_values) = ("DB_HOST", "any string".to_string()); + let (env_var, allowed_values) = ("DB_HOST", "any string"); let from_str = |s| Some(s.to_string()); ); @@ -22,7 +22,7 @@ from_env_var!( /// The password to use with Postgress let name = PgPass; let default: Option = None; - let (env_var, allowed_values) = ("DB_PASS", "any string".to_string()); + let (env_var, allowed_values) = ("DB_PASS", "any string"); let from_str = |s| Some(Some(s.to_string())); ); @@ -30,7 +30,7 @@ from_env_var!( /// The Postgres database to use let name = PgDatabase; let default: String = "mastodon_development".to_string(); - let (env_var, allowed_values) = ("DB_NAME", "any string".to_string()); + let (env_var, allowed_values) = ("DB_NAME", "any string"); let from_str = |s| Some(s.to_string()); ); @@ -38,14 +38,14 @@ from_env_var!( /// The port Postgres is running on let name = PgPort; let default: u16 = 5432; - let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535".to_string()); + let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535"); let from_str = |s| s.parse().ok(); ); from_env_var!( let name = PgSslMode; let default: PgSslInner = PgSslInner::Prefer; - let (env_var, allowed_values) = ("DB_SSLMODE", format!("one of: {:?}", PgSslInner::variants())); + let (env_var, allowed_values) = ("DB_SSLMODE", &format!("one of: {:?}", PgSslInner::variants())); let from_str = |s| PgSslInner::from_str(s).ok(); ); diff --git a/src/config/redis_cfg_types.rs b/src/config/redis_cfg_types.rs index d3dbbf0..8d43c76 100644 --- a/src/config/redis_cfg_types.rs +++ b/src/config/redis_cfg_types.rs @@ -7,48 +7,48 @@ from_env_var!( /// The host address where Redis is running let name = RedisHost; let default: String = "127.0.0.1".to_string(); - let (env_var, allowed_values) = ("REDIS_HOST", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_HOST", "any string"); let from_str = |s| Some(s.to_string()); ); from_env_var!( /// The port Redis is running on let name = RedisPort; let default: u16 = 6379; - let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string()); + let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535"); let from_str = |s| s.parse().ok(); ); from_env_var!( /// How frequently to poll Redis let name = RedisInterval; let default: Duration = Duration::from_millis(100); - let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string()); + let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds"); let from_str = |s| s.parse().map(Duration::from_millis).ok(); ); from_env_var!( /// The password to use for Redis let name = RedisPass; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// An optional Redis Namespace let name = RedisNamespace; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// A user for Redis (not supported) let name = RedisUser; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_USER", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// The database to use with Redis (no current effect for PubSub connections) let name = RedisDb; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_DB", "any string"); let from_str = |s| Some(Some(s.to_string())); ); diff --git a/src/err.rs b/src/err.rs index 6ffbaa0..2863c5e 100644 --- a/src/err.rs +++ b/src/err.rs @@ -1,4 +1,3 @@ -use serde_derive::Serialize; use std::fmt::Display; pub fn die_with_msg(msg: impl Display) -> ! { @@ -14,67 +13,20 @@ macro_rules! log_fatal { };}; } -pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! { - eprintln!( - r"FATAL ERROR: {var} is set to `{value}`, which is invalid. - {var} must be {allowed_vals}.", - var = env_var, - value = supplied_value, - allowed_vals = allowed_values - ); - std::process::exit(1); +#[derive(Debug)] +pub enum RedisParseErr { + Incomplete, + Unrecoverable, } -#[macro_export] -macro_rules! dbg_and_die { - ($msg:expr) => { - let message = format!("FATAL ERROR: {}", $msg); - dbg!(message); - std::process::exit(1); - }; -} -pub fn unwrap_or_die(s: Option, msg: &str) -> T { - s.unwrap_or_else(|| { - eprintln!("FATAL ERROR: {}", msg); - std::process::exit(1) - }) +#[derive(Debug)] +pub enum TimelineErr { + RedisNamespaceMismatch, + InvalidInput, } -#[derive(Serialize)] -pub struct ErrorMessage { - error: String, -} -impl ErrorMessage { - fn new(msg: impl std::fmt::Display) -> Self { - Self { - error: msg.to_string(), - } - } -} - -/// Recover from Errors by sending appropriate Warp::Rejections -pub fn handle_errors( - rejection: warp::reject::Rejection, -) -> Result { - let err_txt = match rejection.cause() { - Some(text) if text.to_string() == "Missing request header 'authorization'" => { - "Error: Missing access token".to_string() - } - Some(text) => text.to_string(), - None => "Error: Nonexistant endpoint".to_string(), - }; - let json = warp::reply::json(&ErrorMessage::new(err_txt)); - - Ok(warp::reply::with_status( - json, - warp::http::StatusCode::UNAUTHORIZED, - )) -} - -pub struct CustomError {} - -impl CustomError { - pub fn unauthorized_list() -> warp::reject::Rejection { - warp::reject::custom("Error: Access to list not authorized") +impl From for TimelineErr { + fn from(_error: std::num::ParseIntError) -> Self { + Self::InvalidInput } } diff --git a/src/main.rs b/src/main.rs index b964926..2d2aaa4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,87 +1,76 @@ use flodgatt::{ - config, err, - parse_client_request::{sse, subscription, ws}, - redis_to_client_stream::{self, ClientAgent}, + config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig}, + parse_client_request::{PgPool, Subscription}, + redis_to_client_stream::{ClientAgent, EventStream}, }; use std::{collections::HashMap, env, fs, net, os::unix::fs::PermissionsExt}; use tokio::net::UnixListener; use warp::{path, ws::Ws2, Filter}; fn main() { - dotenv::from_filename( - match env::var("ENV").ok().as_ref().map(String::as_str) { + dotenv::from_filename(match env::var("ENV").ok().as_ref().map(String::as_str) { Some("production") => ".env.production", Some("development") | None => ".env", - Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."), - }).ok(); + Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"), + }) + .ok(); let env_vars_map: HashMap<_, _> = dotenv::vars().collect(); - let env_vars = config::EnvVar::new(env_vars_map); + let env_vars = EnvVar::new(env_vars_map); pretty_env_logger::init(); log::info!( "Flodgatt recognized the following environmental variables:{}", env_vars.clone() ); - let redis_cfg = config::RedisConfig::from_env(env_vars.clone()); - let cfg = config::DeploymentConfig::from_env(env_vars.clone()); + let redis_cfg = RedisConfig::from_env(env_vars.clone()); + let cfg = DeploymentConfig::from_env(env_vars.clone()); - let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone()); - let pg_pool = subscription::PgPool::new(postgres_cfg); + let postgres_cfg = PostgresConfig::from_env(env_vars.clone()); + let pg_pool = PgPool::new(postgres_cfg); - let client_agent_sse = ClientAgent::blank(redis_cfg, pg_pool.clone()); + let client_agent_sse = ClientAgent::blank(redis_cfg); let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); log::info!("Streaming server initialized and ready to accept connections"); // Server Sent Events - let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); - let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode) + let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); + let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode) .and(warp::sse()) .map( - move |subscription: subscription::Subscription, - sse_connection_to_client: warp::sse::Sse| { + move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); // Create a new ClientAgent let mut client_agent = client_agent_sse.clone_with_shared_receiver(); // Assign ClientAgent to generate stream of updates for the user/timeline pair client_agent.init_for_user(subscription); // send the updates through the SSE connection - redis_to_client_stream::send_updates_to_sse( - client_agent, - sse_connection_to_client, - sse_update_interval, - ) + EventStream::to_sse(client_agent, sse_connection_to_client, sse_interval) }, ) - .with(warp::reply::with::header("Connection", "keep-alive")) - .recover(err::handle_errors); + .with(warp::reply::with::header("Connection", "keep-alive")); // WebSocket let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode); - let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode) + let websocket_routes = Subscription::from_ws_request(pg_pool.clone(), whitelist_mode) .and(warp::ws::ws2()) - .map( - move |subscription: subscription::Subscription, token: Option, ws: Ws2| { - log::info!("Incoming websocket request for {:?}", subscription.timeline); - // Create a new ClientAgent - let mut client_agent = client_agent_ws.clone_with_shared_receiver(); - // Assign that agent to generate a stream of updates for the user/timeline pair - client_agent.init_for_user(subscription); - // send the updates through the WS connection (along with the User's access_token - // which is sent for security) + .map(move |subscription: Subscription, ws: Ws2| { + log::info!("Incoming websocket request for {:?}", subscription.timeline); - ( - ws.on_upgrade(move |socket| { - redis_to_client_stream::send_updates_to_ws( - socket, - client_agent, - ws_update_interval, - ) - }), - token.unwrap_or_else(String::new), - ) - }, - ) + let token = subscription.access_token.clone(); + // Create a new ClientAgent + let mut client_agent = client_agent_ws.clone_with_shared_receiver(); + // Assign that agent to generate a stream of updates for the user/timeline pair + client_agent.init_for_user(subscription); + // send the updates through the WS connection (along with the User's access_token + // which is sent for security) + ( + ws.on_upgrade(move |socket| { + EventStream::to_ws(socket, client_agent, ws_update_interval) + }), + token.unwrap_or_else(String::new), + ) + }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); let cors = warp::cors() @@ -98,7 +87,28 @@ fn main() { fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap(); - warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run_incoming(incoming); + warp::serve( + health.or(websocket_routes.or(sse_routes).with(cors).recover( + |rejection: warp::reject::Rejection| { + let err_txt = match rejection.cause() { + Some(text) + if text.to_string() == "Missing request header 'authorization'" => + { + "Error: Missing access token".to_string() + } + Some(text) => text.to_string(), + None => "Error: Nonexistant endpoint".to_string(), + }; + let json = warp::reply::json(&err_txt); + + Ok(warp::reply::with_status( + json, + warp::http::StatusCode::UNAUTHORIZED, + )) + }, + )), + ) + .run_incoming(incoming); } else { let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0); warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run(server_addr); diff --git a/src/messages.rs b/src/messages.rs index 5facc8e..a31cade 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -429,27 +429,24 @@ impl Status { mod test { use super::*; use crate::{ - parse_client_request::subscription::{Content::*, Reach::*, Stream::*, Timeline}, - redis_to_client_stream::{ - receiver::{MessageQueues, MsgQueue}, - redis::{ - redis_msg::{ParseErr, RedisMsg}, - redis_stream, - }, - }, + err::RedisParseErr, + parse_client_request::{Content::*, Reach::*, Stream::*, Timeline}, + redis_to_client_stream::{MessageQueues, MsgQueue, RedisMsg}, }; use lru::LruCache; use std::collections::HashMap; use uuid::Uuid; - type Err = ParseErr; + type Err = RedisParseErr; /// Set up state shared between multiple tests of Redis parsing pub fn shared_setup() -> (LruCache, MessageQueues, Uuid, Timeline) { - let cache: LruCache = LruCache::new(1000); + let mut cache: LruCache = LruCache::new(1000); let mut queues_map = HashMap::new(); - let id = Uuid::default(); + let id = dbg!(Uuid::default()); - let timeline = Timeline::from_redis_raw_timeline("4", None); + let timeline = dbg!( + Timeline::from_redis_raw_timeline("timeline:4", &mut cache, &None).expect("In test") + ); queues_map.insert(id, MsgQueue::new(timeline)); let queues = MessageQueues(queues_map); (cache, queues, id, timeline) @@ -460,8 +457,8 @@ mod test { let input ="*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:4\r\n$1386\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102866835379605039\",\"created_at\":\"2019-09-27T22:29:02.590Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"http://localhost:3000/users/admin/statuses/102866835379605039\",\"url\":\"http://localhost:3000/@admin/102866835379605039\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"

@susan hi

\",\"reblog\":null,\"application\":{\"name\":\"Web\",\"website\":null},\"account\":{\"id\":\"1\",\"username\":\"admin\",\"acct\":\"admin\",\"display_name\":\"\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-07-04T00:21:05.890Z\",\"note\":\"

\",\"url\":\"http://localhost:3000/@admin\",\"avatar\":\"http://localhost:3000/avatars/original/missing.png\",\"avatar_static\":\"http://localhost:3000/avatars/original/missing.png\",\"header\":\"http://localhost:3000/headers/original/missing.png\",\"header_static\":\"http://localhost:3000/headers/original/missing.png\",\"followers_count\":3,\"following_count\":3,\"statuses_count\":192,\"emojis\":[],\"fields\":[]},\"media_attachments\":[],\"mentions\":[{\"id\":\"4\",\"username\":\"susan\",\"url\":\"http://localhost:3000/@susan\",\"acct\":\"susan\"}],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1569623342825}\r\n"; let (mut cache, mut queues, id, timeline) = shared_setup(); - redis_stream::process_messages(input.to_string(), &mut None, &mut cache, &mut queues) - .map_err(|_| ParseErr::Unrecoverable)?; + crate::redis_to_client_stream::process_messages(input, &mut cache, &mut None, &mut queues); + let parsed_event = queues.oldest_msg_in_target_queue(id, timeline).unwrap(); let test_event = Event::Update{ payload: Status { id: "102866835379605039".to_string(), @@ -538,40 +535,40 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!(subscription_msg1, RedisMsg::SubscriptionMsg)); - let (subscription_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg2, RedisMsg::SubscriptionMsg)); - let (subscription_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg3, RedisMsg::SubscriptionMsg)); - let (subscription_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg4, RedisMsg::SubscriptionMsg)); - let (subscription_msg5, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg5, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg5, RedisMsg::SubscriptionMsg)); - let (update_msg1, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg1, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg1, RedisMsg::EventMsg(_, Event::Update { .. }) )); - let (update_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg2, RedisMsg::EventMsg(_, Event::Update { .. }) )); - let (update_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg3, RedisMsg::EventMsg(_, Event::Update { .. }) )); - let (update_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg4, RedisMsg::EventMsg(_, Event::Update { .. }) @@ -588,7 +585,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( subscription_msg1, RedisMsg::EventMsg(Timeline(User(id), Federated, All), Event::Notification { .. }) if id == 55 @@ -603,9 +600,9 @@ mod test { fn parse_redis_input_delete() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$49\r\n{\"event\":\"delete\",\"payload\":\"103864778284581232\"}\r\n"; - let (mut cache, _, _, _) = shared_setup(); + let (mut cache, _, _, _) = dbg!(shared_setup()); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( subscription_msg1, RedisMsg::EventMsg( @@ -625,7 +622,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( subscription_msg1, RedisMsg::EventMsg(Timeline(User(id), Federated, All), Event::FiltersChanged) if id == 56 @@ -642,7 +639,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg( @@ -660,7 +657,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg( @@ -679,7 +676,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg( @@ -701,7 +698,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -721,7 +718,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -741,7 +738,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -761,7 +758,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -781,7 +778,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -797,7 +794,7 @@ mod test { }, ) )); - let (msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; dbg!(&msg2); assert!(matches!( msg2, @@ -814,7 +811,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -840,7 +837,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -863,7 +860,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -886,7 +883,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -904,7 +901,7 @@ mod test { fn parse_redis_input_from_live_data_1() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$2799\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880088450458596\",\"created_at\":\"2020-03-24T21:12:37.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"es\",\"uri\":\"https://mastodon.social/users/durru/statuses/103880088436492032\",\"url\":\"https://mastodon.social/@durru/103880088436492032\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"

¡No puedes salir, loca!

\",\"reblog\":null,\"account\":{\"id\":\"2271\",\"username\":\"durru\",\"acct\":\"durru@mastodon.social\",\"display_name\":\"Cloaca Maxima\",\"locked\":false,\"bot\":false,\"discoverable\":true,\"group\":false,\"created_at\":\"2020-03-24T21:27:31.669Z\",\"note\":\"

Todo pasa, antes o después, por la Cloaca, diría Vitruvio.
También compongo palíndromos.

\",\"url\":\"https://mastodon.social/@durru\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/002/271/original/d7675a6ff9d9baa7.jpeg?1585085250\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/002/271/original/d7675a6ff9d9baa7.jpeg?1585085250\",\"header\":\"https://instance.codesections.com/system/accounts/headers/000/002/271/original/e3f0a1989b0d8efc.jpeg?1585085250\",\"header_static\":\"https://instance.codesections.com/system/accounts/headers/000/002/271/original/e3f0a1989b0d8efc.jpeg?1585085250\",\"followers_count\":222,\"following_count\":81,\"statuses_count\":5443,\"last_status_at\":\"2020-03-24\",\"emojis\":[],\"fields\":[{\"name\":\"Mis fotos\",\"value\":\"https://pixelfed.de/durru\",\"verified_at\":null},{\"name\":\"diaspora*\",\"value\":\"https://joindiaspora.com/people/75fec0e05114013484870242ac110007\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2864\",\"type\":\"image\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/864/original/3988312d30936494.jpeg?1585085251\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/864/small/3988312d30936494.jpeg?1585085251\",\"remote_url\":\"https://files.mastodon.social/media_attachments/files/026/669/690/original/d8171331f956cf38.jpg\",\"text_url\":null,\"meta\":{\"original\":{\"width\":1001,\"height\":662,\"size\":\"1001x662\",\"aspect\":1.512084592145015},\"small\":{\"width\":491,\"height\":325,\"size\":\"491x325\",\"aspect\":1.5107692307692309}},\"description\":null,\"blurhash\":\"UdLqhI4n4TIUIAt7t7ay~qIojtRj?bM{M{of\"}],\"mentions\":[],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n"; let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. }) @@ -917,7 +914,7 @@ mod test { fn parse_redis_input_from_live_data_2() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$3888\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880373579328660\",\"created_at\":\"2020-03-24T22:25:05.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"https://newsbots.eu/users/granma/statuses/103880373417385978\",\"url\":\"https://newsbots.eu/@granma/103880373417385978\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"

A total of 11 measures have been established for the pre-epidemic stage of the battle against #Covid-19 in #Cuba
#CubaPorLaSalud
http://en.granma.cu/cuba/2020-03-23/public-health-measures-in-covid-19-pre-epidemic-stage 

\",\"reblog\":null,\"account\":{\"id\":\"717\",\"username\":\"granma\",\"acct\":\"granma@newsbots.eu\",\"display_name\":\"Granma (Unofficial)\",\"locked\":false,\"bot\":true,\"discoverable\":false,\"group\":false,\"created_at\":\"2020-03-13T11:08:08.420Z\",\"note\":\"

\",\"url\":\"https://newsbots.eu/@granma\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/717/original/4a1f9ed090fc36e9.jpeg?1584097687\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/717/original/4a1f9ed090fc36e9.jpeg?1584097687\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":57,\"following_count\":1,\"statuses_count\":742,\"last_status_at\":\"2020-03-24\",\"emojis\":[],\"fields\":[{\"name\":\"Source\",\"value\":\"https://twitter.com/Granma_English\",\"verified_at\":null},{\"name\":\"Operator\",\"value\":\"@felix\",\"verified_at\":null},{\"name\":\"Code\",\"value\":\"https://yerbamate.dev/nutomic/tootbot\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2881\",\"type\":\"image\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/881/original/a1e97908e84efbcd.jpeg?1585088707\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/881/small/a1e97908e84efbcd.jpeg?1585088707\",\"remote_url\":\"https://newsbots.eu/system/media_attachments/files/000/176/298/original/f30a877d5035f4a6.jpeg\",\"text_url\":null,\"meta\":{\"original\":{\"width\":700,\"height\":795,\"size\":\"700x795\",\"aspect\":0.8805031446540881},\"small\":{\"width\":375,\"height\":426,\"size\":\"375x426\",\"aspect\":0.8802816901408451}},\"description\":null,\"blurhash\":\"UHCY?%sD%1t6}snOxuxu#7rrx]xu$*i_NFNF\"}],\"mentions\":[],\"tags\":[{\"name\":\"covid\",\"url\":\"https://instance.codesections.com/tags/covid\"},{\"name\":\"cuba\",\"url\":\"https://instance.codesections.com/tags/cuba\"},{\"name\":\"CubaPorLaSalud\",\"url\":\"https://instance.codesections.com/tags/CubaPorLaSalud\"}],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n"; let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. }) @@ -930,7 +927,7 @@ mod test { fn parse_redis_input_from_live_data_3() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$4803\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880453908763088\",\"created_at\":\"2020-03-24T22:45:33.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"https://mstdn.social/users/stux/statuses/103880453855603541\",\"url\":\"https://mstdn.social/@stux/103880453855603541\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"

When they say lockdown. LOCKDOWN.

\",\"reblog\":null,\"account\":{\"id\":\"806\",\"username\":\"stux\",\"acct\":\"stux@mstdn.social\",\"display_name\":\"sтυx⚡\",\"locked\":false,\"bot\":false,\"discoverable\":true,\"group\":false,\"created_at\":\"2020-03-13T23:02:29.970Z\",\"note\":\"

Hi, Stux here! I am running the mstdn.social :mastodon: instance!

For questions and help or just for fun you can always send me a toot♥\u{fe0f}

Oh and no, I am not really a cat! Or am I?

\",\"url\":\"https://mstdn.social/@stux\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/806/original/dae8d9d01d57d7f8.gif?1584140547\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/806/static/dae8d9d01d57d7f8.png?1584140547\",\"header\":\"https://instance.codesections.com/system/accounts/headers/000/000/806/original/88c874d69f7d6989.gif?1584140548\",\"header_static\":\"https://instance.codesections.com/system/accounts/headers/000/000/806/static/88c874d69f7d6989.png?1584140548\",\"followers_count\":13954,\"following_count\":7600,\"statuses_count\":10207,\"last_status_at\":\"2020-03-24\",\"emojis\":[{\"shortcode\":\"mastodon\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/418/original/25ccc64333645735.png?1584140550\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/418/static/25ccc64333645735.png?1584140550\",\"visible_in_picker\":true},{\"shortcode\":\"patreon\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/419/original/3cc463d3dfc1e489.png?1584140550\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/419/static/3cc463d3dfc1e489.png?1584140550\",\"visible_in_picker\":true},{\"shortcode\":\"liberapay\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/420/original/893854353dfa9706.png?1584140551\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/420/static/893854353dfa9706.png?1584140551\",\"visible_in_picker\":true},{\"shortcode\":\"team_valor\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/958/original/96aae26b45292a12.png?1584910917\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/958/static/96aae26b45292a12.png?1584910917\",\"visible_in_picker\":true}],\"fields\":[{\"name\":\"Patreon :patreon:\",\"value\":\"https://www.patreon.com/mstdn\",\"verified_at\":null},{\"name\":\"LiberaPay :liberapay:\",\"value\":\"https://liberapay.com/mstdn\",\"verified_at\":null},{\"name\":\"Team :team_valor:\",\"value\":\"https://mstdn.social/team\",\"verified_at\":null},{\"name\":\"Support :mastodon:\",\"value\":\"https://mstdn.social/funding\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2886\",\"type\":\"video\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/886/original/22b3f98a5e8f86d8.mp4?1585090023\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/886/small/22b3f98a5e8f86d8.png?1585090023\",\"remote_url\":\"https://cdn.mstdn.social/mstdn-social/media_attachments/files/003/338/384/original/c146f62ba86fe63e.mp4\",\"text_url\":null,\"meta\":{\"length\":\"0:00:27.03\",\"duration\":27.03,\"fps\":30,\"size\":\"272x480\",\"width\":272,\"height\":480,\"aspect\":0.5666666666666667,\"audio_encode\":\"aac (LC) (mp4a / 0x6134706D)\",\"audio_bitrate\":\"44100 Hz\",\"audio_channels\":\"stereo\",\"original\":{\"width\":272,\"height\":480,\"frame_rate\":\"30/1\",\"duration\":27.029,\"bitrate\":481885},\"small\":{\"width\":227,\"height\":400,\"size\":\"227x400\",\"aspect\":0.5675}},\"description\":null,\"blurhash\":\"UBF~N@OF-:xv4mM|s+ob9FE2t6tQ9Fs:t8oN\"}],\"mentions\":[],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n"; let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. }) diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs index c999c0e..c076eb2 100644 --- a/src/parse_client_request/mod.rs +++ b/src/parse_client_request/mod.rs @@ -1,5 +1,13 @@ -//! Parse the client request and return a (possibly authenticated) `User` -pub mod query; -pub mod sse; -pub mod subscription; -pub mod ws; +//! Parse the client request and return a Subscription +mod postgres; +mod query; +mod sse; +mod subscription; +mod ws; + +pub use self::postgres::PgPool; +// TODO consider whether we can remove `Stream` from public API +pub use subscription::{Stream, Subscription, Timeline}; + +#[cfg(test)] +pub use subscription::{Content, Reach}; diff --git a/src/parse_client_request/postgres.rs b/src/parse_client_request/postgres.rs new file mode 100644 index 0000000..606f583 --- /dev/null +++ b/src/parse_client_request/postgres.rs @@ -0,0 +1,204 @@ +//! Postgres queries +use crate::{ + config, + parse_client_request::subscription::{Scope, UserData}, +}; +use ::postgres; +use r2d2_postgres::PostgresConnectionManager; +use std::collections::HashSet; +use warp::reject::Rejection; + +#[derive(Clone, Debug)] +pub struct PgPool(pub r2d2::Pool>); +impl PgPool { + pub fn new(pg_cfg: config::PostgresConfig) -> Self { + let mut cfg = postgres::Config::new(); + cfg.user(&pg_cfg.user) + .host(&*pg_cfg.host.to_string()) + .port(*pg_cfg.port) + .dbname(&pg_cfg.database); + if let Some(password) = &*pg_cfg.password { + cfg.password(password); + }; + + let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); + let pool = r2d2::Pool::builder() + .max_size(10) + .build(manager) + .expect("Can connect to local postgres"); + Self(pool) + } + pub fn select_user(self, token: &str) -> Result { + let mut conn = self.0.get().unwrap(); + let query_rows = conn + .query( + " +SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes +FROM +oauth_access_tokens +INNER JOIN users ON +oauth_access_tokens.resource_owner_id = users.id +WHERE oauth_access_tokens.token = $1 +AND oauth_access_tokens.revoked_at IS NULL +LIMIT 1", + &[&token.to_owned()], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if let Some(result_columns) = query_rows.get(0) { + let id = result_columns.get(1); + let allowed_langs = result_columns + .try_get::<_, Vec<_>>(2) + .unwrap_or_else(|_| Vec::new()) + .into_iter() + .collect(); + let mut scopes: HashSet = result_columns + .get::<_, String>(3) + .split(' ') + .filter_map(|scope| match scope { + "read" => Some(Scope::Read), + "read:statuses" => Some(Scope::Statuses), + "read:notifications" => Some(Scope::Notifications), + "read:lists" => Some(Scope::Lists), + "write" | "follow" => None, // ignore write scopes + unexpected => { + log::warn!("Ignoring unknown scope `{}`", unexpected); + None + } + }) + .collect(); + // We don't need to separately track read auth - it's just all three others + if scopes.remove(&Scope::Read) { + scopes.insert(Scope::Statuses); + scopes.insert(Scope::Notifications); + scopes.insert(Scope::Lists); + } + + Ok(UserData { + id, + allowed_langs, + scopes, + }) + } else { + Err(warp::reject::custom("Error: Invalid access token")) + } + } + + pub fn select_hashtag_id(self, tag_name: &String) -> Result { + let mut conn = self.0.get().unwrap(); + let rows = &conn + .query( + " +SELECT id +FROM tags +WHERE name = $1 +LIMIT 1", + &[&tag_name], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + + match rows.get(0) { + Some(row) => Ok(row.get(0)), + None => Err(warp::reject::custom("Error: Hashtag does not exist.")), + } + } + + /// Query Postgres for everyone the user has blocked or muted + /// + /// **NOTE**: because we check this when the user connects, it will not include any blocks + /// the user adds until they refresh/reconnect. + pub fn select_blocked_users(self, user_id: i64) -> HashSet { + // " + // SELECT + // 1 + // FROM blocks + // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) + // OR (account_id = $2 AND target_account_id = $1) + // UNION SELECT + // 1 + // FROM mutes + // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})` + // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`" + self + .0 + .get() + .unwrap() + .query( + " +SELECT target_account_id + FROM blocks + WHERE account_id = $1 +UNION SELECT target_account_id + FROM mutes + WHERE account_id = $1", + &[&user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() + } + /// Query Postgres for everyone who has blocked the user + /// + /// **NOTE**: because we check this when the user connects, it will not include any blocks + /// the user adds until they refresh/reconnect. + pub fn select_blocking_users(self, user_id: i64) -> HashSet { + self + .0 + .get() + .unwrap() + .query( + " +SELECT account_id + FROM blocks + WHERE target_account_id = $1", + &[&user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() + } + + /// Query Postgres for all current domain blocks + /// + /// **NOTE**: because we check this when the user connects, it will not include any blocks + /// the user adds until they refresh/reconnect. + pub fn select_blocked_domains(self, user_id: i64) -> HashSet { + self + .0 + .get() + .unwrap() + .query( + "SELECT domain FROM account_domain_blocks WHERE account_id = $1", + &[&user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() + } + + /// Test whether a user owns a list + pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool { + let mut conn = self.0.get().unwrap(); + // For the Postgres query, `id` = list number; `account_id` = user.id + let rows = &conn + .query( + " +SELECT id, account_id +FROM lists +WHERE id = $1 +LIMIT 1", + &[&list_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + + match rows.get(0) { + None => false, + Some(row) => { + let list_owner_id: i64 = row.get(1); + list_owner_id == user_id + } + } + } +} diff --git a/src/parse_client_request/query.rs b/src/parse_client_request/query.rs index 3c92c40..c4445a9 100644 --- a/src/parse_client_request/query.rs +++ b/src/parse_client_request/query.rs @@ -28,6 +28,12 @@ impl Query { } macro_rules! make_query_type { + (Stream => $parameter:tt:$type:ty) => { + #[derive(Deserialize, Debug, Default)] + pub struct Stream { + pub $parameter: $type, + } + }; ($name:tt => $parameter:tt:$type:ty) => { #[derive(Deserialize, Debug, Default)] pub struct $name { @@ -59,14 +65,14 @@ impl ToString for Stream { } } -pub fn optional_media_query() -> BoxedFilter<(Media,)> { - warp::query() - .or(warp::any().map(|| Media { - only_media: "false".to_owned(), - })) - .unify() - .boxed() -} +// pub fn optional_media_query() -> BoxedFilter<(Media,)> { +// warp::query() +// .or(warp::any().map(|| Media { +// only_media: "false".to_owned(), +// })) +// .unify() +// .boxed() +// } pub struct OptionalAccessToken; diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs index b62461c..287cd9b 100644 --- a/src/parse_client_request/sse.rs +++ b/src/parse_client_request/sse.rs @@ -1,78 +1,4 @@ //! Filters for all the endpoints accessible for Server Sent Event updates -use super::{ - query::{self, Query}, - subscription::{PgPool, Subscription}, -}; -use warp::{filters::BoxedFilter, path, Filter}; -#[allow(dead_code)] -type TimelineUser = ((String, Subscription),); - -/// Helper macro to match on the first of any of the provided filters -macro_rules! any_of { - ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())*.boxed() - }; -} - -macro_rules! parse_query { - (path => $start:tt $(/ $next:tt)* - endpoint => $endpoint:expr) => { - path!($start $(/ $next)*) - .and(query::Auth::to_filter()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .map( - |auth: query::Auth, - media: query::Media, - hashtag: query::Hashtag, - list: query::List| { - Query { - access_token: auth.access_token, - stream: $endpoint.to_string(), - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - } - }, - ) - .boxed() - }; -} -pub fn extract_user_or_reject( - pg_pool: PgPool, - whitelist_mode: bool, -) -> BoxedFilter<(Subscription,)> { - any_of!( - parse_query!( - path => "api" / "v1" / "streaming" / "user" / "notification" - endpoint => "user:notification" ), - parse_query!( - path => "api" / "v1" / "streaming" / "user" - endpoint => "user"), - parse_query!( - path => "api" / "v1" / "streaming" / "public" / "local" - endpoint => "public:local"), - parse_query!( - path => "api" / "v1" / "streaming" / "public" - endpoint => "public"), - parse_query!( - path => "api" / "v1" / "streaming" / "direct" - endpoint => "direct"), - parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" - endpoint => "hashtag:local"), - parse_query!(path => "api" / "v1" / "streaming" / "hashtag" - endpoint => "hashtag"), - parse_query!(path => "api" / "v1" / "streaming" / "list" - endpoint => "list") - ) - // because SSE requests place their `access_token` in the header instead of in a query - // parameter, we need to update our Query if the header has a token - .and(query::OptionalAccessToken::from_sse_header()) - .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) - .boxed() -} // #[cfg(test)] // mod test { diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs new file mode 100644 index 0000000..d8e874a --- /dev/null +++ b/src/parse_client_request/subscription.rs @@ -0,0 +1,396 @@ +//! `User` struct and related functionality +// #[cfg(test)] +// mod mock_postgres; +// #[cfg(test)] +// use mock_postgres as postgres; +// #[cfg(not(test))] + +use super::postgres::PgPool; +use super::query::Query; +use crate::err::TimelineErr; +use crate::log_fatal; +use lru::LruCache; +use std::collections::HashSet; +use warp::reject::Rejection; + +use super::query; +use warp::{filters::BoxedFilter, path, Filter}; + +/// Helper macro to match on the first of any of the provided filters +macro_rules! any_of { + ($filter:expr, $($other_filter:expr),*) => { + $filter$(.or($other_filter).unify())*.boxed() + }; +} +macro_rules! parse_sse_query { + (path => $start:tt $(/ $next:tt)* + endpoint => $endpoint:expr) => { + path!($start $(/ $next)*) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map( + |auth: query::Auth, + media: query::Media, + hashtag: query::Hashtag, + list: query::List| { + Query { + access_token: auth.access_token, + stream: $endpoint.to_string(), + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() + }; +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Subscription { + pub timeline: Timeline, + pub allowed_langs: HashSet, + pub blocks: Blocks, + pub hashtag_name: Option, + pub access_token: Option, +} + +impl Default for Subscription { + fn default() -> Self { + Self { + timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification), + allowed_langs: HashSet::new(), + blocks: Blocks::default(), + hashtag_name: None, + access_token: None, + } + } +} + +impl Subscription { + pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> { + parse_ws_query() + .and(query::OptionalAccessToken::from_ws_header()) + .and_then(Query::update_access_token) + .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) + .boxed() + } + + pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> { + any_of!( + parse_sse_query!( + path => "api" / "v1" / "streaming" / "user" / "notification" + endpoint => "user:notification" ), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "user" + endpoint => "user"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "public" / "local" + endpoint => "public:local"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "public" + endpoint => "public"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "direct" + endpoint => "direct"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" + endpoint => "hashtag:local"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" + endpoint => "hashtag"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "list" + endpoint => "list") + ) + // because SSE requests place their `access_token` in the header instead of in a query + // parameter, we need to update our Query if the header has a token + .and(query::OptionalAccessToken::from_sse_header()) + .and_then(Query::update_access_token) + .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) + .boxed() + } + fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result { + let user = match q.access_token.clone() { + Some(token) => pool.clone().select_user(&token)?, + None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?, + None => UserData::public(), + }; + let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?; + let hashtag_name = match timeline { + Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag), + _non_hashtag_timeline => None, + }; + + Ok(Subscription { + timeline, + allowed_langs: user.allowed_langs, + blocks: Blocks { + blocking_users: pool.clone().select_blocking_users(user.id), + blocked_users: pool.clone().select_blocked_users(user.id), + blocked_domains: pool.clone().select_blocked_domains(user.id), + }, + hashtag_name, + access_token: q.access_token, + }) + } +} + +fn parse_ws_query() -> BoxedFilter<(Query,)> { + path!("api" / "v1" / "streaming") + .and(path::end()) + .and(warp::query()) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map( + |stream: query::Stream, + auth: query::Auth, + media: query::Media, + hashtag: query::Hashtag, + list: query::List| { + Query { + access_token: auth.access_token, + stream: stream.stream, + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub struct Timeline(pub Stream, pub Reach, pub Content); + +impl Timeline { + pub fn empty() -> Self { + use {Content::*, Reach::*, Stream::*}; + Self(Unset, Local, Notification) + } + + pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String { + use {Content::*, Reach::*, Stream::*}; + match self { + Timeline(Public, Federated, All) => "timeline:public".into(), + Timeline(Public, Local, All) => "timeline:public:local".into(), + Timeline(Public, Federated, Media) => "timeline:public:media".into(), + Timeline(Public, Local, Media) => "timeline:public:local:media".into(), + + Timeline(Hashtag(id), Federated, All) => format!( + "timeline:hashtag:{}", + hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) + ), + Timeline(Hashtag(id), Local, All) => format!( + "timeline:hashtag:{}:local", + hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) + ), + Timeline(User(id), Federated, All) => format!("timeline:{}", id), + Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), + Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), + Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), + Timeline(one, _two, _three) => { + log_fatal!("Supposedly impossible timeline reached: {:?}", one) + } + } + } + + pub fn from_redis_raw_timeline( + timeline: &str, + cache: &mut LruCache, + namespace: &Option, + ) -> Result { + use crate::err::TimelineErr::RedisNamespaceMismatch; + use {Content::*, Reach::*, Stream::*}; + let timeline_slice = &timeline.split(":").collect::>()[..]; + + #[rustfmt::skip] + let (stream, reach, content) = if let Some(ns) = namespace { + match timeline_slice { + [n, "timeline", "public"] if n == ns => (Public, Federated, All), + [_, "timeline", "public"] + | ["timeline", "public"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "public", "local"] if ns == n => (Public, Local, All), + [_, "timeline", "public", "local"] + | ["timeline", "public", "local"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "public", "media"] if ns == n => (Public, Federated, Media), + [_, "timeline", "public", "media"] + | ["timeline", "public", "media"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "public", "local", "media"] if ns == n => (Public, Local, Media), + [_, "timeline", "public", "local", "media"] + | ["timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "hashtag", tag_name] if ns == n => { + let tag_id = *cache + .get(&tag_name.to_string()) + .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); + (Hashtag(tag_id), Federated, All) + } + [_, "timeline", "hashtag", _tag] + | ["timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "hashtag", _tag, "local"] if ns == n => (Hashtag(0), Local, All), + [_, "timeline", "hashtag", _tag, "local"] + | ["timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", id] if ns == n => (User(id.parse().unwrap()), Federated, All), + [_, "timeline", _id] + | ["timeline", _id] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", id, "notification"] if ns == n => + (User(id.parse()?), Federated, Notification), + + [_, "timeline", _id, "notification"] + | ["timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?, + + + [n, "timeline", "list", id] if ns == n => (List(id.parse()?), Federated, All), + [_, "timeline", "list", _id] + | ["timeline", "list", _id] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "direct", id] if ns == n => (Direct(id.parse()?), Federated, All), + [_, "timeline", "direct", _id] + | ["timeline", "direct", _id] => Err(RedisNamespaceMismatch)?, + + [..] => log_fatal!("Unexpected channel from Redis: {:?}", timeline_slice), + } + } else { + match timeline_slice { + ["timeline", "public"] => (Public, Federated, All), + [_, "timeline", "public"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "public", "local"] => (Public, Local, All), + [_, "timeline", "public", "local"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "public", "media"] => (Public, Federated, Media), + + [_, "timeline", "public", "media"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "public", "local", "media"] => (Public, Local, Media), + [_, "timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "hashtag", _tag] => (Hashtag(0), Federated, All), + [_, "timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?, + + ["timeline", "hashtag", _tag, "local"] => (Hashtag(0), Local, All), + [_, "timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?, + + ["timeline", id] => (User(id.parse().unwrap()), Federated, All), + [_, "timeline", _id] => Err(RedisNamespaceMismatch)?, + + ["timeline", id, "notification"] => { + (User(id.parse().unwrap()), Federated, Notification) + } + [_, "timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "list", id] => (List(id.parse().unwrap()), Federated, All), + [_, "timeline", "list", _id] => Err(RedisNamespaceMismatch)?, + + ["timeline", "direct", id] => (Direct(id.parse().unwrap()), Federated, All), + [_, "timeline", "direct", _id] => Err(RedisNamespaceMismatch)?, + + // Other endpoints don't exist: + [..] => Err(TimelineErr::InvalidInput)?, + } + }; + + Ok(Timeline(stream, reach, content)) + } + fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result { + use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; + let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag); + let user_owns_list = || pool.clone().user_owns_list(user.id, q.list); + + Ok(match q.stream.as_ref() { + "public" => match q.media { + true => Timeline(Public, Federated, Media), + false => Timeline(Public, Federated, All), + }, + "public:local" => match q.media { + true => Timeline(Public, Local, Media), + false => Timeline(Public, Local, All), + }, + "public:media" => Timeline(Public, Federated, Media), + "public:local:media" => Timeline(Public, Local, Media), + + "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All), + "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All), + "user" => match user.scopes.contains(&Statuses) { + true => Timeline(User(user.id), Federated, All), + false => Err(custom("Error: Missing access token"))?, + }, + "user:notification" => match user.scopes.contains(&Statuses) { + true => Timeline(User(user.id), Federated, Notification), + false => Err(custom("Error: Missing access token"))?, + }, + "list" => match user.scopes.contains(&Lists) && user_owns_list() { + true => Timeline(List(q.list), Federated, All), + false => Err(warp::reject::custom("Error: Missing access token"))?, + }, + "direct" => match user.scopes.contains(&Statuses) { + true => Timeline(Direct(user.id), Federated, All), + false => Err(custom("Error: Missing access token"))?, + }, + other => { + log::warn!("Request for nonexistent endpoint: `{}`", other); + Err(custom("Error: Nonexistent endpoint"))? + } + }) + } +} +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Stream { + User(i64), + List(i64), + Direct(i64), + Hashtag(i64), + Public, + Unset, +} +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Reach { + Local, + Federated, +} +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Content { + All, + Media, + Notification, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Scope { + Read, + Statuses, + Notifications, + Lists, +} + +#[derive(Clone, Default, Debug, PartialEq)] +pub struct Blocks { + pub blocked_domains: HashSet, + pub blocked_users: HashSet, + pub blocking_users: HashSet, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct UserData { + pub id: i64, + pub allowed_langs: HashSet, + pub scopes: HashSet, +} + +impl UserData { + fn public() -> Self { + Self { + id: -1, + allowed_langs: HashSet::new(), + scopes: HashSet::new(), + } + } +} diff --git a/src/parse_client_request/subscription/mock_postgres.rs b/src/parse_client_request/subscription/mock_postgres.rs deleted file mode 100644 index d5a9612..0000000 --- a/src/parse_client_request/subscription/mock_postgres.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! Mock Postgres connection (for use in unit testing) -use super::{OauthScope, Subscription}; -use std::collections::HashSet; - -#[derive(Clone)] -pub struct PgPool; -impl PgPool { - pub fn new() -> Self { - Self - } -} - -pub fn select_user( - access_token: &str, - _pg_pool: PgPool, -) -> Result { - let mut user = Subscription::default(); - if access_token == "TEST_USER" { - user.id = 1; - user.logged_in = true; - user.access_token = "TEST_USER".to_string(); - user.email = "user@example.com".to_string(); - user.scopes = OauthScope::from(vec![ - "read".to_string(), - "write".to_string(), - "follow".to_string(), - ]); - } else if access_token == "INVALID" { - return Err(warp::reject::custom("Error: Invalid access token")); - } - Ok(user) -} - -pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet { - HashSet::new() -} -pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet { - HashSet::new() -} - -pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool { - user_id == list_id -} diff --git a/src/parse_client_request/subscription/mod.rs b/src/parse_client_request/subscription/mod.rs deleted file mode 100644 index 3fdb060..0000000 --- a/src/parse_client_request/subscription/mod.rs +++ /dev/null @@ -1,196 +0,0 @@ -//! `User` struct and related functionality -// #[cfg(test)] -// mod mock_postgres; -// #[cfg(test)] -// use mock_postgres as postgres; -// #[cfg(not(test))] -pub mod postgres; -pub use self::postgres::PgPool; -use super::query::Query; -use crate::log_fatal; -use std::collections::HashSet; -use warp::reject::Rejection; - -/// The User (with data read from Postgres) -#[derive(Clone, Debug, PartialEq)] -pub struct Subscription { - pub timeline: Timeline, - pub allowed_langs: HashSet, - pub blocks: Blocks, -} - -impl Default for Subscription { - fn default() -> Self { - Self { - timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification), - allowed_langs: HashSet::new(), - blocks: Blocks::default(), - } - } -} - -impl Subscription { - pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result { - let user = match q.access_token.clone() { - Some(token) => postgres::select_user(&token, pool.clone())?, - None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?, - None => UserData::public(), - }; - Ok(Subscription { - timeline: Timeline::from_query_and_user(&q, &user, pool.clone())?, - allowed_langs: user.allowed_langs, - blocks: Blocks { - blocking_users: postgres::select_blocking_users(user.id, pool.clone()), - blocked_users: postgres::select_blocked_users(user.id, pool.clone()), - blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()), - }, - }) - } -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub struct Timeline(pub Stream, pub Reach, pub Content); - -impl Timeline { - pub fn empty() -> Self { - use {Content::*, Reach::*, Stream::*}; - Self(Unset, Local, Notification) - } - - pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String { - use {Content::*, Reach::*, Stream::*}; - match self { - Timeline(Public, Federated, All) => "timeline:public".into(), - Timeline(Public, Local, All) => "timeline:public:local".into(), - Timeline(Public, Federated, Media) => "timeline:public:media".into(), - Timeline(Public, Local, Media) => "timeline:public:local:media".into(), - - Timeline(Hashtag(id), Federated, All) => format!( - "timeline:hashtag:{}", - hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) - ), - Timeline(Hashtag(id), Local, All) => format!( - "timeline:hashtag:{}:local", - hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) - ), - Timeline(User(id), Federated, All) => format!("timeline:{}", id), - Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), - Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), - Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), - Timeline(one, _two, _three) => { - log_fatal!("Supposedly impossible timeline reached: {:?}", one) - } - } - } - pub fn from_redis_raw_timeline(raw_timeline: &str, hashtag: Option) -> Self { - use {Content::*, Reach::*, Stream::*}; - match raw_timeline.split(':').collect::>()[..] { - ["public"] => Timeline(Public, Federated, All), - ["public", "local"] => Timeline(Public, Local, All), - ["public", "media"] => Timeline(Public, Federated, Media), - ["public", "local", "media"] => Timeline(Public, Local, Media), - - ["hashtag", _tag] => Timeline(Hashtag(hashtag.unwrap()), Federated, All), - ["hashtag", _tag, "local"] => Timeline(Hashtag(hashtag.unwrap()), Local, All), - [id] => Timeline(User(id.parse().unwrap()), Federated, All), - [id, "notification"] => Timeline(User(id.parse().unwrap()), Federated, Notification), - ["list", id] => Timeline(List(id.parse().unwrap()), Federated, All), - ["direct", id] => Timeline(Direct(id.parse().unwrap()), Federated, All), - // Other endpoints don't exist: - [..] => log_fatal!("Unexpected channel from Redis: {}", raw_timeline), - } - } - fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result { - use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; - let id_from_hashtag = || postgres::select_list_id(&q.hashtag, pool.clone()); - let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone()); - - Ok(match q.stream.as_ref() { - "public" => match q.media { - true => Timeline(Public, Federated, Media), - false => Timeline(Public, Federated, All), - }, - "public:local" => match q.media { - true => Timeline(Public, Local, Media), - false => Timeline(Public, Local, All), - }, - "public:media" => Timeline(Public, Federated, Media), - "public:local:media" => Timeline(Public, Local, Media), - - "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All), - "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All), - "user" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, All), - false => Err(custom("Error: Missing access token"))?, - }, - "user:notification" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, Notification), - false => Err(custom("Error: Missing access token"))?, - }, - "list" => match user.scopes.contains(&Lists) && user_owns_list() { - true => Timeline(List(q.list), Federated, All), - false => Err(warp::reject::custom("Error: Missing access token"))?, - }, - "direct" => match user.scopes.contains(&Statuses) { - true => Timeline(Direct(user.id), Federated, All), - false => Err(custom("Error: Missing access token"))?, - }, - other => { - log::warn!("Request for nonexistent endpoint: `{}`", other); - Err(custom("Error: Nonexistent endpoint"))? - } - }) - } -} -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Stream { - User(i64), - List(i64), - Direct(i64), - Hashtag(i64), - Public, - Unset, -} -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Reach { - Local, - Federated, -} -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Content { - All, - Media, - Notification, -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Scope { - Read, - Statuses, - Notifications, - Lists, -} - -#[derive(Clone, Default, Debug, PartialEq)] -pub struct Blocks { - pub blocked_domains: HashSet, - pub blocked_users: HashSet, - pub blocking_users: HashSet, -} - -#[derive(Clone, Debug, PartialEq)] -pub struct UserData { - id: i64, - allowed_langs: HashSet, - scopes: HashSet, -} - -impl UserData { - fn public() -> Self { - Self { - id: -1, - allowed_langs: HashSet::new(), - scopes: HashSet::new(), - } - } -} diff --git a/src/parse_client_request/subscription/postgres.rs b/src/parse_client_request/subscription/postgres.rs deleted file mode 100644 index 9353efa..0000000 --- a/src/parse_client_request/subscription/postgres.rs +++ /dev/null @@ -1,225 +0,0 @@ -//! Postgres queries -use crate::{ - config, - parse_client_request::subscription::{Scope, UserData}, -}; -use ::postgres; -use r2d2_postgres::PostgresConnectionManager; -use std::collections::HashSet; -use warp::reject::Rejection; - -#[derive(Clone, Debug)] -pub struct PgPool(pub r2d2::Pool>); -impl PgPool { - pub fn new(pg_cfg: config::PostgresConfig) -> Self { - let mut cfg = postgres::Config::new(); - cfg.user(&pg_cfg.user) - .host(&*pg_cfg.host.to_string()) - .port(*pg_cfg.port) - .dbname(&pg_cfg.database); - if let Some(password) = &*pg_cfg.password { - cfg.password(password); - }; - - let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); - let pool = r2d2::Pool::builder() - .max_size(10) - .build(manager) - .expect("Can connect to local postgres"); - Self(pool) - } -} - -pub fn select_user(token: &str, pool: PgPool) -> Result { - let mut conn = pool.0.get().unwrap(); - let query_rows = conn - .query( - " -SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes -FROM -oauth_access_tokens -INNER JOIN users ON -oauth_access_tokens.resource_owner_id = users.id -WHERE oauth_access_tokens.token = $1 -AND oauth_access_tokens.revoked_at IS NULL -LIMIT 1", - &[&token.to_owned()], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - if let Some(result_columns) = query_rows.get(0) { - let id = result_columns.get(1); - let allowed_langs = result_columns - .try_get::<_, Vec<_>>(2) - .unwrap_or_else(|_| Vec::new()) - .into_iter() - .collect(); - let mut scopes: HashSet = result_columns - .get::<_, String>(3) - .split(' ') - .filter_map(|scope| match scope { - "read" => Some(Scope::Read), - "read:statuses" => Some(Scope::Statuses), - "read:notifications" => Some(Scope::Notifications), - "read:lists" => Some(Scope::Lists), - "write" | "follow" => None, // ignore write scopes - unexpected => { - log::warn!("Ignoring unknown scope `{}`", unexpected); - None - } - }) - .collect(); - // We don't need to separately track read auth - it's just all three others - if scopes.remove(&Scope::Read) { - scopes.insert(Scope::Statuses); - scopes.insert(Scope::Notifications); - scopes.insert(Scope::Lists); - } - - Ok(UserData { - id, - allowed_langs, - scopes, - }) - } else { - Err(warp::reject::custom("Error: Invalid access token")) - } -} - -pub fn select_list_id(tag_name: &String, pg_pool: PgPool) -> Result { - let mut conn = pg_pool.0.get().unwrap(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " -SELECT id -FROM tags -WHERE name = $1 -LIMIT 1", - &[&tag_name], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - Some(row) => Ok(row.get(0)), - None => Err(warp::reject::custom("Error: Hashtag does not exist.")), - } -} -pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result { - let mut conn = pg_pool.0.get().unwrap(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " -SELECT name -FROM tags -WHERE id = $1 -LIMIT 1", - &[&tag_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - Some(row) => Ok(row.get(0)), - None => Err(warp::reject::custom("Error: Hashtag does not exist.")), - } -} - -/// Query Postgres for everyone the user has blocked or muted -/// -/// **NOTE**: because we check this when the user connects, it will not include any blocks -/// the user adds until they refresh/reconnect. -pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet { - // " - // SELECT - // 1 - // FROM blocks - // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) - // OR (account_id = $2 AND target_account_id = $1) - // UNION SELECT - // 1 - // FROM mutes - // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})` - // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`" - pg_pool - .0 - .get() - .unwrap() - .query( - " -SELECT target_account_id - FROM blocks - WHERE account_id = $1 -UNION SELECT target_account_id - FROM mutes - WHERE account_id = $1", - &[&user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() -} -/// Query Postgres for everyone who has blocked the user -/// -/// **NOTE**: because we check this when the user connects, it will not include any blocks -/// the user adds until they refresh/reconnect. -pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet { - pg_pool - .0 - .get() - .unwrap() - .query( - " -SELECT account_id - FROM blocks - WHERE target_account_id = $1", - &[&user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() -} - -/// Query Postgres for all current domain blocks -/// -/// **NOTE**: because we check this when the user connects, it will not include any blocks -/// the user adds until they refresh/reconnect. -pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet { - pg_pool - .0 - .get() - .unwrap() - .query( - "SELECT domain FROM account_domain_blocks WHERE account_id = $1", - &[&user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() -} - -/// Test whether a user owns a list -pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool { - let mut conn = pg_pool.0.get().unwrap(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " -SELECT id, account_id -FROM lists -WHERE id = $1 -LIMIT 1", - &[&list_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - None => false, - Some(row) => { - let list_owner_id: i64 = row.get(1); - list_owner_id == user_id - } - } -} diff --git a/src/parse_client_request/subscription/stdin b/src/parse_client_request/subscription/stdin deleted file mode 100644 index e69de29..0000000 diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index aa74c60..e101bed 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -1,48 +1,9 @@ //! Filters for the WebSocket endpoint -use super::{ - query::{self, Query}, - subscription::{PgPool, Subscription}, -}; -use warp::{filters::BoxedFilter, path, Filter}; -/// WebSocket filters -fn parse_query() -> BoxedFilter<(Query,)> { - path!("api" / "v1" / "streaming") - .and(path::end()) - .and(warp::query()) - .and(query::Auth::to_filter()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .map( - |stream: query::Stream, - auth: query::Auth, - media: query::Media, - hashtag: query::Hashtag, - list: query::List| { - Query { - access_token: auth.access_token, - stream: stream.stream, - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - } - }, - ) - .boxed() -} -pub fn extract_user_and_token_or_reject( - pg_pool: PgPool, - whitelist_mode: bool, -) -> BoxedFilter<(Subscription, Option)> { - parse_query() - .and(query::OptionalAccessToken::from_ws_header()) - .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) - .and(query::OptionalAccessToken::from_ws_header()) - .boxed() -} + + + // #[cfg(test)] // mod test { diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs index 0162598..18c79d3 100644 --- a/src/redis_to_client_stream/client_agent.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -19,29 +19,29 @@ use super::receiver::Receiver; use crate::{ config, messages::Event, - parse_client_request::subscription::{PgPool, Stream::Public, Subscription, Timeline}, + parse_client_request::{Stream::Public, Subscription, Timeline}, }; use futures::{ Async::{self, NotReady, Ready}, Poll, }; -use std::sync; +use std::sync::{Arc, Mutex}; use tokio::io::Error; use uuid::Uuid; /// Struct for managing all Redis streams. #[derive(Clone, Debug)] pub struct ClientAgent { - receiver: sync::Arc>, - id: uuid::Uuid, + receiver: Arc>, + id: Uuid, pub subscription: Subscription, } impl ClientAgent { /// Create a new `ClientAgent` with no shared data. - pub fn blank(redis_cfg: config::RedisConfig, pg_pool: PgPool) -> Self { + pub fn blank(redis_cfg: config::RedisConfig) -> Self { ClientAgent { - receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg, pg_pool))), + receiver: Arc::new(Mutex::new(Receiver::new(redis_cfg))), id: Uuid::default(), subscription: Subscription::default(), } @@ -70,7 +70,11 @@ impl ClientAgent { self.subscription = subscription; let start_time = Instant::now(); let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)"); - receiver.manage_new_timeline(self.id, self.subscription.timeline); + receiver.manage_new_timeline( + self.id, + self.subscription.timeline, + self.subscription.hashtag_name.clone(), + ); log::info!("init_for_user had lock for: {:?}", start_time.elapsed()); } } diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs new file mode 100644 index 0000000..5157120 --- /dev/null +++ b/src/redis_to_client_stream/event_stream.rs @@ -0,0 +1,103 @@ +use super::ClientAgent; + +use warp::ws::WebSocket; +use futures::{future::Future, stream::Stream, Async}; +use log; +use std::time::{Duration, Instant}; + +pub struct EventStream; + +impl EventStream { + + +/// Send a stream of replies to a WebSocket client. + pub fn to_ws( + socket: WebSocket, + mut client_agent: ClientAgent, + update_interval: Duration, +) -> impl Future { + let (ws_tx, mut ws_rx) = socket.split(); + let timeline = client_agent.subscription.timeline; + + // Create a pipe + let (tx, rx) = futures::sync::mpsc::unbounded(); + + // Send one end of it to a different thread and tell that end to forward whatever it gets + // on to the websocket client + warp::spawn( + rx.map_err(|()| -> warp::Error { unreachable!() }) + .forward(ws_tx) + .map(|_r| ()) + .map_err(|e| match e.to_string().as_ref() { + "IO error: Broken pipe (os error 32)" => (), // just closed unix socket + _ => log::warn!("websocket send error: {}", e), + }), + ); + + // Yield new events for as long as the client is still connected + let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while( + move |_| match ws_rx.poll() { + Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true), + Ok(Async::Ready(None)) => { + // TODO: consider whether we should manually drop closed connections here + log::info!("Client closed WebSocket connection for {:?}", timeline); + futures::future::ok(false) + } + Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => { + // no err, just closed Unix socket + log::info!("Client closed WebSocket connection for {:?}", timeline); + futures::future::ok(false) + } + Err(e) => { + log::warn!("Error in {:?}: {}", timeline, e); + futures::future::ok(false) + } + }, + ); + + let mut time = Instant::now(); + // Every time you get an event from that stream, send it through the pipe + event_stream + .for_each(move |_instant| { + if let Ok(Async::Ready(Some(msg))) = client_agent.poll() { + tx.unbounded_send(warp::ws::Message::text(msg.to_json_string())) + .expect("No send error"); + }; + if time.elapsed() > Duration::from_secs(30) { + tx.unbounded_send(warp::ws::Message::text("{}")) + .expect("Can ping"); + time = Instant::now(); + } + Ok(()) + }) + .then(move |result| { + // TODO: consider whether we should manually drop closed connections here + log::info!("WebSocket connection for {:?} closed.", timeline); + result + }) + .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e)) +} + pub fn to_sse( + mut client_agent: ClientAgent, + connection: warp::sse::Sse, + update_interval: Duration, +) ->impl warp::reply::Reply { + let event_stream = + tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| { + match client_agent.poll() { + Ok(Async::Ready(Some(event))) => Some(( + warp::sse::event(event.event_name()), + warp::sse::data(event.payload().unwrap_or_else(String::new)), + )), + _ => None, + } + }); + + connection.reply( + warp::sse::keep_alive() + .interval(Duration::from_secs(30)) + .text("thump".to_string()) + .stream(event_stream), + ) +} +} diff --git a/src/redis_to_client_stream/message.rs b/src/redis_to_client_stream/message.rs deleted file mode 100644 index 10cd1d6..0000000 --- a/src/redis_to_client_stream/message.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::log_fatal; -use crate::messages::Event; -use serde_json::Value; -use std::{collections::HashSet, string::String}; -use strum_macros::Display; - -#[derive(Debug, Display, Clone)] -pub enum Message { - Update(Status), - Conversation(Value), - Notification(Value), - Delete(String), - FiltersChanged, - Announcement(AnnouncementType), - UnknownEvent(String, Value), -} - -#[derive(Debug, Clone)] -pub struct Status(Value); - -#[derive(Debug, Clone)] -pub enum AnnouncementType { - New(Value), - Delete(String), - Reaction(Value), -} - -impl Message { - // pub fn from_json(event: Event) -> Self { - // use AnnouncementType::*; - - // match event.event.as_ref() { - // "update" => Self::Update(Status(event.payload)), - // "conversation" => Self::Conversation(event.payload), - // "notification" => Self::Notification(event.payload), - // "delete" => Self::Delete( - // event - // .payload - // .as_str() - // .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event)) - // .to_string(), - // ), - // "filters_changed" => Self::FiltersChanged, - // "announcement" => Self::Announcement(New(event.payload)), - // "announcement.reaction" => Self::Announcement(Reaction(event.payload)), - // "announcement.delete" => Self::Announcement(Delete( - // event - // .payload - // .as_str() - // .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event)) - // .to_string(), - // )), - // other => { - // log::warn!("Received unexpected `event` from Redis: {}", other); - // Self::UnknownEvent(event.event.to_string(), event.payload) - // } - // } - // } - pub fn event(&self) -> String { - use AnnouncementType::*; - match self { - Self::Update(_) => "update", - Self::Conversation(_) => "conversation", - Self::Notification(_) => "notification", - Self::Announcement(New(_)) => "announcement", - Self::Announcement(Reaction(_)) => "announcement.reaction", - Self::UnknownEvent(event, _) => &event, - Self::Delete(_) => "delete", - Self::Announcement(Delete(_)) => "announcement.delete", - Self::FiltersChanged => "filters_changed", - } - .to_string() - } - pub fn payload(&self) -> String { - use AnnouncementType::*; - match self { - Self::Update(status) => status.0.to_string(), - Self::Conversation(value) - | Self::Notification(value) - | Self::Announcement(New(value)) - | Self::Announcement(Reaction(value)) - | Self::UnknownEvent(_, value) => value.to_string(), - Self::Delete(id) | Self::Announcement(Delete(id)) => id.clone(), - Self::FiltersChanged => "".to_string(), - } - } -} diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index 5295073..e4cdfb5 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -1,103 +1,14 @@ //! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. -pub mod client_agent; -pub mod receiver; -pub mod redis; -pub use client_agent::ClientAgent; -use futures::{future::Future, stream::Stream, Async}; -use log; -use std::time::{Duration, Instant}; +mod client_agent; +mod event_stream; +mod receiver; +mod redis; -/// Send a stream of replies to a Server Sent Events client. -pub fn send_updates_to_sse( - mut client_agent: ClientAgent, - connection: warp::sse::Sse, - update_interval: Duration, -) -> impl warp::reply::Reply { - let event_stream = - tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| { - match client_agent.poll() { - Ok(Async::Ready(Some(event))) => Some(( - warp::sse::event(event.event_name()), - warp::sse::data(event.payload().unwrap_or_else(String::new)), - )), - _ => None, - } - }); +pub use {client_agent::ClientAgent, event_stream::EventStream}; - connection.reply( - warp::sse::keep_alive() - .interval(Duration::from_secs(30)) - .text("thump".to_string()) - .stream(event_stream), - ) -} - -use warp::ws::WebSocket; - -/// Send a stream of replies to a WebSocket client. -pub fn send_updates_to_ws( - socket: WebSocket, - mut client_agent: ClientAgent, - update_interval: Duration, -) -> impl Future { - let (ws_tx, mut ws_rx) = socket.split(); - let timeline = client_agent.subscription.timeline; - - // Create a pipe - let (tx, rx) = futures::sync::mpsc::unbounded(); - - // Send one end of it to a different thread and tell that end to forward whatever it gets - // on to the websocket client - warp::spawn( - rx.map_err(|()| -> warp::Error { unreachable!() }) - .forward(ws_tx) - .map(|_r| ()) - .map_err(|e| match e.to_string().as_ref() { - "IO error: Broken pipe (os error 32)" => (), // just closed unix socket - _ => log::warn!("websocket send error: {}", e), - }), - ); - - // Yield new events for as long as the client is still connected - let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while( - move |_| match ws_rx.poll() { - Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true), - Ok(Async::Ready(None)) => { - // TODO: consider whether we should manually drop closed connections here - log::info!("Client closed WebSocket connection for {:?}", timeline); - futures::future::ok(false) - } - Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => { - // no err, just closed Unix socket - log::info!("Client closed WebSocket connection for {:?}", timeline); - futures::future::ok(false) - } - Err(e) => { - log::warn!("Error in {:?}: {}", timeline, e); - futures::future::ok(false) - } - }, - ); - - let mut time = Instant::now(); - // Every time you get an event from that stream, send it through the pipe - event_stream - .for_each(move |_instant| { - if let Ok(Async::Ready(Some(msg))) = client_agent.poll() { - tx.unbounded_send(warp::ws::Message::text(msg.to_json_string())) - .expect("No send error"); - }; - if time.elapsed() > Duration::from_secs(30) { - tx.unbounded_send(warp::ws::Message::text("{}")) - .expect("Can ping"); - time = Instant::now(); - } - Ok(()) - }) - .then(move |result| { - // TODO: consider whether we should manually drop closed connections here - log::info!("WebSocket connection for {:?} closed.", timeline); - result - }) - .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e)) -} +#[cfg(test)] +pub use receiver::process_messages; +#[cfg(test)] +pub use receiver::{MessageQueues, MsgQueue}; +#[cfg(test)] +pub use redis::redis_msg::RedisMsg; diff --git a/src/redis_to_client_stream/receiver/message_queues.rs b/src/redis_to_client_stream/receiver/message_queues.rs index ffcd31e..eb45b30 100644 --- a/src/redis_to_client_stream/receiver/message_queues.rs +++ b/src/redis_to_client_stream/receiver/message_queues.rs @@ -1,5 +1,5 @@ use crate::messages::Event; -use crate::parse_client_request::subscription::Timeline; +use crate::parse_client_request::Timeline; use std::{ collections::{HashMap, VecDeque}, fmt, diff --git a/src/redis_to_client_stream/receiver/mod.rs b/src/redis_to_client_stream/receiver/mod.rs index 364d1fd..87b35ad 100644 --- a/src/redis_to_client_stream/receiver/mod.rs +++ b/src/redis_to_client_stream/receiver/mod.rs @@ -2,62 +2,70 @@ //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. mod message_queues; + +pub use message_queues::{MessageQueues, MsgQueue}; + use crate::{ - config::{self, RedisInterval}, - log_fatal, + config, + err::RedisParseErr, messages::Event, - parse_client_request::subscription::{self, postgres, PgPool, Timeline}, + parse_client_request::{Stream, Timeline}, pubsub_cmd, - redis_to_client_stream::redis::{redis_cmd, RedisConn, RedisStream}, + redis_to_client_stream::redis::redis_msg::RedisMsg, + redis_to_client_stream::redis::{redis_cmd, RedisConn}, }; use futures::{Async, Poll}; use lru::LruCache; -pub use message_queues::{MessageQueues, MsgQueue}; -use std::{collections::HashMap, net, time::Instant}; +use tokio::io::AsyncRead; + +use std::{ + collections::HashMap, + io::Read, + net, str, + time::{Duration, Instant}, +}; use tokio::io::Error; use uuid::Uuid; /// The item that streams from Redis and is polled by the `ClientAgent` #[derive(Debug)] pub struct Receiver { - pub pubsub_connection: RedisStream, + pub pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, - redis_poll_interval: RedisInterval, + redis_poll_interval: Duration, redis_polled_at: Instant, timeline: Timeline, manager_id: Uuid, pub msg_queues: MessageQueues, clients_per_timeline: HashMap, cache: Cache, - pool: PgPool, + redis_input: Vec, + redis_namespace: Option, } + #[derive(Debug)] pub struct Cache { + // TODO: eventually, it might make sense to have Mastodon publish to timelines with + // the tag number instead of the tag name. This would save us from dealing + // with a cache here and would be consistent with how lists/users are handled. id_to_hashtag: LruCache, pub hashtag_to_id: LruCache, } -impl Cache { - fn new(size: usize) -> Self { - Self { - id_to_hashtag: LruCache::new(size), - hashtag_to_id: LruCache::new(size), - } - } -} + impl Receiver { /// Create a new `Receiver`, with its own Redis connections (but, as yet, no /// active subscriptions). - pub fn new(redis_cfg: config::RedisConfig, pool: PgPool) -> Self { + pub fn new(redis_cfg: config::RedisConfig) -> Self { + let redis_namespace = redis_cfg.namespace.clone(); + let RedisConn { primary: pubsub_connection, secondary: secondary_redis_connection, - namespace: redis_namespace, polling_interval: redis_poll_interval, } = RedisConn::new(redis_cfg); Self { - pubsub_connection: RedisStream::from_stream(pubsub_connection) - .with_namespace(redis_namespace), + pubsub_connection, secondary_redis_connection, redis_poll_interval, redis_polled_at: Instant::now(), @@ -65,8 +73,12 @@ impl Receiver { manager_id: Uuid::default(), msg_queues: MessageQueues(HashMap::new()), clients_per_timeline: HashMap::new(), - cache: Cache::new(1000), // should this be a run-time option? - pool, + cache: Cache { + id_to_hashtag: LruCache::new(1000), + hashtag_to_id: LruCache::new(1000), + }, // should these be run-time options? + redis_input: Vec::new(), + redis_namespace, } } @@ -76,12 +88,15 @@ impl Receiver { /// Note: this method calls `subscribe_or_unsubscribe_as_needed`, /// so Redis PubSub subscriptions are only updated when a new timeline /// comes under management for the first time. - pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: Timeline) { - self.manager_id = manager_id; - self.timeline = timeline; - self.msg_queues - .insert(self.manager_id, MsgQueue::new(timeline)); - self.subscribe_or_unsubscribe_as_needed(timeline); + pub fn manage_new_timeline(&mut self, id: Uuid, tl: Timeline, hashtag: Option) { + self.timeline = tl; + if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (hashtag, tl) { + self.cache.id_to_hashtag.put(id, hashtag.clone()); + self.cache.hashtag_to_id.put(hashtag, id); + }; + + self.msg_queues.insert(id, MsgQueue::new(tl)); + self.subscribe_or_unsubscribe_as_needed(tl); } /// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate @@ -91,26 +106,6 @@ impl Receiver { self.timeline = timeline; } - fn if_hashtag_timeline_get_hashtag_name(&mut self, timeline: Timeline) -> Option { - use subscription::Stream::*; - if let Timeline(Hashtag(id), _, _) = timeline { - let cached_tag = self.cache.id_to_hashtag.get(&id).map(String::from); - let tag = match cached_tag { - Some(tag) => tag, - None => { - let new_tag = postgres::select_hashtag_name(&id, self.pool.clone()) - .unwrap_or_else(|_| log_fatal!("No hashtag associated with tag #{}", &id)); - self.cache.hashtag_to_id.put(new_tag.clone(), id); - self.cache.id_to_hashtag.put(id, new_tag.clone()); - new_tag.to_string() - } - }; - Some(tag) - } else { - None - } - } - /// Drop any PubSub subscriptions that don't have active clients and check /// that there's a subscription to the current one. If there isn't, then /// subscribe to it. @@ -121,8 +116,10 @@ impl Receiver { // Record the lower number of clients subscribed to that channel for change in timelines_to_modify { let timeline = change.timeline; - let hashtag = self.if_hashtag_timeline_get_hashtag_name(timeline); - let hashtag = hashtag.as_ref(); + let hashtag = match timeline { + Timeline(Stream::Hashtag(id), _, _) => self.cache.id_to_hashtag.get(&id), + _non_hashtag_timeline => None, + }; let count_of_subscribed_clients = self .clients_per_timeline @@ -157,10 +154,27 @@ impl futures::stream::Stream for Receiver { fn poll(&mut self) -> Poll, Self::Error> { let (timeline, id) = (self.timeline.clone(), self.manager_id); - if self.redis_polled_at.elapsed() > *self.redis_poll_interval { - self.pubsub_connection - .poll_redis(&mut self.cache.hashtag_to_id, &mut self.msg_queues); - self.redis_polled_at = Instant::now(); + if self.redis_polled_at.elapsed() > self.redis_poll_interval { + let mut buffer = vec![0u8; 6000]; + if let Ok(Async::Ready(bytes_read)) = self.poll_read(&mut buffer) { + let binary_input = buffer[..bytes_read].to_vec(); + let (input, extra_bytes) = match str::from_utf8(&binary_input) { + Ok(input) => (input, "".as_bytes()), + Err(e) => { + let (valid, after_valid) = binary_input.split_at(e.valid_up_to()); + let input = str::from_utf8(valid).expect("Guaranteed by `.valid_up_to`"); + (input, after_valid) + } + }; + + let (cache, namespace) = (&mut self.cache.hashtag_to_id, &self.redis_namespace); + + let remaining_input = + process_messages(input, cache, namespace, &mut self.msg_queues); + + self.redis_input.extend_from_slice(remaining_input); + self.redis_input.extend_from_slice(extra_bytes); + } } // Record current time as last polled time @@ -173,3 +187,49 @@ impl futures::stream::Stream for Receiver { } } } + +impl Read for Receiver { + fn read(&mut self, buffer: &mut [u8]) -> Result { + self.pubsub_connection.read(buffer) + } +} + +impl AsyncRead for Receiver { + fn poll_read(&mut self, buf: &mut [u8]) -> Poll { + match self.read(buf) { + Ok(t) => Ok(Async::Ready(t)), + Err(_) => Ok(Async::NotReady), + } + } +} + +#[must_use] +pub fn process_messages<'a>( + input: &'a str, + mut cache: &mut LruCache, + namespace: &Option, + msg_queues: &mut MessageQueues, +) -> &'a [u8] { + let mut remaining_input = input; + use RedisMsg::*; + loop { + match RedisMsg::from_raw(&mut remaining_input, &mut cache, namespace) { + Ok((EventMsg(timeline, event), rest)) => { + for msg_queue in msg_queues.values_mut() { + if msg_queue.timeline == timeline { + msg_queue.messages.push_back(event.clone()); + } + } + remaining_input = rest; + } + Ok((SubscriptionMsg, rest)) | Ok((MsgForDifferentNamespace, rest)) => { + remaining_input = rest; + } + Err(RedisParseErr::Incomplete) => break, + Err(RedisParseErr::Unrecoverable) => { + panic!("Failed parsing Redis msg: {}", &remaining_input) + } + }; + } + remaining_input.as_bytes() +} diff --git a/src/redis_to_client_stream/redis/mod.rs b/src/redis_to_client_stream/redis/mod.rs index c63e243..70ad337 100644 --- a/src/redis_to_client_stream/redis/mod.rs +++ b/src/redis_to_client_stream/redis/mod.rs @@ -1,9 +1,5 @@ pub mod redis_cmd; pub mod redis_connection; pub mod redis_msg; -pub mod redis_stream; pub use redis_connection::RedisConn; -pub use redis_stream::RedisStream; - - diff --git a/src/redis_to_client_stream/redis/redis_cmd.rs b/src/redis_to_client_stream/redis/redis_cmd.rs index b8d8b32..f353ffc 100644 --- a/src/redis_to_client_stream/redis/redis_cmd.rs +++ b/src/redis_to_client_stream/redis/redis_cmd.rs @@ -7,7 +7,7 @@ macro_rules! pubsub_cmd { ($cmd:expr, $self:expr, $tl:expr) => {{ use std::io::Write; log::info!("Sending {} command to {}", $cmd, $tl); - let namespace = $self.pubsub_connection.namespace.clone(); + let namespace = $self.redis_namespace.clone(); $self .pubsub_connection diff --git a/src/redis_to_client_stream/redis/redis_connection.rs b/src/redis_to_client_stream/redis/redis_connection.rs index 26c06d3..8261695 100644 --- a/src/redis_to_client_stream/redis/redis_connection.rs +++ b/src/redis_to_client_stream/redis/redis_connection.rs @@ -1,13 +1,12 @@ use super::redis_cmd; -use crate::config::{RedisConfig, RedisInterval, RedisNamespace}; +use crate::config::RedisConfig; use crate::err; -use std::{io::Read, io::Write, net, time}; +use std::{io::Read, io::Write, net, time::Duration}; pub struct RedisConn { pub primary: net::TcpStream, pub secondary: net::TcpStream, - pub namespace: RedisNamespace, - pub polling_interval: RedisInterval, + pub polling_interval: Duration, } fn send_password(mut conn: net::TcpStream, password: &str) -> net::TcpStream { @@ -68,7 +67,7 @@ impl RedisConn { conn = send_password(conn, &password); } conn = send_test_ping(conn); - conn.set_read_timeout(Some(time::Duration::from_millis(10))) + conn.set_read_timeout(Some(Duration::from_millis(10))) .expect("Can set read timeout for Redis connection"); if let Some(db) = &*redis_cfg.db { conn = set_db(conn, db); @@ -86,8 +85,7 @@ impl RedisConn { Self { primary: primary_conn, secondary: secondary_conn, - namespace: redis_cfg.namespace, - polling_interval: redis_cfg.polling_interval, + polling_interval: *redis_cfg.polling_interval, } } } diff --git a/src/redis_to_client_stream/redis/redis_msg.rs b/src/redis_to_client_stream/redis/redis_msg.rs index f03cd09..d0ef023 100644 --- a/src/redis_to_client_stream/redis/redis_msg.rs +++ b/src/redis_to_client_stream/redis/redis_msg.rs @@ -18,26 +18,31 @@ //! three characters, the second is a bulk string with ten characters, and the third is a //! bulk string with 1,386 characters. -use crate::{log_fatal, messages::Event, parse_client_request::subscription::Timeline}; +use crate::{ + err::{RedisParseErr, TimelineErr}, + messages::Event, + parse_client_request::Timeline, +}; use lru::LruCache; -type Parser<'a, Item> = Result<(Item, &'a str), ParseErr>; -#[derive(Debug)] -pub enum ParseErr { - Incomplete, - Unrecoverable, -} -use ParseErr::*; + +type Parser<'a, Item> = Result<(Item, &'a str), RedisParseErr>; /// A message that has been parsed from an incoming raw message from Redis. #[derive(Debug, Clone)] pub enum RedisMsg { EventMsg(Timeline, Event), SubscriptionMsg, + MsgForDifferentNamespace, } +use RedisParseErr::*; type Hashtags = LruCache; impl RedisMsg { - pub fn from_raw<'a>(input: &'a str, cache: &mut Hashtags, prefix: usize) -> Parser<'a, Self> { + pub fn from_raw<'a>( + input: &'a str, + cache: &mut Hashtags, + namespace: &Option, + ) -> Parser<'a, Self> { // No need to parse the Redis Array header, just skip it let input = input.get("*3\r\n".len()..).ok_or(Incomplete)?; let (command, rest) = parse_redis_bulk_string(&input)?; @@ -46,14 +51,16 @@ impl RedisMsg { // Messages look like; // $10\r\ntimeline:4\r\n // $1386\r\n{\"event\":\"update\",\"payload\"...\"queued_at\":1569623342825}\r\n - let (raw_timeline, rest) = parse_redis_bulk_string(&rest)?; + let (timeline, rest) = parse_redis_bulk_string(&rest)?; let (msg_txt, rest) = parse_redis_bulk_string(&rest)?; + let event: Event = serde_json::from_str(&msg_txt).map_err(|_| Unrecoverable)?; - let raw_timeline = &raw_timeline.get(prefix..).ok_or(Unrecoverable)?; - let event: Event = serde_json::from_str(&msg_txt).unwrap(); - let hashtag = hashtag_from_timeline(&raw_timeline, cache); - let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag); - Ok((Self::EventMsg(timeline, event), rest)) + use TimelineErr::*; + match Timeline::from_redis_raw_timeline(timeline, cache, namespace) { + Ok(timeline) => Ok((Self::EventMsg(timeline, event), rest)), + Err(RedisNamespaceMismatch) => Ok((Self::MsgForDifferentNamespace, rest)), + Err(InvalidInput) => Err(RedisParseErr::Unrecoverable), + } } "subscribe" | "unsubscribe" => { // subscription statuses look like: @@ -101,18 +108,3 @@ fn parse_number_at(input: &str) -> Parser { let rest = &input.get(number_len..).ok_or(Incomplete)?; Ok((number, rest)) } -fn hashtag_from_timeline(raw_timeline: &str, hashtag_id_cache: &mut Hashtags) -> Option { - if raw_timeline.starts_with("hashtag") { - let tag_name = raw_timeline - .split(':') - .nth(1) - .unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline)) - .to_string(); - let tag_id = *hashtag_id_cache - .get(&tag_name) - .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); - Some(tag_id) - } else { - None - } -} diff --git a/src/redis_to_client_stream/redis/redis_stream.rs b/src/redis_to_client_stream/redis/redis_stream.rs deleted file mode 100644 index b50ca0c..0000000 --- a/src/redis_to_client_stream/redis/redis_stream.rs +++ /dev/null @@ -1,127 +0,0 @@ -use super::redis_msg::{ParseErr, RedisMsg}; -use crate::config::RedisNamespace; -use crate::log_fatal; -use crate::redis_to_client_stream::receiver::MessageQueues; -use futures::{Async, Poll}; -use lru::LruCache; -use std::{error::Error, io::Read, net}; -use tokio::io::AsyncRead; - -#[derive(Debug)] -pub struct RedisStream { - pub inner: net::TcpStream, - incoming_raw_msg: String, - pub namespace: RedisNamespace, -} - -impl RedisStream { - pub fn from_stream(inner: net::TcpStream) -> Self { - RedisStream { - inner, - incoming_raw_msg: String::new(), - namespace: RedisNamespace(None), - } - } - pub fn with_namespace(self, namespace: RedisNamespace) -> Self { - RedisStream { namespace, ..self } - } - // Text comes in from redis as a raw stream, which could be more than one message and - // is not guaranteed to end on a message boundary. We need to break it down into - // messages. Incoming messages *are* guaranteed to be RESP arrays (though still not - // guaranteed to end at an array boundary). See https://redis.io/topics/protocol - /// Adds any new Redis messages to the `MsgQueue` for the appropriate `ClientAgent`. - pub fn poll_redis( - &mut self, - hashtag_to_id_cache: &mut LruCache, - queues: &mut MessageQueues, - ) { - let mut buffer = vec![0u8; 6000]; - if let Ok(Async::Ready(num_bytes_read)) = self.poll_read(&mut buffer) { - let raw_utf = self.as_utf8(buffer, num_bytes_read); - self.incoming_raw_msg.push_str(&raw_utf); - match process_messages( - self.incoming_raw_msg.clone(), - &mut self.namespace.0, - hashtag_to_id_cache, - queues, - ) { - Ok(None) => self.incoming_raw_msg.clear(), - Ok(Some(msg_fragment)) => self.incoming_raw_msg = msg_fragment, - Err(e) => { - log::error!("{}", e); - log_fatal!("Could not process RedisStream: {:?}", &self); - } - } - } - } - - fn as_utf8(&mut self, cur_buffer: Vec, size: usize) -> String { - String::from_utf8(cur_buffer[..size].to_vec()).unwrap_or_else(|_| { - let mut new_buffer = vec![0u8; 1]; - self.poll_read(&mut new_buffer).unwrap(); - let buffer = ([cur_buffer, new_buffer]).concat(); - self.as_utf8(buffer, size + 1) - }) - } -} - -type HashtagCache = LruCache; -pub fn process_messages( - raw_msg: String, - namespace: &mut Option, - cache: &mut HashtagCache, - queues: &mut MessageQueues, -) -> Result, Box> { - let prefix_len = match namespace { - Some(namespace) => format!("{}:timeline:", namespace).len(), - None => "timeline:".len(), - }; - - let mut input = raw_msg.as_str(); - loop { - let rest = match RedisMsg::from_raw(&input, cache, prefix_len) { - Ok((RedisMsg::EventMsg(timeline, event), rest)) => { - for msg_queue in queues.values_mut() { - if msg_queue.timeline == timeline { - msg_queue.messages.push_back(event.clone()); - } - } - rest - } - Ok((RedisMsg::SubscriptionMsg, rest)) => rest, - Err(ParseErr::Incomplete) => break, - Err(ParseErr::Unrecoverable) => log_fatal!("Failed parsing Redis msg: {}", &input), - }; - input = rest - } - - Ok(Some(input.to_string())) -} - -impl std::ops::Deref for RedisStream { - type Target = net::TcpStream; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl std::ops::DerefMut for RedisStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl Read for RedisStream { - fn read(&mut self, buffer: &mut [u8]) -> Result { - self.inner.read(buffer) - } -} - -impl AsyncRead for RedisStream { - fn poll_read(&mut self, buf: &mut [u8]) -> Poll { - match self.read(buf) { - Ok(t) => Ok(Async::Ready(t)), - Err(_) => Ok(Async::NotReady), - } - } -}