Skip to content

Commit 8569625

Browse files
authored
Schema generation fixes (piccolo-orm#226)
* fix a typo in the column import * sort the tables by their foreign keys * make sure there are no missing imports from column params * check the generated code is valid * use graphlib for table sorting Using just basic sort would fail on occasion * remove walrus operators * remove unnecessary imports from schema generation output * make graph creation recursive
1 parent ed1e502 commit 8569625

File tree

5 files changed

+339
-31
lines changed

5 files changed

+339
-31
lines changed

piccolo/apps/migrations/auto/migration_manager.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import inspect
44
import typing as t
55
from dataclasses import dataclass, field
6-
from functools import cmp_to_key
76

87
from piccolo.apps.migrations.auto.diffable_table import DiffableTable
98
from piccolo.apps.migrations.auto.operations import (
@@ -16,6 +15,7 @@
1615
from piccolo.columns import Column, column_types
1716
from piccolo.engine import engine_finder
1817
from piccolo.table import Table, create_table_class
18+
from piccolo.utils.graphlib import TopologicalSorter
1919

2020

2121
@dataclass
@@ -118,37 +118,48 @@ def table_class_names(self) -> t.List[str]:
118118
return list(set([i.table_class_name for i in self.alter_columns]))
119119

120120

121-
def _compare_tables(
122-
table_a: t.Type[Table],
123-
table_b: t.Type[Table],
121+
def _get_graph(
122+
table_classes: t.List[t.Type[Table]],
124123
iterations: int = 0,
125-
max_iterations=5,
126-
) -> int:
124+
max_iterations: int = 5,
125+
) -> t.Dict[str, t.Set[str]]:
127126
"""
128-
A comparison function, for sorting Table classes, based on their foreign
129-
keys.
127+
Analyses the tables based on their foreign keys, and returns a data
128+
structure like:
129+
130+
.. code-block:: python
130131
131-
:param iterations:
132-
As this function is called recursively, we use this to limit the depth,
133-
to prevent an infinite loop.
132+
{'band': {'manager'}, 'concert': {'band', 'venue'}, 'manager': set()}
133+
134+
The keys are tablenames, and the values are tablenames directly connected
135+
to it via a foreign key.
134136
135137
"""
138+
output: t.Dict[str, t.Set[str]] = {}
139+
136140
if iterations >= max_iterations:
137-
return 0
141+
return output
138142

139-
for fk_column in table_a._meta.foreign_key_columns:
140-
references = fk_column._foreign_key_meta.resolved_references
141-
if references._meta.tablename == table_b._meta.tablename:
142-
return 1
143-
else:
144-
for _fk_column in references._meta.foreign_key_columns:
145-
_references = _fk_column._foreign_key_meta.resolved_references
146-
if _compare_tables(
147-
_references, table_b, iterations=iterations + 1
148-
):
149-
return 1
143+
for table_class in table_classes:
144+
dependents: t.Set[str] = set()
145+
for fk in table_class._meta.foreign_key_columns:
146+
dependents.add(
147+
fk._foreign_key_meta.resolved_references._meta.tablename
148+
)
150149

151-
return -1
150+
# We also recursively check the related tables to get a fuller
151+
# picture of the schema and relationships.
152+
referenced_table = fk._foreign_key_meta.resolved_references
153+
output.update(
154+
_get_graph(
155+
[referenced_table],
156+
iterations=iterations + 1,
157+
)
158+
)
159+
160+
output[table_class._meta.tablename] = dependents
161+
162+
return output
152163

153164

154165
def sort_table_classes(
@@ -158,7 +169,23 @@ def sort_table_classes(
158169
Sort the table classes based on their foreign keys, so they can be created
159170
in the correct order.
160171
"""
161-
return sorted(table_classes, key=cmp_to_key(_compare_tables))
172+
table_class_dict = {
173+
table_class._meta.tablename: table_class
174+
for table_class in table_classes
175+
}
176+
177+
graph = _get_graph(table_classes)
178+
179+
sorter = TopologicalSorter(graph)
180+
ordered_tablenames = tuple(sorter.static_order())
181+
182+
output: t.List[t.Type[Table]] = []
183+
for tablename in ordered_tablenames:
184+
table_class = table_class_dict.get(tablename, None)
185+
if table_class is not None:
186+
output.append(table_class)
187+
188+
return output
162189

163190

164191
@dataclass

piccolo/apps/schema/commands/generate.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import black
77
from typing_extensions import Literal
88

9+
from piccolo.apps.migrations.auto.migration_manager import sort_table_classes
10+
from piccolo.apps.migrations.auto.serialisation import serialise_params
911
from piccolo.columns.base import Column
1012
from piccolo.columns.column_types import (
1113
JSON,
@@ -336,18 +338,22 @@ class Schema(Table, db=engine):
336338
)
337339
else:
338340
kwargs["references"] = ForeignKeyPlaceholder
339-
imports.add(
340-
"from piccolo.columns.base import OnDelete, OnUpdate"
341-
)
342341

343342
imports.add(
344-
"from piccolo.column_types import " + column_type.__name__
343+
"from piccolo.columns.column_types import "
344+
+ column_type.__name__
345345
)
346346

347347
if column_type is Varchar:
348348
kwargs["length"] = pg_row_meta.character_maximum_length
349349

350-
columns[column_name] = column_type(**kwargs)
350+
column = column_type(**kwargs)
351+
352+
serialised_params = serialise_params(column._meta.params)
353+
for extra_import in serialised_params.extra_imports:
354+
imports.add(extra_import.__repr__())
355+
356+
columns[column_name] = column
351357
else:
352358
warnings.append(f"{tablename}.{column_name} ['{data_type}']")
353359

@@ -358,6 +364,13 @@ class Schema(Table, db=engine):
358364
)
359365
tables.append(table)
360366

367+
# Sort the tables based on their ForeignKeys.
368+
tables = sort_table_classes(tables)
369+
370+
# We currently don't show the index argument for columns in the output,
371+
# so we don't need this import for now:
372+
imports.remove("from piccolo.columns.indexes import IndexMethod")
373+
361374
return OutputSchema(
362375
imports=sorted(list(imports)), warnings=warnings, tables=tables
363376
)

piccolo/utils/graphlib/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
try:
2+
from graphlib import CycleError, TopologicalSorter # type: ignore
3+
except ImportError:
4+
# For version < Python 3.9
5+
from ._graphlib import CycleError, TopologicalSorter

0 commit comments

Comments
 (0)