66from glob import glob
77
88from numba import njit
9- from numpy import arange , array , bincount , empty , ones , uint32 , unique
9+ from numpy import arange , array , bincount , empty , ones , uint32 , unique , zeros
1010
1111from ..generic import build_index , wrap_longitude
1212from ..poly import bbox_intersection , vertice_overlap
1313from .observation import EddiesObservations
14- from .tracking import TrackEddiesObservations , track_median_filter
14+ from .tracking import TrackEddiesObservations , track_loess_filter , track_median_filter
1515
1616logger = logging .getLogger ("pet" )
1717
@@ -71,6 +71,26 @@ def elements(self):
7171 elements .extend (["track" , "segment" , "next_obs" , "previous_obs" ])
7272 return list (set (elements ))
7373
74+ def longer_than (self , nb_day_min = - 1 , nb_day_max = - 1 ):
75+ """
76+ Select network on time duration
77+
78+ :param int nb_day_min: Minimal number of day which must be covered by one network, if negative -> not used
79+ :param int nb_day_max: Maximal number of day which must be covered by one network, if negative -> not used
80+ """
81+ if nb_day_max < 0 :
82+ nb_day_max = 1000000000000
83+ mask = zeros (self .shape , dtype = "bool" )
84+ for i , b0 , b1 in self .iter_on (self .segment_track_array ()):
85+ nb = i .stop - i .start
86+ if nb == 0 :
87+ continue
88+ t = self .time [i ]
89+ dt = t .max () - t .min ()
90+ if nb_day_min <= dt <= nb_day_max :
91+ mask [i ] = True
92+ return self .extract_with_mask (mask )
93+
7494 @classmethod
7595 def from_split_network (cls , group_dataset , indexs , ** kwargs ):
7696 """
@@ -160,6 +180,13 @@ def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=Fal
160180 m = (d <= order ) * (d != - 1 )
161181 return self .extract_with_mask (m )
162182
183+ def numbering_segment (self ):
184+ """
185+ New numbering of segment
186+ """
187+ for i , _ , _ in self .iter_on ("track" ):
188+ new_numbering (self .segment [i ])
189+
163190 def only_one_network (self ):
164191 """
165192 Raise a warning or error?
@@ -168,17 +195,35 @@ def only_one_network(self):
168195 # TODO
169196 pass
170197
198+ def position_filter (self , median_half_window , loess_half_window ):
199+ self .median_filter (median_half_window , "time" , "lon" ).loess_filter (
200+ loess_half_window , "time" , "lon"
201+ )
202+ self .median_filter (median_half_window , "time" , "lat" ).loess_filter (
203+ loess_half_window , "time" , "lat"
204+ )
205+
206+ def loess_filter (self , half_window , xfield , yfield , inplace = True ):
207+ result = track_loess_filter (
208+ half_window , self .obs [xfield ], self .obs [yfield ], self .segment_track_array ()
209+ )
210+ if inplace :
211+ self .obs [yfield ] = result
212+ return self
213+ return result
214+
171215 def median_filter (self , half_window , xfield , yfield , inplace = True ):
172- # FIXME: segments is not enough with several network
173216 result = track_median_filter (
174- half_window , self [xfield ], self [yfield ], self .segment
217+ half_window , self [xfield ], self [yfield ], self .segment_track_array ()
175218 )
176219 if inplace :
177220 self [yfield ][:] = result
178221 return self
179222 return result
180223
181- def display_timeline (self , ax , event = True , field = None , method = None ):
224+ def display_timeline (
225+ self , ax , event = True , field = None , method = None , factor = 1 , ** kwargs
226+ ):
182227 """
183228 Must be call on only one network
184229 """
@@ -191,9 +236,12 @@ def display_timeline(self, ax, event=True, field=None, method=None):
191236 zorder = 1 ,
192237 lw = 3 ,
193238 )
239+ line_kw .update (kwargs )
194240 mappables = dict (lines = list ())
195241 if event :
196- mappables .update (self .event_timeline (ax , field = field , method = method ))
242+ mappables .update (
243+ self .event_timeline (ax , field = field , method = method , factor = factor )
244+ )
197245 for i , b0 , b1 in self .iter_on ("segment" ):
198246 x = self .time [i ]
199247 if x .shape [0 ] == 0 :
@@ -202,24 +250,24 @@ def display_timeline(self, ax, event=True, field=None, method=None):
202250 y = b0 * ones (x .shape )
203251 else :
204252 if method == "all" :
205- y = self [field ][i ]
253+ y = self [field ][i ] * factor
206254 else :
207- y = self [field ][i ].mean () * ones (x .shape )
255+ y = self [field ][i ].mean () * ones (x .shape ) * factor
208256 line = ax .plot (x , y , ** line_kw , color = self .COLORS [j % self .NB_COLORS ])[0 ]
209257 mappables ["lines" ].append (line )
210258 j += 1
211259
212260 return mappables
213261
214- def event_timeline (self , ax , field = None , method = None ):
262+ def event_timeline (self , ax , field = None , method = None , factor = 1 ):
215263 j = 0
216264 # TODO : fill mappables dict
217265 y_seg = dict ()
218266 if field is not None and method != "all" :
219267 for i , b0 , _ in self .iter_on ("segment" ):
220268 y = self [field ][i ]
221269 if y .shape [0 ] != 0 :
222- y_seg [b0 ] = y .mean ()
270+ y_seg [b0 ] = y .mean () * factor
223271 mappables = dict ()
224272 for i , b0 , b1 in self .iter_on ("segment" ):
225273 x = self .time [i ]
@@ -234,26 +282,34 @@ def event_timeline(self, ax, field=None, method=None):
234282 y0 = b0
235283 else :
236284 if method == "all" :
237- y0 = self [field ][i .stop - 1 ]
285+ y0 = self [field ][i .stop - 1 ] * factor
238286 else :
239287 y0 = y_seg [b0 ]
240288 if i_n != - 1 :
241289 seg_next = self .segment [i_n ]
242290 y1 = (
243291 seg_next
244292 if field is None
245- else (self [field ][i_n ] if method == "all" else y_seg [seg_next ])
293+ else (
294+ self [field ][i_n ] * factor
295+ if method == "all"
296+ else y_seg [seg_next ]
297+ )
246298 )
247299 ax .plot ((x [- 1 ], self .time [i_n ]), (y0 , y1 ), ** event_kw )[0 ]
248300 ax .plot (x [- 1 ], y0 , color = "k" , marker = ">" , markersize = 10 , zorder = - 1 )[0 ]
249301 if i_p != - 1 :
250302 seg_previous = self .segment [i_p ]
251303 if field is not None and method == "all" :
252- y0 = self [field ][i .start ]
304+ y0 = self [field ][i .start ] * factor
253305 y1 = (
254306 seg_previous
255307 if field is None
256- else (self [field ][i_p ] if method == "all" else y_seg [seg_previous ])
308+ else (
309+ self [field ][i_p ] * factor
310+ if method == "all"
311+ else y_seg [seg_previous ]
312+ )
257313 )
258314 ax .plot ((x [0 ], self .time [i_p ]), (y0 , y1 ), ** event_kw )[0 ]
259315 ax .plot (x [0 ], y0 , color = "k" , marker = "*" , markersize = 12 , zorder = - 1 )[0 ]
@@ -300,7 +356,7 @@ def segment_track_array(self):
300356 def birth_event (self ):
301357 # FIXME how to manage group 0
302358 indices = list ()
303- for i , b0 , b1 in self .iter_on (self .segment_track_array ()):
359+ for i , _ , _ in self .iter_on (self .segment_track_array ()):
304360 nb = i .stop - i .start
305361 if nb == 0 :
306362 continue
@@ -312,7 +368,7 @@ def birth_event(self):
312368 def death_event (self ):
313369 # FIXME how to manage group 0
314370 indices = list ()
315- for i , b0 , b1 in self .iter_on (self .segment_track_array ()):
371+ for i , _ , _ in self .iter_on (self .segment_track_array ()):
316372 nb = i .stop - i .start
317373 if nb == 0 :
318374 continue
@@ -323,7 +379,7 @@ def death_event(self):
323379
324380 def merging_event (self ):
325381 indices = list ()
326- for i , b0 , b1 in self .iter_on (self .segment_track_array ()):
382+ for i , _ , _ in self .iter_on (self .segment_track_array ()):
327383 nb = i .stop - i .start
328384 if nb == 0 :
329385 continue
@@ -334,7 +390,7 @@ def merging_event(self):
334390
335391 def spliting_event (self ):
336392 indices = list ()
337- for i , b0 , b1 in self .iter_on (self .segment_track_array ()):
393+ for i , _ , _ in self .iter_on (self .segment_track_array ()):
338394 nb = i .stop - i .start
339395 if nb == 0 :
340396 continue
@@ -425,6 +481,9 @@ def extract_with_mask(self, mask):
425481 if nb_obs == 0 :
426482 logger .warning ("Empty dataset will be created" )
427483 else :
484+ logger .info (
485+ f"{ nb_obs } observations will be extract ({ nb_obs * 100. / self .shape [0 ]} %)"
486+ )
428487 for field in self .obs .dtype .descr :
429488 if field in ("next_obs" , "previous_obs" ):
430489 continue
@@ -592,3 +651,15 @@ def build_unique_array(id1, id2):
592651 new_id [i ] = k
593652 id1_previous , id2_previous = id1_ , id2_
594653 return new_id
654+
655+
656+ @njit (cache = True )
657+ def new_numbering (segs ):
658+ nb = len (segs )
659+ s0 = segs [0 ]
660+ j = 0
661+ for i in range (nb ):
662+ if segs [i ] != s0 :
663+ s0 = segs [i ]
664+ j += 1
665+ segs [i ] = j
0 commit comments