Skip to content

Commit e8c72f4

Browse files
committed
refactored so tuple is passed in, to help with auto migrations
1 parent 7f30d73 commit e8c72f4

File tree

11 files changed

+149
-43
lines changed

11 files changed

+149
-43
lines changed

piccolo/apps/migrations/auto/diffable_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
131131
for key, _ in delta.items()
132132
}
133133
)
134+
134135
if delta:
135136
alter_columns.append(
136137
AlterColumn(

piccolo/apps/migrations/auto/migration_manager.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,13 @@ async def _run_alter_columns(self, backwards=False):
399399
column=column, boolean=unique
400400
).run()
401401

402+
# None is a valid value, so retrieve ellipsis if not found.
403+
digits = params.get("digits", ...)
404+
if digits is not ...:
405+
await _Table.alter().set_digits(
406+
column=column.column_name, digits=digits,
407+
).run()
408+
402409
async def _run_drop_tables(self, backwards=False):
403410
if backwards:
404411
for diffable_table in self.drop_tables:
@@ -410,7 +417,9 @@ async def _run_drop_tables(self, backwards=False):
410417
await _Table.create_table().run()
411418
else:
412419
for diffable_table in self.drop_tables:
413-
await diffable_table.to_table_class().alter().drop_table().run()
420+
await (
421+
diffable_table.to_table_class().alter().drop_table().run()
422+
)
414423

415424
async def _run_drop_columns(self, backwards=False):
416425
if backwards:

piccolo/columns/column_types.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -244,30 +244,41 @@ class Numeric(Column):
244244

245245
@property
246246
def column_type(self):
247-
if self.precision and self.scale:
247+
if self.digits:
248248
return f"NUMERIC({self.precision}, {self.scale})"
249249
else:
250250
return "NUMERIC"
251251

252+
@property
253+
def precision(self):
254+
"""
255+
The total number of digits allowed.
256+
"""
257+
return self.digits[0]
258+
259+
@property
260+
def scale(self):
261+
"""
262+
The number of digits after the decimal point.
263+
"""
264+
return self.digits[1]
265+
252266
def __init__(
253267
self,
254-
precision: t.Optional[int] = None,
255-
scale: t.Optional[int] = None,
268+
digits: t.Optional[t.Tuple[int, int]] = None,
256269
default: decimal.Decimal = decimal.Decimal(0.0),
257270
**kwargs,
258271
) -> None:
259-
if (precision, scale).count(None) == 1:
272+
if isinstance(digits, tuple) and len(digits) != 2:
260273
raise ValueError(
261-
"The precision and scale args should either both be None, or "
262-
"neither be None."
274+
"The `digits` argument should be a tuple of length 2, with "
275+
"the first value being the precision, and the second value "
276+
"being the scale."
263277
)
264278

265279
self.default = default
266-
self.precision = precision
267-
self.scale = scale
268-
kwargs.update(
269-
{"default": default, "precision": precision, "scale": scale}
270-
)
280+
self.digits = digits
281+
kwargs.update({"default": default, "digits": digits})
271282
super().__init__(**kwargs)
272283

273284

piccolo/query/methods/alter.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -186,20 +186,21 @@ def querystring(self) -> QueryString:
186186

187187

188188
@dataclass
189-
class SetPrecision(AlterColumnStatement):
189+
class SetDigits(AlterColumnStatement):
190190

191-
__slots__ = ("precision", "scale", "column_type")
191+
__slots__ = ("digits", "column_type")
192192

193-
precision: t.Optional[int]
194-
scale: t.Optional[int]
193+
digits: t.Optional[t.Tuple[int, int]]
195194
column_type: str
196195

197196
@property
198197
def querystring(self) -> QueryString:
199-
if self.precision and self.scale:
198+
if self.digits is not None:
199+
precision = self.digits[0]
200+
scale = self.digits[1]
200201
return QueryString(
201202
f"ALTER COLUMN {self.column_name} TYPE "
202-
f"{self.column_type}({self.precision}, {self.scale})"
203+
f"{self.column_type}({precision}, {scale})"
203204
)
204205
else:
205206
return QueryString(
@@ -220,7 +221,7 @@ class Alter(Query):
220221
"_rename_table",
221222
"_set_length",
222223
"_unique",
223-
"_set_precision",
224+
"_set_digits",
224225
)
225226

226227
def __init__(self, table: t.Type[Table]):
@@ -235,7 +236,7 @@ def __init__(self, table: t.Type[Table]):
235236
self._rename_table: t.List[RenameTable] = []
236237
self._set_length: t.List[SetLength] = []
237238
self._unique: t.List[Unique] = []
238-
self._set_precision: t.List[SetPrecision] = []
239+
self._set_digits: t.List[SetDigits] = []
239240

240241
def add_column(self, name: str, column: Column) -> Alter:
241242
"""
@@ -374,24 +375,18 @@ def add_foreign_key_constraint(
374375
)
375376
return self
376377

377-
def set_precision(
378+
def set_digits(
378379
self,
379380
column: t.Union[str, Numeric],
380-
precision: t.Optional[int],
381-
scale: t.Optional[int],
381+
digits: t.Optional[t.Tuple[int, int]],
382382
):
383383
column_type = (
384384
column.__class__.__name__.upper()
385385
if isinstance(column, Numeric)
386386
else "NUMERIC"
387387
)
388-
self._set_precision.append(
389-
SetPrecision(
390-
precision=precision,
391-
scale=scale,
392-
column=column,
393-
column_type=column_type,
394-
)
388+
self._set_digits.append(
389+
SetDigits(digits=digits, column=column, column_type=column_type,)
395390
)
396391
return self
397392

@@ -459,7 +454,7 @@ def querystrings(self) -> t.Sequence[QueryString]:
459454
self._unique,
460455
self._null,
461456
self._set_length,
462-
self._set_precision,
457+
self._set_digits,
463458
)
464459
]
465460

tests/apps/migrations/auto/test_migration_manager.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def test_rename_table(self):
167167
self.assertEqual(response, [{"id": 1, "name": "Dave"}])
168168

169169
@postgres_only
170-
def test_alter_column(self):
170+
def test_alter_column_unique(self):
171171
"""
172-
Test altering a column with MigrationManager.
172+
Test altering a column uniqueness with MigrationManager.
173173
"""
174174
manager = MigrationManager()
175175

@@ -185,7 +185,8 @@ def test_alter_column(self):
185185

186186
with self.assertRaises(UniqueViolationError):
187187
self.run_sync(
188-
"INSERT INTO manager VALUES (default, 'Dave'), (default, 'Dave');"
188+
"INSERT INTO manager VALUES "
189+
"(default, 'Dave'), (default, 'Dave');"
189190
)
190191

191192
# Reverse
@@ -196,6 +197,41 @@ def test_alter_column(self):
196197
response = self.run_sync("SELECT name FROM manager;")
197198
self.assertEqual(response, [{"name": "Dave"}, {"name": "Dave"}])
198199

200+
def _get_precision_and_scale(self):
201+
return self.run_sync(
202+
"SELECT numeric_precision, numeric_scale "
203+
"FROM information_schema.COLUMNS "
204+
"WHERE TABLE_NAME = 'ticket' AND column_name = 'price';"
205+
)
206+
207+
@postgres_only
208+
def test_alter_column_digits(self):
209+
"""
210+
Test altering a column digits with MigrationManager.
211+
"""
212+
manager = MigrationManager()
213+
214+
manager.alter_column(
215+
table_class_name="Ticket",
216+
tablename="ticket",
217+
column_name="price",
218+
params={"digits": (6, 2)},
219+
old_params={"digits": (5, 2)},
220+
)
221+
222+
asyncio.run(manager.run())
223+
224+
self.assertEqual(
225+
self._get_precision_and_scale(),
226+
[{"numeric_precision": 6, "numeric_scale": 2}],
227+
)
228+
229+
asyncio.run(manager.run_backwards())
230+
self.assertEqual(
231+
self._get_precision_and_scale(),
232+
[{"numeric_precision": 5, "numeric_scale": 2}],
233+
)
234+
199235
@patch.object(BaseMigrationManager, "get_migration_managers")
200236
def test_drop_table(self, get_migration_managers: MagicMock):
201237
self.run_sync("DROP TABLE IF EXISTS musician;")

tests/apps/migrations/auto/test_schema_differ.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import typing as t
22
from unittest import TestCase
33

4-
from piccolo.columns.column_types import Varchar
4+
from piccolo.columns.column_types import Varchar, Numeric
55
from piccolo.apps.migrations.auto import (
66
DiffableTable,
77
SchemaDiffer,
@@ -148,3 +148,31 @@ def test_rename_column(self):
148148
schema_differ.rename_columns[0],
149149
"manager.rename_column(table_class_name='Band', tablename='band', old_column_name='title', new_column_name='name')", # noqa
150150
)
151+
152+
def test_alter_column_precision(self):
153+
price_1 = Numeric(digits=(4, 2))
154+
price_1._meta.name = "price"
155+
156+
price_2 = Numeric(digits=(5, 2))
157+
price_2._meta.name = "price"
158+
159+
schema: t.List[DiffableTable] = [
160+
DiffableTable(
161+
class_name="Ticket", tablename="ticket", columns=[price_1],
162+
)
163+
]
164+
schema_snapshot: t.List[DiffableTable] = [
165+
DiffableTable(
166+
class_name="Ticket", tablename="ticket", columns=[price_2],
167+
)
168+
]
169+
170+
schema_differ = SchemaDiffer(
171+
schema=schema, schema_snapshot=schema_snapshot, auto_input="y"
172+
)
173+
174+
self.assertTrue(len(schema_differ.alter_columns) == 1)
175+
self.assertEqual(
176+
schema_differ.alter_columns[0],
177+
"manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', params={'digits': (5, 2)}, old_params={'digits': (4, 2)})", # noqa
178+
)

tests/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def create_table(self):
5454
popularity SMALLINT
5555
);"""
5656
)
57+
self.run_sync(
58+
"""
59+
CREATE TABLE ticket (
60+
id SERIAL PRIMARY KEY,
61+
price NUMERIC(5,2)
62+
);"""
63+
)
5764
elif ENGINE.engine_type == "sqlite":
5865
self.run_sync(
5966
"""
@@ -71,6 +78,13 @@ def create_table(self):
7178
popularity SMALLINT
7279
);"""
7380
)
81+
self.run_sync(
82+
"""
83+
CREATE TABLE ticket (
84+
id SERIAL PRIMARY KEY,
85+
price NUMERIC(5,2)
86+
);"""
87+
)
7488
else:
7589
raise Exception("Unrecognised engine")
7690

@@ -141,6 +155,7 @@ def insert_many_rows(self, row_count=10000):
141155
def drop_table(self):
142156
self.run_sync("DROP TABLE IF EXISTS band CASCADE;")
143157
self.run_sync("DROP TABLE IF EXISTS manager CASCADE;")
158+
self.run_sync("DROP TABLE IF EXISTS ticket CASCADE;")
144159

145160
def setUp(self):
146161
self.create_table()

tests/columns/test_numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class MyTable(Table):
99
column_a = Numeric()
10-
column_b = Numeric(precision=3, scale=2)
10+
column_b = Numeric(digits=(3, 2))
1111

1212

1313
class TestNumeric(TestCase):

tests/conftest.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77

88

99
async def drop_tables():
10-
for table in ["venue", "concert", "band", "manager", "migration"]:
10+
for table in [
11+
"venue",
12+
"concert",
13+
"band",
14+
"manager",
15+
"ticket",
16+
"migration",
17+
]:
1118
await ENGINE._run_in_new_connection(f"DROP TABLE IF EXISTS {table}")
1219

1320

tests/example_project/tables.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from piccolo.table import Table
2-
from piccolo.columns import Varchar, ForeignKey, Integer
2+
from piccolo.columns import Varchar, ForeignKey, Integer, Numeric
33

44

55
###############################################################################
@@ -29,3 +29,7 @@ class Concert(Table):
2929
band_1 = ForeignKey(Band)
3030
band_2 = ForeignKey(Band)
3131
venue = ForeignKey(Venue)
32+
33+
34+
class Ticket(Table):
35+
price = Numeric(digits=(5, 2))

0 commit comments

Comments
 (0)