Refactor scope managment to use enum

This commit is contained in:
Daniel Sockwell 2020-03-15 16:54:16 -04:00
parent f3d20153e5
commit 7dafa834c1
4 changed files with 53 additions and 48 deletions

View File

@ -10,27 +10,12 @@ use super::query::Query;
use std::collections::HashSet;
use warp::reject::Rejection;
#[derive(Clone, Debug, Default, PartialEq)]
pub struct OauthScope {
pub all: bool,
pub statuses: bool,
pub notify: bool,
pub lists: bool,
}
impl From<Vec<String>> for OauthScope {
fn from(scope_list: Vec<String>) -> Self {
let mut oauth_scope = OauthScope::default();
for scope in scope_list {
match scope.as_str() {
"read" => oauth_scope.all = true,
"read:statuses" => oauth_scope.statuses = true,
"read:notifications" => oauth_scope.notify = true,
"read:lists" => oauth_scope.lists = true,
_ => (),
}
}
oauth_scope
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Scope {
All,
Statuses,
Notifications,
Lists,
}
#[derive(Clone, Default, Debug, PartialEq)]
@ -44,7 +29,7 @@ pub struct Blocks {
pub struct User {
pub target_timeline: String,
pub id: i64,
pub scopes: OauthScope,
pub scopes: HashSet<Scope>,
pub logged_in: bool,
pub allowed_langs: HashSet<String>,
pub blocks: Blocks,
@ -54,7 +39,7 @@ impl Default for User {
fn default() -> Self {
Self {
id: -1,
scopes: OauthScope::default(),
scopes: HashSet::new(),
logged_in: false,
target_timeline: String::new(),
allowed_langs: HashSet::new(),
@ -63,12 +48,6 @@ impl Default for User {
}
}
// impl fmt::Display for User {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// write!(f, r##"User {} "##)
// }
// }
impl User {
pub fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
let token = q.access_token.clone();
@ -80,11 +59,13 @@ impl User {
user = user.set_timeline_and_filter(q, pool.clone())?;
user.blocks.user_blocks = postgres::select_user_blocks(user.id, pool.clone());
user.blocks.domain_blocks = postgres::select_domain_blocks(user.id, pool.clone());
log::info!("Creating user: {:#?}", user);
Ok(user)
}
fn set_timeline_and_filter(self, q: Query, pool: PgPool) -> Result<Self, Rejection> {
let (read_scope, f) = (self.scopes.clone(), self.allowed_langs.clone());
use Scope::*;
let (filter, target_timeline) = match q.stream.as_ref() {
// Public endpoints:
tl @ "public" | tl @ "public:local" if q.media => (f, format!("{}:media", tl)),
@ -94,18 +75,18 @@ impl User {
// Hashtag endpoints:
tl @ "hashtag" | tl @ "hashtag:local" => (f, format!("{}:{}", tl, q.hashtag)),
// Private endpoints: User:
"user" if self.logged_in && (read_scope.all || read_scope.statuses) => {
"user" if self.logged_in && read_scope.contains(&Statuses) => {
(HashSet::new(), format!("{}", self.id))
}
"user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => {
"user:notification" if self.logged_in && read_scope.contains(&Notifications) => {
(HashSet::new(), format!("{}", self.id))
}
// List endpoint:
"list" if self.owns_list(q.list, pool) && (read_scope.all || read_scope.lists) => {
"list" if self.owns_list(q.list, pool) && read_scope.contains(&Lists) => {
(HashSet::new(), format!("list:{}", q.list))
}
// Direct endpoint:
"direct" if self.logged_in && (read_scope.all || read_scope.statuses) => {
"direct" if self.logged_in && read_scope.contains(&Statuses) => {
(HashSet::new(), "direct".to_string())
}
// Reject unathorized access attempts for private endpoints

View File

@ -1,7 +1,7 @@
//! Postgres queries
use crate::{
config,
parse_client_request::user::{OauthScope, User},
parse_client_request::user::{Scope, User},
};
use ::postgres;
use r2d2_postgres::PostgresConnectionManager;
@ -51,11 +51,25 @@ LIMIT 1",
)
.expect("Hard-coded query will return Some([0 or more rows])");
if let Some(result_columns) = query_rows.get(0) {
let scope_vec: Vec<String> = result_columns
let mut scopes: HashSet<Scope> = result_columns
.get::<_, String>(3)
.split(' ')
.map(|s| s.to_owned())
.filter_map(|scope| match scope {
"read" => Some(Scope::All),
"read:statuses" => Some(Scope::Statuses),
"read:notifications" => Some(Scope::Notifications),
"read:lists" => Some(Scope::Lists),
unexpected => {
log::warn!("Unable to parse scope `{}`, ignoring it.", unexpected);
None
}
})
.collect();
if scopes.remove(&Scope::All) {
scopes.insert(Scope::Statuses);
scopes.insert(Scope::Notifications);
scopes.insert(Scope::Lists);
}
let mut allowed_langs = HashSet::new();
if let Ok(langs_vec) = result_columns.try_get::<_, Vec<String>>(2) {
for lang in langs_vec {
@ -65,7 +79,7 @@ LIMIT 1",
Ok(User {
id: result_columns.get(1),
scopes: OauthScope::from(scope_vec),
scopes,
logged_in: true,
allowed_langs,
..User::default()

View File

@ -99,17 +99,20 @@ impl futures::stream::Stream for ClientAgent {
log::warn!("Polling the Receiver took: {:?}", start_time.elapsed());
};
let allowed_langs = &self.current_user.allowed_langs;
let blocked_users = &self.current_user.blocks.user_blocks;
let blocked_domains = &self.current_user.blocks.domain_blocks;
const BLOCK_TOOT: Result<Async<Option<Message>>, Error> = Ok(NotReady);
let (allowed_langs, blocks) = (&self.current_user.allowed_langs, &self.current_user.blocks);
let (blocked_users, blocked_domains) = (&blocks.user_blocks, &blocks.domain_blocks);
let (send, block) = (|msg| Ok(Ready(Some(msg))), Ok(NotReady));
use Message::*;
match result {
Ok(Async::Ready(Some(json))) => match Message::from_json(json) {
Message::Update(toot) if toot.language_not_allowed(allowed_langs) => BLOCK_TOOT,
Message::Update(toot) if toot.involves_blocked_user(blocked_users) => BLOCK_TOOT,
Message::Update(toot) if toot.from_blocked_domain(blocked_domains) => BLOCK_TOOT,
other_message => Ok(Ready(Some(other_message))),
Update(status) if status.language_not_allowed(allowed_langs) => block,
Update(status) if status.involves_blocked_user(blocked_users) => block,
Update(status) if status.from_blocked_domain(blocked_domains) => block,
Update(status) => send(Update(status)),
Notification(notification) => send(Notification(notification)),
Conversation(notification) => send(Conversation(notification)),
Delete(status_id) => send(Delete(status_id)),
FiltersChanged => send(FiltersChanged),
},
Ok(Ready(None)) => Ok(Ready(None)),
Ok(NotReady) => Ok(NotReady),

View File

@ -34,7 +34,7 @@ impl Message {
}
}
pub fn event(&self) -> String {
format!("{}", self)
format!("{}", self).to_lowercase()
}
pub fn payload(&self) -> String {
match self {
@ -72,6 +72,7 @@ impl Status {
None => ALLOW, // If toot language is null, toot is always allowed
}
}
/// Returns `true` if this toot originated from a domain the User has blocked.
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
let full_username = self.0["account"]["acct"]
@ -93,6 +94,8 @@ impl Status {
/// * Wrote the toot that this toot is boosting (if any)
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
let toot = self.0.clone();
const ALLOW: bool = false;
const REJECT: bool = true;
let author_user = match toot["account"]["id"].str_to_i64() {
Ok(user_id) => vec![user_id].into_iter(),
@ -128,7 +131,11 @@ impl Status {
.chain(boosted_user)
.collect::<HashSet<i64>>();
involved_users.is_disjoint(blocked_users)
if involved_users.is_disjoint(blocked_users) {
ALLOW
} else {
REJECT
}
}
}