11from __future__ import annotations
22from dataclasses import dataclass , field
33import inspect
4+ from piccolo .query .methods import alter
45import typing as t
56
67from piccolo .columns import Column , column_types
@@ -252,6 +253,8 @@ def alter_column(
252253 column_name : str ,
253254 params : t .Dict [str , t .Any ],
254255 old_params : t .Dict [str , t .Any ],
256+ column_class : t .Optional [t .Type [Column ]] = None ,
257+ old_column_class : t .Optional [t .Type [Column ]] = None ,
255258 ):
256259 """
257260 All possible alterations aren't currently supported.
@@ -263,6 +266,8 @@ def alter_column(
263266 column_name = column_name ,
264267 params = params ,
265268 old_params = old_params ,
269+ column_class = column_class ,
270+ old_column_class = old_column_class ,
266271 )
267272 )
268273
@@ -328,9 +333,34 @@ async def _run_alter_columns(self, backwards=False):
328333 _Table : t .Type [Table ] = type (table_class_name , (Table ,), {})
329334 _Table ._meta .tablename = alter_columns [0 ].tablename
330335
331- for column in alter_columns :
332- params = column .old_params if backwards else column .params
333- column_name = column .column_name
336+ for alter_column in alter_columns :
337+ params = (
338+ alter_column .old_params
339+ if backwards
340+ else alter_column .params
341+ )
342+ column_name = alter_column .column_name
343+
344+ # Change the column type if possible
345+ column_class = alter_column .column_class
346+ old_column_class = alter_column .old_column_class
347+ if (old_column_class is not None ) and (
348+ column_class is not None
349+ ):
350+ if old_column_class != column_class :
351+ old_column = old_column_class (
352+ ** alter_column .old_params
353+ )
354+ old_column ._meta ._table = _Table
355+ old_column ._meta ._name = alter_column .column_name
356+
357+ new_column = column_class (** alter_column .params )
358+ new_column ._meta ._table = _Table
359+ new_column ._meta ._name = alter_column .column_name
360+
361+ await _Table .alter ().set_column_type (
362+ old_column = old_column , new_column = new_column
363+ )
334364
335365 null = params .get ("null" )
336366 if null is not None :
@@ -383,7 +413,7 @@ async def _run_alter_columns(self, backwards=False):
383413 digits = params .get ("digits" , ...)
384414 if digits is not ...:
385415 await _Table .alter ().set_digits (
386- column = column .column_name , digits = digits ,
416+ column = alter_column .column_name , digits = digits ,
387417 ).run ()
388418
389419 async def _run_drop_tables (self , backwards = False ):
0 commit comments