Skip to content

Commit 0747ef8

Browse files
committed
added ColumnComparison
1 parent 98de5b6 commit 0747ef8

File tree

1 file changed

+35
-7
lines changed

1 file changed

+35
-7
lines changed

piccolo/apps/migrations/auto/diffable_table.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ def __eq__(self, value: TableDelta) -> bool: # type: ignore
3636
return True
3737

3838

39+
@dataclass
40+
class ColumnComparison:
41+
"""
42+
As Column overrides it's `__eq__` method, to allow it to be used in the
43+
`where` clause of a query, we need to wrap `Column` if we want to compare
44+
them.
45+
"""
46+
47+
column: Column
48+
49+
def __hash__(self) -> int:
50+
return self.column.__hash__()
51+
52+
def __eq__(self, value) -> bool:
53+
if isinstance(value, ColumnComparison):
54+
return (
55+
serialise_params(self.column._meta.params).params
56+
== serialise_params(value.column._meta.params).params
57+
) and (self.column._meta.name == value.column._meta.name)
58+
return False
59+
60+
3961
@dataclass
4062
class DiffableTable:
4163
"""
@@ -67,21 +89,27 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
6789
add_columns = [
6890
AddColumn(
6991
table_class_name=self.class_name,
70-
column_name=i._meta.name,
71-
column_class_name=i.__class__.__name__,
72-
column_class=i.__class__,
73-
params=i._meta.params,
92+
column_name=i.column._meta.name,
93+
column_class_name=i.column.__class__.__name__,
94+
column_class=i.column.__class__,
95+
params=i.column._meta.params,
96+
)
97+
for i in (
98+
{ColumnComparison(column=column) for column in self.columns}
99+
- {ColumnComparison(column=column) for column in value.columns}
74100
)
75-
for i in (set(self.columns) - set(value.columns))
76101
]
77102

78103
drop_columns = [
79104
DropColumn(
80105
table_class_name=self.class_name,
81-
column_name=i._meta.name,
106+
column_name=i.column._meta.name,
82107
tablename=value.tablename,
83108
)
84-
for i in (set(value.columns) - set(self.columns))
109+
for i in (
110+
{ColumnComparison(column=column) for column in value.columns}
111+
- {ColumnComparison(column=column) for column in self.columns}
112+
)
85113
]
86114

87115
alter_columns: t.List[AlterColumn] = []

0 commit comments

Comments
 (0)