Code reorganization [WIP]

This commit is contained in:
Daniel Sockwell 2020-04-12 22:01:55 -04:00
parent cf55eb7019
commit a2e879ec3a
12 changed files with 404 additions and 457 deletions

View File

@ -14,6 +14,7 @@ mod redis_cfg;
mod redis_cfg_types;
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
// TODO -- should this allow the user to run in a dir without a `.env` file?
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
Some("production") => ".env.production",
Some("development") | None => ".env",

View File

@ -92,7 +92,7 @@ impl fmt::Debug for Cors<'_> {
}
}
#[derive(EnumString, EnumVariantNames, Debug)]
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
#[strum(serialize_all = "snake_case")]
pub enum LogLevelInner {
Trace,
@ -102,7 +102,7 @@ pub enum LogLevelInner {
Error,
}
#[derive(EnumString, EnumVariantNames, Debug)]
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
#[strum(serialize_all = "snake_case")]
pub enum EnvInner {
Production,

View File

@ -1,6 +1,7 @@
use hashbrown::HashMap;
use std::fmt;
#[derive(Debug)]
pub struct EnvVar(pub HashMap<String, String>);
impl std::ops::Deref for EnvVar {
type Target = HashMap<String, String>;
@ -94,6 +95,7 @@ macro_rules! from_env_var {
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
let from_str = |$arg:ident| $body:expr;
) => {
#[derive(Clone)]
pub struct $name(pub $type);
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {

View File

@ -2,7 +2,7 @@ use super::{postgres_cfg_types::*, EnvVar};
use url::Url;
use urlencoding;
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Postgres {
pub user: PgUser,
pub host: PgHost,

View File

@ -49,7 +49,7 @@ from_env_var!(
let from_str = |s| PgSslInner::from_str(s).ok();
);
#[derive(EnumString, EnumVariantNames, Debug)]
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
#[strum(serialize_all = "snake_case")]
pub enum PgSslInner {
Prefer,

View File

@ -1,7 +1,7 @@
use flodgatt::config;
use flodgatt::err::FatalErr;
use flodgatt::messages::Event;
use flodgatt::request::{PgPool, Subscription, Timeline};
use flodgatt::request::{self, Subscription, Timeline};
use flodgatt::response::redis;
use flodgatt::response::stream;
@ -27,14 +27,16 @@ fn main() -> Result<(), FatalErr> {
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
let shared_pg_conn = PgPool::new(postgres_cfg, *cfg.whitelist_mode);
let request_handler = request::Handler::new(postgres_cfg, *cfg.whitelist_mode);
let poll_freq = *redis_cfg.polling_interval;
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
// Server Sent Events
let sse_manager = shared_manager.clone();
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
let sse = Subscription::from_sse_request(shared_pg_conn.clone())
let sse = request_handler
.parse_sse_request()
.and(warp::sse())
.map(
move |subscription: Subscription, client_conn: warp::sse::Sse| {
@ -56,7 +58,8 @@ fn main() -> Result<(), FatalErr> {
// WebSocket
let ws_manager = shared_manager.clone();
let ws = Subscription::from_ws_request(shared_pg_conn)
let ws = request_handler
.parse_ws_request()
.and(warp::ws::ws2())
.map(move |subscription: Subscription, ws: Ws2| {
log::info!("Incoming websocket request for {:?}", subscription.timeline);

View File

@ -1,23 +1,53 @@
//! Parse the client request and return a Subscription
mod postgres;
mod query;
pub mod timeline;
mod subscription;
pub use self::postgres::PgPool;
// TODO consider whether we can remove `Stream` from public API
pub use subscription::{Blocks, Stream, Subscription, Timeline};
pub use subscription::{Content, Reach};
pub use subscription::{Blocks, Subscription};
pub use timeline::{Content, Reach, Stream, Timeline};
use self::query::Query;
use crate::config;
use warp::{filters::BoxedFilter, path, reject::Rejection, Filter};
use warp::{filters::BoxedFilter, path, Filter};
#[cfg(test)]
mod sse_test;
#[cfg(test)]
mod ws_test;
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*.boxed()
};
}
macro_rules! parse_sse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(|auth: query::Auth, media: query::Media, hashtag: query::Hashtag, list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
#[derive(Debug, Clone)]
pub struct Handler {
pg_conn: PgPool,
}
@ -29,14 +59,47 @@ impl Handler {
}
}
pub fn from_ws_request(&self) -> BoxedFilter<(Subscription,)> {
pub fn parse_ws_request(&self) -> BoxedFilter<(Subscription,)> {
let pg_conn = self.pg_conn.clone();
parse_ws_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_conn.clone()))
.and_then(move |q| Subscription::query_postgres(q, pg_conn.clone()))
.boxed()
}
pub fn parse_sse_request(&self) -> BoxedFilter<(Subscription,)> {
let pg_conn = self.pg_conn.clone();
any_of!(
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::query_postgres(q, pg_conn.clone()))
.boxed()
}
}
fn parse_ws_query() -> BoxedFilter<(Query,)> {

View File

@ -1,12 +1,12 @@
//! Postgres queries
use crate::{
config,
messages::Id,
request::subscription::{Scope, UserData},
};
use crate::config;
use crate::messages::Id;
use crate::request::timeline::{Scope, UserData};
use ::postgres;
use hashbrown::HashSet;
use r2d2_postgres::PostgresConnectionManager;
use std::convert::TryFrom;
use warp::reject::Rejection;
#[derive(Clone, Debug)]
@ -14,6 +14,7 @@ pub struct PgPool {
pub conn: r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>,
whitelist_mode: bool,
}
impl PgPool {
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
let mut cfg = postgres::Config::new();
@ -40,15 +41,11 @@ impl PgPool {
let mut conn = self.conn.get().unwrap();
if let Some(token) = token {
let query_rows = conn
.query(
"
.query("
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1
AND oauth_access_tokens.revoked_at IS NULL
FROM oauth_access_tokens
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&token.to_owned()],
)
@ -57,29 +54,20 @@ LIMIT 1",
let id = Id(result_columns.get(1));
let allowed_langs = result_columns
.try_get::<_, Vec<_>>(2)
.unwrap_or_else(|_| Vec::new())
.unwrap_or_default()
.into_iter()
.collect();
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.filter_map(|scope| match scope {
"read" => Some(Scope::Read),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
"write" | "follow" => None, // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
None
}
})
.filter_map(|scope| Scope::try_from(scope).ok())
.collect();
// We don't need to separately track read auth - it's just all three others
if scopes.remove(&Scope::Read) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
if scopes.contains(&Scope::Read) {
scopes = vec![Scope::Statuses, Scope::Notifications, Scope::Lists]
.into_iter()
.collect()
}
Ok(UserData {
@ -98,19 +86,10 @@ LIMIT 1",
}
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
let mut conn = self.conn.get().unwrap();
let rows = &conn
.query(
"
SELECT id
FROM tags
WHERE name = $1
LIMIT 1",
&[&tag_name],
)
.expect("Hard-coded query will return Some([0 or more rows])");
rows.get(0)
let mut conn = self.conn.get().expect("TODO");
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
.expect("Hard-coded query will return Some([0 or more rows])")
.get(0)
.map(|row| row.get(0))
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
}
@ -120,43 +99,31 @@ LIMIT 1",
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
self.conn
.get()
.unwrap()
.query(
"
SELECT target_account_id
FROM blocks
WHERE account_id = $1
UNION SELECT target_account_id
FROM mutes
WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| Id(row.get(0)))
.collect()
let mut conn = self.conn.get().expect("TODO");
conn.query(
"SELECT target_account_id FROM blocks WHERE account_id = $1
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| Id(row.get(0)))
.collect()
}
/// Query Postgres for everyone who has blocked the user
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
self.conn
.get()
.unwrap()
.query(
"
SELECT account_id
FROM blocks
WHERE target_account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| Id(row.get(0)))
.collect()
let mut conn = self.conn.get().expect("TODO");
conn.query(
"SELECT account_id FROM blocks WHERE target_account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| Id(row.get(0)))
.collect()
}
/// Query Postgres for all current domain blocks
@ -164,37 +131,27 @@ SELECT account_id
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
self.conn
.get()
.unwrap()
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
let mut conn = self.conn.get().expect("TODO");
conn.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&*user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))
.collect()
}
/// Test whether a user owns a list
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
let mut conn = self.conn.get().unwrap();
let mut conn = self.conn.get().expect("TODO");
// For the Postgres query, `id` = list number; `account_id` = user.id
let rows = &conn
.query(
"
SELECT id, account_id
FROM lists
WHERE id = $1
LIMIT 1",
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
&[&list_id],
)
.expect("Hard-coded query will return Some([0 or more rows])");
match rows.get(0) {
None => false,
Some(row) => Id(row.get(1)) == user_id,
}
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
}
}

View File

@ -6,47 +6,13 @@
// #[cfg(not(test))]
use super::postgres::PgPool;
use super::query;
use super::query::Query;
use crate::err::TimelineErr;
use super::{Content, Reach, Stream, Timeline};
use crate::messages::Id;
use hashbrown::HashSet;
use lru::LruCache;
use warp::{filters::BoxedFilter, path, reject::Rejection, Filter};
/// Helper macro to match on the first of any of the provided filters
macro_rules! any_of {
($filter:expr, $($other_filter:expr),*) => {
$filter$(.or($other_filter).unify())*.boxed()
};
}
macro_rules! parse_sse_query {
(path => $start:tt $(/ $next:tt)*
endpoint => $endpoint:expr) => {
path!($start $(/ $next)*)
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: $endpoint.to_string(),
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
};
}
use warp::reject::Rejection;
#[derive(Clone, Debug, PartialEq)]
pub struct Subscription {
@ -77,49 +43,24 @@ impl Default for Subscription {
}
impl Subscription {
pub fn from_ws_request(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> {
parse_ws_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone()))
.boxed()
}
pub fn from_sse_request(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> {
any_of!(
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user" / "notification"
endpoint => "user:notification" ),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "user"
endpoint => "user"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public" / "local"
endpoint => "public:local"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "public"
endpoint => "public"),
parse_sse_query!(
path => "api" / "v1" / "streaming" / "direct"
endpoint => "direct"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
endpoint => "hashtag:local"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
endpoint => "hashtag"),
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
endpoint => "list")
)
// because SSE requests place their `access_token` in the header instead of in a query
// parameter, we need to update our Query if the header has a token
.and(query::OptionalAccessToken::from_sse_header())
.and_then(Query::update_access_token)
.and_then(move |q| Subscription::from_query(q, pg_pool.clone()))
.boxed()
}
pub(super) fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
pub(super) fn query_postgres(q: Query, pool: PgPool) -> Result<Self, Rejection> {
let user = pool.clone().select_user(&q.access_token)?;
let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
let timeline = {
let tl = Timeline::from_query_and_user(&q, &user)?;
let pool = pool.clone();
use Stream::*;
match tl {
Timeline(Hashtag(_), reach, stream) => {
let tag = pool.select_hashtag_id(&q.hashtag)?;
Timeline(Hashtag(tag), reach, stream)
}
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
Err(warp::reject::custom("Error: Missing access token"))?
}
other_tl => other_tl,
}
};
let hashtag_name = match timeline {
Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
_non_hashtag_timeline => None,
@ -138,179 +79,3 @@ impl Subscription {
})
}
}
fn parse_ws_query() -> BoxedFilter<(Query,)> {
path!("api" / "v1" / "streaming")
.and(path::end())
.and(warp::query())
.and(query::Auth::to_filter())
.and(query::Media::to_filter())
.and(query::Hashtag::to_filter())
.and(query::List::to_filter())
.map(
|stream: query::Stream,
auth: query::Auth,
media: query::Media,
hashtag: query::Hashtag,
list: query::List| {
Query {
access_token: auth.access_token,
stream: stream.stream,
media: media.is_truthy(),
hashtag: hashtag.tag,
list: list.list,
}
},
)
.boxed()
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub struct Timeline(pub Stream, pub Reach, pub Content);
impl Timeline {
pub fn empty() -> Self {
use {Content::*, Reach::*, Stream::*};
Self(Unset, Local, Notification)
}
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String, TimelineErr> {
use {Content::*, Reach::*, Stream::*};
Ok(match self {
Timeline(Public, Federated, All) => "timeline:public".into(),
Timeline(Public, Local, All) => "timeline:public:local".into(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
// TODO -- would `.push_str` be faster here?
Timeline(Hashtag(_id), Federated, All) => format!(
"timeline:hashtag:{}",
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
),
Timeline(Hashtag(_id), Local, All) => format!(
"timeline:hashtag:{}:local",
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
),
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
})
}
pub fn from_redis_text(
timeline: &str,
cache: &mut LruCache<String, i64>,
) -> Result<Self, TimelineErr> {
let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) {
Some(id) => Ok(*id),
None => Err(TimelineErr::InvalidInput), // TODO more specific
};
use {Content::*, Reach::*, Stream::*};
Ok(match &timeline.split(':').collect::<Vec<&str>>()[..] {
["public"] => Timeline(Public, Federated, All),
["public", "local"] => Timeline(Public, Local, All),
["public", "media"] => Timeline(Public, Federated, Media),
["public", "local", "media"] => Timeline(Public, Local, Media),
["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All),
["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All),
[id] => Timeline(User(id.parse()?), Federated, All),
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
["list", id] => Timeline(List(id.parse()?), Federated, All),
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
// Other endpoints don't exist:
[..] => Err(TimelineErr::InvalidInput)?,
})
}
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
Ok(match q.stream.as_ref() {
"public" => match q.media {
true => Timeline(Public, Federated, Media),
false => Timeline(Public, Federated, All),
},
"public:local" => match q.media {
true => Timeline(Public, Local, Media),
false => Timeline(Public, Local, All),
},
"public:media" => Timeline(Public, Federated, Media),
"public:local:media" => Timeline(Public, Local, Media),
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
"user" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
"user:notification" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, Notification),
false => Err(custom("Error: Missing access token"))?,
},
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
true => Timeline(List(q.list), Federated, All),
false => Err(warp::reject::custom("Error: Missing access token"))?,
},
"direct" => match user.scopes.contains(&Statuses) {
true => Timeline(Direct(*user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
other => {
log::warn!("Request for nonexistent endpoint: `{}`", other);
Err(custom("Error: Nonexistent endpoint"))?
}
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Stream {
User(Id),
// TODO consider whether List, Direct, and Hashtag should all be `id::Id`s
List(i64),
Direct(i64),
Hashtag(i64),
Public,
Unset,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Reach {
Local,
Federated,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Content {
All,
Media,
Notification,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Scope {
Read,
Statuses,
Notifications,
Lists,
}
pub struct UserData {
pub id: Id,
pub allowed_langs: HashSet<String>,
pub scopes: HashSet<Scope>,
}
impl UserData {
pub fn public() -> Self {
Self {
id: Id(-1),
allowed_langs: HashSet::new(),
scopes: HashSet::new(),
}
}
}

174
src/request/timeline.rs Normal file
View File

@ -0,0 +1,174 @@
use super::query::Query;
use crate::err::TimelineErr;
use crate::messages::Id;
use hashbrown::HashSet;
use lru::LruCache;
use std::convert::TryFrom;
use warp::reject::Rejection;
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub struct Timeline(pub Stream, pub Reach, pub Content);
impl Timeline {
pub fn empty() -> Self {
use {Content::*, Reach::*, Stream::*};
Self(Unset, Local, Notification)
}
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String, TimelineErr> {
use {Content::*, Reach::*, Stream::*};
Ok(match self {
Timeline(Public, Federated, All) => "timeline:public".into(),
Timeline(Public, Local, All) => "timeline:public:local".into(),
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
// TODO -- would `.push_str` be faster here?
Timeline(Hashtag(_id), Federated, All) => format!(
"timeline:hashtag:{}",
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
),
Timeline(Hashtag(_id), Local, All) => format!(
"timeline:hashtag:{}:local",
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
),
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
})
}
pub fn from_redis_text(
timeline: &str,
cache: &mut LruCache<String, i64>,
) -> Result<Self, TimelineErr> {
let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) {
Some(id) => Ok(*id),
None => Err(TimelineErr::InvalidInput), // TODO more specific
};
use {Content::*, Reach::*, Stream::*};
Ok(match &timeline.split(':').collect::<Vec<&str>>()[..] {
["public"] => Timeline(Public, Federated, All),
["public", "local"] => Timeline(Public, Local, All),
["public", "media"] => Timeline(Public, Federated, Media),
["public", "local", "media"] => Timeline(Public, Local, Media),
["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All),
["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All),
[id] => Timeline(User(id.parse()?), Federated, All),
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
["list", id] => Timeline(List(id.parse()?), Federated, All),
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
// Other endpoints don't exist:
[..] => Err(TimelineErr::InvalidInput)?,
})
}
pub fn from_query_and_user(q: &Query, user: &UserData) -> Result<Self, Rejection> {
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
Ok(match q.stream.as_ref() {
"public" => match q.media {
true => Timeline(Public, Federated, Media),
false => Timeline(Public, Federated, All),
},
"public:local" => match q.media {
true => Timeline(Public, Local, Media),
false => Timeline(Public, Local, All),
},
"public:media" => Timeline(Public, Federated, Media),
"public:local:media" => Timeline(Public, Local, Media),
"hashtag" => Timeline(Hashtag(0), Federated, All),
"hashtag:local" => Timeline(Hashtag(0), Local, All),
"user" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
"user:notification" => match user.scopes.contains(&Statuses) {
true => Timeline(User(user.id), Federated, Notification),
false => Err(custom("Error: Missing access token"))?,
},
"list" => match user.scopes.contains(&Lists) {
true => Timeline(List(q.list), Federated, All),
false => Err(warp::reject::custom("Error: Missing access token"))?,
},
"direct" => match user.scopes.contains(&Statuses) {
true => Timeline(Direct(*user.id), Federated, All),
false => Err(custom("Error: Missing access token"))?,
},
other => {
log::warn!("Request for nonexistent endpoint: `{}`", other);
Err(custom("Error: Nonexistent endpoint"))?
}
})
}
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Stream {
User(Id),
// TODO consider whether List, Direct, and Hashtag should all be `id::Id`s
List(i64),
Direct(i64),
Hashtag(i64),
Public,
Unset,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Reach {
Local,
Federated,
}
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
pub enum Content {
All,
Media,
Notification,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Scope {
Read,
Statuses,
Notifications,
Lists,
}
impl TryFrom<&str> for Scope {
type Error = TimelineErr;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s {
"read" => Ok(Scope::Read),
"read:statuses" => Ok(Scope::Statuses),
"read:notifications" => Ok(Scope::Notifications),
"read:lists" => Ok(Scope::Lists),
"write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes
unexpected => {
log::warn!("Ignoring unknown scope `{}`", unexpected);
Err(TimelineErr::InvalidInput)
}
}
}
}
pub struct UserData {
pub id: Id,
pub allowed_langs: HashSet<String>,
pub scopes: HashSet<Scope>,
}
impl UserData {
pub fn public() -> Self {
Self {
id: Id(-1),
allowed_langs: HashSet::new(),
scopes: HashSet::new(),
}
}
}

View File

@ -3,22 +3,16 @@ pub use err::RedisConnErr;
use super::msg::{RedisParseErr, RedisParseOutput};
use super::ManagerErr;
use crate::{
config::Redis,
messages::Event,
request::{Stream, Timeline},
};
use std::{
convert::{TryFrom, TryInto},
io::{Read, Write},
net::TcpStream,
str,
time::Duration,
};
use crate::config::Redis;
use crate::messages::Event;
use crate::request::{Stream, Timeline};
use futures::{Async, Poll};
use lru::LruCache;
use std::convert::{TryFrom, TryInto};
use std::io::{Read, Write};
use std::net::TcpStream;
use std::str;
use std::time::Duration;
type Result<T> = std::result::Result<T, RedisConnErr>;
@ -46,7 +40,7 @@ impl RedisConn {
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
// the tag number instead of the tag name. This would save us from dealing
// with a cache here and would be consistent with how lists/users are handled.
redis_namespace: redis_cfg.namespace.clone(),
redis_namespace: redis_cfg.namespace.clone().0,
redis_input: Vec::new(),
};
Ok(redis_conn)
@ -61,14 +55,12 @@ impl RedisConn {
self.redis_input.extend_from_slice(&buffer[..n]);
break;
}
Ok(n) => {
self.redis_input.extend_from_slice(&buffer[..n]);
}
Ok(n) => self.redis_input.extend_from_slice(&buffer[..n]),
Err(_) => break,
};
if first_read {
size = 2000;
buffer = vec![0u8; size];
buffer = vec![0_u8; size];
first_read = false;
}
}
@ -117,50 +109,6 @@ impl RedisConn {
self.tag_name_cache.put(id, hashtag);
}
fn new_connection(addr: &str, pass: Option<&String>) -> Result<TcpStream> {
match TcpStream::connect(&addr) {
Ok(mut conn) => {
if let Some(password) = pass {
Self::auth_connection(&mut conn, &addr, password)?;
}
Self::validate_connection(&mut conn, &addr)?;
conn.set_read_timeout(Some(Duration::from_millis(10)))
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
Ok(conn)
}
Err(e) => Err(RedisConnErr::with_addr(&addr, e)),
}
}
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let mut buffer = vec![0u8; 5];
conn.read_exact(&mut buffer)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let reply = String::from_utf8_lossy(&buffer);
match &*reply {
"+OK\r\n" => (),
_ => Err(RedisConnErr::IncorrectPassword(pass.to_string()))?,
};
Ok(())
}
fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> {
conn.write_all(b"PING\r\n")
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let mut buffer = vec![0u8; 7];
conn.read_exact(&mut buffer)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let reply = String::from_utf8_lossy(&buffer);
match &*reply {
"+PONG\r\n" => Ok(()),
"-NOAUTH" => Err(RedisConnErr::MissingPassword),
"HTTP/1." => Err(RedisConnErr::NotRedis(addr.to_string())),
_ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())),
}
}
pub fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> {
let hashtag = match timeline {
Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id),
@ -182,6 +130,44 @@ impl RedisConn {
self.secondary.write_all(&secondary_cmd.as_bytes())?;
Ok(())
}
fn new_connection(addr: &str, pass: Option<&String>) -> Result<TcpStream> {
let mut conn = TcpStream::connect(&addr)?;
if let Some(password) = pass {
Self::auth_connection(&mut conn, &addr, password)?;
}
Self::validate_connection(&mut conn, &addr)?;
conn.set_read_timeout(Some(Duration::from_millis(10)))
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
Ok(conn)
}
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let mut buffer = vec![0u8; 5];
conn.read_exact(&mut buffer)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
if String::from_utf8_lossy(&buffer) != "+OK\r\n" {
Err(RedisConnErr::IncorrectPassword(pass.to_string()))?
}
Ok(())
}
fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> {
conn.write_all(b"PING\r\n")
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let mut buffer = vec![0u8; 7];
conn.read_exact(&mut buffer)
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
let reply = String::from_utf8_lossy(&buffer);
match &*reply {
"+PONG\r\n" => Ok(()),
"-NOAUTH" => Err(RedisConnErr::MissingPassword),
"HTTP/1." => Err(RedisConnErr::NotRedis(addr.to_string())),
_ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())),
}
}
}
pub enum RedisCmd {

View File

@ -132,43 +132,38 @@ impl Sse {
let blocks = subscription.blocks;
let event_stream = sse_rx
.filter(move |(timeline, _)| target_timeline == *timeline)
.filter_map(move |(timeline, event)| {
if target_timeline == timeline {
use crate::messages::{
CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind,
};
use crate::messages::{
CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind,
};
use crate::request::Stream::Public;
match event {
TypeSafe(Update { payload, queued_at }) => match timeline {
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
_ if payload.involves_any(&blocks) => None,
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
payload,
queued_at,
})),
},
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind {
match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => {
None
}
_ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(DynEvent {
kind: EventKind::Update(s),
..dyn_event
})),
}
} else {
None
use crate::request::Stream::Public;
match event {
TypeSafe(Update { payload, queued_at }) => match timeline {
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
_ if payload.involves_any(&blocks) => None,
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
payload,
queued_at,
})),
},
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
Dynamic(dyn_event) => {
if let EventKind::Update(s) = dyn_event.kind {
match timeline {
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
_ if s.involves_any(&blocks) => None,
_ => Self::reply_with(Dynamic(DynEvent {
kind: EventKind::Update(s),
..dyn_event
})),
}
} else {
None
}
Ping => None, // pings handled automatically
}
} else {
None
Ping => None, // pings handled automatically
}
})
.then(move |res| {
@ -186,3 +181,4 @@ impl Sse {
)
}
}
// TODO -- split WS and SSE into separate files and add misc stuff from main.rs here