Skip to content

Commit 4df6d4b

Browse files
committed
making auto migrations work in reverse
1 parent 258f06a commit 4df6d4b

File tree

12 files changed

+243
-31
lines changed

12 files changed

+243
-31
lines changed

docs/src/piccolo/migrations/create.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ The contents of an empty migration file looks like this:
3838
manager.add_raw(run)
3939
return manager
4040
41-
42-
async def backwards():
43-
pass
44-
4541
Replace the `run` function with whatever you want the migration to do -
4642
typically running some SQL. It can be a function or a coroutine.
4743

piccolo/apps/migrations/auto/migration_manager.py

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ class MigrationManager:
140140
default_factory=AlterColumnCollection
141141
)
142142
raw: t.List[t.Union[t.Callable, t.Coroutine]] = field(default_factory=list)
143+
raw_backwards: t.List[t.Union[t.Callable, t.Coroutine]] = field(
144+
default_factory=list
145+
)
143146

144147
def add_table(
145148
self,
@@ -162,11 +165,16 @@ def drop_table(self, class_name: str, tablename: str):
162165
)
163166

164167
def rename_table(
165-
self, old_class_name: str, new_class_name: str, new_tablename: str
168+
self,
169+
old_class_name: str,
170+
old_tablename: str,
171+
new_class_name: str,
172+
new_tablename: str,
166173
):
167174
self.rename_tables.append(
168175
RenameTable(
169176
old_class_name=old_class_name,
177+
old_tablename=old_tablename,
170178
new_class_name=new_class_name,
171179
new_tablename=new_tablename,
172180
)
@@ -245,6 +253,13 @@ def add_raw(self, raw: t.Union[t.Callable, t.Coroutine]):
245253
"""
246254
self.raw.append(raw)
247255

256+
def add_raw_backwards(self, raw: t.Union[t.Callable, t.Coroutine]):
257+
"""
258+
When reversing a migration, you may want to run extra code to help
259+
clean up.
260+
"""
261+
self.raw_backwards.append(raw)
262+
248263
###########################################################################
249264

250265
def deserialise_params(
@@ -325,16 +340,19 @@ async def run(self):
325340
###################################################################
326341
# Add tables
327342

328-
for table in self.add_tables:
329-
columns = self.add_columns.columns_for_table_class_name(
330-
table.class_name
343+
for add_table in self.add_tables:
344+
columns = (
345+
self.add_columns.columns_for_table_class_name(
346+
add_table.class_name
347+
)
348+
+ add_table.columns
331349
)
332350
_Table: t.Type[Table] = type(
333-
table.class_name,
351+
add_table.class_name,
334352
(Table,),
335353
{column._meta.name: column for column in columns},
336354
)
337-
_Table._meta.tablename = table.tablename
355+
_Table._meta.tablename = add_table.tablename
338356

339357
await _Table.create_table().run()
340358

@@ -351,6 +369,7 @@ async def run(self):
351369
_Table: t.Type[Table] = type(
352370
rename_table.old_class_name, (Table,), {}
353371
)
372+
_Table._meta.tablename = rename_table.old_tablename
354373
await _Table.alter().rename_table(
355374
new_name=rename_table.new_tablename
356375
).run()
@@ -453,3 +472,108 @@ async def run(self):
453472
await _Table.alter().set_unique(
454473
column=row_name, boolean=unique
455474
).run()
475+
476+
###########################################################################
477+
478+
async def run_backwards(self):
479+
print("Reversing MigrationManager ...")
480+
481+
engine = engine_finder()
482+
483+
if not engine:
484+
raise Exception("Can't find engine")
485+
486+
async with engine.transaction():
487+
488+
for raw in self.raw_backwards:
489+
if inspect.iscoroutinefunction(raw):
490+
await raw()
491+
else:
492+
raw()
493+
494+
###################################################################
495+
# Reverse add tables
496+
497+
for add_table in self.add_tables:
498+
await add_table.to_table_class().alter().drop_table().run()
499+
500+
###################################################################
501+
# Reverse drop tables
502+
503+
if self.drop_tables:
504+
print("Dropped tables can't currently be reversed.")
505+
506+
###################################################################
507+
# Reverse rename tables
508+
509+
for rename_table in self.rename_tables:
510+
_Table: t.Type[Table] = type(
511+
rename_table.new_class_name, (Table,), {}
512+
)
513+
_Table._meta.tablename = rename_table.new_tablename
514+
515+
await _Table.alter().rename_table(
516+
new_name=rename_table.old_tablename
517+
).run()
518+
519+
###################################################################
520+
# Reverse add columns
521+
522+
for add_column in self.add_columns.add_columns:
523+
_Table: t.Type[Table] = type(
524+
add_column.table_class_name, (Table,), {}
525+
)
526+
_Table._meta.tablename = add_column.tablename
527+
528+
await _Table.alter().drop_column(add_column.column).run()
529+
530+
###################################################################
531+
# Reverse drop columns
532+
533+
if self.drop_columns:
534+
print("Dropped columns can't currently be reversed.")
535+
536+
###################################################################
537+
# Reverse rename columns
538+
539+
for rename_column in self.rename_columns.rename_columns:
540+
_Table: t.Type[Table] = type(
541+
rename_column.table_class_name, (Table,), {}
542+
)
543+
_Table._meta.tablename = rename_column.tablename
544+
545+
await _Table.alter().rename_column(
546+
column=rename_column.new_column_name,
547+
new_name=rename_column.old_column_name,
548+
).run()
549+
550+
###################################################################
551+
# Alter columns
552+
553+
# TODO - need to find what the old values are.
554+
# for alter_column in self.alter_columns.alter_columns:
555+
# _Table: t.Type[Table] = type(
556+
# alter_column.table_class_name, (Table,), {}
557+
# )
558+
# _Table._meta.tablename = alter_column.tablename
559+
560+
# params = alter_column.params
561+
# row_name = alter_column.row_name
562+
563+
# null = params.get("null")
564+
# if null is not None:
565+
# await _Table.alter().set_null(
566+
# column=row_name, boolean=null
567+
# ).run()
568+
569+
# length = params.get("length")
570+
# if length is not None:
571+
# await _Table.alter().set_length(
572+
# column=row_name, length=length
573+
# ).run()
574+
575+
# unique = params.get("unique")
576+
# if unique is not None:
577+
# await _Table.alter().set_unique(
578+
# column=row_name, boolean=unique
579+
# ).run()

piccolo/apps/migrations/auto/operations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
@dataclass
66
class RenameTable:
77
old_class_name: str
8+
old_tablename: str
89
new_class_name: str
910
new_tablename: str
1011

piccolo/apps/migrations/auto/schema_differ.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def check_rename_tables(self) -> RenameTableCollection:
134134
collection.append(
135135
RenameTable(
136136
old_class_name=drop_table.class_name,
137+
old_tablename=drop_table.tablename,
137138
new_class_name=new_table.class_name,
138139
new_tablename=new_table.tablename,
139140
)
@@ -237,7 +238,7 @@ def drop_tables(self) -> t.List[str]:
237238
@property
238239
def rename_tables(self) -> t.List[str]:
239240
return [
240-
f"manager.rename_table(old_class_name='{renamed_table.old_class_name}', new_class_name='{renamed_table.new_class_name}', new_tablename='{renamed_table.new_tablename}')" # noqa
241+
f"manager.rename_table(old_class_name='{renamed_table.old_class_name}', old_tablename='{renamed_table.old_tablename}', new_class_name='{renamed_table.new_class_name}', new_tablename='{renamed_table.new_tablename}')" # noqa
241242
for renamed_table in self.rename_tables_collection.rename_tables
242243
]
243244

piccolo/apps/migrations/commands/backwards.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def run(self):
5757
print(f"Reversing {migration_name}")
5858
migration_module = migration_modules[migration_name]
5959
response = asyncio.run(
60-
migration_module.backwards()
60+
migration_module.forwards()
6161
) # type: ignore
6262

6363
if isinstance(response, MigrationManager):
64-
asyncio.run(response.run())
64+
asyncio.run(response.run_backwards())
6565

6666
Migration.delete().where(
6767
Migration.name == migration_name

piccolo/apps/migrations/commands/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ class MigrationModule(ModuleType):
1818
async def forwards() -> None:
1919
pass
2020

21-
@staticmethod
22-
async def backwards() -> None:
23-
pass
24-
2521

2622
class PiccoloAppModule(ModuleType):
2723
APP_CONFIG: AppConfig

piccolo/apps/migrations/commands/templates/migration.py.jinja

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,3 @@ async def forwards():
1717
manager.add_raw(run)
1818
{% endif %}
1919
return manager
20-
21-
22-
async def backwards():
23-
pass

piccolo/apps/user/piccolo_migrations/2019-11-14T21:52:21.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,3 @@ async def forwards():
8282
)
8383

8484
return manager
85-
86-
87-
async def backwards():
88-
pass

tests/apps/migrations/auto/test_migration_manager.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22

33
from piccolo.apps.migrations.auto import MigrationManager
4+
from piccolo.columns import Varchar
45

56
from tests.base import DBTestCase
67

@@ -26,6 +27,12 @@ def test_rename_column(self):
2627
self.assertTrue("title" in response[0].keys())
2728
self.assertTrue("name" not in response[0].keys())
2829

30+
# Now reverse it
31+
asyncio.run(manager.run_backwards())
32+
response = self.run_sync("SELECT * FROM band;")
33+
self.assertTrue("title" not in response[0].keys())
34+
self.assertTrue("name" in response[0].keys())
35+
2936
def test_raw_function(self):
3037
"""
3138
Test adding raw functions to a MigrationManager.
@@ -39,10 +46,15 @@ def run():
3946

4047
manager = MigrationManager()
4148
manager.add_raw(run)
49+
manager.add_raw_backwards(run)
4250

4351
with self.assertRaises(HasRun):
4452
asyncio.run(manager.run())
4553

54+
# Reverse
55+
with self.assertRaises(HasRun):
56+
asyncio.run(manager.run_backwards())
57+
4658
def test_raw_coroutine(self):
4759
"""
4860
Test adding raw coroutines to a MigrationManager.
@@ -56,6 +68,96 @@ async def run():
5668

5769
manager = MigrationManager()
5870
manager.add_raw(run)
71+
manager.add_raw_backwards(run)
5972

6073
with self.assertRaises(HasRun):
6174
asyncio.run(manager.run())
75+
76+
# Reverse
77+
with self.assertRaises(HasRun):
78+
asyncio.run(manager.run_backwards())
79+
80+
def test_add_table(self):
81+
"""
82+
Test adding a table to a MigrationManager.
83+
"""
84+
self.run_sync("DROP TABLE IF EXISTS musician;")
85+
86+
manager = MigrationManager()
87+
name_column = Varchar()
88+
name_column._meta.name = "name"
89+
manager.add_table(
90+
class_name="Musician", tablename="musician", columns=[name_column]
91+
)
92+
asyncio.run(manager.run())
93+
94+
self.run_sync("INSERT INTO musician VALUES (default, 'Bob Jones');")
95+
response = self.run_sync("SELECT * FROM musician;")
96+
97+
self.assertEqual(response, [{"id": 1, "name": "Bob Jones"}])
98+
99+
# Reverse
100+
asyncio.run(manager.run_backwards())
101+
self.assertEqual(self.table_exists("musician"), False)
102+
self.run_sync("DROP TABLE IF EXISTS musician;")
103+
104+
def test_add_column(self):
105+
"""
106+
Test adding a column to a MigrationManager.
107+
"""
108+
manager = MigrationManager()
109+
manager.add_column(
110+
table_class_name="Manager",
111+
tablename="manager",
112+
column_name="email",
113+
column_class_name="Varchar",
114+
params={
115+
"length": 100,
116+
"default": "",
117+
"null": True,
118+
"primary": False,
119+
"key": False,
120+
"unique": True,
121+
"index": False,
122+
},
123+
)
124+
asyncio.run(manager.run())
125+
126+
self.run_sync(
127+
"INSERT INTO manager VALUES (default, 'Dave', 'dave@me.com');"
128+
)
129+
130+
response = self.run_sync("SELECT * FROM manager;")
131+
self.assertEqual(
132+
response, [{"id": 1, "name": "Dave", "email": "dave@me.com"}]
133+
)
134+
135+
# Reverse
136+
asyncio.run(manager.run_backwards())
137+
response = self.run_sync("SELECT * FROM manager;")
138+
self.assertEqual(response, [{"id": 1, "name": "Dave"}])
139+
140+
def test_rename_table(self):
141+
"""
142+
Test renaming a table with MigrationManager.
143+
"""
144+
manager = MigrationManager()
145+
146+
manager.rename_table(
147+
old_class_name="Manager",
148+
old_tablename="manager",
149+
new_class_name="Director",
150+
new_tablename="director",
151+
)
152+
153+
asyncio.run(manager.run())
154+
155+
self.run_sync("INSERT INTO director VALUES (default, 'Dave');")
156+
157+
response = self.run_sync("SELECT * FROM director;")
158+
self.assertEqual(response, [{"id": 1, "name": "Dave"}])
159+
160+
# Reverse
161+
asyncio.run(manager.run_backwards())
162+
response = self.run_sync("SELECT * FROM manager;")
163+
self.assertEqual(response, [{"id": 1, "name": "Dave"}])

0 commit comments

Comments
 (0)