Skip to content

Commit fc50aa2

Browse files
committed
Change to accelerate with numba
1 parent 890d38e commit fc50aa2

File tree

2 files changed

+219
-39
lines changed

2 files changed

+219
-39
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 127 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import logging
55
from numpy import concatenate, int32, empty, maximum, where, array, \
66
sin, deg2rad, pi, ones, cos, ma, int8, histogram2d, arange, float_, \
7-
linspace, errstate, int_, column_stack, interp, meshgrid, nan, ceil, sinc, float64, isnan, \
8-
floor, percentile, zeros
7+
linspace, errstate, int_, interp, meshgrid, nan, ceil, sinc, isnan, \
8+
floor, percentile, zeros, arctan2, arcsin
99
from numpy.linalg import lstsq
1010
from datetime import datetime
1111
from scipy.special import j1
@@ -20,7 +20,7 @@
2020
from matplotlib.contour import QuadContourSet as BaseQuadContourSet
2121
from pyproj import Proj
2222
from pint import UnitRegistry
23-
from ..tools import distance_vector, winding_number_poly, poly_contain_poly, distance, distance_point_vector
23+
from ..tools import winding_number_poly, poly_contain_poly
2424
from ..observations import EddiesObservations
2525
from ..eddy_feature import Amplitude, Contours
2626
from .. import VAR_DESCR
@@ -65,6 +65,32 @@ def lat(self):
6565
BasePath.lat = lat
6666

6767

68+
@njit(cache=True, fastmath=True)
69+
def distance(lon0, lat0, lon1, lat1):
70+
D2R = pi / 180.
71+
sin_dlat = sin((lat1 - lat0) * 0.5 * D2R)
72+
sin_dlon = sin((lon1 - lon0) * 0.5 * D2R)
73+
cos_lat1 = cos(lat0 * D2R)
74+
cos_lat2 = cos(lat1 * D2R)
75+
a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2
76+
return 6370997.0 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5)
77+
78+
@njit(cache=True)
79+
def distance_vincenty(lon0, lat0, lon1, lat1):
80+
""" better than haversine but buggy ??"""
81+
D2R = pi / 180.
82+
dlon = (lon1 - lon0) * D2R
83+
cos_dlon = cos(dlon)
84+
cos_lat1 = cos(lat0 * D2R)
85+
cos_lat2 = cos(lat1 * D2R)
86+
sin_lat1 = sin(lat0 * D2R)
87+
sin_lat2 = sin(lat1 * D2R)
88+
return 6370997.0 * arctan2(
89+
((cos_lat2 * sin(dlon) ** 2) + (cos_lat1 * sin_lat2 - sin_lat1 * cos_lat2 * cos_dlon) ** 2) ** .5,
90+
sin_lat1 * sin_lat2 + cos_lat1 * cos_lat2 * cos_dlon)
91+
92+
93+
@njit(cache=True, fastmath=True)
6894
def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
6995
"""
7096
Resample contours to have (nearly) equal spacing
@@ -74,37 +100,86 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None):
74100
# Get distances
75101
dist = empty(x_val.shape)
76102
dist[0] = 0
77-
distance_vector(
78-
x_val[:-1], y_val[:-1], x_val[1:], y_val[1:], dist[1:])
79-
dist.cumsum(out=dist)
103+
dist[1:] = distance(x_val[:-1], y_val[:-1], x_val[1:], y_val[1:])
104+
# To be still monotonous
105+
dist[dist==0] = 1e-10
106+
dist = dist.cumsum()
80107
# Get uniform distances
81108
if fixed_size is None:
82109
fixed_size = dist.size * num_fac
83-
d_uniform = linspace(0, dist[-1], num=fixed_size)
110+
d_uniform = linspace(0, dist[-1], fixed_size)
84111
x_new = interp(d_uniform, dist, x_val)
85112
y_new = interp(d_uniform, dist, y_val)
86113
return x_new, y_new
87114

88115

89-
@property
90-
def regular_coordinates(self):
91-
"""Give a standard/regular/double sample of contour
92-
"""
93-
if not hasattr(self, '_regular_coordinates'):
94-
self._regular_coordinates = column_stack(uniform_resample(self.lon, self.lat))
95-
return self._regular_coordinates
96-
97-
98-
BasePath.regular_coordinates = regular_coordinates
116+
@njit(cache=True)
117+
def uniform_resample_stack(vertices, num_fac=2, fixed_size=None):
118+
x_val, y_val = vertices[:, 0], vertices[:, 1]
119+
x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size)
120+
data = empty((x_new.shape[0], 2))
121+
data[:, 0] = x_new
122+
data[:, 1] = y_new
123+
return data
99124

100125

101126
def fit_circle_path(self):
102127
if not hasattr(self, '_circle_params'):
103-
self._fit_circle_path()
128+
self._circle_params = _fit_circle_path(self.vertices)
104129
return self._circle_params
105130

106131

107-
def _fit_circle_path(self):
132+
@njit(cache=True, fastmath=True)
133+
def _fit_circle_path(vertice):
134+
lons, lats = vertice[:, 0], vertice[:, 1]
135+
lon0, lat0 = lons.mean(), lats.mean()
136+
c_x, c_y = coordinates_to_local(lons, lats, lon0, lat0)
137+
# Some time, edge is only a dot of few coordinates
138+
d_lon = lons.max() - lons.min()
139+
d_lat = lats.max() - lats.min()
140+
if d_lon < 1e-7 and d_lat < 1e-7:
141+
# logging.warning('An edge is only define in one position')
142+
# logging.debug('%d coordinates %s,%s', len(lons),lons,
143+
# lats)
144+
return 0, -90, nan, nan
145+
centlon_e, centlat_e, eddy_radius_e, aerr = fit_circle_c_numba(c_x, c_y)
146+
centlon_e, centlat_e = local_to_coordinates(centlon_e, centlat_e, lon0, lat0)
147+
centlon_e = (centlon_e - lon0 + 180) % 360 + lon0 - 180
148+
return centlon_e, centlat_e, eddy_radius_e, aerr
149+
150+
151+
@njit(cache=True, fastmath=True)
152+
def coordinates_to_local(lon, lat, lon0, lat0):
153+
D2R = pi / 180.
154+
R = 6370997
155+
dlon = (lon - lon0) * D2R
156+
sin_dlat = sin((lat - lat0) * 0.5 * D2R)
157+
sin_dlon = sin(dlon * 0.5)
158+
cos_lat0 = cos(lat0 * D2R)
159+
cos_lat = cos(lat * D2R)
160+
a_val = sin_dlon ** 2 * cos_lat0 * cos_lat + sin_dlat ** 2
161+
module = R * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5)
162+
163+
dx = lon - lon0
164+
dy = lat - lat0
165+
azimuth = pi /2 - arctan2(
166+
cos_lat * sin(dlon),
167+
cos_lat0 * sin(lat * D2R) - sin(lat0 * D2R) * cos_lat * cos(dlon))
168+
return module * cos(azimuth), module * sin(azimuth)
169+
170+
171+
@njit(cache=True, fastmath=True)
172+
def local_to_coordinates(x, y, lon0, lat0):
173+
D2R = pi / 180.
174+
R = 6370997
175+
d = (x ** 2 + y ** 2) ** .5 / R
176+
a = -(arctan2(y, x) -pi / 2)
177+
lat = arcsin(sin(lat0 * D2R) * cos(d) + cos(lat0 * D2R) * sin(d) * cos(a))
178+
lon = lon0 + arctan2(sin(a) * sin(d) * cos(lat0 * D2R), cos(d) - sin(lat0 * D2R) * sin(lat)) / D2R
179+
return lon, lat / D2R
180+
181+
182+
def _fit_circle_path_old(self):
108183
lon_mean, lat_mean = self.mean_coordinates
109184
# Prepare for shape test and get eddy_radius_e
110185
# http://www.geo.hunter.cuny.edu/~jochen/gtech201/lectures/
@@ -199,7 +274,6 @@ def fit_circle_c_numba(x_vec, y_vec):
199274

200275

201276
BasePath.fit_circle = fit_circle_path
202-
BasePath._fit_circle_path = _fit_circle_path
203277

204278

205279
def pixels_in(self, grid):
@@ -807,7 +881,7 @@ def _low_filter(self, grid_name, x_cut, y_cut, factor=40.):
807881
return ma.array(z_interp, mask=m_interp.ev(x, y) > 0.00001)
808882

809883
def speed_coef(self, contour):
810-
dist, idx = self.index_interp.query(contour.regular_coordinates[1:], k=4)
884+
dist, idx = self.index_interp.query(uniform_resample_stack(contour.vertices)[1:], k=4)
811885
i_y = idx % self.x_c.shape[1]
812886
i_x = int_((idx - i_y) / self.x_c.shape[1])
813887
# A simplified solution to be change by a weight mean
@@ -990,10 +1064,7 @@ def kernel_lanczos(self, lat, wave_length, order=1):
9901064
self.xstep)
9911065

9921066
y, x = meshgrid(y, x)
993-
out_shape = x.shape
994-
dist = empty(out_shape, dtype=float64).flatten()
995-
distance_point_vector(0, lat, x.astype(float64).flatten(), y.astype(float64).flatten(), dist)
996-
dist_norm = dist.reshape(out_shape) / 1000. / wave_length
1067+
dist_norm = distance(0, lat, x, y) / 1000. / wave_length
9971068

9981069
# sinc(d_x) and sinc(d_y) are windows and bessel function give an equivalent of sinc for lanczos filter
9991070
kernel = sinc(dist_norm/order) * sinc(dist_norm)
@@ -1015,28 +1086,25 @@ def kernel_bessel(self, lat, wave_length, order=1):
10151086
raise Exception()
10161087
# half size will be multiply with by order
10171088
half_x_pt, half_y_pt = ceil(wave_length / step_x_km).astype(int), ceil(wave_length / step_y_km).astype(int)
1018-
1089+
# x size is not good over 60 degrees
10191090
y = arange(
10201091
lat - self.ystep * half_y_pt * order,
10211092
lat + self.ystep * half_y_pt * order + 0.01 * self.ystep,
10221093
self.ystep)
1023-
x = arange(
1024-
-self.xstep * half_x_pt * order,
1025-
self.xstep * half_x_pt * order + 0.01 * self.xstep,
1026-
self.xstep)
1027-
1094+
# We compute half + 1 and the other part will be compute by symetry
1095+
x = arange(0, self.xstep * half_x_pt * order + 0.01 * self.xstep, self.xstep)
10281096
y, x = meshgrid(y, x)
1029-
out_shape = x.shape
1030-
dist = empty(out_shape, dtype=float64).flatten()
1031-
distance_point_vector(0, lat, x.astype(float64).flatten(), y.astype(float64).flatten(), dist)
1032-
dist_norm = dist.reshape(out_shape) / 1000. / wave_length
1033-
1097+
dist_norm = distance(0, lat, x, y) / 1000. / wave_length
10341098
# sinc(d_x) and sinc(d_y) are windows and bessel function give an equivalent of sinc for lanczos filter
10351099
with errstate(invalid='ignore'):
10361100
kernel = sinc(dist_norm/order) * j1(2 * pi * dist_norm) / dist_norm
1037-
kernel[half_x_pt * order,half_y_pt * order] = pi
1101+
kernel[0, half_y_pt * order] = pi
10381102
kernel[dist_norm > order] = 0
1039-
return kernel
1103+
# Symetry
1104+
kernel_ = empty((half_x_pt * 2 * order + 1, half_y_pt * 2 * order + 1))
1105+
kernel_[half_x_pt * order:] = kernel
1106+
kernel_[:half_x_pt * order] = kernel[:0:-1]
1107+
return kernel_
10401108

10411109
def convolve_filter_with_dynamic_kernel(self, grid_name, kernel_func, lat_max=85, **kwargs_func):
10421110
logging.warning('No filtering above %f degrees of latitude', lat_max)
@@ -1073,6 +1141,7 @@ def convolve_filter_with_dynamic_kernel(self, grid_name, kernel_func, lat_max=85
10731141
tmp_matrix[~m] = 0
10741142

10751143
demi_x, demi_y = k_shape[0] // 2, k_shape[1] // 2
1144+
# custom_(tmp_matrix, m.astype('f8'), kernel)
10761145
values_sum = filter2D(tmp_matrix, -1, kernel)[demi_x:-demi_x, demi_y]
10771146
kernel_sum = filter2D(m.astype(float), -1, kernel)[demi_x:-demi_x, demi_y]
10781147
with errstate(invalid='ignore'):
@@ -1232,7 +1301,7 @@ def add_uv(self, grid_height):
12321301
self.vars['v'] = d_hx / d_x * gof
12331302

12341303
def speed_coef(self, contour):
1235-
lon, lat = contour.regular_coordinates[1:].T
1304+
lon, lat = uniform_resample_stack(contour.vertices)[1:].T
12361305
return self._speed_ev(lon, lat)
12371306

12381307
def init_speed_coef(self, uname='u', vname='v'):
@@ -1273,6 +1342,25 @@ def interp(self, grid_name, lons, lats):
12731342
return z
12741343

12751344

1345+
# @njit(cache=True, fastmath=True, parallel=True)
1346+
@njit(cache=True, fastmath=True)
1347+
def custom_(data, mask, kernel):
1348+
"""do sortin at high lattitude big part of value are masked"""
1349+
nb_x = kernel.shape[0]
1350+
demi_x = int((nb_x - 1) / 2)
1351+
demi_y = int((kernel.shape[1] - 1) / 2)
1352+
out = empty(data.shape[0] - nb_x + 1)
1353+
for i in prange(out.shape[0]):
1354+
if mask[i + demi_x, demi_y] != 0:
1355+
continue
1356+
p = (mask[i:i + nb_x] * kernel).sum()
1357+
if p != 0:
1358+
out[i] = (data[i:i + nb_x] * kernel).sum() / p
1359+
else:
1360+
out[i] = nan
1361+
return out
1362+
1363+
12761364
@njit(parralel=True, cache=True)
12771365
def interp_numba(x_g, y_g, z, x, y, dest_z, fill_value):
12781366
x_ref = x_g[0]

src/py_eddy_tracker/tools.pyx

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,101 @@ ctypedef double DTYPE_coord
1212

1313

1414
cdef DTYPE_coord D2R = 0.017453292519943295
15+
cdef DTYPE_coord PI = 3.141592653589793
1516
cdef DTYPE_coord EARTH_DIAMETER = 6371.3150 * 2
1617

1718

19+
@wraparound(False)
20+
@boundscheck(False)
21+
def fit_circle_c(
22+
ndarray[DTYPE_coord] x_vec,
23+
ndarray[DTYPE_coord] y_vec
24+
):
25+
"""
26+
Fit the circle
27+
Adapted from ETRACK (KCCMC11)
28+
"""
29+
cdef DTYPE_ui i_elt, i_start, i_end, nb_elt
30+
cdef DTYPE_coord x_mean, y_mean, scale, norme_max, center_x, center_y, radius
31+
cdef DTYPE_coord p_area, c_area, a_err, p_area_incirc, dist_poly
32+
nb_elt = x_vec.shape[0]
33+
34+
cdef DTYPE_coord * p_inon_x = <DTYPE_coord * >malloc(nb_elt * sizeof(DTYPE_coord))
35+
if not p_inon_x:
36+
raise MemoryError()
37+
cdef DTYPE_coord * p_inon_y = <DTYPE_coord * >malloc(nb_elt * sizeof(DTYPE_coord))
38+
if not p_inon_y:
39+
raise MemoryError()
40+
41+
x_mean = 0
42+
y_mean = 0
43+
44+
for i_elt from 0 <= i_elt < nb_elt:
45+
x_mean += x_vec[i_elt]
46+
y_mean += y_vec[i_elt]
47+
y_mean /= nb_elt
48+
x_mean /= nb_elt
49+
50+
norme = (x_vec - x_mean) ** 2 + (y_vec - y_mean) ** 2
51+
norme_max = norme.max()
52+
scale = norme_max ** .5
53+
54+
# Form matrix equation and solve it
55+
# Maybe put f4
56+
datas = ones((nb_elt, 3), dtype='f8')
57+
for i_elt from 0 <= i_elt < nb_elt:
58+
datas[i_elt, 0] = 2. * (x_vec[i_elt] - x_mean) / scale
59+
datas[i_elt, 1] = 2. * (y_vec[i_elt] - y_mean) / scale
60+
61+
(center_x, center_y, radius), _, _, _ = lstsq(datas, norme / norme_max, rcond=None)
62+
63+
# Unscale data and get circle variables
64+
radius += center_x ** 2 + center_y ** 2
65+
radius **= .5
66+
center_x *= scale
67+
center_y *= scale
68+
# radius of fitted circle
69+
radius *= scale
70+
# center X-position of fitted circle
71+
center_x += x_mean
72+
# center Y-position of fitted circle
73+
center_y += y_mean
74+
75+
# area of fitted circle
76+
c_area = (radius ** 2) * PI
77+
78+
# Find distance between circle center and contour points_inside_poly
79+
for i_elt from 0 <= i_elt < nb_elt:
80+
# Find distance between circle center and contour points_inside_poly
81+
dist_poly = ((x_vec[i_elt] - center_x) ** 2 + (y_vec[i_elt] - center_y) ** 2) ** .5
82+
# Indices of polygon points outside circle
83+
# p_inon_? : polygon x or y points inside & on the circle
84+
if dist_poly > radius:
85+
p_inon_y[i_elt] = center_y + radius * (y_vec[i_elt] - center_y) / dist_poly
86+
p_inon_x[i_elt] = center_x - (center_x - x_vec[i_elt]) * (center_y - p_inon_y[i_elt]) / (center_y - y_vec[i_elt])
87+
else:
88+
p_inon_x[i_elt] = x_vec[i_elt]
89+
p_inon_y[i_elt] = y_vec[i_elt]
90+
91+
# Area of closed contour/polygon enclosed by the circle
92+
p_area_incirc = 0
93+
p_area = 0
94+
for i_elt from 0 <= i_elt < (nb_elt - 1):
95+
# Indices of polygon points outside circle
96+
# p_inon_? : polygon x or y points inside & on the circle
97+
p_area_incirc += p_inon_x[i_elt] * p_inon_y[1 + i_elt] - p_inon_x[i_elt + 1] * p_inon_y[i_elt]
98+
# Shape test
99+
# Area and centroid of closed contour/polygon
100+
p_area += x_vec[i_elt] * y_vec[1 + i_elt] - x_vec[1 + i_elt] * y_vec[i_elt]
101+
p_area = abs(p_area) * .5
102+
free(p_inon_x)
103+
free(p_inon_y)
104+
p_area_incirc = abs(p_area_incirc) * .5
105+
106+
a_err = (c_area - 2 * p_area_incirc + p_area) * 100. / c_area
107+
return center_x, center_y, radius, a_err
108+
109+
18110
@wraparound(False)
19111
@boundscheck(False)
20112
cdef is_left(

0 commit comments

Comments
 (0)