Skip to content

Commit 24dbe92

Browse files
committed
Use numba to compute stencil
1 parent d7656c3 commit 24dbe92

File tree

2 files changed

+148
-72
lines changed

2 files changed

+148
-72
lines changed

doc/run_tracking.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Requirements
77

88
Before to run tracking, you will need to run identification on every time step of the period (period of your study).
99

10-
**Advice** : Before to run tracking, displaying some identification file allows one to learn a lot
10+
**Advice** : Before to run tracking, displaying some identification file allows to learn a lot
1111

1212
Default method
1313
**************

src/py_eddy_tracker/dataset/grid.py

Lines changed: 147 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
float_,
2424
floor,
2525
histogram2d,
26-
int8,
2726
int_,
2827
interp,
2928
isnan,
@@ -47,7 +46,7 @@
4746
)
4847
from pint import UnitRegistry
4948
from scipy.interpolate import RectBivariateSpline, interp1d
50-
from scipy.ndimage import convolve, gaussian_filter
49+
from scipy.ndimage import gaussian_filter
5150
from scipy.signal import welch
5251
from scipy.spatial import cKDTree
5352
from scipy.special import j1
@@ -1677,77 +1676,19 @@ def compute_stencil(
16771676
16781677
...
16791678
1680-
16811679
"""
16821680
stencil_halfwidth = max(min(int(stencil_halfwidth), 4), 1)
16831681
logger.debug("Stencil half width apply : %d", stencil_halfwidth)
1684-
# output
1685-
grad = None
1686-
1687-
weights = [
1688-
array((3, -32, 168, -672, 0, 672, -168, 32, -3)) / 840.0,
1689-
array((-1, 9, -45, 0, 45, -9, 1)) / 60.0,
1690-
array((1, -8, 0, 8, -1)) / 12.0,
1691-
array((-1, 0, 1)) / 2.0,
1692-
# uncentered kernel
1693-
# like array((0, -1, 1)) but left value could be default value
1694-
array((-1, 1)),
1695-
# like array((-1, 1, 0)) but right value could be default value
1696-
(1, array((-1, 1))),
1697-
]
1698-
# reduce to stencil selected
1699-
weights = weights[4 - stencil_halfwidth :]
1700-
if vertical:
1701-
data = data.T
1702-
# Iteration from larger stencil to smaller (to fill matrix)
1703-
for weight in weights:
1704-
if isinstance(weight, tuple):
1705-
# In the case of unbalanced diff
1706-
shift, weight = weight
1707-
data_ = data.copy()
1708-
data_[shift:] = data[:-shift]
1709-
if not vertical:
1710-
data_[:shift] = data[-shift:]
1711-
else:
1712-
data_ = data
1713-
# Delta h
1714-
d_h = convolve(data_, weights=weight.reshape((-1, 1)), mode=mode)
1715-
mask = convolve(
1716-
int8(data_.mask), weights=ones(weight.shape).reshape((-1, 1)), mode=mode
1717-
)
1718-
d_h = ma.array(d_h, mask=mask != 0)
1719-
1720-
# Delta d
1721-
if vertical:
1722-
d_h = d_h.T
1723-
d = self.EARTH_RADIUS * 2 * pi / 360 * convolve(self.y_c, weight)
1724-
else:
1725-
if mode == "wrap":
1726-
# Along x axis, we need to close
1727-
# we will compute in two part
1728-
x = self.x_c % 360
1729-
d_degrees = convolve(x, weight, mode=mode)
1730-
d_degrees_180 = convolve((x + 180) % 360 - 180, weight, mode=mode)
1731-
# Arbitrary, to be sure to be far far away of bound
1732-
m = (x < 90) + (x > 270)
1733-
d_degrees[m] = d_degrees_180[m]
1734-
else:
1735-
d_degrees = convolve(self.x_c, weight, mode=mode)
1736-
d = (
1737-
self.EARTH_RADIUS
1738-
* 2
1739-
* pi
1740-
/ 360
1741-
* d_degrees.reshape((-1, 1))
1742-
* cos(deg2rad(self.y_c))
1743-
)
1744-
if grad is None:
1745-
# First Gradient
1746-
grad = d_h / d
1747-
else:
1748-
# Fill hole
1749-
grad[grad.mask] = (d_h / d)[grad.mask]
1750-
return grad
1682+
g, m = compute_stencil(
1683+
self.x_c,
1684+
self.y_c,
1685+
data.data,
1686+
data.mask,
1687+
self.EARTH_RADIUS,
1688+
vertical=vertical,
1689+
stencil_halfwidth=stencil_halfwidth,
1690+
)
1691+
return ma.array(g, mask=m)
17511692

17521693
def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15):
17531694
self.add_uv(grid_height, uname, vname)
@@ -1804,13 +1745,26 @@ def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15):
18041745
self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef * (1 - w)
18051746

18061747
def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4):
1807-
"""Compute a u and v grid
1748+
r"""Compute a u and v grid
18081749
18091750
:param str grid_height: grid name where the funtion will apply stencil method
18101751
:param str uname: future name of u
18111752
:param str vname: future name of v
18121753
:param int stencil_halfwidth: largest stencil could be apply (max: 4)
18131754
1755+
.. math::
1756+
u = \frac{g}{f} \frac{dh}{dy}
1757+
1758+
v = -\frac{g}{f} \frac{dh}{dx}
1759+
1760+
where
1761+
1762+
.. math::
1763+
g = gravity
1764+
1765+
f = 2 \Omega sin(\phi)
1766+
1767+
18141768
.. minigallery:: py_eddy_tracker.RegularGridDataset.add_uv
18151769
"""
18161770
logger.info("Add u/v variable with stencil method")
@@ -2665,3 +2619,125 @@ def advect_t_rk4(
26652619
x_ += dx
26662620
y_ += dy
26672621
x[i], y[i] = x_, y_
2622+
2623+
2624+
@njit(
2625+
[
2626+
"Tuple((f8[:,:],b1[:,:]))(f8[:],f8[:],f8[:,:],b1[:,:],f8,b1,i1)",
2627+
"Tuple((f4[:,:],b1[:,:]))(f8[:],f8[:],f4[:,:],b1[:,:],f8,b1,i1)",
2628+
],
2629+
cache=True,
2630+
fastmath=True,
2631+
)
2632+
def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth=4):
2633+
"""
2634+
Compute stencil on RegularGrid
2635+
2636+
:param array x: longitude coordinates
2637+
:param array y: latitude coordinates
2638+
:param array h: 2D array to derivate
2639+
:param array m: mask associate to h to know where are invalid data
2640+
:param float earth_radius: Earth radius in m
2641+
:param bool vertical: if True stencil will be vertical (along y)
2642+
:param int stencil_halfwidth: from 1 to 4 to specify maximal kernel usable
2643+
2644+
2645+
stencil_halfwidth:
2646+
2647+
- (1) :
2648+
2649+
- (-1, 1, 0)
2650+
- (0, -1, 1)
2651+
- (-1, 0, 1) / 2
2652+
2653+
- (2) : (1, -8, 0, 8, 1) / 12
2654+
- (3) : (-1, 9, -45, 0, 45, -9, 1) / 60
2655+
- (4) : (3, -32, 168, -672, 0, 672, -168, 32, 3) / 840
2656+
"""
2657+
if vertical:
2658+
# If vertical we transpose matrix and inverse coordinates
2659+
h = h.T
2660+
m = m.T
2661+
x, y = y, x
2662+
shape = h.shape
2663+
nb_x, nb_y = shape
2664+
# Out array
2665+
m_out = empty(shape, dtype=numba_types.bool_)
2666+
grad = empty(shape, dtype=h.dtype)
2667+
# Distance step in degrees
2668+
d_step = x[1] - x[0]
2669+
if vertical:
2670+
is_circular = False
2671+
else:
2672+
# Test if matrix is circular
2673+
is_circular = abs(x[-1] % 360 - (x[0] - d_step) % 360) < 1e-5
2674+
2675+
# Compute caracteristic distance, constant when vertical
2676+
d_ = 360 / (d_step * pi * 2 * earth_radius)
2677+
for j in range(nb_y):
2678+
# Buffer of maximal size of stencil (9)
2679+
if is_circular:
2680+
h_3, h_2, h_1, h0 = h[-4, j], h[-3, j], h[-2, j], h[-1, j]
2681+
m_3, m_2, m_1, m0 = m[-4, j], m[-3, j], m[-2, j], m[-1, j]
2682+
else:
2683+
m_3, m_2, m_1, m0 = False, False, False, False
2684+
h1, h2, h3, h4 = h[0, j], h[1, j], h[2, j], h[3, j]
2685+
m1, m2, m3, m4 = m[0, j], m[1, j], m[2, j], m[3, j]
2686+
for i in range(nb_x):
2687+
# Roll value and only last
2688+
h_4, h_3, h_2, h_1, h0, h1, h2, h3 = h_3, h_2, h_1, h0, h1, h2, h3, h4
2689+
m_4, m_3, m_2, m_1, m0, m1, m2, m3 = m_3, m_2, m_1, m0, m1, m2, m3, m4
2690+
i_ = i + 4
2691+
if i_ >= nb_x:
2692+
if is_circular:
2693+
i_ = i_ % nb_x
2694+
m4 = m[i_, j]
2695+
h4 = h[i_, j]
2696+
else:
2697+
# When we are out
2698+
m4 = False
2699+
else:
2700+
m4 = m[i_, j]
2701+
h4 = h[i_, j]
2702+
2703+
# Current value not defined
2704+
if m0:
2705+
m_out[i, j] = True
2706+
continue
2707+
if not vertical:
2708+
# For each row we compute distance
2709+
d_ = 360 / (d_step * cos(deg2rad(y[j])) * pi * 2 * earth_radius)
2710+
if m1 ^ m_1:
2711+
# unbalanced kernel
2712+
if m_1:
2713+
grad[i, j] = (h1 - h0) * d_
2714+
m_out[i, j] = False
2715+
continue
2716+
if m1:
2717+
grad[i, j] = (h0 - h_1) * d_
2718+
m_out[i, j] = False
2719+
continue
2720+
continue
2721+
if m2 or m_2 or stencil_halfwidth == 1:
2722+
grad[i, j] = (h1 - h_1) / 2 * d_
2723+
m_out[i, j] = False
2724+
continue
2725+
if m3 or m_3 or stencil_halfwidth == 2:
2726+
grad[i, j] = (h_2 - h2 + 8 * (h1 - h_1)) / 12 * d_
2727+
m_out[i, j] = False
2728+
continue
2729+
if m4 or m_4 or stencil_halfwidth == 3:
2730+
grad[i, j] = (h3 - h_3 + 9 * (h_2 - h2) + 45 * (h1 - h_1)) / 60 * d_
2731+
m_out[i, j] = False
2732+
continue
2733+
# If all value of buffer are available
2734+
grad[i, j] = (
2735+
(3 * (h_4 - h4) + 32 * (h3 - h_3) + 168 * (h_2 - h2) + 672 * (h1 - h_1))
2736+
/ 840
2737+
* d_
2738+
)
2739+
m_out[i, j] = False
2740+
if vertical:
2741+
return grad.T, m_out.T
2742+
else:
2743+
return grad, m_out

0 commit comments

Comments
 (0)