diff --git a/Cargo.lock b/Cargo.lock index aec9118..c9e83a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -406,7 +406,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" [[package]] name = "flodgatt" -version = "0.8.2" +version = "0.8.3" dependencies = [ "criterion 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 1a28d9a..9828af0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "flodgatt" description = "A blazingly fast drop-in replacement for the Mastodon streaming api server" -version = "0.8.2" +version = "0.8.3" authors = ["Daniel Long Sockwell "] edition = "2018" diff --git a/src/err.rs b/src/err.rs index edf9c6c..3a61f4b 100644 --- a/src/err.rs +++ b/src/err.rs @@ -1,7 +1,3 @@ -mod timeline; - -pub use timeline::TimelineErr; - use crate::response::ManagerErr; use std::fmt; diff --git a/src/main.rs b/src/main.rs index c8be66d..7e385ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -62,9 +62,9 @@ fn main() -> Result<(), FatalErr> { .map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token)); #[cfg(feature = "stub_status")] + #[rustfmt::skip] let status = { let (r1, r3) = (shared_manager.clone(), shared_manager.clone()); - #[rustfmt::skip] request.health().map(|| "OK") .or(request.status() .map(move || r1.lock().unwrap_or_else(redis::Manager::recover).count())) diff --git a/src/request.rs b/src/request.rs index c9a4fed..abeea43 100644 --- a/src/request.rs +++ b/src/request.rs @@ -8,7 +8,7 @@ mod subscription; pub use self::postgres::PgPool; // TODO consider whether we can remove `Stream` from public API pub use subscription::{Blocks, Subscription}; -pub use timeline::{Content, Reach, Stream, Timeline}; +pub use timeline::{Content, Reach, Stream, Timeline, TimelineErr}; use self::query::Query; use crate::config; diff --git a/src/request/timeline.rs b/src/request/timeline.rs index ed2f0e2..39818ec 100644 --- a/src/request/timeline.rs +++ b/src/request/timeline.rs @@ -1,12 +1,15 @@ +pub use self::err::TimelineErr; +pub use self::inner::{Content, Reach, Scope, Stream, UserData}; use super::query::Query; -use crate::err::TimelineErr; -use crate::event::Id; -use hashbrown::HashSet; use lru::LruCache; -use std::convert::TryFrom; use warp::reject::Rejection; +mod err; +mod inner; + +type Result = std::result::Result; + #[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] pub struct Timeline(pub Stream, pub Reach, pub Content); @@ -15,7 +18,7 @@ impl Timeline { Self(Stream::Unset, Reach::Local, Content::Notification) } - pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result { + pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result { use {Content::*, Reach::*, Stream::*}; Ok(match self { Timeline(Public, Federated, All) => "timeline:public".into(), @@ -39,34 +42,27 @@ impl Timeline { }) } - pub fn from_redis_text( - timeline: &str, - cache: &mut LruCache, - ) -> Result { - // TODO -- can a combinator shorten this? - let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) { - Some(id) => Ok(*id), - None => Err(TimelineErr::InvalidInput), // TODO more specific - }; + pub fn from_redis_text(timeline: &str, cache: &mut LruCache) -> Result { + use {Content::*, Reach::*, Stream::*, TimelineErr::*}; + let mut tag_id = |t: &str| cache.get(&t.to_string()).map_or(Err(BadTag), |id| Ok(*id)); - use {Content::*, Reach::*, Stream::*}; Ok(match &timeline.split(':').collect::>()[..] { ["public"] => Timeline(Public, Federated, All), ["public", "local"] => Timeline(Public, Local, All), ["public", "media"] => Timeline(Public, Federated, Media), ["public", "local", "media"] => Timeline(Public, Local, Media), - ["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All), - ["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All), + ["hashtag", tag] => Timeline(Hashtag(tag_id(tag)?), Federated, All), + ["hashtag", tag, "local"] => Timeline(Hashtag(tag_id(tag)?), Local, All), [id] => Timeline(User(id.parse()?), Federated, All), [id, "notification"] => Timeline(User(id.parse()?), Federated, Notification), ["list", id] => Timeline(List(id.parse()?), Federated, All), ["direct", id] => Timeline(Direct(id.parse()?), Federated, All), // Other endpoints don't exist: - [..] => Err(TimelineErr::InvalidInput)?, + [..] => Err(InvalidInput)?, }) } - pub fn from_query_and_user(q: &Query, user: &UserData) -> Result { + pub fn from_query_and_user(q: &Query, user: &UserData) -> std::result::Result { use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*}; Ok(match q.stream.as_ref() { @@ -106,69 +102,3 @@ impl Timeline { }) } } - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Stream { - User(Id), - // TODO consider whether List, Direct, and Hashtag should all be `id::Id`s - List(i64), - Direct(i64), - Hashtag(i64), - Public, - Unset, -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Reach { - Local, - Federated, -} - -#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] -pub enum Content { - All, - Media, - Notification, -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Scope { - Read, - Statuses, - Notifications, - Lists, -} - -impl TryFrom<&str> for Scope { - type Error = TimelineErr; - - fn try_from(s: &str) -> Result { - match s { - "read" => Ok(Scope::Read), - "read:statuses" => Ok(Scope::Statuses), - "read:notifications" => Ok(Scope::Notifications), - "read:lists" => Ok(Scope::Lists), - "write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes - unexpected => { - log::warn!("Ignoring unknown scope `{}`", unexpected); - Err(TimelineErr::InvalidInput) - } - } - } -} - -pub struct UserData { - pub id: Id, - pub allowed_langs: HashSet, - pub scopes: HashSet, -} - -impl UserData { - pub fn public() -> Self { - Self { - id: Id(-1), - allowed_langs: HashSet::new(), - scopes: HashSet::new(), - } - } -} diff --git a/src/err/timeline.rs b/src/request/timeline/err.rs similarity index 89% rename from src/err/timeline.rs rename to src/request/timeline/err.rs index 6f05f89..5dbf660 100644 --- a/src/err/timeline.rs +++ b/src/request/timeline/err.rs @@ -4,6 +4,7 @@ use std::fmt; pub enum TimelineErr { MissingHashtag, InvalidInput, + BadTag, } impl std::error::Error for TimelineErr {} @@ -20,6 +21,7 @@ impl fmt::Display for TimelineErr { let msg = match self { InvalidInput => "The timeline text from Redis could not be parsed into a supported timeline. TODO: add incoming timeline text", MissingHashtag => "Attempted to send a hashtag timeline without supplying a tag name", + BadTag => "No hashtag exists with the specified hashtag ID" }; write!(f, "{}", msg) } diff --git a/src/request/timeline/inner.rs b/src/request/timeline/inner.rs new file mode 100644 index 0000000..abe12d3 --- /dev/null +++ b/src/request/timeline/inner.rs @@ -0,0 +1,70 @@ +use super::TimelineErr; +use crate::event::Id; + +use hashbrown::HashSet; +use std::convert::TryFrom; + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Stream { + User(Id), + List(i64), + Direct(i64), + Hashtag(i64), + Public, + Unset, +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Reach { + Local, + Federated, +} + +#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)] +pub enum Content { + All, + Media, + Notification, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Scope { + Read, + Statuses, + Notifications, + Lists, +} + +impl TryFrom<&str> for Scope { + type Error = TimelineErr; + + fn try_from(s: &str) -> Result { + match s { + "read" => Ok(Scope::Read), + "read:statuses" => Ok(Scope::Statuses), + "read:notifications" => Ok(Scope::Notifications), + "read:lists" => Ok(Scope::Lists), + "write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes + unexpected => { + log::warn!("Ignoring unknown scope `{}`", unexpected); + Err(TimelineErr::InvalidInput) + } + } + } +} + +pub struct UserData { + pub id: Id, + pub allowed_langs: HashSet, + pub scopes: HashSet, +} + +impl UserData { + pub fn public() -> Self { + Self { + id: Id(-1), + allowed_langs: HashSet::new(), + scopes: HashSet::new(), + } + } +} diff --git a/src/response/redis.rs b/src/response/redis.rs index a345ae6..0a060d6 100644 --- a/src/response/redis.rs +++ b/src/response/redis.rs @@ -5,3 +5,23 @@ pub mod msg; pub use connection::{RedisConn, RedisConnErr}; pub use manager::{Manager, ManagerErr}; pub use msg::RedisParseErr; + +pub enum RedisCmd { + Subscribe, + Unsubscribe, +} + +impl RedisCmd { + pub fn into_sendable(&self, tl: &String) -> (Vec, Vec) { + match self { + RedisCmd::Subscribe => ( + format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(), + format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl).into_bytes(), + ), + RedisCmd::Unsubscribe => ( + format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(), + format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl).into_bytes(), + ), + } + } +} diff --git a/src/response/redis/connection.rs b/src/response/redis/connection.rs index d47ce89..29d7ac6 100644 --- a/src/response/redis/connection.rs +++ b/src/response/redis/connection.rs @@ -1,13 +1,13 @@ mod err; pub use err::RedisConnErr; -use super::msg::{RedisMsg, RedisParseErr, RedisParseOutput}; -use super::ManagerErr; +use super::msg::{RedisParseErr, RedisParseOutput}; +use super::{ManagerErr, RedisCmd}; use crate::config::Redis; use crate::event::Event; use crate::request::{Stream, Timeline}; -use futures::Async; +use futures::{Async, Poll}; use lru::LruCache; use std::convert::{TryFrom, TryInto}; use std::io::{Read, Write}; @@ -16,7 +16,6 @@ use std::str; use std::time::Duration; type Result = std::result::Result; -type Poll = futures::Poll, ManagerErr>; #[derive(Debug)] pub struct RedisConn { @@ -48,7 +47,7 @@ impl RedisConn { Ok(redis_conn) } - pub fn poll_redis(&mut self) -> Poll { + pub fn poll_redis(&mut self) -> Poll, ManagerErr> { let mut size = 100; // large enough to handle subscribe/unsubscribe notice let (mut buffer, mut first_read) = (vec![0u8; size], true); loop { @@ -68,6 +67,7 @@ impl RedisConn { return Ok(Async::NotReady); } + // at this point, we have the raw bytes; now, parse what we can and leave the remainder let input = self.redis_input.clone(); self.redis_input.clear(); @@ -83,11 +83,15 @@ impl RedisConn { Ok(Msg(msg)) => match &self.redis_namespace { Some(ns) if msg.timeline_txt.starts_with(&format!("{}:timeline:", ns)) => { let trimmed_tl = &msg.timeline_txt[ns.len() + ":timeline:".len()..]; - (self.into_tl_event(trimmed_tl, &msg)?, msg.leftover_input) + let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?; + let event = msg.event_txt.try_into()?; + (Ok(Ready(Some((tl, event)))), (msg.leftover_input)) } None => { let trimmed_tl = &msg.timeline_txt["timeline:".len()..]; - (self.into_tl_event(trimmed_tl, &msg)?, msg.leftover_input) + let tl = Timeline::from_redis_text(trimmed_tl, &mut self.tag_id_cache)?; + let event = msg.event_txt.try_into()?; + (Ok(Ready(Some((tl, event)))), (msg.leftover_input)) } Some(_non_matching_namespace) => (Ok(Ready(None)), msg.leftover_input), }, @@ -100,12 +104,6 @@ impl RedisConn { res } - fn into_tl_event<'a>(&mut self, tl: &'a str, msg: &'a RedisMsg) -> Result { - let tl = Timeline::from_redis_text(tl, &mut self.tag_id_cache)?; - let event = msg.event_txt.try_into().expect("TODO"); - Ok(Ok(Async::Ready(Some((tl, event))))) - } - pub fn update_cache(&mut self, hashtag: String, id: i64) { self.tag_id_cache.put(hashtag.clone(), id); self.tag_name_cache.put(id, hashtag); @@ -135,10 +133,11 @@ impl RedisConn { .map_err(|e| RedisConnErr::with_addr(&addr, e))?; Ok(conn) } + fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> { conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes()) .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - let mut buffer = vec![0u8; 5]; + let mut buffer = vec![0_u8; 5]; conn.read_exact(&mut buffer) .map_err(|e| RedisConnErr::with_addr(&addr, e))?; if String::from_utf8_lossy(&buffer) != "+OK\r\n" { @@ -150,7 +149,7 @@ impl RedisConn { fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> { conn.write_all(b"PING\r\n") .map_err(|e| RedisConnErr::with_addr(&addr, e))?; - let mut buffer = vec![0u8; 7]; + let mut buffer = vec![0_u8; 7]; conn.read_exact(&mut buffer) .map_err(|e| RedisConnErr::with_addr(&addr, e))?; let reply = String::from_utf8_lossy(&buffer); @@ -162,23 +161,3 @@ impl RedisConn { } } } - -pub enum RedisCmd { - Subscribe, - Unsubscribe, -} - -impl RedisCmd { - pub fn into_sendable(&self, tl: &String) -> (Vec, Vec) { - match self { - RedisCmd::Subscribe => ( - format!("*2\r\n$9\r\nsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(), - format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n1\r\n", tl.len(), tl).into_bytes(), - ), - RedisCmd::Unsubscribe => ( - format!("*2\r\n$11\r\nunsubscribe\r\n${}\r\n{}\r\n", tl.len(), tl).into_bytes(), - format!("*3\r\n$3\r\nSET\r\n${}\r\n{}\r\n$1\r\n0\r\n", tl.len(), tl).into_bytes(), - ), - } - } -} diff --git a/src/response/redis/connection/err.rs b/src/response/redis/connection/err.rs index bb702c2..ec945d8 100644 --- a/src/response/redis/connection/err.rs +++ b/src/response/redis/connection/err.rs @@ -1,4 +1,4 @@ -use crate::err::TimelineErr; +use crate::request::TimelineErr; use std::fmt; #[derive(Debug)] diff --git a/src/response/redis/manager.rs b/src/response/redis/manager.rs index c0fc6d8..04ea052 100644 --- a/src/response/redis/manager.rs +++ b/src/response/redis/manager.rs @@ -4,7 +4,7 @@ mod err; pub use err::ManagerErr; -use super::{connection::RedisCmd, RedisConn}; +use super::{RedisCmd, RedisConn}; use crate::config; use crate::event::Event; use crate::request::{Stream, Subscription, Timeline}; diff --git a/src/response/redis/manager/err.rs b/src/response/redis/manager/err.rs index 3566457..b543520 100644 --- a/src/response/redis/manager/err.rs +++ b/src/response/redis/manager/err.rs @@ -1,7 +1,6 @@ use super::super::{RedisConnErr, RedisParseErr}; -use crate::err::TimelineErr; use crate::event::{Event, EventErr}; -use crate::request::Timeline; +use crate::request::{Timeline, TimelineErr}; use std::fmt; #[derive(Debug)]