Skip to content

Commit 328f6cd

Browse files
committed
adding support for altering foreign key constraints
1 parent d72530c commit 328f6cd

File tree

18 files changed

+251
-128
lines changed

18 files changed

+251
-128
lines changed

piccolo/columns/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
ForeignKey,
1010
UUID,
1111
)
12-
from .base import Column, Selectable # noqa
12+
from .base import Column, Selectable, OnDelete # noqa
1313
from .combination import And, Or, Where # noqa

piccolo/columns/base.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from abc import ABCMeta, abstractmethod
33
from dataclasses import dataclass, field
4+
from enum import Enum
45
import typing as t
56

67
from piccolo.columns.operators import (
@@ -28,6 +29,24 @@
2829
from .column_types import ForeignKey # noqa
2930

3031

32+
class OnDelete(str, Enum):
33+
cascade = "CASCADE"
34+
restrict = "RESTRICT"
35+
no_action = "NO ACTION"
36+
set_null = "SET NULL"
37+
set_default = "SET DEFAULT"
38+
39+
40+
OnUpdate = OnDelete
41+
42+
43+
@dataclass
44+
class ForeignKeyMeta:
45+
references: t.Type[Table]
46+
on_delete: OnDelete
47+
proxy_columns: t.List[Column] = field(default_factory=list)
48+
49+
3150
@dataclass
3251
class ColumnMeta:
3352
"""
@@ -46,7 +65,7 @@ class ColumnMeta:
4665

4766
# Set by the Table Metaclass:
4867
_name: t.Optional[str] = None
49-
_table: t.Optional[Table] = None
68+
_table: t.Optional[t.Type[Table]] = None
5069

5170
# Used by Foreign Keys:
5271
call_chain: t.List["ForeignKey"] = field(default_factory=lambda: [])
@@ -219,10 +238,13 @@ def querystring(self) -> QueryString:
219238
if not self._meta.null:
220239
query += " NOT NULL"
221240

222-
foreign_key_meta = getattr(self, "_foreign_key_meta", None)
241+
foreign_key_meta: t.Optional[ForeignKeyMeta] = getattr(
242+
self, "_foreign_key_meta", None
243+
)
223244
if foreign_key_meta:
224-
references = foreign_key_meta.references
225-
query += f" REFERENCES {references._meta.tablename}"
245+
tablename = foreign_key_meta.references._meta.tablename
246+
on_delete = foreign_key_meta.on_delete.value
247+
query += f" REFERENCES {tablename} ON DELETE {on_delete}"
226248

227249
return QueryString(query)
228250

piccolo/columns/column_types.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import copy
33
from dataclasses import dataclass, field
44
from datetime import datetime
5-
from enum import Enum
65
import typing as t
76
import uuid
87

9-
from piccolo.columns.base import Column
8+
from piccolo.columns.base import Column, ForeignKeyMeta, OnDelete
109
from piccolo.querystring import Unquoted
1110

1211
if t.TYPE_CHECKING:
@@ -126,18 +125,6 @@ def __init__(self, default: bool = False, **kwargs) -> None:
126125
super().__init__(**kwargs)
127126

128127

129-
@dataclass
130-
class ForeignKeyMeta:
131-
references: t.Type[Table]
132-
proxy_columns: t.List[Column] = field(default_factory=list)
133-
134-
135-
class OnDelete(str, Enum):
136-
cascade = "CASCADE"
137-
restrict = "RESTRICT"
138-
no_action = "NO ACTION"
139-
140-
141128
class ForeignKey(Integer):
142129
"""
143130
Returns an integer, representing the referenced row's ID.
@@ -169,10 +156,17 @@ class ForeignKey(Integer):
169156

170157
column_type = "INTEGER"
171158

172-
def __init__(self, references: t.Type[Table], **kwargs) -> None:
159+
def __init__(
160+
self,
161+
references: t.Type[Table],
162+
on_delete: OnDelete = OnDelete.cascade,
163+
**kwargs,
164+
) -> None:
173165
kwargs.update({"references": references})
174166
super().__init__(**kwargs)
175-
self._foreign_key_meta = ForeignKeyMeta(references=references)
167+
self._foreign_key_meta = ForeignKeyMeta(
168+
references=references, on_delete=on_delete
169+
)
176170

177171
# Allow columns on the referenced table to be accessed via auto
178172
# completion.

piccolo/columns/combination.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, first: Combinable, second: Combinable) -> None:
2828
self.second = second
2929

3030
@property
31-
def querystring(self):
31+
def querystring(self) -> QueryString:
3232
return QueryString(
3333
"({} " + self.operator + " {})",
3434
self.first.querystring,

piccolo/engine/postgres.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ async def __anext__(self):
4949
async def __aenter__(self):
5050
self._transaction = self.connection.transaction()
5151
await self._transaction.start()
52-
template, template_args = self.query.querystring.compile_string()
52+
querystring = self.query.querystring[0]
53+
template, template_args = querystring.compile_string()
5354

5455
self._cursor = await self.connection.cursor(template, *template_args)
5556
return self
@@ -84,10 +85,11 @@ def add(self, *query: Query):
8485
async def _run_queries(self, connection):
8586
async with connection.transaction():
8687
for query in self.queries:
87-
_query, args = query.querystring.compile_string(
88-
engine_type=self.engine.engine_type
89-
)
90-
await connection.execute(_query, *args)
88+
for querystring in query.querystring:
89+
_query, args = querystring.compile_string(
90+
engine_type=self.engine.engine_type
91+
)
92+
await connection.execute(_query, *args)
9193

9294
self.queries = []
9395

piccolo/engine/sqlite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ async def __anext__(self):
4141
return response
4242

4343
async def __aenter__(self):
44-
template, template_args = self.query.querystring.compile_string()
44+
querystring = self.query.querystring[0]
45+
template, template_args = querystring.compile_string()
4546

4647
self._cursor = await self.connection.execute(template, *template_args)
4748
return self

piccolo/query/base.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,20 @@ async def run(self, in_pool=True):
9090
f"Table {self.table._meta.tablename} has no db defined in _meta"
9191
)
9292

93-
results = await engine.run_querystring(
94-
self.querystring, in_pool=in_pool
95-
)
96-
97-
return await self._process_results(results)
93+
if len(self.querystring) == 1:
94+
results = await engine.run_querystring(
95+
self.querystring[0], in_pool=in_pool
96+
)
97+
return await self._process_results(results)
98+
else:
99+
responses = []
100+
# TODO - run in a transaction
101+
for querystring in self.querystring:
102+
results = await engine.run_querystring(
103+
querystring, in_pool=in_pool
104+
)
105+
responses.append(await self._process_results(results))
106+
return responses
98107

99108
def run_sync(self, *args, **kwargs):
100109
"""
@@ -112,19 +121,19 @@ async def response_handler(self, response):
112121
###########################################################################
113122

114123
@property
115-
def sqlite_querystring(self) -> QueryString:
124+
def sqlite_querystring(self) -> t.Sequence[QueryString]:
116125
raise NotImplementedError
117126

118127
@property
119-
def postgres_querystring(self) -> QueryString:
128+
def postgres_querystring(self) -> t.Sequence[QueryString]:
120129
raise NotImplementedError
121130

122131
@property
123-
def default_querystring(self) -> QueryString:
132+
def default_querystring(self) -> t.Sequence[QueryString]:
124133
raise NotImplementedError
125134

126135
@property
127-
def querystring(self) -> QueryString:
136+
def querystring(self) -> t.Sequence[QueryString]:
128137
"""
129138
Calls the correct underlying method, depending on the current engine.
130139
"""
@@ -147,4 +156,4 @@ def querystring(self) -> QueryString:
147156
###########################################################################
148157

149158
def __str__(self) -> str:
150-
return self.querystring.__str__()
159+
return "; ".join([i.__str__() for i in self.querystring])

0 commit comments

Comments
 (0)