//! Postgres queries use crate::{ config, parse_client_request::subscription::{Scope, UserData}, }; use ::postgres; use hashbrown::HashSet; use r2d2_postgres::PostgresConnectionManager; use warp::reject::Rejection; #[derive(Clone, Debug)] pub struct PgPool(pub r2d2::Pool>); impl PgPool { pub fn new(pg_cfg: config::PostgresConfig) -> Self { let mut cfg = postgres::Config::new(); cfg.user(&pg_cfg.user) .host(&*pg_cfg.host.to_string()) .port(*pg_cfg.port) .dbname(&pg_cfg.database); if let Some(password) = &*pg_cfg.password { cfg.password(password); }; let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); let pool = r2d2::Pool::builder() .max_size(10) .build(manager) .expect("Can connect to local postgres"); Self(pool) } pub fn select_user(self, token: &str) -> Result { let mut conn = self.0.get().unwrap(); let query_rows = conn .query( " SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1", &[&token.to_owned()], ) .expect("Hard-coded query will return Some([0 or more rows])"); if let Some(result_columns) = query_rows.get(0) { let id = result_columns.get(1); let allowed_langs = result_columns .try_get::<_, Vec<_>>(2) .unwrap_or_else(|_| Vec::new()) .into_iter() .collect(); let mut scopes: HashSet = result_columns .get::<_, String>(3) .split(' ') .filter_map(|scope| match scope { "read" => Some(Scope::Read), "read:statuses" => Some(Scope::Statuses), "read:notifications" => Some(Scope::Notifications), "read:lists" => Some(Scope::Lists), "write" | "follow" => None, // ignore write scopes unexpected => { log::warn!("Ignoring unknown scope `{}`", unexpected); None } }) .collect(); // We don't need to separately track read auth - it's just all three others if scopes.remove(&Scope::Read) { scopes.insert(Scope::Statuses); scopes.insert(Scope::Notifications); scopes.insert(Scope::Lists); } Ok(UserData { id, allowed_langs, scopes, }) } else { Err(warp::reject::custom("Error: Invalid access token")) } } pub fn select_hashtag_id(self, tag_name: &str) -> Result { let mut conn = self.0.get().unwrap(); let rows = &conn .query( " SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name], ) .expect("Hard-coded query will return Some([0 or more rows])"); match rows.get(0) { Some(row) => Ok(row.get(0)), None => Err(warp::reject::custom("Error: Hashtag does not exist.")), } } /// Query Postgres for everyone the user has blocked or muted /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_users(self, user_id: i64) -> HashSet { self.0 .get() .unwrap() .query( " SELECT target_account_id FROM blocks WHERE account_id = $1 UNION SELECT target_account_id FROM mutes WHERE account_id = $1", &[&user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() .map(|row| row.get(0)) .collect() } /// Query Postgres for everyone who has blocked the user /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocking_users(self, user_id: i64) -> HashSet { self.0 .get() .unwrap() .query( " SELECT account_id FROM blocks WHERE target_account_id = $1", &[&user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() .map(|row| row.get(0)) .collect() } /// Query Postgres for all current domain blocks /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_domains(self, user_id: i64) -> HashSet { self.0 .get() .unwrap() .query( "SELECT domain FROM account_domain_blocks WHERE account_id = $1", &[&user_id], ) .expect("Hard-coded query will return Some([0 or more rows])") .iter() .map(|row| row.get(0)) .collect() } /// Test whether a user owns a list pub fn user_owns_list(self, user_id: i64, list_id: i64) -> bool { let mut conn = self.0.get().unwrap(); // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn .query( " SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", &[&list_id], ) .expect("Hard-coded query will return Some([0 or more rows])"); match rows.get(0) { None => false, Some(row) => { let list_owner_id: i64 = row.get(1); list_owner_id == user_id } } } }