mirror of https://github.com/mastodon/flodgatt
Initial cleanup/refactor
This commit is contained in:
parent
f3b86ddac8
commit
1732008840
|
@ -0,0 +1,54 @@
|
|||
//! Configuration settings for servers and databases
|
||||
use dotenv::dotenv;
|
||||
use log::warn;
|
||||
use std::{env, net, time};
|
||||
|
||||
/// Configure CORS for the API server
|
||||
pub fn cross_origin_resource_sharing() -> warp::filters::cors::Cors {
|
||||
warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_methods(vec!["GET", "OPTIONS"])
|
||||
.allow_headers(vec!["Authorization", "Accept", "Cache-Control"])
|
||||
}
|
||||
|
||||
/// Initialize logging and read values from `src/.env`
|
||||
pub fn logging_and_env() {
|
||||
pretty_env_logger::init();
|
||||
dotenv().ok();
|
||||
}
|
||||
|
||||
/// Configure Postgres and return a connection
|
||||
pub fn postgres() -> postgres::Connection {
|
||||
let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or_else(|_| {
|
||||
format!(
|
||||
"postgres://{}@localhost/mastodon_development",
|
||||
env::var("USER").unwrap_or_else(|_| {
|
||||
warn!("No USER env variable set. Connecting to Postgress with default `postgres` user");
|
||||
"postgres".to_owned()
|
||||
})
|
||||
)
|
||||
});
|
||||
postgres::Connection::connect(postgres_addr, postgres::TlsMode::None)
|
||||
.expect("Can connect to local Postgres")
|
||||
}
|
||||
|
||||
pub fn redis_addr() -> (net::TcpStream, net::TcpStream) {
|
||||
let redis_addr = env::var("REDIS_ADDR").unwrap_or_else(|_| "127.0.0.1:6379".to_string());
|
||||
let pubsub_connection = net::TcpStream::connect(&redis_addr).expect("Can connect to Redis");
|
||||
pubsub_connection
|
||||
.set_read_timeout(Some(time::Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
let secondary_redis_connection =
|
||||
net::TcpStream::connect(&redis_addr).expect("Can connect to Redis");
|
||||
secondary_redis_connection
|
||||
.set_read_timeout(Some(time::Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
(pubsub_connection, secondary_redis_connection)
|
||||
}
|
||||
|
||||
pub fn socket_address() -> net::SocketAddr {
|
||||
env::var("SERVER_ADDR")
|
||||
.unwrap_or_else(|_| "127.0.0.1:4000".to_owned())
|
||||
.parse()
|
||||
.expect("static string")
|
||||
}
|
|
@ -30,3 +30,7 @@ pub fn handle_errors(
|
|||
warp::http::StatusCode::UNAUTHORIZED,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn unauthorized_list() -> warp::reject::Rejection {
|
||||
warp::reject::custom("Error: Access to list not authorized")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
//! Streaming server for Mastodon
|
||||
//!
|
||||
//!
|
||||
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
|
||||
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
|
||||
//! connect to the server with the API described [in Mastodon's public API
|
||||
//! documentation](https://docs.joinmastodon.org/api/streaming/).
|
||||
//!
|
||||
//! # Notes on data flow
|
||||
//! * **Client Request → Warp**:
|
||||
//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User`
|
||||
//! representing the client that made the request with data from the client's request and from
|
||||
//! Postgres. The `User` is authenticated, if appropriate. Warp //! repeatedly polls the
|
||||
//! StreamManager for information relevant to the User.
|
||||
//!
|
||||
//! * **Warp → StreamManager**:
|
||||
//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent
|
||||
//! access to the (single) `Receiver`, which it can access behind an `Arc<Mutex>`. The `StreamManager`
|
||||
//! polls the `Receiver` for any updates relevant to the current client. If there are updates, the
|
||||
//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp.
|
||||
//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the
|
||||
//! `Receiver`) when necessary.
|
||||
//!
|
||||
//! * **StreamManager → Receiver**:
|
||||
//! The Receiver receives data from Redis and stores it in a series of queues (one for each
|
||||
//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages
|
||||
//! relevant to that StreamManager and removes them from the queue.
|
||||
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod postgres;
|
||||
pub mod query;
|
||||
pub mod receiver;
|
||||
pub mod redis_cmd;
|
||||
pub mod stream_manager;
|
||||
pub mod timeline;
|
||||
pub mod user;
|
||||
pub mod ws;
|
196
src/main.rs
196
src/main.rs
|
@ -1,58 +1,20 @@
|
|||
//! Streaming server for Mastodon
|
||||
//!
|
||||
//!
|
||||
//! This server provides live, streaming updates for Mastodon clients. Specifically, when a server
|
||||
//! is running this sever, Mastodon clients can use either Server Sent Events or WebSockets to
|
||||
//! connect to the server with the API described [in the public API
|
||||
//! documentation](https://docs.joinmastodon.org/api/streaming/)
|
||||
//!
|
||||
//! # Notes on data flow
|
||||
//! * **Client Request → Warp**:
|
||||
//! Warp filters for valid requests and parses request data. Based on that data, it generates a `User`
|
||||
//! representing the client that made the request. The `User` is authenticated, if appropriate. Warp
|
||||
//! repeatedly polls the StreamManager for information relevant to the User.
|
||||
//!
|
||||
//! * **Warp → StreamManager**:
|
||||
//! A new `StreamManager` is created for each request. The `StreamManager` exists to manage concurrent
|
||||
//! access to the (single) `Receiver`, which it can access behind an `Arc<Mutex>`. The `StreamManager`
|
||||
//! polles the `Receiver` for any updates relvant to the current client. If there are updates, the
|
||||
//! `StreamManager` filters them with the client's filters and passes any matching updates up to Warp.
|
||||
//! The `StreamManager` is also responsible for sending `subscribe` commands to Redis (via the
|
||||
//! `Receiver`) when necessary.
|
||||
//!
|
||||
//! * **StreamManger → Receiver**:
|
||||
//! The Receiver receives data from Redis and stores it in a series of queues (one for each
|
||||
//! StreamManager). When (asynchronously) polled by the StreamManager, it sends back the messages
|
||||
//! relevant to that StreamManager and removes them from the queue.
|
||||
|
||||
pub mod error;
|
||||
pub mod query;
|
||||
pub mod receiver;
|
||||
pub mod redis_cmd;
|
||||
pub mod stream;
|
||||
pub mod timeline;
|
||||
pub mod user;
|
||||
pub mod ws;
|
||||
use dotenv::dotenv;
|
||||
use futures::stream::Stream;
|
||||
use futures::Async;
|
||||
use receiver::Receiver;
|
||||
use std::env;
|
||||
use std::net::SocketAddr;
|
||||
use stream::StreamManager;
|
||||
use user::{OauthScope::*, Scope, User};
|
||||
use warp::path;
|
||||
use warp::Filter as WarpFilter;
|
||||
use futures::{stream::Stream, Async};
|
||||
use ragequit::{
|
||||
any_of, config, error,
|
||||
stream_manager::StreamManager,
|
||||
timeline,
|
||||
user::{Filter::*, User},
|
||||
ws,
|
||||
};
|
||||
use warp::{ws::Ws2, Filter as WarpFilter};
|
||||
|
||||
fn main() {
|
||||
pretty_env_logger::init();
|
||||
dotenv().ok();
|
||||
config::logging_and_env();
|
||||
let stream_manager_sse = StreamManager::new();
|
||||
let stream_manager_ws = stream_manager_sse.clone();
|
||||
|
||||
let redis_updates = StreamManager::new(Receiver::new());
|
||||
let redis_updates_sse = redis_updates.blank_copy();
|
||||
let redis_updates_ws = redis_updates.blank_copy();
|
||||
|
||||
let routes = any_of!(
|
||||
// Server Sent Events
|
||||
let sse_routes = any_of!(
|
||||
// GET /api/v1/streaming/user/notification [private; notification filter]
|
||||
timeline::user_notifications(),
|
||||
// GET /api/v1/streaming/user [private; language filter]
|
||||
|
@ -77,12 +39,12 @@ fn main() {
|
|||
.untuple_one()
|
||||
.and(warp::sse())
|
||||
.map(move |timeline: String, user: User, sse: warp::sse::Sse| {
|
||||
let mut redis_stream = redis_updates_sse.configure_copy(&timeline, user);
|
||||
let mut stream_manager = stream_manager_sse.manage_new_timeline(&timeline, user);
|
||||
let event_stream = tokio::timer::Interval::new(
|
||||
std::time::Instant::now(),
|
||||
std::time::Duration::from_millis(100),
|
||||
)
|
||||
.filter_map(move |_| match redis_stream.poll() {
|
||||
.filter_map(move |_| match stream_manager.poll() {
|
||||
Ok(Async::Ready(Some(json_value))) => Some((
|
||||
warp::sse::event(json_value["event"].clone().to_string()),
|
||||
warp::sse::data(json_value["payload"].clone()),
|
||||
|
@ -94,86 +56,54 @@ fn main() {
|
|||
.with(warp::reply::with::header("Connection", "keep-alive"))
|
||||
.recover(error::handle_errors);
|
||||
|
||||
//let redis_updates_ws = StreamManager::new(Receiver::new());
|
||||
let websocket = path!("api" / "v1" / "streaming")
|
||||
.and(Scope::Public.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Public))
|
||||
.and(warp::query())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.and(warp::ws2())
|
||||
.and_then(
|
||||
move |mut user: User,
|
||||
q: query::Stream,
|
||||
m: query::Media,
|
||||
h: query::Hashtag,
|
||||
l: query::List,
|
||||
ws: warp::ws::Ws2| {
|
||||
let scopes = user.scopes.clone();
|
||||
let timeline = match q.stream.as_ref() {
|
||||
// Public endpoints:
|
||||
tl @ "public" | tl @ "public:local" if m.is_truthy() => format!("{}:media", tl),
|
||||
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
|
||||
tl @ "public" | tl @ "public:local" => tl.to_string(),
|
||||
// Hashtag endpoints:
|
||||
// TODO: handle missing query
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, h.tag),
|
||||
// Private endpoints: User
|
||||
"user"
|
||||
if user.id > 0
|
||||
&& (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) =>
|
||||
{
|
||||
format!("{}", user.id)
|
||||
}
|
||||
"user:notification"
|
||||
if user.id > 0
|
||||
&& (scopes.contains(&Read) || scopes.contains(&ReadNotifications)) =>
|
||||
{
|
||||
user = user.with_notification_filter();
|
||||
format!("{}", user.id)
|
||||
}
|
||||
// List endpoint:
|
||||
// TODO: handle missing query
|
||||
"list"
|
||||
if user.authorized_for_list(l.list).is_ok()
|
||||
&& (scopes.contains(&Read) || scopes.contains(&ReadList)) =>
|
||||
{
|
||||
format!("list:{}", l.list)
|
||||
}
|
||||
// WebSocket
|
||||
let websocket_routes = ws::websocket_routes()
|
||||
.and_then(move |mut user: User, q: ws::Query, ws: Ws2| {
|
||||
let read_scope = user.scopes.clone();
|
||||
let timeline = match q.stream.as_ref() {
|
||||
// Public endpoints:
|
||||
tl @ "public" | tl @ "public:local" if q.media => format!("{}:media", tl),
|
||||
tl @ "public:media" | tl @ "public:local:media" => tl.to_string(),
|
||||
tl @ "public" | tl @ "public:local" => tl.to_string(),
|
||||
// Hashtag endpoints:
|
||||
// TODO: handle missing query
|
||||
tl @ "hashtag" | tl @ "hashtag:local" => format!("{}:{}", tl, q.hashtag),
|
||||
// Private endpoints: User
|
||||
"user" if user.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
format!("{}", user.id)
|
||||
}
|
||||
"user:notification" if user.logged_in && (read_scope.all || read_scope.notify) => {
|
||||
user = user.set_filter(Notification);
|
||||
format!("{}", user.id)
|
||||
}
|
||||
// List endpoint:
|
||||
// TODO: handle missing query
|
||||
"list" if user.owns_list(q.list) && (read_scope.all || read_scope.lists) => {
|
||||
format!("list:{}", q.list)
|
||||
}
|
||||
// Direct endpoint:
|
||||
"direct" if user.logged_in && (read_scope.all || read_scope.statuses) => {
|
||||
"direct".to_string()
|
||||
}
|
||||
// Reject unathorized access attempts for private endpoints
|
||||
"user" | "user:notification" | "direct" | "list" => {
|
||||
return Err(warp::reject::custom("Error: Invalid Access Token"))
|
||||
}
|
||||
// Other endpoints don't exist:
|
||||
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
|
||||
};
|
||||
let token = user.access_token.clone();
|
||||
let stream_manager = stream_manager_ws.manage_new_timeline(&timeline, user);
|
||||
|
||||
// Direct endpoint:
|
||||
"direct"
|
||||
if user.id > 0
|
||||
&& (scopes.contains(&Read) || scopes.contains(&ReadStatuses)) =>
|
||||
{
|
||||
"direct".to_string()
|
||||
}
|
||||
// Reject unathorized access attempts for private endpoints
|
||||
"user" | "user:notification" | "direct" | "list" => {
|
||||
return Err(warp::reject::custom("Error: Invalid Access Token"))
|
||||
}
|
||||
// Other endpoints don't exist:
|
||||
_ => return Err(warp::reject::custom("Error: Nonexistent WebSocket query")),
|
||||
};
|
||||
let token = user.access_token.clone();
|
||||
let stream = redis_updates_ws.configure_copy(&timeline, user);
|
||||
|
||||
Ok((
|
||||
ws.on_upgrade(move |socket| ws::send_replies(socket, stream)),
|
||||
token,
|
||||
))
|
||||
},
|
||||
)
|
||||
Ok((
|
||||
ws.on_upgrade(move |socket| ws::send_replies(socket, stream_manager)),
|
||||
token,
|
||||
))
|
||||
})
|
||||
.map(|(reply, token)| warp::reply::with_header(reply, "sec-websocket-protocol", token));
|
||||
|
||||
let address: SocketAddr = env::var("SERVER_ADDR")
|
||||
.unwrap_or("127.0.0.1:4000".to_owned())
|
||||
.parse()
|
||||
.expect("static string");
|
||||
let cors = warp::cors()
|
||||
.allow_any_origin()
|
||||
.allow_methods(vec!["GET", "OPTIONS"])
|
||||
.allow_headers(vec!["Authorization", "Accept", "Cache-Control"]);
|
||||
warp::serve(websocket.or(routes).with(cors)).run(address);
|
||||
let cors = config::cross_origin_resource_sharing();
|
||||
let address = config::socket_address();
|
||||
|
||||
warp::serve(websocket_routes.or(sse_routes).with(cors)).run(address);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
//! Postgres queries
|
||||
use crate::config;
|
||||
|
||||
pub fn query_for_user_data(access_token: &str) -> (i64, Option<Vec<String>>, Vec<String>) {
|
||||
let conn = config::postgres();
|
||||
let query_result = conn
|
||||
.query(
|
||||
"
|
||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||
FROM
|
||||
oauth_access_tokens
|
||||
INNER JOIN users ON
|
||||
oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1
|
||||
AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&access_token.to_owned()],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if !query_result.is_empty() {
|
||||
let only_row = query_result.get(0);
|
||||
let id: i64 = only_row.get(1);
|
||||
let scopes = only_row
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.map(|s| s.to_owned())
|
||||
.collect();
|
||||
let langs: Option<Vec<String>> = only_row.get(2);
|
||||
(id, langs, scopes)
|
||||
} else {
|
||||
(-1, None, Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query_list_owner(list_id: i64) -> Option<i64> {
|
||||
let conn = config::postgres();
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT id, account_id
|
||||
FROM lists
|
||||
WHERE id = $1
|
||||
LIMIT 1",
|
||||
&[&list_id],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if rows.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(rows.get(0).get(1))
|
||||
}
|
||||
}
|
235
src/receiver.rs
235
src/receiver.rs
|
@ -1,165 +1,155 @@
|
|||
//! Interfacing with Redis and stream the results on to the `StreamManager`
|
||||
use crate::redis_cmd;
|
||||
use crate::user::User;
|
||||
use futures::stream::Stream;
|
||||
//! Interface with Redis and stream the results to the `StreamManager`
|
||||
//! There is only one `Receiver`, which suggests that it's name is bad.
|
||||
//!
|
||||
//! **TODO**: Consider changing the name. Maybe RedisConnectionPool?
|
||||
//! There are many AsyncReadableStreams, though. How do they fit in?
|
||||
//! Figure this out ASAP.
|
||||
//! A new one is created every time the Receiver is polled
|
||||
use crate::{config, pubsub_cmd, redis_cmd};
|
||||
use futures::{Async, Poll};
|
||||
use log::info;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::env;
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::{collections, io::Read, io::Write, net, time};
|
||||
use tokio::io::{AsyncRead, Error};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MsgQueue {
|
||||
messages: VecDeque<Value>,
|
||||
last_polled_at: Instant,
|
||||
redis_channel: String,
|
||||
}
|
||||
impl MsgQueue {
|
||||
fn new(redis_channel: impl std::fmt::Display) -> Self {
|
||||
let redis_channel = redis_channel.to_string();
|
||||
MsgQueue {
|
||||
messages: VecDeque::new(),
|
||||
last_polled_at: Instant::now(),
|
||||
redis_channel,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The item that streams from Redis and is polled by the `StreamManger`
|
||||
/// The item that streams from Redis and is polled by the `StreamManager`
|
||||
#[derive(Debug)]
|
||||
pub struct Receiver {
|
||||
pubsub_connection: TcpStream,
|
||||
secondary_redis_connection: TcpStream,
|
||||
pubsub_connection: net::TcpStream,
|
||||
secondary_redis_connection: net::TcpStream,
|
||||
tl: String,
|
||||
pub user: User,
|
||||
manager_id: Uuid,
|
||||
msg_queues: HashMap<Uuid, MsgQueue>,
|
||||
clients_per_timeline: HashMap<String, i32>,
|
||||
}
|
||||
impl Default for Receiver {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
msg_queues: collections::HashMap<Uuid, MsgQueue>,
|
||||
clients_per_timeline: collections::HashMap<String, i32>,
|
||||
}
|
||||
|
||||
impl Receiver {
|
||||
/// Create a new `Receiver`, with its own Redis connections (but, as yet, no
|
||||
/// active subscriptions).
|
||||
pub fn new() -> Self {
|
||||
let redis_addr = env::var("REDIS_ADDR").unwrap_or("127.0.0.1:6379".to_string());
|
||||
let pubsub_connection = TcpStream::connect(&redis_addr).expect("Can connect to Redis");
|
||||
pubsub_connection
|
||||
.set_read_timeout(Some(Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
let secondary_redis_connection =
|
||||
TcpStream::connect(&redis_addr).expect("Can connect to Redis");
|
||||
secondary_redis_connection
|
||||
.set_read_timeout(Some(Duration::from_millis(10)))
|
||||
.expect("Can set read timeout for Redis connection");
|
||||
let (pubsub_connection, secondary_redis_connection) = config::redis_addr();
|
||||
Self {
|
||||
pubsub_connection,
|
||||
secondary_redis_connection,
|
||||
tl: String::new(),
|
||||
user: User::public(),
|
||||
manager_id: Uuid::new_v4(),
|
||||
msg_queues: HashMap::new(),
|
||||
clients_per_timeline: HashMap::new(),
|
||||
manager_id: Uuid::default(),
|
||||
msg_queues: collections::HashMap::new(),
|
||||
clients_per_timeline: collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the `StreamManager` that is currently polling the `Receiver`
|
||||
pub fn update(&mut self, id: Uuid, timeline: impl std::fmt::Display) {
|
||||
self.manager_id = id;
|
||||
/// Assigns the `Receiver` a new timeline to monitor and runs other
|
||||
/// first-time setup.
|
||||
///
|
||||
/// Importantly, this method calls `subscribe_or_unsubscribe_as_needed`,
|
||||
/// so Redis PubSub subscriptions are only updated when a new timeline
|
||||
/// comes under management for the first time.
|
||||
pub fn manage_new_timeline(&mut self, manager_id: Uuid, timeline: &str) {
|
||||
self.manager_id = manager_id;
|
||||
self.tl = timeline.to_string();
|
||||
let old_value = self
|
||||
.msg_queues
|
||||
.insert(self.manager_id, MsgQueue::new(timeline));
|
||||
// Consider removing/refactoring
|
||||
if let Some(value) = old_value {
|
||||
eprintln!(
|
||||
"Data was overwritten when it shouldn't have been. Old data was: {:#?}",
|
||||
value
|
||||
);
|
||||
}
|
||||
self.subscribe_or_unsubscribe_as_needed(timeline);
|
||||
}
|
||||
|
||||
/// Set the `Receiver`'s manager_id and target_timeline fields to the approprate
|
||||
/// value to be polled by the current `StreamManager`.
|
||||
pub fn configure_for_polling(&mut self, manager_id: Uuid, timeline: &str) {
|
||||
if &manager_id != &self.manager_id {
|
||||
//println!("New Manager: {}", &manager_id);
|
||||
}
|
||||
self.manager_id = manager_id;
|
||||
self.tl = timeline.to_string();
|
||||
}
|
||||
|
||||
/// Send a subscribe command to the Redis PubSub (if needed)
|
||||
pub fn maybe_subscribe(&mut self, tl: &str) {
|
||||
info!("Subscribing to {}", &tl);
|
||||
|
||||
let manager_id = self.manager_id;
|
||||
self.msg_queues.insert(manager_id, MsgQueue::new(tl));
|
||||
let current_clients = self
|
||||
.clients_per_timeline
|
||||
.entry(tl.to_string())
|
||||
.and_modify(|n| *n += 1)
|
||||
.or_insert(1);
|
||||
|
||||
if *current_clients == 1 {
|
||||
let subscribe_cmd = redis_cmd::pubsub("subscribe", tl);
|
||||
self.pubsub_connection
|
||||
.write_all(&subscribe_cmd)
|
||||
.expect("Can subscribe to Redis");
|
||||
let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "1");
|
||||
self.secondary_redis_connection
|
||||
.write_all(&set_subscribed_cmd)
|
||||
.expect("Can set Redis");
|
||||
info!("Now subscribed to: {:#?}", &self.msg_queues);
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop any PubSub subscriptions that don't have active clients
|
||||
pub fn unsubscribe_from_empty_channels(&mut self) {
|
||||
let mut timelines_with_fewer_clients = Vec::new();
|
||||
/// Drop any PubSub subscriptions that don't have active clients and check
|
||||
/// that there's a subscription to the current one. If there isn't, then
|
||||
/// subscribe to it.
|
||||
fn subscribe_or_unsubscribe_as_needed(&mut self, tl: &str) {
|
||||
let mut timelines_to_modify = Vec::new();
|
||||
timelines_to_modify.push((tl.to_owned(), 1));
|
||||
|
||||
// Keep only message queues that have been polled recently
|
||||
self.msg_queues.retain(|_id, msg_queue| {
|
||||
if msg_queue.last_polled_at.elapsed() < Duration::from_secs(30) {
|
||||
if msg_queue.last_polled_at.elapsed() < time::Duration::from_secs(30) {
|
||||
true
|
||||
} else {
|
||||
timelines_with_fewer_clients.push(msg_queue.redis_channel.clone());
|
||||
let timeline = msg_queue.redis_channel.clone();
|
||||
timelines_to_modify.push((timeline, -1));
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
// Record the lower number of clients subscribed to that channel
|
||||
for timeline in timelines_with_fewer_clients {
|
||||
for (timeline, numerical_change) in timelines_to_modify {
|
||||
let mut need_to_subscribe = false;
|
||||
let count_of_subscribed_clients = self
|
||||
.clients_per_timeline
|
||||
.entry(timeline.clone())
|
||||
.and_modify(|n| *n -= 1)
|
||||
.or_insert(0);
|
||||
.entry(timeline.to_owned())
|
||||
.and_modify(|n| *n += numerical_change)
|
||||
.or_insert_with(|| {
|
||||
need_to_subscribe = true;
|
||||
1
|
||||
});
|
||||
// If no clients, unsubscribe from the channel
|
||||
if *count_of_subscribed_clients <= 0 {
|
||||
self.unsubscribe(&timeline);
|
||||
info!("Sent unsubscribe command");
|
||||
pubsub_cmd!("unsubscribe", self, timeline.clone());
|
||||
}
|
||||
if need_to_subscribe {
|
||||
info!("Sent subscribe command");
|
||||
pubsub_cmd!("subscribe", self, timeline.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send an unsubscribe command to the Redis PubSub
|
||||
pub fn unsubscribe(&mut self, tl: &str) {
|
||||
let unsubscribe_cmd = redis_cmd::pubsub("unsubscribe", tl);
|
||||
info!("Unsubscribing from {}", &tl);
|
||||
self.pubsub_connection
|
||||
.write_all(&unsubscribe_cmd)
|
||||
.expect("Can unsubscribe from Redis");
|
||||
let set_subscribed_cmd = redis_cmd::set(format!("subscribed:timeline:{}", tl), "0");
|
||||
self.secondary_redis_connection
|
||||
.write_all(&set_subscribed_cmd)
|
||||
.expect("Can set Redis");
|
||||
info!("Now subscribed only to: {:#?}", &self.msg_queues);
|
||||
fn log_number_of_msgs_in_queue(&self) {
|
||||
let messages_waiting = self
|
||||
.msg_queues
|
||||
.get(&self.manager_id)
|
||||
.expect("Guaranteed by match block")
|
||||
.messages
|
||||
.len();
|
||||
match messages_waiting {
|
||||
number if number > 10 => {
|
||||
log::error!("{} messages waiting in the queue", messages_waiting)
|
||||
}
|
||||
_ => log::info!("{} messages waiting in the queue", messages_waiting),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Stream for Receiver {
|
||||
impl Default for Receiver {
|
||||
fn default() -> Self {
|
||||
Receiver::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl futures::stream::Stream for Receiver {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Value>, Self::Error> {
|
||||
let mut buffer = vec![0u8; 3000];
|
||||
info!("Being polled by: {}", self.manager_id);
|
||||
let timeline = self.tl.clone();
|
||||
|
||||
// Record current time as last polled time
|
||||
self.msg_queues
|
||||
.entry(self.manager_id)
|
||||
.and_modify(|msg_queue| msg_queue.last_polled_at = Instant::now());
|
||||
.and_modify(|msg_queue| msg_queue.last_polled_at = time::Instant::now());
|
||||
|
||||
// Add any incomming messages to the back of the relevant `msg_queues`
|
||||
// NOTE: This could be more/other than the `msg_queue` currently being polled
|
||||
let mut async_stream = AsyncReadableStream(&mut self.pubsub_connection);
|
||||
let mut async_stream = AsyncReadableStream::new(&mut self.pubsub_connection);
|
||||
if let Async::Ready(num_bytes_read) = async_stream.poll_read(&mut buffer)? {
|
||||
let raw_redis_response = &String::from_utf8_lossy(&buffer[..num_bytes_read]);
|
||||
// capture everything between `{` and `}` as potential JSON
|
||||
|
@ -183,11 +173,14 @@ impl Stream for Receiver {
|
|||
match self
|
||||
.msg_queues
|
||||
.entry(self.manager_id)
|
||||
.or_insert_with(|| MsgQueue::new(timeline))
|
||||
.or_insert_with(|| MsgQueue::new(timeline.clone()))
|
||||
.messages
|
||||
.pop_front()
|
||||
{
|
||||
Some(value) => Ok(Async::Ready(Some(value))),
|
||||
Some(value) => {
|
||||
self.log_number_of_msgs_in_queue();
|
||||
Ok(Async::Ready(Some(value)))
|
||||
}
|
||||
_ => Ok(Async::NotReady),
|
||||
}
|
||||
}
|
||||
|
@ -195,12 +188,34 @@ impl Stream for Receiver {
|
|||
|
||||
impl Drop for Receiver {
|
||||
fn drop(&mut self) {
|
||||
let timeline = self.tl.clone();
|
||||
self.unsubscribe(&timeline);
|
||||
pubsub_cmd!("unsubscribe", self, self.tl.clone());
|
||||
}
|
||||
}
|
||||
|
||||
struct AsyncReadableStream<'a>(&'a mut TcpStream);
|
||||
#[derive(Debug, Clone)]
|
||||
struct MsgQueue {
|
||||
pub messages: collections::VecDeque<Value>,
|
||||
pub last_polled_at: time::Instant,
|
||||
pub redis_channel: String,
|
||||
}
|
||||
|
||||
impl MsgQueue {
|
||||
pub fn new(redis_channel: impl std::fmt::Display) -> Self {
|
||||
let redis_channel = redis_channel.to_string();
|
||||
MsgQueue {
|
||||
messages: collections::VecDeque::new(),
|
||||
last_polled_at: time::Instant::now(),
|
||||
redis_channel,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AsyncReadableStream<'a>(&'a mut net::TcpStream);
|
||||
impl<'a> AsyncReadableStream<'a> {
|
||||
pub fn new(stream: &'a mut net::TcpStream) -> Self {
|
||||
AsyncReadableStream(stream)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Read for AsyncReadableStream<'a> {
|
||||
fn read(&mut self, buffer: &mut [u8]) -> Result<usize, std::io::Error> {
|
||||
|
|
|
@ -1,6 +1,26 @@
|
|||
//! Send raw TCP commands to the Redis server
|
||||
use std::fmt::Display;
|
||||
|
||||
/// Send a subscribe or unsubscribe to the Redis PubSub channel
|
||||
#[macro_export]
|
||||
macro_rules! pubsub_cmd {
|
||||
($cmd:expr, $self:expr, $tl:expr) => {{
|
||||
info!("Sending {} command to {}", $cmd, $tl);
|
||||
$self
|
||||
.pubsub_connection
|
||||
.write_all(&redis_cmd::pubsub($cmd, $tl))
|
||||
.expect("Can send command to Redis");
|
||||
let new_value = if $cmd == "subscribe" { "1" } else { "0" };
|
||||
$self
|
||||
.secondary_redis_connection
|
||||
.write_all(&redis_cmd::set(
|
||||
format!("subscribed:timeline:{}", $tl),
|
||||
new_value,
|
||||
))
|
||||
.expect("Can set Redis");
|
||||
info!("Now subscribed to: {:#?}", $self.msg_queues);
|
||||
}};
|
||||
}
|
||||
/// Send a `SUBSCRIBE` or `UNSUBSCRIBE` command to a specific timeline
|
||||
pub fn pubsub(command: impl Display, timeline: impl Display) -> Vec<u8> {
|
||||
let arg = format!("timeline:{}", timeline);
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
//! Manage all existing Redis PubSub connection
|
||||
use crate::receiver::Receiver;
|
||||
use crate::user::{Filter, User};
|
||||
use futures::stream::Stream;
|
||||
use futures::{Async, Poll};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::io::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Struct for manageing all Redis streams
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct StreamManager {
|
||||
receiver: Arc<Mutex<Receiver>>,
|
||||
id: uuid::Uuid,
|
||||
target_timeline: String,
|
||||
current_user: Option<User>,
|
||||
}
|
||||
impl StreamManager {
|
||||
pub fn new(reciever: Receiver) -> Self {
|
||||
StreamManager {
|
||||
receiver: Arc::new(Mutex::new(reciever)),
|
||||
id: Uuid::default(),
|
||||
target_timeline: String::new(),
|
||||
current_user: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a blank StreamManager copy
|
||||
pub fn blank_copy(&self) -> Self {
|
||||
StreamManager { ..self.clone() }
|
||||
}
|
||||
/// Create a StreamManager copy with a new unique id manage subscriptions
|
||||
pub fn configure_copy(&self, timeline: &String, user: User) -> Self {
|
||||
let id = Uuid::new_v4();
|
||||
let mut receiver = self.receiver.lock().expect("No panic in other threads");
|
||||
receiver.update(id, timeline);
|
||||
receiver.maybe_subscribe(timeline);
|
||||
StreamManager {
|
||||
id,
|
||||
current_user: Some(user),
|
||||
target_timeline: timeline.clone(),
|
||||
..self.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for StreamManager {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
let mut receiver = self
|
||||
.receiver
|
||||
.lock()
|
||||
.expect("StreamManager: No other thread panic");
|
||||
receiver.update(self.id, &self.target_timeline.clone());
|
||||
match receiver.poll() {
|
||||
Ok(Async::Ready(Some(value))) => {
|
||||
let user = self
|
||||
.clone()
|
||||
.current_user
|
||||
.expect("Previously set current user");
|
||||
|
||||
let user_langs = user.langs.clone();
|
||||
let event = value["event"].as_str().expect("Redis string");
|
||||
let payload = value["payload"].to_string();
|
||||
|
||||
match (&user.filter, user_langs) {
|
||||
(Filter::Notification, _) if event != "notification" => Ok(Async::NotReady),
|
||||
(Filter::Language, Some(ref user_langs))
|
||||
if !user_langs.contains(
|
||||
&value["payload"]["language"]
|
||||
.as_str()
|
||||
.expect("Redis str")
|
||||
.to_string(),
|
||||
) =>
|
||||
{
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
_ => Ok(Async::Ready(Some(json!(
|
||||
{"event": event,
|
||||
"payload": payload,}
|
||||
)))),
|
||||
}
|
||||
}
|
||||
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
//! The `StreamManager` is responsible to providing an interface between the `Warp`
|
||||
//! filters and the underlying mechanics of talking with Redis/managing multiple
|
||||
//! threads. The `StreamManager` is the only struct that any Warp code should
|
||||
//! need to communicate with.
|
||||
//!
|
||||
//! The `StreamManager`'s interface is very simple. All you can do with it is:
|
||||
//! * Create a totally new `StreamManger` with no shared data;
|
||||
//! * Assign an existing `StreamManager` to manage an new timeline/user pair; or
|
||||
//! * Poll an existing `StreamManager` to see if there are any new messages
|
||||
//! for clients
|
||||
//!
|
||||
//! When you poll the `StreamManager`, it is responsible for polling internal data
|
||||
//! structures, getting any updates from Redis, and then filtering out any updates
|
||||
//! that should be excluded by relevant filters.
|
||||
//!
|
||||
//! Because `StreamManagers` are lightweight data structures that do not directly
|
||||
//! communicate with Redis, it is appropriate to create a new `StreamManager` for
|
||||
//! each new client connection.
|
||||
use crate::{
|
||||
receiver::Receiver,
|
||||
user::{Filter, User},
|
||||
};
|
||||
use futures::{Async, Poll};
|
||||
use serde_json::{json, Value};
|
||||
use std::sync;
|
||||
use std::time;
|
||||
use tokio::io::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Struct for managing all Redis streams.
|
||||
#[derive(Clone, Default, Debug)]
|
||||
pub struct StreamManager {
|
||||
receiver: sync::Arc<sync::Mutex<Receiver>>,
|
||||
id: uuid::Uuid,
|
||||
target_timeline: String,
|
||||
current_user: User,
|
||||
}
|
||||
|
||||
impl StreamManager {
|
||||
/// Create a new `StreamManager` with no shared data.
|
||||
pub fn new() -> Self {
|
||||
StreamManager {
|
||||
receiver: sync::Arc::new(sync::Mutex::new(Receiver::new())),
|
||||
id: Uuid::default(),
|
||||
target_timeline: String::new(),
|
||||
current_user: User::public(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign the `StreamManager` to manage a new timeline/user pair.
|
||||
///
|
||||
/// Note that this *may or may not* result in a new Redis connection.
|
||||
/// If the server has already subscribed to the timeline on behalf of
|
||||
/// a different user, the `StreamManager` is responsible for figuring
|
||||
/// that out and avoiding duplicated connections. Thus, it is safe to
|
||||
/// use this method for each new client connection.
|
||||
pub fn manage_new_timeline(&self, target_timeline: &str, user: User) -> Self {
|
||||
let manager_id = Uuid::new_v4();
|
||||
let mut receiver = self.receiver.lock().expect("No thread panic (stream.rs)");
|
||||
receiver.manage_new_timeline(manager_id, target_timeline);
|
||||
StreamManager {
|
||||
id: manager_id,
|
||||
current_user: user,
|
||||
target_timeline: target_timeline.to_owned(),
|
||||
receiver: self.receiver.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The stream that the `StreamManager` manages. `Poll` is the only method implemented.
|
||||
impl futures::stream::Stream for StreamManager {
|
||||
type Item = Value;
|
||||
type Error = Error;
|
||||
|
||||
/// Checks for any new messages that should be sent to the client.
|
||||
///
|
||||
/// The `StreamManager` will poll underlying data structures and will reply
|
||||
/// with an `Ok(Ready(Some(Value)))` if there is a new message to send to
|
||||
/// the client. If there is no new message or if the new message should be
|
||||
/// filtered out based on one of the user's filters, then the `StreamManager`
|
||||
/// will reply with `Ok(NotReady)`. The `StreamManager` will buble up any
|
||||
/// errors from the underlying data structures.
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
let start_time = time::Instant::now();
|
||||
let result = {
|
||||
let mut receiver = self
|
||||
.receiver
|
||||
.lock()
|
||||
.expect("StreamManager: No other thread panic");
|
||||
receiver.configure_for_polling(self.id, &self.target_timeline.clone());
|
||||
receiver.poll()
|
||||
};
|
||||
println!("Polling took: {:?}", start_time.elapsed());
|
||||
let result = match result {
|
||||
Ok(Async::Ready(Some(value))) => {
|
||||
let user_langs = self.current_user.langs.clone();
|
||||
let toot = Toot::from_json(value);
|
||||
toot.ignore_if_caught_by_filter(&self.current_user.filter, user_langs)
|
||||
}
|
||||
Ok(inner_value) => Ok(inner_value),
|
||||
Err(e) => Err(e),
|
||||
};
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
struct Toot {
|
||||
category: String,
|
||||
payload: String,
|
||||
language: String,
|
||||
}
|
||||
impl Toot {
|
||||
fn from_json(value: Value) -> Self {
|
||||
Self {
|
||||
category: value["event"].as_str().expect("Redis string").to_owned(),
|
||||
payload: value["payload"].to_string(),
|
||||
language: value["payload"]["language"]
|
||||
.as_str()
|
||||
.expect("Redis str")
|
||||
.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_optional_json(&self) -> Option<Value> {
|
||||
Some(json!(
|
||||
{"event": self.category,
|
||||
"payload": self.payload,}
|
||||
))
|
||||
}
|
||||
|
||||
fn ignore_if_caught_by_filter(
|
||||
&self,
|
||||
filter: &Filter,
|
||||
user_langs: Option<Vec<String>>,
|
||||
) -> Result<Async<Option<Value>>, Error> {
|
||||
let toot = self;
|
||||
|
||||
let (send_msg, skip_msg) = (
|
||||
Ok(Async::Ready(toot.to_optional_json())),
|
||||
Ok(Async::NotReady),
|
||||
);
|
||||
|
||||
match &filter {
|
||||
Filter::NoFilter => send_msg,
|
||||
Filter::Notification if toot.category == "notification" => send_msg,
|
||||
// If not, skip it
|
||||
Filter::Notification => skip_msg,
|
||||
Filter::Language if user_langs.is_none() => send_msg,
|
||||
Filter::Language if user_langs.expect("").contains(&toot.language) => send_msg,
|
||||
// If not, skip it
|
||||
Filter::Language => skip_msg,
|
||||
}
|
||||
}
|
||||
}
|
451
src/timeline.rs
451
src/timeline.rs
|
@ -1,6 +1,8 @@
|
|||
//! Filters for all the endpoints accessible for Server Sent Event updates
|
||||
use crate::error;
|
||||
use crate::query;
|
||||
use crate::user::{Scope, User};
|
||||
use crate::user::{Filter::*, Scope, User};
|
||||
use crate::user_from_path;
|
||||
use warp::filters::BoxedFilter;
|
||||
use warp::{path, Filter};
|
||||
|
||||
|
@ -8,14 +10,8 @@ use warp::{path, Filter};
|
|||
type TimelineUser = ((String, User),);
|
||||
|
||||
/// GET /api/v1/streaming/user
|
||||
///
|
||||
///
|
||||
/// **private**. Filter: `Language`
|
||||
pub fn user() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "user")
|
||||
.and(path::end())
|
||||
.and(Scope::Private.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Private))
|
||||
user_from_path!("streaming" / "user", Scope::Private)
|
||||
.map(|user: User| (user.id.to_string(), user))
|
||||
.boxed()
|
||||
}
|
||||
|
@ -23,477 +19,84 @@ pub fn user() -> BoxedFilter<TimelineUser> {
|
|||
/// GET /api/v1/streaming/user/notification
|
||||
///
|
||||
///
|
||||
/// **private**. Filter: `Notification`
|
||||
///
|
||||
///
|
||||
/// **NOTE**: This endpoint is not included in the [public API docs](https://docs.joinmastodon.org/api/streaming/#get-api-v1-streaming-public-local). But it was present in the JavaScript implementation, so has been included here. Should it be publicly documented?
|
||||
pub fn user_notifications() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "user" / "notification")
|
||||
.and(path::end())
|
||||
.and(Scope::Private.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Private))
|
||||
.map(|user: User| (user.id.to_string(), user.with_notification_filter()))
|
||||
user_from_path!("streaming" / "user" / "notification", Scope::Private)
|
||||
.map(|user: User| (user.id.to_string(), user.set_filter(Notification)))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/public
|
||||
///
|
||||
///
|
||||
/// **public**. Filter: `Language`
|
||||
pub fn public() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "public")
|
||||
.and(path::end())
|
||||
.and(Scope::Public.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Public))
|
||||
.map(|user: User| ("public".to_owned(), user.with_language_filter()))
|
||||
user_from_path!("streaming" / "public", Scope::Public)
|
||||
.map(|user: User| ("public".to_owned(), user.set_filter(Language)))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/public?only_media=true
|
||||
///
|
||||
///
|
||||
/// **public**. Filter: `Language`
|
||||
pub fn public_media() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "public")
|
||||
.and(path::end())
|
||||
.and(Scope::Public.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Public))
|
||||
user_from_path!("streaming" / "public", Scope::Public)
|
||||
.and(warp::query())
|
||||
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
|
||||
"1" | "true" => ("public:media".to_owned(), user.with_language_filter()),
|
||||
_ => ("public".to_owned(), user.with_language_filter()),
|
||||
"1" | "true" => ("public:media".to_owned(), user.set_filter(Language)),
|
||||
_ => ("public".to_owned(), user.set_filter(Language)),
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/public/local
|
||||
///
|
||||
///
|
||||
/// **public**. Filter: `Language`
|
||||
pub fn public_local() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "public" / "local")
|
||||
.and(path::end())
|
||||
.and(Scope::Public.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Public))
|
||||
.map(|user: User| ("public:local".to_owned(), user.with_language_filter()))
|
||||
user_from_path!("streaming" / "public" / "local", Scope::Public)
|
||||
.map(|user: User| ("public:local".to_owned(), user.set_filter(Language)))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/public/local?only_media=true
|
||||
///
|
||||
///
|
||||
/// **public**. Filter: `Language`
|
||||
pub fn public_local_media() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "public" / "local")
|
||||
.and(Scope::Public.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Public))
|
||||
user_from_path!("streaming" / "public" / "local", Scope::Public)
|
||||
.and(warp::query())
|
||||
.and(path::end())
|
||||
.map(|user: User, q: query::Media| match q.only_media.as_ref() {
|
||||
"1" | "true" => ("public:local:media".to_owned(), user.with_language_filter()),
|
||||
_ => ("public:local".to_owned(), user.with_language_filter()),
|
||||
"1" | "true" => ("public:local:media".to_owned(), user.set_filter(Language)),
|
||||
_ => ("public:local".to_owned(), user.set_filter(Language)),
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/direct
|
||||
///
|
||||
///
|
||||
/// **private**. Filter: `None`
|
||||
pub fn direct() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "direct")
|
||||
.and(path::end())
|
||||
.and(Scope::Private.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Private))
|
||||
.map(|user: User| (format!("direct:{}", user.id), user.with_no_filter()))
|
||||
user_from_path!("streaming" / "direct", Scope::Private)
|
||||
.map(|user: User| (format!("direct:{}", user.id), user.set_filter(NoFilter)))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/hashtag?tag=:hashtag
|
||||
///
|
||||
///
|
||||
/// **public**. Filter: `None`
|
||||
pub fn hashtag() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "hashtag")
|
||||
.and(warp::query())
|
||||
.and(path::end())
|
||||
.map(|q: query::Hashtag| (format!("hashtag:{}", q.tag), User::public()))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/hashtag/local?tag=:hashtag
|
||||
///
|
||||
///
|
||||
/// **public**. Filter: `None`
|
||||
pub fn hashtag_local() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "hashtag" / "local")
|
||||
.and(warp::query())
|
||||
.and(path::end())
|
||||
.map(|q: query::Hashtag| (format!("hashtag:{}:local", q.tag), User::public()))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GET /api/v1/streaming/list?list=:list_id
|
||||
///
|
||||
///
|
||||
/// **private**. Filter: `None`
|
||||
pub fn list() -> BoxedFilter<TimelineUser> {
|
||||
path!("api" / "v1" / "streaming" / "list")
|
||||
.and(Scope::Private.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, Scope::Private))
|
||||
user_from_path!("streaming" / "list", Scope::Private)
|
||||
.and(warp::query())
|
||||
.and_then(|user: User, q: query::List| (user.authorized_for_list(q.list), Ok(user)))
|
||||
.and_then(|user: User, q: query::List| {
|
||||
if user.owns_list(q.list) {
|
||||
(Ok(q.list), Ok(user))
|
||||
} else {
|
||||
(Err(error::unauthorized_list()), Ok(user))
|
||||
}
|
||||
})
|
||||
.untuple_one()
|
||||
.and(path::end())
|
||||
.map(|list: i64, user: User| (format!("list:{}", list), user.with_no_filter()))
|
||||
.map(|list: i64, user: User| (format!("list:{}", list), user.set_filter(NoFilter)))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Combines multiple routes with the same return type together with
|
||||
/// `or()` and `unify()`
|
||||
#[macro_export]
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::user;
|
||||
|
||||
#[test]
|
||||
fn user_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
|
||||
))
|
||||
.filter(&user());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/user",))
|
||||
.filter(&user());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn user_auth() {
|
||||
let user_id: i64 = 1;
|
||||
let access_token = get_access_token(user_id);
|
||||
|
||||
// Query auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user?access_token={}",
|
||||
access_token
|
||||
))
|
||||
.filter(&user())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/user")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&user())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token, user::Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_notifications_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
|
||||
))
|
||||
.filter(&user_notifications());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/user/notification",))
|
||||
.filter(&user_notifications());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn user_notifications_auth() {
|
||||
let user_id: i64 = 1;
|
||||
let access_token = get_access_token(user_id);
|
||||
|
||||
// Query auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user/notification?access_token={}",
|
||||
access_token
|
||||
))
|
||||
.filter(&user_notifications())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token.clone(), user::Scope::Private)
|
||||
.expect("in test")
|
||||
.with_notification_filter();
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/user/notification")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&user_notifications())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token, user::Scope::Private)
|
||||
.expect("in test")
|
||||
.with_notification_filter();
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
#[test]
|
||||
fn public_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public")
|
||||
.filter(&public())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public".to_string());
|
||||
assert_eq!(value.1, User::public().with_language_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_media_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public?only_media=true")
|
||||
.filter(&public_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:media".to_string());
|
||||
assert_eq!(value.1, User::public().with_language_filter());
|
||||
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public?only_media=1")
|
||||
.filter(&public_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:media".to_string());
|
||||
assert_eq!(value.1, User::public().with_language_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_local_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public/local")
|
||||
.filter(&public_local())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:local".to_string());
|
||||
assert_eq!(value.1, User::public().with_language_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_local_media_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public/local?only_media=true")
|
||||
.filter(&public_local_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:local:media".to_string());
|
||||
assert_eq!(value.1, User::public().with_language_filter());
|
||||
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public/local?only_media=1")
|
||||
.filter(&public_local_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:local:media".to_string());
|
||||
assert_eq!(value.1, User::public().with_language_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn direct_timeline_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
|
||||
))
|
||||
.filter(&direct());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/direct",))
|
||||
.filter(&direct());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn direct_timeline_auth() {
|
||||
let user_id: i64 = 1;
|
||||
let access_token = get_access_token(user_id);
|
||||
|
||||
// Query auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/direct?access_token={}",
|
||||
access_token
|
||||
))
|
||||
.filter(&direct())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "direct:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/direct")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&direct())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token, user::Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "direct:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hashtag_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/hashtag?tag=a")
|
||||
.filter(&hashtag())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "hashtag:a".to_string());
|
||||
assert_eq!(value.1, User::public());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hashtag_timeline_local() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/hashtag/local?tag=a")
|
||||
.filter(&hashtag_local())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "hashtag:a:local".to_string());
|
||||
assert_eq!(value.1, User::public());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn list_timeline_auth() {
|
||||
let list_id = 1;
|
||||
let list_owner_id = get_list_owner(list_id);
|
||||
let access_token = get_access_token(list_owner_id);
|
||||
|
||||
// Query Auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/list?access_token={}&list={}",
|
||||
access_token, list_id,
|
||||
))
|
||||
.filter(&list())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token.clone(), user::Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "list:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header Auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/list?list=1")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&list())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token, user::Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "list:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_timeline_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
|
||||
))
|
||||
.filter(&list());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/list?list=1",))
|
||||
.filter(&list());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
fn get_list_owner(list_number: i32) -> i64 {
|
||||
let list_number: i64 = list_number.into();
|
||||
let conn = user::connect_to_postgres();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||
&[&list_number],
|
||||
)
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(
|
||||
rows.len(),
|
||||
1,
|
||||
"Test database must contain at least one user with a list to run this test."
|
||||
);
|
||||
|
||||
rows.get(0).get(1)
|
||||
}
|
||||
fn get_access_token(user_id: i64) -> String {
|
||||
let conn = user::connect_to_postgres();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Can get access token from id");
|
||||
rows.get(0).get(0)
|
||||
}
|
||||
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
|
||||
match value {
|
||||
Err(error) => match error.cause() {
|
||||
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => {
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
|
||||
match value {
|
||||
Err(error) => match error.cause() {
|
||||
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
178
src/user.rs
178
src/user.rs
|
@ -1,24 +1,21 @@
|
|||
//! Create a User by querying the Postgres database with the user's access_token
|
||||
use crate::{any_of, query};
|
||||
//! `User` struct and related functionality
|
||||
use crate::{postgres, query};
|
||||
use log::info;
|
||||
use postgres;
|
||||
use std::env;
|
||||
use warp::Filter as WarpFilter;
|
||||
|
||||
/// (currently hardcoded to localhost)
|
||||
pub fn connect_to_postgres() -> postgres::Connection {
|
||||
let postgres_addr = env::var("POSTGRESS_ADDR").unwrap_or(format!(
|
||||
"postgres://{}@localhost/mastodon_development",
|
||||
env::var("USER").expect("User env var should exist")
|
||||
));
|
||||
postgres::Connection::connect(postgres_addr, postgres::TlsMode::None)
|
||||
.expect("Can connect to local Postgres")
|
||||
/// Combine multiple routes with the same return type together with
|
||||
/// `or()` and `unify()`
|
||||
#[macro_export]
|
||||
macro_rules! any_of {
|
||||
($filter:expr, $($other_filter:expr),*) => {
|
||||
$filter$(.or($other_filter).unify())*
|
||||
};
|
||||
}
|
||||
|
||||
/// The filters that can be applied to toots after they come from Redis
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum Filter {
|
||||
None,
|
||||
NoFilter,
|
||||
Language,
|
||||
Notification,
|
||||
}
|
||||
|
@ -28,140 +25,99 @@ pub enum Filter {
|
|||
pub struct User {
|
||||
pub id: i64,
|
||||
pub access_token: String,
|
||||
pub scopes: Vec<OauthScope>,
|
||||
pub scopes: OauthScope,
|
||||
pub langs: Option<Vec<String>>,
|
||||
pub logged_in: bool,
|
||||
pub filter: Filter,
|
||||
}
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub enum OauthScope {
|
||||
Read,
|
||||
ReadStatuses,
|
||||
ReadNotifications,
|
||||
ReadList,
|
||||
Other,
|
||||
}
|
||||
impl From<&str> for OauthScope {
|
||||
fn from(scope: &str) -> Self {
|
||||
use OauthScope::*;
|
||||
match scope {
|
||||
"read" => Read,
|
||||
"read:statuses" => ReadStatuses,
|
||||
"read:notifications" => ReadNotifications,
|
||||
"read:lists" => ReadList,
|
||||
_ => Other,
|
||||
}
|
||||
impl Default for User {
|
||||
fn default() -> Self {
|
||||
User::public()
|
||||
}
|
||||
}
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct OauthScope {
|
||||
pub all: bool,
|
||||
pub statuses: bool,
|
||||
pub notify: bool,
|
||||
pub lists: bool,
|
||||
}
|
||||
impl From<Vec<String>> for OauthScope {
|
||||
fn from(scope_list: Vec<String>) -> Self {
|
||||
let mut oauth_scope = OauthScope::default();
|
||||
for scope in scope_list {
|
||||
match scope.as_str() {
|
||||
"read" => oauth_scope.all = true,
|
||||
"read:statuses" => oauth_scope.statuses = true,
|
||||
"read:notifications" => oauth_scope.notify = true,
|
||||
"read:lists" => oauth_scope.lists = true,
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
oauth_scope
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user based on the supplied path and access scope for the resource
|
||||
#[macro_export]
|
||||
macro_rules! user_from_path {
|
||||
($($path_item:tt) / *, $scope:expr) => (path!("api" / "v1" / $($path_item) / +)
|
||||
.and($scope.get_access_token())
|
||||
.and_then(|token| User::from_access_token(token, $scope)))
|
||||
}
|
||||
|
||||
impl User {
|
||||
/// Create a user from the access token supplied in the header or query paramaters
|
||||
pub fn from_access_token(
|
||||
access_token: String,
|
||||
scope: Scope,
|
||||
) -> Result<Self, warp::reject::Rejection> {
|
||||
let conn = connect_to_postgres();
|
||||
let result = &conn
|
||||
.query(
|
||||
"
|
||||
SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes
|
||||
FROM
|
||||
oauth_access_tokens
|
||||
INNER JOIN users ON
|
||||
oauth_access_tokens.resource_owner_id = users.id
|
||||
WHERE oauth_access_tokens.token = $1
|
||||
AND oauth_access_tokens.revoked_at IS NULL
|
||||
LIMIT 1",
|
||||
&[&access_token],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if !result.is_empty() {
|
||||
let only_row = result.get(0);
|
||||
let id: i64 = only_row.get(1);
|
||||
let scopes = only_row
|
||||
.get::<_, String>(3)
|
||||
.split(' ')
|
||||
.map(|scope: &str| scope.into())
|
||||
.filter(|scope| scope != &OauthScope::Other)
|
||||
.collect();
|
||||
dbg!(&scopes);
|
||||
let langs: Option<Vec<String>> = only_row.get(2);
|
||||
info!("Granting logged-in access");
|
||||
let (id, langs, scope_list) = postgres::query_for_user_data(&access_token);
|
||||
let scopes = OauthScope::from(scope_list);
|
||||
if id != -1 || scope == Scope::Public {
|
||||
let (logged_in, log_msg) = match id {
|
||||
-1 => (false, "Public access to non-authenticated endpoints"),
|
||||
_ => (true, "Granting logged-in access"),
|
||||
};
|
||||
info!("{}", log_msg);
|
||||
Ok(User {
|
||||
id,
|
||||
access_token,
|
||||
scopes,
|
||||
langs,
|
||||
logged_in: true,
|
||||
filter: Filter::None,
|
||||
})
|
||||
} else if let Scope::Public = scope {
|
||||
info!("Granting public access to non-authenticated client");
|
||||
Ok(User {
|
||||
id: -1,
|
||||
access_token,
|
||||
scopes: Vec::new(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
filter: Filter::None,
|
||||
logged_in,
|
||||
filter: Filter::NoFilter,
|
||||
})
|
||||
} else {
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
}
|
||||
}
|
||||
/// Add a Notification filter
|
||||
pub fn with_notification_filter(self) -> Self {
|
||||
Self {
|
||||
filter: Filter::Notification,
|
||||
..self
|
||||
}
|
||||
}
|
||||
/// Add a Language filter
|
||||
pub fn with_language_filter(self) -> Self {
|
||||
Self {
|
||||
filter: Filter::Language,
|
||||
..self
|
||||
}
|
||||
}
|
||||
/// Remove all filters
|
||||
pub fn with_no_filter(self) -> Self {
|
||||
Self {
|
||||
filter: Filter::None,
|
||||
..self
|
||||
}
|
||||
/// Set the Notification/Language filter
|
||||
pub fn set_filter(self, filter: Filter) -> Self {
|
||||
Self { filter, ..self }
|
||||
}
|
||||
/// Determine whether the User is authorised for a specified list
|
||||
pub fn authorized_for_list(&self, list: i64) -> Result<i64, warp::reject::Rejection> {
|
||||
let conn = connect_to_postgres();
|
||||
// For the Postgres query, `id` = list number; `account_id` = user.id
|
||||
let rows = &conn
|
||||
.query(
|
||||
" SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||
&[&list],
|
||||
)
|
||||
.expect("Hard-coded query will return Some([0 or more rows])");
|
||||
if !rows.is_empty() {
|
||||
let id_of_account_that_owns_the_list: i64 = rows.get(0).get(1);
|
||||
if id_of_account_that_owns_the_list == self.id {
|
||||
return Ok(list);
|
||||
}
|
||||
};
|
||||
|
||||
Err(warp::reject::custom("Error: Invalid access token"))
|
||||
pub fn owns_list(&self, list: i64) -> bool {
|
||||
match postgres::query_list_owner(list) {
|
||||
Some(i) if i == self.id => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
/// A public (non-authenticated) User
|
||||
pub fn public() -> Self {
|
||||
User {
|
||||
id: -1,
|
||||
access_token: String::new(),
|
||||
scopes: Vec::new(),
|
||||
access_token: String::from("no access token"),
|
||||
scopes: OauthScope::default(),
|
||||
langs: None,
|
||||
logged_in: false,
|
||||
filter: Filter::None,
|
||||
filter: Filter::NoFilter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the endpoint requires authentication or not
|
||||
#[derive(PartialEq)]
|
||||
pub enum Scope {
|
||||
Public,
|
||||
Private,
|
||||
|
|
74
src/ws.rs
74
src/ws.rs
|
@ -1,44 +1,86 @@
|
|||
//! WebSocket-specific functionality
|
||||
use crate::stream::StreamManager;
|
||||
use crate::query;
|
||||
use crate::stream_manager::StreamManager;
|
||||
use crate::user::{Scope, User};
|
||||
use crate::user_from_path;
|
||||
use futures::future::Future;
|
||||
use futures::stream::Stream;
|
||||
use futures::Async;
|
||||
use std::time;
|
||||
use warp::filters::BoxedFilter;
|
||||
use warp::{path, Filter};
|
||||
|
||||
/// Send a stream of replies to a WebSocket client
|
||||
pub fn send_replies(
|
||||
socket: warp::ws::WebSocket,
|
||||
mut stream: StreamManager,
|
||||
) -> impl futures::future::Future<Item = (), Error = ()> {
|
||||
let (tx, rx) = futures::sync::mpsc::unbounded();
|
||||
let (ws_tx, mut ws_rx) = socket.split();
|
||||
|
||||
// Create a pipe
|
||||
let (tx, rx) = futures::sync::mpsc::unbounded();
|
||||
|
||||
// Send one end of it to a different thread and tell that end to forward whatever it gets
|
||||
// on to the websocket client
|
||||
warp::spawn(
|
||||
rx.map_err(|()| -> warp::Error { unreachable!() })
|
||||
.forward(ws_tx)
|
||||
.map_err(|_| ())
|
||||
.map(|_r| ()),
|
||||
);
|
||||
let event_stream = tokio::timer::Interval::new(
|
||||
std::time::Instant::now(),
|
||||
std::time::Duration::from_millis(100),
|
||||
)
|
||||
.take_while(move |_| {
|
||||
if ws_rx.poll().is_err() {
|
||||
futures::future::ok(false)
|
||||
} else {
|
||||
futures::future::ok(true)
|
||||
}
|
||||
});
|
||||
|
||||
// For as long as the client is still connected, yeild a new event every 100 ms
|
||||
let event_stream =
|
||||
tokio::timer::Interval::new(time::Instant::now(), time::Duration::from_millis(100))
|
||||
.take_while(move |_| match ws_rx.poll() {
|
||||
Ok(Async::Ready(None)) => futures::future::ok(false),
|
||||
_ => futures::future::ok(true),
|
||||
});
|
||||
|
||||
// Every time you get an event from that stream, send it through the pipe
|
||||
event_stream
|
||||
.for_each(move |_json_value| {
|
||||
if let Ok(Async::Ready(Some(json_value))) = stream.poll() {
|
||||
let msg = warp::ws::Message::text(json_value.to_string());
|
||||
if !tx.is_closed() {
|
||||
tx.unbounded_send(msg).expect("No send error");
|
||||
}
|
||||
tx.unbounded_send(msg).expect("No send error");
|
||||
};
|
||||
Ok(())
|
||||
})
|
||||
.then(|msg| msg)
|
||||
.map_err(|e| println!("{}", e))
|
||||
}
|
||||
|
||||
pub fn websocket_routes() -> BoxedFilter<(User, Query, warp::ws::Ws2)> {
|
||||
user_from_path!("streaming", Scope::Public)
|
||||
.and(warp::query())
|
||||
.and(query::Media::to_filter())
|
||||
.and(query::Hashtag::to_filter())
|
||||
.and(query::List::to_filter())
|
||||
.and(warp::ws2())
|
||||
.map(
|
||||
|user: User,
|
||||
stream: query::Stream,
|
||||
media: query::Media,
|
||||
hashtag: query::Hashtag,
|
||||
list: query::List,
|
||||
ws: warp::ws::Ws2| {
|
||||
let query = Query {
|
||||
stream: stream.stream,
|
||||
media: media.is_truthy(),
|
||||
hashtag: hashtag.tag,
|
||||
list: list.list,
|
||||
};
|
||||
(user, query, ws)
|
||||
},
|
||||
)
|
||||
.untuple_one()
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Query {
|
||||
pub stream: String,
|
||||
pub media: bool,
|
||||
pub hashtag: String,
|
||||
pub list: i64,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,341 @@
|
|||
use ragequit::{
|
||||
config,
|
||||
timeline::*,
|
||||
user::{Filter::*, Scope, User},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn user_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user?access_token=BAD_ACCESS_TOKEN&list=1",
|
||||
))
|
||||
.filter(&user());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/user",))
|
||||
.filter(&user());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn user_auth() {
|
||||
let user_id: i64 = 1;
|
||||
let access_token = get_access_token(user_id);
|
||||
|
||||
// Query auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user?access_token={}",
|
||||
access_token
|
||||
))
|
||||
.filter(&user())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/user")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&user())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn user_notifications_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user/notification?access_token=BAD_ACCESS_TOKEN",
|
||||
))
|
||||
.filter(&user_notifications());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/user/notification",))
|
||||
.filter(&user_notifications());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn user_notifications_auth() {
|
||||
let user_id: i64 = 1;
|
||||
let access_token = get_access_token(user_id);
|
||||
|
||||
// Query auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/user/notification?access_token={}",
|
||||
access_token
|
||||
))
|
||||
.filter(&user_notifications())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token.clone(), Scope::Private)
|
||||
.expect("in test")
|
||||
.set_filter(Notification);
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/user/notification")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&user_notifications())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token, Scope::Private)
|
||||
.expect("in test")
|
||||
.set_filter(Notification);
|
||||
|
||||
assert_eq!(actual_timeline, "1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
#[test]
|
||||
fn public_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public")
|
||||
.filter(&public())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public".to_string());
|
||||
assert_eq!(value.1, User::public().set_filter(Language));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_media_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public?only_media=true")
|
||||
.filter(&public_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:media".to_string());
|
||||
assert_eq!(value.1, User::public().set_filter(Language));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public?only_media=1")
|
||||
.filter(&public_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:media".to_string());
|
||||
assert_eq!(value.1, User::public().set_filter(Language));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_local_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public/local")
|
||||
.filter(&public_local())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:local".to_string());
|
||||
assert_eq!(value.1, User::public().set_filter(Language));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn public_local_media_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public/local?only_media=true")
|
||||
.filter(&public_local_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:local:media".to_string());
|
||||
assert_eq!(value.1, User::public().set_filter(Language));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/public/local?only_media=1")
|
||||
.filter(&public_local_media())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "public:local:media".to_string());
|
||||
assert_eq!(value.1, User::public().set_filter(Language));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn direct_timeline_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/direct?access_token=BAD_ACCESS_TOKEN",
|
||||
))
|
||||
.filter(&direct());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/direct",))
|
||||
.filter(&direct());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn direct_timeline_auth() {
|
||||
let user_id: i64 = 1;
|
||||
let access_token = get_access_token(user_id);
|
||||
|
||||
// Query auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/direct?access_token={}",
|
||||
access_token
|
||||
))
|
||||
.filter(&direct())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "direct:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/direct")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&direct())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "direct:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hashtag_timeline() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/hashtag?tag=a")
|
||||
.filter(&hashtag())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "hashtag:a".to_string());
|
||||
assert_eq!(value.1, User::public());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hashtag_timeline_local() {
|
||||
let value = warp::test::request()
|
||||
.path("/api/v1/streaming/hashtag/local?tag=a")
|
||||
.filter(&hashtag_local())
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(value.0, "hashtag:a:local".to_string());
|
||||
assert_eq!(value.1, User::public());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn list_timeline_auth() {
|
||||
let list_id = 1;
|
||||
let list_owner_id = get_list_owner(list_id);
|
||||
let access_token = get_access_token(list_owner_id);
|
||||
|
||||
// Query Auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/list?access_token={}&list={}",
|
||||
access_token, list_id,
|
||||
))
|
||||
.filter(&list())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user =
|
||||
User::from_access_token(access_token.clone(), Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "list:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
|
||||
// Header Auth
|
||||
let (actual_timeline, actual_user) = warp::test::request()
|
||||
.path("/api/v1/streaming/list?list=1")
|
||||
.header("Authorization", format!("Bearer: {}", access_token.clone()))
|
||||
.filter(&list())
|
||||
.expect("in test");
|
||||
|
||||
let expected_user = User::from_access_token(access_token, Scope::Private).expect("in test");
|
||||
|
||||
assert_eq!(actual_timeline, "list:1");
|
||||
assert_eq!(actual_user, expected_user);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_timeline_unauthorized() {
|
||||
let value = warp::test::request()
|
||||
.path(&format!(
|
||||
"/api/v1/streaming/list?access_token=BAD_ACCESS_TOKEN&list=1",
|
||||
))
|
||||
.filter(&list());
|
||||
assert!(invalid_access_token(value));
|
||||
|
||||
let value = warp::test::request()
|
||||
.path(&format!("/api/v1/streaming/list?list=1",))
|
||||
.filter(&list());
|
||||
assert!(no_access_token(value));
|
||||
}
|
||||
|
||||
// Helper functions for tests
|
||||
fn get_list_owner(list_number: i32) -> i64 {
|
||||
let list_number: i64 = list_number.into();
|
||||
let conn = config::postgres();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1",
|
||||
&[&list_number],
|
||||
)
|
||||
.expect("in test");
|
||||
|
||||
assert_eq!(
|
||||
rows.len(),
|
||||
1,
|
||||
"Test database must contain at least one user with a list to run this test."
|
||||
);
|
||||
|
||||
rows.get(0).get(1)
|
||||
}
|
||||
|
||||
fn get_access_token(user_id: i64) -> String {
|
||||
let conn = config::postgres();
|
||||
let rows = &conn
|
||||
.query(
|
||||
"SELECT token FROM oauth_access_tokens WHERE resource_owner_id = $1",
|
||||
&[&user_id],
|
||||
)
|
||||
.expect("Can get access token from id");
|
||||
rows.get(0).get(0)
|
||||
}
|
||||
|
||||
fn invalid_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
|
||||
match value {
|
||||
Err(error) => match error.cause() {
|
||||
Some(c) if format!("{:?}", c) == "StringError(\"Error: Invalid access token\")" => true,
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn no_access_token(value: Result<(String, User), warp::reject::Rejection>) -> bool {
|
||||
match value {
|
||||
Err(error) => match error.cause() {
|
||||
// The cause could validly be any of these, depending on the order they're checked
|
||||
// (It would pass with just one, so the last one it doesn't have is "the" cause)
|
||||
Some(c) if format!("{:?}", c) == "MissingHeader(\"authorization\")" => true,
|
||||
Some(c) if format!("{:?}", c) == "InvalidQuery" => true,
|
||||
Some(c) if format!("{:?}", c) == "MissingHeader(\"Sec-WebSocket-Protocol\")" => true,
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue