Skip to content

Commit 30afec0

Browse files
committed
added set_default method to Alter
1 parent 2c8a531 commit 30afec0

File tree

3 files changed

+76
-18
lines changed

3 files changed

+76
-18
lines changed

piccolo/columns/base.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,37 @@ def get_select_string(self, engine_type: str, just_alias=False) -> str:
245245
"""
246246
return self._meta.get_full_name(just_alias=just_alias)
247247

248+
def get_sql_value(self, value: t.Any) -> str:
249+
"""
250+
When using DDL statements, we can't parameterise the values. An example
251+
is when setting the default for a column. So we have to convert from
252+
the Python type to a string representation which we can include in our
253+
DDL statements.
254+
255+
:param value:
256+
The Python value to convert to a string usable in a DDL statement
257+
e.g. 1.
258+
:returns:
259+
The string usable in the DDL statement e.g. '1'.
260+
261+
"""
262+
if isinstance(value, Default):
263+
output = getattr(value, self._meta.engine_type)
264+
elif value is None:
265+
output = "null"
266+
elif isinstance(value, (float, decimal.Decimal)):
267+
output = str(value)
268+
elif isinstance(value, str):
269+
output = f"'{value}'"
270+
elif isinstance(value, bool):
271+
output = str(value).lower()
272+
elif isinstance(value, datetime.datetime):
273+
output = f"'{value.isoformat().replace('T', '')}'"
274+
else:
275+
output = value
276+
277+
return output
278+
248279
@property
249280
def querystring(self) -> QueryString:
250281
"""
@@ -278,22 +309,8 @@ def querystring(self) -> QueryString:
278309

279310
if not self._meta.primary:
280311
default = self.get_default_value()
281-
if isinstance(default, Default):
282-
default_value = getattr(default, self._meta.engine_type)
283-
elif default is None:
284-
default_value = "null"
285-
elif isinstance(default, (float, decimal.Decimal)):
286-
default_value = str(default)
287-
elif isinstance(default, str):
288-
default_value = f"'{default}'"
289-
elif isinstance(default, bool):
290-
default_value = str(default).lower()
291-
elif isinstance(default, datetime.datetime):
292-
default_value = f"'{default.isoformat().replace('T', '')}'"
293-
else:
294-
default_value = default
295-
296-
query += f" DEFAULT {default_value}"
312+
sql_value = self.get_sql_value(value=default)
313+
query += f" DEFAULT {sql_value}"
297314

298315
return QueryString(query)
299316

piccolo/query/methods/alter.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,21 @@ def querystring(self) -> QueryString:
9191
return QueryString("DROP DEFAULT {}", self.column_name)
9292

9393

94+
@dataclass
95+
class SetDefault(AlterColumnStatement):
96+
__slots__ = ("column", "value")
97+
98+
column: Column
99+
value: t.Any
100+
101+
@property
102+
def querystring(self) -> QueryString:
103+
sql_value = self.column.get_sql_value(self.value)
104+
return QueryString(
105+
f"ALTER COLUMN {self.column_name} SET DEFAULT {sql_value}"
106+
)
107+
108+
94109
@dataclass
95110
class Unique(AlterColumnStatement):
96111
__slots__ = ("boolean",)
@@ -224,9 +239,10 @@ class Alter(Query):
224239
"_null",
225240
"_rename_columns",
226241
"_rename_table",
242+
"_set_default",
243+
"_set_digits",
227244
"_set_length",
228245
"_unique",
229-
"_set_digits",
230246
)
231247

232248
def __init__(self, table: t.Type[Table]):
@@ -240,9 +256,10 @@ def __init__(self, table: t.Type[Table]):
240256
self._null: t.List[Null] = []
241257
self._rename_columns: t.List[RenameColumn] = []
242258
self._rename_table: t.List[RenameTable] = []
259+
self._set_default: t.List[SetDefault] = []
260+
self._set_digits: t.List[SetDigits] = []
243261
self._set_length: t.List[SetLength] = []
244262
self._unique: t.List[Unique] = []
245-
self._set_digits: t.List[SetDigits] = []
246263

247264
def add_column(self, name: str, column: Column) -> Alter:
248265
"""
@@ -284,6 +301,15 @@ def rename_column(
284301
self._rename_columns.append(RenameColumn(column, new_name))
285302
return self
286303

304+
def set_default(self, column: Column, value: t.Any) -> Alter:
305+
"""
306+
Set the default for a column.
307+
308+
Band.alter().set_default(Band.popularity, 0)
309+
"""
310+
self._set_default.append(SetDefault(column=column, value=value))
311+
return self
312+
287313
def set_null(
288314
self, column: t.Union[str, Column], boolean: bool = True
289315
) -> Alter:
@@ -463,6 +489,7 @@ def querystrings(self) -> t.Sequence[QueryString]:
463489
self._unique,
464490
self._null,
465491
self._set_length,
492+
self._set_default,
466493
self._set_digits,
467494
)
468495
]

tests/table/test_alter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ def test_null(self):
141141
pass
142142

143143

144+
@postgres_only
145+
class TestSetDefault(DBTestCase):
146+
def test_set_default(self):
147+
Manager.alter().set_default(Manager.name, "Pending").run_sync()
148+
149+
# Bypassing the ORM to make sure the database default is present.
150+
Band.raw(
151+
"INSERT INTO manager (id, name) VALUES (DEFAULT, DEFAULT);"
152+
).run_sync()
153+
154+
manager = Manager.objects().first().run_sync()
155+
self.assertTrue(manager.name == "Pending")
156+
157+
144158
###############################################################################
145159

146160

0 commit comments

Comments
 (0)