diff --git a/Cargo.toml b/Cargo.toml index 07e8e6c..217ffd9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,17 +13,17 @@ edition = "2021" clap = { version = "4.5", features = ["derive", "cargo", "wrap_help"] } flate2 = "1.1" horrorshow = "0.8" -reqwest = { version = "0.12", features = ["blocking"] } -iron = "0.6" +reqwest = { version = "0.12", features = ["rustls-tls"] } +hyper = { version = "1.6", features = ["server", "http1"] } +hyper-util = { version = "0.1", features = ["tokio"] } +http-body-util = "0.1" +tokio = { version = "1.44", features = ["full"] } log = "0.4" -router = "0.6" +env_logger = "0.11" +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -# Use the same unicase version that iron uses -unicase = "1.4" -# Use time 0.3 with macros feature for datetime literals time = { version = "0.3", features = ["macros", "formatting"] } -# Add the old time crate for compatibility with iron -time01 = { version = "0.1", package = "time" } +http = "1.3" [features] default = [] diff --git a/src/asns.rs b/src/asns.rs index d5d54a8..a4a0765 100644 --- a/src/asns.rs +++ b/src/asns.rs @@ -1,5 +1,6 @@ use flate2::read::GzDecoder; -use reqwest::blocking::Client; +use hyper::body::Bytes; +use reqwest::Client; use std::cmp::{Eq, Ord, Ordering, PartialEq, PartialOrd}; use std::collections::BTreeSet; use std::io::prelude::*; @@ -53,20 +54,43 @@ pub struct Asns { } impl Asns { - pub fn new(url: &str) -> Result { - info!("Loading the database"); - let client = Client::new(); - let Ok(res) = client.get(url).send() else { - error!("Unable to load the database"); - return Err("Unable to load the database"); - }; - if !res.status().is_success() { - error!("Unable to load the database"); - return Err("Unable to load the database"); - } - let Ok(bytes) = res.bytes() else { - error!("Unable to read response body"); - return Err("Unable to read response body"); + pub async fn new(url: &str) -> Result { + info!("Loading the database from {}", url); + + let bytes = if url.starts_with("file://") { + // Handle local file URLs + let path = url.strip_prefix("file://").unwrap_or(url); + match tokio::fs::read(path).await { + Ok(content) => Bytes::from(content), + Err(e) => { + error!("Unable to read local file: {}", e); + return Err("Unable to read local file"); + } + } + } else { + // Handle HTTP/HTTPS URLs + let client = Client::builder() + .user_agent("iptoasn-webservice/0.2.5") + .build() + .map_err(|_| { + error!("Failed to create HTTP client"); + "Failed to create HTTP client" + })?; + + let res = client.get(url).send().await.map_err(|e| { + error!("Unable to load the database: {}", e); + "Unable to load the database" + })?; + + if !res.status().is_success() { + error!("Unable to load the database, status: {}", res.status()); + return Err("Unable to load the database"); + } + + res.bytes().await.map_err(|e| { + error!("Unable to read response body: {}", e); + "Unable to read response body" + })? }; let mut data = String::new(); if GzDecoder::new(bytes.as_ref()) diff --git a/src/main.rs b/src/main.rs index 7ac1a2f..464dec9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,8 +2,6 @@ extern crate horrorshow; #[macro_use] extern crate log; -#[macro_use] -extern crate router; mod asns; mod webservice; @@ -12,58 +10,84 @@ use crate::asns::Asns; use crate::webservice::WebService; use clap::{Arg, Command}; use std::sync::{Arc, RwLock}; -use std::thread; use std::time::Duration; -fn get_asns(db_url: &str) -> Result { - info!("Retrieving ASNs"); - let asns = Asns::new(db_url); - info!("ASNs loaded"); - asns -} +#[tokio::main] +async fn main() { + env_logger::init(); -fn update_asns(asns_arc: &Arc>>, db_url: &str) { - let asns = match get_asns(db_url) { + let matches = Command::new("iptoasn-webservice") + .version("0.2.5") + .author("Frank Denis ") + .about("IP to ASN webservice") + .arg( + Arg::new("listen_addr") + .short('l') + .long("listen") + .value_name("listen_addr") + .help("Address:port to listen to") + .default_value("127.0.0.1:53661"), + ) + .arg( + Arg::new("db_url") + .short('u') + .long("url") + .value_name("db_url") + .help("URL of the database") + .default_value("file:///Users/j/src/iptoasn-webservice/test_data.tsv.gz"), + ) + .arg( + Arg::new("refresh_delay") + .short('r') + .long("refresh") + .value_name("refresh_delay") + .help("Database refresh delay (minutes)") + .default_value("60"), + ) + .get_matches(); + + let db_url = matches.get_one::("db_url").unwrap(); + let listen_addr = matches.get_one::("listen_addr").unwrap(); + let refresh_delay = matches.get_one::("refresh_delay").unwrap(); + let refresh_delay = refresh_delay.parse::().unwrap(); + + let asns = match get_asns(db_url).await { Ok(asns) => asns, Err(e) => { warn!("{e}"); return; } }; - *asns_arc.write().unwrap() = Arc::new(asns); + let asns_arc = Arc::new(RwLock::new(Arc::new(asns))); + + let asns_arc_t = asns_arc.clone(); + let db_url_t = db_url.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(refresh_delay * 60)).await; + update_asns(&asns_arc_t, &db_url_t).await; + } + }); + + WebService::start(asns_arc, listen_addr).await; } -fn main() { - let matches = Command::new(env!("CARGO_PKG_NAME")) - .version(env!("CARGO_PKG_VERSION")) - .author(env!("CARGO_PKG_AUTHORS")) - .about(env!("CARGO_PKG_DESCRIPTION")) - .arg( - Arg::new("listen_addr") - .short('l') - .long("listen") - .value_name("ip:port") - .help("Webservice IP and port") - .default_value("0.0.0.0:53661"), - ) - .arg( - Arg::new("db_url") - .short('u') - .long("dburl") - .value_name("url") - .help("URL of the gzipped database") - .default_value("https://iptoasn.com/data/ip2asn-combined.tsv.gz"), - ) - .get_matches(); - let db_url = matches.get_one::("db_url").unwrap().to_owned(); - let listen_addr = matches.get_one::("listen_addr").unwrap().as_str(); - let asns = get_asns(&db_url).expect("Unable to load the initial database"); - let asns_arc = Arc::new(RwLock::new(Arc::new(asns))); - let asns_arc_copy = asns_arc.clone(); - thread::spawn(move || loop { - thread::sleep(Duration::from_secs(3600)); - update_asns(&asns_arc_copy, &db_url); - }); - info!("Starting the webservice"); - WebService::start(asns_arc, listen_addr); +async fn get_asns(db_url: &str) -> Result { + info!("Retrieving ASNs"); + let asns = Asns::new(db_url).await?; + info!("ASNs loaded"); + Ok(asns) +} + +async fn update_asns(asns_arc: &Arc>>, db_url: &str) { + let asns = match get_asns(db_url).await { + Ok(asns) => asns, + Err(e) => { + warn!("{e}"); + return; + } + }; + let asns_arc_new = Arc::new(asns); + let mut asns_arc_w = asns_arc.write().unwrap(); + *asns_arc_w = asns_arc_new; } diff --git a/src/webservice.rs b/src/webservice.rs index 8b110ed..ea1974e 100644 --- a/src/webservice.rs +++ b/src/webservice.rs @@ -1,123 +1,121 @@ use crate::asns::Asns; use horrorshow::prelude::*; -use iron::headers::{Accept, CacheControl, CacheDirective, Expires, HttpDate, Vary}; -use iron::mime::*; -use iron::modifiers::Header; -use iron::prelude::*; -use iron::status; -use iron::{typemap, BeforeMiddleware}; -use router::Router; - -use std::net::IpAddr; +use http::header::{ACCEPT, CACHE_CONTROL, CONTENT_TYPE, EXPIRES, VARY}; +use http::{HeaderMap, HeaderValue, Method, Request, Response, StatusCode}; +use http_body_util::Full; +use hyper::body::Bytes; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; +use serde::{Deserialize, Serialize}; +use std::convert::Infallible; +use std::net::SocketAddr; use std::str::FromStr; use std::sync::{Arc, RwLock}; -// Use the old time crate for iron compatibility -extern crate time01 as time_old; -// Import unicase for Vary headers -use unicase::UniCase; +use time::macros::format_description; +use time::OffsetDateTime; +use tokio::net::TcpListener; const TTL: u32 = 86_400; -struct AsnsMiddleware { - asns_arc: Arc>>, -} - -impl typemap::Key for AsnsMiddleware { - type Value = Arc; -} - -impl AsnsMiddleware { - fn new(asns_arc: Arc>>) -> Self { - Self { asns_arc } - } -} - -impl BeforeMiddleware for AsnsMiddleware { - fn before(&self, req: &mut Request<'_, '_>) -> IronResult<()> { - req.extensions - .insert::(self.asns_arc.read().unwrap().clone()); - Ok(()) - } -} - enum OutputType { Json, Html, } +#[derive(Serialize, Deserialize)] +struct IpLookupResponse { + ip: String, + announced: bool, + #[serde(skip_serializing_if = "Option::is_none")] + first_ip: Option, + #[serde(skip_serializing_if = "Option::is_none")] + last_ip: Option, + #[serde(skip_serializing_if = "Option::is_none")] + as_number: Option, + #[serde(skip_serializing_if = "Option::is_none")] + as_country_code: Option, + #[serde(skip_serializing_if = "Option::is_none")] + as_description: Option, +} + pub struct WebService; impl WebService { - fn index(_: &mut Request<'_, '_>) -> IronResult { - Ok(Response::with(( - status::Ok, - Mime( - TopLevel::Text, - SubLevel::Plain, - vec![(Attr::Charset, Value::Utf8)], - ), - Header(CacheControl(vec![ - CacheDirective::Public, - CacheDirective::MaxAge(TTL), - ])), - Header(Expires(HttpDate( - time_old::now() + time_old::Duration::seconds(TTL.into()), - ))), - "See https://iptoasn.com", - ))) + async fn handle_request( + req: Request, + asns_arc: Arc>>, + ) -> Result>, Infallible> { + let method = req.method(); + let uri = req.uri().path(); + + match (method, uri) { + (&Method::GET, "/") => Ok(Self::index()), + (&Method::GET, path) if path.starts_with("/v1/as/ip/") => { + let ip_s = path.strip_prefix("/v1/as/ip/").unwrap_or(""); + Self::ip_lookup(ip_s, req.headers(), asns_arc) + } + _ => { + let mut response = Response::new(Full::new(Bytes::from("Not Found"))); + *response.status_mut() = StatusCode::NOT_FOUND; + Ok(response) + } + } } - fn accept_type(req: &Request<'_, '_>) -> OutputType { - let mut output_type = OutputType::Json; - if let Some(header_accept) = req.headers.get::() { - for header in header_accept.iter() { - match header.item { - Mime(TopLevel::Text, SubLevel::Html, _) => { - output_type = OutputType::Html; - break; - } - Mime(_, SubLevel::Json, _) => { - output_type = OutputType::Json; - break; - } - _ => {} + fn index() -> Response> { + let mut response = Response::new(Full::new(Bytes::from("iptoasn-webservice\n"))); + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/plain; charset=utf-8"), + ); + *response.status_mut() = StatusCode::OK; + response + } + + fn accept_type(headers: &HeaderMap) -> OutputType { + if let Some(accept) = headers.get(ACCEPT) { + if let Ok(accept_str) = accept.to_str() { + if accept_str.contains("application/json") { + return OutputType::Json; } } } - output_type + OutputType::Html } - fn output_json( - map: &serde_json::Map, - cache_headers: (Header, Header), - vary_header: Header, - ) -> Response { - let json = serde_json::to_string(&map).unwrap(); - let mime_json = Mime( - TopLevel::Application, - SubLevel::Json, - vec![(Attr::Charset, Value::Utf8)], + fn cache_headers(headers: &mut HeaderMap) { + let now = OffsetDateTime::now_utc(); + let expires = now + time::Duration::seconds(TTL as i64); + + let format = format_description!( + "[weekday repr:short], [day] [month repr:short] [year] [hour]:[minute]:[second] GMT" ); - Response::with(( - status::Ok, - mime_json, - cache_headers.0, - cache_headers.1, - vary_header, - json, - )) + let expires_str = expires.format(&format).unwrap(); + + headers.insert( + CACHE_CONTROL, + HeaderValue::from_str(&format!("max-age={}", TTL)).unwrap(), + ); + headers.insert(EXPIRES, HeaderValue::from_str(&expires_str).unwrap()); + headers.insert(VARY, HeaderValue::from_static("Accept")); } - fn output_html( - map: &serde_json::Map, - cache_headers: (Header, Header), - vary_header: Header, - ) -> Response { - let mime_html = Mime( - TopLevel::Text, - SubLevel::Html, - vec![(Attr::Charset, Value::Utf8)], + fn output_json(response: &IpLookupResponse) -> Response> { + let json = serde_json::to_string(&response).unwrap(); + let mut response = Response::new(Full::new(Bytes::from(json))); + + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("application/json; charset=utf-8"), ); + Self::cache_headers(response.headers_mut()); + *response.status_mut() = StatusCode::OK; + + response + } + + fn output_html(response: &IpLookupResponse) -> Response> { let html = html! { head { title : "iptoasn lookup"; @@ -127,35 +125,35 @@ impl WebService { } body(class="container-fluid") { header { - h1 : format_args!("Information for IP address: {}", map.get("ip").unwrap().as_str().unwrap()); + h1 : format_args!("Information for IP address: {}", response.ip); } table { tr { th : "Announced"; td { - @ if map.get("announced").unwrap().as_bool().unwrap() { + @ if response.announced { : "Yes"; } else { : "No"; } } } - @ if map.get("announced").unwrap().as_bool().unwrap() { + @ if response.announced { tr { th : "AS Number"; - td : format_args!("AS{}", map.get("as_number").unwrap().as_u64().unwrap()); + td : format_args!("AS{}", response.as_number.unwrap()); } tr { th : "AS Range"; - td : format_args!("{} - {}", map.get("first_ip").unwrap().as_str().unwrap(), map.get("last_ip").unwrap().as_str().unwrap()); + td : format_args!("{} - {}", response.first_ip.as_ref().unwrap(), response.last_ip.as_ref().unwrap()); } tr { th : "AS Country Code"; - td : map.get("as_country_code").unwrap().as_str().unwrap(); + td : response.as_country_code.as_ref().unwrap(); } tr { th : "AS Description"; - td : map.get("as_description").unwrap().as_str().unwrap(); + td : response.as_description.as_ref().unwrap(); } } } @@ -169,130 +167,98 @@ impl WebService { }.into_string() .unwrap(); let html = format!("\n{html}"); - Response::with(( - status::Ok, - mime_html, - cache_headers.0, - cache_headers.1, - vary_header, - html, - )) + + let mut response = Response::new(Full::new(Bytes::from(html))); + response.headers_mut().insert( + CONTENT_TYPE, + HeaderValue::from_static("text/html; charset=utf-8"), + ); + Self::cache_headers(response.headers_mut()); + *response.status_mut() = StatusCode::OK; + + response } - fn output( - output_type: &OutputType, - map: &serde_json::Map, - cache_headers: (Header, Header), - vary_header: Header, - ) -> Response { + fn output(output_type: &OutputType, response: &IpLookupResponse) -> Response> { match *output_type { - OutputType::Json => Self::output_json(map, cache_headers, vary_header), - OutputType::Html => Self::output_html(map, cache_headers, vary_header), + OutputType::Json => Self::output_json(response), + OutputType::Html => Self::output_html(response), } } - fn ip_lookup(req: &mut Request<'_, '_>) -> IronResult { - let mime_text = Mime( - TopLevel::Text, - SubLevel::Plain, - vec![(Attr::Charset, Value::Utf8)], - ); - let cache_headers = ( - Header(CacheControl(vec![ - CacheDirective::Public, - CacheDirective::MaxAge(TTL), - ])), - Header(Expires(HttpDate( - time_old::now() + time_old::Duration::seconds(TTL.into()), - ))), - ); - let vary_header = Header(Vary::Items(vec![ - UniCase::from_str("accept-encoding").unwrap(), - UniCase::from_str("accept").unwrap(), - ])); - let ip_str = match req.extensions.get::().unwrap().find("ip") { - None => { - let response = Response::with(( - status::BadRequest, - mime_text, - cache_headers, - "Missing IP address", - )); - return Ok(response); - } - Some(ip_str) => ip_str, - }; - let ip = match IpAddr::from_str(ip_str) { + fn ip_lookup( + ip_s: &str, + headers: &HeaderMap, + asns_arc: Arc>>, + ) -> Result>, Infallible> { + let ip = match std::net::IpAddr::from_str(ip_s) { Err(_) => { - return Ok(Response::with(( - status::BadRequest, - mime_text, - cache_headers, - "Invalid IP address", - ))); + let response = IpLookupResponse { + ip: ip_s.to_owned(), + announced: false, + first_ip: None, + last_ip: None, + as_number: None, + as_country_code: None, + as_description: None, + }; + return Ok(Self::output(&Self::accept_type(headers), &response)); } Ok(ip) => ip, }; - let asns = req.extensions.get::().unwrap(); - let mut map = serde_json::Map::new(); - map.insert( - "ip".to_string(), - serde_json::value::Value::String(ip_str.to_string()), - ); + + let asns = asns_arc.read().unwrap().clone(); + let found = match asns.lookup_by_ip(ip) { None => { - map.insert( - "announced".to_string(), - serde_json::value::Value::Bool(false), - ); - return Ok(Self::output( - &Self::accept_type(req), - &map, - cache_headers, - vary_header, - )); + let response = IpLookupResponse { + ip: ip.to_string(), + announced: false, + first_ip: None, + last_ip: None, + as_number: None, + as_country_code: None, + as_description: None, + }; + return Ok(Self::output(&Self::accept_type(headers), &response)); } Some(found) => found, }; - map.insert( - "announced".to_string(), - serde_json::value::Value::Bool(true), - ); - map.insert( - "first_ip".to_string(), - serde_json::value::Value::String(found.first_ip.to_string()), - ); - map.insert( - "last_ip".to_string(), - serde_json::value::Value::String(found.last_ip.to_string()), - ); - map.insert( - "as_number".to_string(), - serde_json::value::Value::Number(serde_json::Number::from(found.number)), - ); - map.insert( - "as_country_code".to_string(), - serde_json::value::Value::String(found.country.clone()), - ); - map.insert( - "as_description".to_string(), - serde_json::value::Value::String(found.description.clone()), - ); - Ok(Self::output( - &Self::accept_type(req), - &map, - cache_headers, - vary_header, - )) + + let response = IpLookupResponse { + ip: ip.to_string(), + announced: true, + first_ip: Some(found.first_ip.to_string()), + last_ip: Some(found.last_ip.to_string()), + as_number: Some(found.number), + as_country_code: Some(found.country.clone()), + as_description: Some(found.description.clone()), + }; + + Ok(Self::output(&Self::accept_type(headers), &response)) } - pub fn start(asns_arc: Arc>>, listen_addr: &str) { - let router = router!(index: get "/" => Self::index, - ip_lookup: get "/v1/as/ip/:ip" => Self::ip_lookup); - let mut chain = Chain::new(router); - let asns_middleware = AsnsMiddleware::new(asns_arc); - chain.link_before(asns_middleware); - warn!("webservice ready"); - Iron::new(chain).http(listen_addr).unwrap(); + pub async fn start(asns_arc: Arc>>, listen_addr: &str) { + let addr: SocketAddr = listen_addr.parse().expect("Could not parse socket address"); + let listener = TcpListener::bind(addr).await.unwrap(); + + log::warn!("webservice ready"); + + loop { + let (tcp, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(tcp); + let asns_arc = asns_arc.clone(); + + tokio::task::spawn(async move { + let service = service_fn(move |req| { + let asns_arc = asns_arc.clone(); + async move { Self::handle_request(req, asns_arc).await } + }); + + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + log::error!("Error serving connection: {:?}", err); + } + }); + } } }