Skip to content

Commit aacd02d

Browse files
committed
modified postgres engine - can specify extensions to install on start
1 parent 06e90ae commit aacd02d

File tree

3 files changed

+48
-51
lines changed

3 files changed

+48
-51
lines changed

piccolo/engine/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import abstractmethod, ABCMeta
33
import typing as t
44

5+
from piccolo.utils.sync import run_sync
56
from piccolo.utils.warnings import colored_warning, Level
67

78
if t.TYPE_CHECKING: # pragma: no cover
@@ -14,8 +15,8 @@ class Batch:
1415

1516
class Engine(metaclass=ABCMeta):
1617
def __init__(self):
17-
self.check_version()
18-
self.prep_database()
18+
run_sync(self.check_version())
19+
run_sync(self.prep_database())
1920

2021
@property
2122
@abstractmethod
@@ -28,23 +29,23 @@ def min_version_number(self) -> float:
2829
pass
2930

3031
@abstractmethod
31-
def get_version(self) -> float:
32+
async def get_version(self) -> float:
3233
pass
3334

3435
@abstractmethod
35-
def prep_database(self):
36+
async def prep_database(self):
3637
pass
3738

3839
@abstractmethod
3940
async def batch(self, query: Query, batch_size: int = 100) -> Batch:
4041
pass
4142

42-
def check_version(self):
43+
async def check_version(self):
4344
"""
4445
Warn if the database version isn't supported.
4546
"""
4647
try:
47-
version_number = self.get_version()
48+
version_number = await self.get_version()
4849
except Exception as exception:
4950
colored_warning(
5051
f"Unable to fetch server version: {exception}",

piccolo/engine/postgres.py

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from __future__ import annotations
2-
import asyncio
3-
from concurrent.futures import ThreadPoolExecutor
42
import contextvars
53
from dataclasses import dataclass
64
import typing as t
@@ -201,25 +199,35 @@ class PostgresEngine(Engine):
201199
engine_type = "postgres"
202200
min_version_number = 9.6
203201

204-
def __init__(self, config: t.Dict[str, t.Any]) -> None:
202+
def __init__(
203+
self,
204+
config: t.Dict[str, t.Any],
205+
extensions: t.Sequence[str] = ["uuid-ossp"],
206+
) -> None:
205207
"""
206-
The config dictionary is passed to the underlying database adapter,
207-
asyncpg. Common arguments you're likely to need are:
208+
:param config:
209+
The config dictionary is passed to the underlying database adapter,
210+
asyncpg. Common arguments you're likely to need are:
208211
209-
* host
210-
* port
211-
* user
212-
* password
213-
* database
212+
* host
213+
* port
214+
* user
215+
* password
216+
* database
214217
215-
For example, ``{'host': 'localhost', 'port': 5432}``.
218+
For example, ``{'host': 'localhost', 'port': 5432}``.
216219
217-
To see all available options:
220+
To see all available options:
218221
219-
* https://magicstack.github.io/asyncpg/current/api/index.html#connection
222+
* https://magicstack.github.io/asyncpg/current/api/index.html#connection
223+
224+
:param extensions:
225+
When the engine starts, it will try and create these extensions
226+
in Postgres.
220227
221228
""" # noqa: E501
222229
self.config = config
230+
self.extensions = extensions
223231
self.pool: t.Optional[Pool] = None
224232
database_name = config.get("database", "Unknown")
225233
self.transaction_connection = contextvars.ContextVar(
@@ -241,20 +249,14 @@ def _parse_raw_version_string(version_string: str) -> float:
241249
version = float(f"{major}.{minor}")
242250
return version
243251

244-
def get_version(self) -> float:
252+
async def get_version(self) -> float:
245253
"""
246254
Returns the version of Postgres being run.
247255
"""
248-
loop = asyncio.new_event_loop()
249-
250-
with ThreadPoolExecutor(max_workers=1) as executor:
251-
future = executor.submit(
252-
loop.run_until_complete,
253-
self._run_in_new_connection("SHOW server_version"),
254-
)
255-
256256
try:
257-
response: t.Sequence[t.Dict] = future.result() # type: ignore
257+
response: t.Sequence[t.Dict] = await self._run_in_new_connection(
258+
"SHOW server_version"
259+
)
258260
except ConnectionRefusedError as exception:
259261
# Suppressing the exception, otherwise importing piccolo_conf.py
260262
# containing an engine will raise an ImportError.
@@ -267,26 +269,20 @@ def get_version(self) -> float:
267269
version_string=version_string
268270
)
269271

270-
def prep_database(self):
271-
loop = asyncio.new_event_loop()
272-
273-
with ThreadPoolExecutor(max_workers=1) as executor:
274-
future = executor.submit(
275-
loop.run_until_complete,
276-
self._run_in_new_connection(
277-
'CREATE EXTENSION IF NOT EXISTS "uuid-ossp"'
278-
),
279-
)
280-
281-
try:
282-
future.result()
283-
except InsufficientPrivilegeError:
284-
print(
285-
"Unable to create uuid-ossp extension - UUID columns might "
286-
"not behave as expected. Make sure your database user has "
287-
"permission to create extensions, or add it manually using "
288-
'`CREATE EXTENSION "uuid-ossp";`'
289-
)
272+
async def prep_database(self):
273+
for extension in self.extensions:
274+
try:
275+
await self._run_in_new_connection(
276+
f'CREATE EXTENSION IF NOT EXISTS "{extension}"',
277+
)
278+
except InsufficientPrivilegeError:
279+
print(
280+
f"Unable to create {extension} extension - some "
281+
"functionality may not behave as expected. Make sure your "
282+
"database user has permission to create extensions, or "
283+
"add it manually using "
284+
f'`CREATE EXTENSION "{extension}";`'
285+
)
290286

291287
###########################################################################
292288

piccolo/engine/sqlite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,14 @@ def path(self):
337337
def path(self, value: str):
338338
self.connection_kwargs["database"] = value
339339

340-
def get_version(self) -> float:
340+
async def get_version(self) -> float:
341341
"""
342342
Warn if the version of SQLite isn't supported.
343343
"""
344344
major, minor, _ = sqlite3.sqlite_version_info
345345
return float(f"{major}.{minor}")
346346

347-
def prep_database(self):
347+
async def prep_database(self):
348348
pass
349349

350350
###########################################################################

0 commit comments

Comments
 (0)