diff --git a/Cargo.lock b/Cargo.lock index e33dbe8..7cb4f27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -440,7 +440,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.6.4" +version = "0.6.5" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/benches/parse_redis.rs b/benches/parse_redis.rs index a683583..c1fd987 100644 --- a/benches/parse_redis.rs +++ b/benches/parse_redis.rs @@ -3,95 +3,27 @@ use criterion::criterion_group; use criterion::criterion_main; use criterion::Criterion; -const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"

Trending tags:
#neverforget
#4styles
#newpipe
#uber
#mercredifiction

\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"

Hello, I am mastodon.host official semi bot.

Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ).

I also randomly boost one of my followers toot every hour !

If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)

If you want me to follow you, just tell me follow !

If you want automatic follow for new users on your instance and you are an instance admin, contact me !

Other commands are private :)

\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"https://mastodon.host/stats.html\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"https://mastodon.host/about/more\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"@gled\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n"; - -/// Parses the Redis message using a Regex. -/// -/// The naive approach from Flodgatt's proof-of-concept stage. -mod regex_parse { - use regex::Regex; - use serde_json::Value; - - pub fn to_json_value(input: String) -> Value { - if input.ends_with("}\r\n") { - let messages = input.as_str().split("message").skip(1); - let regex = Regex::new(r"timeline:(?P.*?)\r\n\$\d+\r\n(?P.*?)\r\n") - .expect("Hard-codded"); - for message in messages { - let _timeline = regex.captures(message).expect("Hard-coded timeline regex") - ["timeline"] - .to_string(); - - let redis_msg: Value = serde_json::from_str( - ®ex.captures(message).expect("Hard-coded value regex")["value"], - ) - .expect("Valid json"); - - return redis_msg; - } - unreachable!() - } else { - unreachable!() - } - } -} - -/// Parse with a simplified inline iterator. -/// -/// Essentially shows best-case performance for producing a serde_json::Value. -mod parse_inline { - use serde_json::Value; - pub fn to_json_value(input: String) -> Value { - fn print_next_str(mut end: usize, input: &str) -> (usize, String) { - let mut start = end + 3; - end = start + 1; - - let mut iter = input.chars(); - iter.nth(start); - - while iter.next().unwrap().is_digit(10) { - end += 1; - } - let length = &input[start..end].parse::().unwrap(); - start = end + 2; - end = start + length; - - let string = &input[start..end]; - (end, string.to_string()) - } - - if input.ends_with("}\r\n") { - let end = 2; - let (end, _) = print_next_str(end, &input); - let (end, _timeline) = print_next_str(end, &input); - let (_, msg) = print_next_str(end, &input); - let redis_msg: Value = serde_json::from_str(&msg).unwrap(); - redis_msg - } else { - unreachable!() - } - } -} - /// Parse using Flodgatt's current functions mod flodgatt_parse_event { - use flodgatt::{messages::Event, redis_to_client_stream::receiver::MessageQueues}; use flodgatt::{ - parse_client_request::subscription::Timeline, - redis_to_client_stream::{receiver::MsgQueue, redis::redis_stream}, + messages::Event, + parse_client_request::Timeline, + redis_to_client_stream::{process_messages, MessageQueues, MsgQueue}, }; use lru::LruCache; use std::collections::HashMap; use uuid::Uuid; /// One-time setup, not included in testing time. - pub fn setup() -> MessageQueues { + pub fn setup() -> (LruCache, MessageQueues, Uuid, Timeline) { + let mut cache: LruCache = LruCache::new(1000); let mut queues_map = HashMap::new(); let id = Uuid::default(); - let timeline = Timeline::from_redis_raw_timeline("1", None); + let timeline = + Timeline::from_redis_raw_timeline("timeline:1", &mut cache, &None).expect("In test"); queues_map.insert(id, MsgQueue::new(timeline)); let queues = MessageQueues(queues_map); - queues + (cache, queues, id, timeline) } pub fn to_event_struct( @@ -101,201 +33,7 @@ mod flodgatt_parse_event { id: Uuid, timeline: Timeline, ) -> Event { - redis_stream::process_messages(input, &mut None, &mut cache, &mut queues).unwrap(); - queues - .oldest_msg_in_target_queue(id, timeline) - .expect("In test") - } -} - -/// Parse using modified a modified version of Flodgatt's current function. -/// -/// This version is modified to return a serde_json::Value instead of an Event to shows -/// the performance we would see if we used serde's built-in method for handling weakly -/// typed JSON instead of our own strongly typed struct. -mod flodgatt_parse_value { - use flodgatt::{log_fatal, parse_client_request::subscription::Timeline}; - use lru::LruCache; - use serde_json::Value; - use std::{ - collections::{HashMap, VecDeque}, - time::Instant, - }; - use uuid::Uuid; - #[derive(Debug)] - pub struct RedisMsg<'a> { - pub raw: &'a str, - pub cursor: usize, - pub prefix_len: usize, - } - - impl<'a> RedisMsg<'a> { - pub fn from_raw(raw: &'a str, prefix_len: usize) -> Self { - Self { - raw, - cursor: "*3\r\n".len(), //length of intro header - prefix_len, - } - } - - /// Move the cursor from the beginning of a number through its end and return the number - pub fn process_number(&mut self) -> usize { - let (mut selected_number, selection_start) = (0, self.cursor); - while let Ok(number) = self.raw[selection_start..=self.cursor].parse::() { - self.cursor += 1; - selected_number = number; - } - selected_number - } - - /// In a pubsub reply from Redis, an item can be either the name of the subscribed channel - /// or the msg payload. Either way, it follows the same format: - /// `$[LENGTH_OF_ITEM_BODY]\r\n[ITEM_BODY]\r\n` - pub fn next_field(&mut self) -> String { - self.cursor += "$".len(); - - let item_len = self.process_number(); - self.cursor += "\r\n".len(); - let item_start_position = self.cursor; - self.cursor += item_len; - let item = self.raw[item_start_position..self.cursor].to_string(); - self.cursor += "\r\n".len(); - item - } - - pub fn extract_raw_timeline_and_message(&mut self) -> (String, Value) { - let timeline = &self.next_field()[self.prefix_len..]; - let msg_txt = self.next_field(); - let msg_value: Value = serde_json::from_str(&msg_txt) - .unwrap_or_else(|_| log_fatal!("Invalid JSON from Redis: {:?}", &msg_txt)); - (timeline.to_string(), msg_value) - } - } - - pub struct MsgQueue { - pub timeline: Timeline, - pub messages: VecDeque, - _last_polled_at: Instant, - } - - pub struct MessageQueues(HashMap); - impl std::ops::Deref for MessageQueues { - type Target = HashMap; - fn deref(&self) -> &Self::Target { - &self.0 - } - } - - impl std::ops::DerefMut for MessageQueues { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } - } - - impl MessageQueues { - pub fn oldest_msg_in_target_queue( - &mut self, - id: Uuid, - timeline: Timeline, - ) -> Option { - self.entry(id) - .or_insert_with(|| MsgQueue::new(timeline)) - .messages - .pop_front() - } - } - - impl MsgQueue { - pub fn new(timeline: Timeline) -> Self { - MsgQueue { - messages: VecDeque::new(), - _last_polled_at: Instant::now(), - timeline, - } - } - } - - pub fn process_msg( - raw_utf: String, - namespace: &Option, - hashtag_id_cache: &mut LruCache, - queues: &mut MessageQueues, - ) { - // Only act if we have a full message (end on a msg boundary) - if !raw_utf.ends_with("}\r\n") { - return; - }; - let prefix_to_skip = match namespace { - Some(namespace) => format!("{}:timeline:", namespace), - None => "timeline:".to_string(), - }; - - let mut msg = RedisMsg::from_raw(&raw_utf, prefix_to_skip.len()); - - while !msg.raw.is_empty() { - let command = msg.next_field(); - match command.as_str() { - "message" => { - let (raw_timeline, msg_value) = msg.extract_raw_timeline_and_message(); - let hashtag = hashtag_from_timeline(&raw_timeline, hashtag_id_cache); - let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag); - for msg_queue in queues.values_mut() { - if msg_queue.timeline == timeline { - msg_queue.messages.push_back(msg_value.clone()); - } - } - } - - "subscribe" | "unsubscribe" => { - // No msg, so ignore & advance cursor to end - let _channel = msg.next_field(); - msg.cursor += ":".len(); - let _active_subscriptions = msg.process_number(); - msg.cursor += "\r\n".len(); - } - cmd => panic!("Invariant violation: {} is unexpected Redis output", cmd), - }; - msg = RedisMsg::from_raw(&msg.raw[msg.cursor..], msg.prefix_len); - } - } - - fn hashtag_from_timeline( - raw_timeline: &str, - hashtag_id_cache: &mut LruCache, - ) -> Option { - if raw_timeline.starts_with("hashtag") { - let tag_name = raw_timeline - .split(':') - .nth(1) - .unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline)) - .to_string(); - - let tag_id = *hashtag_id_cache - .get(&tag_name) - .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); - Some(tag_id) - } else { - None - } - } - pub fn setup() -> (LruCache, MessageQueues, Uuid, Timeline) { - let cache: LruCache = LruCache::new(1000); - let mut queues_map = HashMap::new(); - let id = Uuid::default(); - let timeline = Timeline::from_redis_raw_timeline("1", None); - queues_map.insert(id, MsgQueue::new(timeline)); - let queues = MessageQueues(queues_map); - (cache, queues, id, timeline) - } - - pub fn to_json_value( - input: String, - mut cache: &mut LruCache, - mut queues: &mut MessageQueues, - id: Uuid, - timeline: Timeline, - ) -> Value { - process_msg(input, &None, &mut cache, &mut queues); + process_messages(&input, &mut cache, &mut None, &mut queues); queues .oldest_msg_in_target_queue(id, timeline) .expect("In test") @@ -303,28 +41,10 @@ mod flodgatt_parse_value { } fn criterion_benchmark(c: &mut Criterion) { - let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string(); //INPUT.to_string(); + let input = ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS.to_string(); let mut group = c.benchmark_group("Parse redis RESP array"); - // group.bench_function("parse to Value with a regex", |b| { - // b.iter(|| regex_parse::to_json_value(black_box(input.clone()))) - // }); - group.bench_function("parse to Value inline", |b| { - b.iter(|| parse_inline::to_json_value(black_box(input.clone()))) - }); - let (mut cache, mut queues, id, timeline) = flodgatt_parse_value::setup(); - group.bench_function("parse to Value using Flodgatt functions", |b| { - b.iter(|| { - black_box(flodgatt_parse_value::to_json_value( - black_box(input.clone()), - black_box(&mut cache), - black_box(&mut queues), - black_box(id), - black_box(timeline), - )) - }) - }); - let mut queues = flodgatt_parse_event::setup(); + let (mut cache, mut queues, id, timeline) = flodgatt_parse_event::setup(); group.bench_function("parse to Event using Flodgatt functions", |b| { b.iter(|| { black_box(flodgatt_parse_event::to_event_struct( @@ -340,3 +60,5 @@ fn criterion_benchmark(c: &mut Criterion) { criterion_group!(benches, criterion_benchmark); criterion_main!(benches); + +const ONE_MESSAGE_FOR_THE_USER_TIMLINE_FROM_REDIS: &str = "*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:1\r\n$3790\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102775370117886890\",\"created_at\":\"2019-09-11T18:42:19.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"unlisted\",\"language\":\"en\",\"uri\":\"https://mastodon.host/users/federationbot/statuses/102775346916917099\",\"url\":\"https://mastodon.host/@federationbot/102775346916917099\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"

Trending tags:
#neverforget
#4styles
#newpipe
#uber
#mercredifiction

\",\"reblog\":null,\"account\":{\"id\":\"78\",\"username\":\"federationbot\",\"acct\":\"federationbot@mastodon.host\",\"display_name\":\"Federation Bot\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-09-10T15:04:25.559Z\",\"note\":\"

Hello, I am mastodon.host official semi bot.

Follow me if you want to have some updates on the view of the fediverse from here ( I only post unlisted ).

I also randomly boost one of my followers toot every hour !

If you don\'t feel confortable with me following you, tell me: unfollow and I\'ll do it :)

If you want me to follow you, just tell me follow !

If you want automatic follow for new users on your instance and you are an instance admin, contact me !

Other commands are private :)

\",\"url\":\"https://mastodon.host/@federationbot\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/078/original/d9e2be5398629cf8.jpeg?1568127863\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":16636,\"following_count\":179532,\"statuses_count\":50554,\"emojis\":[],\"fields\":[{\"name\":\"More stats\",\"value\":\"https://mastodon.host/stats.html\",\"verified_at\":null},{\"name\":\"More infos\",\"value\":\"https://mastodon.host/about/more\",\"verified_at\":null},{\"name\":\"Owner/Friend\",\"value\":\"@gled\",\"verified_at\":null}]},\"media_attachments\":[],\"mentions\":[],\"tags\":[{\"name\":\"4styles\",\"url\":\"https://instance.codesections.com/tags/4styles\"},{\"name\":\"neverforget\",\"url\":\"https://instance.codesections.com/tags/neverforget\"},{\"name\":\"mercredifiction\",\"url\":\"https://instance.codesections.com/tags/mercredifiction\"},{\"name\":\"uber\",\"url\":\"https://instance.codesections.com/tags/uber\"},{\"name\":\"newpipe\",\"url\":\"https://instance.codesections.com/tags/newpipe\"}],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1568227693541}\r\n"; diff --git a/src/config/deployment_cfg_types.rs b/src/config/deployment_cfg_types.rs index b09de80..97d018b 100644 --- a/src/config/deployment_cfg_types.rs +++ b/src/config/deployment_cfg_types.rs @@ -11,14 +11,14 @@ from_env_var!( /// The current environment, which controls what file to read other ENV vars from let name = Env; let default: EnvInner = EnvInner::Development; - let (env_var, allowed_values) = ("RUST_ENV", format!("one of: {:?}", EnvInner::variants())); + let (env_var, allowed_values) = ("RUST_ENV", &format!("one of: {:?}", EnvInner::variants())); let from_str = |s| EnvInner::from_str(s).ok(); ); from_env_var!( /// The address to run Flodgatt on let name = FlodgattAddr; let default: IpAddr = IpAddr::V4("127.0.0.1".parse().expect("hardcoded")); - let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)".to_string()); + let (env_var, allowed_values) = ("BIND", "a valid address (e.g., 127.0.0.1)"); let from_str = |s| match s { "localhost" => Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), _ => s.parse().ok(), @@ -28,35 +28,35 @@ from_env_var!( /// How verbosely Flodgatt should log messages let name = LogLevel; let default: LogLevelInner = LogLevelInner::Warn; - let (env_var, allowed_values) = ("RUST_LOG", format!("one of: {:?}", LogLevelInner::variants())); + let (env_var, allowed_values) = ("RUST_LOG", &format!("one of: {:?}", LogLevelInner::variants())); let from_str = |s| LogLevelInner::from_str(s).ok(); ); from_env_var!( /// A Unix Socket to use in place of a local address let name = Socket; let default: Option = None; - let (env_var, allowed_values) = ("SOCKET", "any string".to_string()); + let (env_var, allowed_values) = ("SOCKET", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// The time between replies sent via WebSocket let name = WsInterval; let default: Duration = Duration::from_millis(100); - let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string()); + let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds"); let from_str = |s| s.parse().map(Duration::from_millis).ok(); ); from_env_var!( /// The time between replies sent via Server Sent Events let name = SseInterval; let default: Duration = Duration::from_millis(100); - let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds".to_string()); + let (env_var, allowed_values) = ("WS_FREQ", "a number of milliseconds"); let from_str = |s| s.parse().map(Duration::from_millis).ok(); ); from_env_var!( /// The port to run Flodgatt on let name = Port; let default: u16 = 4000; - let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535".to_string()); + let (env_var, allowed_values) = ("PORT", "a number between 0 and 65535"); let from_str = |s| s.parse().ok(); ); from_env_var!( @@ -66,7 +66,7 @@ from_env_var!( /// (including otherwise public timelines). let name = WhitelistMode; let default: bool = false; - let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false".to_string()); + let (env_var, allowed_values) = ("WHITELIST_MODE", "true or false"); let from_str = |s| s.parse().ok(); ); /// Permissions for Cross Origin Resource Sharing (CORS) diff --git a/src/config/environmental_variables.rs b/src/config/environmental_variables.rs new file mode 100644 index 0000000..a790c0b --- /dev/null +++ b/src/config/environmental_variables.rs @@ -0,0 +1,137 @@ +use std::{collections::HashMap, fmt}; + +pub struct EnvVar(pub HashMap); +impl std::ops::Deref for EnvVar { + type Target = HashMap; + fn deref(&self) -> &HashMap { + &self.0 + } +} + +impl Clone for EnvVar { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} +impl EnvVar { + pub fn new(vars: HashMap) -> Self { + Self(vars) + } + + pub fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) { + if let Some(value) = maybe_value { + self.0.insert(key.to_string(), value.to_string()); + } + } + + pub fn err(env_var: &str, supplied_value: &str, allowed_values: &str) -> ! { + log::error!( + r"{var} is set to `{value}`, which is invalid. + {var} must be {allowed_vals}.", + var = env_var, + value = supplied_value, + allowed_vals = allowed_values + ); + std::process::exit(1); + } +} +impl fmt::Display for EnvVar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut result = String::new(); + for env_var in [ + "NODE_ENV", + "RUST_LOG", + "BIND", + "PORT", + "SOCKET", + "SSE_FREQ", + "WS_FREQ", + "DATABASE_URL", + "DB_USER", + "USER", + "DB_PORT", + "DB_HOST", + "DB_PASS", + "DB_NAME", + "DB_SSLMODE", + "REDIS_HOST", + "REDIS_USER", + "REDIS_PORT", + "REDIS_PASSWORD", + "REDIS_USER", + "REDIS_DB", + ] + .iter() + { + if let Some(value) = self.get(&env_var.to_string()) { + result = format!("{}\n {}: {}", result, env_var, value) + } + } + write!(f, "{}", result) + } +} +#[macro_export] +macro_rules! maybe_update { + ($name:ident; $item: tt:$type:ty) => ( + pub fn $name(self, item: Option<$type>) -> Self { + match item { + Some($item) => Self{ $item, ..self }, + None => Self { ..self } + } + }); + ($name:ident; Some($item: tt: $type:ty)) => ( + fn $name(self, item: Option<$type>) -> Self{ + match item { + Some($item) => Self{ $item: Some($item), ..self }, + None => Self { ..self } + } + })} +#[macro_export] +macro_rules! from_env_var { + ($(#[$outer:meta])* + let name = $name:ident; + let default: $type:ty = $inner:expr; + let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr); + let from_str = |$arg:ident| $body:expr; + ) => { + pub struct $name(pub $type); + impl std::fmt::Debug for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl std::ops::Deref for $name { + type Target = $type; + fn deref(&self) -> &$type { + &self.0 + } + } + impl std::default::Default for $name { + fn default() -> Self { + $name($inner) + } + } + impl $name { + fn inner_from_str($arg: &str) -> Option<$type> { + $body + } + pub fn maybe_update(self, var: Option<&String>) -> Self { + match var { + Some(empty_string) if empty_string.is_empty() => Self::default(), + Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| { + crate::config::EnvVar::err($env_var, value, $allowed_values) + })), + None => self, + } + + // if let Some(value) = var { + // Self(Self::inner_from_str(value).unwrap_or_else(|| { + // crate::err::env_var_fatal($env_var, value, $allowed_values) + // })) + // } else { + // self + // } + } + } + }; +} diff --git a/src/config/mod.rs b/src/config/mod.rs index b2b314d..6d7d8bc 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,135 +4,7 @@ mod postgres_cfg; mod postgres_cfg_types; mod redis_cfg; mod redis_cfg_types; -pub use self::{ - deployment_cfg::DeploymentConfig, - postgres_cfg::PostgresConfig, - redis_cfg::RedisConfig, - redis_cfg_types::{RedisInterval, RedisNamespace}, -}; -use std::{collections::HashMap, fmt}; +mod environmental_variables; -pub struct EnvVar(pub HashMap); -impl std::ops::Deref for EnvVar { - type Target = HashMap; - fn deref(&self) -> &HashMap { - &self.0 - } -} +pub use {deployment_cfg::DeploymentConfig, postgres_cfg::PostgresConfig, redis_cfg::RedisConfig, environmental_variables::EnvVar}; -impl Clone for EnvVar { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} -impl EnvVar { - pub fn new(vars: HashMap) -> Self { - Self(vars) - } - - fn maybe_add_env_var(&mut self, key: &str, maybe_value: Option) { - if let Some(value) = maybe_value { - self.0.insert(key.to_string(), value.to_string()); - } - } -} -impl fmt::Display for EnvVar { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut result = String::new(); - for env_var in [ - "NODE_ENV", - "RUST_LOG", - "BIND", - "PORT", - "SOCKET", - "SSE_FREQ", - "WS_FREQ", - "DATABASE_URL", - "DB_USER", - "USER", - "DB_PORT", - "DB_HOST", - "DB_PASS", - "DB_NAME", - "DB_SSLMODE", - "REDIS_HOST", - "REDIS_USER", - "REDIS_PORT", - "REDIS_PASSWORD", - "REDIS_USER", - "REDIS_DB", - ] - .iter() - { - if let Some(value) = self.get(&env_var.to_string()) { - result = format!("{}\n {}: {}", result, env_var, value) - } - } - write!(f, "{}", result) - } -} -#[macro_export] -macro_rules! maybe_update { - ($name:ident; $item: tt:$type:ty) => ( - pub fn $name(self, item: Option<$type>) -> Self { - match item { - Some($item) => Self{ $item, ..self }, - None => Self { ..self } - } - }); - ($name:ident; Some($item: tt: $type:ty)) => ( - fn $name(self, item: Option<$type>) -> Self{ - match item { - Some($item) => Self{ $item: Some($item), ..self }, - None => Self { ..self } - } - })} -#[macro_export] -macro_rules! from_env_var { - ($(#[$outer:meta])* - let name = $name:ident; - let default: $type:ty = $inner:expr; - let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr); - let from_str = |$arg:ident| $body:expr; - ) => { - pub struct $name(pub $type); - impl std::fmt::Debug for $name { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0) - } - } - impl std::ops::Deref for $name { - type Target = $type; - fn deref(&self) -> &$type { - &self.0 - } - } - impl std::default::Default for $name { - fn default() -> Self { - $name($inner) - } - } - impl $name { - fn inner_from_str($arg: &str) -> Option<$type> { - $body - } - pub fn maybe_update(self, var: Option<&String>) -> Self { - match var { - Some(empty_string) if empty_string.is_empty() => Self::default(), - Some(value) => Self(Self::inner_from_str(value).unwrap_or_else(|| { - crate::err::env_var_fatal($env_var, value, $allowed_values) - })), - None => self, - } - - // if let Some(value) = var { - // Self(Self::inner_from_str(value).unwrap_or_else(|| { - // crate::err::env_var_fatal($env_var, value, $allowed_values) - // })) - // } else { - // self - // } - } - } - }; -} diff --git a/src/config/postgres_cfg_types.rs b/src/config/postgres_cfg_types.rs index 066f2fd..7551d1b 100644 --- a/src/config/postgres_cfg_types.rs +++ b/src/config/postgres_cfg_types.rs @@ -6,7 +6,7 @@ from_env_var!( /// The user to use for Postgres let name = PgUser; let default: String = "postgres".to_string(); - let (env_var, allowed_values) = ("DB_USER", "any string".to_string()); + let (env_var, allowed_values) = ("DB_USER", "any string"); let from_str = |s| Some(s.to_string()); ); @@ -14,7 +14,7 @@ from_env_var!( /// The host address where Postgres is running) let name = PgHost; let default: String = "localhost".to_string(); - let (env_var, allowed_values) = ("DB_HOST", "any string".to_string()); + let (env_var, allowed_values) = ("DB_HOST", "any string"); let from_str = |s| Some(s.to_string()); ); @@ -22,7 +22,7 @@ from_env_var!( /// The password to use with Postgress let name = PgPass; let default: Option = None; - let (env_var, allowed_values) = ("DB_PASS", "any string".to_string()); + let (env_var, allowed_values) = ("DB_PASS", "any string"); let from_str = |s| Some(Some(s.to_string())); ); @@ -30,7 +30,7 @@ from_env_var!( /// The Postgres database to use let name = PgDatabase; let default: String = "mastodon_development".to_string(); - let (env_var, allowed_values) = ("DB_NAME", "any string".to_string()); + let (env_var, allowed_values) = ("DB_NAME", "any string"); let from_str = |s| Some(s.to_string()); ); @@ -38,14 +38,14 @@ from_env_var!( /// The port Postgres is running on let name = PgPort; let default: u16 = 5432; - let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535".to_string()); + let (env_var, allowed_values) = ("DB_PORT", "a number between 0 and 65535"); let from_str = |s| s.parse().ok(); ); from_env_var!( let name = PgSslMode; let default: PgSslInner = PgSslInner::Prefer; - let (env_var, allowed_values) = ("DB_SSLMODE", format!("one of: {:?}", PgSslInner::variants())); + let (env_var, allowed_values) = ("DB_SSLMODE", &format!("one of: {:?}", PgSslInner::variants())); let from_str = |s| PgSslInner::from_str(s).ok(); ); diff --git a/src/config/redis_cfg_types.rs b/src/config/redis_cfg_types.rs index d3dbbf0..8d43c76 100644 --- a/src/config/redis_cfg_types.rs +++ b/src/config/redis_cfg_types.rs @@ -7,48 +7,48 @@ from_env_var!( /// The host address where Redis is running let name = RedisHost; let default: String = "127.0.0.1".to_string(); - let (env_var, allowed_values) = ("REDIS_HOST", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_HOST", "any string"); let from_str = |s| Some(s.to_string()); ); from_env_var!( /// The port Redis is running on let name = RedisPort; let default: u16 = 6379; - let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535".to_string()); + let (env_var, allowed_values) = ("REDIS_PORT", "a number between 0 and 65535"); let from_str = |s| s.parse().ok(); ); from_env_var!( /// How frequently to poll Redis let name = RedisInterval; let default: Duration = Duration::from_millis(100); - let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds".to_string()); + let (env_var, allowed_values) = ("REDIS_POLL_INTERVAL", "a number of milliseconds"); let from_str = |s| s.parse().map(Duration::from_millis).ok(); ); from_env_var!( /// The password to use for Redis let name = RedisPass; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_PASSWORD", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// An optional Redis Namespace let name = RedisNamespace; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_NAMESPACE", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// A user for Redis (not supported) let name = RedisUser; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_USER", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_USER", "any string"); let from_str = |s| Some(Some(s.to_string())); ); from_env_var!( /// The database to use with Redis (no current effect for PubSub connections) let name = RedisDb; let default: Option = None; - let (env_var, allowed_values) = ("REDIS_DB", "any string".to_string()); + let (env_var, allowed_values) = ("REDIS_DB", "any string"); let from_str = |s| Some(Some(s.to_string())); ); diff --git a/src/err.rs b/src/err.rs index 6ffbaa0..2863c5e 100644 --- a/src/err.rs +++ b/src/err.rs @@ -1,4 +1,3 @@ -use serde_derive::Serialize; use std::fmt::Display; pub fn die_with_msg(msg: impl Display) -> ! { @@ -14,67 +13,20 @@ macro_rules! log_fatal { };}; } -pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! { - eprintln!( - r"FATAL ERROR: {var} is set to `{value}`, which is invalid. - {var} must be {allowed_vals}.", - var = env_var, - value = supplied_value, - allowed_vals = allowed_values - ); - std::process::exit(1); +#[derive(Debug)] +pub enum RedisParseErr { + Incomplete, + Unrecoverable, } -#[macro_export] -macro_rules! dbg_and_die { - ($msg:expr) => { - let message = format!("FATAL ERROR: {}", $msg); - dbg!(message); - std::process::exit(1); - }; -} -pub fn unwrap_or_die(s: Option, msg: &str) -> T { - s.unwrap_or_else(|| { - eprintln!("FATAL ERROR: {}", msg); - std::process::exit(1) - }) +#[derive(Debug)] +pub enum TimelineErr { + RedisNamespaceMismatch, + InvalidInput, } -#[derive(Serialize)] -pub struct ErrorMessage { - error: String, -} -impl ErrorMessage { - fn new(msg: impl std::fmt::Display) -> Self { - Self { - error: msg.to_string(), - } - } -} - -/// Recover from Errors by sending appropriate Warp::Rejections -pub fn handle_errors( - rejection: warp::reject::Rejection, -) -> Result { - let err_txt = match rejection.cause() { - Some(text) if text.to_string() == "Missing request header 'authorization'" => { - "Error: Missing access token".to_string() - } - Some(text) => text.to_string(), - None => "Error: Nonexistant endpoint".to_string(), - }; - let json = warp::reply::json(&ErrorMessage::new(err_txt)); - - Ok(warp::reply::with_status( - json, - warp::http::StatusCode::UNAUTHORIZED, - )) -} - -pub struct CustomError {} - -impl CustomError { - pub fn unauthorized_list() -> warp::reject::Rejection { - warp::reject::custom("Error: Access to list not authorized") +impl From for TimelineErr { + fn from(_error: std::num::ParseIntError) -> Self { + Self::InvalidInput } } diff --git a/src/main.rs b/src/main.rs index b964926..2d2aaa4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,87 +1,76 @@ use flodgatt::{ - config, err, - parse_client_request::{sse, subscription, ws}, - redis_to_client_stream::{self, ClientAgent}, + config::{DeploymentConfig, EnvVar, PostgresConfig, RedisConfig}, + parse_client_request::{PgPool, Subscription}, + redis_to_client_stream::{ClientAgent, EventStream}, }; use std::{collections::HashMap, env, fs, net, os::unix::fs::PermissionsExt}; use tokio::net::UnixListener; use warp::{path, ws::Ws2, Filter}; fn main() { - dotenv::from_filename( - match env::var("ENV").ok().as_ref().map(String::as_str) { + dotenv::from_filename(match env::var("ENV").ok().as_ref().map(String::as_str) { Some("production") => ".env.production", Some("development") | None => ".env", - Some(_) => err::die_with_msg("Unknown ENV variable specified.\n Valid options are: `production` or `development`."), - }).ok(); + Some(unsupported) => EnvVar::err("ENV", unsupported, "`production` or `development`"), + }) + .ok(); let env_vars_map: HashMap<_, _> = dotenv::vars().collect(); - let env_vars = config::EnvVar::new(env_vars_map); + let env_vars = EnvVar::new(env_vars_map); pretty_env_logger::init(); log::info!( "Flodgatt recognized the following environmental variables:{}", env_vars.clone() ); - let redis_cfg = config::RedisConfig::from_env(env_vars.clone()); - let cfg = config::DeploymentConfig::from_env(env_vars.clone()); + let redis_cfg = RedisConfig::from_env(env_vars.clone()); + let cfg = DeploymentConfig::from_env(env_vars.clone()); - let postgres_cfg = config::PostgresConfig::from_env(env_vars.clone()); - let pg_pool = subscription::PgPool::new(postgres_cfg); + let postgres_cfg = PostgresConfig::from_env(env_vars.clone()); + let pg_pool = PgPool::new(postgres_cfg); - let client_agent_sse = ClientAgent::blank(redis_cfg, pg_pool.clone()); + let client_agent_sse = ClientAgent::blank(redis_cfg); let client_agent_ws = client_agent_sse.clone_with_shared_receiver(); log::info!("Streaming server initialized and ready to accept connections"); // Server Sent Events - let (sse_update_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); - let sse_routes = sse::extract_user_or_reject(pg_pool.clone(), whitelist_mode) + let (sse_interval, whitelist_mode) = (*cfg.sse_interval, *cfg.whitelist_mode); + let sse_routes = Subscription::from_sse_query(pg_pool.clone(), whitelist_mode) .and(warp::sse()) .map( - move |subscription: subscription::Subscription, - sse_connection_to_client: warp::sse::Sse| { + move |subscription: Subscription, sse_connection_to_client: warp::sse::Sse| { log::info!("Incoming SSE request for {:?}", subscription.timeline); // Create a new ClientAgent let mut client_agent = client_agent_sse.clone_with_shared_receiver(); // Assign ClientAgent to generate stream of updates for the user/timeline pair client_agent.init_for_user(subscription); // send the updates through the SSE connection - redis_to_client_stream::send_updates_to_sse( - client_agent, - sse_connection_to_client, - sse_update_interval, - ) + EventStream::to_sse(client_agent, sse_connection_to_client, sse_interval) }, ) - .with(warp::reply::with::header("Connection", "keep-alive")) - .recover(err::handle_errors); + .with(warp::reply::with::header("Connection", "keep-alive")); // WebSocket let (ws_update_interval, whitelist_mode) = (*cfg.ws_interval, *cfg.whitelist_mode); - let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone(), whitelist_mode) + let websocket_routes = Subscription::from_ws_request(pg_pool.clone(), whitelist_mode) .and(warp::ws::ws2()) - .map( - move |subscription: subscription::Subscription, token: Option, ws: Ws2| { - log::info!("Incoming websocket request for {:?}", subscription.timeline); - // Create a new ClientAgent - let mut client_agent = client_agent_ws.clone_with_shared_receiver(); - // Assign that agent to generate a stream of updates for the user/timeline pair - client_agent.init_for_user(subscription); - // send the updates through the WS connection (along with the User's access_token - // which is sent for security) + .map(move |subscription: Subscription, ws: Ws2| { + log::info!("Incoming websocket request for {:?}", subscription.timeline); - ( - ws.on_upgrade(move |socket| { - redis_to_client_stream::send_updates_to_ws( - socket, - client_agent, - ws_update_interval, - ) - }), - token.unwrap_or_else(String::new), - ) - }, - ) + let token = subscription.access_token.clone(); + // Create a new ClientAgent + let mut client_agent = client_agent_ws.clone_with_shared_receiver(); + // Assign that agent to generate a stream of updates for the user/timeline pair + client_agent.init_for_user(subscription); + // send the updates through the WS connection (along with the User's access_token + // which is sent for security) + ( + ws.on_upgrade(move |socket| { + EventStream::to_ws(socket, client_agent, ws_update_interval) + }), + token.unwrap_or_else(String::new), + ) + }) .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); let cors = warp::cors() @@ -98,7 +87,28 @@ fn main() { fs::set_permissions(socket, PermissionsExt::from_mode(0o666)).unwrap(); - warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run_incoming(incoming); + warp::serve( + health.or(websocket_routes.or(sse_routes).with(cors).recover( + |rejection: warp::reject::Rejection| { + let err_txt = match rejection.cause() { + Some(text) + if text.to_string() == "Missing request header 'authorization'" => + { + "Error: Missing access token".to_string() + } + Some(text) => text.to_string(), + None => "Error: Nonexistant endpoint".to_string(), + }; + let json = warp::reply::json(&err_txt); + + Ok(warp::reply::with_status( + json, + warp::http::StatusCode::UNAUTHORIZED, + )) + }, + )), + ) + .run_incoming(incoming); } else { let server_addr = net::SocketAddr::new(*cfg.address, cfg.port.0); warp::serve(health.or(websocket_routes.or(sse_routes).with(cors))).run(server_addr); diff --git a/src/messages.rs b/src/messages.rs index 5facc8e..a31cade 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -429,27 +429,24 @@ impl Status { mod test { use super::*; use crate::{ - parse_client_request::subscription::{Content::*, Reach::*, Stream::*, Timeline}, - redis_to_client_stream::{ - receiver::{MessageQueues, MsgQueue}, - redis::{ - redis_msg::{ParseErr, RedisMsg}, - redis_stream, - }, - }, + err::RedisParseErr, + parse_client_request::{Content::*, Reach::*, Stream::*, Timeline}, + redis_to_client_stream::{MessageQueues, MsgQueue, RedisMsg}, }; use lru::LruCache; use std::collections::HashMap; use uuid::Uuid; - type Err = ParseErr; + type Err = RedisParseErr; /// Set up state shared between multiple tests of Redis parsing pub fn shared_setup() -> (LruCache, MessageQueues, Uuid, Timeline) { - let cache: LruCache = LruCache::new(1000); + let mut cache: LruCache = LruCache::new(1000); let mut queues_map = HashMap::new(); - let id = Uuid::default(); + let id = dbg!(Uuid::default()); - let timeline = Timeline::from_redis_raw_timeline("4", None); + let timeline = dbg!( + Timeline::from_redis_raw_timeline("timeline:4", &mut cache, &None).expect("In test") + ); queues_map.insert(id, MsgQueue::new(timeline)); let queues = MessageQueues(queues_map); (cache, queues, id, timeline) @@ -460,8 +457,8 @@ mod test { let input ="*3\r\n$7\r\nmessage\r\n$10\r\ntimeline:4\r\n$1386\r\n{\"event\":\"update\",\"payload\":{\"id\":\"102866835379605039\",\"created_at\":\"2019-09-27T22:29:02.590Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"http://localhost:3000/users/admin/statuses/102866835379605039\",\"url\":\"http://localhost:3000/@admin/102866835379605039\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"favourited\":false,\"reblogged\":false,\"muted\":false,\"content\":\"

@susan hi

\",\"reblog\":null,\"application\":{\"name\":\"Web\",\"website\":null},\"account\":{\"id\":\"1\",\"username\":\"admin\",\"acct\":\"admin\",\"display_name\":\"\",\"locked\":false,\"bot\":false,\"created_at\":\"2019-07-04T00:21:05.890Z\",\"note\":\"

\",\"url\":\"http://localhost:3000/@admin\",\"avatar\":\"http://localhost:3000/avatars/original/missing.png\",\"avatar_static\":\"http://localhost:3000/avatars/original/missing.png\",\"header\":\"http://localhost:3000/headers/original/missing.png\",\"header_static\":\"http://localhost:3000/headers/original/missing.png\",\"followers_count\":3,\"following_count\":3,\"statuses_count\":192,\"emojis\":[],\"fields\":[]},\"media_attachments\":[],\"mentions\":[{\"id\":\"4\",\"username\":\"susan\",\"url\":\"http://localhost:3000/@susan\",\"acct\":\"susan\"}],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null},\"queued_at\":1569623342825}\r\n"; let (mut cache, mut queues, id, timeline) = shared_setup(); - redis_stream::process_messages(input.to_string(), &mut None, &mut cache, &mut queues) - .map_err(|_| ParseErr::Unrecoverable)?; + crate::redis_to_client_stream::process_messages(input, &mut cache, &mut None, &mut queues); + let parsed_event = queues.oldest_msg_in_target_queue(id, timeline).unwrap(); let test_event = Event::Update{ payload: Status { id: "102866835379605039".to_string(), @@ -538,40 +535,40 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!(subscription_msg1, RedisMsg::SubscriptionMsg)); - let (subscription_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg2, RedisMsg::SubscriptionMsg)); - let (subscription_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg3, RedisMsg::SubscriptionMsg)); - let (subscription_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg4, RedisMsg::SubscriptionMsg)); - let (subscription_msg5, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (subscription_msg5, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!(subscription_msg5, RedisMsg::SubscriptionMsg)); - let (update_msg1, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg1, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg1, RedisMsg::EventMsg(_, Event::Update { .. }) )); - let (update_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg2, RedisMsg::EventMsg(_, Event::Update { .. }) )); - let (update_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg3, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg3, RedisMsg::EventMsg(_, Event::Update { .. }) )); - let (update_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (update_msg4, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; assert!(matches!( update_msg4, RedisMsg::EventMsg(_, Event::Update { .. }) @@ -588,7 +585,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( subscription_msg1, RedisMsg::EventMsg(Timeline(User(id), Federated, All), Event::Notification { .. }) if id == 55 @@ -603,9 +600,9 @@ mod test { fn parse_redis_input_delete() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$12\r\ntimeline:308\r\n$49\r\n{\"event\":\"delete\",\"payload\":\"103864778284581232\"}\r\n"; - let (mut cache, _, _, _) = shared_setup(); + let (mut cache, _, _, _) = dbg!(shared_setup()); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( subscription_msg1, RedisMsg::EventMsg( @@ -625,7 +622,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (subscription_msg1, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( subscription_msg1, RedisMsg::EventMsg(Timeline(User(id), Federated, All), Event::FiltersChanged) if id == 56 @@ -642,7 +639,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg( @@ -660,7 +657,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg( @@ -679,7 +676,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg( @@ -701,7 +698,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -721,7 +718,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -741,7 +738,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -761,7 +758,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -781,7 +778,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -797,7 +794,7 @@ mod test { }, ) )); - let (msg2, rest) = RedisMsg::from_raw(rest, &mut cache, "timeline:".len())?; + let (msg2, rest) = RedisMsg::from_raw(rest, &mut cache, &None)?; dbg!(&msg2); assert!(matches!( msg2, @@ -814,7 +811,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -840,7 +837,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -863,7 +860,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -886,7 +883,7 @@ mod test { let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; dbg!(&msg); assert!(matches!( msg, @@ -904,7 +901,7 @@ mod test { fn parse_redis_input_from_live_data_1() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$2799\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880088450458596\",\"created_at\":\"2020-03-24T21:12:37.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"es\",\"uri\":\"https://mastodon.social/users/durru/statuses/103880088436492032\",\"url\":\"https://mastodon.social/@durru/103880088436492032\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"

¡No puedes salir, loca!

\",\"reblog\":null,\"account\":{\"id\":\"2271\",\"username\":\"durru\",\"acct\":\"durru@mastodon.social\",\"display_name\":\"Cloaca Maxima\",\"locked\":false,\"bot\":false,\"discoverable\":true,\"group\":false,\"created_at\":\"2020-03-24T21:27:31.669Z\",\"note\":\"

Todo pasa, antes o después, por la Cloaca, diría Vitruvio.
También compongo palíndromos.

\",\"url\":\"https://mastodon.social/@durru\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/002/271/original/d7675a6ff9d9baa7.jpeg?1585085250\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/002/271/original/d7675a6ff9d9baa7.jpeg?1585085250\",\"header\":\"https://instance.codesections.com/system/accounts/headers/000/002/271/original/e3f0a1989b0d8efc.jpeg?1585085250\",\"header_static\":\"https://instance.codesections.com/system/accounts/headers/000/002/271/original/e3f0a1989b0d8efc.jpeg?1585085250\",\"followers_count\":222,\"following_count\":81,\"statuses_count\":5443,\"last_status_at\":\"2020-03-24\",\"emojis\":[],\"fields\":[{\"name\":\"Mis fotos\",\"value\":\"https://pixelfed.de/durru\",\"verified_at\":null},{\"name\":\"diaspora*\",\"value\":\"https://joindiaspora.com/people/75fec0e05114013484870242ac110007\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2864\",\"type\":\"image\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/864/original/3988312d30936494.jpeg?1585085251\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/864/small/3988312d30936494.jpeg?1585085251\",\"remote_url\":\"https://files.mastodon.social/media_attachments/files/026/669/690/original/d8171331f956cf38.jpg\",\"text_url\":null,\"meta\":{\"original\":{\"width\":1001,\"height\":662,\"size\":\"1001x662\",\"aspect\":1.512084592145015},\"small\":{\"width\":491,\"height\":325,\"size\":\"491x325\",\"aspect\":1.5107692307692309}},\"description\":null,\"blurhash\":\"UdLqhI4n4TIUIAt7t7ay~qIojtRj?bM{M{of\"}],\"mentions\":[],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n"; let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. }) @@ -917,7 +914,7 @@ mod test { fn parse_redis_input_from_live_data_2() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$3888\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880373579328660\",\"created_at\":\"2020-03-24T22:25:05.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"https://newsbots.eu/users/granma/statuses/103880373417385978\",\"url\":\"https://newsbots.eu/@granma/103880373417385978\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"

A total of 11 measures have been established for the pre-epidemic stage of the battle against #Covid-19 in #Cuba
#CubaPorLaSalud
http://en.granma.cu/cuba/2020-03-23/public-health-measures-in-covid-19-pre-epidemic-stage 

\",\"reblog\":null,\"account\":{\"id\":\"717\",\"username\":\"granma\",\"acct\":\"granma@newsbots.eu\",\"display_name\":\"Granma (Unofficial)\",\"locked\":false,\"bot\":true,\"discoverable\":false,\"group\":false,\"created_at\":\"2020-03-13T11:08:08.420Z\",\"note\":\"

\",\"url\":\"https://newsbots.eu/@granma\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/717/original/4a1f9ed090fc36e9.jpeg?1584097687\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/717/original/4a1f9ed090fc36e9.jpeg?1584097687\",\"header\":\"https://instance.codesections.com/headers/original/missing.png\",\"header_static\":\"https://instance.codesections.com/headers/original/missing.png\",\"followers_count\":57,\"following_count\":1,\"statuses_count\":742,\"last_status_at\":\"2020-03-24\",\"emojis\":[],\"fields\":[{\"name\":\"Source\",\"value\":\"https://twitter.com/Granma_English\",\"verified_at\":null},{\"name\":\"Operator\",\"value\":\"@felix\",\"verified_at\":null},{\"name\":\"Code\",\"value\":\"https://yerbamate.dev/nutomic/tootbot\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2881\",\"type\":\"image\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/881/original/a1e97908e84efbcd.jpeg?1585088707\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/881/small/a1e97908e84efbcd.jpeg?1585088707\",\"remote_url\":\"https://newsbots.eu/system/media_attachments/files/000/176/298/original/f30a877d5035f4a6.jpeg\",\"text_url\":null,\"meta\":{\"original\":{\"width\":700,\"height\":795,\"size\":\"700x795\",\"aspect\":0.8805031446540881},\"small\":{\"width\":375,\"height\":426,\"size\":\"375x426\",\"aspect\":0.8802816901408451}},\"description\":null,\"blurhash\":\"UHCY?%sD%1t6}snOxuxu#7rrx]xu$*i_NFNF\"}],\"mentions\":[],\"tags\":[{\"name\":\"covid\",\"url\":\"https://instance.codesections.com/tags/covid\"},{\"name\":\"cuba\",\"url\":\"https://instance.codesections.com/tags/cuba\"},{\"name\":\"CubaPorLaSalud\",\"url\":\"https://instance.codesections.com/tags/CubaPorLaSalud\"}],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n"; let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. }) @@ -930,7 +927,7 @@ mod test { fn parse_redis_input_from_live_data_3() -> Result<(), Err> { let input = "*3\r\n$7\r\nmessage\r\n$15\r\ntimeline:public\r\n$4803\r\n{\"event\":\"update\",\"payload\":{\"id\":\"103880453908763088\",\"created_at\":\"2020-03-24T22:45:33.000Z\",\"in_reply_to_id\":null,\"in_reply_to_account_id\":null,\"sensitive\":false,\"spoiler_text\":\"\",\"visibility\":\"public\",\"language\":\"en\",\"uri\":\"https://mstdn.social/users/stux/statuses/103880453855603541\",\"url\":\"https://mstdn.social/@stux/103880453855603541\",\"replies_count\":0,\"reblogs_count\":0,\"favourites_count\":0,\"content\":\"

When they say lockdown. LOCKDOWN.

\",\"reblog\":null,\"account\":{\"id\":\"806\",\"username\":\"stux\",\"acct\":\"stux@mstdn.social\",\"display_name\":\"sтυx⚡\",\"locked\":false,\"bot\":false,\"discoverable\":true,\"group\":false,\"created_at\":\"2020-03-13T23:02:29.970Z\",\"note\":\"

Hi, Stux here! I am running the mstdn.social :mastodon: instance!

For questions and help or just for fun you can always send me a toot♥\u{fe0f}

Oh and no, I am not really a cat! Or am I?

\",\"url\":\"https://mstdn.social/@stux\",\"avatar\":\"https://instance.codesections.com/system/accounts/avatars/000/000/806/original/dae8d9d01d57d7f8.gif?1584140547\",\"avatar_static\":\"https://instance.codesections.com/system/accounts/avatars/000/000/806/static/dae8d9d01d57d7f8.png?1584140547\",\"header\":\"https://instance.codesections.com/system/accounts/headers/000/000/806/original/88c874d69f7d6989.gif?1584140548\",\"header_static\":\"https://instance.codesections.com/system/accounts/headers/000/000/806/static/88c874d69f7d6989.png?1584140548\",\"followers_count\":13954,\"following_count\":7600,\"statuses_count\":10207,\"last_status_at\":\"2020-03-24\",\"emojis\":[{\"shortcode\":\"mastodon\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/418/original/25ccc64333645735.png?1584140550\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/418/static/25ccc64333645735.png?1584140550\",\"visible_in_picker\":true},{\"shortcode\":\"patreon\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/419/original/3cc463d3dfc1e489.png?1584140550\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/419/static/3cc463d3dfc1e489.png?1584140550\",\"visible_in_picker\":true},{\"shortcode\":\"liberapay\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/420/original/893854353dfa9706.png?1584140551\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/420/static/893854353dfa9706.png?1584140551\",\"visible_in_picker\":true},{\"shortcode\":\"team_valor\",\"url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/958/original/96aae26b45292a12.png?1584910917\",\"static_url\":\"https://instance.codesections.com/system/custom_emojis/images/000/000/958/static/96aae26b45292a12.png?1584910917\",\"visible_in_picker\":true}],\"fields\":[{\"name\":\"Patreon :patreon:\",\"value\":\"https://www.patreon.com/mstdn\",\"verified_at\":null},{\"name\":\"LiberaPay :liberapay:\",\"value\":\"https://liberapay.com/mstdn\",\"verified_at\":null},{\"name\":\"Team :team_valor:\",\"value\":\"https://mstdn.social/team\",\"verified_at\":null},{\"name\":\"Support :mastodon:\",\"value\":\"https://mstdn.social/funding\",\"verified_at\":null}]},\"media_attachments\":[{\"id\":\"2886\",\"type\":\"video\",\"url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/886/original/22b3f98a5e8f86d8.mp4?1585090023\",\"preview_url\":\"https://instance.codesections.com/system/media_attachments/files/000/002/886/small/22b3f98a5e8f86d8.png?1585090023\",\"remote_url\":\"https://cdn.mstdn.social/mstdn-social/media_attachments/files/003/338/384/original/c146f62ba86fe63e.mp4\",\"text_url\":null,\"meta\":{\"length\":\"0:00:27.03\",\"duration\":27.03,\"fps\":30,\"size\":\"272x480\",\"width\":272,\"height\":480,\"aspect\":0.5666666666666667,\"audio_encode\":\"aac (LC) (mp4a / 0x6134706D)\",\"audio_bitrate\":\"44100 Hz\",\"audio_channels\":\"stereo\",\"original\":{\"width\":272,\"height\":480,\"frame_rate\":\"30/1\",\"duration\":27.029,\"bitrate\":481885},\"small\":{\"width\":227,\"height\":400,\"size\":\"227x400\",\"aspect\":0.5675}},\"description\":null,\"blurhash\":\"UBF~N@OF-:xv4mM|s+ob9FE2t6tQ9Fs:t8oN\"}],\"mentions\":[],\"tags\":[],\"emojis\":[],\"card\":null,\"poll\":null}}\r\n"; let (mut cache, _, _, _) = shared_setup(); - let (msg, rest) = RedisMsg::from_raw(input, &mut cache, "timeline:".len())?; + let (msg, rest) = RedisMsg::from_raw(input, &mut cache, &None)?; assert!(matches!( msg, RedisMsg::EventMsg(Timeline(Public, Federated, All), Event::Update { .. }) diff --git a/src/parse_client_request/mod.rs b/src/parse_client_request/mod.rs index c999c0e..c076eb2 100644 --- a/src/parse_client_request/mod.rs +++ b/src/parse_client_request/mod.rs @@ -1,5 +1,13 @@ -//! Parse the client request and return a (possibly authenticated) `User` -pub mod query; -pub mod sse; -pub mod subscription; -pub mod ws; +//! Parse the client request and return a Subscription +mod postgres; +mod query; +mod sse; +mod subscription; +mod ws; + +pub use self::postgres::PgPool; +// TODO consider whether we can remove `Stream` from public API +pub use subscription::{Stream, Subscription, Timeline}; + +#[cfg(test)] +pub use subscription::{Content, Reach}; diff --git a/src/parse_client_request/postgres.rs b/src/parse_client_request/postgres.rs new file mode 100644 index 0000000..606f583 --- /dev/null +++ b/src/parse_client_request/postgres.rs @@ -0,0 +1,204 @@ +//! Postgres queries +use crate::{ + config, + parse_client_request::subscription::{Scope, UserData}, +}; +use ::postgres; +use r2d2_postgres::PostgresConnectionManager; +use std::collections::HashSet; +use warp::reject::Rejection; + +#[derive(Clone, Debug)] +pub struct PgPool(pub r2d2::Pool>); +impl PgPool { + pub fn new(pg_cfg: config::PostgresConfig) -> Self { + let mut cfg = postgres::Config::new(); + cfg.user(&pg_cfg.user) + .host(&*pg_cfg.host.to_string()) + .port(*pg_cfg.port) + .dbname(&pg_cfg.database); + if let Some(password) = &*pg_cfg.password { + cfg.password(password); + }; + + let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); + let pool = r2d2::Pool::builder() + .max_size(10) + .build(manager) + .expect("Can connect to local postgres"); + Self(pool) + } + pub fn select_user(self, token: &str) -> Result { + let mut conn = self.0.get().unwrap(); + let query_rows = conn + .query( + " +SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes +FROM +oauth_access_tokens +INNER JOIN users ON +oauth_access_tokens.resource_owner_id = users.id +WHERE oauth_access_tokens.token = $1 +AND oauth_access_tokens.revoked_at IS NULL +LIMIT 1", + &[&token.to_owned()], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + if let Some(result_columns) = query_rows.get(0) { + let id = result_columns.get(1); + let allowed_langs = result_columns + .try_get::<_, Vec<_>>(2) + .unwrap_or_else(|_| Vec::new()) + .into_iter() + .collect(); + let mut scopes: HashSet = result_columns + .get::<_, String>(3) + .split(' ') + .filter_map(|scope| match scope { + "read" => Some(Scope::Read), + "read:statuses" => Some(Scope::Statuses), + "read:notifications" => Some(Scope::Notifications), + "read:lists" => Some(Scope::Lists), + "write" | "follow" => None, // ignore write scopes + unexpected => { + log::warn!("Ignoring unknown scope `{}`", unexpected); + None + } + }) + .collect(); + // We don't need to separately track read auth - it's just all three others + if scopes.remove(&Scope::Read) { + scopes.insert(Scope::Statuses); + scopes.insert(Scope::Notifications); + scopes.insert(Scope::Lists); + } + + Ok(UserData { + id, + allowed_langs, + scopes, + }) + } else { + Err(warp::reject::custom("Error: Invalid access token")) + } + } + + pub fn select_hashtag_id(self, tag_name: &String) -> Result { + let mut conn = self.0.get().unwrap(); + let rows = &conn + .query( + " +SELECT id +FROM tags +WHERE name = $1 +LIMIT 1", + &[&tag_name], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + + match rows.get(0) { + Some(row) => Ok(row.get(0)), + None => Err(warp::reject::custom("Error: Hashtag does not exist.")), + } + } + + /// Query Postgres for everyone the user has blocked or muted + /// + /// **NOTE**: because we check this when the user connects, it will not include any blocks + /// the user adds until they refresh/reconnect. + pub fn select_blocked_users(self, user_id: i64) -> HashSet { + // " + // SELECT + // 1 + // FROM blocks + // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) + // OR (account_id = $2 AND target_account_id = $1) + // UNION SELECT + // 1 + // FROM mutes + // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})` + // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`" + self + .0 + .get() + .unwrap() + .query( + " +SELECT target_account_id + FROM blocks + WHERE account_id = $1 +UNION SELECT target_account_id + FROM mutes + WHERE account_id = $1", + &[&user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() + } + /// Query Postgres for everyone who has blocked the user + /// + /// **NOTE**: because we check this when the user connects, it will not include any blocks + /// the user adds until they refresh/reconnect. + pub fn select_blocking_users(self, user_id: i64) -> HashSet { + self + .0 + .get() + .unwrap() + .query( + " +SELECT account_id + FROM blocks + WHERE target_account_id = $1", + &[&user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() + } + + /// Query Postgres for all current domain blocks + /// + /// **NOTE**: because we check this when the user connects, it will not include any blocks + /// the user adds until they refresh/reconnect. + pub fn select_blocked_domains(self, user_id: i64) -> HashSet { + self + .0 + .get() + .unwrap() + .query( + "SELECT domain FROM account_domain_blocks WHERE account_id = $1", + &[&user_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])") + .iter() + .map(|row| row.get(0)) + .collect() + } + + /// Test whether a user owns a list + pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool { + let mut conn = self.0.get().unwrap(); + // For the Postgres query, `id` = list number; `account_id` = user.id + let rows = &conn + .query( + " +SELECT id, account_id +FROM lists +WHERE id = $1 +LIMIT 1", + &[&list_id], + ) + .expect("Hard-coded query will return Some([0 or more rows])"); + + match rows.get(0) { + None => false, + Some(row) => { + let list_owner_id: i64 = row.get(1); + list_owner_id == user_id + } + } + } +} diff --git a/src/parse_client_request/query.rs b/src/parse_client_request/query.rs index 3c92c40..c4445a9 100644 --- a/src/parse_client_request/query.rs +++ b/src/parse_client_request/query.rs @@ -28,6 +28,12 @@ impl Query { } macro_rules! make_query_type { + (Stream => $parameter:tt:$type:ty) => { + #[derive(Deserialize, Debug, Default)] + pub struct Stream { + pub $parameter: $type, + } + }; ($name:tt => $parameter:tt:$type:ty) => { #[derive(Deserialize, Debug, Default)] pub struct $name { @@ -59,14 +65,14 @@ impl ToString for Stream { } } -pub fn optional_media_query() -> BoxedFilter<(Media,)> { - warp::query() - .or(warp::any().map(|| Media { - only_media: "false".to_owned(), - })) - .unify() - .boxed() -} +// pub fn optional_media_query() -> BoxedFilter<(Media,)> { +// warp::query() +// .or(warp::any().map(|| Media { +// only_media: "false".to_owned(), +// })) +// .unify() +// .boxed() +// } pub struct OptionalAccessToken; diff --git a/src/parse_client_request/sse.rs b/src/parse_client_request/sse.rs index b62461c..287cd9b 100644 --- a/src/parse_client_request/sse.rs +++ b/src/parse_client_request/sse.rs @@ -1,78 +1,4 @@ //! Filters for all the endpoints accessible for Server Sent Event updates -use super::{ - query::{self, Query}, - subscription::{PgPool, Subscription}, -}; -use warp::{filters::BoxedFilter, path, Filter}; -#[allow(dead_code)] -type TimelineUser = ((String, Subscription),); - -/// Helper macro to match on the first of any of the provided filters -macro_rules! any_of { - ($filter:expr, $($other_filter:expr),*) => { - $filter$(.or($other_filter).unify())*.boxed() - }; -} - -macro_rules! parse_query { - (path => $start:tt $(/ $next:tt)* - endpoint => $endpoint:expr) => { - path!($start $(/ $next)*) - .and(query::Auth::to_filter()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .map( - |auth: query::Auth, - media: query::Media, - hashtag: query::Hashtag, - list: query::List| { - Query { - access_token: auth.access_token, - stream: $endpoint.to_string(), - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - } - }, - ) - .boxed() - }; -} -pub fn extract_user_or_reject( - pg_pool: PgPool, - whitelist_mode: bool, -) -> BoxedFilter<(Subscription,)> { - any_of!( - parse_query!( - path => "api" / "v1" / "streaming" / "user" / "notification" - endpoint => "user:notification" ), - parse_query!( - path => "api" / "v1" / "streaming" / "user" - endpoint => "user"), - parse_query!( - path => "api" / "v1" / "streaming" / "public" / "local" - endpoint => "public:local"), - parse_query!( - path => "api" / "v1" / "streaming" / "public" - endpoint => "public"), - parse_query!( - path => "api" / "v1" / "streaming" / "direct" - endpoint => "direct"), - parse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" - endpoint => "hashtag:local"), - parse_query!(path => "api" / "v1" / "streaming" / "hashtag" - endpoint => "hashtag"), - parse_query!(path => "api" / "v1" / "streaming" / "list" - endpoint => "list") - ) - // because SSE requests place their `access_token` in the header instead of in a query - // parameter, we need to update our Query if the header has a token - .and(query::OptionalAccessToken::from_sse_header()) - .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) - .boxed() -} // #[cfg(test)] // mod test { diff --git a/src/parse_client_request/subscription.rs b/src/parse_client_request/subscription.rs new file mode 100644 index 0000000..d8e874a --- /dev/null +++ b/src/parse_client_request/subscription.rs @@ -0,0 +1,396 @@ +//! `User` struct and related functionality +// #[cfg(test)] +// mod mock_postgres; +// #[cfg(test)] +// use mock_postgres as postgres; +// #[cfg(not(test))] + +use super::postgres::PgPool; +use super::query::Query; +use crate::err::TimelineErr; +use crate::log_fatal; +use lru::LruCache; +use std::collections::HashSet; +use warp::reject::Rejection; + +use super::query; +use warp::{filters::BoxedFilter, path, Filter}; + +/// Helper macro to match on the first of any of the provided filters +macro_rules! any_of { + ($filter:expr, $($other_filter:expr),*) => { + $filter$(.or($other_filter).unify())*.boxed() + }; +} +macro_rules! parse_sse_query { + (path => $start:tt $(/ $next:tt)* + endpoint => $endpoint:expr) => { + path!($start $(/ $next)*) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map( + |auth: query::Auth, + media: query::Media, + hashtag: query::Hashtag, + list: query::List| { + Query { + access_token: auth.access_token, + stream: $endpoint.to_string(), + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() + }; +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Subscription { + pub timeline: Timeline, + pub allowed_langs: HashSet, + pub blocks: Blocks, + pub hashtag_name: Option, + pub access_token: Option, +} + +impl Default for Subscription { + fn default() -> Self { + Self { + timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification), + allowed_langs: HashSet::new(), + blocks: Blocks::default(), + hashtag_name: None, + access_token: None, + } + } +} + +impl Subscription { + pub fn from_ws_request(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> { + parse_ws_query() + .and(query::OptionalAccessToken::from_ws_header()) + .and_then(Query::update_access_token) + .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) + .boxed() + } + + pub fn from_sse_query(pg_pool: PgPool, whitelist_mode: bool) -> BoxedFilter<(Subscription,)> { + any_of!( + parse_sse_query!( + path => "api" / "v1" / "streaming" / "user" / "notification" + endpoint => "user:notification" ), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "user" + endpoint => "user"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "public" / "local" + endpoint => "public:local"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "public" + endpoint => "public"), + parse_sse_query!( + path => "api" / "v1" / "streaming" / "direct" + endpoint => "direct"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local" + endpoint => "hashtag:local"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" + endpoint => "hashtag"), + parse_sse_query!(path => "api" / "v1" / "streaming" / "list" + endpoint => "list") + ) + // because SSE requests place their `access_token` in the header instead of in a query + // parameter, we need to update our Query if the header has a token + .and(query::OptionalAccessToken::from_sse_header()) + .and_then(Query::update_access_token) + .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) + .boxed() + } + fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result { + let user = match q.access_token.clone() { + Some(token) => pool.clone().select_user(&token)?, + None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?, + None => UserData::public(), + }; + let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?; + let hashtag_name = match timeline { + Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag), + _non_hashtag_timeline => None, + }; + + Ok(Subscription { + timeline, + allowed_langs: user.allowed_langs, + blocks: Blocks { + blocking_users: pool.clone().select_blocking_users(user.id), + blocked_users: pool.clone().select_blocked_users(user.id), + blocked_domains: pool.clone().select_blocked_domains(user.id), + }, + hashtag_name, + access_token: q.access_token, + }) + } +} + +fn parse_ws_query() -> BoxedFilter<(Query,)> { + path!("api" / "v1" / "streaming") + .and(path::end()) + .and(warp::query()) + .and(query::Auth::to_filter()) + .and(query::Media::to_filter()) + .and(query::Hashtag::to_filter()) + .and(query::List::to_filter()) + .map( + |stream: query::Stream, + auth: query::Auth, + media: query::Media, + hashtag: query::Hashtag, + list: query::List| { + Query { + access_token: auth.access_token, + stream: stream.stream, + media: media.is_truthy(), + hashtag: hashtag.tag, + list: list.list, + } + }, + ) + .boxed() +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub struct Timeline(pub Stream, pub Reach, pub Content); + +impl Timeline { + pub fn empty() -> Self { + use {Content::*, Reach::*, Stream::*}; + Self(Unset, Local, Notification) + } + + pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String { + use {Content::*, Reach::*, Stream::*}; + match self { + Timeline(Public, Federated, All) => "timeline:public".into(), + Timeline(Public, Local, All) => "timeline:public:local".into(), + Timeline(Public, Federated, Media) => "timeline:public:media".into(), + Timeline(Public, Local, Media) => "timeline:public:local:media".into(), + + Timeline(Hashtag(id), Federated, All) => format!( + "timeline:hashtag:{}", + hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) + ), + Timeline(Hashtag(id), Local, All) => format!( + "timeline:hashtag:{}:local", + hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) + ), + Timeline(User(id), Federated, All) => format!("timeline:{}", id), + Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), + Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), + Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), + Timeline(one, _two, _three) => { + log_fatal!("Supposedly impossible timeline reached: {:?}", one) + } + } + } + + pub fn from_redis_raw_timeline( + timeline: &str, + cache: &mut LruCache, + namespace: &Option, + ) -> Result { + use crate::err::TimelineErr::RedisNamespaceMismatch; + use {Content::*, Reach::*, Stream::*}; + let timeline_slice = &timeline.split(":").collect::>()[..]; + + #[rustfmt::skip] + let (stream, reach, content) = if let Some(ns) = namespace { + match timeline_slice { + [n, "timeline", "public"] if n == ns => (Public, Federated, All), + [_, "timeline", "public"] + | ["timeline", "public"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "public", "local"] if ns == n => (Public, Local, All), + [_, "timeline", "public", "local"] + | ["timeline", "public", "local"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "public", "media"] if ns == n => (Public, Federated, Media), + [_, "timeline", "public", "media"] + | ["timeline", "public", "media"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "public", "local", "media"] if ns == n => (Public, Local, Media), + [_, "timeline", "public", "local", "media"] + | ["timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "hashtag", tag_name] if ns == n => { + let tag_id = *cache + .get(&tag_name.to_string()) + .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); + (Hashtag(tag_id), Federated, All) + } + [_, "timeline", "hashtag", _tag] + | ["timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "hashtag", _tag, "local"] if ns == n => (Hashtag(0), Local, All), + [_, "timeline", "hashtag", _tag, "local"] + | ["timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", id] if ns == n => (User(id.parse().unwrap()), Federated, All), + [_, "timeline", _id] + | ["timeline", _id] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", id, "notification"] if ns == n => + (User(id.parse()?), Federated, Notification), + + [_, "timeline", _id, "notification"] + | ["timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?, + + + [n, "timeline", "list", id] if ns == n => (List(id.parse()?), Federated, All), + [_, "timeline", "list", _id] + | ["timeline", "list", _id] => Err(RedisNamespaceMismatch)?, + + [n, "timeline", "direct", id] if ns == n => (Direct(id.parse()?), Federated, All), + [_, "timeline", "direct", _id] + | ["timeline", "direct", _id] => Err(RedisNamespaceMismatch)?, + + [..] => log_fatal!("Unexpected channel from Redis: {:?}", timeline_slice), + } + } else { + match timeline_slice { + ["timeline", "public"] => (Public, Federated, All), + [_, "timeline", "public"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "public", "local"] => (Public, Local, All), + [_, "timeline", "public", "local"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "public", "media"] => (Public, Federated, Media), + + [_, "timeline", "public", "media"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "public", "local", "media"] => (Public, Local, Media), + [_, "timeline", "public", "local", "media"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "hashtag", _tag] => (Hashtag(0), Federated, All), + [_, "timeline", "hashtag", _tag] => Err(RedisNamespaceMismatch)?, + + ["timeline", "hashtag", _tag, "local"] => (Hashtag(0), Local, All), + [_, "timeline", "hashtag", _tag, "local"] => Err(RedisNamespaceMismatch)?, + + ["timeline", id] => (User(id.parse().unwrap()), Federated, All), + [_, "timeline", _id] => Err(RedisNamespaceMismatch)?, + + ["timeline", id, "notification"] => { + (User(id.parse().unwrap()), Federated, Notification) + } + [_, "timeline", _id, "notification"] => Err(RedisNamespaceMismatch)?, + + ["timeline", "list", id] => (List(id.parse().unwrap()), Federated, All), + [_, "timeline", "list", _id] => Err(RedisNamespaceMismatch)?, + + ["timeline", "direct", id] => (Direct(id.parse().unwrap()), Federated, All), + [_, "timeline", "direct", _id] => Err(RedisNamespaceMismatch)?, + + // Other endpoints don't exist: + [..] => Err(TimelineErr::InvalidInput)?, + } + }; + + Ok(Timeline(stream, reach, content)) + } + fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result { + use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; + let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag); + let user_owns_list = || pool.clone().user_owns_list(user.id, q.list); + + Ok(match q.stream.as_ref() { + "public" => match q.media { + true => Timeline(Public, Federated, Media), + false => Timeline(Public, Federated, All), + }, + "public:local" => match q.media { + true => Timeline(Public, Local, Media), + false => Timeline(Public, Local, All), + }, + "public:media" => Timeline(Public, Federated, Media), + "public:local:media" => Timeline(Public, Local, Media), + + "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All), + "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All), + "user" => match user.scopes.contains(&Statuses) { + true => Timeline(User(user.id), Federated, All), + false => Err(custom("Error: Missing access token"))?, + }, + "user:notification" => match user.scopes.contains(&Statuses) { + true => Timeline(User(user.id), Federated, Notification), + false => Err(custom("Error: Missing access token"))?, + }, + "list" => match user.scopes.contains(&Lists) && user_owns_list() { + true => Timeline(List(q.list), Federated, All), + false => Err(warp::reject::custom("Error: Missing access token"))?, + }, + "direct" => match user.scopes.contains(&Statuses) { + true => Timeline(Direct(user.id), Federated, All), + false => Err(custom("Error: Missing access token"))?, + }, + other => { + log::warn!("Request for nonexistent endpoint: `{}`", other); + Err(custom("Error: Nonexistent endpoint"))? + } + }) + } +} +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Stream { + User(i64), + List(i64), + Direct(i64), + Hashtag(i64), + Public, + Unset, +} +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Reach { + Local, + Federated, +} +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Content { + All, + Media, + Notification, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Scope { + Read, + Statuses, + Notifications, + Lists, +} + +#[derive(Clone, Default, Debug, PartialEq)] +pub struct Blocks { + pub blocked_domains: HashSet, + pub blocked_users: HashSet, + pub blocking_users: HashSet, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct UserData { + pub id: i64, + pub allowed_langs: HashSet, + pub scopes: HashSet, +} + +impl UserData { + fn public() -> Self { + Self { + id: -1, + allowed_langs: HashSet::new(), + scopes: HashSet::new(), + } + } +} diff --git a/src/parse_client_request/subscription/mock_postgres.rs b/src/parse_client_request/subscription/mock_postgres.rs deleted file mode 100644 index d5a9612..0000000 --- a/src/parse_client_request/subscription/mock_postgres.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! Mock Postgres connection (for use in unit testing) -use super::{OauthScope, Subscription}; -use std::collections::HashSet; - -#[derive(Clone)] -pub struct PgPool; -impl PgPool { - pub fn new() -> Self { - Self - } -} - -pub fn select_user( - access_token: &str, - _pg_pool: PgPool, -) -> Result { - let mut user = Subscription::default(); - if access_token == "TEST_USER" { - user.id = 1; - user.logged_in = true; - user.access_token = "TEST_USER".to_string(); - user.email = "user@example.com".to_string(); - user.scopes = OauthScope::from(vec![ - "read".to_string(), - "write".to_string(), - "follow".to_string(), - ]); - } else if access_token == "INVALID" { - return Err(warp::reject::custom("Error: Invalid access token")); - } - Ok(user) -} - -pub fn select_user_blocks(_id: i64, _pg_pool: PgPool) -> HashSet { - HashSet::new() -} -pub fn select_domain_blocks(_pg_pool: PgPool) -> HashSet { - HashSet::new() -} - -pub fn user_owns_list(user_id: i64, list_id: i64, _pg_pool: PgPool) -> bool { - user_id == list_id -} diff --git a/src/parse_client_request/subscription/mod.rs b/src/parse_client_request/subscription/mod.rs deleted file mode 100644 index 3fdb060..0000000 --- a/src/parse_client_request/subscription/mod.rs +++ /dev/null @@ -1,196 +0,0 @@ -//! `User` struct and related functionality -// #[cfg(test)] -// mod mock_postgres; -// #[cfg(test)] -// use mock_postgres as postgres; -// #[cfg(not(test))] -pub mod postgres; -pub use self::postgres::PgPool; -use super::query::Query; -use crate::log_fatal; -use std::collections::HashSet; -use warp::reject::Rejection; - -/// The User (with data read from Postgres) -#[derive(Clone, Debug, PartialEq)] -pub struct Subscription { - pub timeline: Timeline, - pub allowed_langs: HashSet, - pub blocks: Blocks, -} - -impl Default for Subscription { - fn default() -> Self { - Self { - timeline: Timeline(Stream::Unset, Reach::Local, Content::Notification), - allowed_langs: HashSet::new(), - blocks: Blocks::default(), - } - } -} - -impl Subscription { - pub fn from_query(q: Query, pool: PgPool, whitelist_mode: bool) -> Result { - let user = match q.access_token.clone() { - Some(token) => postgres::select_user(&token, pool.clone())?, - None if whitelist_mode => Err(warp::reject::custom("Error: Invalid access token"))?, - None => UserData::public(), - }; - Ok(Subscription { - timeline: Timeline::from_query_and_user(&q, &user, pool.clone())?, - allowed_langs: user.allowed_langs, - blocks: Blocks { - blocking_users: postgres::select_blocking_users(user.id, pool.clone()), - blocked_users: postgres::select_blocked_users(user.id, pool.clone()), - blocked_domains: postgres::select_blocked_domains(user.id, pool.clone()), - }, - }) - } -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub struct Timeline(pub Stream, pub Reach, pub Content); - -impl Timeline { - pub fn empty() -> Self { - use {Content::*, Reach::*, Stream::*}; - Self(Unset, Local, Notification) - } - - pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> String { - use {Content::*, Reach::*, Stream::*}; - match self { - Timeline(Public, Federated, All) => "timeline:public".into(), - Timeline(Public, Local, All) => "timeline:public:local".into(), - Timeline(Public, Federated, Media) => "timeline:public:media".into(), - Timeline(Public, Local, Media) => "timeline:public:local:media".into(), - - Timeline(Hashtag(id), Federated, All) => format!( - "timeline:hashtag:{}", - hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) - ), - Timeline(Hashtag(id), Local, All) => format!( - "timeline:hashtag:{}:local", - hashtag.unwrap_or_else(|| log_fatal!("Did not supply a name for hashtag #{}", id)) - ), - Timeline(User(id), Federated, All) => format!("timeline:{}", id), - Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id), - Timeline(List(id), Federated, All) => format!("timeline:list:{}", id), - Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id), - Timeline(one, _two, _three) => { - log_fatal!("Supposedly impossible timeline reached: {:?}", one) - } - } - } - pub fn from_redis_raw_timeline(raw_timeline: &str, hashtag: Option) -> Self { - use {Content::*, Reach::*, Stream::*}; - match raw_timeline.split(':').collect::>()[..] { - ["public"] => Timeline(Public, Federated, All), - ["public", "local"] => Timeline(Public, Local, All), - ["public", "media"] => Timeline(Public, Federated, Media), - ["public", "local", "media"] => Timeline(Public, Local, Media), - - ["hashtag", _tag] => Timeline(Hashtag(hashtag.unwrap()), Federated, All), - ["hashtag", _tag, "local"] => Timeline(Hashtag(hashtag.unwrap()), Local, All), - [id] => Timeline(User(id.parse().unwrap()), Federated, All), - [id, "notification"] => Timeline(User(id.parse().unwrap()), Federated, Notification), - ["list", id] => Timeline(List(id.parse().unwrap()), Federated, All), - ["direct", id] => Timeline(Direct(id.parse().unwrap()), Federated, All), - // Other endpoints don't exist: - [..] => log_fatal!("Unexpected channel from Redis: {}", raw_timeline), - } - } - fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result { - use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; - let id_from_hashtag = || postgres::select_list_id(&q.hashtag, pool.clone()); - let user_owns_list = || postgres::user_owns_list(user.id, q.list, pool.clone()); - - Ok(match q.stream.as_ref() { - "public" => match q.media { - true => Timeline(Public, Federated, Media), - false => Timeline(Public, Federated, All), - }, - "public:local" => match q.media { - true => Timeline(Public, Local, Media), - false => Timeline(Public, Local, All), - }, - "public:media" => Timeline(Public, Federated, Media), - "public:local:media" => Timeline(Public, Local, Media), - - "hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All), - "hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All), - "user" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, All), - false => Err(custom("Error: Missing access token"))?, - }, - "user:notification" => match user.scopes.contains(&Statuses) { - true => Timeline(User(user.id), Federated, Notification), - false => Err(custom("Error: Missing access token"))?, - }, - "list" => match user.scopes.contains(&Lists) && user_owns_list() { - true => Timeline(List(q.list), Federated, All), - false => Err(warp::reject::custom("Error: Missing access token"))?, - }, - "direct" => match user.scopes.contains(&Statuses) { - true => Timeline(Direct(user.id), Federated, All), - false => Err(custom("Error: Missing access token"))?, - }, - other => { - log::warn!("Request for nonexistent endpoint: `{}`", other); - Err(custom("Error: Nonexistent endpoint"))? - } - }) - } -} -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Stream { - User(i64), - List(i64), - Direct(i64), - Hashtag(i64), - Public, - Unset, -} -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Reach { - Local, - Federated, -} -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Content { - All, - Media, - Notification, -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Scope { - Read, - Statuses, - Notifications, - Lists, -} - -#[derive(Clone, Default, Debug, PartialEq)] -pub struct Blocks { - pub blocked_domains: HashSet, - pub blocked_users: HashSet, - pub blocking_users: HashSet, -} - -#[derive(Clone, Debug, PartialEq)] -pub struct UserData { - id: i64, - allowed_langs: HashSet, - scopes: HashSet, -} - -impl UserData { - fn public() -> Self { - Self { - id: -1, - allowed_langs: HashSet::new(), - scopes: HashSet::new(), - } - } -} diff --git a/src/parse_client_request/subscription/postgres.rs b/src/parse_client_request/subscription/postgres.rs deleted file mode 100644 index 9353efa..0000000 --- a/src/parse_client_request/subscription/postgres.rs +++ /dev/null @@ -1,225 +0,0 @@ -//! Postgres queries -use crate::{ - config, - parse_client_request::subscription::{Scope, UserData}, -}; -use ::postgres; -use r2d2_postgres::PostgresConnectionManager; -use std::collections::HashSet; -use warp::reject::Rejection; - -#[derive(Clone, Debug)] -pub struct PgPool(pub r2d2::Pool>); -impl PgPool { - pub fn new(pg_cfg: config::PostgresConfig) -> Self { - let mut cfg = postgres::Config::new(); - cfg.user(&pg_cfg.user) - .host(&*pg_cfg.host.to_string()) - .port(*pg_cfg.port) - .dbname(&pg_cfg.database); - if let Some(password) = &*pg_cfg.password { - cfg.password(password); - }; - - let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); - let pool = r2d2::Pool::builder() - .max_size(10) - .build(manager) - .expect("Can connect to local postgres"); - Self(pool) - } -} - -pub fn select_user(token: &str, pool: PgPool) -> Result { - let mut conn = pool.0.get().unwrap(); - let query_rows = conn - .query( - " -SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes -FROM -oauth_access_tokens -INNER JOIN users ON -oauth_access_tokens.resource_owner_id = users.id -WHERE oauth_access_tokens.token = $1 -AND oauth_access_tokens.revoked_at IS NULL -LIMIT 1", - &[&token.to_owned()], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - if let Some(result_columns) = query_rows.get(0) { - let id = result_columns.get(1); - let allowed_langs = result_columns - .try_get::<_, Vec<_>>(2) - .unwrap_or_else(|_| Vec::new()) - .into_iter() - .collect(); - let mut scopes: HashSet = result_columns - .get::<_, String>(3) - .split(' ') - .filter_map(|scope| match scope { - "read" => Some(Scope::Read), - "read:statuses" => Some(Scope::Statuses), - "read:notifications" => Some(Scope::Notifications), - "read:lists" => Some(Scope::Lists), - "write" | "follow" => None, // ignore write scopes - unexpected => { - log::warn!("Ignoring unknown scope `{}`", unexpected); - None - } - }) - .collect(); - // We don't need to separately track read auth - it's just all three others - if scopes.remove(&Scope::Read) { - scopes.insert(Scope::Statuses); - scopes.insert(Scope::Notifications); - scopes.insert(Scope::Lists); - } - - Ok(UserData { - id, - allowed_langs, - scopes, - }) - } else { - Err(warp::reject::custom("Error: Invalid access token")) - } -} - -pub fn select_list_id(tag_name: &String, pg_pool: PgPool) -> Result { - let mut conn = pg_pool.0.get().unwrap(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " -SELECT id -FROM tags -WHERE name = $1 -LIMIT 1", - &[&tag_name], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - Some(row) => Ok(row.get(0)), - None => Err(warp::reject::custom("Error: Hashtag does not exist.")), - } -} -pub fn select_hashtag_name(tag_id: &i64, pg_pool: PgPool) -> Result { - let mut conn = pg_pool.0.get().unwrap(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " -SELECT name -FROM tags -WHERE id = $1 -LIMIT 1", - &[&tag_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - Some(row) => Ok(row.get(0)), - None => Err(warp::reject::custom("Error: Hashtag does not exist.")), - } -} - -/// Query Postgres for everyone the user has blocked or muted -/// -/// **NOTE**: because we check this when the user connects, it will not include any blocks -/// the user adds until they refresh/reconnect. -pub fn select_blocked_users(user_id: i64, pg_pool: PgPool) -> HashSet { - // " - // SELECT - // 1 - // FROM blocks - // WHERE (account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})) - // OR (account_id = $2 AND target_account_id = $1) - // UNION SELECT - // 1 - // FROM mutes - // WHERE account_id = $1 AND target_account_id IN (${placeholders(targetAccountIds, 2)})` - // , [req.accountId, unpackedPayload.account.id].concat(targetAccountIds)),`" - pg_pool - .0 - .get() - .unwrap() - .query( - " -SELECT target_account_id - FROM blocks - WHERE account_id = $1 -UNION SELECT target_account_id - FROM mutes - WHERE account_id = $1", - &[&user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() -} -/// Query Postgres for everyone who has blocked the user -/// -/// **NOTE**: because we check this when the user connects, it will not include any blocks -/// the user adds until they refresh/reconnect. -pub fn select_blocking_users(user_id: i64, pg_pool: PgPool) -> HashSet { - pg_pool - .0 - .get() - .unwrap() - .query( - " -SELECT account_id - FROM blocks - WHERE target_account_id = $1", - &[&user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() -} - -/// Query Postgres for all current domain blocks -/// -/// **NOTE**: because we check this when the user connects, it will not include any blocks -/// the user adds until they refresh/reconnect. -pub fn select_blocked_domains(user_id: i64, pg_pool: PgPool) -> HashSet { - pg_pool - .0 - .get() - .unwrap() - .query( - "SELECT domain FROM account_domain_blocks WHERE account_id = $1", - &[&user_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])") - .iter() - .map(|row| row.get(0)) - .collect() -} - -/// Test whether a user owns a list -pub fn user_owns_list(user_id: i64, list_id: i64, pg_pool: PgPool) -> bool { - let mut conn = pg_pool.0.get().unwrap(); - // For the Postgres query, `id` = list number; `account_id` = user.id - let rows = &conn - .query( - " -SELECT id, account_id -FROM lists -WHERE id = $1 -LIMIT 1", - &[&list_id], - ) - .expect("Hard-coded query will return Some([0 or more rows])"); - - match rows.get(0) { - None => false, - Some(row) => { - let list_owner_id: i64 = row.get(1); - list_owner_id == user_id - } - } -} diff --git a/src/parse_client_request/subscription/stdin b/src/parse_client_request/subscription/stdin deleted file mode 100644 index e69de29..0000000 diff --git a/src/parse_client_request/ws.rs b/src/parse_client_request/ws.rs index aa74c60..e101bed 100644 --- a/src/parse_client_request/ws.rs +++ b/src/parse_client_request/ws.rs @@ -1,48 +1,9 @@ //! Filters for the WebSocket endpoint -use super::{ - query::{self, Query}, - subscription::{PgPool, Subscription}, -}; -use warp::{filters::BoxedFilter, path, Filter}; -/// WebSocket filters -fn parse_query() -> BoxedFilter<(Query,)> { - path!("api" / "v1" / "streaming") - .and(path::end()) - .and(warp::query()) - .and(query::Auth::to_filter()) - .and(query::Media::to_filter()) - .and(query::Hashtag::to_filter()) - .and(query::List::to_filter()) - .map( - |stream: query::Stream, - auth: query::Auth, - media: query::Media, - hashtag: query::Hashtag, - list: query::List| { - Query { - access_token: auth.access_token, - stream: stream.stream, - media: media.is_truthy(), - hashtag: hashtag.tag, - list: list.list, - } - }, - ) - .boxed() -} -pub fn extract_user_and_token_or_reject( - pg_pool: PgPool, - whitelist_mode: bool, -) -> BoxedFilter<(Subscription, Option)> { - parse_query() - .and(query::OptionalAccessToken::from_ws_header()) - .and_then(Query::update_access_token) - .and_then(move |q| Subscription::from_query(q, pg_pool.clone(), whitelist_mode)) - .and(query::OptionalAccessToken::from_ws_header()) - .boxed() -} + + + // #[cfg(test)] // mod test { diff --git a/src/redis_to_client_stream/client_agent.rs b/src/redis_to_client_stream/client_agent.rs index 0162598..18c79d3 100644 --- a/src/redis_to_client_stream/client_agent.rs +++ b/src/redis_to_client_stream/client_agent.rs @@ -19,29 +19,29 @@ use super::receiver::Receiver; use crate::{ config, messages::Event, - parse_client_request::subscription::{PgPool, Stream::Public, Subscription, Timeline}, + parse_client_request::{Stream::Public, Subscription, Timeline}, }; use futures::{ Async::{self, NotReady, Ready}, Poll, }; -use std::sync; +use std::sync::{Arc, Mutex}; use tokio::io::Error; use uuid::Uuid; /// Struct for managing all Redis streams. #[derive(Clone, Debug)] pub struct ClientAgent { - receiver: sync::Arc>, - id: uuid::Uuid, + receiver: Arc>, + id: Uuid, pub subscription: Subscription, } impl ClientAgent { /// Create a new `ClientAgent` with no shared data. - pub fn blank(redis_cfg: config::RedisConfig, pg_pool: PgPool) -> Self { + pub fn blank(redis_cfg: config::RedisConfig) -> Self { ClientAgent { - receiver: sync::Arc::new(sync::Mutex::new(Receiver::new(redis_cfg, pg_pool))), + receiver: Arc::new(Mutex::new(Receiver::new(redis_cfg))), id: Uuid::default(), subscription: Subscription::default(), } @@ -70,7 +70,11 @@ impl ClientAgent { self.subscription = subscription; let start_time = Instant::now(); let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)"); - receiver.manage_new_timeline(self.id, self.subscription.timeline); + receiver.manage_new_timeline( + self.id, + self.subscription.timeline, + self.subscription.hashtag_name.clone(), + ); log::info!("init_for_user had lock for: {:?}", start_time.elapsed()); } } diff --git a/src/redis_to_client_stream/event_stream.rs b/src/redis_to_client_stream/event_stream.rs new file mode 100644 index 0000000..5157120 --- /dev/null +++ b/src/redis_to_client_stream/event_stream.rs @@ -0,0 +1,103 @@ +use super::ClientAgent; + +use warp::ws::WebSocket; +use futures::{future::Future, stream::Stream, Async}; +use log; +use std::time::{Duration, Instant}; + +pub struct EventStream; + +impl EventStream { + + +/// Send a stream of replies to a WebSocket client. + pub fn to_ws( + socket: WebSocket, + mut client_agent: ClientAgent, + update_interval: Duration, +) -> impl Future { + let (ws_tx, mut ws_rx) = socket.split(); + let timeline = client_agent.subscription.timeline; + + // Create a pipe + let (tx, rx) = futures::sync::mpsc::unbounded(); + + // Send one end of it to a different thread and tell that end to forward whatever it gets + // on to the websocket client + warp::spawn( + rx.map_err(|()| -> warp::Error { unreachable!() }) + .forward(ws_tx) + .map(|_r| ()) + .map_err(|e| match e.to_string().as_ref() { + "IO error: Broken pipe (os error 32)" => (), // just closed unix socket + _ => log::warn!("websocket send error: {}", e), + }), + ); + + // Yield new events for as long as the client is still connected + let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while( + move |_| match ws_rx.poll() { + Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true), + Ok(Async::Ready(None)) => { + // TODO: consider whether we should manually drop closed connections here + log::info!("Client closed WebSocket connection for {:?}", timeline); + futures::future::ok(false) + } + Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => { + // no err, just closed Unix socket + log::info!("Client closed WebSocket connection for {:?}", timeline); + futures::future::ok(false) + } + Err(e) => { + log::warn!("Error in {:?}: {}", timeline, e); + futures::future::ok(false) + } + }, + ); + + let mut time = Instant::now(); + // Every time you get an event from that stream, send it through the pipe + event_stream + .for_each(move |_instant| { + if let Ok(Async::Ready(Some(msg))) = client_agent.poll() { + tx.unbounded_send(warp::ws::Message::text(msg.to_json_string())) + .expect("No send error"); + }; + if time.elapsed() > Duration::from_secs(30) { + tx.unbounded_send(warp::ws::Message::text("{}")) + .expect("Can ping"); + time = Instant::now(); + } + Ok(()) + }) + .then(move |result| { + // TODO: consider whether we should manually drop closed connections here + log::info!("WebSocket connection for {:?} closed.", timeline); + result + }) + .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e)) +} + pub fn to_sse( + mut client_agent: ClientAgent, + connection: warp::sse::Sse, + update_interval: Duration, +) ->impl warp::reply::Reply { + let event_stream = + tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| { + match client_agent.poll() { + Ok(Async::Ready(Some(event))) => Some(( + warp::sse::event(event.event_name()), + warp::sse::data(event.payload().unwrap_or_else(String::new)), + )), + _ => None, + } + }); + + connection.reply( + warp::sse::keep_alive() + .interval(Duration::from_secs(30)) + .text("thump".to_string()) + .stream(event_stream), + ) +} +} diff --git a/src/redis_to_client_stream/message.rs b/src/redis_to_client_stream/message.rs deleted file mode 100644 index 10cd1d6..0000000 --- a/src/redis_to_client_stream/message.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::log_fatal; -use crate::messages::Event; -use serde_json::Value; -use std::{collections::HashSet, string::String}; -use strum_macros::Display; - -#[derive(Debug, Display, Clone)] -pub enum Message { - Update(Status), - Conversation(Value), - Notification(Value), - Delete(String), - FiltersChanged, - Announcement(AnnouncementType), - UnknownEvent(String, Value), -} - -#[derive(Debug, Clone)] -pub struct Status(Value); - -#[derive(Debug, Clone)] -pub enum AnnouncementType { - New(Value), - Delete(String), - Reaction(Value), -} - -impl Message { - // pub fn from_json(event: Event) -> Self { - // use AnnouncementType::*; - - // match event.event.as_ref() { - // "update" => Self::Update(Status(event.payload)), - // "conversation" => Self::Conversation(event.payload), - // "notification" => Self::Notification(event.payload), - // "delete" => Self::Delete( - // event - // .payload - // .as_str() - // .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event)) - // .to_string(), - // ), - // "filters_changed" => Self::FiltersChanged, - // "announcement" => Self::Announcement(New(event.payload)), - // "announcement.reaction" => Self::Announcement(Reaction(event.payload)), - // "announcement.delete" => Self::Announcement(Delete( - // event - // .payload - // .as_str() - // .unwrap_or_else(|| log_fatal!("Could not process `payload` in {:?}", event)) - // .to_string(), - // )), - // other => { - // log::warn!("Received unexpected `event` from Redis: {}", other); - // Self::UnknownEvent(event.event.to_string(), event.payload) - // } - // } - // } - pub fn event(&self) -> String { - use AnnouncementType::*; - match self { - Self::Update(_) => "update", - Self::Conversation(_) => "conversation", - Self::Notification(_) => "notification", - Self::Announcement(New(_)) => "announcement", - Self::Announcement(Reaction(_)) => "announcement.reaction", - Self::UnknownEvent(event, _) => &event, - Self::Delete(_) => "delete", - Self::Announcement(Delete(_)) => "announcement.delete", - Self::FiltersChanged => "filters_changed", - } - .to_string() - } - pub fn payload(&self) -> String { - use AnnouncementType::*; - match self { - Self::Update(status) => status.0.to_string(), - Self::Conversation(value) - | Self::Notification(value) - | Self::Announcement(New(value)) - | Self::Announcement(Reaction(value)) - | Self::UnknownEvent(_, value) => value.to_string(), - Self::Delete(id) | Self::Announcement(Delete(id)) => id.clone(), - Self::FiltersChanged => "".to_string(), - } - } -} diff --git a/src/redis_to_client_stream/mod.rs b/src/redis_to_client_stream/mod.rs index 5295073..e4cdfb5 100644 --- a/src/redis_to_client_stream/mod.rs +++ b/src/redis_to_client_stream/mod.rs @@ -1,103 +1,14 @@ //! Stream the updates appropriate for a given `User`/`timeline` pair from Redis. -pub mod client_agent; -pub mod receiver; -pub mod redis; -pub use client_agent::ClientAgent; -use futures::{future::Future, stream::Stream, Async}; -use log; -use std::time::{Duration, Instant}; +mod client_agent; +mod event_stream; +mod receiver; +mod redis; -/// Send a stream of replies to a Server Sent Events client. -pub fn send_updates_to_sse( - mut client_agent: ClientAgent, - connection: warp::sse::Sse, - update_interval: Duration, -) -> impl warp::reply::Reply { - let event_stream = - tokio::timer::Interval::new(Instant::now(), update_interval).filter_map(move |_| { - match client_agent.poll() { - Ok(Async::Ready(Some(event))) => Some(( - warp::sse::event(event.event_name()), - warp::sse::data(event.payload().unwrap_or_else(String::new)), - )), - _ => None, - } - }); +pub use {client_agent::ClientAgent, event_stream::EventStream}; - connection.reply( - warp::sse::keep_alive() - .interval(Duration::from_secs(30)) - .text("thump".to_string()) - .stream(event_stream), - ) -} - -use warp::ws::WebSocket; - -/// Send a stream of replies to a WebSocket client. -pub fn send_updates_to_ws( - socket: WebSocket, - mut client_agent: ClientAgent, - update_interval: Duration, -) -> impl Future { - let (ws_tx, mut ws_rx) = socket.split(); - let timeline = client_agent.subscription.timeline; - - // Create a pipe - let (tx, rx) = futures::sync::mpsc::unbounded(); - - // Send one end of it to a different thread and tell that end to forward whatever it gets - // on to the websocket client - warp::spawn( - rx.map_err(|()| -> warp::Error { unreachable!() }) - .forward(ws_tx) - .map(|_r| ()) - .map_err(|e| match e.to_string().as_ref() { - "IO error: Broken pipe (os error 32)" => (), // just closed unix socket - _ => log::warn!("websocket send error: {}", e), - }), - ); - - // Yield new events for as long as the client is still connected - let event_stream = tokio::timer::Interval::new(Instant::now(), update_interval).take_while( - move |_| match ws_rx.poll() { - Ok(Async::NotReady) | Ok(Async::Ready(Some(_))) => futures::future::ok(true), - Ok(Async::Ready(None)) => { - // TODO: consider whether we should manually drop closed connections here - log::info!("Client closed WebSocket connection for {:?}", timeline); - futures::future::ok(false) - } - Err(e) if e.to_string() == "IO error: Broken pipe (os error 32)" => { - // no err, just closed Unix socket - log::info!("Client closed WebSocket connection for {:?}", timeline); - futures::future::ok(false) - } - Err(e) => { - log::warn!("Error in {:?}: {}", timeline, e); - futures::future::ok(false) - } - }, - ); - - let mut time = Instant::now(); - // Every time you get an event from that stream, send it through the pipe - event_stream - .for_each(move |_instant| { - if let Ok(Async::Ready(Some(msg))) = client_agent.poll() { - tx.unbounded_send(warp::ws::Message::text(msg.to_json_string())) - .expect("No send error"); - }; - if time.elapsed() > Duration::from_secs(30) { - tx.unbounded_send(warp::ws::Message::text("{}")) - .expect("Can ping"); - time = Instant::now(); - } - Ok(()) - }) - .then(move |result| { - // TODO: consider whether we should manually drop closed connections here - log::info!("WebSocket connection for {:?} closed.", timeline); - result - }) - .map_err(move |e| log::warn!("Error sending to {:?}: {}", timeline, e)) -} +#[cfg(test)] +pub use receiver::process_messages; +#[cfg(test)] +pub use receiver::{MessageQueues, MsgQueue}; +#[cfg(test)] +pub use redis::redis_msg::RedisMsg; diff --git a/src/redis_to_client_stream/receiver/message_queues.rs b/src/redis_to_client_stream/receiver/message_queues.rs index ffcd31e..eb45b30 100644 --- a/src/redis_to_client_stream/receiver/message_queues.rs +++ b/src/redis_to_client_stream/receiver/message_queues.rs @@ -1,5 +1,5 @@ use crate::messages::Event; -use crate::parse_client_request::subscription::Timeline; +use crate::parse_client_request::Timeline; use std::{ collections::{HashMap, VecDeque}, fmt, diff --git a/src/redis_to_client_stream/receiver/mod.rs b/src/redis_to_client_stream/receiver/mod.rs index 364d1fd..87b35ad 100644 --- a/src/redis_to_client_stream/receiver/mod.rs +++ b/src/redis_to_client_stream/receiver/mod.rs @@ -2,62 +2,70 @@ //! polled by the correct `ClientAgent`. Also manages sububscriptions and //! unsubscriptions to/from Redis. mod message_queues; + +pub use message_queues::{MessageQueues, MsgQueue}; + use crate::{ - config::{self, RedisInterval}, - log_fatal, + config, + err::RedisParseErr, messages::Event, - parse_client_request::subscription::{self, postgres, PgPool, Timeline}, + parse_client_request::{Stream, Timeline}, pubsub_cmd, - redis_to_client_stream::redis::{redis_cmd, RedisConn, RedisStream}, + redis_to_client_stream::redis::redis_msg::RedisMsg, + redis_to_client_stream::redis::{redis_cmd, RedisConn}, }; use futures::{Async, Poll}; use lru::LruCache; -pub use message_queues::{MessageQueues, MsgQueue}; -use std::{collections::HashMap, net, time::Instant}; +use tokio::io::AsyncRead; + +use std::{ + collections::HashMap, + io::Read, + net, str, + time::{Duration, Instant}, +}; use tokio::io::Error; use uuid::Uuid; /// The item that streams from Redis and is polled by the `ClientAgent` #[derive(Debug)] pub struct Receiver { - pub pubsub_connection: RedisStream, + pub pubsub_connection: net::TcpStream, secondary_redis_connection: net::TcpStream, - redis_poll_interval: RedisInterval, + redis_poll_interval: Duration, redis_polled_at: Instant, timeline: Timeline, manager_id: Uuid, pub msg_queues: MessageQueues, clients_per_timeline: HashMap, cache: Cache, - pool: PgPool, + redis_input: Vec, + redis_namespace: Option, } + #[derive(Debug)] pub struct Cache { + // TODO: eventually, it might make sense to have Mastodon publish to timelines with + // the tag number instead of the tag name. This would save us from dealing + // with a cache here and would be consistent with how lists/users are handled. id_to_hashtag: LruCache, pub hashtag_to_id: LruCache, } -impl Cache { - fn new(size: usize) -> Self { - Self { - id_to_hashtag: LruCache::new(size), - hashtag_to_id: LruCache::new(size), - } - } -} + impl Receiver { /// Create a new `Receiver`, with its own Redis connections (but, as yet, no /// active subscriptions). - pub fn new(redis_cfg: config::RedisConfig, pool: PgPool) -> Self { + pub fn new(redis_cfg: config::RedisConfig) -> Self { + let redis_namespace = redis_cfg.namespace.clone(); + let RedisConn { primary: pubsub_connection, secondary: secondary_redis_connection, - namespace: redis_namespace, polling_interval: redis_poll_interval, } = RedisConn::new(redis_cfg); Self { - pubsub_connection: RedisStream::from_stream(pubsub_connection) - .with_namespace(redis_namespace), + pubsub_connection, secondary_redis_connection, redis_poll_interval, redis_polled_at: Instant::now(), @@ -65,8 +73,12 @@ impl Receiver { manager_id: Uuid::default(), msg_queues: MessageQueues(HashMap::new()), clients_per_timeline: HashMap::new(), - cache: Cache::new(1000), // should this be a run-time option? - pool, + cache: Cache { + id_to_hashtag: LruCache::new(1000), + hashtag_to_id: LruCache::new(1000), + }, // should these be run-time options? + redis_input: Vec::new(), + redis_namespace, } } @@ -76,12 +88,15 @@ impl Receiver { /// Note: this method calls `subscribe_or_unsubscribe_as_needed`, /// so Redis PubSub subscriptions are only updated when a new timeline /// comes under management for the first time. - pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: Timeline) { - self.manager_id = manager_id; - self.timeline = timeline; - self.msg_queues - .insert(self.manager_id, MsgQueue::new(timeline)); - self.subscribe_or_unsubscribe_as_needed(timeline); + pub fn manage_new_timeline(&mut self, id: Uuid, tl: Timeline, hashtag: Option) { + self.timeline = tl; + if let (Some(hashtag), Timeline(Stream::Hashtag(id), _, _)) = (hashtag, tl) { + self.cache.id_to_hashtag.put(id, hashtag.clone()); + self.cache.hashtag_to_id.put(hashtag, id); + }; + + self.msg_queues.insert(id, MsgQueue::new(tl)); + self.subscribe_or_unsubscribe_as_needed(tl); } /// Set the `Receiver`'s manager_id and target_timeline fields to the appropriate @@ -91,26 +106,6 @@ impl Receiver { self.timeline = timeline; } - fn if_hashtag_timeline_get_hashtag_name(&mut self, timeline: Timeline) -> Option { - use subscription::Stream::*; - if let Timeline(Hashtag(id), _, _) = timeline { - let cached_tag = self.cache.id_to_hashtag.get(&id).map(String::from); - let tag = match cached_tag { - Some(tag) => tag, - None => { - let new_tag = postgres::select_hashtag_name(&id, self.pool.clone()) - .unwrap_or_else(|_| log_fatal!("No hashtag associated with tag #{}", &id)); - self.cache.hashtag_to_id.put(new_tag.clone(), id); - self.cache.id_to_hashtag.put(id, new_tag.clone()); - new_tag.to_string() - } - }; - Some(tag) - } else { - None - } - } - /// Drop any PubSub subscriptions that don't have active clients and check /// that there's a subscription to the current one. If there isn't, then /// subscribe to it. @@ -121,8 +116,10 @@ impl Receiver { // Record the lower number of clients subscribed to that channel for change in timelines_to_modify { let timeline = change.timeline; - let hashtag = self.if_hashtag_timeline_get_hashtag_name(timeline); - let hashtag = hashtag.as_ref(); + let hashtag = match timeline { + Timeline(Stream::Hashtag(id), _, _) => self.cache.id_to_hashtag.get(&id), + _non_hashtag_timeline => None, + }; let count_of_subscribed_clients = self .clients_per_timeline @@ -157,10 +154,27 @@ impl futures::stream::Stream for Receiver { fn poll(&mut self) -> Poll, Self::Error> { let (timeline, id) = (self.timeline.clone(), self.manager_id); - if self.redis_polled_at.elapsed() > *self.redis_poll_interval { - self.pubsub_connection - .poll_redis(&mut self.cache.hashtag_to_id, &mut self.msg_queues); - self.redis_polled_at = Instant::now(); + if self.redis_polled_at.elapsed() > self.redis_poll_interval { + let mut buffer = vec![0u8; 6000]; + if let Ok(Async::Ready(bytes_read)) = self.poll_read(&mut buffer) { + let binary_input = buffer[..bytes_read].to_vec(); + let (input, extra_bytes) = match str::from_utf8(&binary_input) { + Ok(input) => (input, "".as_bytes()), + Err(e) => { + let (valid, after_valid) = binary_input.split_at(e.valid_up_to()); + let input = str::from_utf8(valid).expect("Guaranteed by `.valid_up_to`"); + (input, after_valid) + } + }; + + let (cache, namespace) = (&mut self.cache.hashtag_to_id, &self.redis_namespace); + + let remaining_input = + process_messages(input, cache, namespace, &mut self.msg_queues); + + self.redis_input.extend_from_slice(remaining_input); + self.redis_input.extend_from_slice(extra_bytes); + } } // Record current time as last polled time @@ -173,3 +187,49 @@ impl futures::stream::Stream for Receiver { } } } + +impl Read for Receiver { + fn read(&mut self, buffer: &mut [u8]) -> Result { + self.pubsub_connection.read(buffer) + } +} + +impl AsyncRead for Receiver { + fn poll_read(&mut self, buf: &mut [u8]) -> Poll { + match self.read(buf) { + Ok(t) => Ok(Async::Ready(t)), + Err(_) => Ok(Async::NotReady), + } + } +} + +#[must_use] +pub fn process_messages<'a>( + input: &'a str, + mut cache: &mut LruCache, + namespace: &Option, + msg_queues: &mut MessageQueues, +) -> &'a [u8] { + let mut remaining_input = input; + use RedisMsg::*; + loop { + match RedisMsg::from_raw(&mut remaining_input, &mut cache, namespace) { + Ok((EventMsg(timeline, event), rest)) => { + for msg_queue in msg_queues.values_mut() { + if msg_queue.timeline == timeline { + msg_queue.messages.push_back(event.clone()); + } + } + remaining_input = rest; + } + Ok((SubscriptionMsg, rest)) | Ok((MsgForDifferentNamespace, rest)) => { + remaining_input = rest; + } + Err(RedisParseErr::Incomplete) => break, + Err(RedisParseErr::Unrecoverable) => { + panic!("Failed parsing Redis msg: {}", &remaining_input) + } + }; + } + remaining_input.as_bytes() +} diff --git a/src/redis_to_client_stream/redis/mod.rs b/src/redis_to_client_stream/redis/mod.rs index c63e243..70ad337 100644 --- a/src/redis_to_client_stream/redis/mod.rs +++ b/src/redis_to_client_stream/redis/mod.rs @@ -1,9 +1,5 @@ pub mod redis_cmd; pub mod redis_connection; pub mod redis_msg; -pub mod redis_stream; pub use redis_connection::RedisConn; -pub use redis_stream::RedisStream; - - diff --git a/src/redis_to_client_stream/redis/redis_cmd.rs b/src/redis_to_client_stream/redis/redis_cmd.rs index b8d8b32..f353ffc 100644 --- a/src/redis_to_client_stream/redis/redis_cmd.rs +++ b/src/redis_to_client_stream/redis/redis_cmd.rs @@ -7,7 +7,7 @@ macro_rules! pubsub_cmd { ($cmd:expr, $self:expr, $tl:expr) => {{ use std::io::Write; log::info!("Sending {} command to {}", $cmd, $tl); - let namespace = $self.pubsub_connection.namespace.clone(); + let namespace = $self.redis_namespace.clone(); $self .pubsub_connection diff --git a/src/redis_to_client_stream/redis/redis_connection.rs b/src/redis_to_client_stream/redis/redis_connection.rs index 26c06d3..8261695 100644 --- a/src/redis_to_client_stream/redis/redis_connection.rs +++ b/src/redis_to_client_stream/redis/redis_connection.rs @@ -1,13 +1,12 @@ use super::redis_cmd; -use crate::config::{RedisConfig, RedisInterval, RedisNamespace}; +use crate::config::RedisConfig; use crate::err; -use std::{io::Read, io::Write, net, time}; +use std::{io::Read, io::Write, net, time::Duration}; pub struct RedisConn { pub primary: net::TcpStream, pub secondary: net::TcpStream, - pub namespace: RedisNamespace, - pub polling_interval: RedisInterval, + pub polling_interval: Duration, } fn send_password(mut conn: net::TcpStream, password: &str) -> net::TcpStream { @@ -68,7 +67,7 @@ impl RedisConn { conn = send_password(conn, &password); } conn = send_test_ping(conn); - conn.set_read_timeout(Some(time::Duration::from_millis(10))) + conn.set_read_timeout(Some(Duration::from_millis(10))) .expect("Can set read timeout for Redis connection"); if let Some(db) = &*redis_cfg.db { conn = set_db(conn, db); @@ -86,8 +85,7 @@ impl RedisConn { Self { primary: primary_conn, secondary: secondary_conn, - namespace: redis_cfg.namespace, - polling_interval: redis_cfg.polling_interval, + polling_interval: *redis_cfg.polling_interval, } } } diff --git a/src/redis_to_client_stream/redis/redis_msg.rs b/src/redis_to_client_stream/redis/redis_msg.rs index f03cd09..d0ef023 100644 --- a/src/redis_to_client_stream/redis/redis_msg.rs +++ b/src/redis_to_client_stream/redis/redis_msg.rs @@ -18,26 +18,31 @@ //! three characters, the second is a bulk string with ten characters, and the third is a //! bulk string with 1,386 characters. -use crate::{log_fatal, messages::Event, parse_client_request::subscription::Timeline}; +use crate::{ + err::{RedisParseErr, TimelineErr}, + messages::Event, + parse_client_request::Timeline, +}; use lru::LruCache; -type Parser<'a, Item> = Result<(Item, &'a str), ParseErr>; -#[derive(Debug)] -pub enum ParseErr { - Incomplete, - Unrecoverable, -} -use ParseErr::*; + +type Parser<'a, Item> = Result<(Item, &'a str), RedisParseErr>; /// A message that has been parsed from an incoming raw message from Redis. #[derive(Debug, Clone)] pub enum RedisMsg { EventMsg(Timeline, Event), SubscriptionMsg, + MsgForDifferentNamespace, } +use RedisParseErr::*; type Hashtags = LruCache; impl RedisMsg { - pub fn from_raw<'a>(input: &'a str, cache: &mut Hashtags, prefix: usize) -> Parser<'a, Self> { + pub fn from_raw<'a>( + input: &'a str, + cache: &mut Hashtags, + namespace: &Option, + ) -> Parser<'a, Self> { // No need to parse the Redis Array header, just skip it let input = input.get("*3\r\n".len()..).ok_or(Incomplete)?; let (command, rest) = parse_redis_bulk_string(&input)?; @@ -46,14 +51,16 @@ impl RedisMsg { // Messages look like; // $10\r\ntimeline:4\r\n // $1386\r\n{\"event\":\"update\",\"payload\"...\"queued_at\":1569623342825}\r\n - let (raw_timeline, rest) = parse_redis_bulk_string(&rest)?; + let (timeline, rest) = parse_redis_bulk_string(&rest)?; let (msg_txt, rest) = parse_redis_bulk_string(&rest)?; + let event: Event = serde_json::from_str(&msg_txt).map_err(|_| Unrecoverable)?; - let raw_timeline = &raw_timeline.get(prefix..).ok_or(Unrecoverable)?; - let event: Event = serde_json::from_str(&msg_txt).unwrap(); - let hashtag = hashtag_from_timeline(&raw_timeline, cache); - let timeline = Timeline::from_redis_raw_timeline(&raw_timeline, hashtag); - Ok((Self::EventMsg(timeline, event), rest)) + use TimelineErr::*; + match Timeline::from_redis_raw_timeline(timeline, cache, namespace) { + Ok(timeline) => Ok((Self::EventMsg(timeline, event), rest)), + Err(RedisNamespaceMismatch) => Ok((Self::MsgForDifferentNamespace, rest)), + Err(InvalidInput) => Err(RedisParseErr::Unrecoverable), + } } "subscribe" | "unsubscribe" => { // subscription statuses look like: @@ -101,18 +108,3 @@ fn parse_number_at(input: &str) -> Parser { let rest = &input.get(number_len..).ok_or(Incomplete)?; Ok((number, rest)) } -fn hashtag_from_timeline(raw_timeline: &str, hashtag_id_cache: &mut Hashtags) -> Option { - if raw_timeline.starts_with("hashtag") { - let tag_name = raw_timeline - .split(':') - .nth(1) - .unwrap_or_else(|| log_fatal!("No hashtag found in `{}`", raw_timeline)) - .to_string(); - let tag_id = *hashtag_id_cache - .get(&tag_name) - .unwrap_or_else(|| log_fatal!("No cached id for `{}`", tag_name)); - Some(tag_id) - } else { - None - } -} diff --git a/src/redis_to_client_stream/redis/redis_stream.rs b/src/redis_to_client_stream/redis/redis_stream.rs deleted file mode 100644 index b50ca0c..0000000 --- a/src/redis_to_client_stream/redis/redis_stream.rs +++ /dev/null @@ -1,127 +0,0 @@ -use super::redis_msg::{ParseErr, RedisMsg}; -use crate::config::RedisNamespace; -use crate::log_fatal; -use crate::redis_to_client_stream::receiver::MessageQueues; -use futures::{Async, Poll}; -use lru::LruCache; -use std::{error::Error, io::Read, net}; -use tokio::io::AsyncRead; - -#[derive(Debug)] -pub struct RedisStream { - pub inner: net::TcpStream, - incoming_raw_msg: String, - pub namespace: RedisNamespace, -} - -impl RedisStream { - pub fn from_stream(inner: net::TcpStream) -> Self { - RedisStream { - inner, - incoming_raw_msg: String::new(), - namespace: RedisNamespace(None), - } - } - pub fn with_namespace(self, namespace: RedisNamespace) -> Self { - RedisStream { namespace, ..self } - } - // Text comes in from redis as a raw stream, which could be more than one message and - // is not guaranteed to end on a message boundary. We need to break it down into - // messages. Incoming messages *are* guaranteed to be RESP arrays (though still not - // guaranteed to end at an array boundary). See https://redis.io/topics/protocol - /// Adds any new Redis messages to the `MsgQueue` for the appropriate `ClientAgent`. - pub fn poll_redis( - &mut self, - hashtag_to_id_cache: &mut LruCache, - queues: &mut MessageQueues, - ) { - let mut buffer = vec![0u8; 6000]; - if let Ok(Async::Ready(num_bytes_read)) = self.poll_read(&mut buffer) { - let raw_utf = self.as_utf8(buffer, num_bytes_read); - self.incoming_raw_msg.push_str(&raw_utf); - match process_messages( - self.incoming_raw_msg.clone(), - &mut self.namespace.0, - hashtag_to_id_cache, - queues, - ) { - Ok(None) => self.incoming_raw_msg.clear(), - Ok(Some(msg_fragment)) => self.incoming_raw_msg = msg_fragment, - Err(e) => { - log::error!("{}", e); - log_fatal!("Could not process RedisStream: {:?}", &self); - } - } - } - } - - fn as_utf8(&mut self, cur_buffer: Vec, size: usize) -> String { - String::from_utf8(cur_buffer[..size].to_vec()).unwrap_or_else(|_| { - let mut new_buffer = vec![0u8; 1]; - self.poll_read(&mut new_buffer).unwrap(); - let buffer = ([cur_buffer, new_buffer]).concat(); - self.as_utf8(buffer, size + 1) - }) - } -} - -type HashtagCache = LruCache; -pub fn process_messages( - raw_msg: String, - namespace: &mut Option, - cache: &mut HashtagCache, - queues: &mut MessageQueues, -) -> Result, Box> { - let prefix_len = match namespace { - Some(namespace) => format!("{}:timeline:", namespace).len(), - None => "timeline:".len(), - }; - - let mut input = raw_msg.as_str(); - loop { - let rest = match RedisMsg::from_raw(&input, cache, prefix_len) { - Ok((RedisMsg::EventMsg(timeline, event), rest)) => { - for msg_queue in queues.values_mut() { - if msg_queue.timeline == timeline { - msg_queue.messages.push_back(event.clone()); - } - } - rest - } - Ok((RedisMsg::SubscriptionMsg, rest)) => rest, - Err(ParseErr::Incomplete) => break, - Err(ParseErr::Unrecoverable) => log_fatal!("Failed parsing Redis msg: {}", &input), - }; - input = rest - } - - Ok(Some(input.to_string())) -} - -impl std::ops::Deref for RedisStream { - type Target = net::TcpStream; - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl std::ops::DerefMut for RedisStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} - -impl Read for RedisStream { - fn read(&mut self, buffer: &mut [u8]) -> Result { - self.inner.read(buffer) - } -} - -impl AsyncRead for RedisStream { - fn poll_read(&mut self, buf: &mut [u8]) -> Poll { - match self.read(buf) { - Ok(t) => Ok(Async::Ready(t)), - Err(_) => Ok(Async::NotReady), - } - } -}