From 7cdd63ee42b4868734038280c6a4f83e07c511ad Mon Sep 17 00:00:00 2001 From: Jose Celano Date: Mon, 27 Feb 2023 18:52:25 +0000 Subject: [PATCH 1/2] refactor: [#171] use KeyId in auth:Key The struct `KeyId` was extracted to wrap the primitive type but it was not being used in the `auth::Key` struct. --- src/apis/resources/auth_key.rs | 14 +++---- src/databases/mod.rs | 3 ++ src/databases/mysql.rs | 8 ++-- src/databases/sqlite.rs | 11 +++--- src/http/warp_implementation/filters.rs | 16 ++++++-- src/http/warp_implementation/handlers.rs | 13 ++++--- src/http/warp_implementation/routes.rs | 6 +-- src/tracker/auth.rs | 49 +++++++++++++----------- src/tracker/error.rs | 4 +- src/tracker/mod.rs | 24 +++++++----- tests/tracker_api.rs | 12 +++--- 11 files changed, 91 insertions(+), 69 deletions(-) diff --git a/src/apis/resources/auth_key.rs b/src/apis/resources/auth_key.rs index d5c08f496..207a0c482 100644 --- a/src/apis/resources/auth_key.rs +++ b/src/apis/resources/auth_key.rs @@ -3,18 +3,18 @@ use std::convert::From; use serde::{Deserialize, Serialize}; use crate::protocol::clock::DurationSinceUnixEpoch; -use crate::tracker::auth; +use crate::tracker::auth::{self, KeyId}; #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct AuthKey { - pub key: String, + pub key: String, // todo: rename to `id` pub valid_until: Option, } impl From for auth::Key { fn from(auth_key_resource: AuthKey) -> Self { auth::Key { - key: auth_key_resource.key, + id: auth_key_resource.key.parse::().unwrap(), valid_until: auth_key_resource .valid_until .map(|valid_until| DurationSinceUnixEpoch::new(valid_until, 0)), @@ -25,7 +25,7 @@ impl From for auth::Key { impl From for AuthKey { fn from(auth_key: auth::Key) -> Self { AuthKey { - key: auth_key.key, + key: auth_key.id.to_string(), valid_until: auth_key.valid_until.map(|valid_until| valid_until.as_secs()), } } @@ -37,7 +37,7 @@ mod tests { use super::AuthKey; use crate::protocol::clock::{Current, TimeNow}; - use crate::tracker::auth; + use crate::tracker::auth::{self, KeyId}; #[test] fn it_should_be_convertible_into_an_auth_key() { @@ -51,7 +51,7 @@ mod tests { assert_eq!( auth::Key::from(auth_key_resource), auth::Key { - key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line + id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::().unwrap(), // cspell:disable-line valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()) } ); @@ -62,7 +62,7 @@ mod tests { let duration_in_secs = 60; let auth_key = auth::Key { - key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line + id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::().unwrap(), // cspell:disable-line valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()), }; diff --git a/src/databases/mod.rs b/src/databases/mod.rs index 809decc2c..70cc9eb75 100644 --- a/src/databases/mod.rs +++ b/src/databases/mod.rs @@ -63,16 +63,19 @@ pub trait Database: Sync + Send { async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>; + // todo: replace type `&str` with `&InfoHash` async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result, Error>; async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result; async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result; + // todo: replace type `&str` with `&KeyId` async fn get_key_from_keys(&self, key: &str) -> Result, Error>; async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result; + // todo: replace type `&str` with `&KeyId` async fn remove_key_from_keys(&self, key: &str) -> Result; async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result { diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index ac54ebb82..532ba1dcb 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -12,7 +12,7 @@ use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::common::AUTH_KEY_LENGTH; use crate::protocol::info_hash::InfoHash; -use crate::tracker::auth; +use crate::tracker::auth::{self, KeyId}; const DRIVER: Driver = Driver::MySQL; @@ -117,7 +117,7 @@ impl Database for Mysql { let keys = conn.query_map( "SELECT `key`, valid_until FROM `keys`", |(key, valid_until): (String, i64)| auth::Key { - key, + id: key.parse::().unwrap(), valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), }, )?; @@ -192,7 +192,7 @@ impl Database for Mysql { let key = query?; Ok(key.map(|(key, expiry)| auth::Key { - key, + id: key.parse::().unwrap(), valid_until: Some(Duration::from_secs(expiry.unsigned_abs())), })) } @@ -200,7 +200,7 @@ impl Database for Mysql { async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; - let key = auth_key.key.to_string(); + let key = auth_key.id.to_string(); let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string(); conn.exec_drop( diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index 3425b15c8..d6915c850 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -9,7 +9,7 @@ use super::driver::Driver; use crate::databases::{Database, Error}; use crate::protocol::clock::DurationSinceUnixEpoch; use crate::protocol::info_hash::InfoHash; -use crate::tracker::auth; +use crate::tracker::auth::{self, KeyId}; const DRIVER: Driver = Driver::Sqlite3; @@ -108,11 +108,11 @@ impl Database for Sqlite { let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?; let keys_iter = stmt.query_map([], |row| { - let key = row.get(0)?; + let key: String = row.get(0)?; let valid_until: i64 = row.get(1)?; Ok(auth::Key { - key, + id: key.parse::().unwrap(), valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())), }) })?; @@ -211,8 +211,9 @@ impl Database for Sqlite { Ok(key.map(|f| { let expiry: i64 = f.get(1).unwrap(); + let id: String = f.get(0).unwrap(); auth::Key { - key: f.get(0).unwrap(), + id: id.parse::().unwrap(), valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())), } })) @@ -223,7 +224,7 @@ impl Database for Sqlite { let insert = conn.execute( "INSERT INTO keys (key, valid_until) VALUES (?1, ?2)", - [auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], + [auth_key.id.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()], )?; if insert == 0 { diff --git a/src/http/warp_implementation/filters.rs b/src/http/warp_implementation/filters.rs index fc8ef20bc..eb7abcd4d 100644 --- a/src/http/warp_implementation/filters.rs +++ b/src/http/warp_implementation/filters.rs @@ -1,6 +1,7 @@ use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; use std::panic::Location; +use std::str::FromStr; use std::sync::Arc; use warp::{reject, Filter, Rejection}; @@ -11,7 +12,8 @@ use super::{request, WebResult}; use crate::http::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id}; use crate::protocol::common::MAX_SCRAPE_TORRENTS; use crate::protocol::info_hash::InfoHash; -use crate::tracker::{self, auth, peer}; +use crate::tracker::auth::KeyId; +use crate::tracker::{self, peer}; /// Pass Arc along #[must_use] @@ -35,10 +37,16 @@ pub fn with_peer_id() -> impl Filter + /// Pass Arc along #[must_use] -pub fn with_auth_key() -> impl Filter,), Error = Infallible> + Clone { +pub fn with_auth_key_id() -> impl Filter,), Error = Infallible> + Clone { warp::path::param::() - .map(|key: String| auth::Key::from_string(&key)) - .or_else(|_| async { Ok::<(Option,), Infallible>((None,)) }) + .map(|key: String| { + let key_id = KeyId::from_str(&key); + match key_id { + Ok(id) => Some(id), + Err(_) => None, + } + }) + .or_else(|_| async { Ok::<(Option,), Infallible>((None,)) }) } /// Check for `PeerAddress` diff --git a/src/http/warp_implementation/handlers.rs b/src/http/warp_implementation/handlers.rs index 400cc5762..6019bf016 100644 --- a/src/http/warp_implementation/handlers.rs +++ b/src/http/warp_implementation/handlers.rs @@ -12,6 +12,7 @@ use super::error::Error; use super::{request, response, WebResult}; use crate::http::warp_implementation::peer_builder; use crate::protocol::info_hash::InfoHash; +use crate::tracker::auth::KeyId; use crate::tracker::{self, auth, peer, statistics, torrent}; /// Authenticate `InfoHash` using optional `auth::Key` @@ -21,11 +22,11 @@ use crate::tracker::{self, auth, peer, statistics, torrent}; /// Will return `ServerError` that wraps the `tracker::error::Error` if unable to `authenticate_request`. pub async fn authenticate( info_hash: &InfoHash, - auth_key: &Option, + auth_key_id: &Option, tracker: Arc, ) -> Result<(), Error> { tracker - .authenticate_request(info_hash, auth_key) + .authenticate_request(info_hash, auth_key_id) .await .map_err(|e| Error::TrackerError { source: (Arc::new(e) as Arc).into(), @@ -37,7 +38,7 @@ pub async fn authenticate( /// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_announce_response`. pub async fn handle_announce( announce_request: request::Announce, - auth_key: Option, + auth_key_id: Option, tracker: Arc, ) -> WebResult { debug!("http announce request: {:#?}", announce_request); @@ -45,7 +46,7 @@ pub async fn handle_announce( let info_hash = announce_request.info_hash; let remote_client_ip = announce_request.peer_addr; - authenticate(&info_hash, &auth_key, tracker.clone()).await?; + authenticate(&info_hash, &auth_key_id, tracker.clone()).await?; let mut peer = peer_builder::from_request(&announce_request, &remote_client_ip); @@ -77,7 +78,7 @@ pub async fn handle_announce( /// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_scrape_response`. pub async fn handle_scrape( scrape_request: request::Scrape, - auth_key: Option, + auth_key_id: Option, tracker: Arc, ) -> WebResult { let mut files: HashMap = HashMap::new(); @@ -86,7 +87,7 @@ pub async fn handle_scrape( for info_hash in &scrape_request.info_hashes { let scrape_entry = match db.get(info_hash) { Some(torrent_info) => { - if authenticate(info_hash, &auth_key, tracker.clone()).await.is_ok() { + if authenticate(info_hash, &auth_key_id, tracker.clone()).await.is_ok() { let (seeders, completed, leechers) = torrent_info.get_stats(); response::ScrapeEntry { complete: seeders, diff --git a/src/http/warp_implementation/routes.rs b/src/http/warp_implementation/routes.rs index c46c502e4..2ee60e8c9 100644 --- a/src/http/warp_implementation/routes.rs +++ b/src/http/warp_implementation/routes.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use warp::{Filter, Rejection}; -use super::filters::{with_announce_request, with_auth_key, with_scrape_request, with_tracker}; +use super::filters::{with_announce_request, with_auth_key_id, with_scrape_request, with_tracker}; use super::handlers::{handle_announce, handle_scrape, send_error}; use crate::tracker; @@ -20,7 +20,7 @@ fn announce(tracker: Arc) -> impl Filter) -> impl Filter Key { - let key: String = thread_rng() + let random_id: String = thread_rng() .sample_iter(&Alphanumeric) .take(AUTH_KEY_LENGTH) .map(char::from) .collect(); - debug!("Generated key: {}, valid for: {:?} seconds", key, lifetime); + debug!("Generated key: {}, valid for: {:?} seconds", random_id, lifetime); Key { - key, + id: random_id.parse::().unwrap(), valid_until: Some(Current::add(&lifetime).unwrap()), } } @@ -54,16 +54,14 @@ pub fn verify(auth_key: &Key) -> Result<(), Error> { } None => Err(Error::UnableToReadKey { location: Location::caller(), - key: Box::new(auth_key.clone()), + key_id: Box::new(auth_key.id.clone()), }), } } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] pub struct Key { - // todo: replace key field definition with: - // pub key: KeyId, - pub key: String, + pub id: KeyId, pub valid_until: Option, } @@ -72,7 +70,7 @@ impl std::fmt::Display for Key { write!( f, "key: `{}`, valid until `{}`", - self.key, + self.id, match self.valid_until { Some(duration) => format!( "{}", @@ -91,20 +89,29 @@ impl std::fmt::Display for Key { } impl Key { + /// # Panics + /// + /// Will panic if bytes cannot be converted into a valid `KeyId`. #[must_use] pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { if let Ok(key) = String::from_utf8(Vec::from(key_buffer)) { - Some(Key { key, valid_until: None }) + Some(Key { + id: key.parse::().unwrap(), + valid_until: None, + }) } else { None } } + /// # Panics + /// + /// Will panic if string cannot be converted into a valid `KeyId`. #[must_use] pub fn from_string(key: &str) -> Option { if key.len() == AUTH_KEY_LENGTH { Some(Key { - key: key.to_string(), + id: key.parse::().unwrap(), valid_until: None, }) } else { @@ -112,18 +119,13 @@ impl Key { } } - /// # Panics - /// - /// Will fail if the key id is not a valid key id. #[must_use] pub fn id(&self) -> KeyId { - // todo: replace the type of field `key` with type `KeyId`. - // The constructor should fail if an invalid KeyId is provided. - KeyId::from_str(&self.key).unwrap() + self.id.clone() } } -#[derive(Debug, Display, PartialEq, Clone)] +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone, Display, Hash)] pub struct KeyId(String); #[derive(Debug, PartialEq, Eq)] @@ -148,10 +150,10 @@ pub enum Error { KeyVerificationError { source: LocatedError<'static, dyn std::error::Error + Send + Sync>, }, - #[error("Failed to read key: {key}, {location}")] + #[error("Failed to read key: {key_id}, {location}")] UnableToReadKey { location: &'static Location<'static>, - key: Box, + key_id: Box, }, #[error("Key has expired, {location}")] KeyExpired { location: &'static Location<'static> }, @@ -171,7 +173,7 @@ mod tests { use std::time::Duration; use crate::protocol::clock::{Current, StoppedTime}; - use crate::tracker::auth; + use crate::tracker::auth::{self, KeyId}; #[test] fn auth_key_from_buffer() { @@ -181,7 +183,10 @@ mod tests { ]); assert!(auth_key.is_some()); - assert_eq!(auth_key.unwrap().key, "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ"); + assert_eq!( + auth_key.unwrap().id, + "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ".parse::().unwrap() + ); } #[test] @@ -190,7 +195,7 @@ mod tests { let auth_key = auth::Key::from_string(key_string); assert!(auth_key.is_some()); - assert_eq!(auth_key.unwrap().key, key_string); + assert_eq!(auth_key.unwrap().id, key_string.parse::().unwrap()); } #[test] diff --git a/src/tracker/error.rs b/src/tracker/error.rs index 51bcbf3bb..acc85a1c2 100644 --- a/src/tracker/error.rs +++ b/src/tracker/error.rs @@ -4,9 +4,9 @@ use crate::located_error::LocatedError; #[derive(thiserror::Error, Debug, Clone)] pub enum Error { - #[error("The supplied key: {key:?}, is not valid: {source}")] + #[error("The supplied key: {key_id:?}, is not valid: {source}")] PeerKeyNotValid { - key: super::auth::Key, + key_id: super::auth::KeyId, source: LocatedError<'static, dyn std::error::Error + Send + Sync>, }, #[error("The peer is not authenticated, {location}")] diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 3e5e97439..147c889ac 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -16,6 +16,7 @@ use std::time::Duration; use tokio::sync::mpsc::error::SendError; use tokio::sync::{RwLock, RwLockReadGuard}; +use self::auth::KeyId; use self::error::Error; use self::peer::Peer; use self::torrent::{SwamStats, SwarmMetadata}; @@ -27,7 +28,7 @@ use crate::protocol::info_hash::InfoHash; pub struct Tracker { pub config: Arc, mode: mode::Mode, - keys: RwLock>, + keys: RwLock>, whitelist: RwLock>, torrents: RwLock>, stats_event_sender: Option>, @@ -155,28 +156,31 @@ impl Tracker { pub async fn generate_auth_key(&self, lifetime: Duration) -> Result { let auth_key = auth::generate(lifetime); self.database.add_key_to_keys(&auth_key).await?; - self.keys.write().await.insert(auth_key.key.clone(), auth_key.clone()); + self.keys.write().await.insert(auth_key.id.clone(), auth_key.clone()); Ok(auth_key) } /// # Errors /// /// Will return a `database::Error` if unable to remove the `key` to the database. + /// + /// # Panics + /// + /// Will panic if key cannot be converted into a valid `KeyId`. pub async fn remove_auth_key(&self, key: &str) -> Result<(), databases::error::Error> { self.database.remove_key_from_keys(key).await?; - self.keys.write().await.remove(key); + self.keys.write().await.remove(&key.parse::().unwrap()); Ok(()) } /// # Errors /// /// Will return a `key::Error` if unable to get any `auth_key`. - pub async fn verify_auth_key(&self, auth_key: &auth::Key) -> Result<(), auth::Error> { - // todo: use auth::KeyId for the function argument `auth_key` - match self.keys.read().await.get(&auth_key.key) { + pub async fn verify_auth_key(&self, key_id: &KeyId) -> Result<(), auth::Error> { + match self.keys.read().await.get(key_id) { None => Err(auth::Error::UnableToReadKey { location: Location::caller(), - key: Box::new(auth_key.clone()), + key_id: Box::new(key_id.clone()), }), Some(key) => auth::verify(key), } @@ -192,7 +196,7 @@ impl Tracker { keys.clear(); for key in keys_from_database { - keys.insert(key.key.clone(), key); + keys.insert(key.id.clone(), key); } Ok(()) @@ -283,7 +287,7 @@ impl Tracker { /// Will return a `torrent::Error::PeerNotAuthenticated` if the `key` is `None`. /// /// Will return a `torrent::Error::TorrentNotWhitelisted` if the the Tracker is in listed mode and the `info_hash` is not whitelisted. - pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), Error> { + pub async fn authenticate_request(&self, info_hash: &InfoHash, key: &Option) -> Result<(), Error> { // no authentication needed in public mode if self.is_public() { return Ok(()); @@ -295,7 +299,7 @@ impl Tracker { Some(key) => { if let Err(e) = self.verify_auth_key(key).await { return Err(Error::PeerKeyNotValid { - key: key.clone(), + key_id: key.clone(), source: (Arc::new(e) as Arc).into(), }); } diff --git a/tests/tracker_api.rs b/tests/tracker_api.rs index 193c6487c..bec22e2b4 100644 --- a/tests/tracker_api.rs +++ b/tests/tracker_api.rs @@ -638,7 +638,7 @@ mod tracker_apis { mod for_key_resources { use std::time::Duration; - use torrust_tracker::tracker::auth::Key; + use torrust_tracker::tracker::auth::KeyId; use crate::api::asserts::{ assert_auth_key_utf8, assert_failed_to_delete_key, assert_failed_to_generate_key, assert_failed_to_reload_keys, @@ -665,7 +665,7 @@ mod tracker_apis { // Verify the key with the tracker assert!(api_server .tracker - .verify_auth_key(&Key::from(auth_key_resource)) + .verify_auth_key(&auth_key_resource.key.parse::().unwrap()) .await .is_ok()); } @@ -734,7 +734,7 @@ mod tracker_apis { .unwrap(); let response = Client::new(api_server.get_connection_info()) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_ok(response).await; @@ -777,7 +777,7 @@ mod tracker_apis { force_database_error(&api_server.tracker); let response = Client::new(api_server.get_connection_info()) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_failed_to_delete_key(response).await; @@ -797,7 +797,7 @@ mod tracker_apis { .unwrap(); let response = Client::new(connection_with_invalid_token(&api_server.get_bind_address())) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_token_not_valid(response).await; @@ -810,7 +810,7 @@ mod tracker_apis { .unwrap(); let response = Client::new(connection_with_no_token(&api_server.get_bind_address())) - .delete_auth_key(&auth_key.key) + .delete_auth_key(&auth_key.id.to_string()) .await; assert_unauthorized(response).await; From 28e655fbd698f64b71c16d78bd3dbd211419d47d Mon Sep 17 00:00:00 2001 From: Jose Celano Date: Mon, 27 Feb 2023 19:03:16 +0000 Subject: [PATCH 2/2] refactor: [#171] rename auth::Key to auth::ExpiringKey --- src/apis/resources/auth_key.rs | 14 +++++++------- src/databases/mod.rs | 6 +++--- src/databases/mysql.rs | 10 +++++----- src/databases/sqlite.rs | 12 ++++++------ src/tracker/auth.rs | 24 ++++++++++++------------ src/tracker/mod.rs | 4 ++-- 6 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/apis/resources/auth_key.rs b/src/apis/resources/auth_key.rs index 207a0c482..e9989ca75 100644 --- a/src/apis/resources/auth_key.rs +++ b/src/apis/resources/auth_key.rs @@ -11,9 +11,9 @@ pub struct AuthKey { pub valid_until: Option, } -impl From for auth::Key { +impl From for auth::ExpiringKey { fn from(auth_key_resource: AuthKey) -> Self { - auth::Key { + auth::ExpiringKey { id: auth_key_resource.key.parse::().unwrap(), valid_until: auth_key_resource .valid_until @@ -22,8 +22,8 @@ impl From for auth::Key { } } -impl From for AuthKey { - fn from(auth_key: auth::Key) -> Self { +impl From for AuthKey { + fn from(auth_key: auth::ExpiringKey) -> Self { AuthKey { key: auth_key.id.to_string(), valid_until: auth_key.valid_until.map(|valid_until| valid_until.as_secs()), @@ -49,8 +49,8 @@ mod tests { }; assert_eq!( - auth::Key::from(auth_key_resource), - auth::Key { + auth::ExpiringKey::from(auth_key_resource), + auth::ExpiringKey { id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::().unwrap(), // cspell:disable-line valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()) } @@ -61,7 +61,7 @@ mod tests { fn it_should_be_convertible_from_an_auth_key() { let duration_in_secs = 60; - let auth_key = auth::Key { + let auth_key = auth::ExpiringKey { id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::().unwrap(), // cspell:disable-line valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()), }; diff --git a/src/databases/mod.rs b/src/databases/mod.rs index 70cc9eb75..038be0ea3 100644 --- a/src/databases/mod.rs +++ b/src/databases/mod.rs @@ -57,7 +57,7 @@ pub trait Database: Sync + Send { async fn load_persistent_torrents(&self) -> Result, Error>; - async fn load_keys(&self) -> Result, Error>; + async fn load_keys(&self) -> Result, Error>; async fn load_whitelist(&self) -> Result, Error>; @@ -71,9 +71,9 @@ pub trait Database: Sync + Send { async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result; // todo: replace type `&str` with `&KeyId` - async fn get_key_from_keys(&self, key: &str) -> Result, Error>; + async fn get_key_from_keys(&self, key: &str) -> Result, Error>; - async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result; + async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result; // todo: replace type `&str` with `&KeyId` async fn remove_key_from_keys(&self, key: &str) -> Result; diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index 532ba1dcb..0d545aaa9 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -111,12 +111,12 @@ impl Database for Mysql { Ok(torrents) } - async fn load_keys(&self) -> Result, Error> { + async fn load_keys(&self) -> Result, Error> { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let keys = conn.query_map( "SELECT `key`, valid_until FROM `keys`", - |(key, valid_until): (String, i64)| auth::Key { + |(key, valid_until): (String, i64)| auth::ExpiringKey { id: key.parse::().unwrap(), valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())), }, @@ -183,7 +183,7 @@ impl Database for Mysql { Ok(1) } - async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + async fn get_key_from_keys(&self, key: &str) -> Result, Error> { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let query = @@ -191,13 +191,13 @@ impl Database for Mysql { let key = query?; - Ok(key.map(|(key, expiry)| auth::Key { + Ok(key.map(|(key, expiry)| auth::ExpiringKey { id: key.parse::().unwrap(), valid_until: Some(Duration::from_secs(expiry.unsigned_abs())), })) } - async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { + async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result { let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?; let key = auth_key.id.to_string(); diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index d6915c850..ab0addf4b 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -102,7 +102,7 @@ impl Database for Sqlite { Ok(torrents) } - async fn load_keys(&self) -> Result, Error> { + async fn load_keys(&self) -> Result, Error> { let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?; @@ -111,13 +111,13 @@ impl Database for Sqlite { let key: String = row.get(0)?; let valid_until: i64 = row.get(1)?; - Ok(auth::Key { + Ok(auth::ExpiringKey { id: key.parse::().unwrap(), valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())), }) })?; - let keys: Vec = keys_iter.filter_map(std::result::Result::ok).collect(); + let keys: Vec = keys_iter.filter_map(std::result::Result::ok).collect(); Ok(keys) } @@ -200,7 +200,7 @@ impl Database for Sqlite { } } - async fn get_key_from_keys(&self, key: &str) -> Result, Error> { + async fn get_key_from_keys(&self, key: &str) -> Result, Error> { let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?; @@ -212,14 +212,14 @@ impl Database for Sqlite { Ok(key.map(|f| { let expiry: i64 = f.get(1).unwrap(); let id: String = f.get(0).unwrap(); - auth::Key { + auth::ExpiringKey { id: id.parse::().unwrap(), valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())), } })) } - async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result { + async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result { let conn = self.pool.get().map_err(|e| (e, DRIVER))?; let insert = conn.execute( diff --git a/src/tracker/auth.rs b/src/tracker/auth.rs index 53304657a..22f734e48 100644 --- a/src/tracker/auth.rs +++ b/src/tracker/auth.rs @@ -19,7 +19,7 @@ use crate::protocol::common::AUTH_KEY_LENGTH; /// # Panics /// /// It would panic if the `lifetime: Duration` + Duration is more than `Duration::MAX`. -pub fn generate(lifetime: Duration) -> Key { +pub fn generate(lifetime: Duration) -> ExpiringKey { let random_id: String = thread_rng() .sample_iter(&Alphanumeric) .take(AUTH_KEY_LENGTH) @@ -28,7 +28,7 @@ pub fn generate(lifetime: Duration) -> Key { debug!("Generated key: {}, valid for: {:?} seconds", random_id, lifetime); - Key { + ExpiringKey { id: random_id.parse::().unwrap(), valid_until: Some(Current::add(&lifetime).unwrap()), } @@ -39,7 +39,7 @@ pub fn generate(lifetime: Duration) -> Key { /// Will return `Error::KeyExpired` if `auth_key.valid_until` is past the `current_time`. /// /// Will return `Error::KeyInvalid` if `auth_key.valid_until` is past the `None`. -pub fn verify(auth_key: &Key) -> Result<(), Error> { +pub fn verify(auth_key: &ExpiringKey) -> Result<(), Error> { let current_time: DurationSinceUnixEpoch = Current::now(); match auth_key.valid_until { @@ -60,12 +60,12 @@ pub fn verify(auth_key: &Key) -> Result<(), Error> { } #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] -pub struct Key { +pub struct ExpiringKey { pub id: KeyId, pub valid_until: Option, } -impl std::fmt::Display for Key { +impl std::fmt::Display for ExpiringKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -88,14 +88,14 @@ impl std::fmt::Display for Key { } } -impl Key { +impl ExpiringKey { /// # Panics /// /// Will panic if bytes cannot be converted into a valid `KeyId`. #[must_use] - pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { + pub fn from_buffer(key_buffer: [u8; AUTH_KEY_LENGTH]) -> Option { if let Ok(key) = String::from_utf8(Vec::from(key_buffer)) { - Some(Key { + Some(ExpiringKey { id: key.parse::().unwrap(), valid_until: None, }) @@ -108,9 +108,9 @@ impl Key { /// /// Will panic if string cannot be converted into a valid `KeyId`. #[must_use] - pub fn from_string(key: &str) -> Option { + pub fn from_string(key: &str) -> Option { if key.len() == AUTH_KEY_LENGTH { - Some(Key { + Some(ExpiringKey { id: key.parse::().unwrap(), valid_until: None, }) @@ -177,7 +177,7 @@ mod tests { #[test] fn auth_key_from_buffer() { - let auth_key = auth::Key::from_buffer([ + let auth_key = auth::ExpiringKey::from_buffer([ 89, 90, 83, 108, 52, 108, 77, 90, 117, 112, 82, 117, 79, 112, 83, 82, 67, 51, 107, 114, 73, 75, 82, 53, 66, 80, 66, 49, 52, 110, 114, 74, ]); @@ -192,7 +192,7 @@ mod tests { #[test] fn auth_key_from_string() { let key_string = "YZSl4lMZupRuOpSRC3krIKR5BPB14nrJ"; - let auth_key = auth::Key::from_string(key_string); + let auth_key = auth::ExpiringKey::from_string(key_string); assert!(auth_key.is_some()); assert_eq!(auth_key.unwrap().id, key_string.parse::().unwrap()); diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 147c889ac..0fb434aea 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -28,7 +28,7 @@ use crate::protocol::info_hash::InfoHash; pub struct Tracker { pub config: Arc, mode: mode::Mode, - keys: RwLock>, + keys: RwLock>, whitelist: RwLock>, torrents: RwLock>, stats_event_sender: Option>, @@ -153,7 +153,7 @@ impl Tracker { /// # Errors /// /// Will return a `database::Error` if unable to add the `auth_key` to the database. - pub async fn generate_auth_key(&self, lifetime: Duration) -> Result { + pub async fn generate_auth_key(&self, lifetime: Duration) -> Result { let auth_key = auth::generate(lifetime); self.database.add_key_to_keys(&auth_key).await?; self.keys.write().await.insert(auth_key.id.clone(), auth_key.clone());