diff --git a/src/http/axum_implementation/extractors/announce_request.rs b/src/http/axum_implementation/extractors/announce_request.rs index 0371be9a4..1680cd15c 100644 --- a/src/http/axum_implementation/extractors/announce_request.rs +++ b/src/http/axum_implementation/extractors/announce_request.rs @@ -19,27 +19,95 @@ where type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let raw_query = parts.uri.query(); - - if raw_query.is_none() { - return Err(responses::error::Error::from(ParseAnnounceQueryError::MissingParams { - location: Location::caller(), - }) - .into_response()); + match extract_announce_from(parts.uri.query()) { + Ok(announce_request) => Ok(ExtractRequest(announce_request)), + Err(error) => Err(error.into_response()), } + } +} - let query = raw_query.unwrap().parse::(); +fn extract_announce_from(maybe_raw_query: Option<&str>) -> Result { + if maybe_raw_query.is_none() { + return Err(responses::error::Error::from(ParseAnnounceQueryError::MissingParams { + location: Location::caller(), + })); + } - if let Err(error) = query { - return Err(responses::error::Error::from(error).into_response()); - } + let query = maybe_raw_query.unwrap().parse::(); - let announce_request = Announce::try_from(query.unwrap()); + if let Err(error) = query { + return Err(responses::error::Error::from(error)); + } - if let Err(error) = announce_request { - return Err(responses::error::Error::from(error).into_response()); - } + let announce_request = Announce::try_from(query.unwrap()); + + if let Err(error) = announce_request { + return Err(responses::error::Error::from(error)); + } + + Ok(announce_request.unwrap()) +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::extract_announce_from; + use crate::http::axum_implementation::requests::announce::{Announce, Compact, Event}; + use crate::http::axum_implementation::responses::error::Error; + use crate::protocol::info_hash::InfoHash; + use crate::tracker::peer; + + fn assert_error_response(error: &Error, error_message: &str) { + assert!( + error.failure_reason.contains(error_message), + "Error response does not contain message: '{error_message}'. Error: {error:?}" + ); + } + + #[test] + fn it_should_extract_the_announce_request_from_the_url_query_params() { + let raw_query = "info_hash=%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0&peer_addr=2.137.87.41&downloaded=0&uploaded=0&peer_id=-qB00000000000000001&port=17548&left=0&event=completed&compact=0"; + + let announce = extract_announce_from(Some(raw_query)).unwrap(); + + assert_eq!( + announce, + Announce { + info_hash: InfoHash::from_str("3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0").unwrap(), + peer_id: peer::Id(*b"-qB00000000000000001"), + port: 17548, + downloaded: Some(0), + uploaded: Some(0), + left: Some(0), + event: Some(Event::Completed), + compact: Some(Compact::NotAccepted), + } + ); + } + + #[test] + fn it_should_reject_a_request_without_query_params() { + let response = extract_announce_from(None).unwrap_err(); + + assert_error_response( + &response, + "Cannot parse query params for announce request: missing query params for announce request", + ); + } + + #[test] + fn it_should_reject_a_request_with_a_query_that_cannot_be_parsed() { + let invalid_query = "param1=value1=value2"; + let response = extract_announce_from(Some(invalid_query)).unwrap_err(); + + assert_error_response(&response, "Cannot parse query params"); + } + + #[test] + fn it_should_reject_a_request_with_a_query_that_cannot_be_parsed_into_an_announce_request() { + let response = extract_announce_from(Some("param1=value1")).unwrap_err(); - Ok(ExtractRequest(announce_request.unwrap())) + assert_error_response(&response, "Cannot parse query params for announce request"); } } diff --git a/src/http/axum_implementation/extractors/key.rs b/src/http/axum_implementation/extractors/key.rs index 50aef4a7c..2a3f2a991 100644 --- a/src/http/axum_implementation/extractors/key.rs +++ b/src/http/axum_implementation/extractors/key.rs @@ -1,6 +1,8 @@ +//! Wrapper for Axum `Path` extractor to return custom errors. use std::panic::Location; use axum::async_trait; +use axum::extract::rejection::PathRejection; use axum::extract::{FromRequestParts, Path}; use axum::http::request::Parts; use axum::response::{IntoResponse, Response}; @@ -19,37 +21,74 @@ where type Rejection = Response; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - match Path::::from_request_parts(parts, state).await { - Ok(key_param) => { - let Ok(key) = key_param.0.value().parse::() else { - return Err(responses::error::Error::from( - auth::Error::InvalidKeyFormat { - location: Location::caller() - }) - .into_response()) - }; - Ok(Extract(key)) - } - Err(rejection) => match rejection { - axum::extract::rejection::PathRejection::FailedToDeserializePathParams(_) => { - return Err(responses::error::Error::from(auth::Error::InvalidKeyFormat { - location: Location::caller(), - }) - .into_response()) - } - axum::extract::rejection::PathRejection::MissingPathParams(_) => { - return Err(responses::error::Error::from(auth::Error::MissingAuthKey { - location: Location::caller(), - }) - .into_response()) - } - _ => { - return Err(responses::error::Error::from(auth::Error::CannotExtractKeyParam { - location: Location::caller(), - }) - .into_response()) - } - }, + // Extract `key` from URL path with Axum `Path` extractor + let maybe_path_with_key = Path::::from_request_parts(parts, state).await; + + match extract_key(maybe_path_with_key) { + Ok(key) => Ok(Extract(key)), + Err(error) => Err(error.into_response()), + } + } +} + +fn extract_key(path_extractor_result: Result, PathRejection>) -> Result { + match path_extractor_result { + Ok(key_param) => match parse_key(&key_param.0.value()) { + Ok(key) => Ok(key), + Err(error) => Err(error), + }, + Err(path_rejection) => Err(custom_error(&path_rejection)), + } +} + +fn parse_key(key: &str) -> Result { + let key = key.parse::(); + + match key { + Ok(key) => Ok(key), + Err(_parse_key_error) => Err(responses::error::Error::from(auth::Error::InvalidKeyFormat { + location: Location::caller(), + })), + } +} + +fn custom_error(rejection: &PathRejection) -> responses::error::Error { + match rejection { + axum::extract::rejection::PathRejection::FailedToDeserializePathParams(_) => { + responses::error::Error::from(auth::Error::InvalidKeyFormat { + location: Location::caller(), + }) + } + axum::extract::rejection::PathRejection::MissingPathParams(_) => { + responses::error::Error::from(auth::Error::MissingAuthKey { + location: Location::caller(), + }) } + _ => responses::error::Error::from(auth::Error::CannotExtractKeyParam { + location: Location::caller(), + }), + } +} + +#[cfg(test)] +mod tests { + + use super::parse_key; + use crate::http::axum_implementation::responses::error::Error; + + fn assert_error_response(error: &Error, error_message: &str) { + assert!( + error.failure_reason.contains(error_message), + "Error response does not contain message: '{error_message}'. Error: {error:?}" + ); + } + + #[test] + fn it_should_return_an_authentication_error_if_the_key_cannot_be_parsed() { + let invalid_key = "invalid_key"; + + let response = parse_key(invalid_key).unwrap_err(); + + assert_error_response(&response, "Authentication error: Invalid format for authentication key param"); } } diff --git a/src/http/axum_implementation/extractors/mod.rs b/src/http/axum_implementation/extractors/mod.rs index e6d9e8c67..04e9e306b 100644 --- a/src/http/axum_implementation/extractors/mod.rs +++ b/src/http/axum_implementation/extractors/mod.rs @@ -1,5 +1,4 @@ pub mod announce_request; pub mod key; -pub mod peer_ip; pub mod remote_client_ip; pub mod scrape_request; diff --git a/src/http/axum_implementation/extractors/peer_ip.rs b/src/http/axum_implementation/extractors/peer_ip.rs deleted file mode 100644 index aae348d99..000000000 --- a/src/http/axum_implementation/extractors/peer_ip.rs +++ /dev/null @@ -1,54 +0,0 @@ -use std::net::IpAddr; -use std::panic::Location; - -use axum::response::{IntoResponse, Response}; -use thiserror::Error; - -use super::remote_client_ip::RemoteClientIp; -use crate::http::axum_implementation::responses; - -#[derive(Error, Debug)] -pub enum ResolutionError { - #[error( - "missing or invalid the right most X-Forwarded-For IP (mandatory on reverse proxy tracker configuration) in {location}" - )] - MissingRightMostXForwardedForIp { location: &'static Location<'static> }, - #[error("cannot get the client IP from the connection info in {location}")] - MissingClientIp { location: &'static Location<'static> }, -} - -impl From for responses::error::Error { - fn from(err: ResolutionError) -> Self { - responses::error::Error { - failure_reason: format!("{err}"), - } - } -} - -/// It resolves the peer IP. -/// -/// # Errors -/// -/// Will return an error if the peer IP cannot be obtained according to the configuration. -/// For example, if the IP is extracted from an HTTP header which is missing in the request. -pub fn resolve(on_reverse_proxy: bool, remote_client_ip: &RemoteClientIp) -> Result { - if on_reverse_proxy { - if let Some(ip) = remote_client_ip.right_most_x_forwarded_for { - Ok(ip) - } else { - Err( - responses::error::Error::from(ResolutionError::MissingRightMostXForwardedForIp { - location: Location::caller(), - }) - .into_response(), - ) - } - } else if let Some(ip) = remote_client_ip.connection_info_ip { - Ok(ip) - } else { - Err(responses::error::Error::from(ResolutionError::MissingClientIp { - location: Location::caller(), - }) - .into_response()) - } -} diff --git a/src/http/axum_implementation/extractors/remote_client_ip.rs b/src/http/axum_implementation/extractors/remote_client_ip.rs index e852a1b6f..0f6789261 100644 --- a/src/http/axum_implementation/extractors/remote_client_ip.rs +++ b/src/http/axum_implementation/extractors/remote_client_ip.rs @@ -1,3 +1,5 @@ +//! Wrapper for two Axum extractors to get the relevant information +//! to resolve the remote client IP. use std::net::{IpAddr, SocketAddr}; use axum::async_trait; @@ -18,7 +20,7 @@ use serde::{Deserialize, Serialize}; /// `right_most_x_forwarded_for` = 126.0.0.2 /// `connection_info_ip` = 126.0.0.3 /// -/// More info about inner extractors : +/// More info about inner extractors: #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub struct RemoteClientIp { pub right_most_x_forwarded_for: Option, diff --git a/src/http/axum_implementation/extractors/scrape_request.rs b/src/http/axum_implementation/extractors/scrape_request.rs index 4212abfcb..998728f59 100644 --- a/src/http/axum_implementation/extractors/scrape_request.rs +++ b/src/http/axum_implementation/extractors/scrape_request.rs @@ -19,27 +19,117 @@ where type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let raw_query = parts.uri.query(); - - if raw_query.is_none() { - return Err(responses::error::Error::from(ParseScrapeQueryError::MissingParams { - location: Location::caller(), - }) - .into_response()); + match extract_scrape_from(parts.uri.query()) { + Ok(scrape_request) => Ok(ExtractRequest(scrape_request)), + Err(error) => Err(error.into_response()), } + } +} - let query = raw_query.unwrap().parse::(); +fn extract_scrape_from(maybe_raw_query: Option<&str>) -> Result { + if maybe_raw_query.is_none() { + return Err(responses::error::Error::from(ParseScrapeQueryError::MissingParams { + location: Location::caller(), + })); + } - if let Err(error) = query { - return Err(responses::error::Error::from(error).into_response()); - } + let query = maybe_raw_query.unwrap().parse::(); + + if let Err(error) = query { + return Err(responses::error::Error::from(error)); + } + + let scrape_request = Scrape::try_from(query.unwrap()); + + if let Err(error) = scrape_request { + return Err(responses::error::Error::from(error)); + } + + Ok(scrape_request.unwrap()) +} - let scrape_request = Scrape::try_from(query.unwrap()); +#[cfg(test)] +mod tests { + use std::str::FromStr; - if let Err(error) = scrape_request { - return Err(responses::error::Error::from(error).into_response()); + use super::extract_scrape_from; + use crate::http::axum_implementation::requests::scrape::Scrape; + use crate::http::axum_implementation::responses::error::Error; + use crate::protocol::info_hash::InfoHash; + + struct TestInfoHash { + pub bencoded: String, + pub value: InfoHash, + } + + fn test_info_hash() -> TestInfoHash { + TestInfoHash { + bencoded: "%3B%24U%04%CF%5F%11%BB%DB%E1%20%1C%EAjk%F4Z%EE%1B%C0".to_owned(), + value: InfoHash::from_str("3b245504cf5f11bbdbe1201cea6a6bf45aee1bc0").unwrap(), } + } + + fn assert_error_response(error: &Error, error_message: &str) { + assert!( + error.failure_reason.contains(error_message), + "Error response does not contain message: '{error_message}'. Error: {error:?}" + ); + } + + #[test] + fn it_should_extract_the_scrape_request_from_the_url_query_params() { + let info_hash = test_info_hash(); + + let raw_query = format!("info_hash={}", info_hash.bencoded); + + let scrape = extract_scrape_from(Some(&raw_query)).unwrap(); + + assert_eq!( + scrape, + Scrape { + info_hashes: vec![info_hash.value], + } + ); + } + + #[test] + fn it_should_extract_the_scrape_request_from_the_url_query_params_with_more_than_one_info_hash() { + let info_hash = test_info_hash(); + + let raw_query = format!("info_hash={}&info_hash={}", info_hash.bencoded, info_hash.bencoded); + + let scrape = extract_scrape_from(Some(&raw_query)).unwrap(); + + assert_eq!( + scrape, + Scrape { + info_hashes: vec![info_hash.value, info_hash.value], + } + ); + } + + #[test] + fn it_should_reject_a_request_without_query_params() { + let response = extract_scrape_from(None).unwrap_err(); + + assert_error_response( + &response, + "Cannot parse query params for scrape request: missing query params for scrape request", + ); + } + + #[test] + fn it_should_reject_a_request_with_a_query_that_cannot_be_parsed() { + let invalid_query = "param1=value1=value2"; + let response = extract_scrape_from(Some(invalid_query)).unwrap_err(); + + assert_error_response(&response, "Cannot parse query params"); + } + + #[test] + fn it_should_reject_a_request_with_a_query_that_cannot_be_parsed_into_a_scrape_request() { + let response = extract_scrape_from(Some("param1=value1")).unwrap_err(); - Ok(ExtractRequest(scrape_request.unwrap())) + assert_error_response(&response, "Cannot parse query params for scrape request"); } } diff --git a/src/http/axum_implementation/handlers/announce.rs b/src/http/axum_implementation/handlers/announce.rs index 4bb06da73..e4b5ece80 100644 --- a/src/http/axum_implementation/handlers/announce.rs +++ b/src/http/axum_implementation/handlers/announce.rs @@ -7,9 +7,9 @@ use axum::extract::State; use axum::response::{IntoResponse, Response}; use log::debug; +use super::common::peer_ip; use crate::http::axum_implementation::extractors::announce_request::ExtractRequest; use crate::http::axum_implementation::extractors::key::Extract; -use crate::http::axum_implementation::extractors::peer_ip; use crate::http::axum_implementation::extractors::remote_client_ip::RemoteClientIp; use crate::http::axum_implementation::handlers::auth; use crate::http::axum_implementation::requests::announce::{Announce, Compact, Event}; diff --git a/src/http/axum_implementation/handlers/common/mod.rs b/src/http/axum_implementation/handlers/common/mod.rs new file mode 100644 index 000000000..ed159a32b --- /dev/null +++ b/src/http/axum_implementation/handlers/common/mod.rs @@ -0,0 +1 @@ +pub mod peer_ip; diff --git a/src/http/axum_implementation/handlers/common/peer_ip.rs b/src/http/axum_implementation/handlers/common/peer_ip.rs new file mode 100644 index 000000000..1c3b6c815 --- /dev/null +++ b/src/http/axum_implementation/handlers/common/peer_ip.rs @@ -0,0 +1,170 @@ +//! Helper handler function to resolve the peer IP from the `RemoteClientIp` extractor. +use std::net::IpAddr; +use std::panic::Location; + +use axum::response::{IntoResponse, Response}; +use thiserror::Error; + +use crate::http::axum_implementation::extractors::remote_client_ip::RemoteClientIp; +use crate::http::axum_implementation::responses; + +#[derive(Error, Debug)] +pub enum ResolutionError { + #[error( + "missing or invalid the right most X-Forwarded-For IP (mandatory on reverse proxy tracker configuration) in {location}" + )] + MissingRightMostXForwardedForIp { location: &'static Location<'static> }, + #[error("cannot get the client IP from the connection info in {location}")] + MissingClientIp { location: &'static Location<'static> }, +} + +impl From for responses::error::Error { + fn from(err: ResolutionError) -> Self { + responses::error::Error { + failure_reason: format!("Error resolving peer IP: {err}"), + } + } +} + +/// It resolves the peer IP. +/// +/// # Errors +/// +/// Will return an error response if the peer IP cannot be obtained according to the configuration. +/// For example, if the IP is extracted from an HTTP header which is missing in the request. +pub fn resolve(on_reverse_proxy: bool, remote_client_ip: &RemoteClientIp) -> Result { + match resolve_peer_ip(on_reverse_proxy, remote_client_ip) { + Ok(ip) => Ok(ip), + Err(error) => Err(error.into_response()), + } +} + +fn resolve_peer_ip(on_reverse_proxy: bool, remote_client_ip: &RemoteClientIp) -> Result { + if on_reverse_proxy { + resolve_peer_ip_on_reverse_proxy(remote_client_ip) + } else { + resolve_peer_ip_without_reverse_proxy(remote_client_ip) + } +} + +fn resolve_peer_ip_without_reverse_proxy(remote_client_ip: &RemoteClientIp) -> Result { + if let Some(ip) = remote_client_ip.connection_info_ip { + Ok(ip) + } else { + Err(responses::error::Error::from(ResolutionError::MissingClientIp { + location: Location::caller(), + })) + } +} + +fn resolve_peer_ip_on_reverse_proxy(remote_client_ip: &RemoteClientIp) -> Result { + if let Some(ip) = remote_client_ip.right_most_x_forwarded_for { + Ok(ip) + } else { + Err(responses::error::Error::from( + ResolutionError::MissingRightMostXForwardedForIp { + location: Location::caller(), + }, + )) + } +} + +#[cfg(test)] +mod tests { + use super::resolve_peer_ip; + use crate::http::axum_implementation::responses::error::Error; + + fn assert_error_response(error: &Error, error_message: &str) { + assert!( + error.failure_reason.contains(error_message), + "Error response does not contain message: '{error_message}'. Error: {error:?}" + ); + } + + mod working_without_reverse_proxy { + use std::net::IpAddr; + use std::str::FromStr; + + use super::{assert_error_response, resolve_peer_ip}; + use crate::http::axum_implementation::extractors::remote_client_ip::RemoteClientIp; + + #[test] + fn it_should_get_the_peer_ip_from_the_connection_info() { + let on_reverse_proxy = false; + + let ip = resolve_peer_ip( + on_reverse_proxy, + &RemoteClientIp { + right_most_x_forwarded_for: None, + connection_info_ip: Some(IpAddr::from_str("203.0.113.195").unwrap()), + }, + ) + .unwrap(); + + assert_eq!(ip, IpAddr::from_str("203.0.113.195").unwrap()); + } + + #[test] + fn it_should_return_an_error_if_it_cannot_get_the_peer_ip_from_the_connection_info() { + let on_reverse_proxy = false; + + let response = resolve_peer_ip( + on_reverse_proxy, + &RemoteClientIp { + right_most_x_forwarded_for: None, + connection_info_ip: None, + }, + ) + .unwrap_err(); + + assert_error_response( + &response, + "Error resolving peer IP: cannot get the client IP from the connection info", + ); + } + } + + mod working_on_reverse_proxy { + use std::net::IpAddr; + use std::str::FromStr; + + use super::assert_error_response; + use crate::http::axum_implementation::extractors::remote_client_ip::RemoteClientIp; + use crate::http::axum_implementation::handlers::common::peer_ip::resolve_peer_ip; + + #[test] + fn it_should_get_the_peer_ip_from_the_right_most_ip_in_the_x_forwarded_for_header() { + let on_reverse_proxy = true; + + let ip = resolve_peer_ip( + on_reverse_proxy, + &RemoteClientIp { + right_most_x_forwarded_for: Some(IpAddr::from_str("203.0.113.195").unwrap()), + connection_info_ip: None, + }, + ) + .unwrap(); + + assert_eq!(ip, IpAddr::from_str("203.0.113.195").unwrap()); + } + + #[test] + fn it_should_return_an_error_if_it_cannot_get_the_right_most_ip_from_the_x_forwarded_for_header() { + let on_reverse_proxy = true; + + let response = resolve_peer_ip( + on_reverse_proxy, + &RemoteClientIp { + right_most_x_forwarded_for: None, + connection_info_ip: None, + }, + ) + .unwrap_err(); + + assert_error_response( + &response, + "Error resolving peer IP: missing or invalid the right most X-Forwarded-For IP", + ); + } + } +} diff --git a/src/http/axum_implementation/handlers/mod.rs b/src/http/axum_implementation/handlers/mod.rs index e6b13ae91..36a810d95 100644 --- a/src/http/axum_implementation/handlers/mod.rs +++ b/src/http/axum_implementation/handlers/mod.rs @@ -3,6 +3,7 @@ use crate::tracker::error::Error; pub mod announce; pub mod auth; +pub mod common; pub mod scrape; impl From for responses::error::Error { diff --git a/src/http/axum_implementation/handlers/scrape.rs b/src/http/axum_implementation/handlers/scrape.rs index 41d6bf3dc..d8d68a4c3 100644 --- a/src/http/axum_implementation/handlers/scrape.rs +++ b/src/http/axum_implementation/handlers/scrape.rs @@ -4,8 +4,8 @@ use axum::extract::State; use axum::response::{IntoResponse, Response}; use log::debug; +use super::common::peer_ip; use crate::http::axum_implementation::extractors::key::Extract; -use crate::http::axum_implementation::extractors::peer_ip; use crate::http::axum_implementation::extractors::remote_client_ip::RemoteClientIp; use crate::http::axum_implementation::extractors::scrape_request::ExtractRequest; use crate::http::axum_implementation::requests::scrape::Scrape; diff --git a/src/http/axum_implementation/responses/error.rs b/src/http/axum_implementation/responses/error.rs index bcf2aaa57..0bcdbd9fb 100644 --- a/src/http/axum_implementation/responses/error.rs +++ b/src/http/axum_implementation/responses/error.rs @@ -2,7 +2,7 @@ use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use serde::{self, Serialize}; -#[derive(Serialize)] +#[derive(Serialize, Debug, PartialEq)] pub struct Error { #[serde(rename = "failure reason")] pub failure_reason: String,