Skip to content

Commit 1b9ab25

Browse files
committed
numba correction with masked array
1 parent 1728815 commit 1b9ab25

File tree

12 files changed

+256
-94
lines changed

12 files changed

+256
-94
lines changed

doc/spectrum.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ Compute and display spectrum
2828
ax.set_title("Spectrum")
2929
ax.set_xlabel("km")
3030
for name_area, area in areas.items():
31-
3231
lon_spec, lat_spec = raw.spectrum_lonlat("adt", area=area)
3332
mappable = ax.loglog(*lat_spec, label="lat %s raw" % name_area)[0]
3433
ax.loglog(

examples/16_network/pet_atlas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def update_axes(ax, mappable=None):
129129
# Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations)
130130
# --------------------------------------------------------------------------------------------
131131
ax = start_axes("")
132-
merger = n10.remove_dead_end(nobs=10).merging_event()
132+
n10_ = n10.copy()
133+
n10_.remove_dead_end(nobs=10)
134+
merger = n10_.merging_event()
133135
g_10_merging = merger.grid_count(bins)
134136
m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1)
135137
update_axes(ax, m).set_label("Pixel used in % of time")

examples/16_network/pet_follow_particle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def save(self, *args, **kwargs):
4141
# %%
4242
n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651)
4343
n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269))
44-
n = n.remove_dead_end(nobs=0, ndays=10)
44+
n.remove_dead_end(nobs=0, ndays=10)
45+
n = n.remove_trash()
4546
n.numbering_segment()
4647
c = GridCollection.from_netcdf_cube(
4748
get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"),

examples/16_network/pet_relative.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@
127127
# Remove dead branch
128128
# ------------------
129129
# Remove all tiny segments with less than N obs which didn't join two segments
130-
n_clean = n.remove_dead_end(nobs=5, ndays=10)
130+
n_clean = n.copy()
131+
n_clean.remove_dead_end(nobs=5, ndays=10)
132+
n_clean = n_clean.remove_trash()
131133
fig = plt.figure(figsize=(15, 12))
132134
ax = fig.add_axes([0.04, 0.54, 0.90, 0.40])
133135
ax.set_title(f"Original network ({n.infos()})")
@@ -261,7 +263,9 @@
261263
# --------------------
262264

263265
# Get a simplified network
264-
n = n2.remove_dead_end(nobs=50, recursive=1)
266+
n = n2.copy()
267+
n.remove_dead_end(nobs=50, recursive=1)
268+
n = n.remove_trash()
265269
n.numbering_segment()
266270
# %%
267271
# Only a map can be tricky to understand, with a timeline it's easier!

examples/16_network/pet_replay_segmentation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ def get_obs(dataset):
163163
for b0, b1 in [
164164
(datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008)
165165
]:
166-
167166
ref, delta = datetime(1950, 1, 1), 20
168167
b0_, b1_ = (b0 - ref).days, (b1 - ref).days
169168
ax = timeline_axes()

src/py_eddy_tracker/dataset/grid.py

Lines changed: 87 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,45 @@
22
"""
33
Class to load and manipulate RegularGrid and UnRegularGrid
44
"""
5-
import logging
65
from datetime import datetime
6+
import logging
77

88
from cv2 import filter2D
99
from matplotlib.path import Path as BasePath
1010
from netCDF4 import Dataset
11-
from numba import njit, prange
12-
from numba import types as numba_types
13-
from numpy import (arange, array, ceil, concatenate, cos, deg2rad, empty,
14-
errstate, exp, float_, floor, histogram2d, int_, interp,
15-
isnan, linspace, ma)
16-
from numpy import mean as np_mean
17-
from numpy import (meshgrid, nan, nanmean, ones, percentile, pi, radians,
18-
round_, sin, sinc, where, zeros)
11+
from numba import njit, prange, types as numba_types
12+
from numpy import (
13+
arange,
14+
array,
15+
ceil,
16+
concatenate,
17+
cos,
18+
deg2rad,
19+
empty,
20+
errstate,
21+
exp,
22+
float_,
23+
floor,
24+
histogram2d,
25+
int_,
26+
interp,
27+
isnan,
28+
linspace,
29+
ma,
30+
mean as np_mean,
31+
meshgrid,
32+
nan,
33+
nanmean,
34+
ones,
35+
percentile,
36+
pi,
37+
radians,
38+
round_,
39+
sin,
40+
sinc,
41+
where,
42+
zeros,
43+
)
1944
from pint import UnitRegistry
2045
from scipy.interpolate import RectBivariateSpline, interp1d
2146
from scipy.ndimage import gaussian_filter
@@ -26,13 +51,25 @@
2651
from .. import VAR_DESCR
2752
from ..data import get_demo_path
2853
from ..eddy_feature import Amplitude, Contours
29-
from ..generic import (bbox_indice_regular, coordinates_to_local, distance,
30-
interp2d_geo, local_to_coordinates, nearest_grd_indice,
31-
uniform_resample)
54+
from ..generic import (
55+
bbox_indice_regular,
56+
coordinates_to_local,
57+
distance,
58+
interp2d_geo,
59+
local_to_coordinates,
60+
nearest_grd_indice,
61+
uniform_resample,
62+
)
3263
from ..observations.observation import EddiesObservations
33-
from ..poly import (create_vertice, fit_circle, get_pixel_in_regular,
34-
poly_area, poly_contain_poly, visvalingam,
35-
winding_number_poly)
64+
from ..poly import (
65+
create_vertice,
66+
fit_circle,
67+
get_pixel_in_regular,
68+
poly_area,
69+
poly_contain_poly,
70+
visvalingam,
71+
winding_number_poly,
72+
)
3673

3774
logger = logging.getLogger("pet")
3875

@@ -86,7 +123,7 @@ def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size
86123

87124
@njit(cache=True)
88125
def mean_on_regular_contour(
89-
x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None, nan_remove=False
126+
x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=-1, nan_remove=False
90127
):
91128
x_val, y_val = vertices[:, 0], vertices[:, 1]
92129
x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size)
@@ -406,8 +443,8 @@ def setup_coordinates(self):
406443
x_name, y_name = self.coordinates
407444
if self.is_centered:
408445
# logger.info("Grid center")
409-
self.x_c = self.vars[x_name].astype("float64")
410-
self.y_c = self.vars[y_name].astype("float64")
446+
self.x_c = array(self.vars[x_name].astype("float64"))
447+
self.y_c = array(self.vars[y_name].astype("float64"))
411448

412449
self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],)))
413450
self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],)))
@@ -419,8 +456,8 @@ def setup_coordinates(self):
419456
self.y_bounds[-1] -= d_y[-1] / 2
420457

421458
else:
422-
self.x_bounds = self.vars[x_name].astype("float64")
423-
self.y_bounds = self.vars[y_name].astype("float64")
459+
self.x_bounds = array(self.vars[x_name].astype("float64"))
460+
self.y_bounds = array(self.vars[y_name].astype("float64"))
424461

425462
if len(self.x_dim) == 1:
426463
self.x_c = self.x_bounds.copy()
@@ -757,7 +794,7 @@ def eddy_identification(
757794

758795
# Test of the rotating sense: cyclone or anticyclone
759796
if has_value(
760-
data, i_x_in, i_y_in, cvalues, below=anticyclonic_search
797+
data.data, i_x_in, i_y_in, cvalues, below=anticyclonic_search
761798
):
762799
continue
763800

@@ -788,7 +825,6 @@ def eddy_identification(
788825
contour.reject = 4
789826
continue
790827
if reset_centroid:
791-
792828
if self.is_circular():
793829
centi = self.normalize_x_indice(reset_centroid[0])
794830
else:
@@ -1285,8 +1321,8 @@ def compute_pixel_path(self, x0, y0, x1, y1):
12851321
def clean_land(self, name):
12861322
"""Function to remove all land pixel"""
12871323
mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat")
1288-
x,y = meshgrid(self.x_c, self.y_c)
1289-
m = mask_land.interp('mask', x.reshape(-1), y.reshape(-1), 'nearest')
1324+
x, y = meshgrid(self.x_c, self.y_c)
1325+
m = mask_land.interp("mask", x.reshape(-1), y.reshape(-1), "nearest")
12901326
data = self.grid(name)
12911327
self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T)
12921328

@@ -1310,7 +1346,7 @@ def get_step_in_km(self, lat, wave_length):
13101346
min_wave_length = max(step_x_km, step_y_km) * 2
13111347
if wave_length < min_wave_length:
13121348
logger.error(
1313-
"wave_length too short for resolution, must be > %d km",
1349+
"Wave_length too short for resolution, must be > %d km",
13141350
ceil(min_wave_length),
13151351
)
13161352
raise Exception()
@@ -1361,6 +1397,24 @@ def kernel_lanczos(self, lat, wave_length, order=1):
13611397
kernel[dist_norm > order] = 0
13621398
return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt)
13631399

1400+
def kernel_loess(self, lat, wave_length, order=1):
1401+
"""
1402+
https://fr.wikipedia.org/wiki/R%C3%A9gression_locale
1403+
"""
1404+
order = self.check_order(order)
1405+
half_x_pt, half_y_pt, dist_norm = self.estimate_kernel_shape(
1406+
lat, wave_length, order
1407+
)
1408+
1409+
def inc_func(xdist):
1410+
f = zeros(xdist.size)
1411+
f[abs(xdist) < 1] = 1
1412+
return f
1413+
1414+
kernel = (1 - abs(dist_norm) ** 3) ** 3
1415+
kernel[abs(dist_norm) > order] = 0
1416+
return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt)
1417+
13641418
def kernel_bessel(self, lat, wave_length, order=1):
13651419
"""wave_length in km
13661420
order must be int
@@ -1638,11 +1692,13 @@ def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=Fal
16381692
data1[-schema:] = nan
16391693
data2[:schema] = nan
16401694

1641-
d = self.EARTH_RADIUS * 2 * pi / 360 * 2 * schema
1695+
# Distance for one degree
1696+
d = self.EARTH_RADIUS * 2 * pi / 360
1697+
# Mulitply by 2 step
16421698
if vertical:
1643-
d *= self.ystep
1699+
d *= self.ystep * 2 * schema
16441700
else:
1645-
d *= self.xstep * cos(deg2rad(self.y_c))
1701+
d *= self.xstep * cos(deg2rad(self.y_c)) * 2 * schema
16461702
return (data1 - data2) / d
16471703

16481704
def compute_stencil(
@@ -1855,7 +1911,7 @@ def speed_coef_mean(self, contour):
18551911
return mean_on_regular_contour(
18561912
self.x_c,
18571913
self.y_c,
1858-
self._speed_ev,
1914+
self._speed_ev.data,
18591915
self._speed_ev.mask,
18601916
contour.vertices,
18611917
nan_remove=True,
@@ -1945,7 +2001,7 @@ def interp(self, grid_name, lons, lats, method="bilinear"):
19452001
g = self.grid(grid_name)
19462002
m = self.get_mask(g)
19472003
return interp2d_geo(
1948-
self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest"
2004+
self.x_c, self.y_c, g.data, m, lons, lats, nearest=method == "nearest"
19492005
)
19502006

19512007
def uv_for_advection(
@@ -1981,7 +2037,7 @@ def uv_for_advection(
19812037
u = -u
19822038
v = -v
19832039
m = u.mask + v.mask
1984-
return u, v, m
2040+
return u.data, v.data, m
19852041

19862042
def advect(self, x, y, u_name, v_name, nb_step=10, rk4=True, **kw):
19872043
"""

src/py_eddy_tracker/featured_tracking/old_tracker_reference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
class CheltonTracker(Model):
11-
1211
__slots__ = tuple()
1312

1413
GROUND = RegularGridDataset(

src/py_eddy_tracker/generic.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,27 @@
33
Tool method which use mostly numba
44
"""
55

6-
from numba import njit, prange
7-
from numba import types as numba_types
8-
from numpy import (absolute, arcsin, arctan2, bool_, cos, empty, floor,
9-
histogram, interp, isnan, linspace, nan, ones, pi, radians,
10-
sin, where, zeros)
6+
from numba import njit, prange, types as numba_types
7+
from numpy import (
8+
absolute,
9+
arcsin,
10+
arctan2,
11+
bool_,
12+
cos,
13+
empty,
14+
floor,
15+
histogram,
16+
interp,
17+
isnan,
18+
linspace,
19+
nan,
20+
ones,
21+
pi,
22+
radians,
23+
sin,
24+
where,
25+
zeros,
26+
)
1127

1228

1329
@njit(cache=True)
@@ -285,14 +301,14 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y):
285301

286302

287303
@njit(cache=True, fastmath=True)
288-
def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
304+
def uniform_resample(x_val, y_val, num_fac=2, fixed_size=-1):
289305
"""
290306
Resample contours to have (nearly) equal spacing.
291307
292308
:param array_like x_val: input x contour coordinates
293309
:param array_like y_val: input y contour coordinates
294310
:param int num_fac: factor to increase lengths of output coordinates
295-
:param int,None fixed_size: if defined, will be used to set sampling
311+
:param int fixed_size: if > -1, will be used to set sampling
296312
"""
297313
nb = x_val.shape[0]
298314
# Get distances
@@ -303,7 +319,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
303319
dist[1:][dist[1:] < 1e-3] = 1e-3
304320
dist = dist.cumsum()
305321
# Get uniform distances
306-
if fixed_size is None:
322+
if fixed_size == -1:
307323
fixed_size = dist.size * num_fac
308324
d_uniform = linspace(0, dist[-1], fixed_size)
309325
x_new = interp(d_uniform, dist, x_val)

0 commit comments

Comments
 (0)