|
40 | 40 | from piccolo.query.methods.indexes import Indexes |
41 | 41 | from piccolo.querystring import QueryString, Unquoted |
42 | 42 | from piccolo.utils import _camel_to_snake |
| 43 | +from piccolo.utils.graphlib import TopologicalSorter |
43 | 44 | from piccolo.utils.sql_values import convert_to_sql_value |
44 | 45 |
|
45 | 46 | if t.TYPE_CHECKING: |
46 | 47 | from piccolo.columns import Selectable |
47 | 48 |
|
48 | | - |
49 | 49 | PROTECTED_TABLENAMES = ("user",) |
50 | 50 |
|
51 | 51 |
|
@@ -126,7 +126,6 @@ def __str__(cls): |
126 | 126 |
|
127 | 127 |
|
128 | 128 | class Table(metaclass=TableMetaclass): |
129 | | - |
130 | 129 | # These are just placeholder values, so type inference isn't confused - the |
131 | 130 | # actual values are set in __init_subclass__. |
132 | 131 | _meta = TableMeta() |
@@ -975,3 +974,82 @@ def create_table_class( |
975 | 974 | kwds=class_kwargs, |
976 | 975 | exec_body=lambda namespace: namespace.update(class_members), |
977 | 976 | ) |
| 977 | + |
| 978 | + |
| 979 | +def create_tables(*args: t.Type[Table], if_not_exists: bool = False) -> None: |
| 980 | + """ |
| 981 | + Creates multiple tables that passed to it. |
| 982 | + """ |
| 983 | + sorted_table_classes = sort_table_classes(list(args)) |
| 984 | + for table in sorted_table_classes: |
| 985 | + Create(table=table, if_not_exists=if_not_exists).run_sync() |
| 986 | + |
| 987 | + |
| 988 | +def sort_table_classes( |
| 989 | + table_classes: t.List[t.Type[Table]], |
| 990 | +) -> t.List[t.Type[Table]]: |
| 991 | + """ |
| 992 | + Sort the table classes based on their foreign keys, so they can be created |
| 993 | + in the correct order. |
| 994 | + """ |
| 995 | + table_class_dict = { |
| 996 | + table_class._meta.tablename: table_class |
| 997 | + for table_class in table_classes |
| 998 | + } |
| 999 | + |
| 1000 | + graph = _get_graph(table_classes) |
| 1001 | + |
| 1002 | + sorter = TopologicalSorter(graph) |
| 1003 | + ordered_tablenames = tuple(sorter.static_order()) |
| 1004 | + |
| 1005 | + output: t.List[t.Type[Table]] = [] |
| 1006 | + for tablename in ordered_tablenames: |
| 1007 | + table_class = table_class_dict.get(tablename, None) |
| 1008 | + if table_class is not None: |
| 1009 | + output.append(table_class) |
| 1010 | + |
| 1011 | + return output |
| 1012 | + |
| 1013 | + |
| 1014 | +def _get_graph( |
| 1015 | + table_classes: t.List[t.Type[Table]], |
| 1016 | + iterations: int = 0, |
| 1017 | + max_iterations: int = 5, |
| 1018 | +) -> t.Dict[str, t.Set[str]]: |
| 1019 | + """ |
| 1020 | + Analyses the tables based on their foreign keys, and returns a data |
| 1021 | + structure like: |
| 1022 | +
|
| 1023 | + .. code-block:: python |
| 1024 | +
|
| 1025 | + {'band': {'manager'}, 'concert': {'band', 'venue'}, 'manager': set()} |
| 1026 | +
|
| 1027 | + The keys are tablenames, and the values are tablenames directly connected |
| 1028 | + to it via a foreign key. |
| 1029 | +
|
| 1030 | + """ |
| 1031 | + output: t.Dict[str, t.Set[str]] = {} |
| 1032 | + |
| 1033 | + if iterations >= max_iterations: |
| 1034 | + return output |
| 1035 | + |
| 1036 | + for table_class in table_classes: |
| 1037 | + dependents: t.Set[str] = set() |
| 1038 | + for fk in table_class._meta.foreign_key_columns: |
| 1039 | + dependents.add( |
| 1040 | + fk._foreign_key_meta.resolved_references._meta.tablename |
| 1041 | + ) |
| 1042 | + |
| 1043 | + # We also recursively check the related tables to get a fuller |
| 1044 | + # picture of the schema and relationships. |
| 1045 | + referenced_table = fk._foreign_key_meta.resolved_references |
| 1046 | + output.update( |
| 1047 | + _get_graph( |
| 1048 | + [referenced_table], |
| 1049 | + iterations=iterations + 1, |
| 1050 | + ) |
| 1051 | + ) |
| 1052 | + |
| 1053 | + output[table_class._meta.tablename] = dependents |
| 1054 | + |
| 1055 | + return output |
0 commit comments