Skip to content

Commit 8f81399

Browse files
committed
numba evolution, bug around bounds for interpolation
1 parent fc50aa2 commit 8f81399

File tree

2 files changed

+339
-179
lines changed

2 files changed

+339
-179
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 159 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
from scipy.spatial import cKDTree
1616
from scipy.signal import welch
1717
from cv2 import filter2D
18-
from numba import njit, prange
18+
from numba import njit, prange, types as numba_types
1919
from matplotlib.path import Path as BasePath
2020
from matplotlib.contour import QuadContourSet as BaseQuadContourSet
2121
from pyproj import Proj
2222
from pint import UnitRegistry
23-
from ..tools import winding_number_poly, poly_contain_poly
2423
from ..observations import EddiesObservations
2524
from ..eddy_feature import Amplitude, Contours
2625
from .. import VAR_DESCR
@@ -102,7 +101,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
102101
dist[0] = 0
103102
dist[1:] = distance(x_val[:-1], y_val[:-1], x_val[1:], y_val[1:])
104103
# To be still monotonous
105-
dist[dist==0] = 1e-10
104+
dist[1:][dist[1:]<1e-10] = 1e-10
106105
dist = dist.cumsum()
107106
# Get uniform distances
108107
if fixed_size is None:
@@ -122,6 +121,12 @@ def uniform_resample_stack(vertices, num_fac=2, fixed_size=None):
122121
data[:, 1] = y_new
123122
return data
124123

124+
@njit(cache=True)
125+
def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None):
126+
x_val, y_val = vertices[:, 0], vertices[:, 1]
127+
x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size)
128+
return interp2d_geo(x_g, y_g, z_g, m_g, x_new[1:], y_new[1:])
129+
125130

126131
def fit_circle_path(self):
127132
if not hasattr(self, '_circle_params'):
@@ -179,32 +184,6 @@ def local_to_coordinates(x, y, lon0, lat0):
179184
return lon, lat / D2R
180185

181186

182-
def _fit_circle_path_old(self):
183-
lon_mean, lat_mean = self.mean_coordinates
184-
# Prepare for shape test and get eddy_radius_e
185-
# http://www.geo.hunter.cuny.edu/~jochen/gtech201/lectures/
186-
# lec6concepts/map%20coordinate%20systems/
187-
# how%20to%20choose%20a%20projection.htm
188-
proj = Proj('+proj=aeqd +ellps=WGS84 +lat_0=%s +lon_0=%s'
189-
% (lat_mean, lon_mean))
190-
191-
c_x, c_y = proj(self.lon, self.lat)
192-
try:
193-
centlon_e, centlat_e, eddy_radius_e, aerr = fit_circle_c_numba(c_x, c_y)
194-
centlon_e, centlat_e = proj(centlon_e, centlat_e, inverse=True)
195-
centlon_e = (centlon_e - lon_mean + 180) % 360 + lon_mean - 180
196-
self._circle_params = centlon_e, centlat_e, eddy_radius_e, aerr
197-
except ZeroDivisionError:
198-
# Some time, edge is only a dot of few coordinates
199-
d_lon = self.lon.max() - self.lon.min()
200-
d_lat = self.lat.max() - self.lat.min()
201-
if d_lon < 1e-7 and d_lat < 1e-7:
202-
logging.warning('An edge is only define in one position')
203-
logging.debug('%d coordinates %s,%s', len(self.lon), self.lon,
204-
self.lat)
205-
self._circle_params = 0, -90, nan, nan
206-
207-
208187
@njit(cache=True)
209188
def fit_circle_c_numba(x_vec, y_vec):
210189
nb_elt = x_vec.shape[0]
@@ -304,7 +283,92 @@ def nb_pixel(self):
304283
raise Exception('No pixels_in call before!')
305284
return self._pixels_in[0].shape[0]
306285

307-
286+
287+
@njit(cache=True)
288+
def is_left(x_line_0, y_line_0, x_line_1, y_line_1, x_test, y_test):
289+
"""
290+
http://geomalgorithms.com/a03-_inclusion.html
291+
isLeft(): tests if a point is Left|On|Right of an infinite line.
292+
Input: three points P0, P1, and P2
293+
Return: >0 for P2 left of the line through P0 and P1
294+
=0 for P2 on the line
295+
<0 for P2 right of the line
296+
See: Algorithm 1 "Area of Triangles and Polygons"
297+
"""
298+
# Vector product
299+
product = (x_line_1 - x_line_0) * (y_test - y_line_0) - (x_test - x_line_0) * (y_line_1 - y_line_0)
300+
return product > 0
301+
302+
303+
@njit(cache=True)
304+
def poly_contain_poly(xy_poly_out, xy_poly_in):
305+
nb_elt = xy_poly_in.shape[0]
306+
for i_elt in prange(nb_elt):
307+
wn = winding_number_poly(xy_poly_in[i_elt, 0], xy_poly_in[i_elt, 1], xy_poly_out)
308+
if wn == 0:
309+
return False
310+
return True
311+
312+
313+
@njit(cache=True)
314+
def winding_number_poly(x, y, xy_poly):
315+
nb_elt = xy_poly.shape[0]
316+
wn = 0
317+
# loop through all edges of the polygon
318+
for i_elt in range(nb_elt):
319+
if i_elt + 1 == nb_elt:
320+
x_next = xy_poly[0, 0]
321+
y_next = xy_poly[0, 1]
322+
else:
323+
x_next = xy_poly[i_elt + 1, 0]
324+
y_next = xy_poly[i_elt + 1, 1]
325+
if xy_poly[i_elt, 1] <= y:
326+
if y_next > y:
327+
if is_left(xy_poly[i_elt, 0],
328+
xy_poly[i_elt, 1],
329+
x_next,
330+
y_next,
331+
x, y
332+
):
333+
wn += 1
334+
else:
335+
if y_next <= y:
336+
if not is_left(xy_poly[i_elt, 0],
337+
xy_poly[i_elt, 1],
338+
x_next,
339+
y_next,
340+
x, y
341+
):
342+
wn -= 1
343+
return wn
344+
345+
346+
@njit(cache=True)
347+
def winding_number_grid_in_poly(x_1d, y_1d, i_x0, i_x1, x_size, i_y0, xy_poly):
348+
"""
349+
http://geomalgorithms.com/a03-_inclusion.html
350+
wn_PnPoly(): winding number test for a point in a polygon
351+
Input: P = a point,
352+
V[] = vertex points of a polygon V[n+1] with V[n]=V[0]
353+
Return: wn = the winding number (=0 only when P is outside)
354+
"""
355+
# the winding number counter
356+
nb_x, nb_y = len(x_1d), len(y_1d)
357+
wn = empty((nb_x, nb_y), dtype=numba_types.bool_)
358+
for i in range(nb_x):
359+
x_pt = x_1d[i]
360+
for j in range(nb_y):
361+
y_pt = y_1d[j]
362+
wn[i, j] = winding_number_poly(x_pt, y_pt, xy_poly)
363+
i_x, i_y = where(wn)
364+
i_x += i_x0
365+
i_y += i_y0
366+
if i_x1 < i_x0:
367+
i_x %= x_size
368+
return i_x, i_y
369+
370+
371+
308372
BasePath.pixels_in = pixels_in
309373
BasePath.pixels_index = pixels_index
310374
BasePath.bbox_slice = bbox_slice
@@ -597,7 +661,7 @@ def eddy_identification(self, grid_height, uname, vname, date, step=0.005, shape
597661
i_x, i_y = self.nearest_grd_indice(centlon_e, centlat_e)
598662

599663
# Check if centroid is on define value
600-
if hasattr(data, 'mask') and data.mask[i_x, i_y]:
664+
if data.mask[i_x, i_y]:
601665
continue
602666
# Test to know cyclone or anticyclone
603667
acyc_not_cyc = data[i_x, i_y] >= cvalues
@@ -806,16 +870,16 @@ def bbox_indice(self, vertices):
806870
dist, idx = self.index_interp.query(vertices, k=1)
807871
i_y = idx % self.x_c.shape[1]
808872
i_x = int_((idx - i_y) / self.x_c.shape[1])
809-
return slice(i_x.min() - self.N, i_x.max() + self.N + 1), slice(i_y.min() - self.N, i_y.max() + self.N + 1)
873+
return (i_x.min() - self.N, i_x.max() + self.N + 1), (i_y.min() - self.N, i_y.max() + self.N + 1)
810874

811875
def get_pixels_in(self, contour):
812-
slice_x, slice_y = contour.bbox_slice
813-
pts = array((self.x_c[slice_x, slice_y].reshape(-1),
814-
self.y_c[slice_x, slice_y].reshape(-1))).T
815-
mask = contour.contains_points(pts).reshape((slice_x.stop - slice_x.start, -1))
876+
(x_start, x_stop), (y_start, y_stop) = contour.bbox_slice
877+
pts = array((self.x_c[x_start:x_stop, y_start:x_stop].reshape(-1),
878+
self.y_c[x_start:y_stop, y_start:y_stop].reshape(-1))).T
879+
mask = contour.contains_points(pts).reshape((x_stop - x_start, -1))
816880
i_x, i_y = where(mask)
817-
i_x += slice_x.start
818-
i_y += slice_y.start
881+
i_x += x_start
882+
i_y += y_start
819883
return i_x, i_y
820884

821885
def normalize_x_indice(self, indices):
@@ -917,38 +981,24 @@ def init_pos_interpolator(self):
917981
self.yinterp = arange(self.y_bounds.shape[0])
918982

919983
def bbox_indice(self, vertices):
920-
lon, lat = vertices.T
921-
lon_min, lon_max = lon.min(), lon.max()
922-
lat_min, lat_max = lat.min(), lat.max()
923-
i_x0, i_y0 = self.nearest_grd_indice(lon_min, lat_min)
924-
i_x1, i_y1 = self.nearest_grd_indice(lon_max, lat_max)
925-
slice_x = slice(i_x0 - self.N, i_x1 + self.N + 1)
926-
slice_y = slice(i_y0 - self.N, i_y1 + self.N + 1)
927-
return slice_x, slice_y
984+
return bbox_indice_regular(vertices, self.x_bounds[0], self.y_bounds[0], self.xstep, self.ystep, self.N)
928985

929986
def get_pixels_in(self, contour):
930-
slice_x, slice_y = contour.bbox_slice
931-
if slice_x.stop < slice_x.start:
987+
(x_start, x_stop), (y_start, y_stop) = contour.bbox_slice
988+
if x_stop < x_start:
932989
x_ref = contour.vertices[0, 0]
933-
x_array = (concatenate((self.x_c[slice_x.start:], self.x_c[:slice_x.stop])) - x_ref + 180) % 360 + x_ref -180
990+
x_array = (concatenate((self.x_c[x_start:], self.x_c[:x_stop])) - x_ref + 180) % 360 + x_ref -180
934991
else:
935-
x_array = self.x_c[slice_x]
936-
x, y = meshgrid(x_array, self.y_c[slice_y])
937-
pts = array((x.reshape(-1), y.reshape(-1))).T
938-
mask = contour.contains_points(pts).reshape(x.shape)
939-
i_x, i_y = where(mask.T)
940-
i_x += slice_x.start
941-
i_y += slice_y.start
942-
if slice_x.stop < slice_x.start:
943-
i_x %= self.x_size
944-
return i_x, i_y
992+
x_array = self.x_c[x_start:x_stop]
993+
return winding_number_grid_in_poly(x_array, self.y_c[y_start:y_stop], x_start, x_stop, self.x_size, y_start, contour.vertices)
994+
945995

946996
def normalize_x_indice(self, indices):
947997
return indices % self.x_size
948998

949999
def nearest_grd_indice(self, x, y):
9501000
return int32(((x - self.x_bounds[0]) % 360) // self.xstep), \
951-
int32(((y - self.y_bounds[0]) % 360) // self.ystep)
1001+
int32((y - self.y_bounds[0]) // self.ystep)
9521002

9531003
@property
9541004
def xstep(self):
@@ -1110,7 +1160,7 @@ def convolve_filter_with_dynamic_kernel(self, grid_name, kernel_func, lat_max=85
11101160
logging.warning('No filtering above %f degrees of latitude', lat_max)
11111161
data = self.grid(grid_name).copy()
11121162
# Matrix for result
1113-
data_out = ma.zeros(data.shape)
1163+
data_out = ma.empty(data.shape)
11141164
data_out.mask = ones(data_out.shape, dtype=bool)
11151165
for i, lat in enumerate(self.y_c):
11161166
if abs(lat) > lat_max or data[:, i].mask.all():
@@ -1141,13 +1191,11 @@ def convolve_filter_with_dynamic_kernel(self, grid_name, kernel_func, lat_max=85
11411191
tmp_matrix[~m] = 0
11421192

11431193
demi_x, demi_y = k_shape[0] // 2, k_shape[1] // 2
1144-
# custom_(tmp_matrix, m.astype('f8'), kernel)
1145-
values_sum = filter2D(tmp_matrix, -1, kernel)[demi_x:-demi_x, demi_y]
1194+
values_sum = filter2D(tmp_matrix.data, -1, kernel)[demi_x:-demi_x, demi_y]
11461195
kernel_sum = filter2D(m.astype(float), -1, kernel)[demi_x:-demi_x, demi_y]
11471196
with errstate(invalid='ignore'):
11481197
data_out[:, i] = values_sum / kernel_sum
11491198
data_out = ma.array(data_out, mask=data.mask + data_out.mask)
1150-
11511199
return data_out
11521200

11531201
def _low_filter(self, grid_name, x_cut, y_cut):
@@ -1301,16 +1349,18 @@ def add_uv(self, grid_height):
13011349
self.vars['v'] = d_hx / d_x * gof
13021350

13031351
def speed_coef(self, contour):
1304-
lon, lat = uniform_resample_stack(contour.vertices)[1:].T
1305-
return self._speed_ev(lon, lat)
1352+
"""some nan can be compute over contour if we are near border,
1353+
something to explore
1354+
"""
1355+
return value_on_regular_contour(
1356+
self.x_c, self.y_c,
1357+
self._speed_ev, self._speed_ev.mask,
1358+
contour.vertices)
13061359

13071360
def init_speed_coef(self, uname='u', vname='v'):
13081361
"""Draft
13091362
"""
1310-
speed = (self.grid(uname) ** 2 + self.grid(vname) ** 2) ** .5
1311-
# Evaluation near masked value will be smoothed to 0 !!!, not perfect
1312-
speed[speed.mask] = 0
1313-
self._speed_ev = RectBivariateSpline(self.x_c, self.y_c, speed, kx=1, ky=1).ev
1363+
self._speed_ev = (self.grid(uname) ** 2 + self.grid(vname) ** 2) ** .5
13141364

13151365
def display(self, ax, name, **kwargs):
13161366
if 'cmap' not in kwargs:
@@ -1342,52 +1392,67 @@ def interp(self, grid_name, lons, lats):
13421392
return z
13431393

13441394

1345-
# @njit(cache=True, fastmath=True, parallel=True)
1346-
@njit(cache=True, fastmath=True)
1347-
def custom_(data, mask, kernel):
1395+
@njit(cache=True, fastmath=True, parallel=True)
1396+
def custom_convolution(data, mask, kernel):
13481397
"""do sortin at high lattitude big part of value are masked"""
13491398
nb_x = kernel.shape[0]
13501399
demi_x = int((nb_x - 1) / 2)
13511400
demi_y = int((kernel.shape[1] - 1) / 2)
13521401
out = empty(data.shape[0] - nb_x + 1)
13531402
for i in prange(out.shape[0]):
1354-
if mask[i + demi_x, demi_y] != 0:
1355-
continue
1356-
p = (mask[i:i + nb_x] * kernel).sum()
1357-
if p != 0:
1358-
out[i] = (data[i:i + nb_x] * kernel).sum() / p
1403+
if mask[i + demi_x, demi_y] == 1:
1404+
w = (mask[i:i + nb_x] * kernel).sum()
1405+
if w != 0:
1406+
out[i] = (data[i:i + nb_x] * kernel).sum() / w
1407+
else:
1408+
out[i] = nan
13591409
else:
13601410
out[i] = nan
13611411
return out
13621412

13631413

1364-
@njit(parralel=True, cache=True)
1365-
def interp_numba(x_g, y_g, z, x, y, dest_z, fill_value):
1414+
@njit(cache=True, fastmath=True)
1415+
def interp2d_geo(x_g, y_g, z_g, m_g, x, y):
1416+
"""For geographic grid, test of cicularity
1417+
Maybe test if we are out of bounds
1418+
"""
13661419
x_ref = x_g[0]
13671420
y_ref = y_g[0]
13681421
x_step = x_g[1] - x_ref
13691422
y_step = y_g[1] - y_ref
1423+
nb_x = x_g.shape[0]
1424+
is_circular = (x_g[-1] + x_step) % 360 == x_g[0] % 360
1425+
z = empty(x.shape)
13701426
for i in prange(x.size):
13711427
x_ = (x[i] - x_ref) / x_step
13721428
y_ = (y[i] - y_ref) / y_step
13731429
i0 = int(floor(x_))
13741430
i1 = i0 + 1
1431+
if is_circular:
1432+
xd = (x_ - i0)
1433+
i0 %= nb_x
1434+
i1 %= nb_x
13751435
j0 = int(floor(y_))
13761436
j1 = j0 + 1
1377-
xd = (x_ - i0)
13781437
yd = (y_ - j0)
1379-
z00 = z[i0, j0]
1380-
z01 = z[i0, j1]
1381-
z10 = z[i1, j0]
1382-
z11 = z[i1, j1]
1383-
if z00 == fill_value:
1384-
dest_z[i] = nan
1385-
elif z01 == fill_value:
1386-
dest_z[i] = nan
1387-
elif z10 == fill_value:
1388-
dest_z[i] = nan
1389-
elif z11 == fill_value:
1390-
dest_z[i] = nan
1438+
z00 = z_g[i0, j0]
1439+
z01 = z_g[i0, j1]
1440+
z10 = z_g[i1, j0]
1441+
z11 = z_g[i1, j1]
1442+
if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]:
1443+
z[i] = nan
13911444
else:
1392-
dest_z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + (z01 * (1 - xd) + z11 * xd) * yd
1445+
z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + (z01 * (1 - xd) + z11 * xd) * yd
1446+
return z
13931447

1448+
1449+
@njit(cache=True)
1450+
def bbox_indice_regular(vertices, x0, y0, xstep, ystep, N):
1451+
lon, lat = vertices[:,0], vertices[:,1]
1452+
lon_min, lon_max = lon.min(), lon.max()
1453+
lat_min, lat_max = lat.min(), lat.max()
1454+
i_x0, i_y0 = int32(((lon_min - x0) % 360) // xstep), int32((lat_min - y0) // ystep)
1455+
i_x1, i_y1 = int32(((lon_max - x0) % 360) // xstep), int32((lat_max - y0) // ystep)
1456+
slice_x = i_x0 - N, i_x1 + N + 1
1457+
slice_y = i_y0 - N, i_y1 + N + 1
1458+
return slice_x, slice_y

0 commit comments

Comments
 (0)