@@ -278,7 +278,7 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
278278 return result
279279
280280 def display_timeline (
281- self , ax , event = True , field = None , method = None , factor = 1 , ** kwargs
281+ self , ax , event = True , field = None , method = None , factor = 1 , colors_mode = "roll" , ** kwargs
282282 ):
283283 """
284284 Plot a timeline of a network.
@@ -289,6 +289,7 @@ def display_timeline(
289289 :param str,array field: yaxis values, if None, segments are used
290290 :param str method: if None, mean values are used
291291 :param float factor: to multiply field
292+ :param str colors_mode: color of lines. "roll" means looping through colors, "y" means color adapt the y values (for matching color plots)
292293 :return: plot mappable
293294 """
294295 self .only_one_network ()
@@ -302,9 +303,10 @@ def display_timeline(
302303 )
303304 line_kw .update (kwargs )
304305 mappables = dict (lines = list ())
306+
305307 if event :
306308 mappables .update (
307- self .event_timeline (ax , field = field , method = method , factor = factor )
309+ self .event_timeline (ax , field = field , method = method , factor = factor , colors_mode = colors_mode )
308310 )
309311 for i , b0 , b1 in self .iter_on ("segment" ):
310312 x = self .time [i ]
@@ -317,14 +319,25 @@ def display_timeline(
317319 y = self [field ][i ] * factor
318320 else :
319321 y = self [field ][i ].mean () * ones (x .shape ) * factor
320- line = ax .plot (x , y , ** line_kw , color = self .COLORS [j % self .NB_COLORS ])[0 ]
322+
323+ if colors_mode == "roll" :
324+ _color = self .get_color (j )
325+ elif colors_mode == "y" :
326+ _color = self .get_color (b0 - 1 )
327+ else :
328+ raise NotImplementedError (f"colors_mode '{ colors_mode } ' not defined" )
329+
330+ line = ax .plot (x , y , ** line_kw , color = _color )[0 ]
321331 mappables ["lines" ].append (line )
322332 j += 1
323333
324334 return mappables
325335
326- def event_timeline (self , ax , field = None , method = None , factor = 1 ):
336+ def event_timeline (self , ax , field = None , method = None , factor = 1 , colors_mode = "roll" ):
337+ """mark events in plot"""
327338 j = 0
339+ events = dict (spliting = [], merging = [])
340+
328341 # TODO : fill mappables dict
329342 y_seg = dict ()
330343 if field is not None and method != "all" :
@@ -337,7 +350,16 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
337350 x = self .time [i ]
338351 if x .shape [0 ] == 0 :
339352 continue
340- event_kw = dict (color = self .COLORS [j % self .NB_COLORS ], ls = "-" , zorder = 1 )
353+
354+ if colors_mode == "roll" :
355+ _color = self .get_color (j )
356+ elif colors_mode == "y" :
357+ _color = self .get_color (b0 - 1 )
358+ else :
359+ raise NotImplementedError (f"colors_mode '{ colors_mode } ' not defined" )
360+
361+ event_kw = dict (color = _color , ls = "-" , zorder = 1 )
362+
341363 i_n , i_p = (
342364 self .next_obs [i .stop - 1 ],
343365 self .previous_obs [i .start ],
@@ -361,7 +383,8 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
361383 )
362384 )
363385 ax .plot ((x [- 1 ], self .time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
364- ax .plot (x [- 1 ], y0 , color = "k" , marker = "H" , markersize = 10 , zorder = - 1 )[0 ]
386+ events ["merging" ].append ((x [- 1 ], y0 ))
387+
365388 if i_p != - 1 :
366389 seg_previous = self .segment [i_p ]
367390 if field is not None and method == "all" :
@@ -376,8 +399,21 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
376399 )
377400 )
378401 ax .plot ((x [0 ], self .time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
379- ax .plot (x [0 ], y0 , color = "k" , marker = "*" , markersize = 12 , zorder = - 1 )[0 ]
402+ events ["spliting" ].append ((x [0 ], y0 ))
403+
380404 j += 1
405+
406+ kwargs = dict (color = "k" , zorder = - 1 , linestyle = " " )
407+ if len (events ["spliting" ]) > 0 :
408+ X , Y = list (zip (* events ["spliting" ]))
409+ ref = ax .plot (X , Y , marker = "*" , markersize = 12 , label = "spliting events" , ** kwargs )[0 ]
410+ mappables .setdefault ("events" ,[]).append (ref )
411+
412+ if len (events ["merging" ]) > 0 :
413+ X , Y = list (zip (* events ["merging" ]))
414+ ref = ax .plot (X , Y , marker = "H" , markersize = 10 , label = "merging events" , ** kwargs )[0 ]
415+ mappables .setdefault ("events" ,[]).append (ref )
416+
381417 return mappables
382418
383419 def mean_by_segment (self , y , ** kw ):
0 commit comments