Skip to content

Commit 98161f3

Browse files
committed
Add dummy test on convolution => which detect an index error in original code(corrected)
1 parent a7ef56e commit 98161f3

File tree

4 files changed

+44
-33
lines changed

4 files changed

+44
-33
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
matplotlib
1+
matplotlib<3.5
22
netCDF4
33
numba>=0.53
44
numpy<1.21

src/py_eddy_tracker/dataset/grid.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def eddy_identification(
680680
)
681681
z_min, z_max = z_min_p, z_max_p
682682

683+
logger.debug("Levels from %f to %f", z_min, z_max)
683684
levels = arange(z_min - z_min % step, z_max - z_max % step + 2 * step, step)
684685

685686
# Get x and y values
@@ -1404,7 +1405,8 @@ def convolve_filter_with_dynamic_kernel(
14041405
tmp_matrix = ma.zeros((2 * d_lon + data.shape[0], k_shape[1]))
14051406
tmp_matrix.mask = ones(tmp_matrix.shape, dtype=bool)
14061407
# Slice to apply on input data
1407-
sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat, data.shape[1]))
1408+
# +1 for upper bound, to take in acount this column
1409+
sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat + 1, data.shape[1]))
14081410
# slice to apply on temporary matrix to store input data
14091411
sl_lat_in = slice(
14101412
d_lat - (i - sl_lat_data.start), d_lat + (sl_lat_data.stop - i)

src/py_eddy_tracker/observations/network.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import time
77
from glob import glob
88

9+
import netCDF4
10+
import zarr
911
from numba import njit
1012
from numpy import (
1113
arange,
@@ -23,9 +25,6 @@
2325
zeros,
2426
)
2527

26-
import netCDF4
27-
import zarr
28-
2928
from ..dataset.grid import GridCollection
3029
from ..generic import build_index, wrap_longitude
3130
from ..poly import bbox_intersection, vertice_overlap
@@ -680,13 +679,7 @@ def display_timeline(
680679
"""
681680
self.only_one_network()
682681
j = 0
683-
line_kw = dict(
684-
ls="-",
685-
marker="+",
686-
markersize=6,
687-
zorder=1,
688-
lw=3,
689-
)
682+
line_kw = dict(ls="-", marker="+", markersize=6, zorder=1, lw=3,)
690683
line_kw.update(kwargs)
691684
mappables = dict(lines=list())
692685

@@ -919,10 +912,7 @@ def event_map(self, ax, **kwargs):
919912
"""Add the merging and splitting events to a map"""
920913
j = 0
921914
mappables = dict()
922-
symbol_kw = dict(
923-
markersize=10,
924-
color="k",
925-
)
915+
symbol_kw = dict(markersize=10, color="k",)
926916
symbol_kw.update(kwargs)
927917
symbol_kw_split = symbol_kw.copy()
928918
symbol_kw_split["markersize"] += 4
@@ -951,13 +941,7 @@ def event_map(self, ax, **kwargs):
951941
return mappables
952942

953943
def scatter(
954-
self,
955-
ax,
956-
name="time",
957-
factor=1,
958-
ref=None,
959-
edgecolor_cycle=None,
960-
**kwargs,
944+
self, ax, name="time", factor=1, ref=None, edgecolor_cycle=None, **kwargs,
961945
):
962946
"""
963947
This function scatters the path of each network, with the merging and splitting events

tests/test_grid.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
11
from matplotlib.path import Path
2-
from numpy import array, isnan, ma
2+
from numpy import arange, array, isnan, ma, nan, ones, zeros
33
from pytest import approx
44

55
from py_eddy_tracker.data import get_demo_path
66
from py_eddy_tracker.dataset.grid import RegularGridDataset
77

88
G = RegularGridDataset(get_demo_path("mask_1_60.nc"), "lon", "lat")
99
X = 0.025
10-
contour = Path(
11-
(
12-
(-X, 0),
13-
(X, 0),
14-
(X, X),
15-
(-X, X),
16-
(-X, 0),
17-
)
18-
)
10+
contour = Path(((-X, 0), (X, 0), (X, X), (-X, X), (-X, 0),))
1911

2012

2113
# contour
@@ -85,3 +77,36 @@ def test_interp():
8577
assert g.interp("z", x0, y0) == 1.5
8678
assert g.interp("z", x1, y1) == 2
8779
assert isnan(g.interp("z", x2, y2))
80+
81+
82+
def test_convolution():
83+
"""
84+
Add some dummy check on convolution filter
85+
"""
86+
# Fake grid
87+
z = ma.array(
88+
arange(12).reshape((-1, 1)) * arange(10).reshape((1, -1)),
89+
mask=zeros((12, 10), dtype="bool"),
90+
dtype="f4",
91+
)
92+
g = RegularGridDataset.with_array(
93+
coordinates=("x", "y"),
94+
datas=dict(z=z, x=arange(0, 6, 0.5), y=arange(0, 5, 0.5),),
95+
centered=True,
96+
)
97+
98+
def kernel_func(lat):
99+
return ones((3, 3))
100+
101+
# After transpose we must get same result
102+
d = g.convolve_filter_with_dynamic_kernel("z", kernel_func)
103+
assert (d.T[:9, :9] == d[:9, :9]).all()
104+
# We mask one value and check convolution result
105+
z.mask[2, 2] = True
106+
d = g.convolve_filter_with_dynamic_kernel("z", kernel_func)
107+
assert d[1, 1] == z[:3, :3].sum() / 8
108+
# Add nan and check only nearest value is contaminate
109+
z[2, 2] = nan
110+
d = g.convolve_filter_with_dynamic_kernel("z", kernel_func)
111+
assert not isnan(d[0, 0])
112+
assert isnan(d[1:4, 1:4]).all()

0 commit comments

Comments
 (0)