mirror of https://github.com/mastodon/flodgatt
Refactor scope managment to use enum
This commit is contained in:
parent
f3d20153e5
commit
7dafa834c1
|
@ -10,27 +10,12 @@ use super::query::Query;
|
|||
use std::collections::HashSet;
|
||||
use warp::reject::Rejection;
|
||||
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct OauthScope {
|
||||
pub all: bool,
|
||||
pub statuses: bool,
|
||||
pub notify: bool,
|
||||
pub lists: bool,
|
||||
}
|
||||
impl From<Vec<String>> for OauthScope {
|
||||
fn from(scope_list: Vec<String>) -> Self {
|
||||
let mut oauth_scope = OauthScope::default();
|
||||
for scope in scope_list {
|
||||
match scope.as_str() {
|
||||
"read" => oauth_scope.all = true,
|
||||
"read:statuses" => oauth_scope.statuses = true,
|
||||
"read:notifications" => oauth_scope.notify = true,
|
||||
"read:lists" => oauth_scope.lists = true,
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
oauth_scope
|
||||
}
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum Scope {
|
||||
All,
|
||||
Statuses,
|
||||
Notifications,
|
||||
Lists,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq)]
|
||||
|
@ -44,7 +29,7 @@ pub struct Blocks {
|
|||
pub struct User {
|
||||
pub target_timeline: String,
|
||||
pub id: i64,
|
||||
pub scopes: OauthScope,
|
||||
pub scopes: HashSet<Scope>,
|
||||
pub logged_in: bool,
|
||||
pub allowed_langs: HashSet<String>,
|
||||
pub blocks: Blocks,
|
||||
|
@ -54,7 +39,7 @@ impl Default for User {
|
|||
fn default() -> Self {
|
||||
Self {
|
||||
id: -1,
|
||||
scopes: OauthScope::default(),
|
||||
scopes: HashSet::new(),
|
||||
logged_in: false,
|
||||
target_timeline: String::new(),
|
||||
allowed_langs: HashSet::new(),
|
||||
|
@ -63,12 +48,6 @@ impl Default for User {
|
|||
}
|
||||
}
|
||||
|
||||
// impl fmt::Display for User {
|
||||
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
// write!(f, r##"User {} "##)
|
||||
// }
|
||||
// }
|
||||
|
||||
impl User {
|
||||
pub fn from_query(q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||
let token = q.access_token.clone();
|
||||
|
@ -80,11 +59,13 @@ impl User {
|
|||
user = user.set_timeline_and_filter(q, pool.clone())?;
|
||||
user.blocks.user_blocks = postgres::select_user_blocks(user.id, pool.clone());
|
||||
user.blocks.domain_blocks = postgres::select_domain_blocks(user.id, pool.clone());
|
||||
log::info!("Creating user: {:#?}", user);
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
fn set_timeline_and_filter(self, q: Query, pool: PgPool) -> Result<Self, Rejection> {
|
||||
let (read_scope, f) = (self.scopes.clone(), self.allowed_langs.clone());
|
||||
use Scope::*;
|
||||
let (filter, target_timeline) = match q.stream.as_ref() {
|
||||
// Public endpoints:
|
||||
tl @ "public" | tl @ "public:local" if q.media => (f, format!("{}:media", tl)),
|
||||
|
@ -94,18 +75,18 @@ impl User {
|
|||
// Hashtag endpoints:
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => (f, format!("{}:{}", tl, q.hashtag)),
|
||||
// Private endpoints: User:
|
||||
"user" if self.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
"user" if self.logged_in && read_scope.contains(&Statuses) => {
|
||||
(HashSet::new(), format!("{}", self.id))
|
||||
}
|
||||
"user:notification" if self.logged_in && (read_scope.all || read_scope.notify) => {
|
||||
"user:notification" if self.logged_in && read_scope.contains(&Notifications) => {
|
||||
(HashSet::new(), format!("{}", self.id))
|
||||
}
|
||||
// List endpoint:
|
||||
"list" if self.owns_list(q.list, pool) && (read_scope.all || read_scope.lists) => {
|
||||
"list" if self.owns_list(q.list, pool) && read_scope.contains(&Lists) => {
|
||||
(HashSet::new(), format!("list:{}", q.list))
|
||||
}
|
||||
// Direct endpoint:
|
||||
"direct" if self.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
"direct" if self.logged_in && read_scope.contains(&Statuses) => {
|
||||
(HashSet::new(), "direct".to_string())
|
||||
}
|
||||
// Reject unathorized access attempts for private endpoints
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! Postgres queries
|
||||
use crate::{
|
||||
config,
|
||||
parse_client_request::user::{OauthScope, User},
|
||||
parse_client_request::user::{Scope, User},
|
||||
};
|
||||
use ::postgres;
|
||||
use r2d2_postgres::PostgresConnectionManager;
|
||||
|
@ -51,11 +51,25 @@ LIMIT 1",
|
|||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if let Some(result_columns) = query_rows.get(0) {
|
||||
let scope_vec: Vec<String> = result_columns
|
||||
let mut scopes: HashSet<Scope> = result_columns
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.map(|s| s.to_owned())
|
||||
.filter_map(|scope| match scope {
|
||||
"read" => Some(Scope::All),
|
||||
"read:statuses" => Some(Scope::Statuses),
|
||||
"read:notifications" => Some(Scope::Notifications),
|
||||
"read:lists" => Some(Scope::Lists),
|
||||
unexpected => {
|
||||
log::warn!("Unable to parse scope `{}`, ignoring it.", unexpected);
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if scopes.remove(&Scope::All) {
|
||||
scopes.insert(Scope::Statuses);
|
||||
scopes.insert(Scope::Notifications);
|
||||
scopes.insert(Scope::Lists);
|
||||
}
|
||||
let mut allowed_langs = HashSet::new();
|
||||
if let Ok(langs_vec) = result_columns.try_get::<_, Vec<String>>(2) {
|
||||
for lang in langs_vec {
|
||||
|
@ -65,7 +79,7 @@ LIMIT 1",
|
|||
|
||||
Ok(User {
|
||||
id: result_columns.get(1),
|
||||
scopes: OauthScope::from(scope_vec),
|
||||
scopes,
|
||||
logged_in: true,
|
||||
allowed_langs,
|
||||
..User::default()
|
||||
|
|
|
@ -99,17 +99,20 @@ impl futures::stream::Stream for ClientAgent {
|
|||
log::warn!("Polling the Receiver took: {:?}", start_time.elapsed());
|
||||
};
|
||||
|
||||
let allowed_langs = &self.current_user.allowed_langs;
|
||||
let blocked_users = &self.current_user.blocks.user_blocks;
|
||||
let blocked_domains = &self.current_user.blocks.domain_blocks;
|
||||
const BLOCK_TOOT: Result<Async<Option<Message>>, Error> = Ok(NotReady);
|
||||
|
||||
let (allowed_langs, blocks) = (&self.current_user.allowed_langs, &self.current_user.blocks);
|
||||
let (blocked_users, blocked_domains) = (&blocks.user_blocks, &blocks.domain_blocks);
|
||||
let (send, block) = (|msg| Ok(Ready(Some(msg))), Ok(NotReady));
|
||||
use Message::*;
|
||||
match result {
|
||||
Ok(Async::Ready(Some(json))) => match Message::from_json(json) {
|
||||
Message::Update(toot) if toot.language_not_allowed(allowed_langs) => BLOCK_TOOT,
|
||||
Message::Update(toot) if toot.involves_blocked_user(blocked_users) => BLOCK_TOOT,
|
||||
Message::Update(toot) if toot.from_blocked_domain(blocked_domains) => BLOCK_TOOT,
|
||||
other_message => Ok(Ready(Some(other_message))),
|
||||
Update(status) if status.language_not_allowed(allowed_langs) => block,
|
||||
Update(status) if status.involves_blocked_user(blocked_users) => block,
|
||||
Update(status) if status.from_blocked_domain(blocked_domains) => block,
|
||||
Update(status) => send(Update(status)),
|
||||
Notification(notification) => send(Notification(notification)),
|
||||
Conversation(notification) => send(Conversation(notification)),
|
||||
Delete(status_id) => send(Delete(status_id)),
|
||||
FiltersChanged => send(FiltersChanged),
|
||||
},
|
||||
Ok(Ready(None)) => Ok(Ready(None)),
|
||||
Ok(NotReady) => Ok(NotReady),
|
||||
|
|
|
@ -34,7 +34,7 @@ impl Message {
|
|||
}
|
||||
}
|
||||
pub fn event(&self) -> String {
|
||||
format!("{}", self)
|
||||
format!("{}", self).to_lowercase()
|
||||
}
|
||||
pub fn payload(&self) -> String {
|
||||
match self {
|
||||
|
@ -72,6 +72,7 @@ impl Status {
|
|||
None => ALLOW, // If toot language is null, toot is always allowed
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if this toot originated from a domain the User has blocked.
|
||||
pub fn from_blocked_domain(&self, blocked_domains: &HashSet<String>) -> bool {
|
||||
let full_username = self.0["account"]["acct"]
|
||||
|
@ -93,6 +94,8 @@ impl Status {
|
|||
/// * Wrote the toot that this toot is boosting (if any)
|
||||
pub fn involves_blocked_user(&self, blocked_users: &HashSet<i64>) -> bool {
|
||||
let toot = self.0.clone();
|
||||
const ALLOW: bool = false;
|
||||
const REJECT: bool = true;
|
||||
|
||||
let author_user = match toot["account"]["id"].str_to_i64() {
|
||||
Ok(user_id) => vec![user_id].into_iter(),
|
||||
|
@ -128,7 +131,11 @@ impl Status {
|
|||
.chain(boosted_user)
|
||||
.collect::<HashSet<i64>>();
|
||||
|
||||
involved_users.is_disjoint(blocked_users)
|
||||
if involved_users.is_disjoint(blocked_users) {
|
||||
ALLOW
|
||||
} else {
|
||||
REJECT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue