Skip to content

Commit 3969a25

Browse files
authored
Merge pull request piccolo-orm#71 from piccolo-orm/migration_test_coverage
Migration test coverage
2 parents d66e6a4 + 7af1d69 commit 3969a25

File tree

6 files changed

+252
-28
lines changed

6 files changed

+252
-28
lines changed

piccolo/apps/migrations/auto/diffable_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
102102
for key, _ in delta.items()
103103
}
104104

105-
if delta:
105+
if delta or (column.__class__ != existing_column.__class__):
106106
alter_columns.append(
107107
AlterColumn(
108108
table_class_name=self.class_name,

piccolo/query/methods/alter.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,30 @@ def querystring(self) -> QueryString:
9393

9494
@dataclass
9595
class SetColumnType(AlterStatement):
96-
__slots__ = ("old_column", "new_column")
96+
"""
97+
:param using_expression:
98+
Postgres can't automatically convert between certain column types. You
99+
can tell Postgres which action to take. For example
100+
`my_column_name::integer`.
101+
102+
"""
97103

98104
old_column: Column
99105
new_column: Column
106+
using_expression: t.Optional[str] = None
100107

101108
@property
102109
def querystring(self) -> QueryString:
103110
if self.new_column._meta._table is None:
104111
self.new_column._meta._table = self.old_column._meta.table
105112

106113
column_name = self.old_column._meta.name
107-
return QueryString(
114+
query = (
108115
f"ALTER COLUMN {column_name} TYPE {self.new_column.column_type}"
109116
)
117+
if self.using_expression is not None:
118+
query += f" USING {self.using_expression}"
119+
return QueryString(query)
110120

111121

112122
@dataclass
@@ -353,12 +363,21 @@ def rename_column(
353363
self._rename_columns.append(RenameColumn(column, new_name))
354364
return self
355365

356-
def set_column_type(self, old_column: Column, new_column: Column) -> Alter:
366+
def set_column_type(
367+
self,
368+
old_column: Column,
369+
new_column: Column,
370+
using_expression: t.Optional[str] = None,
371+
) -> Alter:
357372
"""
358373
Change the type of a column.
359374
"""
360375
self._set_column_type.append(
361-
SetColumnType(old_column=old_column, new_column=new_column)
376+
SetColumnType(
377+
old_column=old_column,
378+
new_column=new_column,
379+
using_expression=using_expression,
380+
)
362381
)
363382
return self
364383

tests/apps/migrations/auto/test_migration_manager.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from asyncpg.exceptions import UniqueViolationError
55
from piccolo.apps.migrations.auto import MigrationManager
66
from piccolo.apps.migrations.commands.base import BaseMigrationManager
7-
from piccolo.columns import Varchar
7+
from piccolo.columns import Varchar, Text
88
from piccolo.columns.base import OnDelete, OnUpdate
99

1010
from tests.example_app.tables import Manager
@@ -146,6 +146,36 @@ def test_add_column(self):
146146
response = self.run_sync("SELECT * FROM manager;")
147147
self.assertEqual(response, [{"id": 1, "name": "Dave"}])
148148

149+
@postgres_only
150+
def test_add_column_with_index(self):
151+
"""
152+
Test adding a column with an index to a MigrationManager.
153+
"""
154+
manager = MigrationManager()
155+
manager.add_column(
156+
table_class_name="Manager",
157+
tablename="manager",
158+
column_name="email",
159+
column_class_name="Varchar",
160+
params={
161+
"length": 100,
162+
"default": "",
163+
"null": True,
164+
"primary": False,
165+
"key": False,
166+
"unique": True,
167+
"index": True,
168+
},
169+
)
170+
index_name = Manager._get_index_name(["email"])
171+
172+
asyncio.run(manager.run())
173+
self.assertTrue(index_name in Manager.indexes().run_sync())
174+
175+
# Reverse
176+
asyncio.run(manager.run_backwards())
177+
self.assertTrue(index_name not in Manager.indexes().run_sync())
178+
149179
@postgres_only
150180
def test_add_foreign_key_self_column(self):
151181
"""
@@ -310,6 +340,36 @@ def test_alter_column_unique(self):
310340
response = self.run_sync("SELECT name FROM manager;")
311341
self.assertEqual(response, [{"name": "Dave"}, {"name": "Dave"}])
312342

343+
@postgres_only
344+
def test_alter_column_set_null(self):
345+
"""
346+
Test altering whether a column is nullable with MigrationManager.
347+
"""
348+
manager = MigrationManager()
349+
350+
manager.alter_column(
351+
table_class_name="Manager",
352+
tablename="manager",
353+
column_name="name",
354+
params={"null": True},
355+
old_params={"null": False},
356+
)
357+
358+
asyncio.run(manager.run())
359+
self.assertTrue(
360+
self.get_postgres_is_nullable(
361+
tablename="manager", column_name="name"
362+
)
363+
)
364+
365+
# Reverse
366+
asyncio.run(manager.run_backwards())
367+
self.assertFalse(
368+
self.get_postgres_is_nullable(
369+
tablename="manager", column_name="name"
370+
)
371+
)
372+
313373
def _get_column_precision_and_scale(
314374
self, tablename="ticket", column_name="price"
315375
):
@@ -467,6 +527,68 @@ def test_alter_column_add_index(self):
467527
not in Manager.indexes().run_sync()
468528
)
469529

530+
@postgres_only
531+
def test_alter_column_set_type(self):
532+
"""
533+
Test altering a column to change it's type with MigrationManager.
534+
"""
535+
manager = MigrationManager()
536+
537+
manager.alter_column(
538+
table_class_name="Manager",
539+
tablename="manager",
540+
column_name="name",
541+
params={},
542+
old_params={},
543+
column_class=Text,
544+
old_column_class=Varchar,
545+
)
546+
547+
asyncio.run(manager.run())
548+
column_type_str = self.get_postgres_column_type(
549+
tablename="manager", column_name="name"
550+
)
551+
self.assertEqual(column_type_str, "TEXT")
552+
553+
asyncio.run(manager.run_backwards())
554+
column_type_str = self.get_postgres_column_type(
555+
tablename="manager", column_name="name"
556+
)
557+
self.assertEqual(column_type_str, "CHARACTER VARYING")
558+
559+
@postgres_only
560+
def test_alter_column_set_length(self):
561+
"""
562+
Test altering a Varchar column's length with MigrationManager.
563+
"""
564+
manager = MigrationManager()
565+
566+
manager.alter_column(
567+
table_class_name="Manager",
568+
tablename="manager",
569+
column_name="name",
570+
params={"length": 500},
571+
old_params={"length": 200},
572+
column_class=Text,
573+
old_column_class=Varchar,
574+
)
575+
576+
asyncio.run(manager.run())
577+
self.assertEqual(
578+
self.get_postgres_varchar_length(
579+
tablename="manager", column_name="name"
580+
),
581+
500,
582+
)
583+
584+
asyncio.run(manager.run_backwards())
585+
self.assertEqual(
586+
self.get_postgres_varchar_length(
587+
tablename="manager", column_name="name"
588+
),
589+
200,
590+
)
591+
470592
@postgres_only
471593
@patch.object(BaseMigrationManager, "get_migration_managers")
472594
def test_drop_table(self, get_migration_managers: MagicMock):

tests/apps/migrations/auto/test_schema_differ.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,31 @@ def test_add_table(self):
1414
"""
1515
Test adding a new table.
1616
"""
17-
pass
17+
name_column = Varchar()
18+
name_column._meta.name = "name"
19+
schema: t.List[DiffableTable] = [
20+
DiffableTable(
21+
class_name="Band", tablename="band", columns=[name_column]
22+
)
23+
]
24+
schema_snapshot: t.List[DiffableTable] = []
25+
schema_differ = SchemaDiffer(
26+
schema=schema, schema_snapshot=schema_snapshot, auto_input="y"
27+
)
28+
29+
create_tables = schema_differ.create_tables
30+
self.assertTrue(len(create_tables.statements) == 1)
31+
self.assertEqual(
32+
create_tables.statements[0],
33+
"manager.add_table('Band', tablename='band')",
34+
)
35+
36+
new_table_columns = schema_differ.new_table_columns
37+
self.assertTrue(len(new_table_columns.statements) == 1)
38+
self.assertEqual(
39+
new_table_columns.statements[0],
40+
"manager.add_column(table_class_name='Band', tablename='band', column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary': False, 'key': False, 'unique': False, 'index': False})", # noqa
41+
)
1842

1943
def test_drop_table(self):
2044
"""

tests/base.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,60 @@ def run_sync(self, query):
5050
_Table = type("_Table", (Table,), {})
5151
return _Table.raw(query).run_sync()
5252

53-
def table_exists(self, tablename: str):
53+
def table_exists(self, tablename: str) -> bool:
5454
_Table: t.Type[Table] = type(tablename.upper(), (Table,), {})
5555
_Table._meta.tablename = tablename
5656
return _Table.table_exists().run_sync()
5757

58+
###########################################################################
59+
60+
# Postgres specific utils
61+
62+
def get_postgres_column_definition(
63+
self, tablename: str, column_name: str
64+
) -> t.Dict[str, t.Any]:
65+
query = """
66+
SELECT * FROM information_schema.columns
67+
WHERE table_name = '{tablename}'
68+
AND table_catalog = 'piccolo'
69+
AND column_name = '{column_name}'
70+
""".format(
71+
tablename=tablename, column_name=column_name
72+
)
73+
response = self.run_sync(query)
74+
return response[0]
75+
76+
def get_postgres_column_type(
77+
self, tablename: str, column_name: str
78+
) -> str:
79+
"""
80+
Fetches the column type as a string, from the database.
81+
"""
82+
return self.get_postgres_column_definition(
83+
tablename=tablename, column_name=column_name
84+
)["data_type"].upper()
85+
86+
def get_postgres_is_nullable(self, tablename, column_name: str) -> bool:
87+
"""
88+
Fetches whether the column is defined as nullable, from the database.
89+
"""
90+
return (
91+
self.get_postgres_column_definition(
92+
tablename=tablename, column_name=column_name
93+
)["is_nullable"].upper()
94+
== "YES"
95+
)
96+
97+
def get_postgres_varchar_length(self, tablename, column_name: str) -> int:
98+
"""
99+
Fetches whether the column is defined as nullable, from the database.
100+
"""
101+
return self.get_postgres_column_definition(
102+
tablename=tablename, column_name=column_name
103+
)["character_maximum_length"]
104+
105+
###########################################################################
106+
58107
def create_tables(self):
59108
if ENGINE.engine_type == "postgres":
60109
self.run_sync(

tests/table/test_alter.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from piccolo.columns.column_types import Varchar
21
from unittest import TestCase
32

4-
from piccolo.columns import BigInt, Integer, Numeric
3+
from piccolo.columns import BigInt, Integer, Numeric, Varchar
54
from piccolo.table import Table
65

76
from ..base import DBTestCase, postgres_only
@@ -151,15 +150,12 @@ def test_integer_to_bigint(self):
151150
)
152151
alter_query.run_sync()
153152

154-
query = """
155-
SELECT data_type FROM information_schema.columns
156-
WHERE table_name = 'band'
157-
AND table_catalog = 'piccolo'
158-
AND column_name = 'popularity'
159-
"""
160-
161-
response = Band.raw(query).run_sync()
162-
self.assertEqual(response[0]["data_type"].upper(), "BIGINT")
153+
self.assertEqual(
154+
self.get_postgres_column_type(
155+
tablename="band", column_name="popularity"
156+
),
157+
"BIGINT",
158+
)
163159

164160
popularity = (
165161
Band.select(Band.popularity).first().run_sync()["popularity"]
@@ -177,21 +173,35 @@ def test_integer_to_varchar(self):
177173
)
178174
alter_query.run_sync()
179175

180-
query = """
181-
SELECT data_type FROM information_schema.columns
182-
WHERE table_name = 'band'
183-
AND table_catalog = 'piccolo'
184-
AND column_name = 'popularity'
185-
"""
186-
187-
response = Band.raw(query).run_sync()
188-
self.assertEqual(response[0]["data_type"].upper(), "CHARACTER VARYING")
176+
self.assertEqual(
177+
self.get_postgres_column_type(
178+
tablename="band", column_name="popularity"
179+
),
180+
"CHARACTER VARYING",
181+
)
189182

190183
popularity = (
191184
Band.select(Band.popularity).first().run_sync()["popularity"]
192185
)
193186
self.assertEqual(popularity, "1000")
194187

188+
def test_using_expression(self):
189+
"""
190+
Test the `using_expression` option, which can be used to tell Postgres
191+
how to convert certain column types.
192+
"""
193+
Band(name="1").save().run_sync()
194+
195+
alter_query = Band.alter().set_column_type(
196+
old_column=Band.name,
197+
new_column=Integer(),
198+
using_expression="name::integer",
199+
)
200+
alter_query.run_sync()
201+
202+
popularity = Band.select(Band.name).first().run_sync()["name"]
203+
self.assertEqual(popularity, 1)
204+
195205

196206
@postgres_only
197207
class TestSetNull(DBTestCase):

0 commit comments

Comments
 (0)