1+ from __future__ import annotations
12from dataclasses import dataclass
3+ import logging
24import os
35import sqlite3
46import typing as t
1113from piccolo .querystring import QueryString
1214
1315
16+ logger = logging .getLogger (__file__ )
17+
18+
1419@dataclass
1520class AsyncBatch (Batch ):
1621
@@ -41,7 +46,7 @@ async def __anext__(self):
4146 return response
4247
4348 async def __aenter__ (self ):
44- querystring = self .query .querystring [0 ]
49+ querystring = self .query .querystrings [0 ]
4550 template , template_args = querystring .compile_string ()
4651
4752 self ._cursor = await self .connection .execute (template , * template_args )
@@ -56,6 +61,52 @@ async def __aexit__(self, exc_type, exc, tb):
5661###############################################################################
5762
5863
64+ class Transaction :
65+ """
66+ Usage:
67+
68+ transaction = engine.transaction()
69+ transaction.add(Foo.create_table())
70+
71+ # Either:
72+ transaction.run_sync()
73+ await transaction.run()
74+ """
75+
76+ __slots__ = ("engine" , "queries" )
77+
78+ def __init__ (self , engine : SQLiteEngine ):
79+ self .engine = engine
80+ self .queries : t .List [Query ] = []
81+
82+ def add (self , * query : Query ):
83+ self .queries += list (query )
84+
85+ async def run (self ):
86+ for query in self .queries :
87+ await self .engine .run ("BEGIN" )
88+ try :
89+ for querystring in query .querystrings :
90+ await self .engine .run (
91+ * querystring .compile_string (
92+ engine_type = self .engine_type
93+ ),
94+ query_type = querystring .query_type ,
95+ )
96+ except Exception as exception :
97+ logger .error (exception )
98+ await self .engine .run ("ROLLBACK" )
99+ else :
100+ await self .engine .run ("COMMIT" )
101+ self .queries = []
102+
103+ def run_sync (self ):
104+ return run_sync (self ._run ())
105+
106+
107+ ###############################################################################
108+
109+
59110def dict_factory (cursor , row ):
60111 d = {}
61112 for idx , col in enumerate (cursor .description ):
@@ -120,7 +171,9 @@ async def batch(self, query: Query, batch_size=100) -> AsyncBatch:
120171
121172 async def get_connection (self ) -> Connection :
122173 connection = await aiosqlite .connect (
123- self .path , detect_types = sqlite3 .PARSE_DECLTYPES
174+ self .path ,
175+ detect_types = sqlite3 .PARSE_DECLTYPES ,
176+ isolation_level = None ,
124177 )
125178 connection .row_factory = dict_factory
126179 return connection
0 commit comments