diff --git a/cSpell.json b/cSpell.json index 1df69e4e7..a2c4235c4 100644 --- a/cSpell.json +++ b/cSpell.json @@ -30,10 +30,12 @@ "nanos", "nextest", "nocapture", + "oneshot", "ostr", "Pando", "Rasterbar", "repr", + "reqwest", "rngs", "rusqlite", "rustfmt", diff --git a/src/databases/database.rs b/src/databases/database.rs index c67f39a54..795be0d45 100644 --- a/src/databases/database.rs +++ b/src/databases/database.rs @@ -53,6 +53,17 @@ pub trait Database: Sync + Send { async fn add_key_to_keys(&self, auth_key: &AuthKey) -> Result; async fn remove_key_from_keys(&self, key: &str) -> Result; + + async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result { + if let Err(e) = self.get_info_hash_from_whitelist(&info_hash.to_owned().to_string()).await { + if let Error::QueryReturnedNoRows = e { + return Ok(false); + } else { + return Err(e); + } + } + Ok(true) + } } #[derive(Debug, Display, PartialEq, Eq, Error)] diff --git a/src/databases/mysql.rs b/src/databases/mysql.rs index a4d870101..fc6ff5098 100644 --- a/src/databases/mysql.rs +++ b/src/databases/mysql.rs @@ -141,10 +141,10 @@ impl Database for MysqlDatabase { "SELECT info_hash FROM whitelist WHERE info_hash = :info_hash", params! { info_hash }, ) - .map_err(|_| database::Error::QueryReturnedNoRows)? + .map_err(|_| database::Error::DatabaseError)? { Some(info_hash) => Ok(InfoHash::from_str(&info_hash).unwrap()), - None => Err(database::Error::InvalidQuery), + None => Err(database::Error::QueryReturnedNoRows), } } diff --git a/src/databases/sqlite.rs b/src/databases/sqlite.rs index ef9f12d9c..7a567b07e 100644 --- a/src/databases/sqlite.rs +++ b/src/databases/sqlite.rs @@ -137,13 +137,15 @@ impl Database for SqliteDatabase { let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?; let mut rows = stmt.query([info_hash])?; - if let Some(row) = rows.next()? { - let info_hash: String = row.get(0).unwrap(); - - // should never be able to fail - Ok(InfoHash::from_str(&info_hash).unwrap()) - } else { - Err(database::Error::InvalidQuery) + match rows.next() { + Ok(row) => match row { + Some(row) => Ok(InfoHash::from_str(&row.get_unwrap::<_, String>(0)).unwrap()), + None => Err(database::Error::QueryReturnedNoRows), + }, + Err(e) => { + debug!("{:?}", e); + Err(database::Error::InvalidQuery) + } } } diff --git a/src/jobs/tracker_api.rs b/src/jobs/tracker_api.rs index 97b1fa3b0..ba5b8a1fb 100644 --- a/src/jobs/tracker_api.rs +++ b/src/jobs/tracker_api.rs @@ -1,21 +1,43 @@ use std::sync::Arc; use log::info; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use crate::api::server; use crate::tracker::TorrentTracker; use crate::Configuration; -pub fn start_job(config: &Configuration, tracker: Arc) -> JoinHandle<()> { +#[derive(Debug)] +pub struct ApiServerJobStarted(); + +pub async fn start_job(config: &Configuration, tracker: Arc) -> JoinHandle<()> { let bind_addr = config .http_api .bind_address .parse::() .expect("Tracker API bind_address invalid."); + info!("Starting Torrust API server on: {}", bind_addr); - tokio::spawn(async move { - server::start(bind_addr, tracker).await; - }) + let (tx, rx) = oneshot::channel::(); + + // Run the API server + let join_handle = tokio::spawn(async move { + let handel = server::start(bind_addr, tracker); + + if tx.send(ApiServerJobStarted()).is_err() { + panic!("the start job dropped"); + } + + handel.await; + }); + + // Wait until the API server job is running + match rx.await { + Ok(_msg) => info!("Torrust API server started"), + Err(_) => panic!("the api server dropped"), + } + + join_handle } diff --git a/src/setup.rs b/src/setup.rs index 2ecc1c143..9906a2d03 100644 --- a/src/setup.rs +++ b/src/setup.rs @@ -49,7 +49,7 @@ pub async fn setup(config: &Configuration, tracker: Arc) -> Vec< // Start HTTP API server if config.http_api.enabled { - jobs.push(tracker_api::start_job(config, tracker.clone())); + jobs.push(tracker_api::start_job(config, tracker.clone()).await); } // Remove torrents without peers, every interval diff --git a/src/tracker/mod.rs b/src/tracker/mod.rs index 77f51098a..a3eecd427 100644 --- a/src/tracker/mod.rs +++ b/src/tracker/mod.rs @@ -106,8 +106,14 @@ impl TorrentTracker { Ok(()) } + /// It adds a torrent to the whitelist if it has not been whitelisted previously async fn add_torrent_to_database_whitelist(&self, info_hash: &InfoHash) -> Result<(), database::Error> { + if self.database.is_info_hash_whitelisted(info_hash).await.unwrap() { + return Ok(()); + } + self.database.add_info_hash_to_whitelist(*info_hash).await?; + Ok(()) } diff --git a/tests/api.rs b/tests/api.rs index 38966a81b..278f9d4fb 100644 --- a/tests/api.rs +++ b/tests/api.rs @@ -8,22 +8,22 @@ mod common; mod tracker_api { use core::panic; use std::env; + use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::task::JoinHandle; - use tokio::time::{sleep, Duration}; use torrust_tracker::api::resources::auth_key_resource::AuthKeyResource; use torrust_tracker::jobs::tracker_api; use torrust_tracker::tracker::key::AuthKey; use torrust_tracker::tracker::statistics::StatsTracker; use torrust_tracker::tracker::TorrentTracker; - use torrust_tracker::{ephemeral_instance_keys, logging, static_time, Configuration}; + use torrust_tracker::{ephemeral_instance_keys, logging, static_time, Configuration, InfoHash}; use crate::common::ephemeral_random_port; #[tokio::test] - async fn should_generate_a_new_auth_key() { + async fn should_allow_generating_a_new_auth_key() { let configuration = tracker_configuration(); let api_server = new_running_api_server(configuration.clone()).await; @@ -44,15 +44,60 @@ mod tracker_api { .is_ok()); } + #[tokio::test] + async fn should_allow_whitelisting_a_torrent() { + let configuration = tracker_configuration(); + let api_server = new_running_api_server(configuration.clone()).await; + + let bind_address = api_server.bind_address.unwrap().clone(); + let api_token = configuration.http_api.access_tokens.get_key_value("admin").unwrap().1.clone(); + let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); + + let url = format!("http://{}/api/whitelist/{}?token={}", &bind_address, &info_hash, &api_token); + + let res = reqwest::Client::new().post(url.clone()).send().await.unwrap(); + + assert_eq!(res.status(), 200); + assert!( + api_server + .tracker + .unwrap() + .is_info_hash_whitelisted(&InfoHash::from_str(&info_hash).unwrap()) + .await + ); + } + + #[tokio::test] + async fn should_allow_whitelisting_a_torrent_that_has_been_already_whitelisted() { + let configuration = tracker_configuration(); + let api_server = new_running_api_server(configuration.clone()).await; + + let bind_address = api_server.bind_address.unwrap().clone(); + let api_token = configuration.http_api.access_tokens.get_key_value("admin").unwrap().1.clone(); + let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned(); + + let url = format!("http://{}/api/whitelist/{}?token={}", &bind_address, &info_hash, &api_token); + + // First whitelist request + let res = reqwest::Client::new().post(url.clone()).send().await.unwrap(); + assert_eq!(res.status(), 200); + + // Second whitelist request + let res = reqwest::Client::new().post(url.clone()).send().await.unwrap(); + assert_eq!(res.status(), 200); + } + fn tracker_configuration() -> Arc { let mut config = Configuration::default(); config.log_level = Some("off".to_owned()); - config.http_api.bind_address = format!("127.0.0.1:{}", ephemeral_random_port()); + // Ephemeral socket address + let port = ephemeral_random_port(); + config.http_api.bind_address = format!("127.0.0.1:{}", &port); - // Temp database + // Ephemeral database let temp_directory = env::temp_dir(); - let temp_file = temp_directory.join("data.db"); + let temp_file = temp_directory.join(format!("data_{}.db", &port)); config.db_path = temp_file.to_str().unwrap().to_owned(); Arc::new(config) @@ -107,12 +152,9 @@ mod tracker_api { logging::setup_logging(&configuration); // Start the HTTP API job - self.job = Some(tracker_api::start_job(&configuration, tracker.clone())); + self.job = Some(tracker_api::start_job(&configuration, tracker).await); self.started.store(true, Ordering::Relaxed); - - // Wait to give time to the API server to be ready to accept requests - sleep(Duration::from_millis(100)).await; } } }