@@ -36,6 +36,25 @@ def __eq__(self, value: TableDelta) -> bool: # type: ignore
3636 return True
3737
3838
39+ @dataclass
40+ class ColumnComparison :
41+ """
42+ As Column overrides it's `__eq__` method, to allow it to be used in the
43+ `where` clause of a query, we need to wrap `Column` if we want to compare
44+ them.
45+ """
46+
47+ column : Column
48+
49+ def __hash__ (self ) -> int :
50+ return self .column .__hash__ ()
51+
52+ def __eq__ (self , value ) -> bool :
53+ if isinstance (value , ColumnComparison ):
54+ return self .column ._meta .name == value .column ._meta .name
55+ return False
56+
57+
3958@dataclass
4059class DiffableTable :
4160 """
@@ -67,21 +86,27 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
6786 add_columns = [
6887 AddColumn (
6988 table_class_name = self .class_name ,
70- column_name = i ._meta .name ,
71- column_class_name = i .__class__ .__name__ ,
72- column_class = i .__class__ ,
73- params = i ._meta .params ,
89+ column_name = i .column ._meta .name ,
90+ column_class_name = i .column .__class__ .__name__ ,
91+ column_class = i .column .__class__ ,
92+ params = i .column ._meta .params ,
93+ )
94+ for i in (
95+ {ColumnComparison (column = column ) for column in self .columns }
96+ - {ColumnComparison (column = column ) for column in value .columns }
7497 )
75- for i in (set (self .columns ) - set (value .columns ))
7698 ]
7799
78100 drop_columns = [
79101 DropColumn (
80102 table_class_name = self .class_name ,
81- column_name = i ._meta .name ,
103+ column_name = i .column . _meta .name ,
82104 tablename = value .tablename ,
83105 )
84- for i in (set (value .columns ) - set (self .columns ))
106+ for i in (
107+ {ColumnComparison (column = column ) for column in value .columns }
108+ - {ColumnComparison (column = column ) for column in self .columns }
109+ )
85110 ]
86111
87112 alter_columns : t .List [AlterColumn ] = []
0 commit comments