@@ -68,9 +68,27 @@ class NetworkObservations(EddiesObservations):
6868    @property  
6969    def  elements (self ):
7070        elements  =  super ().elements 
71-         elements .extend (["track" , "segment" , "next_obs" , "previous_obs" ])
71+         elements .extend (
72+             [
73+                 "track" ,
74+                 "segment" ,
75+                 "next_obs" ,
76+                 "previous_obs" ,
77+                 "next_cost" ,
78+                 "previous_cost" ,
79+             ]
80+         )
7281        return  list (set (elements ))
7382
83+     def  astype (self , cls ):
84+         new  =  cls .new_like (self , self .shape )
85+         print ()
86+         for  k  in  new .obs .dtype .names :
87+             if  k  in  self .obs .dtype .names :
88+                 new [k ][:] =  self [k ][:]
89+         new .sign_type  =  self .sign_type 
90+         return  new 
91+ 
7492    def  longer_than (self , nb_day_min = - 1 , nb_day_max = - 1 ):
7593        """ 
7694        Select network on time duration 
@@ -81,7 +99,7 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
8199        if  nb_day_max  <  0 :
82100            nb_day_max  =  1000000000000 
83101        mask  =  zeros (self .shape , dtype = "bool" )
84-         for  i , b0 , b1  in  self .iter_on (self .segment_track_array () ):
102+         for  i , b0 , b1  in  self .iter_on (self .segment_track_array ):
85103            nb  =  i .stop  -  i .start 
86104            if  nb  ==  0 :
87105                continue 
@@ -115,6 +133,8 @@ def from_split_network(cls, group_dataset, indexs, **kwargs):
115133        translate [index_order ] =  arange (index_order .shape [0 ])
116134        network .next_obs [:] =  translate [n ]
117135        network .previous_obs [:] =  translate [p ]
136+         network .next_cost [:] =  indexs ["next_cost" ][index_order ]
137+         network .previous_cost [:] =  indexs ["previous_cost" ][index_order ]
118138        return  network 
119139
120140    def  infos (self , label = "" ):
@@ -205,7 +225,7 @@ def position_filter(self, median_half_window, loess_half_window):
205225
206226    def  loess_filter (self , half_window , xfield , yfield , inplace = True ):
207227        result  =  track_loess_filter (
208-             half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array () 
228+             half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array 
209229        )
210230        if  inplace :
211231            self .obs [yfield ] =  result 
@@ -214,7 +234,7 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True):
214234
215235    def  median_filter (self , half_window , xfield , yfield , inplace = True ):
216236        result  =  track_median_filter (
217-             half_window , self [xfield ], self [yfield ], self .segment_track_array () 
237+             half_window , self [xfield ], self [yfield ], self .segment_track_array 
218238        )
219239        if  inplace :
220240            self [yfield ][:] =  result 
@@ -316,18 +336,59 @@ def event_timeline(self, ax, field=None, method=None, factor=1):
316336            j  +=  1 
317337        return  mappables 
318338
319-     def  scatter_timeline (self , ax , name , factor = 1 , event = True , ** kwargs ):
339+     def  mean_by_segment (self , y , ** kw ):
340+         kw ["dtype" ] =  y .dtype 
341+         return  self .map_segment (lambda  x : x .mean (), y , ** kw )
342+ 
343+     def  map_segment (self , method , y , same = True , ** kw ):
344+         if  same :
345+             out  =  empty (y .shape , ** kw )
346+         else :
347+             out  =  list ()
348+         for  i , b0 , b1  in  self .iter_on (self .segment_track_array ):
349+             res  =  method (y [i ])
350+             if  same :
351+                 out [i ] =  res 
352+             else :
353+                 if  isinstance (i , slice ):
354+                     if  i .start  ==  i .stop :
355+                         continue 
356+                 elif  len (i ) ==  0 :
357+                     continue 
358+                 out .append (res )
359+         if  not  same :
360+             out  =  array (out )
361+         return  out 
362+ 
363+     def  scatter_timeline (
364+         self ,
365+         ax ,
366+         name ,
367+         factor = 1 ,
368+         event = True ,
369+         yfield = None ,
370+         yfactor = 1 ,
371+         method = None ,
372+         ** kwargs ,
373+     ):
320374        """ 
321375        Must be call on only one network 
322376        """ 
323377        self .only_one_network ()
378+         y  =  (self .segment  if  yfield  is  None  else  self [yfield ]) *  yfactor 
379+         if  method  ==  "all" :
380+             pass 
381+         else :
382+             y  =  self .mean_by_segment (y )
324383        mappables  =  dict ()
325384        if  event :
326-             mappables .update (self .event_timeline (ax ))
385+             mappables .update (
386+                 self .event_timeline (ax , field = yfield , method = method , factor = yfactor )
387+             )
327388        if  "c"  not  in kwargs :
328389            v  =  self .parse_varname (name )
329390            kwargs ["c" ] =  v  *  factor 
330-         mappables ["scatter" ] =  ax .scatter (self .time , self . segment , ** kwargs )
391+         mappables ["scatter" ] =  ax .scatter (self .time , y , ** kwargs )
331392        return  mappables 
332393
333394    def  insert_virtual (self ):
@@ -350,13 +411,14 @@ def extract_event(self, indices):
350411        new .sign_type  =  self .sign_type 
351412        return  new 
352413
414+     @property  
353415    def  segment_track_array (self ):
354416        return  build_unique_array (self .segment , self .track )
355417
356418    def  birth_event (self ):
357419        # FIXME how to manage group 0 
358420        indices  =  list ()
359-         for  i , _ , _  in  self .iter_on (self .segment_track_array () ):
421+         for  i , _ , _  in  self .iter_on (self .segment_track_array ):
360422            nb  =  i .stop  -  i .start 
361423            if  nb  ==  0 :
362424                continue 
@@ -368,7 +430,7 @@ def birth_event(self):
368430    def  death_event (self ):
369431        # FIXME how to manage group 0 
370432        indices  =  list ()
371-         for  i , _ , _  in  self .iter_on (self .segment_track_array () ):
433+         for  i , _ , _  in  self .iter_on (self .segment_track_array ):
372434            nb  =  i .stop  -  i .start 
373435            if  nb  ==  0 :
374436                continue 
@@ -379,7 +441,7 @@ def death_event(self):
379441
380442    def  merging_event (self ):
381443        indices  =  list ()
382-         for  i , _ , _  in  self .iter_on (self .segment_track_array () ):
444+         for  i , _ , _  in  self .iter_on (self .segment_track_array ):
383445            nb  =  i .stop  -  i .start 
384446            if  nb  ==  0 :
385447                continue 
@@ -390,7 +452,7 @@ def merging_event(self):
390452
391453    def  spliting_event (self ):
392454        indices  =  list ()
393-         for  i , _ , _  in  self .iter_on (self .segment_track_array () ):
455+         for  i , _ , _  in  self .iter_on (self .segment_track_array ):
394456            nb  =  i .stop  -  i .start 
395457            if  nb  ==  0 :
396458                continue 
@@ -403,7 +465,7 @@ def fully_connected(self):
403465        self .only_one_network ()
404466        # TODO 
405467
406-     def  plot (self , ax , ref = None , ** kwargs ):
468+     def  plot (self , ax , ref = None , color_cycle = None ,  ** kwargs ):
407469        """ 
408470        This function will draw path of each trajectory 
409471
@@ -412,17 +474,25 @@ def plot(self, ax, ref=None, **kwargs):
412474        :param dict kwargs: keyword arguments for Axes.plot 
413475        :return: a list of matplotlib mappables 
414476        """ 
477+         nb_colors  =  0 
478+         if  color_cycle  is  not None :
479+             kwargs  =  kwargs .copy ()
480+             nb_colors  =  len (color_cycle )
415481        mappables  =  list ()
416482        if  "label"  in  kwargs :
417483            kwargs ["label" ] =  self .format_label (kwargs ["label" ])
418-         for  i , b0 , b1  in  self .iter_on ("segment" ):
484+         j  =  0 
485+         for  i , _ , _  in  self .iter_on ("segment" ):
419486            nb  =  i .stop  -  i .start 
420487            if  nb  ==  0 :
421488                continue 
489+             if  nb_colors :
490+                 kwargs ["color" ] =  color_cycle [j  %  nb_colors ]
422491            x , y  =  self .lon [i ], self .lat [i ]
423492            if  ref  is  not None :
424493                x , y  =  wrap_longitude (x , y , ref , cut = True )
425494            mappables .append (ax .plot (x , y , ** kwargs )[0 ])
495+             j  +=  1 
426496        return  mappables 
427497
428498    def  remove_dead_branch (self , nobs = 3 ):
0 commit comments