@@ -140,6 +140,9 @@ class MigrationManager:
140140 default_factory = AlterColumnCollection
141141 )
142142 raw : t .List [t .Union [t .Callable , t .Coroutine ]] = field (default_factory = list )
143+ raw_backwards : t .List [t .Union [t .Callable , t .Coroutine ]] = field (
144+ default_factory = list
145+ )
143146
144147 def add_table (
145148 self ,
@@ -162,11 +165,16 @@ def drop_table(self, class_name: str, tablename: str):
162165 )
163166
164167 def rename_table (
165- self , old_class_name : str , new_class_name : str , new_tablename : str
168+ self ,
169+ old_class_name : str ,
170+ old_tablename : str ,
171+ new_class_name : str ,
172+ new_tablename : str ,
166173 ):
167174 self .rename_tables .append (
168175 RenameTable (
169176 old_class_name = old_class_name ,
177+ old_tablename = old_tablename ,
170178 new_class_name = new_class_name ,
171179 new_tablename = new_tablename ,
172180 )
@@ -245,6 +253,13 @@ def add_raw(self, raw: t.Union[t.Callable, t.Coroutine]):
245253 """
246254 self .raw .append (raw )
247255
256+ def add_raw_backwards (self , raw : t .Union [t .Callable , t .Coroutine ]):
257+ """
258+ When reversing a migration, you may want to run extra code to help
259+ clean up.
260+ """
261+ self .raw_backwards .append (raw )
262+
248263 ###########################################################################
249264
250265 def deserialise_params (
@@ -325,16 +340,19 @@ async def run(self):
325340 ###################################################################
326341 # Add tables
327342
328- for table in self .add_tables :
329- columns = self .add_columns .columns_for_table_class_name (
330- table .class_name
343+ for add_table in self .add_tables :
344+ columns = (
345+ self .add_columns .columns_for_table_class_name (
346+ add_table .class_name
347+ )
348+ + add_table .columns
331349 )
332350 _Table : t .Type [Table ] = type (
333- table .class_name ,
351+ add_table .class_name ,
334352 (Table ,),
335353 {column ._meta .name : column for column in columns },
336354 )
337- _Table ._meta .tablename = table .tablename
355+ _Table ._meta .tablename = add_table .tablename
338356
339357 await _Table .create_table ().run ()
340358
@@ -351,6 +369,7 @@ async def run(self):
351369 _Table : t .Type [Table ] = type (
352370 rename_table .old_class_name , (Table ,), {}
353371 )
372+ _Table ._meta .tablename = rename_table .old_tablename
354373 await _Table .alter ().rename_table (
355374 new_name = rename_table .new_tablename
356375 ).run ()
@@ -453,3 +472,108 @@ async def run(self):
453472 await _Table .alter ().set_unique (
454473 column = row_name , boolean = unique
455474 ).run ()
475+
476+ ###########################################################################
477+
478+ async def run_backwards (self ):
479+ print ("Reversing MigrationManager ..." )
480+
481+ engine = engine_finder ()
482+
483+ if not engine :
484+ raise Exception ("Can't find engine" )
485+
486+ async with engine .transaction ():
487+
488+ for raw in self .raw_backwards :
489+ if inspect .iscoroutinefunction (raw ):
490+ await raw ()
491+ else :
492+ raw ()
493+
494+ ###################################################################
495+ # Reverse add tables
496+
497+ for add_table in self .add_tables :
498+ await add_table .to_table_class ().alter ().drop_table ().run ()
499+
500+ ###################################################################
501+ # Reverse drop tables
502+
503+ if self .drop_tables :
504+ print ("Dropped tables can't currently be reversed." )
505+
506+ ###################################################################
507+ # Reverse rename tables
508+
509+ for rename_table in self .rename_tables :
510+ _Table : t .Type [Table ] = type (
511+ rename_table .new_class_name , (Table ,), {}
512+ )
513+ _Table ._meta .tablename = rename_table .new_tablename
514+
515+ await _Table .alter ().rename_table (
516+ new_name = rename_table .old_tablename
517+ ).run ()
518+
519+ ###################################################################
520+ # Reverse add columns
521+
522+ for add_column in self .add_columns .add_columns :
523+ _Table : t .Type [Table ] = type (
524+ add_column .table_class_name , (Table ,), {}
525+ )
526+ _Table ._meta .tablename = add_column .tablename
527+
528+ await _Table .alter ().drop_column (add_column .column ).run ()
529+
530+ ###################################################################
531+ # Reverse drop columns
532+
533+ if self .drop_columns :
534+ print ("Dropped columns can't currently be reversed." )
535+
536+ ###################################################################
537+ # Reverse rename columns
538+
539+ for rename_column in self .rename_columns .rename_columns :
540+ _Table : t .Type [Table ] = type (
541+ rename_column .table_class_name , (Table ,), {}
542+ )
543+ _Table ._meta .tablename = rename_column .tablename
544+
545+ await _Table .alter ().rename_column (
546+ column = rename_column .new_column_name ,
547+ new_name = rename_column .old_column_name ,
548+ ).run ()
549+
550+ ###################################################################
551+ # Alter columns
552+
553+ # TODO - need to find what the old values are.
554+ # for alter_column in self.alter_columns.alter_columns:
555+ # _Table: t.Type[Table] = type(
556+ # alter_column.table_class_name, (Table,), {}
557+ # )
558+ # _Table._meta.tablename = alter_column.tablename
559+
560+ # params = alter_column.params
561+ # row_name = alter_column.row_name
562+
563+ # null = params.get("null")
564+ # if null is not None:
565+ # await _Table.alter().set_null(
566+ # column=row_name, boolean=null
567+ # ).run()
568+
569+ # length = params.get("length")
570+ # if length is not None:
571+ # await _Table.alter().set_length(
572+ # column=row_name, length=length
573+ # ).run()
574+
575+ # unique = params.get("unique")
576+ # if unique is not None:
577+ # await _Table.alter().set_unique(
578+ # column=row_name, boolean=unique
579+ # ).run()
0 commit comments