mirror of https://github.com/mastodon/flodgatt
Code reorganization [WIP]
This commit is contained in:
parent
cf55eb7019
commit
a2e879ec3a
|
@ -14,6 +14,7 @@ mod redis_cfg;
|
||||||
mod redis_cfg_types;
|
mod redis_cfg_types;
|
||||||
|
|
||||||
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
|
pub fn merge_dotenv() -> Result<(), err::FatalErr> {
|
||||||
|
// TODO -- should this allow the user to run in a dir without a `.env` file?
|
||||||
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
|
dotenv::from_filename(match env::var("ENV").ok().as_deref() {
|
||||||
Some("production") => ".env.production",
|
Some("production") => ".env.production",
|
||||||
Some("development") | None => ".env",
|
Some("development") | None => ".env",
|
||||||
|
|
|
@ -92,7 +92,7 @@ impl fmt::Debug for Cors<'_> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
|
||||||
#[strum(serialize_all = "snake_case")]
|
#[strum(serialize_all = "snake_case")]
|
||||||
pub enum LogLevelInner {
|
pub enum LogLevelInner {
|
||||||
Trace,
|
Trace,
|
||||||
|
@ -102,7 +102,7 @@ pub enum LogLevelInner {
|
||||||
Error,
|
Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
|
||||||
#[strum(serialize_all = "snake_case")]
|
#[strum(serialize_all = "snake_case")]
|
||||||
pub enum EnvInner {
|
pub enum EnvInner {
|
||||||
Production,
|
Production,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use hashbrown::HashMap;
|
use hashbrown::HashMap;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct EnvVar(pub HashMap<String, String>);
|
pub struct EnvVar(pub HashMap<String, String>);
|
||||||
impl std::ops::Deref for EnvVar {
|
impl std::ops::Deref for EnvVar {
|
||||||
type Target = HashMap<String, String>;
|
type Target = HashMap<String, String>;
|
||||||
|
@ -94,6 +95,7 @@ macro_rules! from_env_var {
|
||||||
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
|
let (env_var, allowed_values) = ($env_var:tt, $allowed_values:expr);
|
||||||
let from_str = |$arg:ident| $body:expr;
|
let from_str = |$arg:ident| $body:expr;
|
||||||
) => {
|
) => {
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct $name(pub $type);
|
pub struct $name(pub $type);
|
||||||
impl std::fmt::Debug for $name {
|
impl std::fmt::Debug for $name {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
|
|
@ -2,7 +2,7 @@ use super::{postgres_cfg_types::*, EnvVar};
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use urlencoding;
|
use urlencoding;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Postgres {
|
pub struct Postgres {
|
||||||
pub user: PgUser,
|
pub user: PgUser,
|
||||||
pub host: PgHost,
|
pub host: PgHost,
|
||||||
|
|
|
@ -49,7 +49,7 @@ from_env_var!(
|
||||||
let from_str = |s| PgSslInner::from_str(s).ok();
|
let from_str = |s| PgSslInner::from_str(s).ok();
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(EnumString, EnumVariantNames, Debug)]
|
#[derive(EnumString, EnumVariantNames, Debug, Clone)]
|
||||||
#[strum(serialize_all = "snake_case")]
|
#[strum(serialize_all = "snake_case")]
|
||||||
pub enum PgSslInner {
|
pub enum PgSslInner {
|
||||||
Prefer,
|
Prefer,
|
||||||
|
|
11
src/main.rs
11
src/main.rs
|
@ -1,7 +1,7 @@
|
||||||
use flodgatt::config;
|
use flodgatt::config;
|
||||||
use flodgatt::err::FatalErr;
|
use flodgatt::err::FatalErr;
|
||||||
use flodgatt::messages::Event;
|
use flodgatt::messages::Event;
|
||||||
use flodgatt::request::{PgPool, Subscription, Timeline};
|
use flodgatt::request::{self, Subscription, Timeline};
|
||||||
use flodgatt::response::redis;
|
use flodgatt::response::redis;
|
||||||
use flodgatt::response::stream;
|
use flodgatt::response::stream;
|
||||||
|
|
||||||
|
@ -27,14 +27,16 @@ fn main() -> Result<(), FatalErr> {
|
||||||
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
let (event_tx, event_rx) = watch::channel((Timeline::empty(), Event::Ping));
|
||||||
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
let shared_pg_conn = PgPool::new(postgres_cfg, *cfg.whitelist_mode);
|
let request_handler = request::Handler::new(postgres_cfg, *cfg.whitelist_mode);
|
||||||
let poll_freq = *redis_cfg.polling_interval;
|
let poll_freq = *redis_cfg.polling_interval;
|
||||||
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
let shared_manager = redis::Manager::try_from(redis_cfg, event_tx, cmd_rx)?.into_arc();
|
||||||
|
|
||||||
// Server Sent Events
|
// Server Sent Events
|
||||||
let sse_manager = shared_manager.clone();
|
let sse_manager = shared_manager.clone();
|
||||||
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
let (sse_rx, sse_cmd_tx) = (event_rx.clone(), cmd_tx.clone());
|
||||||
let sse = Subscription::from_sse_request(shared_pg_conn.clone())
|
|
||||||
|
let sse = request_handler
|
||||||
|
.parse_sse_request()
|
||||||
.and(warp::sse())
|
.and(warp::sse())
|
||||||
.map(
|
.map(
|
||||||
move |subscription: Subscription, client_conn: warp::sse::Sse| {
|
move |subscription: Subscription, client_conn: warp::sse::Sse| {
|
||||||
|
@ -56,7 +58,8 @@ fn main() -> Result<(), FatalErr> {
|
||||||
|
|
||||||
// WebSocket
|
// WebSocket
|
||||||
let ws_manager = shared_manager.clone();
|
let ws_manager = shared_manager.clone();
|
||||||
let ws = Subscription::from_ws_request(shared_pg_conn)
|
let ws = request_handler
|
||||||
|
.parse_ws_request()
|
||||||
.and(warp::ws::ws2())
|
.and(warp::ws::ws2())
|
||||||
.map(move |subscription: Subscription, ws: Ws2| {
|
.map(move |subscription: Subscription, ws: Ws2| {
|
||||||
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
log::info!("Incoming websocket request for {:?}", subscription.timeline);
|
||||||
|
|
|
@ -1,23 +1,53 @@
|
||||||
//! Parse the client request and return a Subscription
|
//! Parse the client request and return a Subscription
|
||||||
mod postgres;
|
mod postgres;
|
||||||
mod query;
|
mod query;
|
||||||
|
pub mod timeline;
|
||||||
|
|
||||||
mod subscription;
|
mod subscription;
|
||||||
|
|
||||||
pub use self::postgres::PgPool;
|
pub use self::postgres::PgPool;
|
||||||
// TODO consider whether we can remove `Stream` from public API
|
// TODO consider whether we can remove `Stream` from public API
|
||||||
pub use subscription::{Blocks, Stream, Subscription, Timeline};
|
pub use subscription::{Blocks, Subscription};
|
||||||
pub use subscription::{Content, Reach};
|
pub use timeline::{Content, Reach, Stream, Timeline};
|
||||||
|
|
||||||
use self::query::Query;
|
use self::query::Query;
|
||||||
use crate::config;
|
use crate::config;
|
||||||
use warp::{filters::BoxedFilter, path, reject::Rejection, Filter};
|
use warp::{filters::BoxedFilter, path, Filter};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod sse_test;
|
mod sse_test;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod ws_test;
|
mod ws_test;
|
||||||
|
|
||||||
|
/// Helper macro to match on the first of any of the provided filters
|
||||||
|
macro_rules! any_of {
|
||||||
|
($filter:expr, $($other_filter:expr),*) => {
|
||||||
|
$filter$(.or($other_filter).unify())*.boxed()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
macro_rules! parse_sse_query {
|
||||||
|
(path => $start:tt $(/ $next:tt)*
|
||||||
|
endpoint => $endpoint:expr) => {
|
||||||
|
path!($start $(/ $next)*)
|
||||||
|
.and(query::Auth::to_filter())
|
||||||
|
.and(query::Media::to_filter())
|
||||||
|
.and(query::Hashtag::to_filter())
|
||||||
|
.and(query::List::to_filter())
|
||||||
|
.map(|auth: query::Auth, media: query::Media, hashtag: query::Hashtag, list: query::List| {
|
||||||
|
Query {
|
||||||
|
access_token: auth.access_token,
|
||||||
|
stream: $endpoint.to_string(),
|
||||||
|
media: media.is_truthy(),
|
||||||
|
hashtag: hashtag.tag,
|
||||||
|
list: list.list,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.boxed()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Handler {
|
pub struct Handler {
|
||||||
pg_conn: PgPool,
|
pg_conn: PgPool,
|
||||||
}
|
}
|
||||||
|
@ -29,14 +59,47 @@ impl Handler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_ws_request(&self) -> BoxedFilter<(Subscription,)> {
|
pub fn parse_ws_request(&self) -> BoxedFilter<(Subscription,)> {
|
||||||
let pg_conn = self.pg_conn.clone();
|
let pg_conn = self.pg_conn.clone();
|
||||||
parse_ws_query()
|
parse_ws_query()
|
||||||
.and(query::OptionalAccessToken::from_ws_header())
|
.and(query::OptionalAccessToken::from_ws_header())
|
||||||
.and_then(Query::update_access_token)
|
.and_then(Query::update_access_token)
|
||||||
.and_then(move |q| Subscription::from_query(q, pg_conn.clone()))
|
.and_then(move |q| Subscription::query_postgres(q, pg_conn.clone()))
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn parse_sse_request(&self) -> BoxedFilter<(Subscription,)> {
|
||||||
|
let pg_conn = self.pg_conn.clone();
|
||||||
|
any_of!(
|
||||||
|
parse_sse_query!(
|
||||||
|
path => "api" / "v1" / "streaming" / "user" / "notification"
|
||||||
|
endpoint => "user:notification" ),
|
||||||
|
parse_sse_query!(
|
||||||
|
path => "api" / "v1" / "streaming" / "user"
|
||||||
|
endpoint => "user"),
|
||||||
|
parse_sse_query!(
|
||||||
|
path => "api" / "v1" / "streaming" / "public" / "local"
|
||||||
|
endpoint => "public:local"),
|
||||||
|
parse_sse_query!(
|
||||||
|
path => "api" / "v1" / "streaming" / "public"
|
||||||
|
endpoint => "public"),
|
||||||
|
parse_sse_query!(
|
||||||
|
path => "api" / "v1" / "streaming" / "direct"
|
||||||
|
endpoint => "direct"),
|
||||||
|
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
||||||
|
endpoint => "hashtag:local"),
|
||||||
|
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
|
||||||
|
endpoint => "hashtag"),
|
||||||
|
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
|
||||||
|
endpoint => "list")
|
||||||
|
)
|
||||||
|
// because SSE requests place their `access_token` in the header instead of in a query
|
||||||
|
// parameter, we need to update our Query if the header has a token
|
||||||
|
.and(query::OptionalAccessToken::from_sse_header())
|
||||||
|
.and_then(Query::update_access_token)
|
||||||
|
.and_then(move |q| Subscription::query_postgres(q, pg_conn.clone()))
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
//! Postgres queries
|
//! Postgres queries
|
||||||
use crate::{
|
use crate::config;
|
||||||
config,
|
use crate::messages::Id;
|
||||||
messages::Id,
|
use crate::request::timeline::{Scope, UserData};
|
||||||
request::subscription::{Scope, UserData},
|
|
||||||
};
|
|
||||||
use ::postgres;
|
use ::postgres;
|
||||||
use hashbrown::HashSet;
|
use hashbrown::HashSet;
|
||||||
use r2d2_postgres::PostgresConnectionManager;
|
use r2d2_postgres::PostgresConnectionManager;
|
||||||
|
use std::convert::TryFrom;
|
||||||
use warp::reject::Rejection;
|
use warp::reject::Rejection;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
|
@ -14,6 +14,7 @@ pub struct PgPool {
|
||||||
pub conn: r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>,
|
pub conn: r2d2::Pool<PostgresConnectionManager<postgres::NoTls>>,
|
||||||
whitelist_mode: bool,
|
whitelist_mode: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PgPool {
|
impl PgPool {
|
||||||
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
pub fn new(pg_cfg: config::Postgres, whitelist_mode: bool) -> Self {
|
||||||
let mut cfg = postgres::Config::new();
|
let mut cfg = postgres::Config::new();
|
||||||
|
@ -40,15 +41,11 @@ impl PgPool {
|
||||||
let mut conn = self.conn.get().unwrap();
|
let mut conn = self.conn.get().unwrap();
|
||||||
if let Some(token) = token {
|
if let Some(token) = token {
|
||||||
let query_rows = conn
|
let query_rows = conn
|
||||||
.query(
|
.query("
|
||||||
"
|
|
||||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||||
FROM
|
FROM oauth_access_tokens
|
||||||
oauth_access_tokens
|
INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
|
||||||
INNER JOIN users ON
|
WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL
|
||||||
oauth_access_tokens.resource_owner_id = users.id
|
|
||||||
WHERE oauth_access_tokens.token = $1
|
|
||||||
AND oauth_access_tokens.revoked_at IS NULL
|
|
||||||
LIMIT 1",
|
LIMIT 1",
|
||||||
&[&token.to_owned()],
|
&[&token.to_owned()],
|
||||||
)
|
)
|
||||||
|
@ -57,29 +54,20 @@ LIMIT 1",
|
||||||
let id = Id(result_columns.get(1));
|
let id = Id(result_columns.get(1));
|
||||||
let allowed_langs = result_columns
|
let allowed_langs = result_columns
|
||||||
.try_get::<_, Vec<_>>(2)
|
.try_get::<_, Vec<_>>(2)
|
||||||
.unwrap_or_else(|_| Vec::new())
|
.unwrap_or_default()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut scopes: HashSet<Scope> = result_columns
|
let mut scopes: HashSet<Scope> = result_columns
|
||||||
.get::<_, String>(3)
|
.get::<_, String>(3)
|
||||||
.split(' ')
|
.split(' ')
|
||||||
.filter_map(|scope| match scope {
|
.filter_map(|scope| Scope::try_from(scope).ok())
|
||||||
"read" => Some(Scope::Read),
|
|
||||||
"read:statuses" => Some(Scope::Statuses),
|
|
||||||
"read:notifications" => Some(Scope::Notifications),
|
|
||||||
"read:lists" => Some(Scope::Lists),
|
|
||||||
"write" | "follow" => None, // ignore write scopes
|
|
||||||
unexpected => {
|
|
||||||
log::warn!("Ignoring unknown scope `{}`", unexpected);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
.collect();
|
||||||
// We don't need to separately track read auth - it's just all three others
|
// We don't need to separately track read auth - it's just all three others
|
||||||
if scopes.remove(&Scope::Read) {
|
if scopes.contains(&Scope::Read) {
|
||||||
scopes.insert(Scope::Statuses);
|
scopes = vec![Scope::Statuses, Scope::Notifications, Scope::Lists]
|
||||||
scopes.insert(Scope::Notifications);
|
.into_iter()
|
||||||
scopes.insert(Scope::Lists);
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(UserData {
|
Ok(UserData {
|
||||||
|
@ -98,19 +86,10 @@ LIMIT 1",
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
|
pub fn select_hashtag_id(self, tag_name: &str) -> Result<i64, Rejection> {
|
||||||
let mut conn = self.conn.get().unwrap();
|
let mut conn = self.conn.get().expect("TODO");
|
||||||
let rows = &conn
|
conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name])
|
||||||
.query(
|
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||||
"
|
.get(0)
|
||||||
SELECT id
|
|
||||||
FROM tags
|
|
||||||
WHERE name = $1
|
|
||||||
LIMIT 1",
|
|
||||||
&[&tag_name],
|
|
||||||
)
|
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
|
||||||
|
|
||||||
rows.get(0)
|
|
||||||
.map(|row| row.get(0))
|
.map(|row| row.get(0))
|
||||||
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
.ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist."))
|
||||||
}
|
}
|
||||||
|
@ -120,43 +99,31 @@ LIMIT 1",
|
||||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||||
/// the user adds until they refresh/reconnect.
|
/// the user adds until they refresh/reconnect.
|
||||||
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
|
pub fn select_blocked_users(self, user_id: Id) -> HashSet<Id> {
|
||||||
self.conn
|
let mut conn = self.conn.get().expect("TODO");
|
||||||
.get()
|
conn.query(
|
||||||
.unwrap()
|
"SELECT target_account_id FROM blocks WHERE account_id = $1
|
||||||
.query(
|
UNION SELECT target_account_id FROM mutes WHERE account_id = $1",
|
||||||
"
|
&[&*user_id],
|
||||||
SELECT target_account_id
|
)
|
||||||
FROM blocks
|
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||||
WHERE account_id = $1
|
.iter()
|
||||||
UNION SELECT target_account_id
|
.map(|row| Id(row.get(0)))
|
||||||
FROM mutes
|
.collect()
|
||||||
WHERE account_id = $1",
|
|
||||||
&[&*user_id],
|
|
||||||
)
|
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
|
||||||
.iter()
|
|
||||||
.map(|row| Id(row.get(0)))
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
/// Query Postgres for everyone who has blocked the user
|
/// Query Postgres for everyone who has blocked the user
|
||||||
///
|
///
|
||||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||||
/// the user adds until they refresh/reconnect.
|
/// the user adds until they refresh/reconnect.
|
||||||
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
|
pub fn select_blocking_users(self, user_id: Id) -> HashSet<Id> {
|
||||||
self.conn
|
let mut conn = self.conn.get().expect("TODO");
|
||||||
.get()
|
conn.query(
|
||||||
.unwrap()
|
"SELECT account_id FROM blocks WHERE target_account_id = $1",
|
||||||
.query(
|
&[&*user_id],
|
||||||
"
|
)
|
||||||
SELECT account_id
|
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||||
FROM blocks
|
.iter()
|
||||||
WHERE target_account_id = $1",
|
.map(|row| Id(row.get(0)))
|
||||||
&[&*user_id],
|
.collect()
|
||||||
)
|
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
|
||||||
.iter()
|
|
||||||
.map(|row| Id(row.get(0)))
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Query Postgres for all current domain blocks
|
/// Query Postgres for all current domain blocks
|
||||||
|
@ -164,37 +131,27 @@ SELECT account_id
|
||||||
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
/// **NOTE**: because we check this when the user connects, it will not include any blocks
|
||||||
/// the user adds until they refresh/reconnect.
|
/// the user adds until they refresh/reconnect.
|
||||||
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
|
pub fn select_blocked_domains(self, user_id: Id) -> HashSet<String> {
|
||||||
self.conn
|
let mut conn = self.conn.get().expect("TODO");
|
||||||
.get()
|
conn.query(
|
||||||
.unwrap()
|
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
||||||
.query(
|
&[&*user_id],
|
||||||
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
|
)
|
||||||
&[&*user_id],
|
.expect("Hard-coded query will return Some([0 or more rows])")
|
||||||
)
|
.iter()
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])")
|
.map(|row| row.get(0))
|
||||||
.iter()
|
.collect()
|
||||||
.map(|row| row.get(0))
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test whether a user owns a list
|
/// Test whether a user owns a list
|
||||||
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
|
pub fn user_owns_list(self, user_id: Id, list_id: i64) -> bool {
|
||||||
let mut conn = self.conn.get().unwrap();
|
let mut conn = self.conn.get().expect("TODO");
|
||||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||||
let rows = &conn
|
let rows = &conn
|
||||||
.query(
|
.query(
|
||||||
"
|
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||||
SELECT id, account_id
|
|
||||||
FROM lists
|
|
||||||
WHERE id = $1
|
|
||||||
LIMIT 1",
|
|
||||||
&[&list_id],
|
&[&list_id],
|
||||||
)
|
)
|
||||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||||
|
rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)
|
||||||
match rows.get(0) {
|
|
||||||
None => false,
|
|
||||||
Some(row) => Id(row.get(1)) == user_id,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,47 +6,13 @@
|
||||||
// #[cfg(not(test))]
|
// #[cfg(not(test))]
|
||||||
|
|
||||||
use super::postgres::PgPool;
|
use super::postgres::PgPool;
|
||||||
use super::query;
|
|
||||||
use super::query::Query;
|
use super::query::Query;
|
||||||
use crate::err::TimelineErr;
|
use super::{Content, Reach, Stream, Timeline};
|
||||||
|
|
||||||
use crate::messages::Id;
|
use crate::messages::Id;
|
||||||
|
|
||||||
use hashbrown::HashSet;
|
use hashbrown::HashSet;
|
||||||
use lru::LruCache;
|
|
||||||
use warp::{filters::BoxedFilter, path, reject::Rejection, Filter};
|
|
||||||
|
|
||||||
/// Helper macro to match on the first of any of the provided filters
|
use warp::reject::Rejection;
|
||||||
macro_rules! any_of {
|
|
||||||
($filter:expr, $($other_filter:expr),*) => {
|
|
||||||
$filter$(.or($other_filter).unify())*.boxed()
|
|
||||||
};
|
|
||||||
}
|
|
||||||
macro_rules! parse_sse_query {
|
|
||||||
(path => $start:tt $(/ $next:tt)*
|
|
||||||
endpoint => $endpoint:expr) => {
|
|
||||||
path!($start $(/ $next)*)
|
|
||||||
.and(query::Auth::to_filter())
|
|
||||||
.and(query::Media::to_filter())
|
|
||||||
.and(query::Hashtag::to_filter())
|
|
||||||
.and(query::List::to_filter())
|
|
||||||
.map(
|
|
||||||
|auth: query::Auth,
|
|
||||||
media: query::Media,
|
|
||||||
hashtag: query::Hashtag,
|
|
||||||
list: query::List| {
|
|
||||||
Query {
|
|
||||||
access_token: auth.access_token,
|
|
||||||
stream: $endpoint.to_string(),
|
|
||||||
media: media.is_truthy(),
|
|
||||||
hashtag: hashtag.tag,
|
|
||||||
list: list.list,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.boxed()
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct Subscription {
|
pub struct Subscription {
|
||||||
|
@ -77,49 +43,24 @@ impl Default for Subscription {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Subscription {
|
impl Subscription {
|
||||||
pub fn from_ws_request(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> {
|
pub(super) fn query_postgres(q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||||
parse_ws_query()
|
|
||||||
.and(query::OptionalAccessToken::from_ws_header())
|
|
||||||
.and_then(Query::update_access_token)
|
|
||||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone()))
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_sse_request(pg_pool: PgPool) -> BoxedFilter<(Subscription,)> {
|
|
||||||
any_of!(
|
|
||||||
parse_sse_query!(
|
|
||||||
path => "api" / "v1" / "streaming" / "user" / "notification"
|
|
||||||
endpoint => "user:notification" ),
|
|
||||||
parse_sse_query!(
|
|
||||||
path => "api" / "v1" / "streaming" / "user"
|
|
||||||
endpoint => "user"),
|
|
||||||
parse_sse_query!(
|
|
||||||
path => "api" / "v1" / "streaming" / "public" / "local"
|
|
||||||
endpoint => "public:local"),
|
|
||||||
parse_sse_query!(
|
|
||||||
path => "api" / "v1" / "streaming" / "public"
|
|
||||||
endpoint => "public"),
|
|
||||||
parse_sse_query!(
|
|
||||||
path => "api" / "v1" / "streaming" / "direct"
|
|
||||||
endpoint => "direct"),
|
|
||||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag" / "local"
|
|
||||||
endpoint => "hashtag:local"),
|
|
||||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "hashtag"
|
|
||||||
endpoint => "hashtag"),
|
|
||||||
parse_sse_query!(path => "api" / "v1" / "streaming" / "list"
|
|
||||||
endpoint => "list")
|
|
||||||
)
|
|
||||||
// because SSE requests place their `access_token` in the header instead of in a query
|
|
||||||
// parameter, we need to update our Query if the header has a token
|
|
||||||
.and(query::OptionalAccessToken::from_sse_header())
|
|
||||||
.and_then(Query::update_access_token)
|
|
||||||
.and_then(move |q| Subscription::from_query(q, pg_pool.clone()))
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
|
||||||
let user = pool.clone().select_user(&q.access_token)?;
|
let user = pool.clone().select_user(&q.access_token)?;
|
||||||
let timeline = Timeline::from_query_and_user(&q, &user, pool.clone())?;
|
let timeline = {
|
||||||
|
let tl = Timeline::from_query_and_user(&q, &user)?;
|
||||||
|
let pool = pool.clone();
|
||||||
|
use Stream::*;
|
||||||
|
match tl {
|
||||||
|
Timeline(Hashtag(_), reach, stream) => {
|
||||||
|
let tag = pool.select_hashtag_id(&q.hashtag)?;
|
||||||
|
Timeline(Hashtag(tag), reach, stream)
|
||||||
|
}
|
||||||
|
Timeline(List(list_id), _, _) if !pool.user_owns_list(user.id, list_id) => {
|
||||||
|
Err(warp::reject::custom("Error: Missing access token"))?
|
||||||
|
}
|
||||||
|
other_tl => other_tl,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let hashtag_name = match timeline {
|
let hashtag_name = match timeline {
|
||||||
Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
|
Timeline(Stream::Hashtag(_), _, _) => Some(q.hashtag),
|
||||||
_non_hashtag_timeline => None,
|
_non_hashtag_timeline => None,
|
||||||
|
@ -138,179 +79,3 @@ impl Subscription {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_ws_query() -> BoxedFilter<(Query,)> {
|
|
||||||
path!("api" / "v1" / "streaming")
|
|
||||||
.and(path::end())
|
|
||||||
.and(warp::query())
|
|
||||||
.and(query::Auth::to_filter())
|
|
||||||
.and(query::Media::to_filter())
|
|
||||||
.and(query::Hashtag::to_filter())
|
|
||||||
.and(query::List::to_filter())
|
|
||||||
.map(
|
|
||||||
|stream: query::Stream,
|
|
||||||
auth: query::Auth,
|
|
||||||
media: query::Media,
|
|
||||||
hashtag: query::Hashtag,
|
|
||||||
list: query::List| {
|
|
||||||
Query {
|
|
||||||
access_token: auth.access_token,
|
|
||||||
stream: stream.stream,
|
|
||||||
media: media.is_truthy(),
|
|
||||||
hashtag: hashtag.tag,
|
|
||||||
list: list.list,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
|
||||||
pub struct Timeline(pub Stream, pub Reach, pub Content);
|
|
||||||
|
|
||||||
impl Timeline {
|
|
||||||
pub fn empty() -> Self {
|
|
||||||
use {Content::*, Reach::*, Stream::*};
|
|
||||||
Self(Unset, Local, Notification)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String, TimelineErr> {
|
|
||||||
use {Content::*, Reach::*, Stream::*};
|
|
||||||
Ok(match self {
|
|
||||||
Timeline(Public, Federated, All) => "timeline:public".into(),
|
|
||||||
Timeline(Public, Local, All) => "timeline:public:local".into(),
|
|
||||||
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
|
|
||||||
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
|
|
||||||
// TODO -- would `.push_str` be faster here?
|
|
||||||
Timeline(Hashtag(_id), Federated, All) => format!(
|
|
||||||
"timeline:hashtag:{}",
|
|
||||||
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
|
|
||||||
),
|
|
||||||
Timeline(Hashtag(_id), Local, All) => format!(
|
|
||||||
"timeline:hashtag:{}:local",
|
|
||||||
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
|
|
||||||
),
|
|
||||||
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
|
|
||||||
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
|
|
||||||
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
|
|
||||||
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
|
|
||||||
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_redis_text(
|
|
||||||
timeline: &str,
|
|
||||||
cache: &mut LruCache<String, i64>,
|
|
||||||
) -> Result<Self, TimelineErr> {
|
|
||||||
let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) {
|
|
||||||
Some(id) => Ok(*id),
|
|
||||||
None => Err(TimelineErr::InvalidInput), // TODO more specific
|
|
||||||
};
|
|
||||||
|
|
||||||
use {Content::*, Reach::*, Stream::*};
|
|
||||||
Ok(match &timeline.split(':').collect::<Vec<&str>>()[..] {
|
|
||||||
["public"] => Timeline(Public, Federated, All),
|
|
||||||
["public", "local"] => Timeline(Public, Local, All),
|
|
||||||
["public", "media"] => Timeline(Public, Federated, Media),
|
|
||||||
["public", "local", "media"] => Timeline(Public, Local, Media),
|
|
||||||
["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All),
|
|
||||||
["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All),
|
|
||||||
[id] => Timeline(User(id.parse()?), Federated, All),
|
|
||||||
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
|
|
||||||
["list", id] => Timeline(List(id.parse()?), Federated, All),
|
|
||||||
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
|
|
||||||
// Other endpoints don't exist:
|
|
||||||
[..] => Err(TimelineErr::InvalidInput)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_query_and_user(q: &Query, user: &UserData, pool: PgPool) -> Result<Self, Rejection> {
|
|
||||||
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
|
|
||||||
let id_from_hashtag = || pool.clone().select_hashtag_id(&q.hashtag);
|
|
||||||
let user_owns_list = || pool.clone().user_owns_list(user.id, q.list);
|
|
||||||
|
|
||||||
Ok(match q.stream.as_ref() {
|
|
||||||
"public" => match q.media {
|
|
||||||
true => Timeline(Public, Federated, Media),
|
|
||||||
false => Timeline(Public, Federated, All),
|
|
||||||
},
|
|
||||||
"public:local" => match q.media {
|
|
||||||
true => Timeline(Public, Local, Media),
|
|
||||||
false => Timeline(Public, Local, All),
|
|
||||||
},
|
|
||||||
"public:media" => Timeline(Public, Federated, Media),
|
|
||||||
"public:local:media" => Timeline(Public, Local, Media),
|
|
||||||
|
|
||||||
"hashtag" => Timeline(Hashtag(id_from_hashtag()?), Federated, All),
|
|
||||||
"hashtag:local" => Timeline(Hashtag(id_from_hashtag()?), Local, All),
|
|
||||||
"user" => match user.scopes.contains(&Statuses) {
|
|
||||||
true => Timeline(User(user.id), Federated, All),
|
|
||||||
false => Err(custom("Error: Missing access token"))?,
|
|
||||||
},
|
|
||||||
"user:notification" => match user.scopes.contains(&Statuses) {
|
|
||||||
true => Timeline(User(user.id), Federated, Notification),
|
|
||||||
false => Err(custom("Error: Missing access token"))?,
|
|
||||||
},
|
|
||||||
"list" => match user.scopes.contains(&Lists) && user_owns_list() {
|
|
||||||
true => Timeline(List(q.list), Federated, All),
|
|
||||||
false => Err(warp::reject::custom("Error: Missing access token"))?,
|
|
||||||
},
|
|
||||||
"direct" => match user.scopes.contains(&Statuses) {
|
|
||||||
true => Timeline(Direct(*user.id), Federated, All),
|
|
||||||
false => Err(custom("Error: Missing access token"))?,
|
|
||||||
},
|
|
||||||
other => {
|
|
||||||
log::warn!("Request for nonexistent endpoint: `{}`", other);
|
|
||||||
Err(custom("Error: Nonexistent endpoint"))?
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
|
||||||
pub enum Stream {
|
|
||||||
User(Id),
|
|
||||||
// TODO consider whether List, Direct, and Hashtag should all be `id::Id`s
|
|
||||||
List(i64),
|
|
||||||
Direct(i64),
|
|
||||||
Hashtag(i64),
|
|
||||||
Public,
|
|
||||||
Unset,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
|
||||||
pub enum Reach {
|
|
||||||
Local,
|
|
||||||
Federated,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
|
||||||
pub enum Content {
|
|
||||||
All,
|
|
||||||
Media,
|
|
||||||
Notification,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
|
||||||
pub enum Scope {
|
|
||||||
Read,
|
|
||||||
Statuses,
|
|
||||||
Notifications,
|
|
||||||
Lists,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct UserData {
|
|
||||||
pub id: Id,
|
|
||||||
pub allowed_langs: HashSet<String>,
|
|
||||||
pub scopes: HashSet<Scope>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UserData {
|
|
||||||
pub fn public() -> Self {
|
|
||||||
Self {
|
|
||||||
id: Id(-1),
|
|
||||||
allowed_langs: HashSet::new(),
|
|
||||||
scopes: HashSet::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,174 @@
|
||||||
|
use super::query::Query;
|
||||||
|
use crate::err::TimelineErr;
|
||||||
|
use crate::messages::Id;
|
||||||
|
|
||||||
|
use hashbrown::HashSet;
|
||||||
|
use lru::LruCache;
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use warp::reject::Rejection;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||||
|
pub struct Timeline(pub Stream, pub Reach, pub Content);
|
||||||
|
|
||||||
|
impl Timeline {
|
||||||
|
pub fn empty() -> Self {
|
||||||
|
use {Content::*, Reach::*, Stream::*};
|
||||||
|
Self(Unset, Local, Notification)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_redis_raw_timeline(&self, hashtag: Option<&String>) -> Result<String, TimelineErr> {
|
||||||
|
use {Content::*, Reach::*, Stream::*};
|
||||||
|
Ok(match self {
|
||||||
|
Timeline(Public, Federated, All) => "timeline:public".into(),
|
||||||
|
Timeline(Public, Local, All) => "timeline:public:local".into(),
|
||||||
|
Timeline(Public, Federated, Media) => "timeline:public:media".into(),
|
||||||
|
Timeline(Public, Local, Media) => "timeline:public:local:media".into(),
|
||||||
|
// TODO -- would `.push_str` be faster here?
|
||||||
|
Timeline(Hashtag(_id), Federated, All) => format!(
|
||||||
|
"timeline:hashtag:{}",
|
||||||
|
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
|
||||||
|
),
|
||||||
|
Timeline(Hashtag(_id), Local, All) => format!(
|
||||||
|
"timeline:hashtag:{}:local",
|
||||||
|
hashtag.ok_or_else(|| TimelineErr::MissingHashtag)?
|
||||||
|
),
|
||||||
|
Timeline(User(id), Federated, All) => format!("timeline:{}", id),
|
||||||
|
Timeline(User(id), Federated, Notification) => format!("timeline:{}:notification", id),
|
||||||
|
Timeline(List(id), Federated, All) => format!("timeline:list:{}", id),
|
||||||
|
Timeline(Direct(id), Federated, All) => format!("timeline:direct:{}", id),
|
||||||
|
Timeline(_one, _two, _three) => Err(TimelineErr::InvalidInput)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_redis_text(
|
||||||
|
timeline: &str,
|
||||||
|
cache: &mut LruCache<String, i64>,
|
||||||
|
) -> Result<Self, TimelineErr> {
|
||||||
|
let mut id_from_tag = |tag: &str| match cache.get(&tag.to_string()) {
|
||||||
|
Some(id) => Ok(*id),
|
||||||
|
None => Err(TimelineErr::InvalidInput), // TODO more specific
|
||||||
|
};
|
||||||
|
|
||||||
|
use {Content::*, Reach::*, Stream::*};
|
||||||
|
Ok(match &timeline.split(':').collect::<Vec<&str>>()[..] {
|
||||||
|
["public"] => Timeline(Public, Federated, All),
|
||||||
|
["public", "local"] => Timeline(Public, Local, All),
|
||||||
|
["public", "media"] => Timeline(Public, Federated, Media),
|
||||||
|
["public", "local", "media"] => Timeline(Public, Local, Media),
|
||||||
|
["hashtag", tag] => Timeline(Hashtag(id_from_tag(tag)?), Federated, All),
|
||||||
|
["hashtag", tag, "local"] => Timeline(Hashtag(id_from_tag(tag)?), Local, All),
|
||||||
|
[id] => Timeline(User(id.parse()?), Federated, All),
|
||||||
|
[id, "notification"] => Timeline(User(id.parse()?), Federated, Notification),
|
||||||
|
["list", id] => Timeline(List(id.parse()?), Federated, All),
|
||||||
|
["direct", id] => Timeline(Direct(id.parse()?), Federated, All),
|
||||||
|
// Other endpoints don't exist:
|
||||||
|
[..] => Err(TimelineErr::InvalidInput)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_query_and_user(q: &Query, user: &UserData) -> Result<Self, Rejection> {
|
||||||
|
use {warp::reject::custom, Content::*, Reach::*, Scope::*, Stream::*};
|
||||||
|
|
||||||
|
Ok(match q.stream.as_ref() {
|
||||||
|
"public" => match q.media {
|
||||||
|
true => Timeline(Public, Federated, Media),
|
||||||
|
false => Timeline(Public, Federated, All),
|
||||||
|
},
|
||||||
|
"public:local" => match q.media {
|
||||||
|
true => Timeline(Public, Local, Media),
|
||||||
|
false => Timeline(Public, Local, All),
|
||||||
|
},
|
||||||
|
"public:media" => Timeline(Public, Federated, Media),
|
||||||
|
"public:local:media" => Timeline(Public, Local, Media),
|
||||||
|
|
||||||
|
"hashtag" => Timeline(Hashtag(0), Federated, All),
|
||||||
|
"hashtag:local" => Timeline(Hashtag(0), Local, All),
|
||||||
|
"user" => match user.scopes.contains(&Statuses) {
|
||||||
|
true => Timeline(User(user.id), Federated, All),
|
||||||
|
false => Err(custom("Error: Missing access token"))?,
|
||||||
|
},
|
||||||
|
"user:notification" => match user.scopes.contains(&Statuses) {
|
||||||
|
true => Timeline(User(user.id), Federated, Notification),
|
||||||
|
false => Err(custom("Error: Missing access token"))?,
|
||||||
|
},
|
||||||
|
"list" => match user.scopes.contains(&Lists) {
|
||||||
|
true => Timeline(List(q.list), Federated, All),
|
||||||
|
false => Err(warp::reject::custom("Error: Missing access token"))?,
|
||||||
|
},
|
||||||
|
"direct" => match user.scopes.contains(&Statuses) {
|
||||||
|
true => Timeline(Direct(*user.id), Federated, All),
|
||||||
|
false => Err(custom("Error: Missing access token"))?,
|
||||||
|
},
|
||||||
|
other => {
|
||||||
|
log::warn!("Request for nonexistent endpoint: `{}`", other);
|
||||||
|
Err(custom("Error: Nonexistent endpoint"))?
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||||
|
pub enum Stream {
|
||||||
|
User(Id),
|
||||||
|
// TODO consider whether List, Direct, and Hashtag should all be `id::Id`s
|
||||||
|
List(i64),
|
||||||
|
Direct(i64),
|
||||||
|
Hashtag(i64),
|
||||||
|
Public,
|
||||||
|
Unset,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||||
|
pub enum Reach {
|
||||||
|
Local,
|
||||||
|
Federated,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, Eq, Hash, PartialEq)]
|
||||||
|
pub enum Content {
|
||||||
|
All,
|
||||||
|
Media,
|
||||||
|
Notification,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub enum Scope {
|
||||||
|
Read,
|
||||||
|
Statuses,
|
||||||
|
Notifications,
|
||||||
|
Lists,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<&str> for Scope {
|
||||||
|
type Error = TimelineErr;
|
||||||
|
|
||||||
|
fn try_from(s: &str) -> Result<Self, Self::Error> {
|
||||||
|
match s {
|
||||||
|
"read" => Ok(Scope::Read),
|
||||||
|
"read:statuses" => Ok(Scope::Statuses),
|
||||||
|
"read:notifications" => Ok(Scope::Notifications),
|
||||||
|
"read:lists" => Ok(Scope::Lists),
|
||||||
|
"write" | "follow" => Err(TimelineErr::InvalidInput), // ignore write scopes
|
||||||
|
unexpected => {
|
||||||
|
log::warn!("Ignoring unknown scope `{}`", unexpected);
|
||||||
|
Err(TimelineErr::InvalidInput)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UserData {
|
||||||
|
pub id: Id,
|
||||||
|
pub allowed_langs: HashSet<String>,
|
||||||
|
pub scopes: HashSet<Scope>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserData {
|
||||||
|
pub fn public() -> Self {
|
||||||
|
Self {
|
||||||
|
id: Id(-1),
|
||||||
|
allowed_langs: HashSet::new(),
|
||||||
|
scopes: HashSet::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3,22 +3,16 @@ pub use err::RedisConnErr;
|
||||||
|
|
||||||
use super::msg::{RedisParseErr, RedisParseOutput};
|
use super::msg::{RedisParseErr, RedisParseOutput};
|
||||||
use super::ManagerErr;
|
use super::ManagerErr;
|
||||||
use crate::{
|
use crate::config::Redis;
|
||||||
config::Redis,
|
use crate::messages::Event;
|
||||||
messages::Event,
|
use crate::request::{Stream, Timeline};
|
||||||
request::{Stream, Timeline},
|
|
||||||
};
|
|
||||||
|
|
||||||
use std::{
|
|
||||||
convert::{TryFrom, TryInto},
|
|
||||||
io::{Read, Write},
|
|
||||||
net::TcpStream,
|
|
||||||
str,
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
|
|
||||||
use futures::{Async, Poll};
|
use futures::{Async, Poll};
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
|
use std::convert::{TryFrom, TryInto};
|
||||||
|
use std::io::{Read, Write};
|
||||||
|
use std::net::TcpStream;
|
||||||
|
use std::str;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
type Result<T> = std::result::Result<T, RedisConnErr>;
|
type Result<T> = std::result::Result<T, RedisConnErr>;
|
||||||
|
|
||||||
|
@ -46,7 +40,7 @@ impl RedisConn {
|
||||||
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
|
// TODO: eventually, it might make sense to have Mastodon publish to timelines with
|
||||||
// the tag number instead of the tag name. This would save us from dealing
|
// the tag number instead of the tag name. This would save us from dealing
|
||||||
// with a cache here and would be consistent with how lists/users are handled.
|
// with a cache here and would be consistent with how lists/users are handled.
|
||||||
redis_namespace: redis_cfg.namespace.clone(),
|
redis_namespace: redis_cfg.namespace.clone().0,
|
||||||
redis_input: Vec::new(),
|
redis_input: Vec::new(),
|
||||||
};
|
};
|
||||||
Ok(redis_conn)
|
Ok(redis_conn)
|
||||||
|
@ -61,14 +55,12 @@ impl RedisConn {
|
||||||
self.redis_input.extend_from_slice(&buffer[..n]);
|
self.redis_input.extend_from_slice(&buffer[..n]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
Ok(n) => self.redis_input.extend_from_slice(&buffer[..n]),
|
||||||
self.redis_input.extend_from_slice(&buffer[..n]);
|
|
||||||
}
|
|
||||||
Err(_) => break,
|
Err(_) => break,
|
||||||
};
|
};
|
||||||
if first_read {
|
if first_read {
|
||||||
size = 2000;
|
size = 2000;
|
||||||
buffer = vec![0u8; size];
|
buffer = vec![0_u8; size];
|
||||||
first_read = false;
|
first_read = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,50 +109,6 @@ impl RedisConn {
|
||||||
self.tag_name_cache.put(id, hashtag);
|
self.tag_name_cache.put(id, hashtag);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_connection(addr: &str, pass: Option<&String>) -> Result<TcpStream> {
|
|
||||||
match TcpStream::connect(&addr) {
|
|
||||||
Ok(mut conn) => {
|
|
||||||
if let Some(password) = pass {
|
|
||||||
Self::auth_connection(&mut conn, &addr, password)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Self::validate_connection(&mut conn, &addr)?;
|
|
||||||
conn.set_read_timeout(Some(Duration::from_millis(10)))
|
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
|
||||||
Ok(conn)
|
|
||||||
}
|
|
||||||
Err(e) => Err(RedisConnErr::with_addr(&addr, e)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
|
|
||||||
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
|
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
|
||||||
let mut buffer = vec![0u8; 5];
|
|
||||||
conn.read_exact(&mut buffer)
|
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
|
||||||
let reply = String::from_utf8_lossy(&buffer);
|
|
||||||
match &*reply {
|
|
||||||
"+OK\r\n" => (),
|
|
||||||
_ => Err(RedisConnErr::IncorrectPassword(pass.to_string()))?,
|
|
||||||
};
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> {
|
|
||||||
conn.write_all(b"PING\r\n")
|
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
|
||||||
let mut buffer = vec![0u8; 7];
|
|
||||||
conn.read_exact(&mut buffer)
|
|
||||||
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
|
||||||
let reply = String::from_utf8_lossy(&buffer);
|
|
||||||
match &*reply {
|
|
||||||
"+PONG\r\n" => Ok(()),
|
|
||||||
"-NOAUTH" => Err(RedisConnErr::MissingPassword),
|
|
||||||
"HTTP/1." => Err(RedisConnErr::NotRedis(addr.to_string())),
|
|
||||||
_ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> {
|
pub fn send_cmd(&mut self, cmd: RedisCmd, timeline: &Timeline) -> Result<()> {
|
||||||
let hashtag = match timeline {
|
let hashtag = match timeline {
|
||||||
Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id),
|
Timeline(Stream::Hashtag(id), _, _) => self.tag_name_cache.get(id),
|
||||||
|
@ -182,6 +130,44 @@ impl RedisConn {
|
||||||
self.secondary.write_all(&secondary_cmd.as_bytes())?;
|
self.secondary.write_all(&secondary_cmd.as_bytes())?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_connection(addr: &str, pass: Option<&String>) -> Result<TcpStream> {
|
||||||
|
let mut conn = TcpStream::connect(&addr)?;
|
||||||
|
if let Some(password) = pass {
|
||||||
|
Self::auth_connection(&mut conn, &addr, password)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::validate_connection(&mut conn, &addr)?;
|
||||||
|
conn.set_read_timeout(Some(Duration::from_millis(10)))
|
||||||
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
Ok(conn)
|
||||||
|
}
|
||||||
|
fn auth_connection(conn: &mut TcpStream, addr: &str, pass: &str) -> Result<()> {
|
||||||
|
conn.write_all(&format!("*2\r\n$4\r\nauth\r\n${}\r\n{}\r\n", pass.len(), pass).as_bytes())
|
||||||
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
let mut buffer = vec![0u8; 5];
|
||||||
|
conn.read_exact(&mut buffer)
|
||||||
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
if String::from_utf8_lossy(&buffer) != "+OK\r\n" {
|
||||||
|
Err(RedisConnErr::IncorrectPassword(pass.to_string()))?
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_connection(conn: &mut TcpStream, addr: &str) -> Result<()> {
|
||||||
|
conn.write_all(b"PING\r\n")
|
||||||
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
let mut buffer = vec![0u8; 7];
|
||||||
|
conn.read_exact(&mut buffer)
|
||||||
|
.map_err(|e| RedisConnErr::with_addr(&addr, e))?;
|
||||||
|
let reply = String::from_utf8_lossy(&buffer);
|
||||||
|
match &*reply {
|
||||||
|
"+PONG\r\n" => Ok(()),
|
||||||
|
"-NOAUTH" => Err(RedisConnErr::MissingPassword),
|
||||||
|
"HTTP/1." => Err(RedisConnErr::NotRedis(addr.to_string())),
|
||||||
|
_ => Err(RedisConnErr::InvalidRedisReply(reply.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum RedisCmd {
|
pub enum RedisCmd {
|
||||||
|
|
|
@ -132,43 +132,38 @@ impl Sse {
|
||||||
let blocks = subscription.blocks;
|
let blocks = subscription.blocks;
|
||||||
|
|
||||||
let event_stream = sse_rx
|
let event_stream = sse_rx
|
||||||
|
.filter(move |(timeline, _)| target_timeline == *timeline)
|
||||||
.filter_map(move |(timeline, event)| {
|
.filter_map(move |(timeline, event)| {
|
||||||
if target_timeline == timeline {
|
use crate::messages::{
|
||||||
use crate::messages::{
|
CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind,
|
||||||
CheckedEvent, CheckedEvent::Update, DynEvent, Event::*, EventKind,
|
};
|
||||||
};
|
|
||||||
|
|
||||||
use crate::request::Stream::Public;
|
use crate::request::Stream::Public;
|
||||||
match event {
|
match event {
|
||||||
TypeSafe(Update { payload, queued_at }) => match timeline {
|
TypeSafe(Update { payload, queued_at }) => match timeline {
|
||||||
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
Timeline(Public, _, _) if payload.language_not(&allowed_langs) => None,
|
||||||
_ if payload.involves_any(&blocks) => None,
|
_ if payload.involves_any(&blocks) => None,
|
||||||
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
|
_ => Self::reply_with(Event::TypeSafe(CheckedEvent::Update {
|
||||||
payload,
|
payload,
|
||||||
queued_at,
|
queued_at,
|
||||||
})),
|
})),
|
||||||
},
|
},
|
||||||
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
|
TypeSafe(non_update) => Self::reply_with(Event::TypeSafe(non_update)),
|
||||||
Dynamic(dyn_event) => {
|
Dynamic(dyn_event) => {
|
||||||
if let EventKind::Update(s) = dyn_event.kind {
|
if let EventKind::Update(s) = dyn_event.kind {
|
||||||
match timeline {
|
match timeline {
|
||||||
Timeline(Public, _, _) if s.language_not(&allowed_langs) => {
|
Timeline(Public, _, _) if s.language_not(&allowed_langs) => None,
|
||||||
None
|
_ if s.involves_any(&blocks) => None,
|
||||||
}
|
_ => Self::reply_with(Dynamic(DynEvent {
|
||||||
_ if s.involves_any(&blocks) => None,
|
kind: EventKind::Update(s),
|
||||||
_ => Self::reply_with(Dynamic(DynEvent {
|
..dyn_event
|
||||||
kind: EventKind::Update(s),
|
})),
|
||||||
..dyn_event
|
|
||||||
})),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
Ping => None, // pings handled automatically
|
|
||||||
}
|
}
|
||||||
} else {
|
Ping => None, // pings handled automatically
|
||||||
None
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(move |res| {
|
.then(move |res| {
|
||||||
|
@ -186,3 +181,4 @@ impl Sse {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// TODO -- split WS and SSE into separate files and add misc stuff from main.rs here
|
||||||
|
|
Loading…
Reference in New Issue