11from __future__ import annotations
2- import asyncio
3- from concurrent .futures import ThreadPoolExecutor
42import contextvars
53from dataclasses import dataclass
64import typing as t
@@ -201,25 +199,35 @@ class PostgresEngine(Engine):
201199 engine_type = "postgres"
202200 min_version_number = 9.6
203201
204- def __init__ (self , config : t .Dict [str , t .Any ]) -> None :
202+ def __init__ (
203+ self ,
204+ config : t .Dict [str , t .Any ],
205+ extensions : t .Sequence [str ] = ["uuid-ossp" ],
206+ ) -> None :
205207 """
206- The config dictionary is passed to the underlying database adapter,
207- asyncpg. Common arguments you're likely to need are:
208+ :param config:
209+ The config dictionary is passed to the underlying database adapter,
210+ asyncpg. Common arguments you're likely to need are:
208211
209- * host
210- * port
211- * user
212- * password
213- * database
212+ * host
213+ * port
214+ * user
215+ * password
216+ * database
214217
215- For example, ``{'host': 'localhost', 'port': 5432}``.
218+ For example, ``{'host': 'localhost', 'port': 5432}``.
216219
217- To see all available options:
220+ To see all available options:
218221
219- * https://magicstack.github.io/asyncpg/current/api/index.html#connection
222+ * https://magicstack.github.io/asyncpg/current/api/index.html#connection
223+
224+ :param extensions:
225+ When the engine starts, it will try and create these extensions
226+ in Postgres.
220227
221228 """ # noqa: E501
222229 self .config = config
230+ self .extensions = extensions
223231 self .pool : t .Optional [Pool ] = None
224232 database_name = config .get ("database" , "Unknown" )
225233 self .transaction_connection = contextvars .ContextVar (
@@ -241,20 +249,14 @@ def _parse_raw_version_string(version_string: str) -> float:
241249 version = float (f"{ major } .{ minor } " )
242250 return version
243251
244- def get_version (self ) -> float :
252+ async def get_version (self ) -> float :
245253 """
246254 Returns the version of Postgres being run.
247255 """
248- loop = asyncio .new_event_loop ()
249-
250- with ThreadPoolExecutor (max_workers = 1 ) as executor :
251- future = executor .submit (
252- loop .run_until_complete ,
253- self ._run_in_new_connection ("SHOW server_version" ),
254- )
255-
256256 try :
257- response : t .Sequence [t .Dict ] = future .result () # type: ignore
257+ response : t .Sequence [t .Dict ] = await self ._run_in_new_connection (
258+ "SHOW server_version"
259+ )
258260 except ConnectionRefusedError as exception :
259261 # Suppressing the exception, otherwise importing piccolo_conf.py
260262 # containing an engine will raise an ImportError.
@@ -267,26 +269,20 @@ def get_version(self) -> float:
267269 version_string = version_string
268270 )
269271
270- def prep_database (self ):
271- loop = asyncio .new_event_loop ()
272-
273- with ThreadPoolExecutor (max_workers = 1 ) as executor :
274- future = executor .submit (
275- loop .run_until_complete ,
276- self ._run_in_new_connection (
277- 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp"'
278- ),
279- )
280-
281- try :
282- future .result ()
283- except InsufficientPrivilegeError :
284- print (
285- "Unable to create uuid-ossp extension - UUID columns might "
286- "not behave as expected. Make sure your database user has "
287- "permission to create extensions, or add it manually using "
288- '`CREATE EXTENSION "uuid-ossp";`'
289- )
272+ async def prep_database (self ):
273+ for extension in self .extensions :
274+ try :
275+ await self ._run_in_new_connection (
276+ f'CREATE EXTENSION IF NOT EXISTS "{ extension } "' ,
277+ )
278+ except InsufficientPrivilegeError :
279+ print (
280+ f"Unable to create { extension } extension - some "
281+ "functionality may not behave as expected. Make sure your "
282+ "database user has permission to create extensions, or "
283+ "add it manually using "
284+ f'`CREATE EXTENSION "{ extension } ";`'
285+ )
290286
291287 ###########################################################################
292288
0 commit comments