//! Postgres queries use super::err; use super::timeline::{Scope, UserData}; use crate::config; use crate::event::Id; use ::postgres; use hashbrown::HashSet; use r2d2_postgres::PostgresConnectionManager; use std::convert::TryFrom; #[derive(Clone, Debug)] pub struct PgPool { pub conn: r2d2::Pool>, whitelist_mode: bool, } type Result = std::result::Result; type Rejectable = std::result::Result; impl PgPool { pub fn new(pg_cfg: &config::Postgres, whitelist_mode: bool) -> Result { let mut cfg = postgres::Config::new(); cfg.user(&pg_cfg.user) .host(&*pg_cfg.host.to_string()) .port(*pg_cfg.port) .dbname(&pg_cfg.database); if let Some(password) = &*pg_cfg.password { cfg.password(password); }; cfg.connect(postgres::NoTls)?; // Test connection, letting us immediately exit with an error // when Postgres isn't running instead of timing out below let manager = PostgresConnectionManager::new(cfg, postgres::NoTls); let pool = r2d2::Pool::builder().max_size(10).build(manager)?; Ok(Self { conn: pool, whitelist_mode, }) } pub fn select_user(self, token: &Option) -> Rejectable { let mut conn = self.conn.get().map_err(warp::reject::custom)?; if let Some(token) = token { let query_rows = conn .query(" SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1", &[&token.to_owned()], ).map_err(warp::reject::custom)?; if let Some(result_columns) = query_rows.get(0) { let id = Id(result_columns.get(1)); let allowed_langs = result_columns .try_get::<_, Vec<_>>(2) .unwrap_or_default() .into_iter() .collect(); let mut scopes: HashSet = result_columns .get::<_, String>(3) .split(' ') .filter_map(|scope| Scope::try_from(scope).ok()) .collect(); // We don't need to separately track read auth - it's just all three others if scopes.contains(&Scope::Read) { scopes = vec![Scope::Statuses, Scope::Notifications, Scope::Lists] .into_iter() .collect() } Ok(UserData { id, allowed_langs, scopes, }) } else { Err(warp::reject::custom("Error: Invalid access token")) } } else if self.whitelist_mode { Err(warp::reject::custom("Error: Invalid access token")) } else { Ok(UserData::public()) } } pub fn select_hashtag_id(self, tag_name: &str) -> Rejectable { let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query("SELECT id FROM tags WHERE name = $1 LIMIT 1", &[&tag_name]) .map_err(warp::reject::custom)? .get(0) .map(|row| row.get(0)) .ok_or_else(|| warp::reject::custom("Error: Hashtag does not exist.")) } /// Query Postgres for everyone the user has blocked or muted /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_users(self, user_id: Id) -> Rejectable> { let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query( "SELECT target_account_id FROM blocks WHERE account_id = $1 UNION SELECT target_account_id FROM mutes WHERE account_id = $1", &[&*user_id], ) .map_err(warp::reject::custom)? .iter() .map(|row| Ok(Id(row.get(0)))) .collect() } /// Query Postgres for everyone who has blocked the user /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocking_users(self, user_id: Id) -> Rejectable> { let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query( "SELECT account_id FROM blocks WHERE target_account_id = $1", &[&*user_id], ) .map_err(warp::reject::custom)? .iter() .map(|row| Ok(Id(row.get(0)))) .collect() } /// Query Postgres for all current domain blocks /// /// **NOTE**: because we check this when the user connects, it will not include any blocks /// the user adds until they refresh/reconnect. pub fn select_blocked_domains(self, user_id: Id) -> Rejectable> { let mut conn = self.conn.get().map_err(warp::reject::custom)?; conn.query( "SELECT domain FROM account_domain_blocks WHERE account_id = $1", &[&*user_id], ) .map_err(warp::reject::custom)? .iter() .map(|row| Ok(row.get(0))) .collect() } /// Test whether a user owns a list pub fn user_owns_list(self, user_id: Id, list_id: i64) -> Rejectable { let mut conn = self.conn.get().map_err(warp::reject::custom)?; // For the Postgres query, `id` = list number; `account_id` = user.id let rows = &conn .query( "SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1", &[&list_id], ) .map_err(warp::reject::custom)?; Ok(rows.get(0).map_or(false, |row| Id(row.get(1)) == user_id)) } }