From 9eca3cdf9e8a42080616cb4b797ea088398a6e75 Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 08:51:59 +0100 Subject: [PATCH 01/10] set edition for rustfmt --- rustfmt.toml | 1 + 1 file changed, 1 insertion(+) create mode 100644 rustfmt.toml diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..32a9786fa --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2018" From 19b21a92abdf7fef305699c1fd8e6f08dbd016ef Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 08:52:30 +0100 Subject: [PATCH 02/10] run rustfmt --- src/common.rs | 5 +- src/config.rs | 56 ++++----- src/database.rs | 104 +++++++++++------ src/http_api_server.rs | 125 +++++++++++++-------- src/http_server.rs | 249 ++++++++++++++++++++++++++--------------- src/key_manager.rs | 36 +++--- src/lib.rs | 22 ++-- src/logging.rs | 24 ++-- src/main.rs | 41 ++++--- src/response.rs | 18 +-- src/tracker.rs | 169 ++++++++++++++++------------ src/udp_server.rs | 192 ++++++++++++++++++------------- src/utils.rs | 11 +- 13 files changed, 632 insertions(+), 420 deletions(-) diff --git a/src/common.rs b/src/common.rs index 82ea19ab8..439adf69c 100644 --- a/src/common.rs +++ b/src/common.rs @@ -252,8 +252,9 @@ impl PeerId { } impl Serialize for PeerId { fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, { + where + S: serde::Serializer, + { let mut tmp = [0u8; 40]; binascii::bin2hex(&self.0, &mut tmp).unwrap(); let id = std::str::from_utf8(&tmp).ok(); diff --git a/src/config.rs b/src/config.rs index 9a7e47e37..ac67fb1ad 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,13 @@ pub use crate::tracker::TrackerMode; -use serde::{Serialize, Deserialize, Serializer}; +use config::{Config, ConfigError, File}; +use serde::{Deserialize, Serialize, Serializer}; use std; use std::collections::HashMap; use std::fs; -use toml; -use std::net::{IpAddr}; +use std::net::IpAddr; use std::path::Path; use std::str::FromStr; -use config::{ConfigError, Config, File}; +use toml; #[derive(Serialize, Deserialize)] pub struct UdpTrackerConfig { @@ -24,7 +24,7 @@ pub struct HttpTrackerConfig { #[serde(serialize_with = "none_as_empty_string")] pub ssl_cert_path: Option, #[serde(serialize_with = "none_as_empty_string")] - pub ssl_key_path: Option + pub ssl_key_path: Option, } impl HttpTrackerConfig { @@ -70,9 +70,9 @@ impl std::fmt::Display for ConfigurationError { impl std::error::Error for ConfigurationError {} pub fn none_as_empty_string(option: &Option, serializer: S) -> Result - where - T: Serialize, - S: Serializer, +where + T: Serialize, + S: Serializer, { if let Some(value) = option { value.serialize(serializer) @@ -89,26 +89,20 @@ impl Configuration { pub fn load_file(path: &str) -> Result { match std::fs::read(path) { Err(e) => Err(ConfigurationError::IOError(e)), - Ok(data) => { - match Self::load(data.as_slice()) { - Ok(cfg) => { - Ok(cfg) - }, - Err(e) => Err(ConfigurationError::ParseError(e)), - } - } + Ok(data) => match Self::load(data.as_slice()) { + Ok(cfg) => Ok(cfg), + Err(e) => Err(ConfigurationError::ParseError(e)), + }, } } pub fn get_ext_ip(&self) -> Option { match &self.external_ip { None => None, - Some(external_ip) => { - match IpAddr::from_str(external_ip) { - Ok(external_ip) => Some(external_ip), - Err(_) => None - } - } + Some(external_ip) => match IpAddr::from_str(external_ip) { + Ok(external_ip) => Some(external_ip), + Err(_) => None, + }, } } } @@ -131,12 +125,15 @@ impl Configuration { announce_interval: 120, ssl_enabled: false, ssl_cert_path: None, - ssl_key_path: None + ssl_key_path: None, }), http_api: Option::from(HttpApiConfig { enabled: true, bind_address: String::from("127.0.0.1:1212"), - access_tokens: [(String::from("admin"), String::from("MyAccessToken"))].iter().cloned().collect(), + access_tokens: [(String::from("admin"), String::from("MyAccessToken"))] + .iter() + .cloned() + .collect(), }), } } @@ -153,16 +150,21 @@ impl Configuration { eprintln!("Creating config file.."); let config = Configuration::default(); let _ = config.save_to_file(); - return Err(ConfigError::Message(format!("Please edit the config.TOML in the root folder and restart the tracker."))) + return Err(ConfigError::Message(format!( + "Please edit the config.TOML in the root folder and restart the tracker." + ))); } match config.try_into() { Ok(data) => Ok(data), - Err(e) => Err(ConfigError::Message(format!("Errors while processing config: {}.", e))), + Err(e) => Err(ConfigError::Message(format!( + "Errors while processing config: {}.", + e + ))), } } - pub fn save_to_file(&self) -> Result<(), ()>{ + pub fn save_to_file(&self) -> Result<(), ()> { let toml_string = toml::to_string(self).expect("Could not encode TOML value"); fs::write("config.toml", toml_string).expect("Could not write to file!"); Ok(()) diff --git a/src/database.rs b/src/database.rs index fbec824a0..9dda516e5 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,55 +1,58 @@ +use crate::key_manager::AuthKey; use crate::{InfoHash, AUTH_KEY_LENGTH}; use log::debug; -use r2d2_sqlite::{SqliteConnectionManager, rusqlite}; -use r2d2::{Pool}; +use r2d2::Pool; use r2d2_sqlite::rusqlite::NO_PARAMS; -use crate::key_manager::AuthKey; +use r2d2_sqlite::{rusqlite, SqliteConnectionManager}; use std::str::FromStr; pub struct SqliteDatabase { - pool: Pool + pool: Pool, } impl SqliteDatabase { pub fn new(db_path: &str) -> Result { let sqlite_connection_manager = SqliteConnectionManager::file(db_path); - let sqlite_pool = r2d2::Pool::new(sqlite_connection_manager).expect("Failed to create r2d2 SQLite connection pool."); - let sqlite_database = SqliteDatabase { - pool: sqlite_pool - }; + let sqlite_pool = r2d2::Pool::new(sqlite_connection_manager) + .expect("Failed to create r2d2 SQLite connection pool."); + let sqlite_database = SqliteDatabase { pool: sqlite_pool }; if let Err(error) = SqliteDatabase::create_database_tables(&sqlite_database.pool) { - return Err(error) + return Err(error); }; Ok(sqlite_database) } - pub fn create_database_tables(pool: &Pool) -> Result { + pub fn create_database_tables( + pool: &Pool, + ) -> Result { let create_whitelist_table = " CREATE TABLE IF NOT EXISTS whitelist ( id integer PRIMARY KEY AUTOINCREMENT, info_hash VARCHAR(20) NOT NULL UNIQUE - );".to_string(); + );" + .to_string(); - let create_keys_table = format!(" + let create_keys_table = format!( + " CREATE TABLE IF NOT EXISTS keys ( id integer PRIMARY KEY AUTOINCREMENT, key VARCHAR({}) NOT NULL UNIQUE, valid_until INT(10) NOT NULL - );", AUTH_KEY_LENGTH as i8); + );", + AUTH_KEY_LENGTH as i8 + ); let conn = pool.get().unwrap(); match conn.execute(&create_whitelist_table, NO_PARAMS) { - Ok(updated) => { - match conn.execute(&create_keys_table, NO_PARAMS) { - Ok(updated2) => Ok(updated + updated2), - Err(e) => { - debug!("{:?}", e); - Err(e) - } + Ok(updated) => match conn.execute(&create_keys_table, NO_PARAMS) { + Ok(updated2) => Ok(updated + updated2), + Err(e) => { + debug!("{:?}", e); + Err(e) } - } + }, Err(e) => { debug!("{:?}", e); Err(e) @@ -57,7 +60,10 @@ impl SqliteDatabase { } } - pub async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result { + pub async fn get_info_hash_from_whitelist( + &self, + info_hash: &str, + ) -> Result { let conn = self.pool.get().unwrap(); let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?; let mut rows = stmt.query(&[info_hash])?; @@ -72,13 +78,21 @@ impl SqliteDatabase { } } - pub async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result { + pub async fn add_info_hash_to_whitelist( + &self, + info_hash: InfoHash, + ) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute("INSERT INTO whitelist (info_hash) VALUES (?)", &[info_hash.to_string()]) { + match conn.execute( + "INSERT INTO whitelist (info_hash) VALUES (?)", + &[info_hash.to_string()], + ) { Ok(updated) => { - if updated > 0 { return Ok(updated) } + if updated > 0 { + return Ok(updated); + } Err(rusqlite::Error::ExecuteReturnedResults) - }, + } Err(e) => { debug!("{:?}", e); Err(e) @@ -86,13 +100,21 @@ impl SqliteDatabase { } } - pub async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result { + pub async fn remove_info_hash_from_whitelist( + &self, + info_hash: InfoHash, + ) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute("DELETE FROM whitelist WHERE info_hash = ?", &[info_hash.to_string()]) { + match conn.execute( + "DELETE FROM whitelist WHERE info_hash = ?", + &[info_hash.to_string()], + ) { Ok(updated) => { - if updated > 0 { return Ok(updated) } + if updated > 0 { + return Ok(updated); + } Err(rusqlite::Error::ExecuteReturnedResults) - }, + } Err(e) => { debug!("{:?}", e); Err(e) @@ -111,7 +133,7 @@ impl SqliteDatabase { Ok(AuthKey { key, - valid_until: Some(valid_until_i64 as u64) + valid_until: Some(valid_until_i64 as u64), }) } else { Err(rusqlite::Error::QueryReturnedNoRows) @@ -120,13 +142,19 @@ impl SqliteDatabase { pub async fn add_key_to_keys(&self, auth_key: &AuthKey) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute("INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", - &[auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string()] + match conn.execute( + "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", + &[ + auth_key.key.to_string(), + auth_key.valid_until.unwrap().to_string(), + ], ) { Ok(updated) => { - if updated > 0 { return Ok(updated) } + if updated > 0 { + return Ok(updated); + } Err(rusqlite::Error::ExecuteReturnedResults) - }, + } Err(e) => { debug!("{:?}", e); Err(e) @@ -138,9 +166,11 @@ impl SqliteDatabase { let conn = self.pool.get().unwrap(); match conn.execute("DELETE FROM keys WHERE key = ?", &[key]) { Ok(updated) => { - if updated > 0 { return Ok(updated) } + if updated > 0 { + return Ok(updated); + } Err(rusqlite::Error::ExecuteReturnedResults) - }, + } Err(e) => { debug!("{:?}", e); Err(e) diff --git a/src/http_api_server.rs b/src/http_api_server.rs index 5f7339036..6c3fa000c 100644 --- a/src/http_api_server.rs +++ b/src/http_api_server.rs @@ -1,10 +1,10 @@ -use crate::tracker::{TorrentTracker}; +use super::common::*; +use crate::tracker::TorrentTracker; use serde::{Deserialize, Serialize}; use std::cmp::min; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use warp::{filters, reply, reply::Reply, serve, Filter, Server}; -use super::common::*; #[derive(Deserialize, Debug)] struct TorrentInfoQuery { @@ -34,7 +34,9 @@ enum ActionStatus<'a> { impl warp::reject::Reject for ActionStatus<'static> {} -fn authenticate(tokens: HashMap) -> impl Filter + Clone { +fn authenticate( + tokens: HashMap, +) -> impl Filter + Clone { #[derive(Deserialize)] struct AuthToken { token: Option, @@ -51,19 +53,25 @@ fn authenticate(tokens: HashMap) -> impl Filter { if !tokens.contains(&token) { - return Err(warp::reject::custom(ActionStatus::Err { reason: "token not valid".into() })) + return Err(warp::reject::custom(ActionStatus::Err { + reason: "token not valid".into(), + })); } Ok(()) } - None => Err(warp::reject::custom(ActionStatus::Err { reason: "unauthorized".into() })) + None => Err(warp::reject::custom(ActionStatus::Err { + reason: "unauthorized".into(), + })), } } }) .untuple_one() } -pub fn build_server(tracker: Arc) -> Server + Clone + Send + Sync + 'static> { +pub fn build_server( + tracker: Arc, +) -> Server + Clone + Send + Sync + 'static> { // GET /api/torrents?offset=:u32&limit=:u32 // View torrent list let t1 = tracker.clone(); @@ -75,32 +83,34 @@ pub fn build_server(tracker: Arc) -> Server)| { - async move { - let offset = limits.offset.unwrap_or(0); - let limit = min(limits.limit.unwrap_or(1000), 4000); + .and_then( + |(limits, tracker): (TorrentInfoQuery, Arc)| { + async move { + let offset = limits.offset.unwrap_or(0); + let limit = min(limits.limit.unwrap_or(1000), 4000); - let db = tracker.get_torrents().await; - let results: Vec<_> = db - .iter() - .map(|(info_hash, torrent_entry)| { - let (seeders, completed, leechers) = torrent_entry.get_stats(); - Torrent { - info_hash, - data: torrent_entry, - seeders, - completed, - leechers, - peers: None, - } - }) - .skip(offset as usize) - .take(limit as usize) - .collect(); + let db = tracker.get_torrents().await; + let results: Vec<_> = db + .iter() + .map(|(info_hash, torrent_entry)| { + let (seeders, completed, leechers) = torrent_entry.get_stats(); + Torrent { + info_hash, + data: torrent_entry, + seeders, + completed, + leechers, + peers: None, + } + }) + .skip(offset as usize) + .take(limit as usize) + .collect(); - Result::<_, warp::reject::Rejection>::Ok(reply::json(&results)) - } - }); + Result::<_, warp::reject::Rejection>::Ok(reply::json(&results)) + } + }, + ); // GET /api/torrent/:infohash // View torrent info @@ -119,7 +129,9 @@ pub fn build_server(tracker: Arc) -> Server) -> Server)| { async move { - match tracker.remove_torrent_from_whitelist(&info_hash).await { - Ok(_) => Ok(warp::reply::json(&ActionStatus::Ok)), - Err(_) => Err(warp::reject::custom(ActionStatus::Err { reason: "failed to remove torrent from whitelist".into() })) - } + match tracker.remove_torrent_from_whitelist(&info_hash).await { + Ok(_) => Ok(warp::reply::json(&ActionStatus::Ok)), + Err(_) => Err(warp::reject::custom(ActionStatus::Err { + reason: "failed to remove torrent from whitelist".into(), + })), + } } }); @@ -177,7 +191,9 @@ pub fn build_server(tracker: Arc) -> Server Ok(warp::reply::json(&ActionStatus::Ok)), - Err(..) => Err(warp::reject::custom(ActionStatus::Err { reason: "failed to whitelist torrent".into() })) + Err(..) => Err(warp::reject::custom(ActionStatus::Err { + reason: "failed to whitelist torrent".into(), + })), } } }); @@ -197,7 +213,9 @@ pub fn build_server(tracker: Arc) -> Server Ok(warp::reply::json(&auth_key)), - Err(..) => Err(warp::reject::custom(ActionStatus::Err { reason: "failed to generate key".into() })) + Err(..) => Err(warp::reject::custom(ActionStatus::Err { + reason: "failed to generate key".into(), + })), } } }); @@ -217,22 +235,31 @@ pub fn build_server(tracker: Arc) -> Server Ok(warp::reply::json(&ActionStatus::Ok)), - Err(_) => Err(warp::reject::custom(ActionStatus::Err { reason: "failed to delete key".into() })) + Err(_) => Err(warp::reject::custom(ActionStatus::Err { + reason: "failed to delete key".into(), + })), } } }); - let api_routes = - filters::path::path("api") - .and(view_torrent_list - .or(delete_torrent) - .or(view_torrent_info) - .or(add_torrent) - .or(create_key) - .or(delete_key) - ); - - let server = api_routes.and(authenticate(tracker.config.http_api.as_ref().unwrap().access_tokens.clone())); + let api_routes = filters::path::path("api").and( + view_torrent_list + .or(delete_torrent) + .or(view_torrent_info) + .or(add_torrent) + .or(create_key) + .or(delete_key), + ); + + let server = api_routes.and(authenticate( + tracker + .config + .http_api + .as_ref() + .unwrap() + .access_tokens + .clone(), + )); serve(server) } diff --git a/src/http_server.rs b/src/http_server.rs index bf1f7f88d..0182d6db7 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -1,19 +1,19 @@ -use std::collections::{HashMap}; -use crate::tracker::{TorrentTracker}; +use super::common::*; +use crate::key_manager::AuthKey; +use crate::tracker::TorrentTracker; +use crate::utils::url_encode_bytes; +use crate::{TorrentError, TorrentPeer, TorrentStats}; +use log::debug; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::convert::Infallible; use std::error::Error; use std::io::Write; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; use std::str::FromStr; -use log::{debug}; -use warp::{filters, reply::Reply, Filter}; +use std::sync::Arc; use warp::http::Response; -use crate::{TorrentError, TorrentPeer, TorrentStats}; -use crate::key_manager::AuthKey; -use crate::utils::url_encode_bytes; -use super::common::*; +use warp::{filters, reply::Reply, Filter}; #[derive(Deserialize, Debug)] pub struct AnnounceRequest { @@ -52,7 +52,7 @@ struct AnnounceResponse { //tracker_id: String, complete: u32, incomplete: u32, - peers: Vec + peers: Vec, } impl AnnounceResponse { @@ -104,7 +104,7 @@ impl AnnounceResponse { #[derive(Serialize)] struct ScrapeResponse { - files: HashMap + files: HashMap, } impl ScrapeResponse { @@ -122,7 +122,7 @@ struct ScrapeResponseEntry { #[derive(Serialize)] struct ErrorResponse { - failure_reason: String + failure_reason: String, } impl warp::Reply for ErrorResponse { @@ -138,77 +138,105 @@ pub struct HttpServer { impl HttpServer { pub fn new(tracker: Arc) -> HttpServer { - HttpServer { - tracker - } + HttpServer { tracker } } // &self did not work here - pub fn routes(http_server: Arc) -> impl Filter + Clone + Send + Sync + 'static { + pub fn routes( + http_server: Arc, + ) -> impl Filter + Clone + Send + Sync + 'static { // optional tracker key - let opt_key = warp::path::param::() - .map(Some) - .or_else(|_| async { + let opt_key = warp::path::param::().map(Some).or_else(|_| { + async { // Ok(None) Ok::<(Option,), std::convert::Infallible>((None,)) - }); + } + }); // GET /announce?key=:String // Announce peer let hs1 = http_server.clone(); - let announce_route = - filters::path::path("announce") - .and(filters::method::get()) - .and(warp::addr::remote()) - .and(opt_key) - .and(filters::query::raw()) - .and(filters::query::query()) - .map(move |remote_addr, key, raw_query, query| { - debug!("Request: {}", raw_query); - (remote_addr, key, raw_query, query, hs1.clone()) - }) - .and_then(move |(remote_addr, key, raw_query, mut query, http_server): (Option, Option, String, AnnounceRequest, Arc)| { + let announce_route = filters::path::path("announce") + .and(filters::method::get()) + .and(warp::addr::remote()) + .and(opt_key) + .and(filters::query::raw()) + .and(filters::query::query()) + .map(move |remote_addr, key, raw_query, query| { + debug!("Request: {}", raw_query); + (remote_addr, key, raw_query, query, hs1.clone()) + }) + .and_then( + move |(remote_addr, key, raw_query, mut query, http_server): ( + Option, + Option, + String, + AnnounceRequest, + Arc, + )| { async move { - if remote_addr.is_none() { return HttpServer::send_error("could not get remote address") } + if remote_addr.is_none() { + return HttpServer::send_error("could not get remote address"); + } // query.info_hash somehow receives a corrupt string // so we have to get the info_hash manually from the raw query let info_hashes = HttpServer::info_hashes_from_raw_query(&raw_query); - if info_hashes.len() < 1 { return HttpServer::send_error("info_hash not found") } + if info_hashes.len() < 1 { + return HttpServer::send_error("info_hash not found"); + } query.info_hash = info_hashes[0].to_string(); debug!("{:?}", query.info_hash); - if let Some(err) = http_server.authenticate_request(&query.info_hash, key).await { return err } + if let Some(err) = http_server + .authenticate_request(&query.info_hash, key) + .await + { + return err; + } - http_server.handle_announce(query, remote_addr.unwrap()).await + http_server + .handle_announce(query, remote_addr.unwrap()) + .await } - }); + }, + ); // GET /scrape?key=:String // Get torrent info let hs2 = http_server.clone(); - let scrape_route = - filters::path::path("scrape") - .and(filters::method::get()) - .and(opt_key) - .and(filters::query::raw()) - .map(move |key, raw_query| { - debug!("Request: {}", raw_query); - (key, raw_query, hs2.clone()) - }) - .and_then(move |(key, raw_query, http_server): (Option, String, Arc)| { + let scrape_route = filters::path::path("scrape") + .and(filters::method::get()) + .and(opt_key) + .and(filters::query::raw()) + .map(move |key, raw_query| { + debug!("Request: {}", raw_query); + (key, raw_query, hs2.clone()) + }) + .and_then( + move |(key, raw_query, http_server): (Option, String, Arc)| { async move { let info_hashes = HttpServer::info_hashes_from_raw_query(&raw_query); - if info_hashes.len() < 1 { return HttpServer::send_error("info_hash not found") } - if info_hashes.len() > 50 { return HttpServer::send_error("exceeded the max of 50 info_hashes") } + if info_hashes.len() < 1 { + return HttpServer::send_error("info_hash not found"); + } + if info_hashes.len() > 50 { + return HttpServer::send_error("exceeded the max of 50 info_hashes"); + } debug!("{:?}", info_hashes); // todo: verify all info_hashes before scrape - if let Some(err) = http_server.authenticate_request(&info_hashes[0].to_string(), key).await { return err } + if let Some(err) = http_server + .authenticate_request(&info_hashes[0].to_string(), key) + .await + { + return err; + } http_server.handle_scrape(info_hashes).await } - }); + }, + ); // all routes warp::any().and(announce_route.or(scrape_route)) @@ -221,7 +249,8 @@ impl HttpServer { for v in split_raw_query { if v.contains("info_hash") { let raw_info_hash = v.split("=").collect::>()[1]; - let info_hash_bytes = percent_encoding::percent_decode_str(raw_info_hash).collect::>(); + let info_hash_bytes = + percent_encoding::percent_decode_str(raw_info_hash).collect::>(); let info_hash = InfoHash::from_str(&hex::encode(info_hash_bytes)); if let Ok(ih) = info_hash { info_hashes.push(ih); @@ -232,18 +261,26 @@ impl HttpServer { info_hashes } - fn send_announce_response(query: &AnnounceRequest, torrent_stats: TorrentStats, peers: Vec, interval: u32) -> Result { - let http_peers: Vec = peers.iter().map(|peer| Peer { - peer_id: String::from_utf8_lossy(&peer.peer_id.0).to_string(), - ip: peer.peer_addr.ip(), - port: peer.peer_addr.port() - }).collect(); + fn send_announce_response( + query: &AnnounceRequest, + torrent_stats: TorrentStats, + peers: Vec, + interval: u32, + ) -> Result { + let http_peers: Vec = peers + .iter() + .map(|peer| Peer { + peer_id: String::from_utf8_lossy(&peer.peer_id.0).to_string(), + ip: peer.peer_addr.ip(), + port: peer.peer_addr.port(), + }) + .collect(); let res = AnnounceResponse { interval, complete: torrent_stats.seeders, incomplete: torrent_stats.leechers, - peers: http_peers + peers: http_peers, }; // check for compact response request @@ -270,20 +307,31 @@ impl HttpServer { fn send_error(msg: &str) -> Result { Ok(ErrorResponse { - failure_reason: msg.to_string() - }.into_response()) + failure_reason: msg.to_string(), + } + .into_response()) } - async fn authenticate_request(&self, info_hash_str: &str, key: Option) -> Option> { - let info_hash= InfoHash::from_str(info_hash_str); - if info_hash.is_err() { return Some(HttpServer::send_error("invalid info_hash")) } + async fn authenticate_request( + &self, + info_hash_str: &str, + key: Option, + ) -> Option> { + let info_hash = InfoHash::from_str(info_hash_str); + if info_hash.is_err() { + return Some(HttpServer::send_error("invalid info_hash")); + } let auth_key = match key { None => None, - Some(v) => AuthKey::from_string(&v) + Some(v) => AuthKey::from_string(&v), }; - if let Err(e) = self.tracker.authenticate_request(&info_hash.unwrap(), &auth_key).await { + if let Err(e) = self + .tracker + .authenticate_request(&info_hash.unwrap(), &auth_key) + .await + { return match e { TorrentError::TorrentNotWhitelisted => { debug!("Info_hash not whitelisted."); @@ -297,45 +345,72 @@ impl HttpServer { debug!("Peer not authenticated."); Some(HttpServer::send_error("peer not authenticated")) } - } + }; } None } - async fn handle_announce(&self, query: AnnounceRequest, remote_addr: SocketAddr) -> Result { + async fn handle_announce( + &self, + query: AnnounceRequest, + remote_addr: SocketAddr, + ) -> Result { let info_hash = match InfoHash::from_str(&query.info_hash) { Ok(v) => v, - Err(_) => { - return HttpServer::send_error("info_hash is invalid") - } + Err(_) => return HttpServer::send_error("info_hash is invalid"), }; - let peer = TorrentPeer::from_http_announce_request(&query, remote_addr, self.tracker.config.get_ext_ip()); - - match self.tracker.update_torrent_with_peer_and_get_stats(&info_hash, &peer).await { + let peer = TorrentPeer::from_http_announce_request( + &query, + remote_addr, + self.tracker.config.get_ext_ip(), + ); + + match self + .tracker + .update_torrent_with_peer_and_get_stats(&info_hash, &peer) + .await + { Err(e) => { debug!("{:?}", e); HttpServer::send_error("server error") } Ok(torrent_stats) => { // get all peers excluding the client_addr - let peers = self.tracker.get_torrent_peers(&info_hash, &peer.peer_addr).await; + let peers = self + .tracker + .get_torrent_peers(&info_hash, &peer.peer_addr) + .await; if peers.is_none() { debug!("No peers found after announce."); - return HttpServer::send_error("peer is invalid") + return HttpServer::send_error("peer is invalid"); } // todo: add http announce interval config option // success response - let announce_interval = self.tracker.config.http_tracker.as_ref().unwrap().announce_interval; - HttpServer::send_announce_response(&query, torrent_stats, peers.unwrap(), announce_interval) + let announce_interval = self + .tracker + .config + .http_tracker + .as_ref() + .unwrap() + .announce_interval; + HttpServer::send_announce_response( + &query, + torrent_stats, + peers.unwrap(), + announce_interval, + ) } } } - async fn handle_scrape(&self, info_hashes: Vec) -> Result { + async fn handle_scrape( + &self, + info_hashes: Vec, + ) -> Result { let mut res = ScrapeResponse { - files: HashMap::new() + files: HashMap::new(), }; let db = self.tracker.get_torrents().await; @@ -347,16 +422,14 @@ impl HttpServer { ScrapeResponseEntry { complete: seeders, downloaded: completed, - incomplete: leechers - } - } - None => { - ScrapeResponseEntry { - complete: 0, - downloaded: 0, - incomplete: 0 + incomplete: leechers, } } + None => ScrapeResponseEntry { + complete: 0, + downloaded: 0, + incomplete: 0, + }, }; if let Ok(encoded_info_hash) = url_encode_bytes(&info_hash.0) { diff --git a/src/key_manager.rs b/src/key_manager.rs index b1f16f1dc..d571700b9 100644 --- a/src/key_manager.rs +++ b/src/key_manager.rs @@ -1,10 +1,10 @@ use super::common::AUTH_KEY_LENGTH; use crate::utils::current_time; -use rand::{thread_rng, Rng}; +use derive_more::{Display, Error}; +use log::debug; use rand::distributions::Alphanumeric; +use rand::{thread_rng, Rng}; use serde::Serialize; -use log::debug; -use derive_more::{Display, Error}; pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { let key: String = thread_rng() @@ -13,7 +13,10 @@ pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { .map(char::from) .collect(); - debug!("Generated key: {}, valid for: {} seconds", key, seconds_valid); + debug!( + "Generated key: {}, valid for: {} seconds", + key, seconds_valid + ); AuthKey { key, @@ -23,8 +26,12 @@ pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { let current_time = current_time(); - if auth_key.valid_until.is_none() { return Err(Error::KeyInvalid) } - if auth_key.valid_until.unwrap() < current_time { return Err(Error::KeyExpired) } + if auth_key.valid_until.is_none() { + return Err(Error::KeyInvalid); + } + if auth_key.valid_until.unwrap() < current_time { + return Err(Error::KeyExpired); + } Ok(()) } @@ -67,7 +74,7 @@ pub enum Error { #[display(fmt = "Key is invalid.")] KeyInvalid, #[display(fmt = "Key has expired.")] - KeyExpired + KeyExpired, } impl From for Error { @@ -83,17 +90,10 @@ mod tests { #[test] fn auth_key_from_buffer() { - let auth_key = key_manager::AuthKey::from_buffer( - [ - 89, 90, 83, 108, - 52, 108, 77, 90, - 117, 112, 82, 117, - 79, 112, 83, 82, - 67, 51, 107, 114, - 73, 75, 82, 53, - 66, 80, 66, 49, - 52, 110, 114, 74] - ); + let auth_key = key_manager::AuthKey::from_buffer([ + 89, 90, 83, 108, 52, 108, 77, 90, 117, 112, 82, 117, 79, 112, 83, 82, 67, 51, 107, 114, + 73, 75, 82, 53, 66, 80, 66, 49, 52, 110, 114, 74, + ]); assert!(auth_key.is_some()); assert_eq!(auth_key.unwrap().key, "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ"); diff --git a/src/lib.rs b/src/lib.rs index 375c0f903..bb451e125 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,19 @@ -pub mod config; -pub mod udp_server; -pub mod http_server; -pub mod tracker; -pub mod http_api_server; pub mod common; -pub mod response; -pub mod utils; +pub mod config; pub mod database; +pub mod http_api_server; +pub mod http_server; pub mod key_manager; pub mod logging; +pub mod response; +pub mod tracker; +pub mod udp_server; +pub mod utils; +pub use self::common::*; pub use self::config::*; -pub use self::udp_server::*; -pub use self::http_server::*; -pub use self::tracker::*; pub use self::http_api_server::*; -pub use self::common::*; +pub use self::http_server::*; pub use self::response::*; +pub use self::tracker::*; +pub use self::udp_server::*; diff --git a/src/logging.rs b/src/logging.rs index 580e35094..9489e6fa8 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,22 +1,20 @@ -use log::info; use crate::Configuration; +use log::info; pub fn setup_logging(cfg: &Configuration) { let log_level = match &cfg.log_level { None => log::LevelFilter::Info, - Some(level) => { - match level.as_str() { - "off" => log::LevelFilter::Off, - "trace" => log::LevelFilter::Trace, - "debug" => log::LevelFilter::Debug, - "info" => log::LevelFilter::Info, - "warn" => log::LevelFilter::Warn, - "error" => log::LevelFilter::Error, - _ => { - panic!("Unknown log level encountered: '{}'", level.as_str()); - } + Some(level) => match level.as_str() { + "off" => log::LevelFilter::Off, + "trace" => log::LevelFilter::Trace, + "debug" => log::LevelFilter::Debug, + "info" => log::LevelFilter::Info, + "warn" => log::LevelFilter::Warn, + "error" => log::LevelFilter::Error, + _ => { + panic!("Unknown log level encountered: '{}'", level.as_str()); } - } + }, }; if let Err(_err) = fern::Dispatch::new() diff --git a/src/main.rs b/src/main.rs index 74a905c0d..93b0ba595 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,17 @@ -use log::{info}; -use torrust_tracker::{http_api_server, Configuration, TorrentTracker, UdpServer, HttpTrackerConfig, UdpTrackerConfig, HttpApiConfig, logging}; +use log::info; use std::sync::Arc; use tokio::task::JoinHandle; use torrust_tracker::http_server::HttpServer; +use torrust_tracker::{ + http_api_server, logging, Configuration, HttpApiConfig, HttpTrackerConfig, TorrentTracker, + UdpServer, UdpTrackerConfig, +}; #[tokio::main] async fn main() { let config = match Configuration::load_from_file() { Ok(config) => Arc::new(config), - Err(error) => { - panic!("{}", error) - } + Err(error) => panic!("{}", error), }; logging::setup_logging(&config); @@ -47,7 +48,10 @@ async fn main() { } } -fn start_torrent_cleanup_job(config: Arc, tracker: Arc) -> Option> { +fn start_torrent_cleanup_job( + config: Arc, + tracker: Arc, +) -> Option> { let weak_tracker = std::sync::Arc::downgrade(&tracker); let interval = config.cleanup_interval.unwrap_or(600); @@ -55,7 +59,7 @@ fn start_torrent_cleanup_job(config: Arc, tracker: Arc, tracker: Arc) -> JoinHandle<()> { @@ -77,10 +81,16 @@ fn start_api_server(config: &HttpApiConfig, tracker: Arc) -> Joi }) } -fn start_http_tracker_server(config: &HttpTrackerConfig, tracker: Arc) -> JoinHandle<()> { +fn start_http_tracker_server( + config: &HttpTrackerConfig, + tracker: Arc, +) -> JoinHandle<()> { info!("Starting HTTP server on: {}", config.bind_address); let http_tracker = Arc::new(HttpServer::new(tracker)); - let bind_addr = config.bind_address.parse::().unwrap(); + let bind_addr = config + .bind_address + .parse::() + .unwrap(); let ssl_enabled = config.ssl_enabled; let ssl_cert_path = config.ssl_cert_path.clone(); let ssl_key_path = config.ssl_key_path.clone(); @@ -93,15 +103,20 @@ fn start_http_tracker_server(config: &HttpTrackerConfig, tracker: Arc) -> JoinHandle<()> { +async fn start_udp_tracker_server( + config: &UdpTrackerConfig, + tracker: Arc, +) -> JoinHandle<()> { info!("Starting UDP server on: {}", config.bind_address); let udp_server = UdpServer::new(tracker).await.unwrap_or_else(|e| { panic!("Could not start UDP server: {}", e); diff --git a/src/response.rs b/src/response.rs index 9734e3769..ef9203196 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,10 +1,10 @@ -use std; -use std::io::{Write}; -use std::net::{SocketAddr}; -use byteorder::{NetworkEndian, WriteBytesExt}; use super::common::*; -use std::io; use crate::TorrentPeer; +use byteorder::{NetworkEndian, WriteBytesExt}; +use std; +use std::io; +use std::io::Write; +use std::net::SocketAddr; #[derive(PartialEq, Eq, Clone, Debug)] pub enum UdpResponse { @@ -83,7 +83,7 @@ impl UdpResponse { bytes.write_i32::(0)?; // 0 = connect bytes.write_i32::(r.transaction_id.0)?; bytes.write_i64::(r.connection_id.0)?; - }, + } UdpResponse::Announce(r) => { bytes.write_i32::(1)?; // 1 = announce bytes.write_i32::(r.transaction_id.0)?; @@ -103,7 +103,7 @@ impl UdpResponse { } } } - }, + } UdpResponse::Scrape(r) => { bytes.write_i32::(2)?; // 2 = scrape bytes.write_i32::(r.transaction_id.0)?; @@ -113,12 +113,12 @@ impl UdpResponse { bytes.write_i32::(torrent_stat.completed)?; bytes.write_i32::(torrent_stat.leechers)?; } - }, + } UdpResponse::Error(r) => { bytes.write_i32::(3)?; bytes.write_i32::(r.transaction_id.0)?; bytes.write_all(r.message.as_bytes())?; - }, + } } Ok(()) diff --git a/src/tracker.rs b/src/tracker.rs index 3e6bcca3e..0d6eb22ef 100644 --- a/src/tracker.rs +++ b/src/tracker.rs @@ -1,17 +1,17 @@ -use serde::{Deserialize, Serialize}; -use std::borrow::Cow; -use std::collections::BTreeMap; -use tokio::sync::RwLock; -use crate::common::{NumberOfBytes, InfoHash}; use super::common::*; -use std::net::{SocketAddr, IpAddr}; -use crate::{Configuration, http_server, key_manager, udp_server}; -use std::collections::btree_map::Entry; +use crate::common::{InfoHash, NumberOfBytes}; use crate::database::SqliteDatabase; -use std::sync::Arc; +use crate::key_manager::AuthKey; +use crate::{http_server, key_manager, udp_server, Configuration}; use log::debug; -use crate::key_manager::{AuthKey}; use r2d2_sqlite::rusqlite; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::collections::btree_map::Entry; +use std::collections::BTreeMap; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use tokio::sync::RwLock; const TWO_HOURS: std::time::Duration = std::time::Duration::from_secs(3600 * 2); const FIVE_MINUTES: std::time::Duration = std::time::Duration::from_secs(300); @@ -50,7 +50,11 @@ pub struct TorrentPeer { } impl TorrentPeer { - pub fn from_udp_announce_request(announce_request: &udp_server::AnnounceRequest, remote_addr: SocketAddr, peer_addr: Option) -> Self { + pub fn from_udp_announce_request( + announce_request: &udp_server::AnnounceRequest, + remote_addr: SocketAddr, + peer_addr: Option, + ) -> Self { // Potentially substitute localhost IP with external IP let peer_addr = match peer_addr { None => SocketAddr::new(IpAddr::from(remote_addr.ip()), announce_request.port.0), @@ -70,11 +74,15 @@ impl TorrentPeer { uploaded: announce_request.bytes_uploaded, downloaded: announce_request.bytes_downloaded, left: announce_request.bytes_left, - event: announce_request.event + event: announce_request.event, } } - pub fn from_http_announce_request(announce_request: &http_server::AnnounceRequest, remote_addr: SocketAddr, peer_addr: Option) -> Self { + pub fn from_http_announce_request( + announce_request: &http_server::AnnounceRequest, + remote_addr: SocketAddr, + peer_addr: Option, + ) -> Self { // Potentially substitute localhost IP with external IP let peer_addr = match peer_addr { None => SocketAddr::new(IpAddr::from(remote_addr.ip()), announce_request.port), @@ -92,7 +100,7 @@ impl TorrentPeer { "started" => AnnounceEvent::Started, "stopped" => AnnounceEvent::Stopped, "completed" => AnnounceEvent::Completed, - _ => AnnounceEvent::None + _ => AnnounceEvent::None, } } else { AnnounceEvent::None @@ -105,11 +113,13 @@ impl TorrentPeer { uploaded: announce_request.uploaded, downloaded: announce_request.downloaded, left: announce_request.left, - event + event, } } - fn is_seeder(&self) -> bool { self.left.0 <= 0 && self.event != AnnounceEvent::Stopped } + fn is_seeder(&self) -> bool { + self.left.0 <= 0 && self.event != AnnounceEvent::Stopped + } fn is_completed(&self) -> bool { self.event == AnnounceEvent::Completed @@ -159,7 +169,6 @@ impl TorrentEntry { .filter(|e| e.1.peer_addr.is_ipv4()) .take(MAX_SCRAPE_TORRENTS as usize) { - // skip ip address of client if peer.peer_addr == *remote_addr { //continue; @@ -174,7 +183,11 @@ impl TorrentEntry { self.peers.iter() } - pub fn update_torrent_stats_with_peer(&mut self, peer: &TorrentPeer, peer_old: Option) { + pub fn update_torrent_stats_with_peer( + &mut self, + peer: &TorrentPeer, + peer_old: Option, + ) { match peer_old { None => { if peer.is_seeder() { @@ -253,9 +266,8 @@ pub struct TorrentTracker { impl TorrentTracker { pub fn new(config: Arc) -> TorrentTracker { - let database = SqliteDatabase::new(&config.db_path).unwrap_or_else(|error| { - panic!("Could not create SQLite database. Reason: {}", error) - }); + let database = SqliteDatabase::new(&config.db_path) + .unwrap_or_else(|error| panic!("Could not create SQLite database. Reason: {}", error)); TorrentTracker { config, @@ -268,7 +280,9 @@ impl TorrentTracker { let auth_key = key_manager::generate_auth_key(seconds_valid); // add key to database - if let Err(error) = self.database.add_key_to_keys(&auth_key).await { return Err(error) } + if let Err(error) = self.database.add_key_to_keys(&auth_key).await { + return Err(error); + } Ok(auth_key) } @@ -282,95 +296,100 @@ impl TorrentTracker { key_manager::verify_auth_key(&db_key) } - pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), TorrentError> { + pub async fn authenticate_request( + &self, + info_hash: &InfoHash, + key: &Option, + ) -> Result<(), TorrentError> { match self.config.mode { TrackerMode::PublicMode => Ok(()), TrackerMode::ListedMode => { if !self.is_info_hash_whitelisted(info_hash).await { - return Err(TorrentError::TorrentNotWhitelisted) + return Err(TorrentError::TorrentNotWhitelisted); } Ok(()) } - TrackerMode::PrivateMode => { - match key { - Some(key) => { - if self.verify_auth_key(key).await.is_err() { - return Err(TorrentError::PeerKeyNotValid) - } - - Ok(()) - } - None => { - return Err(TorrentError::PeerNotAuthenticated) + TrackerMode::PrivateMode => match key { + Some(key) => { + if self.verify_auth_key(key).await.is_err() { + return Err(TorrentError::PeerKeyNotValid); } - } - } - TrackerMode::PrivateListedMode => { - match key { - Some(key) => { - if self.verify_auth_key(key).await.is_err() { - return Err(TorrentError::PeerKeyNotValid) - } - - if !self.is_info_hash_whitelisted(info_hash).await { - return Err(TorrentError::TorrentNotWhitelisted) - } - Ok(()) + Ok(()) + } + None => return Err(TorrentError::PeerNotAuthenticated), + }, + TrackerMode::PrivateListedMode => match key { + Some(key) => { + if self.verify_auth_key(key).await.is_err() { + return Err(TorrentError::PeerKeyNotValid); } - None => { - return Err(TorrentError::PeerNotAuthenticated) + + if !self.is_info_hash_whitelisted(info_hash).await { + return Err(TorrentError::TorrentNotWhitelisted); } + + Ok(()) } - } + None => return Err(TorrentError::PeerNotAuthenticated), + }, } } // Adding torrents is not relevant to public trackers. - pub async fn add_torrent_to_whitelist(&self, info_hash: &InfoHash) -> Result { - self.database.add_info_hash_to_whitelist(info_hash.clone()).await + pub async fn add_torrent_to_whitelist( + &self, + info_hash: &InfoHash, + ) -> Result { + self.database + .add_info_hash_to_whitelist(info_hash.clone()) + .await } // Removing torrents is not relevant to public trackers. - pub async fn remove_torrent_from_whitelist(&self, info_hash: &InfoHash) -> Result { - self.database.remove_info_hash_from_whitelist(info_hash.clone()).await + pub async fn remove_torrent_from_whitelist( + &self, + info_hash: &InfoHash, + ) -> Result { + self.database + .remove_info_hash_from_whitelist(info_hash.clone()) + .await } pub async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> bool { - match self.database.get_info_hash_from_whitelist(&info_hash.to_string()).await { + match self + .database + .get_info_hash_from_whitelist(&info_hash.to_string()) + .await + { Ok(_) => true, - Err(_) => false + Err(_) => false, } } - pub async fn get_torrent_peers( &self, info_hash: &InfoHash, - peer_addr: &std::net::SocketAddr + peer_addr: &std::net::SocketAddr, ) -> Option> { let read_lock = self.torrents.read().await; match read_lock.get(info_hash) { - None => { - None - } - Some(entry) => { - Some(entry.get_peers(peer_addr)) - } + None => None, + Some(entry) => Some(entry.get_peers(peer_addr)), } } - pub async fn update_torrent_with_peer_and_get_stats(&self, info_hash: &InfoHash, peer: &TorrentPeer) -> Result { + pub async fn update_torrent_with_peer_and_get_stats( + &self, + info_hash: &InfoHash, + peer: &TorrentPeer, + ) -> Result { let mut torrents = self.torrents.write().await; let torrent_entry = match torrents.entry(info_hash.clone()) { - Entry::Vacant(vacant) => { - Ok(vacant.insert(TorrentEntry::new())) - } - Entry::Occupied(entry) => { - Ok(entry.into_mut()) - } + Entry::Vacant(vacant) => Ok(vacant.insert(TorrentEntry::new())), + Entry::Occupied(entry) => Ok(entry.into_mut()), }; match torrent_entry { @@ -385,11 +404,13 @@ impl TorrentTracker { completed, }) } - Err(e) => Err(e) + Err(e) => Err(e), } } - pub async fn get_torrents(&self) -> tokio::sync::RwLockReadGuard<'_, BTreeMap> { + pub async fn get_torrents( + &self, + ) -> tokio::sync::RwLockReadGuard<'_, BTreeMap> { self.torrents.read().await } diff --git a/src/udp_server.rs b/src/udp_server.rs index 079ff67a0..8f0131829 100644 --- a/src/udp_server.rs +++ b/src/udp_server.rs @@ -1,19 +1,19 @@ -use log::{debug}; +use byteorder::{NetworkEndian, ReadBytesExt}; +use log::debug; use std; use std::convert::TryInto; use std::io; +use std::io::{Cursor, Read}; use std::net::{Ipv4Addr, SocketAddr}; use std::sync::Arc; -use std::io::{Cursor, Read}; use tokio::net::UdpSocket; -use byteorder::{NetworkEndian, ReadBytesExt}; use super::common::*; +use crate::key_manager::AuthKey; use crate::response::*; -use crate::utils::get_connection_id; use crate::tracker::TorrentTracker; -use crate::{TorrentPeer, TrackerMode, TorrentError}; -use crate::key_manager::AuthKey; +use crate::utils::get_connection_id; +use crate::{TorrentError, TorrentPeer, TrackerMode}; #[derive(PartialEq, Eq, Clone, Debug)] pub enum Request { @@ -114,8 +114,6 @@ impl Request { .read_i32::() .map_err(RequestParseError::io)?; - - match action { // Connect 0 => { @@ -123,7 +121,7 @@ impl Request { Ok((ConnectRequest { transaction_id: TransactionId(transaction_id), }) - .into()) + .into()) } else { Err(RequestParseError::text( transaction_id, @@ -208,7 +206,7 @@ impl Request { port: Port(port), auth_key, }) - .into()) + .into()) } // Scrape @@ -227,7 +225,7 @@ impl Request { transaction_id: TransactionId(transaction_id), info_hashes, }) - .into()) + .into()) } _ => Err(RequestParseError::text(transaction_id, "Invalid action")), @@ -250,48 +248,51 @@ impl UdpServer { }) } - pub async fn authenticate_announce_request(&self, announce_request: &AnnounceRequest) -> Result<(), TorrentError> { + pub async fn authenticate_announce_request( + &self, + announce_request: &AnnounceRequest, + ) -> Result<(), TorrentError> { match self.tracker.config.mode { TrackerMode::PublicMode => Ok(()), TrackerMode::ListedMode => { - if !self.tracker.is_info_hash_whitelisted(&announce_request.info_hash).await { - return Err(TorrentError::TorrentNotWhitelisted) + if !self + .tracker + .is_info_hash_whitelisted(&announce_request.info_hash) + .await + { + return Err(TorrentError::TorrentNotWhitelisted); } Ok(()) } - TrackerMode::PrivateMode => { - match &announce_request.auth_key { - Some(auth_key) => { - if self.tracker.verify_auth_key(auth_key).await.is_err() { - return Err(TorrentError::PeerKeyNotValid) - } - - Ok(()) + TrackerMode::PrivateMode => match &announce_request.auth_key { + Some(auth_key) => { + if self.tracker.verify_auth_key(auth_key).await.is_err() { + return Err(TorrentError::PeerKeyNotValid); } - None => { - return Err(TorrentError::PeerNotAuthenticated) - } - } - } - TrackerMode::PrivateListedMode => { - match &announce_request.auth_key { - Some(auth_key) => { - if self.tracker.verify_auth_key(auth_key).await.is_err() { - return Err(TorrentError::PeerKeyNotValid) - } - - if !self.tracker.is_info_hash_whitelisted(&announce_request.info_hash).await { - return Err(TorrentError::TorrentNotWhitelisted) - } - Ok(()) + Ok(()) + } + None => return Err(TorrentError::PeerNotAuthenticated), + }, + TrackerMode::PrivateListedMode => match &announce_request.auth_key { + Some(auth_key) => { + if self.tracker.verify_auth_key(auth_key).await.is_err() { + return Err(TorrentError::PeerKeyNotValid); } - None => { - return Err(TorrentError::PeerNotAuthenticated) + + if !self + .tracker + .is_info_hash_whitelisted(&announce_request.info_hash) + .await + { + return Err(TorrentError::TorrentNotWhitelisted); } + + Ok(()) } - } + None => return Err(TorrentError::PeerNotAuthenticated), + }, } } @@ -321,27 +322,44 @@ impl UdpServer { match request { Request::Connect(r) => self.handle_connect(remote_addr, r).await, Request::Announce(r) => { - match self.tracker.authenticate_request(&r.info_hash, &r.auth_key).await { + match self + .tracker + .authenticate_request(&r.info_hash, &r.auth_key) + .await + { Ok(()) => self.handle_announce(remote_addr, r).await, - Err(e) => { - match e { - TorrentError::TorrentNotWhitelisted => { - debug!("Info_hash not whitelisted."); - self.send_error(remote_addr, &r.transaction_id, "torrent not whitelisted").await; - } - TorrentError::PeerKeyNotValid => { - debug!("Peer key not valid."); - self.send_error(remote_addr, &r.transaction_id, "peer key not valid").await; - } - TorrentError::PeerNotAuthenticated => { - debug!("Peer not authenticated."); - self.send_error(remote_addr, &r.transaction_id, "peer not authenticated").await; - } + Err(e) => match e { + TorrentError::TorrentNotWhitelisted => { + debug!("Info_hash not whitelisted."); + self.send_error( + remote_addr, + &r.transaction_id, + "torrent not whitelisted", + ) + .await; } - } + TorrentError::PeerKeyNotValid => { + debug!("Peer key not valid."); + self.send_error( + remote_addr, + &r.transaction_id, + "peer key not valid", + ) + .await; + } + TorrentError::PeerNotAuthenticated => { + debug!("Peer not authenticated."); + self.send_error( + remote_addr, + &r.transaction_id, + "peer not authenticated", + ) + .await; + } + }, } - }, - Request::Scrape(r) => self.handle_scrape(remote_addr, r).await + } + Request::Scrape(r) => self.handle_scrape(remote_addr, r).await, } } Err(err) => { @@ -363,12 +381,24 @@ impl UdpServer { } async fn handle_announce(&self, remote_addr: SocketAddr, request: AnnounceRequest) { - let peer = TorrentPeer::from_udp_announce_request(&request, remote_addr, self.tracker.config.get_ext_ip()); - - match self.tracker.update_torrent_with_peer_and_get_stats(&request.info_hash, &peer).await { + let peer = TorrentPeer::from_udp_announce_request( + &request, + remote_addr, + self.tracker.config.get_ext_ip(), + ); + + match self + .tracker + .update_torrent_with_peer_and_get_stats(&request.info_hash, &peer) + .await + { Ok(torrent_stats) => { // get all peers excluding the client_addr - let peers = match self.tracker.get_torrent_peers(&request.info_hash, &peer.peer_addr).await { + let peers = match self + .tracker + .get_torrent_peers(&request.info_hash, &peer.peer_addr) + .await + { Some(v) => v, None => { debug!("announce: No peers found."); @@ -389,7 +419,8 @@ impl UdpServer { } Err(e) => { debug!("{:?}", e); - self.send_error(remote_addr, &request.transaction_id, "error adding torrent").await; + self.send_error(remote_addr, &request.transaction_id, "error adding torrent") + .await; } } } @@ -414,13 +445,11 @@ impl UdpServer { leechers: leechers as i32, } } - None => { - UdpScrapeResponseEntry { - seeders: 0, - completed: 0, - leechers: 0, - } - } + None => UdpScrapeResponseEntry { + seeders: 0, + completed: 0, + leechers: 0, + }, }; scrape_response.torrent_stats.push(scrape_entry); @@ -431,7 +460,11 @@ impl UdpServer { let _ = self.send_response(remote_addr, response).await; } - async fn send_response(&self, remote_addr: SocketAddr, response: UdpResponse) -> Result { + async fn send_response( + &self, + remote_addr: SocketAddr, + response: UdpResponse, + ) -> Result { debug!("sending response to: {:?}", &remote_addr); let buffer = vec![0u8; MAX_PACKET_SIZE]; @@ -458,17 +491,26 @@ impl UdpServer { } } - async fn send_packet(&self, remote_addr: &SocketAddr, payload: &[u8]) -> Result { + async fn send_packet( + &self, + remote_addr: &SocketAddr, + payload: &[u8], + ) -> Result { match self.socket.send_to(payload, remote_addr).await { Err(err) => { debug!("failed to send a packet: {}", err); Err(err) - }, + } Ok(sz) => Ok(sz), } } - async fn send_error(&self, remote_addr: SocketAddr, transaction_id: &TransactionId, error_msg: &str) { + async fn send_error( + &self, + remote_addr: SocketAddr, + transaction_id: &TransactionId, + error_msg: &str, + ) { let error_response = UdpErrorResponse { action: Actions::Error, transaction_id: transaction_id.clone(), diff --git a/src/utils.rs b/src/utils.rs index 11c61e4fb..405796489 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,19 +1,22 @@ -use std::net::SocketAddr; use crate::common::*; -use std::time::SystemTime; use std::error::Error; use std::fmt::Write; +use std::net::SocketAddr; +use std::time::SystemTime; pub fn get_connection_id(remote_address: &SocketAddr) -> ConnectionId { match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) { - Ok(duration) => ConnectionId(((duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36)) as i64), + Ok(duration) => ConnectionId( + ((duration.as_secs() / 3600) | ((remote_address.port() as u64) << 36)) as i64, + ), Err(_) => ConnectionId(0x7FFFFFFFFFFFFFFF), } } pub fn current_time() -> u64 { SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH).unwrap() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() .as_secs() } From a41961e58162f4e5ed29642b535da4820239260b Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 09:01:32 +0100 Subject: [PATCH 03/10] refactoring database.rs --- src/database.rs | 115 +++++++++++++++++------------------------------- 1 file changed, 41 insertions(+), 74 deletions(-) diff --git a/src/database.rs b/src/database.rs index 9dda516e5..7fce68621 100644 --- a/src/database.rs +++ b/src/database.rs @@ -4,6 +4,7 @@ use log::debug; use r2d2::Pool; use r2d2_sqlite::rusqlite::NO_PARAMS; use r2d2_sqlite::{rusqlite, SqliteConnectionManager}; +use std::convert::TryInto; use std::str::FromStr; pub struct SqliteDatabase { @@ -15,13 +16,9 @@ impl SqliteDatabase { let sqlite_connection_manager = SqliteConnectionManager::file(db_path); let sqlite_pool = r2d2::Pool::new(sqlite_connection_manager) .expect("Failed to create r2d2 SQLite connection pool."); - let sqlite_database = SqliteDatabase { pool: sqlite_pool }; + SqliteDatabase::create_database_tables(&sqlite_pool)?; - if let Err(error) = SqliteDatabase::create_database_tables(&sqlite_database.pool) { - return Err(error); - }; - - Ok(sqlite_database) + Ok(SqliteDatabase { pool: sqlite_pool }) } pub fn create_database_tables( @@ -31,8 +28,7 @@ impl SqliteDatabase { CREATE TABLE IF NOT EXISTS whitelist ( id integer PRIMARY KEY AUTOINCREMENT, info_hash VARCHAR(20) NOT NULL UNIQUE - );" - .to_string(); + );"; let create_keys_table = format!( " @@ -41,23 +37,16 @@ impl SqliteDatabase { key VARCHAR({}) NOT NULL UNIQUE, valid_until INT(10) NOT NULL );", - AUTH_KEY_LENGTH as i8 + AUTH_KEY_LENGTH ); let conn = pool.get().unwrap(); - match conn.execute(&create_whitelist_table, NO_PARAMS) { - Ok(updated) => match conn.execute(&create_keys_table, NO_PARAMS) { - Ok(updated2) => Ok(updated + updated2), - Err(e) => { - debug!("{:?}", e); - Err(e) - } - }, - Err(e) => { - debug!("{:?}", e); - Err(e) - } - } + conn.execute(create_whitelist_table, NO_PARAMS) + .and_then(|updated| { + conn.execute(&create_keys_table, NO_PARAMS) + .map(|updated2| updated + updated2) + }) + .map_err(trace_debug) } pub async fn get_info_hash_from_whitelist( @@ -83,21 +72,12 @@ impl SqliteDatabase { info_hash: InfoHash, ) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute( + conn.execute( "INSERT INTO whitelist (info_hash) VALUES (?)", &[info_hash.to_string()], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(rusqlite::Error::ExecuteReturnedResults) - } - Err(e) => { - debug!("{:?}", e); - Err(e) - } - } + ) + .map_err(trace_debug) + .and_then(validate_updated) } pub async fn remove_info_hash_from_whitelist( @@ -105,21 +85,12 @@ impl SqliteDatabase { info_hash: InfoHash, ) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute( + conn.execute( "DELETE FROM whitelist WHERE info_hash = ?", &[info_hash.to_string()], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(rusqlite::Error::ExecuteReturnedResults) - } - Err(e) => { - debug!("{:?}", e); - Err(e) - } - } + ) + .map_err(trace_debug) + .and_then(validate_updated) } pub async fn get_key_from_keys(&self, key: &str) -> Result { @@ -133,7 +104,7 @@ impl SqliteDatabase { Ok(AuthKey { key, - valid_until: Some(valid_until_i64 as u64), + valid_until: Some(valid_until_i64.try_into().unwrap()), }) } else { Err(rusqlite::Error::QueryReturnedNoRows) @@ -142,39 +113,35 @@ impl SqliteDatabase { pub async fn add_key_to_keys(&self, auth_key: &AuthKey) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute( + + conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", &[ auth_key.key.to_string(), auth_key.valid_until.unwrap().to_string(), ], - ) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(rusqlite::Error::ExecuteReturnedResults) - } - Err(e) => { - debug!("{:?}", e); - Err(e) - } - } + ) + .map_err(trace_debug) + .and_then(validate_updated) } pub async fn remove_key_from_keys(&self, key: String) -> Result { let conn = self.pool.get().unwrap(); - match conn.execute("DELETE FROM keys WHERE key = ?", &[key]) { - Ok(updated) => { - if updated > 0 { - return Ok(updated); - } - Err(rusqlite::Error::ExecuteReturnedResults) - } - Err(e) => { - debug!("{:?}", e); - Err(e) - } - } + conn.execute("DELETE FROM keys WHERE key = ?", &[key]) + .map_err(trace_debug) + .and_then(validate_updated) + } +} + +fn trace_debug(value: T) -> T { + debug!("{:?}", value); + value +} + +fn validate_updated(updated: usize) -> Result { + if updated > 0 { + Ok(updated) + } else { + Err(rusqlite::Error::ExecuteReturnedResults) } } From c111891c931199ec8b0e7ef19a442333fe669e10 Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 09:11:07 +0100 Subject: [PATCH 04/10] updated readme --- README.md | 66 +------------------------------------------------------ 1 file changed, 1 insertion(+), 65 deletions(-) diff --git a/README.md b/README.md index 3bb8c3e56..1b5fd4b2c 100644 --- a/README.md +++ b/README.md @@ -1,67 +1,3 @@ # Torrust Tracker -![Test](https://github.com/torrust/torrust-tracker/actions/workflows/test_build_release.yml/badge.svg) -## Project Description -Torrust Tracker is a lightweight but incredibly powerful and feature-rich BitTorrent tracker made using Rust. - - -### Features -* [X] UDP server -* [X] HTTP (optional SSL) server -* [X] Private & Whitelisted mode -* [X] Built-in API -* [X] Torrent whitelisting -* [X] Peer authentication using time-bound keys - -### Implemented BEPs -* [BEP 15](http://www.bittorrent.org/beps/bep_0015.html): UDP Tracker Protocol for BitTorrent -* [BEP 23](http://bittorrent.org/beps/bep_0023.html): Tracker Returns Compact Peer Lists -* [BEP 27](http://bittorrent.org/beps/bep_0027.html): Private Torrents -* [BEP 41](http://bittorrent.org/beps/bep_0041.html): UDP Tracker Protocol Extensions -* [BEP 48](http://bittorrent.org/beps/bep_0048.html): Tracker Protocol Extension: Scrape - -## Getting Started -You can get the latest binaries from [releases](https://github.com/torrust/torrust-tracker/releases) or follow the install from scratch instructions below. - -### Install From Scratch -1. Clone the repo. -```bash -git clone https://github.com/torrust/torrust-tracker.git -cd torrust-tracker -``` - -2. Build the source code. -```bash -cargo build --release -``` - -3. Copy binaries: `torrust-tracker/target/torrust-tracker` to a new folder. - -### Usage -1. Navigate to the folder you put the torrust-tracker binaries in. - - -2. Run the torrust-tracker once to create the `config.toml` file: -```bash -./torrust-tracker -``` - - -3. Edit the newly created config.toml file in the same folder as your torrust-tracker binaries according to your liking. See [configuration documentation](https://torrust.github.io/torrust-documentation/torrust-tracker/config/). - - -4. Run the torrust-tracker again: -```bash -./torrust-tracker -``` - -### Tracker URL -Your tracker will be `udp://tracker-ip:port/announce` or `https://tracker-ip:port/announce` depending on your tracker mode. -In private mode, tracker keys are added after the tracker URL like: `https://tracker-ip:port/announce/tracker-key`. - -### Built-in API -Read the API documentation [here](https://torrust.github.io/torrust-documentation/torrust-tracker/api/). - -### Credits -This project was a joint effort by [Nautilus Cyberneering GmbH](https://nautilus-cyberneering.de/) and [Dutch Bits](https://dutchbits.nl). -Also thanks to [Naim A.](https://github.com/naim94a/udpt) for some parts of the code. +This fork of https://github.com/torrust/torrust-tracker exists so that I can mess around with it :3 From 700f67372cbb179a153b31f7cc8c7cd5762f043c Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 09:42:45 +0100 Subject: [PATCH 05/10] validate logging configuration at parse time --- src/config.rs | 33 +++++++++++++++++++++++++++++++-- src/logging.rs | 18 +++--------------- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/config.rs b/src/config.rs index ac67fb1ad..2b612eee5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -40,9 +40,38 @@ pub struct HttpApiConfig { pub access_tokens: HashMap, } +#[derive(Deserialize, Serialize, Copy, Clone, Debug)] +pub enum LogLevel { + #[serde(rename = "off")] + Off, + #[serde(rename = "trace")] + Trace, + #[serde(rename = "debug")] + Debug, + #[serde(rename = "info")] + Info, + #[serde(rename = "warn")] + Warn, + #[serde(rename = "error")] + Error, +} + +impl Into for LogLevel { + fn into(self) -> log::LevelFilter { + match self { + LogLevel::Off => log::LevelFilter::Off, + LogLevel::Trace => log::LevelFilter::Trace, + LogLevel::Debug => log::LevelFilter::Debug, + LogLevel::Info => log::LevelFilter::Info, + LogLevel::Warn => log::LevelFilter::Warn, + LogLevel::Error => log::LevelFilter::Error, + } + } +} + #[derive(Serialize, Deserialize)] pub struct Configuration { - pub log_level: Option, + pub log_level: Option, pub mode: TrackerMode, pub db_path: String, pub cleanup_interval: Option, @@ -110,7 +139,7 @@ impl Configuration { impl Configuration { pub fn default() -> Configuration { Configuration { - log_level: Option::from(String::from("info")), + log_level: Some(LogLevel::Info), mode: TrackerMode::PublicMode, db_path: String::from("data.db"), cleanup_interval: Some(600), diff --git a/src/logging.rs b/src/logging.rs index 9489e6fa8..201dd0e49 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,21 +1,9 @@ +use crate::config::LogLevel; use crate::Configuration; use log::info; pub fn setup_logging(cfg: &Configuration) { - let log_level = match &cfg.log_level { - None => log::LevelFilter::Info, - Some(level) => match level.as_str() { - "off" => log::LevelFilter::Off, - "trace" => log::LevelFilter::Trace, - "debug" => log::LevelFilter::Debug, - "info" => log::LevelFilter::Info, - "warn" => log::LevelFilter::Warn, - "error" => log::LevelFilter::Error, - _ => { - panic!("Unknown log level encountered: '{}'", level.as_str()); - } - }, - }; + let log_level = cfg.log_level.unwrap_or(LogLevel::Info); if let Err(_err) = fern::Dispatch::new() .format(|out, message, record| { @@ -27,7 +15,7 @@ pub fn setup_logging(cfg: &Configuration) { message )) }) - .level(log_level) + .level(log_level.into()) .chain(std::io::stdout()) .apply() { From 62799af1e97abff64c4bdad331caec5c02c1e689 Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 09:48:23 +0100 Subject: [PATCH 06/10] validate external_ip parameter at parse time --- src/config.rs | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/config.rs b/src/config.rs index 2b612eee5..6bc7e0676 100644 --- a/src/config.rs +++ b/src/config.rs @@ -75,7 +75,7 @@ pub struct Configuration { pub mode: TrackerMode, pub db_path: String, pub cleanup_interval: Option, - pub external_ip: Option, + pub external_ip: Option, pub udp_tracker: UdpTrackerConfig, pub http_tracker: Option, pub http_api: Option, @@ -126,13 +126,7 @@ impl Configuration { } pub fn get_ext_ip(&self) -> Option { - match &self.external_ip { - None => None, - Some(external_ip) => match IpAddr::from_str(external_ip) { - Ok(external_ip) => Some(external_ip), - Err(_) => None, - }, - } + self.external_ip.clone() } } @@ -143,7 +137,7 @@ impl Configuration { mode: TrackerMode::PublicMode, db_path: String::from("data.db"), cleanup_interval: Some(600), - external_ip: Some(String::from("0.0.0.0")), + external_ip: IpAddr::from_str("0.0.0.0").ok(), udp_tracker: UdpTrackerConfig { bind_address: String::from("0.0.0.0:6969"), announce_interval: 120, From 75dd27fa8207fdce41c431a9c27bfb8c11b7ae05 Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 13:11:30 +0100 Subject: [PATCH 07/10] refactor: one way to eliminate 'boolean blindness' --- src/key_manager.rs | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/key_manager.rs b/src/key_manager.rs index d571700b9..57cb23e94 100644 --- a/src/key_manager.rs +++ b/src/key_manager.rs @@ -26,14 +26,12 @@ pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { let current_time = current_time(); - if auth_key.valid_until.is_none() { - return Err(Error::KeyInvalid); - } - if auth_key.valid_until.unwrap() < current_time { - return Err(Error::KeyExpired); - } - Ok(()) + match auth_key.valid_until { + Some(valid_until) if valid_until < current_time => Ok(()), + Some(_) => Err(Error::KeyExpired), + None => Err(Error::KeyInvalid), + } } #[derive(Serialize, Debug, Eq, PartialEq, Clone)] From a29d1e0115991288d46f5f1464c8f6e8e7da428b Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 13:22:29 +0100 Subject: [PATCH 08/10] refactor: move functions into an impl, use Self syntax --- src/key_manager.rs | 78 +++++++++++++++++++++++----------------------- src/tracker.rs | 8 +++-- 2 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/key_manager.rs b/src/key_manager.rs index 57cb23e94..70efe01e8 100644 --- a/src/key_manager.rs +++ b/src/key_manager.rs @@ -6,34 +6,6 @@ use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use serde::Serialize; -pub fn generate_auth_key(seconds_valid: u64) -> AuthKey { - let key: String = thread_rng() - .sample_iter(&Alphanumeric) - .take(AUTH_KEY_LENGTH) - .map(char::from) - .collect(); - - debug!( - "Generated key: {}, valid for: {} seconds", - key, seconds_valid - ); - - AuthKey { - key, - valid_until: Some(current_time() + seconds_valid), - } -} - -pub fn verify_auth_key(auth_key: &AuthKey) -> Result<(), Error> { - let current_time = current_time(); - - match auth_key.valid_until { - Some(valid_until) if valid_until < current_time => Ok(()), - Some(_) => Err(Error::KeyExpired), - None => Err(Error::KeyInvalid), - } -} - #[derive(Serialize, Debug, Eq, PartialEq, Clone)] pub struct AuthKey { pub key: String, @@ -41,9 +13,37 @@ pub struct AuthKey { } impl AuthKey { - pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { + pub fn generate(seconds_valid: u64) -> Self { + let key: String = thread_rng() + .sample_iter(&Alphanumeric) + .take(AUTH_KEY_LENGTH) + .map(char::from) + .collect(); + + debug!( + "Generated key: {}, valid for: {} seconds", + key, seconds_valid + ); + + Self { + key, + valid_until: Some(current_time() + seconds_valid), + } + } + + pub fn verify(&self) -> Result<(), Error> { + let current_time = current_time(); + + match self.valid_until { + Some(valid_until) if valid_until < current_time => Ok(()), + Some(_) => Err(Error::KeyExpired), + None => Err(Error::KeyInvalid), + } + } + + pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { if let Ok(key) = String::from_utf8(Vec::from(key_buffer)) { - Some(AuthKey { + Some(Self { key, valid_until: None, }) @@ -52,11 +52,11 @@ impl AuthKey { } } - pub fn from_string(key: &str) -> Option { + pub fn from_string(key: &str) -> Option { if key.len() != AUTH_KEY_LENGTH { None } else { - Some(AuthKey { + Some(Self { key: key.to_string(), valid_until: None, }) @@ -84,11 +84,11 @@ impl From for Error { #[cfg(test)] mod tests { - use crate::key_manager; + use crate::key_manager::AuthKey; #[test] fn auth_key_from_buffer() { - let auth_key = key_manager::AuthKey::from_buffer([ + let auth_key = AuthKey::from_buffer([ 89, 90, 83, 108, 52, 108, 77, 90, 117, 112, 82, 117, 79, 112, 83, 82, 67, 51, 107, 114, 73, 75, 82, 53, 66, 80, 66, 49, 52, 110, 114, 74, ]); @@ -100,7 +100,7 @@ mod tests { #[test] fn auth_key_from_string() { let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ"; - let auth_key = key_manager::AuthKey::from_string(key_string); + let auth_key = AuthKey::from_string(key_string); assert!(auth_key.is_some()); assert_eq!(auth_key.unwrap().key, key_string); @@ -108,16 +108,16 @@ mod tests { #[test] fn generate_valid_auth_key() { - let auth_key = key_manager::generate_auth_key(9999); + let auth_key = AuthKey::generate(9999); - assert!(key_manager::verify_auth_key(&auth_key).is_ok()); + assert!(&auth_key.verify().is_ok()); } #[test] fn generate_expired_auth_key() { - let mut auth_key = key_manager::generate_auth_key(0); + let mut auth_key = AuthKey::generate(0); auth_key.valid_until = Some(0); - assert!(key_manager::verify_auth_key(&auth_key).is_err()); + assert!(&auth_key.verify().is_err()); } } diff --git a/src/tracker.rs b/src/tracker.rs index 0d6eb22ef..f2000b81c 100644 --- a/src/tracker.rs +++ b/src/tracker.rs @@ -277,7 +277,7 @@ impl TorrentTracker { } pub async fn generate_auth_key(&self, seconds_valid: u64) -> Result { - let auth_key = key_manager::generate_auth_key(seconds_valid); + let auth_key = key_manager::AuthKey::generate(seconds_valid); // add key to database if let Err(error) = self.database.add_key_to_keys(&auth_key).await { @@ -292,8 +292,10 @@ impl TorrentTracker { } pub async fn verify_auth_key(&self, auth_key: &AuthKey) -> Result<(), key_manager::Error> { - let db_key = self.database.get_key_from_keys(&auth_key.key).await?; - key_manager::verify_auth_key(&db_key) + self.database + .get_key_from_keys(&auth_key.key) + .await? + .verify() } pub async fn authenticate_request( From ad1b15716a121d1374a74d339bb67413ef54db81 Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 13:28:55 +0100 Subject: [PATCH 09/10] refactor: further refactoring of key_manager.rs --- src/key_manager.rs | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/key_manager.rs b/src/key_manager.rs index 70efe01e8..a51dae7d4 100644 --- a/src/key_manager.rs +++ b/src/key_manager.rs @@ -42,25 +42,19 @@ impl AuthKey { } pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { - if let Ok(key) = String::from_utf8(Vec::from(key_buffer)) { - Some(Self { + String::from_utf8(Vec::from(key_buffer)) + .ok() + .map(|key| Self { key, valid_until: None, }) - } else { - None - } } pub fn from_string(key: &str) -> Option { - if key.len() != AUTH_KEY_LENGTH { - None - } else { - Some(Self { - key: key.to_string(), - valid_until: None, - }) - } + (key.len() == AUTH_KEY_LENGTH).then(|| Self { + key: key.to_string(), + valid_until: None, + }) } } @@ -77,7 +71,7 @@ pub enum Error { impl From for Error { fn from(e: r2d2_sqlite::rusqlite::Error) -> Self { - eprintln!("{}", e); + debug!("{}", e); Error::KeyVerificationError } } From 07f012437493a558f8d5214e7c686fae3e85cf10 Mon Sep 17 00:00:00 2001 From: Long Huynh Huu Date: Thu, 27 Jan 2022 13:29:44 +0100 Subject: [PATCH 10/10] fix: wrong inequality --- src/key_manager.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/key_manager.rs b/src/key_manager.rs index a51dae7d4..1f75735aa 100644 --- a/src/key_manager.rs +++ b/src/key_manager.rs @@ -35,7 +35,7 @@ impl AuthKey { let current_time = current_time(); match self.valid_until { - Some(valid_until) if valid_until < current_time => Ok(()), + Some(valid_until) if valid_until > current_time => Ok(()), Some(_) => Err(Error::KeyExpired), None => Err(Error::KeyInvalid), }