Skip to content

Commit 255c5c1

Browse files
authored
Merge pull request piccolo-orm#105 from piccolo-orm/choices
Choices
2 parents c85b202 + 94b2241 commit 255c5c1

File tree

27 files changed

+465
-38
lines changed

27 files changed

+465
-38
lines changed

docs/src/piccolo/schema/advanced.rst

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,59 @@ use mixins to reduce the amount of repetition.
8383
8484
class Manager(FavouriteMixin, Table):
8585
name = Varchar()
86+
87+
-------------------------------------------------------------------------------
88+
89+
Choices
90+
-------
91+
92+
You can specify choices for a column, using Python's ``Enum`` support.
93+
94+
.. code-block:: python
95+
96+
from enum import Enum
97+
98+
from piccolo.columns import Varchar
99+
from piccolo.table import Table
100+
101+
102+
class Shirt(Table):
103+
class Size(str, Enum):
104+
small = 's'
105+
medium = 'm'
106+
large = 'l'
107+
108+
size = Varchar(length=1, choices=Size)
109+
110+
We can then use the ``Enum`` in our queries.
111+
112+
.. code-block:: python
113+
114+
>>> Shirt(size=Shirt.Size.large).save().run_sync()
115+
116+
>>> Shirt.select().run_sync()
117+
[{'id': 1, 'size': 'l'}]
118+
119+
Note how the value stored in the database is the ``Enum`` value (in this case ``'l'``).
120+
121+
You can also use the ``Enum`` in ``where`` clauses, and in most other situations
122+
where a query requires a value.
123+
124+
.. code-block:: python
125+
126+
>>> Shirt.insert(
127+
>>> Shirt(size=Shirt.Size.small),
128+
>>> Shirt(size=Shirt.Size.medium)
129+
>>> ).run_sync()
130+
131+
>>> Shirt.select().where(Shirt.size == Shirt.Size.small).run_sync()
132+
[{'id': 1, 'size': 's'}]
133+
134+
Advantages
135+
~~~~~~~~~~
136+
137+
By using choices, you get the following benefits:
138+
139+
* Signalling to other programmers what values are acceptable for the column.
140+
* Improved storage efficiency (we can store ``'l'`` instead of ``'large'``).
141+
* Piccolo admin support (in progress)

piccolo/apps/migrations/auto/serialisation.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,22 @@ def __lt__(self, other):
9494
return repr(self) < repr(other)
9595

9696

97+
@dataclass
98+
class SerialisedEnumType:
99+
enum_type: t.Type[Enum]
100+
101+
def __hash__(self):
102+
return hash(self.__repr__())
103+
104+
def __eq__(self, other):
105+
return self.__hash__() == other.__hash__()
106+
107+
def __repr__(self):
108+
class_name = self.enum_type.__name__
109+
params = {i.name: i.value for i in self.enum_type}
110+
return f"Enum('{class_name}', {params})"
111+
112+
97113
@dataclass
98114
class SerialisedCallable:
99115
callable_: t.Callable
@@ -162,10 +178,7 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams:
162178
for key, value in params.items():
163179

164180
# Builtins, such as str, list and dict.
165-
if (
166-
hasattr(value, "__module__")
167-
and value.__module__ == builtins.__name__
168-
):
181+
if inspect.getmodule(value) == builtins:
169182
params[key] = SerialisedBuiltin(builtin=value)
170183
continue
171184

@@ -238,6 +251,20 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams:
238251
)
239252
continue
240253

254+
# Enum types
255+
if inspect.isclass(value) and issubclass(value, Enum):
256+
params[key] = SerialisedEnumType(enum_type=value)
257+
extra_imports.append(Import(module="enum", target="Enum"))
258+
for member in value:
259+
type_ = type(member.value)
260+
module = inspect.getmodule(type_)
261+
262+
if module and module != builtins:
263+
module_name = module.__name__
264+
extra_imports.append(
265+
Import(module=module_name, target=type_.__name__)
266+
)
267+
241268
# Functions
242269
if inspect.isfunction(value):
243270
if value.__name__ == "<lambda>":
@@ -300,5 +327,7 @@ def deserialise_params(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
300327
params[key] = value.callable_
301328
elif isinstance(value, SerialisedTableType):
302329
params[key] = value.table_type
330+
elif isinstance(value, SerialisedEnumType):
331+
params[key] = value.enum_type
303332

304333
return params

piccolo/apps/migrations/commands/new.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ async def _create_new_migration(app_config: AppConfig, auto=False) -> None:
7575
chain(*[i.statements for i in alter_statements])
7676
)
7777
extra_imports = sorted(
78-
list(set(chain(*[i.extra_imports for i in alter_statements])))
78+
list(set(chain(*[i.extra_imports for i in alter_statements]))),
79+
key=lambda x: x.__repr__(),
7980
)
8081
extra_definitions = sorted(
81-
list(set(chain(*[i.extra_definitions for i in alter_statements])))
82+
list(set(chain(*[i.extra_definitions for i in alter_statements]))),
8283
)
8384

8485
if sum([len(i.statements) for i in alter_statements]) == 0:

piccolo/columns/base.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
NotLike,
2626
)
2727
from piccolo.columns.combination import Where
28+
from piccolo.columns.choices import Choice
2829
from piccolo.columns.defaults.base import Default
2930
from piccolo.columns.reference import LazyTableReference
3031
from piccolo.columns.indexes import IndexMethod
@@ -124,6 +125,7 @@ class ColumnMeta:
124125
index_method: IndexMethod = IndexMethod.btree
125126
required: bool = False
126127
help_text: t.Optional[str] = None
128+
choices: t.Optional[t.Type[Enum]] = None
127129

128130
# Used for representing the table in migrations and the playground.
129131
params: t.Dict[str, t.Any] = field(default_factory=dict)
@@ -164,6 +166,30 @@ def engine_type(self) -> str:
164166
else:
165167
raise ValueError("The table has no engine defined.")
166168

169+
def get_choices_dict(self) -> t.Optional[t.Dict[str, t.Any]]:
170+
"""
171+
Return the choices Enum as a dict. It maps the attribute name to a
172+
dict containing the display name, and value.
173+
"""
174+
if self.choices is None:
175+
return None
176+
else:
177+
output = {}
178+
for element in self.choices:
179+
if isinstance(element.value, Choice):
180+
display_name = element.value.display_name
181+
value = element.value.value
182+
else:
183+
display_name = element.name.replace("_", " ").title()
184+
value = element.value
185+
186+
output[element.name] = {
187+
"display_name": display_name,
188+
"value": value,
189+
}
190+
191+
return output
192+
167193
def get_full_name(self, just_alias=False) -> str:
168194
"""
169195
Returns the full column name, taking into account joins.
@@ -183,6 +209,8 @@ def get_full_name(self, just_alias=False) -> str:
183209
else:
184210
return f'{alias} AS "{column_name}"'
185211

212+
###########################################################################
213+
186214
def copy(self) -> ColumnMeta:
187215
kwargs = self.__dict__.copy()
188216
kwargs.update(
@@ -266,11 +294,14 @@ def __init__(
266294
index_method: IndexMethod = IndexMethod.btree,
267295
required: bool = False,
268296
help_text: t.Optional[str] = None,
297+
choices: t.Optional[t.Type[Enum]] = None,
269298
**kwargs,
270299
) -> None:
271300
# Used for migrations.
272301
# We deliberately omit 'required', and 'help_text' as they don't effect
273302
# the actual schema.
303+
# 'choices' isn't used directly in the schema, but may be important
304+
# for data migrations.
274305
kwargs.update(
275306
{
276307
"null": null,
@@ -279,6 +310,7 @@ def __init__(
279310
"unique": unique,
280311
"index": index,
281312
"index_method": index_method,
313+
"choices": choices,
282314
}
283315
)
284316

@@ -288,6 +320,9 @@ def __init__(
288320
"not nullable."
289321
)
290322

323+
if choices is not None:
324+
self._validate_choices(choices, allowed_type=self.value_type)
325+
291326
self._meta = ColumnMeta(
292327
null=null,
293328
primary=primary,
@@ -298,6 +333,7 @@ def __init__(
298333
params=kwargs,
299334
required=required,
300335
help_text=help_text,
336+
choices=choices,
301337
)
302338

303339
self.alias: t.Optional[str] = None
@@ -324,12 +360,37 @@ def _validate_default(
324360
elif callable(default):
325361
self._validated = True
326362
return True
363+
elif (
364+
isinstance(default, Enum) and type(default.value) in allowed_types
365+
):
366+
self._validated = True
367+
return True
327368
else:
328369
raise ValueError(
329370
f"The default {default} isn't one of the permitted types - "
330371
f"{allowed_types}"
331372
)
332373

374+
def _validate_choices(
375+
self, choices: t.Type[Enum], allowed_type: t.Type[t.Any]
376+
) -> bool:
377+
"""
378+
Make sure the choices value has values of the allowed_type.
379+
"""
380+
for element in choices:
381+
if isinstance(element.value, allowed_type):
382+
continue
383+
elif isinstance(element.value, Choice) and isinstance(
384+
element.value.value, allowed_type
385+
):
386+
continue
387+
else:
388+
raise ValueError(
389+
f"{element.name} doesn't have the correct type"
390+
)
391+
392+
return True
393+
333394
def is_in(self, values: t.List[t.Any]) -> Where:
334395
if len(values) == 0:
335396
raise ValueError(

piccolo/columns/choices.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
from dataclasses import dataclass
3+
import typing as t
4+
5+
6+
@dataclass
7+
class Choice:
8+
"""
9+
When defining enums for ``Column`` choices, they can either be defined
10+
like:
11+
12+
.. code-block:: python
13+
14+
class Title(Enum):
15+
mr = 1
16+
mrs = 2
17+
18+
If using Piccolo Admin, the values shown will be ``Mr`` and ``Mrs``. If you
19+
want more control, you can use ``Choice`` for the value instead.
20+
21+
.. code-block:: python
22+
23+
class Title(Enum):
24+
mr = Choice(value=1, display_name="Mr.")
25+
mrs = Choice(value=1, display_name="Mrs.")
26+
27+
Now the values shown will be ``Mr.`` and ``Mrs.``.
28+
29+
"""
30+
31+
value: t.Any
32+
display_name: str

piccolo/columns/column_types.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
import decimal
5+
from enum import Enum
56
import typing as t
67
import uuid
78
from datetime import date, datetime, time, timedelta
@@ -165,7 +166,7 @@ class Band(Table):
165166
def __init__(
166167
self,
167168
length: int = 255,
168-
default: t.Union[str, t.Callable[[], str], None] = "",
169+
default: t.Union[str, Enum, t.Callable[[], str], None] = "",
169170
**kwargs,
170171
) -> None:
171172
self._validate_default(default, (str, None))
@@ -251,7 +252,9 @@ class Band(Table):
251252
concat_delegate: ConcatDelegate = ConcatDelegate()
252253

253254
def __init__(
254-
self, default: t.Union[str, None, t.Callable[[], str]] = "", **kwargs
255+
self,
256+
default: t.Union[str, Enum, None, t.Callable[[], str]] = "",
257+
**kwargs,
255258
) -> None:
256259
self._validate_default(default, (str, None))
257260
self.default = default
@@ -333,7 +336,9 @@ class Band(Table):
333336
math_delegate = MathDelegate()
334337

335338
def __init__(
336-
self, default: t.Union[int, t.Callable[[], int], None] = 0, **kwargs
339+
self,
340+
default: t.Union[int, Enum, t.Callable[[], int], None] = 0,
341+
**kwargs,
337342
) -> None:
338343
self._validate_default(default, (int, None))
339344
self.default = default
@@ -771,7 +776,7 @@ class Band(Table):
771776

772777
def __init__(
773778
self,
774-
default: t.Union[bool, t.Callable[[], bool], None] = False,
779+
default: t.Union[bool, Enum, t.Callable[[], bool], None] = False,
775780
**kwargs,
776781
) -> None:
777782
self._validate_default(default, (bool, None))
@@ -841,7 +846,7 @@ def __init__(
841846
self,
842847
digits: t.Optional[t.Tuple[int, int]] = None,
843848
default: t.Union[
844-
decimal.Decimal, t.Callable[[], decimal.Decimal], None
849+
decimal.Decimal, Enum, t.Callable[[], decimal.Decimal], None
845850
] = decimal.Decimal(0.0),
846851
**kwargs,
847852
) -> None:
@@ -897,7 +902,7 @@ class Concert(Table):
897902

898903
def __init__(
899904
self,
900-
default: t.Union[float, t.Callable[[], float], None] = 0.0,
905+
default: t.Union[float, Enum, t.Callable[[], float], None] = 0.0,
901906
**kwargs,
902907
) -> None:
903908
self._validate_default(default, (float, None))
@@ -1087,7 +1092,7 @@ class Band(Table):
10871092
def __init__(
10881093
self,
10891094
references: t.Union[t.Type[Table], LazyTableReference, str],
1090-
default: t.Union[int, None] = None,
1095+
default: t.Union[int, Enum, None] = None,
10911096
null: bool = True,
10921097
on_delete: OnDelete = OnDelete.cascade,
10931098
on_update: OnUpdate = OnUpdate.cascade,
@@ -1324,6 +1329,7 @@ def __init__(
13241329
default: t.Union[
13251330
bytes,
13261331
bytearray,
1332+
Enum,
13271333
t.Callable[[], bytes],
13281334
t.Callable[[], bytearray],
13291335
None,
@@ -1376,7 +1382,7 @@ class Ticket(Table):
13761382
def __init__(
13771383
self,
13781384
base_column: Column,
1379-
default: t.Union[t.List, t.Callable[[], t.List], None] = list,
1385+
default: t.Union[t.List, Enum, t.Callable[[], t.List], None] = list,
13801386
**kwargs,
13811387
) -> None:
13821388
if isinstance(base_column, ForeignKey):

0 commit comments

Comments
 (0)