Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/apis/resources/auth_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@ use std::convert::From;
use serde::{Deserialize, Serialize};

use crate::protocol::clock::DurationSinceUnixEpoch;
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
pub struct AuthKey {
pub key: String,
pub key: String, // todo: rename to `id`
pub valid_until: Option<u64>,
}

impl From<AuthKey> for auth::Key {
impl From<AuthKey> for auth::ExpiringKey {
fn from(auth_key_resource: AuthKey) -> Self {
auth::Key {
key: auth_key_resource.key,
auth::ExpiringKey {
id: auth_key_resource.key.parse::<KeyId>().unwrap(),
valid_until: auth_key_resource
.valid_until
.map(|valid_until| DurationSinceUnixEpoch::new(valid_until, 0)),
}
}
}

impl From<auth::Key> for AuthKey {
fn from(auth_key: auth::Key) -> Self {
impl From<auth::ExpiringKey> for AuthKey {
fn from(auth_key: auth::ExpiringKey) -> Self {
AuthKey {
key: auth_key.key,
key: auth_key.id.to_string(),
valid_until: auth_key.valid_until.map(|valid_until| valid_until.as_secs()),
}
}
Expand All @@ -37,7 +37,7 @@ mod tests {

use super::AuthKey;
use crate::protocol::clock::{Current, TimeNow};
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

#[test]
fn it_should_be_convertible_into_an_auth_key() {
Expand All @@ -49,9 +49,9 @@ mod tests {
};

assert_eq!(
auth::Key::from(auth_key_resource),
auth::Key {
key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line
auth::ExpiringKey::from(auth_key_resource),
auth::ExpiringKey {
id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::<KeyId>().unwrap(), // cspell:disable-line
valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap())
}
);
Expand All @@ -61,8 +61,8 @@ mod tests {
fn it_should_be_convertible_from_an_auth_key() {
let duration_in_secs = 60;

let auth_key = auth::Key {
key: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".to_string(), // cspell:disable-line
let auth_key = auth::ExpiringKey {
id: "IaWDneuFNZi8IB4MPA3qW1CD0M30EZSM".parse::<KeyId>().unwrap(), // cspell:disable-line
valid_until: Some(Current::add(&Duration::new(duration_in_secs, 0)).unwrap()),
};

Expand Down
9 changes: 6 additions & 3 deletions src/databases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,25 @@ pub trait Database: Sync + Send {

async fn load_persistent_torrents(&self) -> Result<Vec<(InfoHash, u32)>, Error>;

async fn load_keys(&self) -> Result<Vec<auth::Key>, Error>;
async fn load_keys(&self) -> Result<Vec<auth::ExpiringKey>, Error>;

async fn load_whitelist(&self) -> Result<Vec<InfoHash>, Error>;

async fn save_persistent_torrent(&self, info_hash: &InfoHash, completed: u32) -> Result<(), Error>;

// todo: replace type `&str` with `&InfoHash`
async fn get_info_hash_from_whitelist(&self, info_hash: &str) -> Result<Option<InfoHash>, Error>;

async fn add_info_hash_to_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn remove_info_hash_from_whitelist(&self, info_hash: InfoHash) -> Result<usize, Error>;

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::Key>, Error>;
// todo: replace type `&str` with `&KeyId`
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error>;

async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result<usize, Error>;
async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result<usize, Error>;

// todo: replace type `&str` with `&KeyId`
async fn remove_key_from_keys(&self, key: &str) -> Result<usize, Error>;

async fn is_info_hash_whitelisted(&self, info_hash: &InfoHash) -> Result<bool, Error> {
Expand Down
18 changes: 9 additions & 9 deletions src/databases/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::driver::Driver;
use crate::databases::{Database, Error};
use crate::protocol::common::AUTH_KEY_LENGTH;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

const DRIVER: Driver = Driver::MySQL;

Expand Down Expand Up @@ -111,13 +111,13 @@ impl Database for Mysql {
Ok(torrents)
}

async fn load_keys(&self) -> Result<Vec<auth::Key>, Error> {
async fn load_keys(&self) -> Result<Vec<auth::ExpiringKey>, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let keys = conn.query_map(
"SELECT `key`, valid_until FROM `keys`",
|(key, valid_until): (String, i64)| auth::Key {
key,
|(key, valid_until): (String, i64)| auth::ExpiringKey {
id: key.parse::<KeyId>().unwrap(),
valid_until: Some(Duration::from_secs(valid_until.unsigned_abs())),
},
)?;
Expand Down Expand Up @@ -183,24 +183,24 @@ impl Database for Mysql {
Ok(1)
}

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::Key>, Error> {
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let query =
conn.exec_first::<(String, i64), _, _>("SELECT `key`, valid_until FROM `keys` WHERE `key` = :key", params! { key });

let key = query?;

Ok(key.map(|(key, expiry)| auth::Key {
key,
Ok(key.map(|(key, expiry)| auth::ExpiringKey {
id: key.parse::<KeyId>().unwrap(),
valid_until: Some(Duration::from_secs(expiry.unsigned_abs())),
}))
}

async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result<usize, Error> {
async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result<usize, Error> {
let mut conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let key = auth_key.key.to_string();
let key = auth_key.id.to_string();
let valid_until = auth_key.valid_until.unwrap_or(Duration::ZERO).as_secs().to_string();

conn.exec_drop(
Expand Down
23 changes: 12 additions & 11 deletions src/databases/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::driver::Driver;
use crate::databases::{Database, Error};
use crate::protocol::clock::DurationSinceUnixEpoch;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth;
use crate::tracker::auth::{self, KeyId};

const DRIVER: Driver = Driver::Sqlite3;

Expand Down Expand Up @@ -102,22 +102,22 @@ impl Database for Sqlite {
Ok(torrents)
}

async fn load_keys(&self) -> Result<Vec<auth::Key>, Error> {
async fn load_keys(&self) -> Result<Vec<auth::ExpiringKey>, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let mut stmt = conn.prepare("SELECT key, valid_until FROM keys")?;

let keys_iter = stmt.query_map([], |row| {
let key = row.get(0)?;
let key: String = row.get(0)?;
let valid_until: i64 = row.get(1)?;

Ok(auth::Key {
key,
Ok(auth::ExpiringKey {
id: key.parse::<KeyId>().unwrap(),
valid_until: Some(DurationSinceUnixEpoch::from_secs(valid_until.unsigned_abs())),
})
})?;

let keys: Vec<auth::Key> = keys_iter.filter_map(std::result::Result::ok).collect();
let keys: Vec<auth::ExpiringKey> = keys_iter.filter_map(std::result::Result::ok).collect();

Ok(keys)
}
Expand Down Expand Up @@ -200,7 +200,7 @@ impl Database for Sqlite {
}
}

async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::Key>, Error> {
async fn get_key_from_keys(&self, key: &str) -> Result<Option<auth::ExpiringKey>, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let mut stmt = conn.prepare("SELECT key, valid_until FROM keys WHERE key = ?")?;
Expand All @@ -211,19 +211,20 @@ impl Database for Sqlite {

Ok(key.map(|f| {
let expiry: i64 = f.get(1).unwrap();
auth::Key {
key: f.get(0).unwrap(),
let id: String = f.get(0).unwrap();
auth::ExpiringKey {
id: id.parse::<KeyId>().unwrap(),
valid_until: Some(DurationSinceUnixEpoch::from_secs(expiry.unsigned_abs())),
}
}))
}

async fn add_key_to_keys(&self, auth_key: &auth::Key) -> Result<usize, Error> {
async fn add_key_to_keys(&self, auth_key: &auth::ExpiringKey) -> Result<usize, Error> {
let conn = self.pool.get().map_err(|e| (e, DRIVER))?;

let insert = conn.execute(
"INSERT INTO keys (key, valid_until) VALUES (?1, ?2)",
[auth_key.key.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()],
[auth_key.id.to_string(), auth_key.valid_until.unwrap().as_secs().to_string()],
)?;

if insert == 0 {
Expand Down
16 changes: 12 additions & 4 deletions src/http/warp_implementation/filters.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::convert::Infallible;
use std::net::{IpAddr, SocketAddr};
use std::panic::Location;
use std::str::FromStr;
use std::sync::Arc;

use warp::{reject, Filter, Rejection};
Expand All @@ -11,7 +12,8 @@ use super::{request, WebResult};
use crate::http::percent_encoding::{percent_decode_info_hash, percent_decode_peer_id};
use crate::protocol::common::MAX_SCRAPE_TORRENTS;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::{self, auth, peer};
use crate::tracker::auth::KeyId;
use crate::tracker::{self, peer};

/// Pass Arc<tracker::TorrentTracker> along
#[must_use]
Expand All @@ -35,10 +37,16 @@ pub fn with_peer_id() -> impl Filter<Extract = (peer::Id,), Error = Rejection> +

/// Pass Arc<tracker::TorrentTracker> along
#[must_use]
pub fn with_auth_key() -> impl Filter<Extract = (Option<auth::Key>,), Error = Infallible> + Clone {
pub fn with_auth_key_id() -> impl Filter<Extract = (Option<KeyId>,), Error = Infallible> + Clone {
warp::path::param::<String>()
.map(|key: String| auth::Key::from_string(&key))
.or_else(|_| async { Ok::<(Option<auth::Key>,), Infallible>((None,)) })
.map(|key: String| {
let key_id = KeyId::from_str(&key);
match key_id {
Ok(id) => Some(id),
Err(_) => None,
}
})
.or_else(|_| async { Ok::<(Option<KeyId>,), Infallible>((None,)) })
}

/// Check for `PeerAddress`
Expand Down
13 changes: 7 additions & 6 deletions src/http/warp_implementation/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use super::error::Error;
use super::{request, response, WebResult};
use crate::http::warp_implementation::peer_builder;
use crate::protocol::info_hash::InfoHash;
use crate::tracker::auth::KeyId;
use crate::tracker::{self, auth, peer, statistics, torrent};

/// Authenticate `InfoHash` using optional `auth::Key`
Expand All @@ -21,11 +22,11 @@ use crate::tracker::{self, auth, peer, statistics, torrent};
/// Will return `ServerError` that wraps the `tracker::error::Error` if unable to `authenticate_request`.
pub async fn authenticate(
info_hash: &InfoHash,
auth_key: &Option<auth::Key>,
auth_key_id: &Option<auth::KeyId>,
tracker: Arc<tracker::Tracker>,
) -> Result<(), Error> {
tracker
.authenticate_request(info_hash, auth_key)
.authenticate_request(info_hash, auth_key_id)
.await
.map_err(|e| Error::TrackerError {
source: (Arc::new(e) as Arc<dyn std::error::Error + Send + Sync>).into(),
Expand All @@ -37,15 +38,15 @@ pub async fn authenticate(
/// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_announce_response`.
pub async fn handle_announce(
announce_request: request::Announce,
auth_key: Option<auth::Key>,
auth_key_id: Option<KeyId>,
tracker: Arc<tracker::Tracker>,
) -> WebResult<impl Reply> {
debug!("http announce request: {:#?}", announce_request);

let info_hash = announce_request.info_hash;
let remote_client_ip = announce_request.peer_addr;

authenticate(&info_hash, &auth_key, tracker.clone()).await?;
authenticate(&info_hash, &auth_key_id, tracker.clone()).await?;

let mut peer = peer_builder::from_request(&announce_request, &remote_client_ip);

Expand Down Expand Up @@ -77,7 +78,7 @@ pub async fn handle_announce(
/// Will return `warp::Rejection` that wraps the `ServerError` if unable to `send_scrape_response`.
pub async fn handle_scrape(
scrape_request: request::Scrape,
auth_key: Option<auth::Key>,
auth_key_id: Option<KeyId>,
tracker: Arc<tracker::Tracker>,
) -> WebResult<impl Reply> {
let mut files: HashMap<InfoHash, response::ScrapeEntry> = HashMap::new();
Expand All @@ -86,7 +87,7 @@ pub async fn handle_scrape(
for info_hash in &scrape_request.info_hashes {
let scrape_entry = match db.get(info_hash) {
Some(torrent_info) => {
if authenticate(info_hash, &auth_key, tracker.clone()).await.is_ok() {
if authenticate(info_hash, &auth_key_id, tracker.clone()).await.is_ok() {
let (seeders, completed, leechers) = torrent_info.get_stats();
response::ScrapeEntry {
complete: seeders,
Expand Down
6 changes: 3 additions & 3 deletions src/http/warp_implementation/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use warp::{Filter, Rejection};

use super::filters::{with_announce_request, with_auth_key, with_scrape_request, with_tracker};
use super::filters::{with_announce_request, with_auth_key_id, with_scrape_request, with_tracker};
use super::handlers::{handle_announce, handle_scrape, send_error};
use crate::tracker;

Expand All @@ -20,7 +20,7 @@ fn announce(tracker: Arc<tracker::Tracker>) -> impl Filter<Extract = impl warp::
warp::path::path("announce")
.and(warp::filters::method::get())
.and(with_announce_request(tracker.config.on_reverse_proxy))
.and(with_auth_key())
.and(with_auth_key_id())
.and(with_tracker(tracker))
.and_then(handle_announce)
}
Expand All @@ -30,7 +30,7 @@ fn scrape(tracker: Arc<tracker::Tracker>) -> impl Filter<Extract = impl warp::Re
warp::path::path("scrape")
.and(warp::filters::method::get())
.and(with_scrape_request(tracker.config.on_reverse_proxy))
.and(with_auth_key())
.and(with_auth_key_id())
.and(with_tracker(tracker))
.and_then(handle_scrape)
}
Loading