Skip to content

Commit 9085eac

Browse files
committed
- Add period to cube
- Add some methods for display - Speed up overlap
1 parent 7fc19df commit 9085eac

File tree

7 files changed

+133
-178
lines changed

7 files changed

+133
-178
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
[![PyPI version](https://badge.fury.io/py/pyEddyTracker.svg)](https://badge.fury.io/py/pyEddyTracker)
2+
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6333988.svg)](https://doi.org/10.5281/zenodo.6333988)
23
[![Documentation Status](https://readthedocs.org/projects/py-eddy-tracker/badge/?version=stable)](https://py-eddy-tracker.readthedocs.io/en/stable/?badge=stable)
34
[![Gitter](https://badges.gitter.im/py-eddy-tracker/community.svg)](https://gitter.im/py-eddy-tracker/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
45
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/AntSimi/py-eddy-tracker/master?urlpath=lab/tree/notebooks/python_module/)
56
[![pytest](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml/badge.svg)](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml)
67

78
# README #
89

10+
### How to cite code? ###
11+
12+
Zenodo provide DOI for each tagged version, [all DOI are available here](https://doi.org/10.5281/zenodo.6333988)
13+
914
### Method ###
1015

1116
Method was described in :

src/py_eddy_tracker/appli/network.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def subset_network():
128128
action="store_true",
129129
help="Remove trash (network id == 0)",
130130
)
131+
parser.add_argument(
132+
"-i", "--ids", nargs="+", type=int, help="List of network which will be extract"
133+
)
131134
parser.add_argument(
132135
"-p",
133136
"--period",
@@ -138,6 +141,8 @@ def subset_network():
138141
)
139142
args = parser.parse_args()
140143
n = NetworkObservations.load_file(args.input, raw_data=True)
144+
if args.ids is not None:
145+
n = n.networks(args.ids)
141146
if args.length is not None:
142147
n = n.longer_than(*args.length)
143148
if args.remove_dead_end is not None:

src/py_eddy_tracker/dataset/grid.py

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,20 @@
22
"""
33
Class to load and manipulate RegularGrid and UnRegularGrid
44
"""
5-
from datetime import datetime
65
import logging
6+
from datetime import datetime
77

88
from cv2 import filter2D
99
from matplotlib.path import Path as BasePath
1010
from netCDF4 import Dataset
11-
from numba import njit, prange, types as numba_types
12-
from numpy import (
13-
arange,
14-
array,
15-
ceil,
16-
concatenate,
17-
cos,
18-
deg2rad,
19-
empty,
20-
errstate,
21-
exp,
22-
float_,
23-
floor,
24-
histogram2d,
25-
int_,
26-
interp,
27-
isnan,
28-
linspace,
29-
ma,
30-
mean as np_mean,
31-
meshgrid,
32-
nan,
33-
nanmean,
34-
ones,
35-
percentile,
36-
pi,
37-
radians,
38-
round_,
39-
sin,
40-
sinc,
41-
where,
42-
zeros,
43-
)
11+
from numba import njit, prange
12+
from numba import types as numba_types
13+
from numpy import (arange, array, ceil, concatenate, cos, deg2rad, empty,
14+
errstate, exp, float_, floor, histogram2d, int_, interp,
15+
isnan, linspace, ma)
16+
from numpy import mean as np_mean
17+
from numpy import (meshgrid, nan, nanmean, ones, percentile, pi, radians,
18+
round_, sin, sinc, where, zeros)
4419
from pint import UnitRegistry
4520
from scipy.interpolate import RectBivariateSpline, interp1d
4621
from scipy.ndimage import gaussian_filter
@@ -49,26 +24,15 @@
4924
from scipy.special import j1
5025

5126
from .. import VAR_DESCR
27+
from ..data import get_demo_path
5228
from ..eddy_feature import Amplitude, Contours
53-
from ..generic import (
54-
bbox_indice_regular,
55-
coordinates_to_local,
56-
distance,
57-
interp2d_geo,
58-
local_to_coordinates,
59-
nearest_grd_indice,
60-
uniform_resample,
61-
)
29+
from ..generic import (bbox_indice_regular, coordinates_to_local, distance,
30+
interp2d_geo, local_to_coordinates, nearest_grd_indice,
31+
uniform_resample)
6232
from ..observations.observation import EddiesObservations
63-
from ..poly import (
64-
create_vertice,
65-
fit_circle,
66-
get_pixel_in_regular,
67-
poly_area,
68-
poly_contain_poly,
69-
visvalingam,
70-
winding_number_poly,
71-
)
33+
from ..poly import (create_vertice, fit_circle, get_pixel_in_regular,
34+
poly_area, poly_contain_poly, visvalingam,
35+
winding_number_poly)
7236

7337
logger = logging.getLogger("pet")
7438

@@ -1318,9 +1282,13 @@ def compute_pixel_path(self, x0, y0, x1, y1):
13181282
self.x_size,
13191283
)
13201284

1321-
def clean_land(self):
1285+
def clean_land(self, name):
13221286
"""Function to remove all land pixel"""
1323-
pass
1287+
mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat")
1288+
x,y = meshgrid(self.x_c, self.y_c)
1289+
m = mask_land.interp('mask', x.reshape(-1), y.reshape(-1), 'nearest')
1290+
data = self.grid(name)
1291+
self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T)
13241292

13251293
def is_circular(self):
13261294
"""Check if the grid is circular"""
@@ -2392,6 +2360,15 @@ def __iter__(self):
23922360
for _, d in self.datasets:
23932361
yield d
23942362

2363+
@property
2364+
def time(self):
2365+
return array([t for t, _ in self.datasets])
2366+
2367+
@property
2368+
def period(self):
2369+
t = self.time
2370+
return t.min(), t.max()
2371+
23952372
def __getitem__(self, item):
23962373
for t, d in self.datasets:
23972374
if t == item:

src/py_eddy_tracker/generic.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,11 @@
33
Tool method which use mostly numba
44
"""
55

6-
from numba import njit, prange, types as numba_types
7-
from numpy import (
8-
absolute,
9-
arcsin,
10-
arctan2,
11-
bool_,
12-
cos,
13-
empty,
14-
floor,
15-
histogram,
16-
interp,
17-
isnan,
18-
linspace,
19-
nan,
20-
ones,
21-
pi,
22-
radians,
23-
sin,
24-
where,
25-
zeros,
26-
)
6+
from numba import njit, prange
7+
from numba import types as numba_types
8+
from numpy import (absolute, arcsin, arctan2, bool_, cos, empty, floor,
9+
histogram, interp, isnan, linspace, nan, ones, pi, radians,
10+
sin, where, zeros)
2711

2812

2913
@njit(cache=True)
@@ -426,7 +410,7 @@ def split_line(x, y, i):
426410
"""
427411
nb_jump = len(where(i[1:] - i[:-1] != 0)[0])
428412
nb_value = x.shape[0]
429-
final_size = (nb_jump - 1) + nb_value
413+
final_size = nb_jump + nb_value
430414
new_x = empty(final_size, dtype=x.dtype)
431415
new_y = empty(final_size, dtype=y.dtype)
432416
new_j = 0

src/py_eddy_tracker/observations/network.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,26 @@
22
"""
33
Class to create network of observations
44
"""
5-
from glob import glob
65
import logging
76
import time
7+
from glob import glob
88

99
import netCDF4
10-
from numba import njit, types as nb_types
11-
from numba.typed import List
12-
from numpy import (
13-
arange,
14-
array,
15-
bincount,
16-
bool_,
17-
concatenate,
18-
empty,
19-
nan,
20-
ones,
21-
percentile,
22-
uint16,
23-
uint32,
24-
unique,
25-
where,
26-
zeros,
27-
)
2810
import zarr
11+
from numba import njit
12+
from numba import types as nb_types
13+
from numba.typed import List
14+
from numpy import (arange, array, bincount, bool_, concatenate, empty, nan,
15+
ones, percentile, uint16, uint32, unique, where, zeros)
2916

3017
from ..dataset.grid import GridCollection
3118
from ..generic import build_index, wrap_longitude
3219
from ..poly import bbox_intersection, vertice_overlap
33-
from .groups import GroupEddiesObservations, get_missing_indices, particle_candidate
20+
from .groups import (GroupEddiesObservations, get_missing_indices,
21+
particle_candidate)
3422
from .observation import EddiesObservations
35-
from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter
23+
from .tracking import (TrackEddiesObservations, track_loess_filter,
24+
track_median_filter)
3625

3726
logger = logging.getLogger("pet")
3827

@@ -280,6 +269,15 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
280269
"""
281270
Select network on time duration
282271
272+
:param int nb_day_min: Minimal number of days covered by one network, if negative -> not used
273+
:param int nb_day_max: Maximal number of days covered by one network, if negative -> not used
274+
"""
275+
return self.extract_with_mask(self.mask_longer_than(nb_day_min, nb_day_max))
276+
277+
def mask_longer_than(self, nb_day_min=-1, nb_day_max=-1):
278+
"""
279+
Select network on time duration
280+
283281
:param int nb_day_min: Minimal number of days covered by one network, if negative -> not used
284282
:param int nb_day_max: Maximal number of days covered by one network, if negative -> not used
285283
"""
@@ -293,7 +291,7 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1):
293291
continue
294292
if nb_day_min <= (ptp(t[i]) + 1) <= nb_day_max:
295293
mask[i] = True
296-
return self.extract_with_mask(mask)
294+
return mask
297295

298296
@classmethod
299297
def from_split_network(cls, group_dataset, indexs, **kwargs):
@@ -800,7 +798,7 @@ def display_timeline(
800798
if field is not None:
801799
field = self.parse_varname(field)
802800
for i, b0, b1 in self.iter_on("segment"):
803-
x = self.time[i]
801+
x = self.time_datetime64[i]
804802
if x.shape[0] == 0:
805803
continue
806804
if field is None:
@@ -831,7 +829,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol
831829

832830
# TODO : fill mappables dict
833831
y_seg = dict()
834-
_time = self.time
832+
_time = self.time_datetime64
835833

836834
if field is not None and method != "all":
837835
for i, b0, _ in self.iter_on("segment"):
@@ -1011,7 +1009,7 @@ def scatter_timeline(
10111009
if "c" not in kwargs:
10121010
v = self.parse_varname(name)
10131011
kwargs["c"] = v * factor
1014-
mappables["scatter"] = ax.scatter(self.time, y, **kwargs)
1012+
mappables["scatter"] = ax.scatter(self.time_datetime64, y, **kwargs)
10151013
return mappables
10161014

10171015
def event_map(self, ax, **kwargs):
@@ -1244,7 +1242,7 @@ def networks_mask(self, id_networks, segment=False):
12441242

12451243
def networks(self, id_networks):
12461244
return self.extract_with_mask(
1247-
generate_mask_from_ids(id_networks, self.track.size, *self.index_network)
1245+
generate_mask_from_ids(array(id_networks), self.track.size, *self.index_network)
12481246
)
12491247

12501248
@property
@@ -1423,10 +1421,10 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs):
14231421
:param dict kwargs: keyword arguments for Axes.plot
14241422
:return: a list of matplotlib mappables
14251423
"""
1426-
nb_colors = 0
1427-
if color_cycle is not None:
1428-
kwargs = kwargs.copy()
1429-
nb_colors = len(color_cycle)
1424+
kwargs = kwargs.copy()
1425+
if color_cycle is None:
1426+
color_cycle = self.COLORS
1427+
nb_colors = len(color_cycle)
14301428
mappables = list()
14311429
if "label" in kwargs:
14321430
kwargs["label"] = self.format_label(kwargs["label"])

0 commit comments

Comments
 (0)