66from  glob  import  glob 
77
88from  numba  import  njit 
9- from  numpy  import  arange , array , bincount , empty , ones , uint32 , unique 
9+ from  numpy  import  arange , array , bincount , empty , ones , uint32 , unique ,  zeros 
1010
1111from  ..generic  import  build_index , wrap_longitude 
1212from  ..poly  import  bbox_intersection , vertice_overlap 
1313from  .observation  import  EddiesObservations 
14- from  .tracking  import  TrackEddiesObservations , track_median_filter 
14+ from  .tracking  import  TrackEddiesObservations , track_loess_filter ,  track_median_filter 
1515
1616logger  =  logging .getLogger ("pet" )
1717
@@ -71,6 +71,26 @@ def elements(self):
7171        elements .extend (["track" , "segment" , "next_obs" , "previous_obs" ])
7272        return  list (set (elements ))
7373
74+     def  longer_than (self , nb_day_min = - 1 , nb_day_max = - 1 ):
75+         """ 
76+         Select network on time duration 
77+ 
78+         :param int nb_day_min: Minimal number of day which must be covered by one network, if negative -> not used 
79+         :param int nb_day_max: Maximal number of day which must be covered by one network, if negative -> not used 
80+         """ 
81+         if  nb_day_max  <  0 :
82+             nb_day_max  =  1000000000000 
83+         mask  =  zeros (self .shape , dtype = "bool" )
84+         for  i , b0 , b1  in  self .iter_on (self .segment_track_array ()):
85+             nb  =  i .stop  -  i .start 
86+             if  nb  ==  0 :
87+                 continue 
88+             t  =  self .time [i ]
89+             dt  =  t .max () -  t .min ()
90+             if  nb_day_min  <=  dt  <=  nb_day_max :
91+                 mask [i ] =  True 
92+         return  self .extract_with_mask (mask )
93+ 
7494    @classmethod  
7595    def  from_split_network (cls , group_dataset , indexs , ** kwargs ):
7696        """ 
@@ -160,6 +180,13 @@ def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=Fal
160180        m  =  (d  <=  order ) *  (d  !=  - 1 )
161181        return  self .extract_with_mask (m )
162182
183+     def  numbering_segment (self ):
184+         """ 
185+         New numbering of segment 
186+         """ 
187+         for  i , _ , _  in  self .iter_on ("track" ):
188+             new_numbering (self .segment [i ])
189+ 
163190    def  only_one_network (self ):
164191        """ 
165192        Raise a warning or error? 
@@ -168,17 +195,35 @@ def only_one_network(self):
168195        # TODO 
169196        pass 
170197
198+     def  position_filter (self , median_half_window , loess_half_window ):
199+         self .median_filter (median_half_window , "time" , "lon" ).loess_filter (
200+             loess_half_window , "time" , "lon" 
201+         )
202+         self .median_filter (median_half_window , "time" , "lat" ).loess_filter (
203+             loess_half_window , "time" , "lat" 
204+         )
205+ 
206+     def  loess_filter (self , half_window , xfield , yfield , inplace = True ):
207+         result  =  track_loess_filter (
208+             half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array ()
209+         )
210+         if  inplace :
211+             self .obs [yfield ] =  result 
212+             return  self 
213+         return  result 
214+ 
171215    def  median_filter (self , half_window , xfield , yfield , inplace = True ):
172-         # FIXME: segments is not enough with several network 
173216        result  =  track_median_filter (
174-             half_window , self [xfield ], self [yfield ], self .segment 
217+             half_window , self [xfield ], self [yfield ], self .segment_track_array () 
175218        )
176219        if  inplace :
177220            self [yfield ][:] =  result 
178221            return  self 
179222        return  result 
180223
181-     def  display_timeline (self , ax , event = True , field = None , method = None ):
224+     def  display_timeline (
225+         self , ax , event = True , field = None , method = None , factor = 1 , ** kwargs 
226+     ):
182227        """ 
183228        Must be call on only one network 
184229        """ 
@@ -191,9 +236,12 @@ def display_timeline(self, ax, event=True, field=None, method=None):
191236            zorder = 1 ,
192237            lw = 3 ,
193238        )
239+         line_kw .update (kwargs )
194240        mappables  =  dict (lines = list ())
195241        if  event :
196-             mappables .update (self .event_timeline (ax , field = field , method = method ))
242+             mappables .update (
243+                 self .event_timeline (ax , field = field , method = method , factor = factor )
244+             )
197245        for  i , b0 , b1  in  self .iter_on ("segment" ):
198246            x  =  self .time [i ]
199247            if  x .shape [0 ] ==  0 :
@@ -202,24 +250,24 @@ def display_timeline(self, ax, event=True, field=None, method=None):
202250                y  =  b0  *  ones (x .shape )
203251            else :
204252                if  method  ==  "all" :
205-                     y  =  self [field ][i ]
253+                     y  =  self [field ][i ]  *   factor 
206254                else :
207-                     y  =  self [field ][i ].mean () *  ones (x .shape )
255+                     y  =  self [field ][i ].mean () *  ones (x .shape )  *   factor 
208256            line  =  ax .plot (x , y , ** line_kw , color = self .COLORS [j  %  self .NB_COLORS ])[0 ]
209257            mappables ["lines" ].append (line )
210258            j  +=  1 
211259
212260        return  mappables 
213261
214-     def  event_timeline (self , ax , field = None , method = None ):
262+     def  event_timeline (self , ax , field = None , method = None ,  factor = 1 ):
215263        j  =  0 
216264        # TODO : fill mappables dict 
217265        y_seg  =  dict ()
218266        if  field  is  not None  and  method  !=  "all" :
219267            for  i , b0 , _  in  self .iter_on ("segment" ):
220268                y  =  self [field ][i ]
221269                if  y .shape [0 ] !=  0 :
222-                     y_seg [b0 ] =  y .mean ()
270+                     y_seg [b0 ] =  y .mean ()  *   factor 
223271        mappables  =  dict ()
224272        for  i , b0 , b1  in  self .iter_on ("segment" ):
225273            x  =  self .time [i ]
@@ -234,26 +282,34 @@ def event_timeline(self, ax, field=None, method=None):
234282                y0  =  b0 
235283            else :
236284                if  method  ==  "all" :
237-                     y0  =  self [field ][i .stop  -  1 ]
285+                     y0  =  self [field ][i .stop  -  1 ]  *   factor 
238286                else :
239287                    y0  =  y_seg [b0 ]
240288            if  i_n  !=  - 1 :
241289                seg_next  =  self .segment [i_n ]
242290                y1  =  (
243291                    seg_next 
244292                    if  field  is  None 
245-                     else  (self [field ][i_n ] if  method  ==  "all"  else  y_seg [seg_next ])
293+                     else  (
294+                         self [field ][i_n ] *  factor 
295+                         if  method  ==  "all" 
296+                         else  y_seg [seg_next ]
297+                     )
246298                )
247299                ax .plot ((x [- 1 ], self .time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
248300                ax .plot (x [- 1 ], y0 , color = "k" , marker = ">" , markersize = 10 , zorder = - 1 )[0 ]
249301            if  i_p  !=  - 1 :
250302                seg_previous  =  self .segment [i_p ]
251303                if  field  is  not None  and  method  ==  "all" :
252-                     y0  =  self [field ][i .start ]
304+                     y0  =  self [field ][i .start ]  *   factor 
253305                y1  =  (
254306                    seg_previous 
255307                    if  field  is  None 
256-                     else  (self [field ][i_p ] if  method  ==  "all"  else  y_seg [seg_previous ])
308+                     else  (
309+                         self [field ][i_p ] *  factor 
310+                         if  method  ==  "all" 
311+                         else  y_seg [seg_previous ]
312+                     )
257313                )
258314                ax .plot ((x [0 ], self .time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
259315                ax .plot (x [0 ], y0 , color = "k" , marker = "*" , markersize = 12 , zorder = - 1 )[0 ]
@@ -300,7 +356,7 @@ def segment_track_array(self):
300356    def  birth_event (self ):
301357        # FIXME how to manage group 0 
302358        indices  =  list ()
303-         for  i , b0 ,  b1  in  self .iter_on (self .segment_track_array ()):
359+         for  i , _ ,  _  in  self .iter_on (self .segment_track_array ()):
304360            nb  =  i .stop  -  i .start 
305361            if  nb  ==  0 :
306362                continue 
@@ -312,7 +368,7 @@ def birth_event(self):
312368    def  death_event (self ):
313369        # FIXME how to manage group 0 
314370        indices  =  list ()
315-         for  i , b0 ,  b1  in  self .iter_on (self .segment_track_array ()):
371+         for  i , _ ,  _  in  self .iter_on (self .segment_track_array ()):
316372            nb  =  i .stop  -  i .start 
317373            if  nb  ==  0 :
318374                continue 
@@ -323,7 +379,7 @@ def death_event(self):
323379
324380    def  merging_event (self ):
325381        indices  =  list ()
326-         for  i , b0 ,  b1  in  self .iter_on (self .segment_track_array ()):
382+         for  i , _ ,  _  in  self .iter_on (self .segment_track_array ()):
327383            nb  =  i .stop  -  i .start 
328384            if  nb  ==  0 :
329385                continue 
@@ -334,7 +390,7 @@ def merging_event(self):
334390
335391    def  spliting_event (self ):
336392        indices  =  list ()
337-         for  i , b0 ,  b1  in  self .iter_on (self .segment_track_array ()):
393+         for  i , _ ,  _  in  self .iter_on (self .segment_track_array ()):
338394            nb  =  i .stop  -  i .start 
339395            if  nb  ==  0 :
340396                continue 
@@ -425,6 +481,9 @@ def extract_with_mask(self, mask):
425481        if  nb_obs  ==  0 :
426482            logger .warning ("Empty dataset will be created" )
427483        else :
484+             logger .info (
485+                 f"{ nb_obs } { nb_obs  *  100.  /  self .shape [0 ]}  
486+             )
428487            for  field  in  self .obs .dtype .descr :
429488                if  field  in  ("next_obs" , "previous_obs" ):
430489                    continue 
@@ -592,3 +651,15 @@ def build_unique_array(id1, id2):
592651        new_id [i ] =  k 
593652        id1_previous , id2_previous  =  id1_ , id2_ 
594653    return  new_id 
654+ 
655+ 
656+ @njit (cache = True ) 
657+ def  new_numbering (segs ):
658+     nb  =  len (segs )
659+     s0  =  segs [0 ]
660+     j  =  0 
661+     for  i  in  range (nb ):
662+         if  segs [i ] !=  s0 :
663+             s0  =  segs [i ]
664+             j  +=  1 
665+         segs [i ] =  j 
0 commit comments