Skip to content

Commit 752c1de

Browse files
committed
slight refactor, and output required imports
1 parent 6ee3c36 commit 752c1de

File tree

2 files changed

+134
-69
lines changed

2 files changed

+134
-69
lines changed

piccolo/apps/schema/commands/generate.py

Lines changed: 127 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from piccolo.engine.finder import engine_finder
2121
from piccolo.engine.postgres import PostgresEngine
2222
from piccolo.table import Table, create_table_class
23+
from piccolo.utils.naming import _snake_to_camel
2324

2425
if t.TYPE_CHECKING: # pragma: no cover
2526
from piccolo.engine.base import Engine
@@ -30,11 +31,103 @@ class PostgresRowMeta:
3031
column_default: str
3132
column_name: str
3233
is_nullable: Literal["YES", "NO"]
33-
ordinal_position: int
3434
table_name: str
3535
character_maximum_length: t.Optional[int]
3636
data_type: str
3737

38+
@classmethod
39+
def get_column_name_str(cls) -> str:
40+
return ", ".join([i.name for i in dataclasses.fields(cls)])
41+
42+
43+
@dataclasses.dataclass
44+
class PostgresContraint:
45+
contraint_type: Literal["PRIMARY KEY", "UNIQUE", "FOREIGN KEY", "CHECK"]
46+
constraint_name: str
47+
48+
@classmethod
49+
def get_column_name_str(cls) -> str:
50+
return ", ".join([i.name for i in dataclasses.fields(cls)])
51+
52+
53+
COLUMN_TYPE_MAP = {
54+
"bigint": BigInt,
55+
"boolean": Boolean,
56+
"character varying": Varchar,
57+
"integer": Integer,
58+
"numeric": Numeric,
59+
"smallint": SmallInt,
60+
"text": Text,
61+
"timestamp without time zone": Timestamp,
62+
"timestamp with time zone": Timestamptz,
63+
"uuid": UUID,
64+
}
65+
66+
67+
async def get_foreign_keys(
68+
table_class: t.Type[Table], schema_name: str = "public"
69+
):
70+
"""
71+
:param table_class:
72+
Any Table subclass - just used to execute raw queries on the database.
73+
74+
"""
75+
foreign_key_meta: t.List[str] = await table_class.raw(
76+
(
77+
f"SELECT {PostgresContraint.get_column_name_str()} FROM "
78+
"information_schema.table_constraints "
79+
"WHERE table_schema = {} "
80+
"AND table_name = {};"
81+
),
82+
schema_name,
83+
)
84+
return foreign_key_meta
85+
86+
87+
async def get_tablenames(
88+
table_class: t.Type[Table], schema_name: str = "public"
89+
) -> t.List[str]:
90+
"""
91+
:param table_class:
92+
Any Table subclass - just used to execute raw queries on the database.
93+
:returns:
94+
A list of tablenames for the given schema.
95+
96+
"""
97+
tablenames: t.List[str] = [
98+
i["tablename"]
99+
for i in await table_class.raw(
100+
(
101+
"SELECT tablename FROM pg_catalog.pg_tables WHERE "
102+
"schemaname = {};"
103+
),
104+
schema_name,
105+
).run()
106+
]
107+
return tablenames
108+
109+
110+
async def get_table_schema(
111+
table_class: t.Type[Table], tablename: str, schema_name: str = "public"
112+
) -> t.List[PostgresRowMeta]:
113+
"""
114+
:returns:
115+
A list, with each item containing information about a colum in the
116+
table.
117+
118+
"""
119+
row_meta_list = await table_class.raw(
120+
(
121+
f"SELECT {PostgresRowMeta.get_column_name_str()} FROM "
122+
"information_schema.columns "
123+
"WHERE table_schema = {} "
124+
"AND TABLE_NAME = {};"
125+
),
126+
schema_name,
127+
tablename,
128+
).run()
129+
return [PostgresRowMeta(**row_meta) for row_meta in row_meta_list]
130+
38131

39132
# This is currently a beta version, and can be improved. However, having
40133
# something working is still useful for people migrating large schemas to
@@ -61,81 +154,46 @@ async def generate(schema_name: str = "public"):
61154
class Schema(Table, db=engine):
62155
pass
63156

64-
tablenames: t.List[str] = [
65-
i["tablename"]
66-
for i in await Schema.raw(
67-
(
68-
"SELECT tablename FROM pg_catalog.pg_tables WHERE "
69-
"schemaname = {};"
70-
),
71-
schema_name,
72-
).run()
73-
]
74-
75-
schema_column_names = ", ".join(
76-
[i.name for i in dataclasses.fields(PostgresRowMeta)]
77-
)
157+
tablenames = await get_tablenames(Schema, schema_name=schema_name)
78158

79-
output = []
159+
output: t.List[str] = []
160+
imports: t.Set[str] = {"from piccolo.table import Table"}
80161
warnings: t.List[str] = []
81162

82163
for tablename in tablenames:
83-
row_meta_list = await Schema.raw(
84-
(
85-
f"SELECT {schema_column_names} FROM "
86-
"information_schema.columns "
87-
"WHERE table_schema = {} "
88-
"AND TABLE_NAME = {};"
89-
),
90-
schema_name,
91-
tablename,
92-
).run()
164+
table_schema = await get_table_schema(Schema, tablename, schema_name)
165+
166+
columns: t.Dict[str, Column] = {}
93167

94-
class_name = tablename.title()
95-
96-
if row_meta_list:
97-
columns: t.Dict[str, Column] = {}
98-
99-
for row_meta in row_meta_list:
100-
pg_row_meta = PostgresRowMeta(**row_meta)
101-
102-
column_type_map = {
103-
"bigint": BigInt,
104-
"boolean": Boolean,
105-
"character varying": Varchar,
106-
"integer": Integer,
107-
"numeric": Numeric,
108-
"smallint": SmallInt,
109-
"text": Text,
110-
"timestamp without time zone": Timestamp,
111-
"timestamp with time zone": Timestamptz,
112-
"uuid": UUID,
168+
for pg_row_meta in table_schema:
169+
data_type = pg_row_meta.data_type
170+
column_type = COLUMN_TYPE_MAP.get(data_type, None)
171+
column_name = pg_row_meta.column_name
172+
173+
if column_type:
174+
kwargs: t.Dict[str, t.Any] = {
175+
"null": pg_row_meta.is_nullable == "YES"
113176
}
114177

115-
data_type = pg_row_meta.data_type
116-
column_type = column_type_map.get(data_type, None)
117-
column_name = pg_row_meta.column_name
118-
119-
if column_type:
120-
kwargs: t.Dict[str, t.Any] = {
121-
"null": pg_row_meta.is_nullable == "YES"
122-
}
123-
124-
if column_type is Varchar:
125-
kwargs["length"] = pg_row_meta.character_maximum_length
126-
127-
columns[pg_row_meta.column_name] = column_type(**kwargs)
128-
else:
129-
warnings.append(
130-
f"{tablename}.{column_name} ['{data_type}']"
131-
)
132-
133-
table = create_table_class(
134-
class_name=class_name,
135-
class_kwargs={"tablename": tablename},
136-
class_members=columns,
137-
)
138-
output.append(table._table_str())
178+
imports.add(
179+
"from piccolo.column_types import " + column_type.__name__
180+
)
181+
182+
if column_type is Varchar:
183+
kwargs["length"] = pg_row_meta.character_maximum_length
184+
185+
columns[column_name] = column_type(**kwargs)
186+
else:
187+
warnings.append(f"{tablename}.{column_name} ['{data_type}']")
188+
189+
table = create_table_class(
190+
class_name=_snake_to_camel(tablename),
191+
class_kwargs={"tablename": tablename},
192+
class_members=columns,
193+
)
194+
output.append(table._table_str())
195+
196+
output = sorted(list(imports)) + output
139197

140198
if warnings:
141199
warning_str = "\n".join(warnings)

piccolo/utils/naming.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ def _camel_to_snake(string: str):
66
Convert CamelCase to snake_case.
77
"""
88
return inflection.underscore(string)
9+
10+
11+
def _snake_to_camel(string: str):
12+
"""
13+
Convert snake_case to CamelCase.
14+
"""
15+
return inflection.camelize(string)

0 commit comments

Comments
 (0)