Skip to content

Commit b301c55

Browse files
committed
allow custom column types in migrations
1 parent 06e90ae commit b301c55

File tree

5 files changed

+52
-17
lines changed

5 files changed

+52
-17
lines changed

piccolo/apps/migrations/auto/diffable_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
6969
table_class_name=self.class_name,
7070
column_name=i._meta.name,
7171
column_class_name=i.__class__.__name__,
72+
column_class=i.__class__,
7273
params=i._meta.params,
7374
)
7475
for i in (set(self.columns) - set(value.columns))

piccolo/apps/migrations/auto/migration_manager.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,27 @@ def add_column(
186186
table_class_name: str,
187187
tablename: str,
188188
column_name: str,
189-
column_class_name: str,
189+
column_class_name: str = "",
190+
column_class: t.Optional[t.Type[Column]] = None,
190191
params: t.Dict[str, t.Any] = {},
191192
):
192-
column_class = getattr(column_types, column_class_name)
193+
"""
194+
Add a new column to the table.
195+
196+
:param column_class_name:
197+
The column type was traditionally specified as a string, using this
198+
variable. This didn't allow users to define custom column types
199+
though, which is why newer migrations directly reference a
200+
``Column`` subclass using ``column_class``.
201+
:param column_class:
202+
A direct reference to a ``Column`` subclass.
203+
204+
"""
205+
column_class = column_class or getattr(column_types, column_class_name)
206+
207+
if column_class is None:
208+
raise ValueError("Unrecognised column type")
209+
193210
cleaned_params = deserialise_params(params=params)
194211
column = column_class(**cleaned_params)
195212
column._meta.name = column_name

piccolo/apps/migrations/auto/operations.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
from piccolo.columns.base import Column
23
import typing as t
34

45

@@ -39,4 +40,5 @@ class AddColumn:
3940
table_class_name: str
4041
column_name: str
4142
column_class_name: str
43+
column_class: t.Type[Column]
4244
params: t.Dict[str, t.Any]

piccolo/apps/migrations/auto/schema_differ.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,9 @@ def _get_snapshot_table(
298298

299299
@property
300300
def alter_columns(self) -> AlterStatements:
301-
response = []
302-
extra_imports = []
303-
extra_definitions = []
301+
response: t.List[str] = []
302+
extra_imports: t.List[Import] = []
303+
extra_definitions: t.List[str] = []
304304
for table in self.schema:
305305
snapshot_table = self._get_snapshot_table(table.class_name)
306306
if snapshot_table:
@@ -351,30 +351,38 @@ def drop_columns(self) -> AlterStatements:
351351

352352
@property
353353
def add_columns(self) -> AlterStatements:
354-
response = []
355-
extra_imports = []
356-
extra_definitions = []
354+
response: t.List[str] = []
355+
extra_imports: t.List[Import] = []
356+
extra_definitions: t.List[str] = []
357357
for table in self.schema:
358358
snapshot_table = self._get_snapshot_table(table.class_name)
359359
if snapshot_table:
360360
delta: TableDelta = table - snapshot_table
361361
else:
362362
continue
363363

364-
for column in delta.add_columns:
364+
for add_column in delta.add_columns:
365365
if (
366-
column.column_name
366+
add_column.column_name
367367
in self.rename_columns_collection.new_column_names
368368
):
369369
continue
370370

371-
params = serialise_params(column.params)
371+
params = serialise_params(add_column.params)
372372
cleaned_params = params.params
373373
extra_imports.extend(params.extra_imports)
374374
extra_definitions.extend(params.extra_definitions)
375375

376+
column_class = add_column.column_class
377+
extra_imports.append(
378+
Import(
379+
module=column_class.__module__,
380+
target=column_class.__name__,
381+
)
382+
)
383+
376384
response.append(
377-
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', column_class_name='{column.column_class_name}', params={str(cleaned_params)})" # noqa: E501
385+
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{add_column.column_name}', column_class_name='{add_column.column_class_name}', column_class={column_class.__name__ if column_class else ''}, params={str(cleaned_params)})" # noqa: E501
378386
)
379387
return AlterStatements(
380388
statements=response,
@@ -399,9 +407,9 @@ def new_table_columns(self) -> AlterStatements:
399407
set(self.schema) - set(self.schema_snapshot)
400408
)
401409

402-
response = []
403-
extra_imports = []
404-
extra_definitions = []
410+
response: t.List[str] = []
411+
extra_imports: t.List[Import] = []
412+
extra_definitions: t.List[str] = []
405413
for table in new_tables:
406414
if (
407415
table.class_name
@@ -417,8 +425,15 @@ def new_table_columns(self) -> AlterStatements:
417425
extra_imports.extend(_params.extra_imports)
418426
extra_definitions.extend(_params.extra_definitions)
419427

428+
extra_imports.append(
429+
Import(
430+
module=column.__class__.__module__,
431+
target=column.__class__.__name__,
432+
)
433+
)
434+
420435
response.append(
421-
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', params={str(cleaned_params)})" # noqa: E501
436+
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', column_class={column.__class__.__name__}, params={str(cleaned_params)})" # noqa: E501
422437
)
423438
return AlterStatements(
424439
statements=response,

tests/apps/migrations/auto/test_schema_differ.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_add_column(self):
9595
self.assertTrue(len(schema_differ.add_columns.statements) == 1)
9696
self.assertEqual(
9797
schema_differ.add_columns.statements[0],
98-
"manager.add_column(table_class_name='Band', tablename='band', column_name='genre', column_class_name='Varchar', params={'length': 255, 'default': '', 'null': False, 'primary': False, 'key': False, 'unique': False, 'index': False})", # noqa
98+
"manager.add_column(table_class_name='Band', tablename='band', column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary': False, 'key': False, 'unique': False, 'index': False})", # noqa
9999
)
100100

101101
def test_drop_column(self):

0 commit comments

Comments
 (0)