Skip to content
2 changes: 2 additions & 0 deletions cSpell.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
"nanos",
"nextest",
"nocapture",
"oneshot",
"ostr",
"Pando",
"Rasterbar",
"repr",
"reqwest",
"rngs",
"rusqlite",
"rustfmt",
Expand Down
11 changes: 11 additions & 0 deletions src/databases/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ pub trait Database: Sync + Send {
async fn add_key_to_keys(&self, auth_key: &AuthKey) -> Result<usize, Error>;

async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error>;

async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result<bool, Error> {
if let Err(e) = self.get_info_hash_from_whitelist(&info_hash.to_owned().to_string()).await {
if let Error::QueryReturnedNoRows = e {
return Ok(false);
} else {
return Err(e);
}
}
Ok(true)
}
}

#[derive(Debug, Display, PartialEq, Eq, Error)]
Expand Down
4 changes: 2 additions & 2 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ impl Database for MysqlDatabase {
"SELECT info_hash FROM whitelist WHERE info_hash = :info_hash",
params! { info_hash },
)
.map_err(|_| database::Error::QueryReturnedNoRows)?
.map_err(|_| database::Error::DatabaseError)?
{
Some(info_hash) => Ok(InfoHash::from_str(&info_hash).unwrap()),
None => Err(database::Error::InvalidQuery),
None => Err(database::Error::QueryReturnedNoRows),
}
}

Expand Down
16 changes: 9 additions & 7 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,15 @@ impl Database for SqliteDatabase {
let mut stmt = conn.prepare("SELECT info_hash FROM whitelist WHERE info_hash = ?")?;
let mut rows = stmt.query([info_hash])?;

if let Some(row) = rows.next()? {
let info_hash: String = row.get(0).unwrap();

// should never be able to fail
Ok(InfoHash::from_str(&info_hash).unwrap())
} else {
Err(database::Error::InvalidQuery)
match rows.next() {
Ok(row) => match row {
Some(row) => Ok(InfoHash::from_str(&row.get_unwrap::<_, String>(0)).unwrap()),
None => Err(database::Error::QueryReturnedNoRows),
},
Err(e) => {
debug!("{:?}", e);
Err(database::Error::InvalidQuery)
}
}
}

Expand Down
30 changes: 26 additions & 4 deletions src/jobs/tracker_api.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,43 @@
use std::sync::Arc;

use log::info;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;

use crate::api::server;
use crate::tracker::TorrentTracker;
use crate::Configuration;

pub fn start_job(config: &Configuration, tracker: Arc<TorrentTracker>) -> JoinHandle<()> {
#[derive(Debug)]
pub struct ApiServerJobStarted();

pub async fn start_job(config: &Configuration, tracker: Arc<TorrentTracker>) -> JoinHandle<()> {
let bind_addr = config
.http_api
.bind_address
.parse::<std::net::SocketAddr>()
.expect("Tracker API bind_address invalid.");

info!("Starting Torrust API server on: {}", bind_addr);

tokio::spawn(async move {
server::start(bind_addr, tracker).await;
})
let (tx, rx) = oneshot::channel::<ApiServerJobStarted>();

// Run the API server
let join_handle = tokio::spawn(async move {
let handel = server::start(bind_addr, tracker);

if tx.send(ApiServerJobStarted()).is_err() {
panic!("the start job dropped");
}

handel.await;
});

// Wait until the API server job is running
match rx.await {
Ok(_msg) => info!("Torrust API server started"),
Err(_) => panic!("the api server dropped"),
}

join_handle
}
2 changes: 1 addition & 1 deletion src/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub async fn setup(config: &Configuration, tracker: Arc<TorrentTracker>) -> Vec<

// Start HTTP API server
if config.http_api.enabled {
jobs.push(tracker_api::start_job(config, tracker.clone()));
jobs.push(tracker_api::start_job(config, tracker.clone()).await);
}

// Remove torrents without peers, every interval
Expand Down
6 changes: 6 additions & 0 deletions src/tracker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,14 @@ impl TorrentTracker {
Ok(())
}

/// It adds a torrent to the whitelist if it has not been whitelisted previously
async fn add_torrent_to_database_whitelist(&self, info_hash: &InfoHash) -> Result<(), database::Error> {
if self.database.is_info_hash_whitelisted(info_hash).await.unwrap() {
return Ok(());
}

self.database.add_info_hash_to_whitelist(*info_hash).await?;

Ok(())
}

Expand Down
62 changes: 52 additions & 10 deletions tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@ mod common;
mod tracker_api {
use core::panic;
use std::env;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use tokio::task::JoinHandle;
use tokio::time::{sleep, Duration};
use torrust_tracker::api::resources::auth_key_resource::AuthKeyResource;
use torrust_tracker::jobs::tracker_api;
use torrust_tracker::tracker::key::AuthKey;
use torrust_tracker::tracker::statistics::StatsTracker;
use torrust_tracker::tracker::TorrentTracker;
use torrust_tracker::{ephemeral_instance_keys, logging, static_time, Configuration};
use torrust_tracker::{ephemeral_instance_keys, logging, static_time, Configuration, InfoHash};

use crate::common::ephemeral_random_port;

#[tokio::test]
async fn should_generate_a_new_auth_key() {
async fn should_allow_generating_a_new_auth_key() {
let configuration = tracker_configuration();
let api_server = new_running_api_server(configuration.clone()).await;

Expand All @@ -44,15 +44,60 @@ mod tracker_api {
.is_ok());
}

#[tokio::test]
async fn should_allow_whitelisting_a_torrent() {
let configuration = tracker_configuration();
let api_server = new_running_api_server(configuration.clone()).await;

let bind_address = api_server.bind_address.unwrap().clone();
let api_token = configuration.http_api.access_tokens.get_key_value("admin").unwrap().1.clone();
let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();

let url = format!("http://{}/api/whitelist/{}?token={}", &bind_address, &info_hash, &api_token);

let res = reqwest::Client::new().post(url.clone()).send().await.unwrap();

assert_eq!(res.status(), 200);
assert!(
api_server
.tracker
.unwrap()
.is_info_hash_whitelisted(&InfoHash::from_str(&info_hash).unwrap())
.await
);
}

#[tokio::test]
async fn should_allow_whitelisting_a_torrent_that_has_been_already_whitelisted() {
let configuration = tracker_configuration();
let api_server = new_running_api_server(configuration.clone()).await;

let bind_address = api_server.bind_address.unwrap().clone();
let api_token = configuration.http_api.access_tokens.get_key_value("admin").unwrap().1.clone();
let info_hash = "9e0217d0fa71c87332cd8bf9dbeabcb2c2cf3c4d".to_owned();

let url = format!("http://{}/api/whitelist/{}?token={}", &bind_address, &info_hash, &api_token);

// First whitelist request
let res = reqwest::Client::new().post(url.clone()).send().await.unwrap();
assert_eq!(res.status(), 200);

// Second whitelist request
let res = reqwest::Client::new().post(url.clone()).send().await.unwrap();
assert_eq!(res.status(), 200);
}

fn tracker_configuration() -> Arc<Configuration> {
let mut config = Configuration::default();
config.log_level = Some("off".to_owned());

config.http_api.bind_address = format!("127.0.0.1:{}", ephemeral_random_port());
// Ephemeral socket address
let port = ephemeral_random_port();
config.http_api.bind_address = format!("127.0.0.1:{}", &port);

// Temp database
// Ephemeral database
let temp_directory = env::temp_dir();
let temp_file = temp_directory.join("data.db");
let temp_file = temp_directory.join(format!("data_{}.db", &port));
config.db_path = temp_file.to_str().unwrap().to_owned();

Arc::new(config)
Expand Down Expand Up @@ -107,12 +152,9 @@ mod tracker_api {
logging::setup_logging(&configuration);

// Start the HTTP API job
self.job = Some(tracker_api::start_job(&configuration, tracker.clone()));
self.job = Some(tracker_api::start_job(&configuration, tracker).await);

self.started.store(true, Ordering::Relaxed);

// Wait to give time to the API server to be ready to accept requests
sleep(Duration::from_millis(100)).await;
}
}
}
Expand Down