diff --git a/cSpell.json b/cSpell.json index dc51c87c5..9f10d99e4 100644 --- a/cSpell.json +++ b/cSpell.json @@ -28,6 +28,7 @@ "Hydranode", "incompletei", "infohash", + "infohashes", "infoschema", "intervali", "leecher", @@ -58,6 +59,7 @@ "sharktorrent", "socketaddr", "sqllite", + "subsec", "Swatinem", "Swiftbit", "thiserror", diff --git a/src/apis/handlers.rs b/src/apis/handlers.rs index 8d9689025..38959edbe 100644 --- a/src/apis/handlers.rs +++ b/src/apis/handlers.rs @@ -66,8 +66,8 @@ pub async fn add_torrent_to_whitelist_handler( match InfoHash::from_str(&info_hash.0) { Err(_) => invalid_info_hash_param_response(&info_hash.0), Ok(info_hash) => match tracker.add_torrent_to_whitelist(&info_hash).await { - Ok(..) => ok_response(), - Err(..) => failed_to_whitelist_torrent_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_whitelist_torrent_response(e), }, } } @@ -79,16 +79,16 @@ pub async fn remove_torrent_from_whitelist_handler( match InfoHash::from_str(&info_hash.0) { Err(_) => invalid_info_hash_param_response(&info_hash.0), Ok(info_hash) => match tracker.remove_torrent_from_whitelist(&info_hash).await { - Ok(..) => ok_response(), - Err(..) => failed_to_remove_torrent_from_whitelist_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_remove_torrent_from_whitelist_response(e), }, } } pub async fn reload_whitelist_handler(State(tracker): State>) -> Response { match tracker.load_whitelist().await { - Ok(..) => ok_response(), - Err(..) => failed_to_reload_whitelist_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_reload_whitelist_response(e), } } @@ -96,7 +96,7 @@ pub async fn generate_auth_key_handler(State(tracker): State>, Path let seconds_valid = seconds_valid_or_key; match tracker.generate_auth_key(Duration::from_secs(seconds_valid)).await { Ok(auth_key) => auth_key_response(&AuthKey::from(auth_key)), - Err(_) => failed_to_generate_key_response(), + Err(e) => failed_to_generate_key_response(e), } } @@ -111,15 +111,15 @@ pub async fn delete_auth_key_handler( Err(_) => invalid_auth_key_param_response(&seconds_valid_or_key.0), Ok(key_id) => match tracker.remove_auth_key(&key_id.to_string()).await { Ok(_) => ok_response(), - Err(_) => failed_to_delete_key_response(), + Err(e) => failed_to_delete_key_response(e), }, } } pub async fn reload_keys_handler(State(tracker): State>) -> Response { match tracker.load_keys().await { - Ok(..) => ok_response(), - Err(..) => failed_to_reload_keys_response(), + Ok(_) => ok_response(), + Err(e) => failed_to_reload_keys_response(e), } } diff --git a/src/apis/responses.rs b/src/apis/responses.rs index b150b4bff..3b0946396 100644 --- a/src/apis/responses.rs +++ b/src/apis/responses.rs @@ -1,3 +1,5 @@ +use std::error::Error; + use axum::http::{header, StatusCode}; use axum::response::{IntoResponse, Json, Response}; use serde::Serialize; @@ -110,33 +112,33 @@ pub fn torrent_not_known_response() -> Response { } #[must_use] -pub fn failed_to_remove_torrent_from_whitelist_response() -> Response { - unhandled_rejection_response("failed to remove torrent from whitelist".to_string()) +pub fn failed_to_remove_torrent_from_whitelist_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to remove torrent from whitelist: {e}")) } #[must_use] -pub fn failed_to_whitelist_torrent_response() -> Response { - unhandled_rejection_response("failed to whitelist torrent".to_string()) +pub fn failed_to_whitelist_torrent_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to whitelist torrent: {e}")) } #[must_use] -pub fn failed_to_reload_whitelist_response() -> Response { - unhandled_rejection_response("failed to reload whitelist".to_string()) +pub fn failed_to_reload_whitelist_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to reload whitelist: {e}")) } #[must_use] -pub fn failed_to_generate_key_response() -> Response { - unhandled_rejection_response("failed to generate key".to_string()) +pub fn failed_to_generate_key_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to generate key: {e}")) } #[must_use] -pub fn failed_to_delete_key_response() -> Response { - unhandled_rejection_response("failed to delete key".to_string()) +pub fn failed_to_delete_key_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to delete key: {e}")) } #[must_use] -pub fn failed_to_reload_keys_response() -> Response { - unhandled_rejection_response("failed to reload keys".to_string()) +pub fn failed_to_reload_keys_response(e: E) -> Response { + unhandled_rejection_response(format!("failed to reload keys: {e}")) } /// This error response is to keep backward compatibility with the old Warp API. diff --git a/src/config.rs b/src/config.rs index 3ca4b37d8..7ed0f9fa7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,16 +1,21 @@ use std::collections::{HashMap, HashSet}; use std::net::IpAddr; +use std::panic::Location; use std::path::Path; use std::str::FromStr; +use std::sync::Arc; use std::{env, fs}; use config::{Config, ConfigError, File, FileFormat}; +use log::warn; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, NoneAsEmptyString}; +use thiserror::Error; use {std, toml}; use crate::databases::driver::Driver; +use crate::located_error::{Located, LocatedError}; use crate::tracker::mode; #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] @@ -74,13 +79,30 @@ pub struct Configuration { pub http_api: HttpApi, } -#[derive(Debug)] +#[derive(Error, Debug)] pub enum Error { - Message(String), - ConfigError(ConfigError), - IOError(std::io::Error), - ParseError(toml::de::Error), - TrackerModeIncompatible, + #[error("Unable to load from Environmental Variable: {source}")] + UnableToLoadFromEnvironmentVariable { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + }, + + #[error("Default configuration created at: `{path}`, please review and reload tracker, {location}")] + CreatedNewConfigHalt { + location: &'static Location<'static>, + path: String, + }, + + #[error("Failed processing the configuration: {source}")] + ConfigError { source: LocatedError<'static, ConfigError> }, +} + +impl From for Error { + #[track_caller] + fn from(err: ConfigError) -> Self { + Self::ConfigError { + source: Located(err).into(), + } + } } /// This configuration is used for testing. It generates random config values so they do not collide @@ -129,20 +151,6 @@ fn random_port() -> u16 { rng.gen_range(49152..65535) } -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Error::Message(e) => e.fmt(f), - Error::ConfigError(e) => e.fmt(f), - Error::IOError(e) => e.fmt(f), - Error::ParseError(e) => e.fmt(f), - Error::TrackerModeIncompatible => write!(f, "{self:?}"), - } - } -} - -impl std::error::Error for Error {} - impl Default for Configuration { fn default() -> Self { let mut configuration = Configuration { @@ -210,21 +218,19 @@ impl Configuration { let mut config = Config::default(); if Path::new(path).exists() { - config = config_builder - .add_source(File::with_name(path)) - .build() - .map_err(Error::ConfigError)?; + config = config_builder.add_source(File::with_name(path)).build()?; } else { - eprintln!("No config file found."); - eprintln!("Creating config file.."); + warn!("No config file found."); + warn!("Creating config file.."); let config = Configuration::default(); config.save_to_file(path)?; - return Err(Error::Message( - "Please edit the config.TOML and restart the tracker.".to_string(), - )); + return Err(Error::CreatedNewConfigHalt { + location: Location::caller(), + path: path.to_string(), + }); } - let torrust_config: Configuration = config.try_deserialize().map_err(Error::ConfigError)?; + let torrust_config: Configuration = config.try_deserialize()?; Ok(torrust_config) } @@ -237,15 +243,13 @@ impl Configuration { Ok(config_toml) => { let config_builder = Config::builder() .add_source(File::from_str(&config_toml, FileFormat::Toml)) - .build() - .map_err(Error::ConfigError)?; - let config = config_builder.try_deserialize().map_err(Error::ConfigError)?; + .build()?; + let config = config_builder.try_deserialize()?; Ok(config) } - Err(_) => Err(Error::Message(format!( - "No environment variable for configuration found: {}", - &config_env_var_name - ))), + Err(e) => Err(Error::UnableToLoadFromEnvironmentVariable { + source: (Arc::new(e) as Arc).into(), + }), } } @@ -262,7 +266,7 @@ impl Configuration { #[cfg(test)] mod tests { - use crate::config::{Configuration, Error}; + use crate::config::Configuration; #[cfg(test)] fn default_config_toml() -> String { @@ -381,13 +385,6 @@ mod tests { assert_eq!(configuration, Configuration::default()); } - #[test] - fn configuration_error_could_be_displayed() { - let error = Error::TrackerModeIncompatible; - - assert_eq!(format!("{error}"), "TrackerModeIncompatible"); - } - #[test] fn http_api_configuration_should_check_if_it_contains_a_token() { let configuration = Configuration::default(); diff --git a/src/databases/driver.rs b/src/databases/driver.rs index 7eaa9064e..c601f1866 100644 --- a/src/databases/driver.rs +++ b/src/databases/driver.rs @@ -1,7 +1,30 @@ use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] +use super::error::Error; +use super::mysql::Mysql; +use super::sqlite::Sqlite; +use super::{Builder, Database}; + +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, derive_more::Display, Clone)] pub enum Driver { Sqlite3, MySQL, } + +impl Driver { + /// . + /// + /// # Errors + /// + /// This function will return an error if unable to connect to the database. + pub fn build(&self, db_path: &str) -> Result, Error> { + let database = match self { + Driver::Sqlite3 => Builder::::build(db_path), + Driver::MySQL => Builder::::build(db_path), + }?; + + database.create_database_tables().expect("Could not create database tables."); + + Ok(database) + } +} diff --git a/src/databases/error.rs b/src/databases/error.rs index 467db407f..4bee82f19 100644 --- a/src/databases/error.rs +++ b/src/databases/error.rs @@ -1,21 +1,95 @@ -use derive_more::{Display, Error}; +use std::panic::Location; +use std::sync::Arc; -#[derive(Debug, Display, PartialEq, Eq, Error)] -#[allow(dead_code)] +use r2d2_mysql::mysql::UrlError; + +use super::driver::Driver; +use crate::located_error::{Located, LocatedError}; + +#[derive(thiserror::Error, Debug, Clone)] pub enum Error { - #[display(fmt = "Query returned no rows.")] - QueryReturnedNoRows, - #[display(fmt = "Invalid query.")] - InvalidQuery, - #[display(fmt = "Database error.")] - DatabaseError, + #[error("The {driver} query unexpectedly returned nothing: {source}")] + QueryReturnedNoRows { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + driver: Driver, + }, + + #[error("The {driver} query was malformed: {source}")] + InvalidQuery { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + driver: Driver, + }, + + #[error("Unable to insert record into {driver} database, {location}")] + InsertFailed { + location: &'static Location<'static>, + driver: Driver, + }, + + #[error("Failed to remove record from {driver} database, error-code: {error_code}, {location}")] + DeleteFailed { + location: &'static Location<'static>, + error_code: usize, + driver: Driver, + }, + + #[error("Failed to connect to {driver} database: {source}")] + ConnectionError { + source: LocatedError<'static, UrlError>, + driver: Driver, + }, + + #[error("Failed to create r2d2 {driver} connection pool: {source}")] + ConnectionPool { + source: LocatedError<'static, r2d2::Error>, + driver: Driver, + }, } impl From for Error { - fn from(e: r2d2_sqlite::rusqlite::Error) -> Self { - match e { - r2d2_sqlite::rusqlite::Error::QueryReturnedNoRows => Error::QueryReturnedNoRows, - _ => Error::InvalidQuery, + #[track_caller] + fn from(err: r2d2_sqlite::rusqlite::Error) -> Self { + match err { + r2d2_sqlite::rusqlite::Error::QueryReturnedNoRows => Error::QueryReturnedNoRows { + source: (Arc::new(err) as Arc).into(), + driver: Driver::Sqlite3, + }, + _ => Error::InvalidQuery { + source: (Arc::new(err) as Arc).into(), + driver: Driver::Sqlite3, + }, + } + } +} + +impl From for Error { + #[track_caller] + fn from(err: r2d2_mysql::mysql::Error) -> Self { + let e: Arc = Arc::new(err); + Error::InvalidQuery { + source: e.into(), + driver: Driver::MySQL, + } + } +} + +impl From for Error { + #[track_caller] + fn from(err: UrlError) -> Self { + Self::ConnectionError { + source: Located(err).into(), + driver: Driver::MySQL, + } + } +} + +impl From<(r2d2::Error, Driver)> for Error { + #[track_caller] + fn from(e: (r2d2::Error, Driver)) -> Self { + let (err, driver) = e; + Self::ConnectionPool { + source: Located(err).into(), + driver, } } } diff --git a/src/databases/mod.rs b/src/databases/mod.rs index 873dd70eb..809decc2c 100644 --- a/src/databases/mod.rs +++ b/src/databases/mod.rs @@ -3,37 +3,48 @@ pub mod error; pub mod mysql; pub mod sqlite; +use std::marker::PhantomData; + use async_trait::async_trait; -use self::driver::Driver; use self::error::Error; -use crate::databases::mysql::Mysql; -use crate::databases::sqlite::Sqlite; use crate::protocol::info_hash::InfoHash; use crate::tracker::auth; -/// # Errors -/// -/// Will return `r2d2::Error` if `db_path` is not able to create a database. -pub fn connect(db_driver: &Driver, db_path: &str) -> Result, r2d2::Error> { - let database: Box = match db_driver { - Driver::Sqlite3 => { - let db = Sqlite::new(db_path)?; - Box::new(db) - } - Driver::MySQL => { - let db = Mysql::new(db_path)?; - Box::new(db) - } - }; - - database.create_database_tables().expect("Could not create database tables."); - - Ok(database) +pub(self) struct Builder +where + T: Database, +{ + phantom: PhantomData, +} + +impl Builder +where + T: Database + 'static, +{ + /// . + /// + /// # Errors + /// + /// Will return `r2d2::Error` if `db_path` is not able to create a database. + pub(self) fn build(db_path: &str) -> Result, Error> { + Ok(Box::new(T::new(db_path)?)) + } } #[async_trait] pub trait Database: Sync + Send { + /// . + /// + /// # Errors + /// + /// Will return `r2d2::Error` if `db_path` is not able to create a database. + fn new(db_path: &str) -> Result + where + Self: std::marker::Sized; + + /// . + /// /// # Errors /// /// Will return `Error` if unable to create own tables. @@ -52,27 +63,22 @@ pub trait Database: Sync + Send { async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>; - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result; + async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error>; async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result; async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result; - async fn get_key_from_keys(&self, key: &str) -> Result; + async fn get_key_from_keys(&self, key: &str) -> Result, Error>; async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result; async fn remove_key_from_keys(&self, key: &str) -> Result; async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result { - self.get_info_hash_from_whitelist(&info_hash.clone().to_string()) - .await - .map_or_else( - |e| match e { - Error::QueryReturnedNoRows => Ok(false), - e => Err(e), - }, - |_| Ok(true), - ) + Ok(self + .get_info_hash_from_whitelist(&info_hash.clone().to_string()) + .await? + .is_some()) } } diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index 71b06378c..ac54ebb82 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -8,33 +8,32 @@ use r2d2_mysql::mysql::prelude::Queryable; use r2d2_mysql::mysql::{params, Opts, OptsBuilder}; use r2d2_mysql::MysqlConnectionManager; +use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::common::AUTH_KEY_LENGTH; use crate::protocol::info_hash::InfoHash; use crate::tracker::auth; +const DRIVER: Driver = Driver::MySQL; + pub struct Mysql { pool: Pool, } -impl Mysql { +#[async_trait] +impl Database for Mysql { /// # Errors /// /// Will return `r2d2::Error` if `db_path` is not able to create `MySQL` database. - pub fn new(db_path: &str) -> Result { - let opts = Opts::from_url(db_path).expect("Failed to connect to MySQL database."); + fn new(db_path: &str) -> Result { + let opts = Opts::from_url(db_path)?; let builder = OptsBuilder::from_opts(opts); let manager = MysqlConnectionManager::new(builder); - let pool = r2d2::Pool::builder() - .build(manager) - .expect("Failed to create r2d2 MySQL connection pool."); + let pool = r2d2::Pool::builder().build(manager).map_err(|e| (e, DRIVER))?; Ok(Self { pool }) } -} -#[async_trait] -impl Database for Mysql { fn create_database_tables(&self) -> Result<(), Error> { let create_whitelist_table = " CREATE TABLE IF NOT EXISTS whitelist ( @@ -63,7 +62,7 @@ impl Database for Mysql { i8::try_from(AUTH_KEY_LENGTH).expect("auth::Auth Key Length Should fit within a i8!") ); - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; conn.query_drop(&create_torrents_table) .expect("Could not create torrents table."); @@ -87,7 +86,7 @@ impl Database for Mysql { DROP TABLE `keys`;" .to_string(); - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; conn.query_drop(&drop_whitelist_table) .expect("Could not drop `whitelist` table."); @@ -99,155 +98,124 @@ impl Database for Mysql { } async fn load_persistent_torrents(&self) -> Result, Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - let torrents: Vec<(InfoHash, u32)> = conn - .query_map( - "SELECT info_hash, completed FROM torrents", - |(info_hash_string, completed): (String, u32)| { - let info_hash = InfoHash::from_str(&info_hash_string).unwrap(); - (info_hash, completed) - }, - ) - .map_err(|_| Error::QueryReturnedNoRows)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let torrents = conn.query_map( + "SELECT info_hash, completed FROM torrents", + |(info_hash_string, completed): (String, u32)| { + let info_hash = InfoHash::from_str(&info_hash_string).unwrap(); + (info_hash, completed) + }, + )?; Ok(torrents) } async fn load_keys(&self) -> Result, Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - let keys: Vec = conn - .query_map( - "SELECT `key`, valid_until FROM `keys`", - |(key, valid_until): (String, i64)| auth::Key { - key, - valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), - }, - ) - .map_err(|_| Error::QueryReturnedNoRows)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let keys = conn.query_map( + "SELECT `key`, valid_until FROM `keys`", + |(key, valid_until): (String, i64)| auth::Key { + key, + valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), + }, + )?; Ok(keys) } async fn load_whitelist(&self) -> Result, Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - let info_hashes: Vec = conn - .query_map("SELECT info_hash FROM whitelist", |info_hash: String| { - InfoHash::from_str(&info_hash).unwrap() - }) - .map_err(|_| Error::QueryReturnedNoRows)?; + let info_hashes = conn.query_map("SELECT info_hash FROM whitelist", |info_hash: String| { + InfoHash::from_str(&info_hash).unwrap() + })?; Ok(info_hashes) } async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error> { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + const COMMAND : &str = "INSERT INTO torrents (info_hash, completed) VALUES (:info_hash_str, :completed) ON DUPLICATE KEY UPDATE completed = VALUES(completed)"; + + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let info_hash_str = info_hash.to_string(); debug!("{}", info_hash_str); - match conn.exec_drop("INSERT INTO torrents (info_hash, completed) VALUES (:info_hash_str, :completed) ON DUPLICATE KEY UPDATE completed = VALUES(completed)", params! { info_hash_str, completed }) { - Ok(_) => { - Ok(()) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + Ok(conn.exec_drop(COMMAND, params! { info_hash_str, completed })?) } - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn - .exec_first::( - "SELECT info_hash FROM whitelist WHERE info_hash = :info_hash", - params! { info_hash }, - ) - .map_err(|_| Error::DatabaseError)? - { - Some(info_hash) => Ok(InfoHash::from_str(&info_hash).unwrap()), - None => Err(Error::QueryReturnedNoRows), - } + async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error> { + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let select = conn.exec_first::( + "SELECT info_hash FROM whitelist WHERE info_hash = :info_hash", + params! { info_hash }, + )?; + + let info_hash = select.map(|f| InfoHash::from_str(&f).expect("Failed to decode InfoHash String from DB!")); + + Ok(info_hash) } async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let info_hash_str = info_hash.to_string(); - match conn.exec_drop( + conn.exec_drop( "INSERT INTO whitelist (info_hash) VALUES (:info_hash_str)", params! { info_hash_str }, - ) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + )?; + + Ok(1) } async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let info_hash = info_hash.to_string(); - match conn.exec_drop("DELETE FROM whitelist WHERE info_hash = :info_hash", params! { info_hash }) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + conn.exec_drop("DELETE FROM whitelist WHERE info_hash = :info_hash", params! { info_hash })?; + + Ok(1) } - async fn get_key_from_keys(&self, key: &str) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - match conn - .exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key }) - .map_err(|_| Error::QueryReturnedNoRows)? - { - Some((key, valid_until)) => Ok(auth::Key { - key, - valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), - }), - None => Err(Error::InvalidQuery), - } + let query = + conn.exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key }); + + let key = query?; + + Ok(key.map(|(key, expiry)| auth::Key { + key, + valid_until: Some(Duration::from_secs(expiry.unsigned_abs())), + })) } async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let key = auth_key.key.to_string(); let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string(); - match conn.exec_drop( + conn.exec_drop( "INSERT INTO `keys` (`key`, valid_until) VALUES (:key, :valid_until)", params! { key, valid_until }, - ) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + )?; + + Ok(1) } async fn remove_key_from_keys(&self, key: &str) -> Result { - let mut conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key }) { - Ok(_) => Ok(1), - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + conn.exec_drop("DELETE FROM `keys` WHERE key = :key", params! { key })?; + + Ok(1) } } diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 1d7caf052..3425b15c8 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -1,32 +1,32 @@ +use std::panic::Location; use std::str::FromStr; use async_trait::async_trait; -use log::debug; use r2d2::Pool; use r2d2_sqlite::SqliteConnectionManager; +use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::clock::DurationSinceUnixEpoch; use crate::protocol::info_hash::InfoHash; use crate::tracker::auth; +const DRIVER: Driver = Driver::Sqlite3; + pub struct Sqlite { pool: Pool, } -impl Sqlite { +#[async_trait] +impl Database for Sqlite { /// # Errors /// /// Will return `r2d2::Error` if `db_path` is not able to create `SqLite` database. - pub fn new(db_path: &str) -> Result { + fn new(db_path: &str) -> Result { let cm = SqliteConnectionManager::file(db_path); - let pool = Pool::new(cm).expect("Failed to create r2d2 SQLite connection pool."); - Ok(Sqlite { pool }) + Pool::new(cm).map_or_else(|err| Err((err, Driver::Sqlite3).into()), |pool| Ok(Sqlite { pool })) } -} -#[async_trait] -impl Database for Sqlite { fn create_database_tables(&self) -> Result<(), Error> { let create_whitelist_table = " CREATE TABLE IF NOT EXISTS whitelist ( @@ -51,13 +51,13 @@ impl Database for Sqlite { );" .to_string(); - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + conn.execute(&create_whitelist_table, [])?; + conn.execute(&create_keys_table, [])?; + conn.execute(&create_torrents_table, [])?; - conn.execute(&create_whitelist_table, []) - .and_then(|_| conn.execute(&create_keys_table, [])) - .and_then(|_| conn.execute(&create_torrents_table, [])) - .map_err(|_| Error::InvalidQuery) - .map(|_| ()) + Ok(()) } fn drop_database_tables(&self) -> Result<(), Error> { @@ -73,17 +73,17 @@ impl Database for Sqlite { DROP TABLE keys;" .to_string(); - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; conn.execute(&drop_whitelist_table, []) .and_then(|_| conn.execute(&drop_torrents_table, [])) - .and_then(|_| conn.execute(&drop_keys_table, [])) - .map_err(|_| Error::InvalidQuery) - .map(|_| ()) + .and_then(|_| conn.execute(&drop_keys_table, []))?; + + Ok(()) } async fn load_persistent_torrents(&self) -> Result, Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash, completed FROM torrents")?; @@ -94,13 +94,16 @@ impl Database for Sqlite { Ok((info_hash, completed)) })?; + //torrent_iter?; + //let torrent_iter = torrent_iter.unwrap(); + let torrents: Vec<(InfoHash, u32)> = torrent_iter.filter_map(std::result::Result::ok).collect(); Ok(torrents) } async fn load_keys(&self) -> Result, Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?; @@ -120,7 +123,7 @@ impl Database for Sqlite { } async fn load_whitelist(&self) -> Result, Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash FROM whitelist")?; @@ -136,130 +139,117 @@ impl Database for Sqlite { } async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error> { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; - match conn.execute( + let insert = conn.execute( "INSERT INTO torrents (info_hash, completed) VALUES (?1, ?2) ON CONFLICT(info_hash) DO UPDATE SET completed = ?2", [info_hash.to_string(), completed.to_string()], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(()); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + )?; + + if insert == 0 { + Err(Error::InsertFailed { + location: Location::caller(), + driver: DRIVER, + }) + } else { + Ok(()) } } - async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error> { + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?; + let mut rows = stmt.query([info_hash])?; - match rows.next() { - Ok(row) => match row { - Some(row) => Ok(InfoHash::from_str(&row.get_unwrap::<_, String>(0)).unwrap()), - None => Err(Error::QueryReturnedNoRows), - }, - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } - } + let query = rows.next()?; + + Ok(query.map(|f| InfoHash::from_str(&f.get_unwrap::<_, String>(0)).unwrap())) } async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.execute("INSERT INTO whitelist (info_hash) VALUES (?)", [info_hash.to_string()]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let insert = conn.execute("INSERT INTO whitelist (info_hash) VALUES (?)", [info_hash.to_string()])?; + + if insert == 0 { + Err(Error::InsertFailed { + location: Location::caller(), + driver: DRIVER, + }) + } else { + Ok(insert) } } async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.execute("DELETE FROM whitelist WHERE info_hash = ?", [info_hash.to_string()]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let deleted = conn.execute("DELETE FROM whitelist WHERE info_hash = ?", [info_hash.to_string()])?; + + if deleted == 1 { + // should only remove a single record. + Ok(deleted) + } else { + Err(Error::DeleteFailed { + location: Location::caller(), + error_code: deleted, + driver: DRIVER, + }) } } - async fn get_key_from_keys(&self, key: &str) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?; + let mut rows = stmt.query([key.to_string()])?; - if let Some(row) = rows.next()? { - let key: String = row.get(0).unwrap(); - let valid_until: i64 = row.get(1).unwrap(); + let key = rows.next()?; - Ok(auth::Key { - key, - valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())), - }) - } else { - Err(Error::QueryReturnedNoRows) - } + Ok(key.map(|f| { + let expiry: i64 = f.get(1).unwrap(); + auth::Key { + key: f.get(0).unwrap(), + valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())), + } + })) } async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; - match conn.execute( + let insert = conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", [auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + )?; + + if insert == 0 { + Err(Error::InsertFailed { + location: Location::caller(), + driver: DRIVER, + }) + } else { + Ok(insert) } } async fn remove_key_from_keys(&self, key: &str) -> Result { - let conn = self.pool.get().map_err(|_| Error::DatabaseError)?; - - match conn.execute("DELETE FROM keys WHERE key = ?", [key]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(Error::QueryReturnedNoRows) - } - Err(e) => { - debug!("{:?}", e); - Err(Error::InvalidQuery) - } + let conn = self.pool.get().map_err(|e| (e, DRIVER))?; + + let deleted = conn.execute("DELETE FROM keys WHERE key = ?", [key])?; + + if deleted == 1 { + // should only remove a single record. + Ok(deleted) + } else { + Err(Error::DeleteFailed { + location: Location::caller(), + error_code: deleted, + driver: DRIVER, + }) } } } diff --git a/src/http/error.rs b/src/http/error.rs index b6c08a8ba..f07c32f6d 100644 --- a/src/http/error.rs +++ b/src/http/error.rs @@ -1,34 +1,40 @@ +use std::panic::Location; + use thiserror::Error; use warp::reject::Reject; +use crate::located_error::LocatedError; + #[derive(Error, Debug)] pub enum Error { - #[error("internal server error")] - InternalServer, - - #[error("info_hash is either missing or invalid")] - InvalidInfo, - - #[error("peer_id is either missing or invalid")] - InvalidPeerId, - - #[error("could not find remote address")] - AddressNotFound, - - #[error("torrent has no peers")] - NoPeersFound, - - #[error("torrent not on whitelist")] - TorrentNotWhitelisted, - - #[error("peer not authenticated")] - PeerNotAuthenticated, - - #[error("invalid authentication key")] - PeerKeyNotValid, - - #[error("exceeded info_hash limit")] - ExceededInfoHashLimit, + #[error("tracker server error: {source}")] + TrackerError { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + }, + + #[error("internal server error: {message}, {location}")] + InternalServer { + location: &'static Location<'static>, + message: String, + }, + + #[error("no valid infohashes found, {location}")] + EmptyInfoHash { location: &'static Location<'static> }, + + #[error("peer_id is either missing or invalid, {location}")] + InvalidPeerId { location: &'static Location<'static> }, + + #[error("could not find remote address: {message}, {location}")] + AddressNotFound { + location: &'static Location<'static>, + message: String, + }, + + #[error("too many infohashes: {message}, {location}")] + TwoManyInfoHashes { + location: &'static Location<'static>, + message: String, + }, } impl Reject for Error {} diff --git a/src/http/filters.rs b/src/http/filters.rs index 0fe369eba..2760c995c 100644 --- a/src/http/filters.rs +++ b/src/http/filters.rs @@ -1,5 +1,6 @@ use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; +use std::panic::Location; use std::str::FromStr; use std::sync::Arc; @@ -87,9 +88,14 @@ fn info_hashes(raw_query: &String) -> WebResult> { } if info_hashes.len() > MAX_SCRAPE_TORRENTS as usize { - Err(reject::custom(Error::ExceededInfoHashLimit)) + Err(reject::custom(Error::TwoManyInfoHashes { + location: Location::caller(), + message: format! {"found: {}, but limit is: {}",info_hashes.len(), MAX_SCRAPE_TORRENTS}, + })) } else if info_hashes.is_empty() { - Err(reject::custom(Error::InvalidInfo)) + Err(reject::custom(Error::EmptyInfoHash { + location: Location::caller(), + })) } else { Ok(info_hashes) } @@ -114,7 +120,9 @@ fn peer_id(raw_query: &String) -> WebResult { // peer_id must be 20 bytes if peer_id_bytes.len() != 20 { - return Err(reject::custom(Error::InvalidPeerId)); + return Err(reject::custom(Error::InvalidPeerId { + location: Location::caller(), + })); } // clone peer_id_bytes into fixed length array @@ -128,18 +136,26 @@ fn peer_id(raw_query: &String) -> WebResult { match peer_id { Some(id) => Ok(id), - None => Err(reject::custom(Error::InvalidPeerId)), + None => Err(reject::custom(Error::InvalidPeerId { + location: Location::caller(), + })), } } /// Get `PeerAddress` from `RemoteAddress` or Forwarded fn peer_addr((on_reverse_proxy, remote_addr, x_forwarded_for): (bool, Option, Option)) -> WebResult { if !on_reverse_proxy && remote_addr.is_none() { - return Err(reject::custom(Error::AddressNotFound)); + return Err(reject::custom(Error::AddressNotFound { + location: Location::caller(), + message: "neither on have remote address or on a reverse proxy".to_string(), + })); } if on_reverse_proxy && x_forwarded_for.is_none() { - return Err(reject::custom(Error::AddressNotFound)); + return Err(reject::custom(Error::AddressNotFound { + location: Location::caller(), + message: "must have a x-forwarded-for when using a reverse proxy".to_string(), + })); } if on_reverse_proxy { @@ -151,7 +167,14 @@ fn peer_addr((on_reverse_proxy, remote_addr, x_forwarded_for): (bool, Option, tracker: Arc, ) -> Result<(), Error> { - tracker.authenticate_request(info_hash, auth_key).await.map_err(|e| match e { - torrent::Error::TorrentNotWhitelisted => Error::TorrentNotWhitelisted, - torrent::Error::PeerNotAuthenticated => Error::PeerNotAuthenticated, - torrent::Error::PeerKeyNotValid => Error::PeerKeyNotValid, - torrent::Error::NoPeersFound => Error::NoPeersFound, - torrent::Error::CouldNotSendResponse => Error::InternalServer, - torrent::Error::InvalidInfoHash => Error::InvalidInfo, - }) + tracker + .authenticate_request(info_hash, auth_key) + .await + .map_err(|e| Error::TrackerError { + source: (Arc::new(e) as Arc).into(), + }) } /// Handle announce request @@ -42,9 +41,7 @@ pub async fn handle_announce( auth_key: Option, tracker: Arc, ) -> WebResult { - authenticate(&announce_request.info_hash, &auth_key, tracker.clone()) - .await - .map_err(reject::custom)?; + authenticate(&announce_request.info_hash, &auth_key, tracker.clone()).await?; debug!("{:?}", announce_request); @@ -161,7 +158,10 @@ fn send_announce_response( if let Some(1) = announce_request.compact { match res.write_compact() { Ok(body) => Ok(Response::new(body)), - Err(_) => Err(reject::custom(Error::InternalServer)), + Err(e) => Err(reject::custom(Error::InternalServer { + message: e.to_string(), + location: Location::caller(), + })), } } else { Ok(Response::new(res.write().into())) @@ -174,7 +174,10 @@ fn send_scrape_response(files: HashMap) -> WebR match res.write() { Ok(body) => Ok(Response::new(body)), - Err(_) => Err(reject::custom(Error::InternalServer)), + Err(e) => Err(reject::custom(Error::InternalServer { + message: e.to_string(), + location: Location::caller(), + })), } } @@ -184,15 +187,21 @@ fn send_scrape_response(files: HashMap) -> WebR /// /// Will not return a error, `Infallible`, but instead convert the `ServerError` into a `Response`. pub fn send_error(r: &Rejection) -> std::result::Result { - let body = if let Some(server_error) = r.find::() { - debug!("{:?}", server_error); + let warp_reject_error = r.find::(); + + let body = if let Some(error) = warp_reject_error { + debug!("{:?}", error); response::Error { - failure_reason: server_error.to_string(), + failure_reason: error.to_string(), } .write() } else { response::Error { - failure_reason: Error::InternalServer.to_string(), + failure_reason: Error::InternalServer { + message: "Undefined".to_string(), + location: Location::caller(), + } + .to_string(), } .write() }; diff --git a/src/lib.rs b/src/lib.rs index e8cf53045..cbda2854c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod config; pub mod databases; pub mod http; pub mod jobs; +pub mod located_error; pub mod logging; pub mod protocol; pub mod setup; diff --git a/src/located_error.rs b/src/located_error.rs new file mode 100644 index 000000000..d45517e5a --- /dev/null +++ b/src/located_error.rs @@ -0,0 +1,103 @@ +// https://stackoverflow.com/questions/74336993/getting-line-numbers-with-when-using-boxdyn-stderrorerror + +use std::error::Error; +use std::panic::Location; +use std::sync::Arc; + +pub struct Located(pub E); + +#[derive(Debug)] +pub struct LocatedError<'a, E> +where + E: Error + ?Sized + Send + Sync, +{ + source: Arc, + location: Box>, +} + +impl<'a, E> std::fmt::Display for LocatedError<'a, E> +where + E: Error + ?Sized + Send + Sync, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}, {}", self.source, self.location) + } +} + +impl<'a, E> Error for LocatedError<'a, E> +where + E: Error + ?Sized + Send + Sync + 'static, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.source) + } +} + +impl<'a, E> Clone for LocatedError<'a, E> +where + E: Error + ?Sized + Send + Sync, +{ + fn clone(&self) -> Self { + LocatedError { + source: self.source.clone(), + location: self.location.clone(), + } + } +} + +#[allow(clippy::from_over_into)] +impl<'a, E> Into> for Located +where + E: Error + Send + Sync, + Arc: Clone, +{ + #[track_caller] + fn into(self) -> LocatedError<'a, E> { + let e = LocatedError { + source: Arc::new(self.0), + location: Box::new(*std::panic::Location::caller()), + }; + log::debug!("{e}"); + e + } +} + +#[allow(clippy::from_over_into)] +impl<'a> Into> for Arc { + #[track_caller] + fn into(self) -> LocatedError<'a, dyn std::error::Error + Send + Sync> { + LocatedError { + source: self, + location: Box::new(*std::panic::Location::caller()), + } + } +} + +#[cfg(test)] +mod tests { + use std::panic::Location; + + use super::LocatedError; + use crate::located_error::Located; + + #[derive(thiserror::Error, Debug)] + enum TestError { + #[error("Test")] + Test, + } + + #[track_caller] + fn get_caller_location() -> Location<'static> { + *Location::caller() + } + + #[test] + fn error_should_include_location() { + let e = TestError::Test; + + let b: LocatedError = Located(e).into(); + let l = get_caller_location(); + + assert_eq!(b.location.file(), l.file()); + } +} diff --git a/src/tracker/auth.rs b/src/tracker/auth.rs index 3b8af96a1..197e0dc37 100644 --- a/src/tracker/auth.rs +++ b/src/tracker/auth.rs @@ -1,12 +1,17 @@ +use std::panic::Location; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; -use derive_more::{Display, Error}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use derive_more::Display; use log::debug; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use serde::{Deserialize, Serialize}; +use thiserror::Error; +use crate::located_error::LocatedError; use crate::protocol::clock::{Current, DurationSinceUnixEpoch, Time, TimeNow}; use crate::protocol::common::AUTH_KEY_LENGTH; @@ -38,14 +43,19 @@ pub fn verify(auth_key: &Key) -> Result<(), Error> { let current_time: DurationSinceUnixEpoch = Current::now(); match auth_key.valid_until { - Some(valid_untill) => { - if valid_untill < current_time { - Err(Error::KeyExpired) + Some(valid_until) => { + if valid_until < current_time { + Err(Error::KeyExpired { + location: Location::caller(), + }) } else { Ok(()) } } - None => Err(Error::KeyInvalid), + None => Err(Error::UnableToReadKey { + location: Location::caller(), + key: Box::new(auth_key.clone()), + }), } } @@ -57,6 +67,29 @@ pub struct Key { pub valid_until: Option, } +impl std::fmt::Display for Key { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "key: `{}`, valid until `{}`", + self.key, + match self.valid_until { + Some(duration) => format!( + "{}", + DateTime::::from_utc( + NaiveDateTime::from_timestamp( + i64::try_from(duration.as_secs()).expect("Overflow of i64 seconds, very future!"), + duration.subsec_nanos(), + ), + Utc + ) + ), + None => "Empty!?".to_string(), + } + ) + } +} + impl Key { #[must_use] pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { @@ -108,21 +141,27 @@ impl FromStr for KeyId { } } -#[derive(Debug, Display, PartialEq, Eq, Error)] +#[derive(Debug, Error)] #[allow(dead_code)] pub enum Error { - #[display(fmt = "Key could not be verified.")] - KeyVerificationError, - #[display(fmt = "Key is invalid.")] - KeyInvalid, - #[display(fmt = "Key has expired.")] - KeyExpired, + #[error("Key could not be verified: {source}")] + KeyVerificationError { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + }, + #[error("Failed to read key: {key}, {location}")] + UnableToReadKey { + location: &'static Location<'static>, + key: Box, + }, + #[error("Key has expired, {location}")] + KeyExpired { location: &'static Location<'static> }, } impl From for Error { fn from(e: r2d2_sqlite::rusqlite::Error) -> Self { - eprintln!("{e}"); - Error::KeyVerificationError + Error::KeyVerificationError { + source: (Arc::new(e) as Arc).into(), + } } } diff --git a/src/tracker/error.rs b/src/tracker/error.rs new file mode 100644 index 000000000..51bcbf3bb --- /dev/null +++ b/src/tracker/error.rs @@ -0,0 +1,20 @@ +use std::panic::Location; + +use crate::located_error::LocatedError; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum Error { + #[error("The supplied key: {key:?}, is not valid: {source}")] + PeerKeyNotValid { + key: super::auth::Key, + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + }, + #[error("The peer is not authenticated, {location}")] + PeerNotAuthenticated { location: &'static Location<'static> }, + + #[error("The torrent: {info_hash}, is not whitelisted, {location}")] + TorrentNotWhitelisted { + info_hash: crate::protocol::info_hash::InfoHash, + location: &'static Location<'static>, + }, +} diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 4f1dab49b..acbf7d536 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -1,4 +1,5 @@ pub mod auth; +pub mod error; pub mod mode; pub mod peer; pub mod services; @@ -8,13 +9,16 @@ pub mod torrent; use std::collections::btree_map::Entry; use std::collections::BTreeMap; use std::net::SocketAddr; +use std::panic::Location; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::error::SendError; use tokio::sync::{RwLock, RwLockReadGuard}; +use self::error::Error; use crate::config::Configuration; +use crate::databases::driver::Driver; use crate::databases::{self, Database}; use crate::protocol::info_hash::InfoHash; @@ -40,13 +44,13 @@ pub struct TorrentsMetrics { impl Tracker { /// # Errors /// - /// Will return a `r2d2::Error` if unable to connect to database. + /// Will return a `databases::error::Error` if unable to connect to database. pub fn new( config: &Arc, stats_event_sender: Option>, stats_repository: statistics::Repo, - ) -> Result { - let database = databases::connect(&config.db_driver, &config.db_path)?; + ) -> Result { + let database = Driver::build(&config.db_driver, &config.db_path)?; Ok(Tracker { config: config.clone(), @@ -97,7 +101,10 @@ impl Tracker { pub async fn verify_auth_key(&self, auth_key: &auth::Key) -> Result<(), auth::Error> { // todo: use auth::KeyId for the function argument `auth_key` match self.keys.read().await.get(&auth_key.key) { - None => Err(auth::Error::KeyInvalid), + None => Err(auth::Error::UnableToReadKey { + location: Location::caller(), + key: Box::new(auth_key.clone()), + }), Some(key) => auth::verify(key), } } @@ -203,7 +210,7 @@ impl Tracker { /// Will return a `torrent::Error::PeerNotAuthenticated` if the `key` is `None`. /// /// Will return a `torrent::Error::TorrentNotWhitelisted` if the the Tracker is in listed mode and the `info_hash` is not whitelisted. - pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), torrent::Error> { + pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), Error> { // no authentication needed in public mode if self.is_public() { return Ok(()); @@ -213,19 +220,27 @@ impl Tracker { if self.is_private() { match key { Some(key) => { - if self.verify_auth_key(key).await.is_err() { - return Err(torrent::Error::PeerKeyNotValid); + if let Err(e) = self.verify_auth_key(key).await { + return Err(Error::PeerKeyNotValid { + key: key.clone(), + source: (Arc::new(e) as Arc).into(), + }); } } None => { - return Err(torrent::Error::PeerNotAuthenticated); + return Err(Error::PeerNotAuthenticated { + location: Location::caller(), + }); } } } // check if info_hash is whitelisted if self.is_whitelisted() && !self.is_info_hash_whitelisted(info_hash).await { - return Err(torrent::Error::TorrentNotWhitelisted); + return Err(Error::TorrentNotWhitelisted { + info_hash: *info_hash, + location: Location::caller(), + }); } Ok(()) diff --git a/src/tracker/torrent.rs b/src/tracker/torrent.rs index e292dff54..b5535a932 100644 --- a/src/tracker/torrent.rs +++ b/src/tracker/torrent.rs @@ -99,16 +99,6 @@ pub struct SwamStats { pub leechers: u32, } -#[derive(Debug)] -pub enum Error { - TorrentNotWhitelisted, - PeerNotAuthenticated, - PeerKeyNotValid, - NoPeersFound, - CouldNotSendResponse, - InvalidInfoHash, -} - #[cfg(test)] mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; diff --git a/src/udp/connection_cookie.rs b/src/udp/connection_cookie.rs index 3daa3e0f6..ef2a8b219 100644 --- a/src/udp/connection_cookie.rs +++ b/src/udp/connection_cookie.rs @@ -1,4 +1,5 @@ use std::net::SocketAddr; +use std::panic::Location; use aquatic_udp_protocol::ConnectionId; @@ -49,7 +50,9 @@ pub fn check(remote_address: &SocketAddr, connection_cookie: &Cookie) -> Result< return Ok(checking_time_extent); } } - Err(Error::InvalidConnectionId) + Err(Error::InvalidConnectionId { + location: Location::caller(), + }) } mod cookie_builder { diff --git a/src/udp/error.rs b/src/udp/error.rs index c5fbb3929..de66eb2bf 100644 --- a/src/udp/error.rs +++ b/src/udp/error.rs @@ -1,49 +1,27 @@ +use std::panic::Location; + use thiserror::Error; -use crate::tracker::torrent; +use crate::located_error::LocatedError; #[derive(Error, Debug)] pub enum Error { - #[error("internal server error")] - InternalServer, + #[error("tracker server error: {source}")] + TrackerError { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + }, - #[error("info_hash is either missing or invalid")] - InvalidInfoHash, + #[error("internal server error: {message}, {location}")] + InternalServer { + location: &'static Location<'static>, + message: String, + }, #[error("connection id could not be verified")] - InvalidConnectionId, - - #[error("could not find remote address")] - AddressNotFound, - - #[error("torrent has no peers")] - NoPeersFound, - - #[error("torrent not on whitelist")] - TorrentNotWhitelisted, - - #[error("peer not authenticated")] - PeerNotAuthenticated, - - #[error("invalid authentication key")] - PeerKeyNotValid, - - #[error("exceeded info_hash limit")] - ExceededInfoHashLimit, - - #[error("bad request")] - BadRequest, -} + InvalidConnectionId { location: &'static Location<'static> }, -impl From for Error { - fn from(e: torrent::Error) -> Self { - match e { - torrent::Error::TorrentNotWhitelisted => Error::TorrentNotWhitelisted, - torrent::Error::PeerNotAuthenticated => Error::PeerNotAuthenticated, - torrent::Error::PeerKeyNotValid => Error::PeerKeyNotValid, - torrent::Error::NoPeersFound => Error::NoPeersFound, - torrent::Error::CouldNotSendResponse => Error::InternalServer, - torrent::Error::InvalidInfoHash => Error::InvalidInfoHash, - } - } + #[error("bad request: {source}")] + BadRequest { + source: LocatedError<'static, dyn std::error::Error + Send + Sync>, + }, } diff --git a/src/udp/handlers.rs b/src/udp/handlers.rs index 076710fb6..b36399f89 100644 --- a/src/udp/handlers.rs +++ b/src/udp/handlers.rs @@ -1,4 +1,5 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::panic::Location; use std::sync::Arc; use aquatic_udp_protocol::{ @@ -14,7 +15,10 @@ use crate::udp::error::Error; use crate::udp::request::AnnounceWrapper; pub async fn handle_packet(remote_addr: SocketAddr, payload: Vec, tracker: Arc) -> Response { - match Request::from_bytes(&payload[..payload.len()], MAX_SCRAPE_TORRENTS).map_err(|_| Error::InternalServer) { + match Request::from_bytes(&payload[..payload.len()], MAX_SCRAPE_TORRENTS).map_err(|e| Error::InternalServer { + message: format!("{e:?}"), + location: Location::caller(), + }) { Ok(request) => { let transaction_id = match &request { Request::Connect(connect_request) => connect_request.transaction_id, @@ -28,7 +32,12 @@ pub async fn handle_packet(remote_addr: SocketAddr, payload: Vec, tracker: A } } // bad request - Err(_) => handle_error(&Error::BadRequest, TransactionId(0)), + Err(e) => handle_error( + &Error::BadRequest { + source: (Arc::new(e) as Arc).into(), + }, + TransactionId(0), + ), } } @@ -90,7 +99,10 @@ pub async fn handle_announce( tracker .authenticate_request(&wrapped_announce_request.info_hash, &None) - .await?; + .await + .map_err(|e| Error::TrackerError { + source: (Arc::new(e) as Arc).into(), + })?; let peer = peer::Peer::from_udp_announce_request( &wrapped_announce_request.announce_request, diff --git a/tests/api/asserts.rs b/tests/api/asserts.rs index 5f9d39705..5a4abfb62 100644 --- a/tests/api/asserts.rs +++ b/tests/api/asserts.rs @@ -37,9 +37,20 @@ pub async fn assert_auth_key_utf8(response: Response) -> AuthKey { // OK response pub async fn assert_ok(response: Response) { - assert_eq!(response.status(), 200); - assert_eq!(response.headers().get("content-type").unwrap(), "application/json"); - assert_eq!(response.text().await.unwrap(), "{\"status\":\"ok\"}"); + let response_status = response.status(); + let response_headers = response.headers().get("content-type").cloned().unwrap(); + let response_text = response.text().await.unwrap(); + + let details = format!( + r#" + status: ´{response_status}´ + headers: ´{response_headers:?}´ + text: ´"{response_text}"´"# + ); + + assert_eq!(response_status, 200, "details:{details}."); + assert_eq!(response_headers, "application/json", "\ndetails:{details}."); + assert_eq!(response_text, "{\"status\":\"ok\"}", "\ndetails:{details}."); } // Error responses @@ -118,8 +129,11 @@ pub async fn assert_failed_to_reload_keys(response: Response) { async fn assert_unhandled_rejection(response: Response, reason: &str) { assert_eq!(response.status(), 500); assert_eq!(response.headers().get("content-type").unwrap(), "text/plain; charset=utf-8"); - assert_eq!( - response.text().await.unwrap(), - format!("Unhandled rejection: Err {{ reason: \"{reason}\" }}") + + let reason_text = format!("Unhandled rejection: Err {{ reason: \"{reason}"); + let response_text = response.text().await.unwrap(); + assert!( + response_text.contains(&reason_text), + ":\n response: `\"{response_text}\"`\n dose not contain: `\"{reason_text}\"`." ); } diff --git a/tests/http/asserts.rs b/tests/http/asserts.rs index 59f4ed42a..211a7bb33 100644 --- a/tests/http/asserts.rs +++ b/tests/http/asserts.rs @@ -1,9 +1,27 @@ +use std::panic::Location; + use reqwest::Response; use super::responses::announce::{Announce, Compact, DeserializedCompact}; use super::responses::scrape; use crate::http::responses::error::Error; +pub fn assert_error_bencoded(response_text: &String, expected_failure_reason: &str, location: &'static Location<'static>) { + let error_failure_reason = serde_bencode::from_str::(response_text) + .unwrap_or_else(|_| panic!( + "response body should be a valid bencoded string for the '{expected_failure_reason}' error, got \"{response_text}\"" + ) + ) + .failure_reason; + + assert!( + error_failure_reason.contains(expected_failure_reason), + r#": + response: `"{error_failure_reason}"` + dose not contain: `"{expected_failure_reason}"`, {location}"# + ); +} + pub async fn assert_empty_announce_response(response: Response) { assert_eq!(response.status(), 200); let announce_response: Announce = serde_bencode::from_str(&response.text().await.unwrap()).unwrap(); @@ -64,90 +82,48 @@ pub async fn assert_is_announce_response(response: Response) { pub async fn assert_internal_server_error_response(response: Response) { assert_eq!(response.status(), 200); - let body = response.text().await.unwrap(); - let error_response: Error = serde_bencode::from_str(&body).unwrap_or_else(|_| { - panic!( - "response body should be a valid bencoded string for the 'internal server' error, got \"{}\"", - &body - ) - }); - let expected_error_response = Error { - failure_reason: "internal server error".to_string(), - }; - assert_eq!(error_response, expected_error_response); + + assert_error_bencoded(&response.text().await.unwrap(), "internal server", Location::caller()); } pub async fn assert_invalid_info_hash_error_response(response: Response) { assert_eq!(response.status(), 200); - let body = response.text().await.unwrap(); - let error_response: Error = serde_bencode::from_str(&body).unwrap_or_else(|_| { - panic!( - "response body should be a valid bencoded string for the 'invalid info_hash' error, got \"{}\"", - &body - ) - }); - let expected_error_response = Error { - failure_reason: "info_hash is either missing or invalid".to_string(), - }; - assert_eq!(error_response, expected_error_response); + + assert_error_bencoded( + &response.text().await.unwrap(), + "no valid infohashes found", + Location::caller(), + ); } pub async fn assert_invalid_peer_id_error_response(response: Response) { assert_eq!(response.status(), 200); - let body = response.text().await.unwrap(); - let error_response: Error = serde_bencode::from_str(&body).unwrap_or_else(|_| { - panic!( - "response body should be a valid bencoded string for the 'invalid peer id' error, got \"{}\"", - &body - ) - }); - let expected_error_response = Error { - failure_reason: "peer_id is either missing or invalid".to_string(), - }; - assert_eq!(error_response, expected_error_response); + + assert_error_bencoded( + &response.text().await.unwrap(), + "peer_id is either missing or invalid", + Location::caller(), + ); } pub async fn assert_torrent_not_in_whitelist_error_response(response: Response) { assert_eq!(response.status(), 200); - let body = response.text().await.unwrap(); - let error_response: Error = serde_bencode::from_str(&body).unwrap_or_else(|_| { - panic!( - "response body should be a valid bencoded string for the 'torrent not on whitelist' error, got \"{}\"", - &body - ) - }); - let expected_error_response = Error { - failure_reason: "torrent not on whitelist".to_string(), - }; - assert_eq!(error_response, expected_error_response); + + assert_error_bencoded(&response.text().await.unwrap(), "is not whitelisted", Location::caller()); } pub async fn assert_peer_not_authenticated_error_response(response: Response) { assert_eq!(response.status(), 200); - let body = response.text().await.unwrap(); - let error_response: Error = serde_bencode::from_str(&body).unwrap_or_else(|_| { - panic!( - "response body should be a valid bencoded string for the 'peer not authenticated' error, got \"{}\"", - &body - ) - }); - let expected_error_response = Error { - failure_reason: "peer not authenticated".to_string(), - }; - assert_eq!(error_response, expected_error_response); + + assert_error_bencoded( + &response.text().await.unwrap(), + "The peer is not authenticated", + Location::caller(), + ); } pub async fn assert_invalid_authentication_key_error_response(response: Response) { assert_eq!(response.status(), 200); - let body = response.text().await.unwrap(); - let error_response: Error = serde_bencode::from_str(&body).unwrap_or_else(|_| { - panic!( - "response body should be a valid bencoded string for the 'invalid authentication key' error, got \"{}\"", - &body - ) - }); - let expected_error_response = Error { - failure_reason: "invalid authentication key".to_string(), - }; - assert_eq!(error_response, expected_error_response); + + assert_error_bencoded(&response.text().await.unwrap(), "is not valid", Location::caller()); }