@@ -68,9 +68,27 @@ class NetworkObservations(EddiesObservations):
6868 @property
6969 def elements (self ):
7070 elements = super ().elements
71- elements .extend (["track" , "segment" , "next_obs" , "previous_obs" ])
71+ elements .extend (
72+ [
73+ "track" ,
74+ "segment" ,
75+ "next_obs" ,
76+ "previous_obs" ,
77+ "next_cost" ,
78+ "previous_cost" ,
79+ ]
80+ )
7281 return list (set (elements ))
7382
83+ def astype (self , cls ):
84+ new = cls .new_like (self , self .shape )
85+ print ()
86+ for k in new .obs .dtype .names :
87+ if k in self .obs .dtype .names :
88+ new [k ][:] = self [k ][:]
89+ new .sign_type = self .sign_type
90+ return new
91+
7492 def longer_than (self , nb_day_min = - 1 , nb_day_max = - 1 ):
7593 """
7694 Select network on time duration
@@ -81,7 +99,7 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
8199 if nb_day_max < 0 :
82100 nb_day_max = 1000000000000
83101 mask = zeros (self .shape , dtype = "bool" )
84- for i , b0 , b1 in self .iter_on (self .segment_track_array () ):
102+ for i , b0 , b1 in self .iter_on (self .segment_track_array ):
85103 nb = i .stop - i .start
86104 if nb == 0 :
87105 continue
@@ -115,6 +133,8 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
115133 translate [index_order ] = arange (index_order .shape [0 ])
116134 network .next_obs [:] = translate [n ]
117135 network .previous_obs [:] = translate [p ]
136+ network .next_cost [:] = indexs ["next_cost" ][index_order ]
137+ network .previous_cost [:] = indexs ["previous_cost" ][index_order ]
118138 return network
119139
120140 def infos (self , label = "" ):
@@ -205,7 +225,7 @@ def position_filter(self, median_half_window, loess_half_window):
205225
206226 def loess_filter (self , half_window , xfield , yfield , inplace = True ):
207227 result = track_loess_filter (
208- half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array ()
228+ half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array
209229 )
210230 if inplace :
211231 self .obs [yfield ] = result
@@ -214,7 +234,7 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True):
214234
215235 def median_filter (self , half_window , xfield , yfield , inplace = True ):
216236 result = track_median_filter (
217- half_window , self [xfield ], self [yfield ], self .segment_track_array ()
237+ half_window , self [xfield ], self [yfield ], self .segment_track_array
218238 )
219239 if inplace :
220240 self [yfield ][:] = result
@@ -316,18 +336,59 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
316336 j += 1
317337 return mappables
318338
319- def scatter_timeline (self , ax , name , factor = 1 , event = True , ** kwargs ):
339+ def mean_by_segment (self , y , ** kw ):
340+ kw ["dtype" ] = y .dtype
341+ return self .map_segment (lambda x : x .mean (), y , ** kw )
342+
343+ def map_segment (self , method , y , same = True , ** kw ):
344+ if same :
345+ out = empty (y .shape , ** kw )
346+ else :
347+ out = list ()
348+ for i , b0 , b1 in self .iter_on (self .segment_track_array ):
349+ res = method (y [i ])
350+ if same :
351+ out [i ] = res
352+ else :
353+ if isinstance (i , slice ):
354+ if i .start == i .stop :
355+ continue
356+ elif len (i ) == 0 :
357+ continue
358+ out .append (res )
359+ if not same :
360+ out = array (out )
361+ return out
362+
363+ def scatter_timeline (
364+ self ,
365+ ax ,
366+ name ,
367+ factor = 1 ,
368+ event = True ,
369+ yfield = None ,
370+ yfactor = 1 ,
371+ method = None ,
372+ ** kwargs ,
373+ ):
320374 """
321375 Must be call on only one network
322376 """
323377 self .only_one_network ()
378+ y = (self .segment if yfield is None else self [yfield ]) * yfactor
379+ if method == "all" :
380+ pass
381+ else :
382+ y = self .mean_by_segment (y )
324383 mappables = dict ()
325384 if event :
326- mappables .update (self .event_timeline (ax ))
385+ mappables .update (
386+ self .event_timeline (ax , field = yfield , method = method , factor = yfactor )
387+ )
327388 if "c" not in kwargs :
328389 v = self .parse_varname (name )
329390 kwargs ["c" ] = v * factor
330- mappables ["scatter" ] = ax .scatter (self .time , self . segment , ** kwargs )
391+ mappables ["scatter" ] = ax .scatter (self .time , y , ** kwargs )
331392 return mappables
332393
333394 def insert_virtual (self ):
@@ -350,13 +411,14 @@ def extract_event(self, indices):
350411 new .sign_type = self .sign_type
351412 return new
352413
414+ @property
353415 def segment_track_array (self ):
354416 return build_unique_array (self .segment , self .track )
355417
356418 def birth_event (self ):
357419 # FIXME how to manage group 0
358420 indices = list ()
359- for i , _ , _ in self .iter_on (self .segment_track_array () ):
421+ for i , _ , _ in self .iter_on (self .segment_track_array ):
360422 nb = i .stop - i .start
361423 if nb == 0 :
362424 continue
@@ -368,7 +430,7 @@ def birth_event(self):
368430 def death_event (self ):
369431 # FIXME how to manage group 0
370432 indices = list ()
371- for i , _ , _ in self .iter_on (self .segment_track_array () ):
433+ for i , _ , _ in self .iter_on (self .segment_track_array ):
372434 nb = i .stop - i .start
373435 if nb == 0 :
374436 continue
@@ -379,7 +441,7 @@ def death_event(self):
379441
380442 def merging_event (self ):
381443 indices = list ()
382- for i , _ , _ in self .iter_on (self .segment_track_array () ):
444+ for i , _ , _ in self .iter_on (self .segment_track_array ):
383445 nb = i .stop - i .start
384446 if nb == 0 :
385447 continue
@@ -390,7 +452,7 @@ def merging_event(self):
390452
391453 def spliting_event (self ):
392454 indices = list ()
393- for i , _ , _ in self .iter_on (self .segment_track_array () ):
455+ for i , _ , _ in self .iter_on (self .segment_track_array ):
394456 nb = i .stop - i .start
395457 if nb == 0 :
396458 continue
@@ -403,7 +465,7 @@ def fully_connected(self):
403465 self .only_one_network ()
404466 # TODO
405467
406- def plot (self , ax , ref = None , ** kwargs ):
468+ def plot (self , ax , ref = None , color_cycle = None , ** kwargs ):
407469 """
408470 This function will draw path of each trajectory
409471
@@ -412,17 +474,25 @@ def plot(self, ax, ref=None, **kwargs):
412474 :param dict kwargs: keyword arguments for Axes.plot
413475 :return: a list of matplotlib mappables
414476 """
477+ nb_colors = 0
478+ if color_cycle is not None :
479+ kwargs = kwargs .copy ()
480+ nb_colors = len (color_cycle )
415481 mappables = list ()
416482 if "label" in kwargs :
417483 kwargs ["label" ] = self .format_label (kwargs ["label" ])
418- for i , b0 , b1 in self .iter_on ("segment" ):
484+ j = 0
485+ for i , _ , _ in self .iter_on ("segment" ):
419486 nb = i .stop - i .start
420487 if nb == 0 :
421488 continue
489+ if nb_colors :
490+ kwargs ["color" ] = color_cycle [j % nb_colors ]
422491 x , y = self .lon [i ], self .lat [i ]
423492 if ref is not None :
424493 x , y = wrap_longitude (x , y , ref , cut = True )
425494 mappables .append (ax .plot (x , y , ** kwargs )[0 ])
495+ j += 1
426496 return mappables
427497
428498 def remove_dead_branch (self , nobs = 3 ):
0 commit comments