2
2
"""
3
3
Class to create network of observations
4
4
"""
5
- from glob import glob
6
5
import logging
7
6
import time
7
+ from glob import glob
8
8
9
9
import netCDF4
10
- from numba import njit , types as nb_types
11
- from numba .typed import List
12
- from numpy import (
13
- arange ,
14
- array ,
15
- bincount ,
16
- bool_ ,
17
- concatenate ,
18
- empty ,
19
- nan ,
20
- ones ,
21
- percentile ,
22
- uint16 ,
23
- uint32 ,
24
- unique ,
25
- where ,
26
- zeros ,
27
- )
28
10
import zarr
11
+ from numba import njit
12
+ from numba import types as nb_types
13
+ from numba .typed import List
14
+ from numpy import (arange , array , bincount , bool_ , concatenate , empty , nan ,
15
+ ones , percentile , uint16 , uint32 , unique , where , zeros )
29
16
30
17
from ..dataset .grid import GridCollection
31
18
from ..generic import build_index , wrap_longitude
32
19
from ..poly import bbox_intersection , vertice_overlap
33
- from .groups import GroupEddiesObservations , get_missing_indices , particle_candidate
20
+ from .groups import (GroupEddiesObservations , get_missing_indices ,
21
+ particle_candidate )
34
22
from .observation import EddiesObservations
35
- from .tracking import TrackEddiesObservations , track_loess_filter , track_median_filter
23
+ from .tracking import (TrackEddiesObservations , track_loess_filter ,
24
+ track_median_filter )
36
25
37
26
logger = logging .getLogger ("pet" )
38
27
@@ -280,6 +269,15 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
280
269
"""
281
270
Select network on time duration
282
271
272
+ :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used
273
+ :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used
274
+ """
275
+ return self .extract_with_mask (self .mask_longer_than (nb_day_min , nb_day_max ))
276
+
277
+ def mask_longer_than (self , nb_day_min = - 1 , nb_day_max = - 1 ):
278
+ """
279
+ Select network on time duration
280
+
283
281
:param int nb_day_min: Minimal number of days covered by one network, if negative -> not used
284
282
:param int nb_day_max: Maximal number of days covered by one network, if negative -> not used
285
283
"""
@@ -293,7 +291,7 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
293
291
continue
294
292
if nb_day_min <= (ptp (t [i ]) + 1 ) <= nb_day_max :
295
293
mask [i ] = True
296
- return self . extract_with_mask ( mask )
294
+ return mask
297
295
298
296
@classmethod
299
297
def from_split_network (cls , group_dataset , indexs , ** kwargs ):
@@ -800,7 +798,7 @@ def display_timeline(
800
798
if field is not None :
801
799
field = self .parse_varname (field )
802
800
for i , b0 , b1 in self .iter_on ("segment" ):
803
- x = self .time [i ]
801
+ x = self .time_datetime64 [i ]
804
802
if x .shape [0 ] == 0 :
805
803
continue
806
804
if field is None :
@@ -831,7 +829,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
831
829
832
830
# TODO : fill mappables dict
833
831
y_seg = dict ()
834
- _time = self .time
832
+ _time = self .time_datetime64
835
833
836
834
if field is not None and method != "all" :
837
835
for i , b0 , _ in self .iter_on ("segment" ):
@@ -1011,7 +1009,7 @@ def scatter_timeline(
1011
1009
if "c" not in kwargs :
1012
1010
v = self .parse_varname (name )
1013
1011
kwargs ["c" ] = v * factor
1014
- mappables ["scatter" ] = ax .scatter (self .time , y , ** kwargs )
1012
+ mappables ["scatter" ] = ax .scatter (self .time_datetime64 , y , ** kwargs )
1015
1013
return mappables
1016
1014
1017
1015
def event_map (self , ax , ** kwargs ):
@@ -1244,7 +1242,7 @@ def networks_mask(self, id_networks, segment=False):
1244
1242
1245
1243
def networks (self , id_networks ):
1246
1244
return self .extract_with_mask (
1247
- generate_mask_from_ids (id_networks , self .track .size , * self .index_network )
1245
+ generate_mask_from_ids (array ( id_networks ) , self .track .size , * self .index_network )
1248
1246
)
1249
1247
1250
1248
@property
@@ -1423,10 +1421,10 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
1423
1421
:param dict kwargs: keyword arguments for Axes.plot
1424
1422
:return: a list of matplotlib mappables
1425
1423
"""
1426
- nb_colors = 0
1427
- if color_cycle is not None :
1428
- kwargs = kwargs . copy ()
1429
- nb_colors = len (color_cycle )
1424
+ kwargs = kwargs . copy ()
1425
+ if color_cycle is None :
1426
+ color_cycle = self . COLORS
1427
+ nb_colors = len (color_cycle )
1430
1428
mappables = list ()
1431
1429
if "label" in kwargs :
1432
1430
kwargs ["label" ] = self .format_label (kwargs ["label" ])
0 commit comments