|
3 | 3 | import inspect |
4 | 4 | import typing as t |
5 | 5 | from dataclasses import dataclass, field |
| 6 | +from functools import cmp_to_key |
6 | 7 |
|
7 | 8 | from piccolo.apps.migrations.auto.diffable_table import DiffableTable |
8 | 9 | from piccolo.apps.migrations.auto.operations import ( |
@@ -117,6 +118,49 @@ def table_class_names(self) -> t.List[str]: |
117 | 118 | return list(set([i.table_class_name for i in self.alter_columns])) |
118 | 119 |
|
119 | 120 |
|
| 121 | +def _compare_tables( |
| 122 | + table_a: t.Type[Table], |
| 123 | + table_b: t.Type[Table], |
| 124 | + iterations: int = 0, |
| 125 | + max_iterations=5, |
| 126 | +) -> int: |
| 127 | + """ |
| 128 | + A comparison function, for sorting Table classes, based on their foreign |
| 129 | + keys. |
| 130 | +
|
| 131 | + :param iterations: |
| 132 | + As this function is called recursively, we use this to limit the depth, |
| 133 | + to prevent an infinite loop. |
| 134 | +
|
| 135 | + """ |
| 136 | + if iterations >= max_iterations: |
| 137 | + return 0 |
| 138 | + |
| 139 | + for fk_column in table_a._meta.foreign_key_columns: |
| 140 | + references = fk_column._foreign_key_meta.resolved_references |
| 141 | + if references._meta.tablename == table_b._meta.tablename: |
| 142 | + return 1 |
| 143 | + else: |
| 144 | + for _fk_column in references._meta.foreign_key_columns: |
| 145 | + _references = _fk_column._foreign_key_meta.resolved_references |
| 146 | + if _compare_tables( |
| 147 | + _references, table_b, iterations=iterations + 1 |
| 148 | + ): |
| 149 | + return 1 |
| 150 | + |
| 151 | + return -1 |
| 152 | + |
| 153 | + |
| 154 | +def sort_table_classes( |
| 155 | + table_classes: t.List[t.Type[Table]], |
| 156 | +) -> t.List[t.Type[Table]]: |
| 157 | + """ |
| 158 | + Sort the table classes based on their foreign keys, so they can be created |
| 159 | + in the correct order. |
| 160 | + """ |
| 161 | + return sorted(table_classes, key=cmp_to_key(_compare_tables)) |
| 162 | + |
| 163 | + |
120 | 164 | @dataclass |
121 | 165 | class MigrationManager: |
122 | 166 | """ |
@@ -563,25 +607,29 @@ async def _run_rename_columns(self, backwards=False): |
563 | 607 | ).run() |
564 | 608 |
|
565 | 609 | async def _run_add_tables(self, backwards=False): |
| 610 | + table_classes: t.List[t.Type[Table]] = [] |
| 611 | + for add_table in self.add_tables: |
| 612 | + add_columns: t.List[ |
| 613 | + AddColumnClass |
| 614 | + ] = self.add_columns.for_table_class_name(add_table.class_name) |
| 615 | + _Table: t.Type[Table] = create_table_class( |
| 616 | + class_name=add_table.class_name, |
| 617 | + class_kwargs={"tablename": add_table.tablename}, |
| 618 | + class_members={ |
| 619 | + add_column.column._meta.name: add_column.column |
| 620 | + for add_column in add_columns |
| 621 | + }, |
| 622 | + ) |
| 623 | + table_classes.append(_Table) |
| 624 | + |
| 625 | + # Sort by foreign key, so they're created in the right order. |
| 626 | + sorted_table_classes = sort_table_classes(table_classes) |
| 627 | + |
566 | 628 | if backwards: |
567 | | - for add_table in self.add_tables: |
568 | | - await add_table.to_table_class().alter().drop_table( |
569 | | - cascade=True |
570 | | - ).run() |
| 629 | + for _Table in reversed(sorted_table_classes): |
| 630 | + await _Table.alter().drop_table(cascade=True).run() |
571 | 631 | else: |
572 | | - for add_table in self.add_tables: |
573 | | - add_columns: t.List[ |
574 | | - AddColumnClass |
575 | | - ] = self.add_columns.for_table_class_name(add_table.class_name) |
576 | | - _Table: t.Type[Table] = create_table_class( |
577 | | - class_name=add_table.class_name, |
578 | | - class_kwargs={"tablename": add_table.tablename}, |
579 | | - class_members={ |
580 | | - add_column.column._meta.name: add_column.column |
581 | | - for add_column in add_columns |
582 | | - }, |
583 | | - ) |
584 | | - |
| 632 | + for _Table in sorted_table_classes: |
585 | 633 | await _Table.create_table().run() |
586 | 634 |
|
587 | 635 | async def _run_add_columns(self, backwards=False): |
|
0 commit comments