2020from piccolo .engine .finder import engine_finder
2121from piccolo .engine .postgres import PostgresEngine
2222from piccolo .table import Table , create_table_class
23+ from piccolo .utils .naming import _snake_to_camel
2324
2425if t .TYPE_CHECKING : # pragma: no cover
2526 from piccolo .engine .base import Engine
@@ -30,11 +31,103 @@ class PostgresRowMeta:
3031 column_default : str
3132 column_name : str
3233 is_nullable : Literal ["YES" , "NO" ]
33- ordinal_position : int
3434 table_name : str
3535 character_maximum_length : t .Optional [int ]
3636 data_type : str
3737
38+ @classmethod
39+ def get_column_name_str (cls ) -> str :
40+ return ", " .join ([i .name for i in dataclasses .fields (cls )])
41+
42+
43+ @dataclasses .dataclass
44+ class PostgresContraint :
45+ contraint_type : Literal ["PRIMARY KEY" , "UNIQUE" , "FOREIGN KEY" , "CHECK" ]
46+ constraint_name : str
47+
48+ @classmethod
49+ def get_column_name_str (cls ) -> str :
50+ return ", " .join ([i .name for i in dataclasses .fields (cls )])
51+
52+
53+ COLUMN_TYPE_MAP = {
54+ "bigint" : BigInt ,
55+ "boolean" : Boolean ,
56+ "character varying" : Varchar ,
57+ "integer" : Integer ,
58+ "numeric" : Numeric ,
59+ "smallint" : SmallInt ,
60+ "text" : Text ,
61+ "timestamp without time zone" : Timestamp ,
62+ "timestamp with time zone" : Timestamptz ,
63+ "uuid" : UUID ,
64+ }
65+
66+
67+ async def get_foreign_keys (
68+ table_class : t .Type [Table ], schema_name : str = "public"
69+ ):
70+ """
71+ :param table_class:
72+ Any Table subclass - just used to execute raw queries on the database.
73+
74+ """
75+ foreign_key_meta : t .List [str ] = await table_class .raw (
76+ (
77+ f"SELECT { PostgresContraint .get_column_name_str ()} FROM "
78+ "information_schema.table_constraints "
79+ "WHERE table_schema = {} "
80+ "AND table_name = {};"
81+ ),
82+ schema_name ,
83+ )
84+ return foreign_key_meta
85+
86+
87+ async def get_tablenames (
88+ table_class : t .Type [Table ], schema_name : str = "public"
89+ ) -> t .List [str ]:
90+ """
91+ :param table_class:
92+ Any Table subclass - just used to execute raw queries on the database.
93+ :returns:
94+ A list of tablenames for the given schema.
95+
96+ """
97+ tablenames : t .List [str ] = [
98+ i ["tablename" ]
99+ for i in await table_class .raw (
100+ (
101+ "SELECT tablename FROM pg_catalog.pg_tables WHERE "
102+ "schemaname = {};"
103+ ),
104+ schema_name ,
105+ ).run ()
106+ ]
107+ return tablenames
108+
109+
110+ async def get_table_schema (
111+ table_class : t .Type [Table ], tablename : str , schema_name : str = "public"
112+ ) -> t .List [PostgresRowMeta ]:
113+ """
114+ :returns:
115+ A list, with each item containing information about a colum in the
116+ table.
117+
118+ """
119+ row_meta_list = await table_class .raw (
120+ (
121+ f"SELECT { PostgresRowMeta .get_column_name_str ()} FROM "
122+ "information_schema.columns "
123+ "WHERE table_schema = {} "
124+ "AND TABLE_NAME = {};"
125+ ),
126+ schema_name ,
127+ tablename ,
128+ ).run ()
129+ return [PostgresRowMeta (** row_meta ) for row_meta in row_meta_list ]
130+
38131
39132# This is currently a beta version, and can be improved. However, having
40133# something working is still useful for people migrating large schemas to
@@ -61,81 +154,46 @@ async def generate(schema_name: str = "public"):
61154 class Schema (Table , db = engine ):
62155 pass
63156
64- tablenames : t .List [str ] = [
65- i ["tablename" ]
66- for i in await Schema .raw (
67- (
68- "SELECT tablename FROM pg_catalog.pg_tables WHERE "
69- "schemaname = {};"
70- ),
71- schema_name ,
72- ).run ()
73- ]
74-
75- schema_column_names = ", " .join (
76- [i .name for i in dataclasses .fields (PostgresRowMeta )]
77- )
157+ tablenames = await get_tablenames (Schema , schema_name = schema_name )
78158
79- output = []
159+ output : t .List [str ] = []
160+ imports : t .Set [str ] = {"from piccolo.table import Table" }
80161 warnings : t .List [str ] = []
81162
82163 for tablename in tablenames :
83- row_meta_list = await Schema .raw (
84- (
85- f"SELECT { schema_column_names } FROM "
86- "information_schema.columns "
87- "WHERE table_schema = {} "
88- "AND TABLE_NAME = {};"
89- ),
90- schema_name ,
91- tablename ,
92- ).run ()
164+ table_schema = await get_table_schema (Schema , tablename , schema_name )
165+
166+ columns : t .Dict [str , Column ] = {}
93167
94- class_name = tablename .title ()
95-
96- if row_meta_list :
97- columns : t .Dict [str , Column ] = {}
98-
99- for row_meta in row_meta_list :
100- pg_row_meta = PostgresRowMeta (** row_meta )
101-
102- column_type_map = {
103- "bigint" : BigInt ,
104- "boolean" : Boolean ,
105- "character varying" : Varchar ,
106- "integer" : Integer ,
107- "numeric" : Numeric ,
108- "smallint" : SmallInt ,
109- "text" : Text ,
110- "timestamp without time zone" : Timestamp ,
111- "timestamp with time zone" : Timestamptz ,
112- "uuid" : UUID ,
168+ for pg_row_meta in table_schema :
169+ data_type = pg_row_meta .data_type
170+ column_type = COLUMN_TYPE_MAP .get (data_type , None )
171+ column_name = pg_row_meta .column_name
172+
173+ if column_type :
174+ kwargs : t .Dict [str , t .Any ] = {
175+ "null" : pg_row_meta .is_nullable == "YES"
113176 }
114177
115- data_type = pg_row_meta .data_type
116- column_type = column_type_map .get (data_type , None )
117- column_name = pg_row_meta .column_name
118-
119- if column_type :
120- kwargs : t .Dict [str , t .Any ] = {
121- "null" : pg_row_meta .is_nullable == "YES"
122- }
123-
124- if column_type is Varchar :
125- kwargs ["length" ] = pg_row_meta .character_maximum_length
126-
127- columns [pg_row_meta .column_name ] = column_type (** kwargs )
128- else :
129- warnings .append (
130- f"{ tablename } .{ column_name } ['{ data_type } ']"
131- )
132-
133- table = create_table_class (
134- class_name = class_name ,
135- class_kwargs = {"tablename" : tablename },
136- class_members = columns ,
137- )
138- output .append (table ._table_str ())
178+ imports .add (
179+ "from piccolo.column_types import " + column_type .__name__
180+ )
181+
182+ if column_type is Varchar :
183+ kwargs ["length" ] = pg_row_meta .character_maximum_length
184+
185+ columns [column_name ] = column_type (** kwargs )
186+ else :
187+ warnings .append (f"{ tablename } .{ column_name } ['{ data_type } ']" )
188+
189+ table = create_table_class (
190+ class_name = _snake_to_camel (tablename ),
191+ class_kwargs = {"tablename" : tablename },
192+ class_members = columns ,
193+ )
194+ output .append (table ._table_str ())
195+
196+ output = sorted (list (imports )) + output
139197
140198 if warnings :
141199 warning_str = "\n " .join (warnings )
0 commit comments