Skip to content

Commit 76594b7

Browse files
committed
feat(tracker-core): add async sqlx mysql driver in parallel module
1 parent 2fb25a1 commit 76594b7

6 files changed

Lines changed: 648 additions & 0 deletions

File tree

packages/tracker-core/src/databases/sqlx/driver/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![allow(dead_code)]
22

3+
pub mod mysql;
34
pub mod sqlite;
45

56
#[cfg(test)]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use ::sqlx::Row;
2+
use async_trait::async_trait;
3+
use torrust_tracker_primitives::DurationSinceUnixEpoch;
4+
5+
use super::{MysqlSqlx, DRIVER};
6+
use crate::authentication::{self, Key};
7+
use crate::databases::error::Error;
8+
use crate::databases::sqlx::traits::AsyncAuthKeyStore;
9+
10+
#[async_trait]
11+
impl AsyncAuthKeyStore for MysqlSqlx {
12+
async fn load_keys(&self) -> Result<Vec<authentication::PeerKey>, Error> {
13+
self.ensure_schema().await?;
14+
15+
let rows = ::sqlx::query("SELECT `key`, valid_until FROM `keys`")
16+
.fetch_all(&self.pool)
17+
.await
18+
.map_err(|e| (e, DRIVER))?;
19+
20+
rows.into_iter()
21+
.map(|row| {
22+
let key_value: String = row.try_get("key").map_err(|e| (e, DRIVER))?;
23+
let valid_until: Option<i64> = row.try_get("valid_until").map_err(|e| (e, DRIVER))?;
24+
25+
let parsed_key = key_value.parse::<Key>().map_err(|e| Error::MalformedDatabaseRecord {
26+
message: e.to_string(),
27+
driver: DRIVER,
28+
})?;
29+
30+
Ok(match valid_until {
31+
Some(value) => authentication::PeerKey {
32+
key: parsed_key,
33+
valid_until: Some(DurationSinceUnixEpoch::from_secs(value.unsigned_abs())),
34+
},
35+
None => authentication::PeerKey {
36+
key: parsed_key,
37+
valid_until: None,
38+
},
39+
})
40+
})
41+
.collect()
42+
}
43+
44+
async fn get_key_from_keys(&self, key: &Key) -> Result<Option<authentication::PeerKey>, Error> {
45+
self.ensure_schema().await?;
46+
47+
let maybe_row = ::sqlx::query("SELECT `key`, valid_until FROM `keys` WHERE `key` = ?")
48+
.bind(key.to_string())
49+
.fetch_optional(&self.pool)
50+
.await
51+
.map_err(|e| (e, DRIVER))?;
52+
53+
maybe_row
54+
.map(|row| {
55+
let key_value: String = row.try_get("key").map_err(|e| (e, DRIVER))?;
56+
let valid_until: Option<i64> = row.try_get("valid_until").map_err(|e| (e, DRIVER))?;
57+
58+
let parsed_key = key_value.parse::<Key>().map_err(|e| Error::MalformedDatabaseRecord {
59+
message: e.to_string(),
60+
driver: DRIVER,
61+
})?;
62+
63+
Ok(match valid_until {
64+
Some(value) => authentication::PeerKey {
65+
key: parsed_key,
66+
valid_until: Some(DurationSinceUnixEpoch::from_secs(value.unsigned_abs())),
67+
},
68+
None => authentication::PeerKey {
69+
key: parsed_key,
70+
valid_until: None,
71+
},
72+
})
73+
})
74+
.transpose()
75+
}
76+
77+
async fn add_key_to_keys(&self, auth_key: &authentication::PeerKey) -> Result<usize, Error> {
78+
self.ensure_schema().await?;
79+
80+
let valid_until = auth_key
81+
.valid_until
82+
.map(|value| {
83+
i64::try_from(value.as_secs()).map_err(|e| Error::MalformedDatabaseRecord {
84+
message: e.to_string(),
85+
driver: DRIVER,
86+
})
87+
})
88+
.transpose()?;
89+
90+
let insert = ::sqlx::query("INSERT INTO `keys` (`key`, valid_until) VALUES (?, ?)")
91+
.bind(auth_key.key.to_string())
92+
.bind(valid_until)
93+
.execute(&self.pool)
94+
.await
95+
.map_err(|e| (e, DRIVER))?
96+
.rows_affected();
97+
98+
if insert == 0 {
99+
Err(Error::InsertFailed {
100+
location: std::panic::Location::caller(),
101+
driver: DRIVER,
102+
})
103+
} else {
104+
Ok(usize::try_from(insert).unwrap_or(0))
105+
}
106+
}
107+
108+
async fn remove_key_from_keys(&self, key: &Key) -> Result<usize, Error> {
109+
self.ensure_schema().await?;
110+
111+
let deleted = ::sqlx::query("DELETE FROM `keys` WHERE `key` = ?")
112+
.bind(key.to_string())
113+
.execute(&self.pool)
114+
.await
115+
.map_err(|e| (e, DRIVER))?
116+
.rows_affected();
117+
118+
if deleted == 1 {
119+
Ok(1)
120+
} else {
121+
Err(Error::DeleteFailed {
122+
location: std::panic::Location::caller(),
123+
error_code: usize::try_from(deleted).unwrap_or(0),
124+
driver: DRIVER,
125+
})
126+
}
127+
}
128+
}
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#![allow(dead_code)]
2+
3+
use std::str::FromStr;
4+
use std::sync::atomic::{AtomicBool, Ordering};
5+
6+
use ::sqlx::mysql::{MySqlConnectOptions, MySqlPoolOptions};
7+
use ::sqlx::{MySqlPool, Row};
8+
use tokio::sync::Mutex;
9+
use torrust_tracker_primitives::NumberOfDownloads;
10+
11+
use crate::databases::driver::Driver;
12+
use crate::databases::error::Error;
13+
use crate::databases::sqlx::traits::AsyncSchemaMigrator;
14+
15+
mod auth_key_store;
16+
mod schema_migrator;
17+
mod torrent_metrics_store;
18+
mod whitelist_store;
19+
20+
const DRIVER: Driver = Driver::MySQL;
21+
22+
pub(crate) struct MysqlSqlx {
23+
pool: MySqlPool,
24+
schema_ready: AtomicBool,
25+
schema_lock: Mutex<()>,
26+
}
27+
28+
impl MysqlSqlx {
29+
pub fn new(db_path: &str) -> Result<Self, Error> {
30+
let options = MySqlConnectOptions::from_str(db_path).map_err(|e| (e, DRIVER))?;
31+
32+
let pool = MySqlPoolOptions::new().connect_lazy_with(options);
33+
34+
Ok(Self {
35+
pool,
36+
schema_ready: AtomicBool::new(false),
37+
schema_lock: Mutex::new(()),
38+
})
39+
}
40+
41+
async fn ensure_schema(&self) -> Result<(), Error> {
42+
if self.schema_ready.load(Ordering::Acquire) {
43+
return Ok(());
44+
}
45+
46+
let _guard = self.schema_lock.lock().await;
47+
if self.schema_ready.load(Ordering::Acquire) {
48+
return Ok(());
49+
}
50+
51+
self.create_database_tables().await?;
52+
self.schema_ready.store(true, Ordering::Release);
53+
54+
Ok(())
55+
}
56+
57+
async fn load_torrent_aggregate_metric(&self, metric_name: &str) -> Result<Option<NumberOfDownloads>, Error> {
58+
let maybe_row = ::sqlx::query("SELECT value FROM torrent_aggregate_metrics WHERE metric_name = ?")
59+
.bind(metric_name)
60+
.fetch_optional(&self.pool)
61+
.await
62+
.map_err(|e| (e, DRIVER))?;
63+
64+
maybe_row
65+
.map(|row| {
66+
let value: i64 = row.try_get("value").map_err(|e| (e, DRIVER))?;
67+
u32::try_from(value).map_err(|e| Error::MalformedDatabaseRecord {
68+
message: e.to_string(),
69+
driver: DRIVER,
70+
})
71+
})
72+
.transpose()
73+
}
74+
75+
async fn save_torrent_aggregate_metric(&self, metric_name: &str, completed: NumberOfDownloads) -> Result<(), Error> {
76+
let insert = ::sqlx::query(
77+
"INSERT INTO torrent_aggregate_metrics (metric_name, value) VALUES (?, ?) ON DUPLICATE KEY UPDATE value = VALUES(value)",
78+
)
79+
.bind(metric_name)
80+
.bind(i64::from(completed))
81+
.execute(&self.pool)
82+
.await
83+
.map_err(|e| (e, DRIVER))?
84+
.rows_affected();
85+
86+
if insert == 0 {
87+
Err(Error::InsertFailed {
88+
location: std::panic::Location::caller(),
89+
driver: DRIVER,
90+
})
91+
} else {
92+
Ok(())
93+
}
94+
}
95+
}
96+
97+
#[cfg(all(test, feature = "db-compatibility-tests"))]
98+
mod tests {
99+
use std::sync::Arc;
100+
101+
use testcontainers::core::IntoContainerPort;
102+
use testcontainers::runners::AsyncRunner;
103+
use testcontainers::{ContainerAsync, GenericImage, ImageExt};
104+
use torrust_tracker_configuration::Core;
105+
106+
use super::MysqlSqlx;
107+
use crate::databases::sqlx::driver::tests::run_tests;
108+
use crate::databases::sqlx::traits::AsyncDatabase;
109+
110+
#[derive(Debug, Default)]
111+
struct StoppedMysqlContainer {}
112+
113+
impl StoppedMysqlContainer {
114+
async fn run(self, config: &MysqlConfiguration) -> Result<RunningMysqlContainer, Box<dyn std::error::Error + 'static>> {
115+
let image_tag = std::env::var("TORRUST_TRACKER_CORE_MYSQL_DRIVER_IMAGE_TAG").unwrap_or_else(|_| "8.0".to_string());
116+
117+
let container = GenericImage::new("mysql", image_tag.as_str())
118+
.with_exposed_port(config.internal_port.tcp())
119+
.with_env_var("MYSQL_ROOT_PASSWORD", config.db_root_password.clone())
120+
.with_env_var("MYSQL_DATABASE", config.database.clone())
121+
.with_env_var("MYSQL_ROOT_HOST", "%")
122+
.start()
123+
.await?;
124+
125+
Ok(RunningMysqlContainer::new(container, config.internal_port))
126+
}
127+
}
128+
129+
struct RunningMysqlContainer {
130+
container: ContainerAsync<GenericImage>,
131+
internal_port: u16,
132+
}
133+
134+
impl RunningMysqlContainer {
135+
fn new(container: ContainerAsync<GenericImage>, internal_port: u16) -> Self {
136+
Self {
137+
container,
138+
internal_port,
139+
}
140+
}
141+
142+
async fn stop(self) {
143+
self.container.stop().await.unwrap();
144+
}
145+
146+
async fn get_host(&self) -> url::Host {
147+
self.container.get_host().await.unwrap()
148+
}
149+
150+
async fn get_host_port_ipv4(&self) -> u16 {
151+
self.container.get_host_port_ipv4(self.internal_port).await.unwrap()
152+
}
153+
}
154+
155+
impl Default for MysqlConfiguration {
156+
fn default() -> Self {
157+
Self {
158+
internal_port: 3306,
159+
database: "torrust_tracker_test".to_string(),
160+
db_user: "root".to_string(),
161+
db_root_password: "test".to_string(),
162+
}
163+
}
164+
}
165+
166+
struct MysqlConfiguration {
167+
pub internal_port: u16,
168+
pub database: String,
169+
pub db_user: String,
170+
pub db_root_password: String,
171+
}
172+
173+
fn core_configuration(host: &url::Host, port: u16, mysql_configuration: &MysqlConfiguration) -> Core {
174+
let mut config = Core::default();
175+
176+
let database = mysql_configuration.database.clone();
177+
let db_user = mysql_configuration.db_user.clone();
178+
let db_password = mysql_configuration.db_root_password.clone();
179+
180+
config.database.path = format!("mysql://{db_user}:{db_password}@{host}:{port}/{database}");
181+
182+
config
183+
}
184+
185+
fn initialize_driver(config: &Core) -> Arc<Box<dyn AsyncDatabase>> {
186+
Arc::new(Box::new(MysqlSqlx::new(&config.database.path).unwrap()))
187+
}
188+
189+
#[tokio::test]
190+
async fn run_mysql_sqlx_driver_tests() -> Result<(), Box<dyn std::error::Error + 'static>> {
191+
if std::env::var("TORRUST_TRACKER_CORE_RUN_MYSQL_DRIVER_TEST").is_err() {
192+
println!("Skipping the MySQL sqlx driver tests.");
193+
return Ok(());
194+
}
195+
196+
let mysql_configuration = MysqlConfiguration::default();
197+
198+
let stopped_mysql_container = StoppedMysqlContainer::default();
199+
200+
let mysql_container = stopped_mysql_container.run(&mysql_configuration).await.unwrap();
201+
202+
let host = mysql_container.get_host().await;
203+
let port = mysql_container.get_host_port_ipv4().await;
204+
205+
let config = core_configuration(&host, port, &mysql_configuration);
206+
207+
let driver = initialize_driver(&config);
208+
209+
run_tests(&driver).await;
210+
211+
mysql_container.stop().await;
212+
213+
Ok(())
214+
}
215+
}

0 commit comments

Comments
 (0)