Skip to content

Commit 72eb453

Browse files
committed
scoping transaction to engine
1 parent 01d54ee commit 72eb453

File tree

4 files changed

+79
-19
lines changed

4 files changed

+79
-19
lines changed

piccolo/engine/postgres.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
from piccolo.utils.sync import run_sync
1818

1919

20-
pg_transaction_connection = contextvars.ContextVar(
21-
"pg_transaction_connection", default=None
22-
)
23-
24-
2520
@dataclass
2621
class AsyncBatch(Batch):
2722

@@ -136,6 +131,8 @@ def run_sync(self):
136131

137132

138133
###############################################################################
134+
135+
139136
class Transaction:
140137
"""
141138
Used for wrapping queries in a transaction, using a context manager.
@@ -162,7 +159,7 @@ async def __aenter__(self):
162159

163160
self.transaction = self.connection.transaction()
164161
await self.transaction.start()
165-
self.context = pg_transaction_connection.set(self.connection)
162+
self.context = self.engine.transaction_connection.set(self.connection)
166163

167164
async def commit(self):
168165
await self.transaction.commit()
@@ -181,9 +178,9 @@ async def __aexit__(self, exception_type, exception, traceback):
181178
else:
182179
self.connection.close()
183180

184-
pg_transaction_connection.reset(self.context)
181+
self.engine.transaction_connection.reset(self.context)
185182

186-
return exception is not None
183+
return exception is None
187184

188185

189186
###############################################################################
@@ -199,6 +196,10 @@ class PostgresEngine(Engine):
199196
def __init__(self, config: t.Dict[str, t.Any]) -> None:
200197
self.config = config
201198
self.pool: t.Optional[Pool] = None
199+
database_name = config.get("database", "Unknown")
200+
self.transaction_connection = contextvars.ContextVar(
201+
f"pg_transaction_connection_{database_name}", default=None
202+
)
202203
super().__init__()
203204

204205
def get_version(self) -> float:
@@ -283,7 +284,7 @@ async def run_querystring(
283284
)
284285

285286
# If running inside a transaction:
286-
connection = pg_transaction_connection.get()
287+
connection = self.transaction_connection.get()
287288
if connection:
288289
return await connection.fetch(query, *query_args)
289290
elif in_pool and self.pool:

piccolo/engine/sqlite.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
from piccolo.utils.sync import run_sync
1515

1616

17-
sqlite_transaction_connection = contextvars.ContextVar(
18-
"sqlite_transaction_connection", default=None
19-
)
20-
21-
2217
@dataclass
2318
class AsyncBatch(Batch):
2419

@@ -135,7 +130,7 @@ def __init__(self, engine: SQLiteEngine):
135130
async def __aenter__(self):
136131
self.connection = await self.engine.get_connection()
137132
await self.connection.execute("BEGIN")
138-
self.context = sqlite_transaction_connection.set(self.connection)
133+
self.context = self.engine.transaction_connection.set(self.connection)
139134

140135
async def __aexit__(self, exception_type, exception, traceback):
141136
if exception:
@@ -144,7 +139,9 @@ async def __aexit__(self, exception_type, exception, traceback):
144139
await self.connection.execute("COMMIT")
145140

146141
await self.connection.close()
147-
return exception is not None
142+
self.engine.transaction_connection.reset(self.context)
143+
144+
return exception is None
148145

149146

150147
###############################################################################
@@ -186,6 +183,11 @@ def __init__(
186183
}
187184
)
188185
self.connection_kwargs = connection_kwargs
186+
187+
self.transaction_connection = contextvars.ContextVar(
188+
f"sqlite_transaction_connection_{path}", default=None
189+
)
190+
189191
super().__init__()
190192

191193
@property
@@ -297,7 +299,7 @@ async def run_querystring(
297299
)
298300

299301
# If running inside a transaction:
300-
connection = sqlite_transaction_connection.get()
302+
connection = self.transaction_connection.get()
301303
if connection:
302304
return await self._run_in_existing_connection(
303305
connection=connection,
@@ -307,7 +309,7 @@ async def run_querystring(
307309
)
308310

309311
return await self._run_in_new_connection(
310-
query=query, args=query_args, query_type=querystring.query_type,
312+
query=query, args=query_args, query_type=querystring.query_type
311313
)
312314

313315
def atomic(self) -> Atomic:
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import asyncio
2+
from unittest import TestCase
3+
4+
from piccolo.columns.column_types import Varchar
5+
from piccolo.engine.sqlite import SQLiteEngine
6+
from piccolo.table import Table
7+
8+
from ..base import sqlite_only
9+
10+
11+
ENGINE_1 = SQLiteEngine(path="engine1.sqlite")
12+
ENGINE_2 = SQLiteEngine(path="engine2.sqlite")
13+
14+
15+
class Musician(Table, db=ENGINE_1):
16+
name = Varchar(length=100)
17+
18+
19+
class Roadie(Table, db=ENGINE_2):
20+
name = Varchar(length=100)
21+
22+
23+
@sqlite_only
24+
class TestNestedTransaction(TestCase):
25+
def setUp(self):
26+
ENGINE_1.remove_db_file()
27+
ENGINE_2.remove_db_file()
28+
29+
def tearDown(self):
30+
ENGINE_1.remove_db_file()
31+
ENGINE_2.remove_db_file()
32+
33+
async def run_nested(self):
34+
"""
35+
Make sure nested transactions which reference different databases work
36+
as expected.
37+
"""
38+
async with Musician._meta.db.transaction():
39+
await Musician.create_table().run()
40+
await Musician(name="Bob").save().run()
41+
42+
async with Roadie._meta.db.transaction():
43+
await Roadie.create_table().run()
44+
await Roadie(name="Dave").save().run()
45+
46+
self.assertTrue(await Musician.table_exists().run())
47+
musician = await Musician.select("name").first().run()
48+
self.assertTrue(musician["name"] == "Bob")
49+
50+
self.assertTrue(await Roadie.table_exists().run())
51+
roadie = await Roadie.select("name").first().run()
52+
self.assertTrue(roadie["name"] == "Dave")
53+
54+
def test_nested(self):
55+
asyncio.run(self.run_nested())

tests/engine/test_transaction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
22
from unittest import TestCase
33

4+
from piccolo.engine.sqlite import SQLiteEngine
5+
46
from ..example_project.tables import Band, Manager
5-
from ..base import postgres_only
7+
from ..base import postgres_only, sqlite_only
68

79

810
class TestAtomic(TestCase):

0 commit comments

Comments
 (0)