2727===========================================================================
2828
2929"""
30- from numpy import empty , arange , where , unique , interp , ones , bool_ , zeros , array , median
30+ from numpy import (
31+ empty ,
32+ arange ,
33+ where ,
34+ unique ,
35+ interp ,
36+ ones ,
37+ bool_ ,
38+ zeros ,
39+ array ,
40+ median ,
41+ )
3142from .. import VAR_DESCR_inv
3243import logging
3344from datetime import datetime , timedelta
3445from .observation import EddiesObservations
3546from numba import njit
47+ from ..generic import split_line
3648
3749logger = logging .getLogger ("pet" )
3850
@@ -224,28 +236,30 @@ def extract_ids(self, tracks):
224236 return self .__extract_with_mask (mask )
225237
226238 def extract_first_obs_in_box (self , res ):
227- data = empty (self .obs .shape , dtype = [('lon' , 'f4' ), ('lat' , 'f4' ), ('track' , 'i4' )])
228- data ['lon' ] = self .longitude - self .longitude % res
229- data ['lat' ] = self .latitude - self .latitude % res
230- data ['track' ] = self .obs ["track" ]
239+ data = empty (
240+ self .obs .shape , dtype = [("lon" , "f4" ), ("lat" , "f4" ), ("track" , "i4" )]
241+ )
242+ data ["lon" ] = self .longitude - self .longitude % res
243+ data ["lat" ] = self .latitude - self .latitude % res
244+ data ["track" ] = self .obs ["track" ]
231245 _ , indexs = unique (data , return_index = True )
232- mask = zeros (self .obs .shape , dtype = ' bool' )
246+ mask = zeros (self .obs .shape , dtype = " bool" )
233247 mask [indexs ] = True
234248 return self .__extract_with_mask (mask )
235249
236250 def extract_in_direction (self , direction , value = 0 ):
237251 nb_obs = self .nb_obs_by_track
238252 i_start = self .index_from_track
239253 i_stop = i_start + nb_obs - 1
240- if direction in ('S' , 'N' ):
254+ if direction in ("S" , "N" ):
241255 d_lat = self .latitude [i_stop ] - self .latitude [i_start ]
242- mask = d_lat < 0 if 'S' == direction else d_lat > 0
256+ mask = d_lat < 0 if "S" == direction else d_lat > 0
243257 mask &= abs (d_lat ) > value
244258 else :
245- lon_start , lon_end = self .longitude [i_start ], self .longitude [i_stop ]
259+ lon_start , lon_end = self .longitude [i_start ], self .longitude [i_stop ]
246260 lon_end = (lon_end - (lon_start - 180 )) % 360 + lon_start - 180
247261 d_lon = lon_end - lon_start
248- mask = d_lon < 0 if 'W' == direction else d_lon > 0
262+ mask = d_lon < 0 if "W" == direction else d_lon > 0
249263 mask &= abs (d_lon ) > value
250264 mask = mask .repeat (nb_obs )
251265 return self .__extract_with_mask (mask )
@@ -280,7 +294,12 @@ def median_filter(self, half_window, xfield, yfield, inplace=True):
280294 self .obs [yfield ] = result
281295
282296 def __extract_with_mask (
283- self , mask , full_path = False , remove_incomplete = False , compress_id = False , reject_virtual = False ,
297+ self ,
298+ mask ,
299+ full_path = False ,
300+ remove_incomplete = False ,
301+ compress_id = False ,
302+ reject_virtual = False ,
284303 ):
285304 """
286305 Extract a subset of observations
@@ -302,7 +321,7 @@ def __extract_with_mask(
302321
303322 if full_path :
304323 if reject_virtual :
305- mask *= ~ self .obs [' virtual' ].astype (' bool' )
324+ mask *= ~ self .obs [" virtual" ].astype (" bool" )
306325 tracks = unique (self .tracks [mask ])
307326 mask = self .get_mask_from_id (tracks )
308327 elif remove_incomplete :
@@ -333,6 +352,14 @@ def __extract_with_mask(
333352 new .obs ["track" ] = id_translate [new .obs ["track" ]]
334353 return new
335354
355+ def plot (self , ax , ref = None , ** kwargs ):
356+ if "label" in kwargs :
357+ kwargs ["label" ] += " (%s eddies)" % (self .nb_obs_by_track != 0 ).sum ()
358+ x , y = split_line (self .longitude , self .latitude , self .tracks )
359+ if ref is not None :
360+ x = (x - ref ) % 360 + ref
361+ return ax .plot (x , y , ** kwargs )
362+
336363
337364@njit (cache = True )
338365def compute_index (tracks , index , number ):
@@ -373,7 +400,9 @@ def track_loess_filter(half_window, x, y, track):
373400 if i != 0 :
374401 i_previous = i - 1
375402 dx = x [i ] - x [i_previous ]
376- while dx < half_window and i_previous != 0 and cur_track == track [i_previous ]:
403+ while (
404+ dx < half_window and i_previous != 0 and cur_track == track [i_previous ]
405+ ):
377406 w = (1 - (dx / half_window ) ** 3 ) ** 3
378407 y_sum += y [i_previous ] * w
379408 w_sum += w
@@ -412,7 +441,11 @@ def track_median_filter(half_window, x, y, track):
412441 cur_track = track [i ]
413442 while x [i ] - x [i_previous ] > half_window or cur_track != track [i_previous ]:
414443 i_previous += 1
415- while i_next < nb and x [i_next ] - x [i ] <= half_window and cur_track == track [i_next ]:
444+ while (
445+ i_next < nb
446+ and x [i_next ] - x [i ] <= half_window
447+ and cur_track == track [i_next ]
448+ ):
416449 i_next += 1
417450 y_new [i ] = median (y [i_previous :i_next ])
418451 return y_new
0 commit comments