Skip to content

Commit 5e7d9a6

Browse files
authored
Array and JSON column migration fixes (piccolo-orm#213)
* create DDL class * some test fixes * fix mypy warnings * add tests for JSON and JSONB * add tests for arrays of varchars * reduce code repetition
1 parent 1c42d6a commit 5e7d9a6

File tree

11 files changed

+312
-111
lines changed

11 files changed

+312
-111
lines changed

piccolo/apps/migrations/auto/serialisation.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
###############################################################################
2222

2323

24+
def check_equality(self, other):
25+
if getattr(other, "__hash__", None) is not None:
26+
return self.__hash__() == other.__hash__()
27+
else:
28+
return False
29+
30+
2431
@dataclass
2532
class SerialisedBuiltin:
2633
builtin: t.Any
@@ -29,7 +36,7 @@ def __hash__(self):
2936
return hash(self.builtin.__name__)
3037

3138
def __eq__(self, other):
32-
return self.__hash__() == other.__hash__()
39+
return check_equality(self, other)
3340

3441
def __repr__(self):
3542
return self.builtin.__name__
@@ -43,7 +50,7 @@ def __hash__(self):
4350
return self.instance.__hash__()
4451

4552
def __eq__(self, other):
46-
return self.__hash__() == other.__hash__()
53+
return check_equality(self, other)
4754

4855
def __repr__(self):
4956
return repr_class_instance(self.instance)
@@ -58,7 +65,7 @@ def __hash__(self):
5865
return self.instance.__hash__()
5966

6067
def __eq__(self, other):
61-
return self.__hash__() == other.__hash__()
68+
return check_equality(self, other)
6269

6370
def __repr__(self):
6471
args = ", ".join(
@@ -78,7 +85,7 @@ def __hash__(self):
7885
return hash(self.__repr__())
7986

8087
def __eq__(self, other):
81-
return self.__hash__() == other.__hash__()
88+
return check_equality(self, other)
8289

8390
def __repr__(self):
8491
return f"{self.instance.__class__.__name__}.{self.instance.name}"
@@ -94,7 +101,7 @@ def __hash__(self):
94101
)
95102

96103
def __eq__(self, other):
97-
return self.__hash__() == other.__hash__()
104+
return check_equality(self, other)
98105

99106
def __repr__(self):
100107
tablename = self.table_type._meta.tablename
@@ -126,7 +133,7 @@ def __hash__(self):
126133
return hash(self.__repr__())
127134

128135
def __eq__(self, other):
129-
return self.__hash__() == other.__hash__()
136+
return check_equality(self, other)
130137

131138
def __repr__(self):
132139
class_name = self.enum_type.__name__
@@ -142,7 +149,7 @@ def __hash__(self):
142149
return hash(self.callable_.__name__)
143150

144151
def __eq__(self, other):
145-
return self.__hash__() == other.__hash__()
152+
return check_equality(self, other)
146153

147154
def __repr__(self):
148155
return self.callable_.__name__
@@ -156,7 +163,7 @@ def __hash__(self):
156163
return self.instance.int
157164

158165
def __eq__(self, other):
159-
return self.__hash__() == other.__hash__()
166+
return check_equality(self, other)
160167

161168
def __repr__(self):
162169
return f"UUID('{str(self.instance)}')"

piccolo/columns/base.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
NotLike,
3333
)
3434
from piccolo.columns.reference import LazyTableReference
35-
from piccolo.querystring import QueryString
3635
from piccolo.utils.warnings import colored_warning
3736

3837
if t.TYPE_CHECKING: # pragma: no cover
@@ -579,7 +578,16 @@ def get_sql_value(self, value: t.Any) -> t.Any:
579578
elif isinstance(value, list):
580579
# Convert to the array syntax.
581580
output = (
582-
"'{" + ", ".join([self.get_sql_value(i) for i in value]) + "}'"
581+
"'{"
582+
+ ", ".join(
583+
[
584+
f'"{i}"'
585+
if isinstance(i, str)
586+
else str(self.get_sql_value(i))
587+
for i in value
588+
]
589+
)
590+
+ "}'"
583591
)
584592
else:
585593
output = value
@@ -591,7 +599,7 @@ def column_type(self):
591599
return self.__class__.__name__.upper()
592600

593601
@property
594-
def querystring(self) -> QueryString:
602+
def ddl(self) -> str:
595603
"""
596604
Used when creating tables.
597605
"""
@@ -621,16 +629,9 @@ def querystring(self) -> QueryString:
621629
if not self._meta.primary_key:
622630
default = self.get_default_value()
623631
sql_value = self.get_sql_value(value=default)
624-
# Escape the value if it contains a pair of curly braces, otherwise
625-
# an empty value will appear in the compiled querystring.
626-
sql_value = (
627-
sql_value.replace("{}", "{{}}")
628-
if isinstance(sql_value, str)
629-
else sql_value
630-
)
631632
query += f" DEFAULT {sql_value}"
632633

633-
return QueryString(query)
634+
return query
634635

635636
def copy(self) -> Column:
636637
column: Column = copy.copy(self)
@@ -645,7 +646,7 @@ def __deepcopy__(self, memo) -> Column:
645646
return self.copy()
646647

647648
def __str__(self):
648-
return self.querystring.__str__()
649+
return self.ddl.__str__()
649650

650651
def __repr__(self):
651652
try:

piccolo/engine/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ async def batch(self, query: Query, batch_size: int = 100) -> Batch:
5050
async def run_querystring(self, querystring: QueryString, in_pool: bool):
5151
pass
5252

53+
@abstractmethod
54+
async def run_ddl(self, ddl: str, in_pool: bool = True):
55+
pass
56+
5357
async def check_version(self):
5458
"""
5559
Warn if the database version isn't supported.

piccolo/engine/postgres.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from piccolo.engine.base import Batch, Engine
88
from piccolo.engine.exceptions import TransactionError
9-
from piccolo.query.base import Query
9+
from piccolo.query.base import DDL, Query
1010
from piccolo.querystring import QueryString
1111
from piccolo.utils.lazy_loader import LazyLoader
1212
from piccolo.utils.sync import run_sync
@@ -100,11 +100,15 @@ def add(self, *query: Query):
100100
async def _run_queries(self, connection):
101101
async with connection.transaction():
102102
for query in self.queries:
103-
for querystring in query.querystrings:
104-
_query, args = querystring.compile_string(
105-
engine_type=self.engine.engine_type
106-
)
107-
await connection.execute(_query, *args)
103+
if isinstance(query, Query):
104+
for querystring in query.querystrings:
105+
_query, args = querystring.compile_string(
106+
engine_type=self.engine.engine_type
107+
)
108+
await connection.execute(_query, *args)
109+
elif isinstance(query, DDL):
110+
for ddl in query.ddl:
111+
await connection.execute(ddl)
108112

109113
self.queries = []
110114

@@ -391,6 +395,19 @@ async def run_querystring(
391395
else:
392396
return await self._run_in_new_connection(query, query_args)
393397

398+
async def run_ddl(self, ddl: str, in_pool: bool = True):
399+
if self.log_queries:
400+
print(ddl)
401+
402+
# If running inside a transaction:
403+
connection = self.transaction_connection.get()
404+
if connection:
405+
return await connection.fetch(ddl)
406+
elif in_pool and self.pool:
407+
return await self._run_in_pool(ddl)
408+
else:
409+
return await self._run_in_new_connection(ddl)
410+
394411
def atomic(self) -> Atomic:
395412
return Atomic(engine=self)
396413

piccolo/engine/sqlite.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from piccolo.engine.base import Batch, Engine
1313
from piccolo.engine.exceptions import TransactionError
14-
from piccolo.query.base import Query
14+
from piccolo.query.base import DDL, Query
1515
from piccolo.querystring import QueryString
1616
from piccolo.utils.encoding import dump_json, load_json
1717
from piccolo.utils.lazy_loader import LazyLoader
@@ -250,12 +250,17 @@ async def run(self):
250250

251251
try:
252252
for query in self.queries:
253-
for querystring in query.querystrings:
254-
await connection.execute(
255-
*querystring.compile_string(
256-
engine_type=self.engine.engine_type
253+
if isinstance(query, Query):
254+
for querystring in query.querystrings:
255+
await connection.execute(
256+
*querystring.compile_string(
257+
engine_type=self.engine.engine_type
258+
)
257259
)
258-
)
260+
elif isinstance(query, DDL):
261+
for ddl in query.ddl:
262+
await connection.execute(ddl)
263+
259264
except Exception as exception:
260265
await connection.execute("ROLLBACK")
261266
await connection.close()
@@ -513,6 +518,23 @@ async def run_querystring(
513518
table=querystring.table,
514519
)
515520

521+
async def run_ddl(self, ddl: str, in_pool: bool = False):
522+
"""
523+
Connection pools aren't currently supported - the argument is there
524+
for consistency with other engines.
525+
"""
526+
# If running inside a transaction:
527+
connection = self.transaction_connection.get()
528+
if connection:
529+
return await self._run_in_existing_connection(
530+
connection=connection,
531+
query=ddl,
532+
)
533+
534+
return await self._run_in_new_connection(
535+
query=ddl,
536+
)
537+
516538
def atomic(self) -> Atomic:
517539
return Atomic(engine=self)
518540

piccolo/query/base.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,92 @@ def __getattr__(self, name: str):
317317

318318
def __str__(self) -> str:
319319
return self.query.__str__()
320+
321+
322+
class DDL:
323+
324+
__slots__ = ("table",)
325+
326+
def __init__(self, table: t.Type[Table], **kwargs):
327+
self.table = table
328+
329+
@property
330+
def engine_type(self) -> str:
331+
engine = self.table._meta.db
332+
if engine:
333+
return engine.engine_type
334+
else:
335+
raise ValueError("Engine isn't defined.")
336+
337+
@property
338+
def sqlite_ddl(self) -> t.Sequence[str]:
339+
raise NotImplementedError
340+
341+
@property
342+
def postgres_ddl(self) -> t.Sequence[str]:
343+
raise NotImplementedError
344+
345+
@property
346+
def default_ddl(self) -> t.Sequence[str]:
347+
raise NotImplementedError
348+
349+
@property
350+
def ddl(self) -> t.Sequence[str]:
351+
"""
352+
Calls the correct underlying method, depending on the current engine.
353+
"""
354+
engine_type = self.engine_type
355+
if engine_type == "postgres":
356+
try:
357+
return self.postgres_ddl
358+
except NotImplementedError:
359+
return self.default_ddl
360+
elif engine_type == "sqlite":
361+
try:
362+
return self.sqlite_ddl
363+
except NotImplementedError:
364+
return self.default_ddl
365+
else:
366+
raise Exception(
367+
f"No querystring found for the {engine_type} engine."
368+
)
369+
370+
def __await__(self):
371+
"""
372+
If the user doesn't explicity call .run(), proxy to it as a
373+
convenience.
374+
"""
375+
return self.run().__await__()
376+
377+
async def run(self, in_pool=True):
378+
engine = self.table._meta.db
379+
if not engine:
380+
raise ValueError(
381+
f"Table {self.table._meta.tablename} has no db defined in "
382+
"_meta"
383+
)
384+
385+
if len(self.ddl) == 1:
386+
return await engine.run_ddl(self.ddl[0], in_pool=in_pool)
387+
else:
388+
responses = []
389+
# TODO - run in a transaction
390+
for ddl in self.ddl:
391+
response = await engine.run_ddl(ddl, in_pool=in_pool)
392+
responses.append(response)
393+
return responses
394+
395+
def run_sync(self, timed=False, *args, **kwargs):
396+
"""
397+
A convenience method for running the coroutine synchronously.
398+
"""
399+
coroutine = self.run(*args, **kwargs, in_pool=False)
400+
401+
if timed:
402+
with Timer():
403+
return run_sync(coroutine)
404+
else:
405+
return run_sync(coroutine)
406+
407+
def __str__(self) -> str:
408+
return self.ddl.__str__()

0 commit comments

Comments
 (0)