Skip to content

Commit e409557

Browse files
committed
MigrationManager handles changes in column defaults
1 parent 822e968 commit e409557

File tree

7 files changed

+106
-14
lines changed

7 files changed

+106
-14
lines changed

piccolo/apps/migrations/auto/migration_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,19 @@ async def _run_alter_columns(self, backwards=False):
399399
column=column, boolean=unique
400400
).run()
401401

402+
if "default" in params.keys():
403+
# Don't compare to None or ..., as both as valid values.
404+
default = params.get("default")
405+
column = Column()
406+
column._meta._table = _Table
407+
column._meta._name = column_name
408+
if default is ...:
409+
await _Table.alter().drop_default(column=column).run()
410+
else:
411+
await _Table.alter().set_default(
412+
column=column, value=default
413+
).run()
414+
402415
# None is a valid value, so retrieve ellipsis if not found.
403416
digits = params.get("digits", ...)
404417
if digits is not ...:

piccolo/columns/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,24 @@ def __init__(
172172
params=kwargs,
173173
)
174174

175+
def _validate_arguments(
176+
default: t.Any, allowed_types: t.Tuple[t.Type[t.Any]], null: bool
177+
) -> bool:
178+
"""
179+
Make sure that the default values are of the allowed types. Also
180+
make sure that a value of None isn't the default if the column is not
181+
nullable.
182+
"""
183+
if default is None:
184+
if not null:
185+
raise ValueError(
186+
"A default value of None isn't allowed if the column is "
187+
"null = False."
188+
)
189+
return None in allowed_types
190+
else:
191+
return type(default) in allowed_types
192+
175193
def is_in(self, values: Iterable) -> Where:
176194
return Where(column=self, values=values, operator=In)
177195

piccolo/columns/column_types.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77

88
from piccolo.columns.base import Column, OnDelete, OnUpdate, ForeignKeyMeta
99
from piccolo.columns.operators.string import ConcatPostgres, ConcatSQLite
10-
from piccolo.custom_types import UUIDDefault
10+
from piccolo.custom_types import (
11+
DateDefault,
12+
TimeDefault,
13+
TimestampDefault,
14+
UUIDDefault,
15+
)
1116
from piccolo.querystring import Unquoted, QueryString
1217

1318
if t.TYPE_CHECKING:
@@ -120,7 +125,9 @@ class Varchar(Column):
120125
value_type = str
121126
concat_delegate: ConcatDelegate = ConcatDelegate()
122127

123-
def __init__(self, length: int = 255, default: str = "", **kwargs) -> None:
128+
def __init__(
129+
self, length: int = 255, default: t.Union[str, None] = "", **kwargs
130+
) -> None:
124131
self.length = length
125132
self.default = default
126133
kwargs.update({"length": length, "default": default})
@@ -168,7 +175,7 @@ class Text(Column):
168175
value_type = str
169176
concat_delegate: ConcatDelegate = ConcatDelegate()
170177

171-
def __init__(self, default: str = "", **kwargs) -> None:
178+
def __init__(self, default: t.Union[str, None] = "", **kwargs) -> None:
172179
self.default = default
173180
super().__init__(**kwargs)
174181

@@ -205,7 +212,7 @@ class Integer(Column):
205212

206213
math_delegate = MathDelegate()
207214

208-
def __init__(self, default: int = None, **kwargs) -> None:
215+
def __init__(self, default: t.Union[str, None] = 0, **kwargs) -> None:
209216
self.default = default
210217
kwargs.update({"default": default})
211218
super().__init__(**kwargs)
@@ -351,7 +358,9 @@ class Timestamp(Column):
351358

352359
value_type = datetime
353360

354-
def __init__(self, default: TimestampArg = None, **kwargs) -> None:
361+
def __init__(
362+
self, default: TimestampArg = TimestampDefault.now, **kwargs
363+
) -> None:
355364
self.default = default
356365
kwargs.update({"default": default})
357366
super().__init__(**kwargs)
@@ -360,7 +369,7 @@ def __init__(self, default: TimestampArg = None, **kwargs) -> None:
360369
class Date(Column):
361370
value_type = date
362371

363-
def __init__(self, default: DateArg = None, **kwargs) -> None:
372+
def __init__(self, default: DateArg = DateDefault.now, **kwargs) -> None:
364373
self.default = default
365374
kwargs.update({"default": default})
366375
super().__init__(**kwargs)
@@ -369,7 +378,7 @@ def __init__(self, default: DateArg = None, **kwargs) -> None:
369378
class Time(Column):
370379
value_type = time
371380

372-
def __init__(self, default: TimeArg = None, **kwargs) -> None:
381+
def __init__(self, default: TimeArg = TimeDefault.now, **kwargs) -> None:
373382
self.default = default
374383
kwargs.update({"default": default})
375384
super().__init__(**kwargs)
@@ -503,6 +512,8 @@ class ForeignKey(Integer):
503512
def __init__(
504513
self,
505514
references: t.Union[t.Type[Table], str],
515+
default: t.Union[int, None] = None,
516+
null: bool = True,
506517
on_delete: OnDelete = OnDelete.cascade,
507518
on_update: OnUpdate = OnUpdate.cascade,
508519
**kwargs,
@@ -522,7 +533,7 @@ def __init__(
522533
"on_update": on_update,
523534
}
524535
)
525-
super().__init__(**kwargs)
536+
super().__init__(default=default, null=null, **kwargs)
526537

527538
if t.TYPE_CHECKING:
528539
# This is here just for type inference - the actual value is set by

piccolo/query/methods/alter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ def drop_column(self, column: t.Union[str, Column]) -> Alter:
276276
self._drop.append(DropColumn(column))
277277
return self
278278

279+
def drop_default(self, column: t.Union[str, Column]) -> Alter:
280+
"""
281+
Band.alter().drop_default(Band.popularity)
282+
"""
283+
self._drop_default.append(DropDefault(column=column))
284+
return self
285+
279286
def drop_table(self) -> Alter:
280287
"""
281288
Band.alter().drop_table()
@@ -486,6 +493,7 @@ def querystrings(self) -> t.Sequence[QueryString]:
486493
self._rename_columns,
487494
self._rename_table,
488495
self._drop,
496+
self._drop_default,
489497
self._unique,
490498
self._null,
491499
self._set_length,

tests/apps/migrations/auto/test_migration_manager.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,22 @@ def test_alter_column_unique(self):
228228
response = self.run_sync("SELECT name FROM manager;")
229229
self.assertEqual(response, [{"name": "Dave"}, {"name": "Dave"}])
230230

231-
def _get_precision_and_scale(self):
231+
def _get_column_precision_and_scale(
232+
self, tablename="ticket", column_name="price"
233+
):
232234
return self.run_sync(
233235
"SELECT numeric_precision, numeric_scale "
234236
"FROM information_schema.COLUMNS "
235-
"WHERE TABLE_NAME = 'ticket' AND column_name = 'price';"
237+
f"WHERE table_name = '{tablename}' AND "
238+
f"column_name = '{column_name}';"
239+
)
240+
241+
def _get_column_default(self, tablename="manager", column_name="name"):
242+
return self.run_sync(
243+
"SELECT column_default "
244+
"FROM information_schema.COLUMNS "
245+
f"WHERE table_name = '{tablename}' "
246+
f"AND column_name = '{column_name}';"
236247
)
237248

238249
@postgres_only
@@ -251,18 +262,44 @@ def test_alter_column_digits(self):
251262
)
252263

253264
asyncio.run(manager.run())
254-
255265
self.assertEqual(
256-
self._get_precision_and_scale(),
266+
self._get_column_precision_and_scale(),
257267
[{"numeric_precision": 6, "numeric_scale": 2}],
258268
)
259269

260270
asyncio.run(manager.run_backwards())
261271
self.assertEqual(
262-
self._get_precision_and_scale(),
272+
self._get_column_precision_and_scale(),
263273
[{"numeric_precision": 5, "numeric_scale": 2}],
264274
)
265275

276+
@postgres_only
277+
def test_alter_column_set_default(self):
278+
"""
279+
Test altering a column default with MigrationManager.
280+
"""
281+
manager = MigrationManager()
282+
283+
manager.alter_column(
284+
table_class_name="Manager",
285+
tablename="manager",
286+
column_name="name",
287+
params={"default": "Unknown"},
288+
old_params={"default": ""},
289+
)
290+
291+
asyncio.run(manager.run())
292+
self.assertEqual(
293+
self._get_column_default(),
294+
[{"column_default": "'Unknown'::character varying"}],
295+
)
296+
297+
asyncio.run(manager.run_backwards())
298+
self.assertEqual(
299+
self._get_column_default(),
300+
[{"column_default": "''::character varying"}],
301+
)
302+
266303
@postgres_only
267304
@patch.object(BaseMigrationManager, "get_migration_managers")
268305
def test_drop_table(self, get_migration_managers: MagicMock):

tests/apps/migrations/auto/test_schema_differ.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,6 @@ def test_alter_column_precision(self):
176176
schema_differ.alter_columns[0],
177177
"manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', params={'digits': (5, 2)}, old_params={'digits': (4, 2)})", # noqa
178178
)
179+
180+
def test_alter_default(self):
181+
pass

tests/table/test_alter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def test_add(self):
7474
"""
7575
self.insert_row()
7676

77-
add_query = Band.alter().add_column("weight", Integer(null=True))
77+
add_query = Band.alter().add_column(
78+
"weight", Integer(null=True, default=None)
79+
)
7880
add_query.run_sync()
7981

8082
response = Band.raw("SELECT * FROM band").run_sync()

0 commit comments

Comments
 (0)