Skip to content

Commit aa8e774

Browse files
committed
advection function rk4 and euler
1 parent 9653441 commit aa8e774

File tree

1 file changed

+120
-6
lines changed

1 file changed

+120
-6
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from cv2 import filter2D
99
from matplotlib.path import Path as BasePath
1010
from netCDF4 import Dataset
11-
from numba import njit
11+
from numba import njit, prange
1212
from numba import types as numba_types
1313
from numpy import (
1414
arange,
@@ -21,6 +21,7 @@
2121
errstate,
2222
exp,
2323
float_,
24+
floor,
2425
histogram2d,
2526
int8,
2627
int_,
@@ -37,6 +38,7 @@
3738
ones,
3839
percentile,
3940
pi,
41+
radians,
4042
round_,
4143
sin,
4244
sinc,
@@ -242,8 +244,6 @@ class GridDataset(object):
242244
"""
243245

244246
__slots__ = (
245-
"_x_var",
246-
"_y_var",
247247
"x_c",
248248
"y_c",
249249
"x_bounds",
@@ -258,8 +258,6 @@ class GridDataset(object):
258258
"variables_description",
259259
"global_attrs",
260260
"vars",
261-
"interpolators",
262-
"speed_coef",
263261
"contours",
264262
)
265263

@@ -295,7 +293,6 @@ def __init__(
295293
self.coordinates = x_name, y_name
296294
self.vars = dict()
297295
self.indexs = dict() if indexs is None else indexs
298-
self.interpolators = dict()
299296
if centered is None:
300297
logger.warning(
301298
"We assume pixel position of grid is center for %s", filename
@@ -1956,6 +1953,123 @@ def interp(self, grid_name, lons, lats, method="bilinear"):
19561953
self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest"
19571954
)
19581955

1956+
def uv_for_advection(self, u_name, v_name, time_step=600, backward=False):
1957+
"""
1958+
Get U,V to be used in degrees with precomputed time step
1959+
1960+
:param str,array u_name: U field to advect obs
1961+
:param str,array v_name: V field to advect obs
1962+
:param int time_step: Number of second for each advection
1963+
"""
1964+
u = (self.grid(u_name) if isinstance(u_name, str) else u_name).copy()
1965+
v = (self.grid(v_name) if isinstance(v_name, str) else v_name).copy()
1966+
# N seconds / 1 degrees in m
1967+
coef = time_step * 180 / pi / self.EARTH_RADIUS
1968+
u *= coef / cos(radians(self.y_c))
1969+
v *= coef
1970+
if backward:
1971+
u = -u
1972+
v = -v
1973+
m = u.mask + v.mask
1974+
return u, v, m
1975+
1976+
@njit(cache=True)
1977+
def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, nb_step):
1978+
# Grid coordinates
1979+
x_ref, y_ref = x_g[0], y_g[0]
1980+
x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref
1981+
# On each particule
1982+
for i in prange(x.size):
1983+
# If particule are not valid => continue
1984+
if isnan(x[i]) or isnan(y[i]):
1985+
continue
1986+
# Iterate on whole steps
1987+
for _ in range(nb_step):
1988+
# k1, slope at origin
1989+
u1, v1 = get_uv(x_ref, y_ref, x_step, y_step, u_g, v_g, m_g, x[i], y[i])
1990+
if isnan(u1) or isnan(v1):
1991+
x[i], y[i] = nan, nan
1992+
break
1993+
# k2, slope at middle with first guess position
1994+
x1, y1 = x[i] + u1*.5, y[i] + v1*.5
1995+
u2, v2 = get_uv(x_ref, y_ref, x_step, y_step, u_g, v_g, m_g, x1, y1)
1996+
if isnan(u2) or isnan(v2):
1997+
x[i], y[i] = nan, nan
1998+
break
1999+
# k3, slope at middle with update guess position
2000+
x2, y2 = x[i] + u2 * .5, y[i] + v2 * .5
2001+
u3, v3 = get_uv(x_ref, y_ref, x_step, y_step, u_g, v_g, m_g, x2, y2)
2002+
if isnan(u3) or isnan(v3):
2003+
x[i], y[i] = nan, nan
2004+
break
2005+
# k4, slope at end with update guess position
2006+
x3, y3 = x2 + u3, y2 + v3
2007+
u4, v4 = get_uv(x_ref, y_ref, x_step, y_step, u_g, v_g, m_g, x3, y3)
2008+
if isnan(u4) or isnan(v4):
2009+
x[i], y[i] = nan, nan
2010+
break
2011+
dx = (u1 + 2 * u2 + 2 * u3 + u4) / 6
2012+
dy = (v1 + 2 * v2 + 2 * v3 + v4) / 6
2013+
# # Compute new x,y
2014+
x[i] += dx
2015+
y[i] += dy
2016+
2017+
2018+
@njit(cache=True)
2019+
def get_uv(x0, y0, x_step, y_step, u,v, m, x,y):
2020+
i, j = (x - x0) / x_step, (y - y0) / y_step
2021+
i0, j0 = int(floor(i)), int(floor(j))
2022+
i1, j1 = i0 + 1, j0 + 1
2023+
if m[i0, j0] or m[i0, j1] or m[i1, j0] or m[i1, j1]:
2024+
return nan, nan
2025+
# Extract value for u and v
2026+
u00, u01, u10, u11 = u[i0, j0], u[i0, j1], u[i1, j0], u[i1, j1]
2027+
v00, v01, v10, v11 = v[i0, j0], v[i0, j1], v[i1, j0], v[i1, j1]
2028+
xd, yd = i - i0, j - j0
2029+
xd_i, yd_i = 1 - xd, 1 - yd
2030+
u = (u00 * xd_i + u10 * xd) * yd_i + (u01 * xd_i + u11 * xd) * yd
2031+
v = (v00 * xd_i + v10 * xd) * yd_i + (v01 * xd_i + v11 * xd) * yd
2032+
return u, v
2033+
2034+
@njit(cache=True)
2035+
def advect(x_g, y_g, u_g, v_g, m_g, x, y, nb_step):
2036+
# Grid coordinates
2037+
x_ref, y_ref = x_g[0], y_g[0]
2038+
x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref
2039+
# Indices which should be never exist
2040+
i0_old, j0_old = -100000, -100000
2041+
# On each particule
2042+
for i in prange(x.size):
2043+
# If particule are not valid => continue
2044+
if isnan(x[i]) or isnan(y[i]):
2045+
continue
2046+
# Iterate on whole steps
2047+
for _ in range(nb_step):
2048+
# Compute coordinates in indice referentiel
2049+
x_, y_ = (x[i] - x_ref) / x_step, (y[i] - y_ref) / y_step
2050+
# corner bottom left Indice
2051+
i0_, j0_ = int(floor(x_)), int(floor(y_))
2052+
i0, j0 = i0_, j0_
2053+
i1 = i0 + 1
2054+
# corner are the same need only a new xd and yd
2055+
if i0 != i0_old or j0 != j0_old:
2056+
j1 = j0 + 1
2057+
# If one of nearest pixel is invalid
2058+
if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]:
2059+
x[i], y[i] = nan, nan
2060+
break
2061+
# Extract value for u and v
2062+
u00, u01, u10, u11 = u_g[i0, j0], u_g[i0, j1], u_g[i1, j0], u_g[i1, j1]
2063+
v00, v01, v10, v11 = v_g[i0, j0], v_g[i0, j1], v_g[i1, j0], v_g[i1, j1]
2064+
# Need to be store only on change
2065+
i0_old, j0_old = i0, j0
2066+
# Compute distance
2067+
xd, yd = x_ - i0_, y_ - j0_
2068+
xd_i, yd_i = 1 - xd, 1 - yd
2069+
# Compute new x,y
2070+
x[i] += (u00 * xd_i + u10 * xd) * yd_i + (u01 * xd_i + u11 * xd) * yd
2071+
y[i] += (v00 * xd_i + v10 * xd) * yd_i + (v01 * xd_i + v11 * xd) * yd
2072+
19592073

19602074
@njit(cache=True, fastmath=True)
19612075
def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x):

0 commit comments

Comments
 (0)