Skip to content

Commit 319d583

Browse files
committed
added initial tests
1 parent 752c1de commit 319d583

File tree

4 files changed

+174
-15
lines changed

4 files changed

+174
-15
lines changed

piccolo/apps/schema/commands/generate.py

Lines changed: 68 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
import typing as t
35

@@ -6,11 +8,17 @@
68

79
from piccolo.columns.base import Column
810
from piccolo.columns.column_types import (
11+
JSON,
12+
JSONB,
913
UUID,
1014
BigInt,
1115
Boolean,
16+
Bytea,
17+
Date,
1218
Integer,
19+
Interval,
1320
Numeric,
21+
Real,
1422
SmallInt,
1523
Text,
1624
Timestamp,
@@ -50,16 +58,47 @@ def get_column_name_str(cls) -> str:
5058
return ", ".join([i.name for i in dataclasses.fields(cls)])
5159

5260

61+
@dataclasses.dataclass
62+
class OutputSchema:
63+
"""
64+
Represents the schema which will printed output.
65+
66+
:param imports:
67+
e.g. ["from piccolo.table import Table"]
68+
:param warnings:
69+
e.g. ["some_table.some_column unrecognised_type"]
70+
:param tables:
71+
e.g. ["class MyTable(Table): ..."]
72+
73+
"""
74+
75+
imports: t.List[str]
76+
warnings: t.List[str]
77+
tables: t.List[t.Type[Table]]
78+
79+
def get_table_with_name(self, name: str) -> t.Type[Table]:
80+
"""
81+
Just used by unit tests.
82+
"""
83+
return next(table for table in self.tables if table.__name__ == name)
84+
85+
5386
COLUMN_TYPE_MAP = {
5487
"bigint": BigInt,
5588
"boolean": Boolean,
89+
"bytea": Bytea,
5690
"character varying": Varchar,
91+
"date": Date,
5792
"integer": Integer,
93+
"interval": Interval,
94+
"json": JSON,
95+
"jsonb": JSONB,
5896
"numeric": Numeric,
97+
"real": Real,
5998
"smallint": SmallInt,
6099
"text": Text,
61-
"timestamp without time zone": Timestamp,
62100
"timestamp with time zone": Timestamptz,
101+
"timestamp without time zone": Timestamp,
63102
"uuid": UUID,
64103
}
65104

@@ -129,15 +168,7 @@ async def get_table_schema(
129168
return [PostgresRowMeta(**row_meta) for row_meta in row_meta_list]
130169

131170

132-
# This is currently a beta version, and can be improved. However, having
133-
# something working is still useful for people migrating large schemas to
134-
# Piccolo.
135-
async def generate(schema_name: str = "public"):
136-
"""
137-
Automatically generates Piccolo Table classes by introspecting the
138-
database. Please check the generated code in case there are errors.
139-
140-
"""
171+
async def get_output_schema(schema_name: str = "public") -> OutputSchema:
141172
engine: t.Optional[Engine] = engine_finder()
142173

143174
if engine is None:
@@ -152,11 +183,15 @@ async def generate(schema_name: str = "public"):
152183
)
153184

154185
class Schema(Table, db=engine):
186+
"""
187+
Just used for making raw queries on the db.
188+
"""
189+
155190
pass
156191

157192
tablenames = await get_tablenames(Schema, schema_name=schema_name)
158193

159-
output: t.List[str] = []
194+
tables: t.List[t.Type[Table]] = []
160195
imports: t.Set[str] = {"from piccolo.table import Table"}
161196
warnings: t.List[str] = []
162197

@@ -191,12 +226,30 @@ class Schema(Table, db=engine):
191226
class_kwargs={"tablename": tablename},
192227
class_members=columns,
193228
)
194-
output.append(table._table_str())
229+
tables.append(table)
230+
231+
return OutputSchema(
232+
imports=sorted(list(imports)), warnings=warnings, tables=tables
233+
)
234+
235+
236+
# This is currently a beta version, and can be improved. However, having
237+
# something working is still useful for people migrating large schemas to
238+
# Piccolo.
239+
async def generate(schema_name: str = "public"):
240+
"""
241+
Automatically generates Piccolo Table classes by introspecting the
242+
database. Please check the generated code in case there are errors.
243+
244+
"""
245+
output_schema = await get_output_schema(schema_name=schema_name)
195246

196-
output = sorted(list(imports)) + output
247+
output = output_schema.imports + [
248+
i._table_str() for i in output_schema.tables
249+
]
197250

198-
if warnings:
199-
warning_str = "\n".join(warnings)
251+
if output_schema.warnings:
252+
warning_str = "\n".join(output_schema.warnings)
200253

201254
output.append('"""')
202255
output.append(

tests/apps/schema/__init__.py

Whitespace-only changes.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
from unittest import TestCase
5+
6+
from piccolo.apps.schema.commands.generate import (
7+
OutputSchema,
8+
get_output_schema,
9+
)
10+
from piccolo.columns.column_types import (
11+
JSON,
12+
JSONB,
13+
UUID,
14+
BigInt,
15+
Boolean,
16+
Bytea,
17+
Date,
18+
ForeignKey,
19+
Integer,
20+
Interval,
21+
Numeric,
22+
Real,
23+
SmallInt,
24+
Text,
25+
Timestamp,
26+
Timestamptz,
27+
Varchar,
28+
)
29+
from piccolo.table import Table
30+
from piccolo.utils.sync import run_sync
31+
32+
33+
class SmallTable(Table):
34+
varchar_col = Varchar()
35+
36+
37+
class MegaTable(Table):
38+
"""
39+
A table containing all of the column types, and different column kwargs.
40+
"""
41+
42+
bigint_col = BigInt()
43+
boolean_col = Boolean()
44+
bytea_col = Bytea()
45+
date_col = Date()
46+
foreignkey_col = ForeignKey(SmallTable)
47+
integer_col = Integer()
48+
interval_col = Interval()
49+
json_col = JSON()
50+
jsonb_col = JSONB()
51+
numeric_col = Numeric()
52+
real_col = Real()
53+
smallint_col = SmallInt()
54+
text_col = Text()
55+
timestamp_col = Timestamp()
56+
timestamptz_col = Timestamptz()
57+
uuid_col = UUID()
58+
varchar_col = Varchar()
59+
60+
unique_col = Varchar(unique=True)
61+
null_col = Varchar(null=True)
62+
not_null_col = Varchar(null=False)
63+
64+
65+
class TestGenerate(TestCase):
66+
def setUp(self):
67+
for table_class in (SmallTable, MegaTable):
68+
table_class.create_table().run_sync()
69+
70+
def tearDown(self):
71+
for table_class in (MegaTable, SmallTable):
72+
table_class.alter().drop_table().run_sync()
73+
74+
def _compare_table_columns(
75+
self, table_1: t.Type[Table], table_2: t.Type[Table]
76+
):
77+
"""
78+
Make sure that for each column in table_1, there is a corresponding
79+
column in table_2 of the same type.
80+
"""
81+
column_names = [
82+
column._meta.name for column in table_1._meta.non_default_columns
83+
]
84+
for column_name in column_names:
85+
self.assertTrue(
86+
type(table_1._meta.get_column_by_name(column_name)),
87+
type(table_2._meta.get_column_by_name(column_name)),
88+
)
89+
90+
def test_get_output_schema(self):
91+
"""
92+
Make sure that the a Piccolo schema can be generated from the database.
93+
"""
94+
output_schema: OutputSchema = run_sync(get_output_schema())
95+
96+
self.assertTrue(len(output_schema.warnings) == 0)
97+
self.assertTrue(len(output_schema.tables) == 2)
98+
self.assertTrue(len(output_schema.imports) > 0)
99+
100+
MegaTable_ = output_schema.get_table_with_name("MegaTable")
101+
self._compare_table_columns(MegaTable, MegaTable_)
102+
103+
SmallTable_ = output_schema.get_table_with_name("SmallTable")
104+
self._compare_table_columns(SmallTable, SmallTable_)

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ async def drop_tables():
1818
"my_table",
1919
"recording_studio",
2020
"shirt",
21+
"mega_table",
22+
"small_table",
2123
]:
2224
await ENGINE._run_in_new_connection(f"DROP TABLE IF EXISTS {table}")
2325

0 commit comments

Comments
 (0)