Improve parsing of values from Postgres

This commit is contained in:
Daniel Sockwell 2020-04-22 16:59:04 -04:00
parent bb5a601851
commit 63a6d0ba13
2 changed files with 49 additions and 38 deletions

View File

@ -39,10 +39,9 @@ impl PgPool {
cfg.connect(postgres::NoTls)?; // Test connection, letting us immediately exit with an error
// when Postgres isn't running instead of timing out below
let manager = PostgresConnectionManager::new(cfg, postgres::NoTls);
let pool = r2d2::Pool::builder().max_size(10).build(manager)?;
Ok(Self {
conn: pool,
conn: r2d2::Pool::builder().max_size(10).build(manager)?,
whitelist_mode,
})
}
@ -60,7 +59,7 @@ impl PgPool {
Err(reject::custom(Self::BAD_TOKEN))?;
};
let query_rows = conn
let rows = conn
.simple_query(&format!("
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
FROM oauth_access_tokens
@ -69,25 +68,25 @@ INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id
LIMIT 1", &token.to_owned())
).map_err(reject::custom)?;
let result_columns = match query_rows
.get(0)
.ok_or_else(|| reject::custom(Self::SERVER_ERR))?
{
postgres::SimpleQueryMessage::Row(row) => row,
let row = match rows.get(0) {
Some(postgres::SimpleQueryMessage::Row(row)) => row,
_ => Err(reject::custom(Self::PG_NULL))?, // Wildcard required by #[non_exhaustive]
};
let id = Id(get_col_or_reject(result_columns, 1)?
.parse()
.map_err(reject::custom)?);
let allowed_langs = result_columns
let id = Id(get_col_or_reject(row, 1)?.parse().map_err(reject::custom)?);
let allowed_langs: HashSet<_> = row
.try_get(2)
.unwrap_or_default()
.into_iter()
.map(String::from)
.collect();
.map_err(reject::custom)? // looks like `Some("{en,eo,es}")`
.map_or_else(HashSet::new, |str| {
str.trim_start_matches('{')
.trim_end_matches('}')
.split(',')
.map(String::from)
.collect()
});
let mut scopes: HashSet<Scope> = get_col_or_reject(result_columns, 3)?
let mut scopes: HashSet<Scope> = get_col_or_reject(row, 3)?
.split(' ')
.filter_map(|scope| Scope::try_from(scope).ok())
.collect();
@ -142,13 +141,13 @@ LIMIT 1", &token.to_owned())
))
.map_err(reject::custom)?
.iter()
.map(|msg| match msg {
postgres::SimpleQueryMessage::Row(row) => Ok(Id(get_col_or_reject(row, 0)?
.parse()
.map_err(reject::custom)?)),
_ => Ok(Id(0)),
.try_fold(HashSet::new(), |mut set, row| match row {
SimpleQueryMessage::Row(row) => {
set.insert(get_col_or_reject(row, 0)?.parse().map_err(reject::custom)?);
Ok(set)
}
_ => Ok(set),
})
.collect()
}
/// Query Postgres for everyone who has blocked the user
@ -163,13 +162,13 @@ LIMIT 1", &token.to_owned())
))
.map_err(reject::custom)?
.iter()
.map(|msg| match msg {
postgres::SimpleQueryMessage::Row(row) => Ok(Id(get_col_or_reject(row, 0)?
.parse()
.map_err(reject::custom)?)),
_ => Ok(Id(0)),
.try_fold(HashSet::new(), |mut set, row| match row {
SimpleQueryMessage::Row(row) => {
set.insert(get_col_or_reject(row, 0)?.parse().map_err(reject::custom)?);
Ok(set)
}
_ => Ok(set),
})
.collect()
}
/// Query Postgres for all current domain blocks
@ -184,11 +183,13 @@ LIMIT 1", &token.to_owned())
))
.map_err(reject::custom)?
.iter()
.map(|msg| match msg {
postgres::SimpleQueryMessage::Row(row) => Ok(get_col_or_reject(row, 0)?.to_string()),
_ => Ok(String::new()),
.try_fold(HashSet::new(), |mut set, row| match row {
SimpleQueryMessage::Row(row) => {
set.insert(get_col_or_reject(row, 0)?.to_string());
Ok(set)
}
_ => Ok(set),
})
.collect()
}
/// Test whether a user owns a list

View File

@ -53,6 +53,7 @@ impl Ws {
let incoming_events = self.ws_rx.clone().map_err(|_| ());
incoming_events.for_each(move |(tl, event)| {
// dbg!(&tl, &event);
if matches!(event, Event::Ping) {
self.send_msg(&event)?
} else if target_timeline == tl {
@ -67,8 +68,7 @@ impl Ws {
}
fn send_or_filter(&mut self, tl: Timeline, event: &Event, update: &impl Payload) -> Result<()> {
let blocks = &self.subscription.blocks;
let allowed_langs = &self.subscription.allowed_langs;
let (blocks, allowed_langs) = (&self.subscription.blocks, &self.subscription.allowed_langs);
const SKIP: Result<()> = Ok(());
match tl {
tl if tl.is_public()
@ -76,11 +76,21 @@ impl Ws {
&& !allowed_langs.is_empty()
&& !allowed_langs.contains(&update.language()) =>
{
log::info!("{:?} msg skipped - disallowed language", tl);
SKIP
}
tl if !blocks.blocked_users.is_disjoint(&update.involved_users()) => {
log::info!("{:?} msg skipped - involves blocked user", tl);
SKIP
}
tl if blocks.blocking_users.contains(update.author()) => {
log::info!("{:?} msg skipped - from blocking user", tl);
SKIP
}
tl if blocks.blocked_domains.contains(update.sent_from()) => {
log::info!("{:?} msg skipped - from blocked domain", tl);
SKIP
}
_ if !blocks.blocked_users.is_disjoint(&update.involved_users()) => SKIP,
_ if blocks.blocking_users.contains(update.author()) => SKIP,
_ if blocks.blocked_domains.contains(update.sent_from()) => SKIP,
_ => Ok(self.send_msg(&event)?),
}
}