Major refactor

This commit is contained in:
Daniel Sockwell 2020-03-14 21:16:48 -04:00
parent 4df364d1ac
commit f3d20153e5
9 changed files with 162 additions and 114 deletions

View File

@ -6,6 +6,14 @@ pub fn die_with_msg(msg: impl Display) -> ! {
std::process::exit(1);
}
#[macro_export]
macro_rules! log_fatal {
($str:expr, $var:expr) => {{
log::error!($str, $var);
panic!();
};};
}
pub fn env_var_fatal(env_var: &str, supplied_value: &str, allowed_values: String) -> ! {
eprintln!(
r"FATAL ERROR: {var} is set to `{value}`, which is invalid.

View File

@ -57,11 +57,10 @@ fn main() {
// WebSocket
let ws_update_interval = *cfg.ws_interval;
let websocket_routes = ws::extract_user_or_reject(pg_pool.clone())
let websocket_routes = ws::extract_user_and_token_or_reject(pg_pool.clone())
.and(warp::ws::ws2())
.map(move |user: user::User, ws: Ws2| {
.map(move |user: user::User, token: Option<String>, ws: Ws2| {
log::info!("Incoming websocket request");
let token = user.access_token.clone();
// Create a new ClientAgent
let mut client_agent = client_agent_ws.clone_with_shared_receiver();
// Assign that agent to generate a stream of updates for the user/timeline pair
@ -77,7 +76,7 @@ fn main() {
ws_update_interval,
)
}),
token,
token.unwrap_or_else(String::new),
)
})
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));

View File

@ -190,7 +190,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(public_media_1 {
@ -209,7 +209,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(public_local {
@ -228,7 +228,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(public_local_media_true {
@ -247,7 +247,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(public_local_media_1 {
@ -266,7 +266,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(hashtag {
@ -285,7 +285,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(hashtag_local {
@ -304,7 +304,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
@ -324,7 +324,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
allowed_langs: Filter::NoFilter,
},
});
test_private_endpoint!(user_notification {
@ -343,7 +343,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::Notification,
allowed_langs: Filter::Notification,
},
});
test_private_endpoint!(direct {
@ -362,7 +362,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
allowed_langs: Filter::NoFilter,
},
});
@ -383,7 +383,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
allowed_langs: Filter::NoFilter,
},
});
test_bad_auth_token_in_query!(public_media_true_bad_auth {

View File

@ -43,8 +43,6 @@ pub struct Blocks {
#[derive(Clone, Debug, PartialEq)]
pub struct User {
pub target_timeline: String,
pub email: String, // We only use email for logging; we could cut it for performance
pub access_token: String, // We only need this once (to send back with the WS reply). Cut?
pub id: i64,
pub scopes: OauthScope,
pub logged_in: bool,
@ -56,8 +54,6 @@ impl Default for User {
fn default() -> Self {
Self {
id: -1,
email: "".to_string(),
access_token: "".to_string(),
scopes: OauthScope::default(),
logged_in: false,
target_timeline: String::new(),
@ -67,18 +63,23 @@ impl Default for User {
}
}
// impl fmt::Display for User {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// write!(f, r##"User {} "##)
// }
// }
impl User {
pub fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
println!("Creating user...");
let mut user: User = match q.access_token.clone() {
let token = q.access_token.clone();
let mut user: User = match token {
None => User::default(),
Some(token) => postgres::select_user(&token, pool.clone())?,
};
user = user.set_timeline_and_filter(q, pool.clone())?;
user.blocks.user_blocks = postgres::select_user_blocks(user.id, pool.clone());
user.blocks.domain_blocks = postgres::select_domain_blocks(pool.clone());
dbg!(&user);
user.blocks.domain_blocks = postgres::select_domain_blocks(user.id, pool.clone());
Ok(user)
}

View File

@ -36,10 +36,10 @@ impl PgPool {
/// methods to do so. In general, this function shouldn't be needed outside `User`.
pub fn select_user(access_token: &str, pg_pool: PgPool) -> Result<User, Rejection> {
let mut conn = pg_pool.0.get().unwrap();
let query_result = conn
let query_rows = conn
.query(
"
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.email, users.chosen_languages, oauth_access_tokens.scopes
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM
oauth_access_tokens
INNER JOIN users ON
@ -49,33 +49,29 @@ AND oauth_access_tokens.revoked_at IS NULL
LIMIT 1",
&[&access_token.to_owned()],
)
.expect("Hard-coded query will return Some([0 or more rows])");
if query_result.is_empty() {
Err(warp::reject::custom("Error: Invalid access token"))
} else {
// TODO: better name than `only_row`
let only_row: &postgres::Row = query_result.get(0).unwrap();
let scope_vec: Vec<String> = only_row
.get::<_, String>(4)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let scope_vec: Vec<String> = result_columns
.get::<_, String>(3)
.split(' ')
.map(|s| s.to_owned())
.collect();
let mut allowed_langs = HashSet::new();
if let Ok(langs_vec) = only_row.try_get::<_, Vec<String>>(3) {
if let Ok(langs_vec) = result_columns.try_get::<_, Vec<String>>(2) {
for lang in langs_vec {
allowed_langs.insert(lang);
}
}
Ok(User {
email: only_row.get(2),
access_token: access_token.to_string(),
id: only_row.get(1),
id: result_columns.get(1),
scopes: OauthScope::from(scope_vec),
logged_in: true,
allowed_langs,
..User::default()
})
} else {
Err(warp::reject::custom("Error: Invalid access token"))
}
}
@ -108,12 +104,15 @@ UNION SELECT target_account_id
///
/// **NOTE**: because we check this when the user connects, it will not include any blocks
/// the user adds until they refresh/reconnect.
pub fn select_domain_blocks(pg_pool: PgPool) -> HashSet<String> {
pub fn select_domain_blocks(user_id: i64, pg_pool: PgPool) -> HashSet<String> {
pg_pool
.0
.get()
.unwrap()
.query("SELECT domain FROM account_domain_blocks", &[])
.query(
"SELECT domain FROM account_domain_blocks WHERE account_id = $1",
&[&user_id],
)
.expect("Hard-coded query will return Some([0 or more rows])")
.iter()
.map(|row| row.get(0))

View File

@ -32,11 +32,12 @@ fn parse_query() -> BoxedFilter<(Query,)> {
.boxed()
}
pub fn extract_user_or_reject(pg_pool: PgPool) -> BoxedFilter<(User,)> {
pub fn extract_user_and_token_or_reject(pg_pool: PgPool) -> BoxedFilter<(User, Option<String>)> {
parse_query()
.and(query::OptionalAccessToken::from_ws_header())
.and_then(Query::update_access_token)
.and_then(move |q| User::from_query(q, pg_pool.clone()))
.and(query::OptionalAccessToken::from_ws_header())
.boxed()
}
@ -137,7 +138,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(public_local {
@ -156,7 +157,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(public_local_media {
@ -175,7 +176,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(hashtag {
@ -194,7 +195,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
test_public_endpoint!(hashtag_local {
@ -213,7 +214,7 @@ mod test {
},
logged_in: false,
blocks: Blocks::default(),
filter: Filter::Language,
allowed_langs: Filter::Language,
},
});
@ -233,7 +234,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
allowed_langs: Filter::NoFilter,
},
});
test_private_endpoint!(user_notification {
@ -252,7 +253,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::Notification,
allowed_langs: Filter::Notification,
},
});
test_private_endpoint!(direct {
@ -271,7 +272,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
allowed_langs: Filter::NoFilter,
},
});
test_private_endpoint!(list_valid_list {
@ -290,7 +291,7 @@ mod test {
},
logged_in: true,
blocks: Blocks::default(),
filter: Filter::NoFilter,
allowed_langs: Filter::NoFilter,
},
});

View File

@ -14,10 +14,14 @@
//!
//! Because `StreamManagers` are lightweight data structures that do not directly
//! communicate with Redis, it we create a new `ClientAgent` for
//! each new client connection (each in its own thread).
//! each new client connection (each in its own thread).use super::{message::Message, receiver::Receiver}
use super::{message::Message, receiver::Receiver};
use crate::{config, parse_client_request::user::User};
use futures::{Async, Poll};
use futures::{
Async::{self, NotReady, Ready},
Poll,
};
use std::sync;
use tokio::io::Error;
use uuid::Uuid;
@ -95,15 +99,20 @@ impl futures::stream::Stream for ClientAgent {
log::warn!("Polling the Receiver took: {:?}", start_time.elapsed());
};
let (filter, blocks) = (&self.current_user.allowed_langs, &self.current_user.blocks);
let allowed_langs = &self.current_user.allowed_langs;
let blocked_users = &self.current_user.blocks.user_blocks;
let blocked_domains = &self.current_user.blocks.domain_blocks;
const BLOCK_TOOT: Result<Async<Option<Message>>, Error> = Ok(NotReady);
match result {
Ok(Async::Ready(Some(json))) => match Message::from_json(json) {
Message::Update(status) if status.is_filtered_out(filter) => Ok(Async::NotReady),
Message::Update(status) if status.is_blocked(blocks) => Ok(Async::NotReady),
no_filtering_needed => Ok(Async::Ready(Some(no_filtering_needed))),
Message::Update(toot) if toot.language_not_allowed(allowed_langs) => BLOCK_TOOT,
Message::Update(toot) if toot.involves_blocked_user(blocked_users) => BLOCK_TOOT,
Message::Update(toot) if toot.from_blocked_domain(blocked_domains) => BLOCK_TOOT,
other_message => Ok(Ready(Some(other_message))),
},
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Ready(None)) => Ok(Ready(None)),
Ok(NotReady) => Ok(NotReady),
Err(e) => Err(e),
}
}

View File

@ -1,4 +1,5 @@
use crate::parse_client_request::user::Blocks;
use crate::log_fatal;
use log::{log_enabled, Level};
use serde_json::Value;
use std::{collections::HashSet, string::String};
use strum_macros::Display;
@ -13,17 +14,23 @@ pub enum Message {
}
#[derive(Debug, Clone)]
pub struct Status(pub Value);
pub struct Status(Value);
impl Message {
pub fn from_json(json: Value) -> Self {
match json["event"].as_str().unwrap() {
let event = json["event"]
.as_str()
.unwrap_or_else(|| log_fatal!("Could not process `event` in {:?}", json));
match event {
"update" => Self::Update(Status(json["payload"].clone())),
"conversation" => Self::Conversation(json["payload"].clone()),
"notification" => Self::Notification(json["payload"].clone()),
"delete" => Self::Delete(json["payload"].to_string()),
"filters_changed" => Self::FiltersChanged,
_ => unreachable!(),
unsupported_event => log_fatal!(
"Received an unsupported `event` type from Redis: {}",
unsupported_event
),
}
}
pub fn event(&self) -> String {
@ -40,59 +47,88 @@ impl Message {
}
impl Status {
pub fn get_originating_domain(&self) -> HashSet<String> {
let api = "originating Invariant Violation: JSON value does not conform to Mastodon API";
let mut originating_domain = HashSet::new();
// TODO: make this log an error instead of panicking.
originating_domain.insert(
self.0["account"]["acct"]
.as_str()
.expect(&api)
.split('@')
.nth(1)
.expect(&api)
.to_string(),
);
originating_domain
}
/// Returns `true` if the status is filtered out based on its language
pub fn language_not_allowed(&self, allowed_langs: &HashSet<String>) -> bool {
const ALLOW: bool = false;
const REJECT: bool = true;
pub fn get_involved_users(&self) -> HashSet<i64> {
let mut involved_users: HashSet<i64> = HashSet::new();
let msg = self.0.clone();
let api = "Invariant Violation: JSON value does not conform to Mastodon API";
involved_users.insert(msg["account"]["id"].str_to_i64().expect(&api));
if let Some(mentions) = msg["mentions"].as_array() {
for mention in mentions {
involved_users.insert(mention["id"].str_to_i64().expect(&api));
let reject_and_maybe_log = |toot_language| {
if log_enabled!(Level::Info) {
log::info!(
"Language `{toot_language}` is not in list `{allowed_langs:?}`",
toot_language = toot_language,
allowed_langs = allowed_langs
);
log::info!("Filtering out toot from `{}`", &self.0["account"]["acct"],);
}
REJECT
};
if allowed_langs.is_empty() {
return ALLOW; // listing no allowed_langs results in allowing all languages
}
if let Some(replied_to_account) = msg["in_reply_to_account_id"].as_str() {
involved_users.insert(replied_to_account.parse().expect(&api));
match self.0["language"].as_str() {
Some(toot_language) if allowed_langs.contains(toot_language) => ALLOW,
Some(toot_language) => reject_and_maybe_log(toot_language),
None => ALLOW, // If toot language is null, toot is always allowed
}
if let Some(reblog) = msg["reblog"].as_object() {
involved_users.insert(reblog["account"]["id"].str_to_i64().expect(&api));
}
involved_users
}
pub fn is_filtered_out(&self, permitted_langs: &HashSet<String>) -> bool {
// TODO add logging
let toot_language = self.0["language"]
/// Returns `true` if this toot originated from a domain the User has blocked.
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
let full_username = self.0["account"]["acct"]
.as_str()
.expect("Valid language")
.to_string();
!{ permitted_langs.is_empty() || permitted_langs.contains(&toot_language) }
.unwrap_or_else(|| log_fatal!("Could not process `account.acct` in {:?}", self.0));
match full_username.split('@').nth(1) {
Some(originating_domain) => blocked_domains.contains(originating_domain),
None => false, // None means the user is on the local instance, which can't be blocked
}
}
/// Returns `true` if the status is blocked by _either_ domain blocks or _user_ blocks
pub fn is_blocked(&self, b: &Blocks) -> bool {
// TODO add logging
!{
b.domain_blocks.is_disjoint(&self.get_originating_domain())
&& b.user_blocks.is_disjoint(&self.get_involved_users())
}
/// Returns `true` if the User's list of blocked users includes a user involved in this toot.
///
/// A user is involved if they:
/// * Wrote this toot
/// * Are mentioned in this toot
/// * Wrote a toot that this toot is replying to (if any)
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
let toot = self.0.clone();
let author_user = match toot["account"]["id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
Err(_) => log_fatal!("Could not process `account.id` in {:?}", toot),
};
let mentioned_users = (match &toot["mentions"] {
Value::Array(inner) => inner,
_ => log_fatal!("Could not process `mentions` in {:?}", toot),
})
.into_iter()
.map(|mention| match mention["id"].str_to_i64() {
Ok(user_id) => user_id,
Err(_) => log_fatal!("Could not process `id` field of mention in {:?}", toot),
});
let replied_to_user = match toot["in_reply_to_account_id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
Err(_) => vec![].into_iter(), // no error; just no replied_to_user
};
let boosted_user = match toot["reblog"].as_object() {
Some(boosted_user) => match boosted_user["account"]["id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
Err(_) => log_fatal!("Could not process `reblog.account.id` in {:?}", toot),
},
None => vec![].into_iter(), // no error; just no boosted_user
};
let involved_users = author_user
.chain(mentioned_users)
.chain(replied_to_user)
.chain(boosted_user)
.collect::<HashSet<i64>>();
involved_users.is_disjoint(blocked_users)
}
}
@ -102,10 +138,6 @@ trait ConvertValue {
impl ConvertValue for Value {
fn str_to_i64(&self) -> Result<i64, Box<dyn std::error::Error>> {
Ok(self
.as_str()
.ok_or(format!("{} is not a string", &self))?
.parse()
.map_err(|_| "Could not parse str")?)
Ok(self.as_str().ok_or("none_err")?.parse()?)
}
}

View File

@ -56,9 +56,8 @@ pub fn send_updates_to_ws(
}),
);
let (tl, email, id) = (
let (tl, id) = (
client_agent.current_user.target_timeline.clone(),
client_agent.current_user.email.clone(),
client_agent.current_user.id,
);
// Yield new events for as long as the client is still connected
@ -76,7 +75,7 @@ pub fn send_updates_to_ws(
futures::future::ok(false)
}
Err(e) => {
log::warn!("Error in TL {}\nfor user: {}({})\n{}", tl, email, id, e);
log::warn!("Error in TL {}\nfor user: #{}\n{}", tl, id, e);
futures::future::ok(false)
}
});