Skip to content

Commit a3ba91e

Browse files
committed
Add function to know if positions are in eddies
1 parent e761210 commit a3ba91e

File tree

2 files changed

+66
-21
lines changed

2 files changed

+66
-21
lines changed

src/py_eddy_tracker/generic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ def simplify(x, y, precision=0.1):
290290
mask = ones(nb, dtype=bool_)
291291
for i in range(1, nb):
292292
x_, y_ = x[i], y[i]
293+
if isnan(x_) or isnan(y_):
294+
continue
293295
d_x = x_ - x_previous
294296
if d_x > precision:
295297
x_previous, y_previous = x_, y_

src/py_eddy_tracker/observations/observation.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
create_vertice,
8282
close_center,
8383
get_pixel_in_regular,
84+
winding_number_poly,
8485
)
8586

8687
logger = logging.getLogger("pet")
@@ -1497,14 +1498,14 @@ def scatter(self, ax, name=None, ref=None, factor=1, **kwargs):
14971498
if ref is not None:
14981499
x = (x - ref) % 360 + ref
14991500
kwargs = kwargs.copy()
1500-
if name is not None and 'c' not in kwargs:
1501-
kwargs['c'] = self[name] * factor
1501+
if name is not None and "c" not in kwargs:
1502+
kwargs["c"] = self[name] * factor
15021503
return ax.scatter(x, self.latitude, **kwargs)
15031504

15041505
def filled(
15051506
self,
15061507
ax,
1507-
varname,
1508+
varname=None,
15081509
ref=None,
15091510
intern=False,
15101511
cmap="magma_r",
@@ -1516,7 +1517,7 @@ def filled(
15161517
):
15171518
"""
15181519
:param matplotlib.axes.Axes ax: matplotlib axes use to draw
1519-
:param str,array varname: var which will be use to fill contour, or an array of same size of obs
1520+
:param str,array varname, None: var which will be use to fill contour, or an array of same size of obs
15201521
:param float,None ref: if define use like west bound
15211522
:param bool intern: if True draw speed contour instead of effective contour
15221523
:param str cmap: matplotlib colormap name
@@ -1529,26 +1530,28 @@ def filled(
15291530
15301531
.. minigallery:: py_eddy_tracker.EddiesObservations.filled
15311532
"""
1532-
cmap = get_cmap(cmap, lut)
15331533
x_name, y_name = self.intern(intern)
1534-
v = (self[varname] if isinstance(varname, str) else varname) * factor
15351534
x, y = self[x_name], self[y_name]
15361535
if ref is not None:
15371536
# TODO : maybe buggy with global display
15381537
shape_out = x.shape
15391538
x, y = wrap_longitude(x.reshape(-1), y.reshape(-1), ref)
15401539
x, y = x.reshape(shape_out), y.reshape(shape_out)
1541-
if vmin is None:
1542-
vmin = v.min()
1543-
if vmax is None:
1544-
vmax = v.max()
1545-
v = (v - vmin) / (vmax - vmin)
15461540
verts = list()
1547-
colors = list()
1548-
for x_, y_, v_ in zip(x, y, v):
1541+
for x_, y_ in zip(x, y):
15491542
verts.append(create_vertice(x_, y_))
1550-
colors.append(cmap(v_))
1551-
c = PolyCollection(verts, facecolors=colors, **kwargs)
1543+
if "facecolors" not in kwargs:
1544+
kwargs = kwargs.copy()
1545+
cmap = get_cmap(cmap, lut)
1546+
v = (self[varname] if isinstance(varname, str) else varname) * factor
1547+
if vmin is None:
1548+
vmin = v.min()
1549+
if vmax is None:
1550+
vmax = v.max()
1551+
v = (v - vmin) / (vmax - vmin)
1552+
colors = [cmap(v_) for v_ in v]
1553+
kwargs["facecolors"] = colors
1554+
c = PolyCollection(verts, **kwargs)
15521555
ax.add_collection(c)
15531556
c.cmap = cmap
15541557
c.norm = Normalize(vmin=vmin, vmax=vmax)
@@ -1627,6 +1630,19 @@ def last_obs(self):
16271630
m[:-1][self["n"][1:] == 0] = True
16281631
return self.extract_with_mask(m)
16291632

1633+
def inside(self, x, y, intern=False):
1634+
"""
1635+
True for each postion inside an eddy
1636+
1637+
:param array x: longitude
1638+
:param array y: latitude
1639+
:param bool intern: If true use speed contour instead of effective contour
1640+
:return: flag
1641+
:rtype: array[bool]
1642+
"""
1643+
xname, yname = self.intern(intern)
1644+
return insidepoly(x, y, self[xname], self[yname])
1645+
16301646
def grid_count(self, bins, intern=False, center=False):
16311647
"""
16321648
Compute count of eddies in each bin (use of all pixel in each contour)
@@ -1720,14 +1736,14 @@ def interp_grid(
17201736
"""
17211737
if method == "center":
17221738
return grid_object.interp(varname, self.longitude, self.latitude)
1723-
elif method in ("min", "max", "mean", 'count'):
1739+
elif method in ("min", "max", "mean", "count"):
17241740
x0 = grid_object.x_bounds[0]
17251741
x_name, y_name = self.intern(False if intern is None else intern)
17261742
x_ref = ((self.longitude - x0) % 360 + x0 - 180).reshape(-1, 1)
17271743
x, y = (self[x_name] - x_ref) % 360 + x_ref, self[y_name]
17281744
grid = grid_object.grid(varname)
17291745
result = empty(self.shape, dtype=grid.dtype if dtype is None else dtype)
1730-
min_method = method == 'min'
1746+
min_method = method == "min"
17311747
grid_stat(
17321748
grid_object.x_c,
17331749
grid_object.y_c,
@@ -1736,7 +1752,7 @@ def interp_grid(
17361752
y,
17371753
result,
17381754
grid_object.is_circular(),
1739-
method='max' if min_method else method
1755+
method="max" if min_method else method,
17401756
)
17411757
return -result if min_method else result
17421758
else:
@@ -1760,7 +1776,34 @@ def grid_count_(grid, i, j):
17601776

17611777

17621778
@njit(cache=True)
1763-
def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method='mean'):
1779+
def insidepoly(x_p, y_p, x_c, y_c):
1780+
"""
1781+
True for each postion inside an contour
1782+
1783+
:param array x_p: longitude to test
1784+
:param array y_p: latitude to test
1785+
:param array x_c: longitude of contours
1786+
:param array y_c: latitude of contours
1787+
"""
1788+
nb_p = x_p.shape[0]
1789+
nb_c = x_c.shape[0]
1790+
flag = zeros(nb_p, dtype=numba_types.bool_)
1791+
for i in range(nb_c):
1792+
x_c_min, y_c_min = x_c[i].min(), y_c[i].min()
1793+
x_c_max, y_c_max = x_c[i].max(), y_c[i].max()
1794+
v = create_vertice(x_c[i], y_c[i])
1795+
for j in range(nb_p):
1796+
x, y = x_p[j], y_p[j]
1797+
if flag[j]:
1798+
continue
1799+
if x > x_c_min and x < x_c_max and y > y_c_min and y < y_c_max:
1800+
if winding_number_poly(x, y, v) != 0:
1801+
flag[j] = True
1802+
return flag
1803+
1804+
1805+
@njit(cache=True)
1806+
def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"):
17641807
"""
17651808
Compute mean of grid for each contour
17661809
@@ -1777,8 +1820,8 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method='mean'):
17771820
xstep, ystep = x_c[1] - x_c[0], y_c[1] - y_c[0]
17781821
x0, y0 = x_c - xstep / 2.0, y_c - ystep / 2.0
17791822
nb_x = x_c.shape[0]
1780-
max_method = 'max' == method
1781-
mean_method = 'mean' == method
1823+
max_method = "max" == method
1824+
mean_method = "mean" == method
17821825
for elt in range(nb):
17831826
v = create_vertice(x[elt], y[elt],)
17841827
(x_start, x_stop), (y_start, y_stop) = bbox_indice_regular(

0 commit comments

Comments
 (0)