Skip to content

Commit 4f4b443

Browse files
committed
improved enum serialisation
When an enum instance (e.g. `Colours.red`) was passed as a `default` argument to a Column, migrations wouldn't serialise properly.
1 parent 467a0b6 commit 4f4b443

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

piccolo/apps/migrations/auto/serialisation.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ def __repr__(self):
7373
return f"{self.instance.__class__.__name__}({args})"
7474

7575

76+
@dataclass
77+
class SerialisedEnumInstance:
78+
instance: Enum
79+
80+
def __hash__(self):
81+
return hash(self.__repr__())
82+
83+
def __eq__(self, other):
84+
return self.__hash__() == other.__hash__()
85+
86+
def __repr__(self):
87+
return f"{self.instance.__class__.__name__}.{self.instance.name}"
88+
89+
7690
@dataclass
7791
class SerialisedTableType:
7892
table_type: t.Type[Table]
@@ -241,14 +255,29 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams:
241255
extra_imports.append(Import(module="decimal", target="Decimal"))
242256
continue
243257

244-
# Enums
258+
# Enum instances
245259
if isinstance(value, Enum):
246-
# Enums already have a good __repr__.
247-
extra_imports.append(
248-
Import(
249-
module=value.__module__, target=value.__class__.__name__,
260+
if value.__module__.startswith("piccolo"):
261+
# It's an Enum defined within Piccolo, so we can safely import
262+
# it.
263+
params[key] = SerialisedEnumInstance(instance=value)
264+
extra_imports.append(
265+
Import(
266+
module=value.__module__,
267+
target=value.__class__.__name__,
268+
)
250269
)
251-
)
270+
else:
271+
# It's a user defined Enum, so we'll insert the raw value.
272+
enum_serialised_params: SerialisedParams = serialise_params(
273+
params={key: value.value}
274+
)
275+
params[key] = enum_serialised_params.params[key]
276+
extra_imports.extend(enum_serialised_params.extra_imports)
277+
extra_definitions.extend(
278+
enum_serialised_params.extra_definitions
279+
)
280+
252281
continue
253282

254283
# Enum types
@@ -329,5 +358,7 @@ def deserialise_params(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
329358
params[key] = value.table_type
330359
elif isinstance(value, SerialisedEnumType):
331360
params[key] = value.enum_type
361+
elif isinstance(value, SerialisedEnumInstance):
362+
params[key] = value.instance
332363

333364
return params

0 commit comments

Comments
 (0)