Skip to content

Commit ee44c4a

Browse files
committed
add ugly method to filter unregular
1 parent 17327f4 commit ee44c4a

File tree

1 file changed

+77
-59
lines changed

1 file changed

+77
-59
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 77 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ class GridDataset(object):
304304
# indice margin (if put to 0, raise warning that i don't understand)
305305
N = 1
306306

307-
def __init__(self, filename, x_name, y_name, centered=None, indexs=None):
307+
def __init__(
308+
self, filename, x_name, y_name, centered=None, indexs=None, unset=False
309+
):
308310
self.dimensions = None
309311
self.variables_description = None
310312
self.global_attrs = None
@@ -327,8 +329,9 @@ def __init__(self, filename, x_name, y_name, centered=None, indexs=None):
327329
logger.warning(
328330
"We assume the position of grid is the center corner for %s", filename,
329331
)
330-
self.load_general_features()
331-
self.load()
332+
if not unset:
333+
self.load_general_features()
334+
self.load()
332335

333336
@property
334337
def is_centered(self):
@@ -414,41 +417,40 @@ def load(self):
414417
self.vars[x_name] = h.variables[x_name][sl_x]
415418
self.vars[y_name] = h.variables[y_name][sl_y]
416419

417-
if self.is_centered:
418-
logger.info("Grid center")
419-
self.x_c = self.vars[x_name]
420-
self.y_c = self.vars[y_name]
420+
self.setup_coordinates()
421+
self.init_pos_interpolator()
421422

422-
self.x_bounds = concatenate(
423-
(self.x_c, (2 * self.x_c[-1] - self.x_c[-2],))
424-
)
425-
self.y_bounds = concatenate(
426-
(self.y_c, (2 * self.y_c[-1] - self.y_c[-2],))
427-
)
428-
d_x = self.x_bounds[1:] - self.x_bounds[:-1]
429-
d_y = self.y_bounds[1:] - self.y_bounds[:-1]
430-
self.x_bounds[:-1] -= d_x / 2
431-
self.x_bounds[-1] -= d_x[-1] / 2
432-
self.y_bounds[:-1] -= d_y / 2
433-
self.y_bounds[-1] -= d_y[-1] / 2
423+
def setup_coordinates(self):
424+
x_name, y_name = self.coordinates
425+
if self.is_centered:
426+
logger.info("Grid center")
427+
self.x_c = self.vars[x_name]
428+
self.y_c = self.vars[y_name]
434429

435-
else:
436-
self.x_bounds = self.vars[x_name]
437-
self.y_bounds = self.vars[y_name]
430+
self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],)))
431+
self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],)))
432+
d_x = self.x_bounds[1:] - self.x_bounds[:-1]
433+
d_y = self.y_bounds[1:] - self.y_bounds[:-1]
434+
self.x_bounds[:-1] -= d_x / 2
435+
self.x_bounds[-1] -= d_x[-1] / 2
436+
self.y_bounds[:-1] -= d_y / 2
437+
self.y_bounds[-1] -= d_y[-1] / 2
438438

439-
if len(self.x_dim) == 1:
440-
self.x_c = self.x_bounds.copy()
441-
dx2 = (self.x_bounds[1:] - self.x_bounds[:-1]) / 2
442-
self.x_c[:-1] += dx2
443-
self.x_c[-1] += dx2[-1]
444-
self.y_c = self.y_bounds.copy()
445-
dy2 = (self.y_bounds[1:] - self.y_bounds[:-1]) / 2
446-
self.y_c[:-1] += dy2
447-
self.y_c[-1] += dy2[-1]
448-
else:
449-
raise Exception("not write")
439+
else:
440+
self.x_bounds = self.vars[x_name]
441+
self.y_bounds = self.vars[y_name]
450442

451-
self.init_pos_interpolator()
443+
if len(self.x_dim) == 1:
444+
self.x_c = self.x_bounds.copy()
445+
dx2 = (self.x_bounds[1:] - self.x_bounds[:-1]) / 2
446+
self.x_c[:-1] += dx2
447+
self.x_c[-1] += dx2[-1]
448+
self.y_c = self.y_bounds.copy()
449+
dy2 = (self.y_bounds[1:] - self.y_bounds[:-1]) / 2
450+
self.y_c[:-1] += dy2
451+
self.y_c[-1] += dy2[-1]
452+
else:
453+
raise Exception("not write")
452454

453455
def is_circular(self):
454456
"""Check grid circularity
@@ -1084,44 +1086,47 @@ def init_pos_interpolator(self):
10841086

10851087
logger.debug("... OK")
10861088

1087-
def _low_filter(self, grid_name, x_cut, y_cut, factor=40.):
1089+
def _low_filter(self, grid_name, w_cut, factor=8.0):
10881090
data = self.grid(grid_name)
10891091
mean_data = data.mean()
10901092
x = self.grid(self.coordinates[0])
10911093
y = self.grid(self.coordinates[1])
1092-
regrid_x_step = x_cut / factor
1093-
regrid_y_step = y_cut / factor
1094+
regrid_step = w_cut / 111.0 / factor
10941095
x_min, x_max, y_min, y_max = self.bounds
1095-
x_array = arange(x_min, x_max + regrid_x_step, regrid_x_step)
1096-
y_array = arange(y_min, y_max + regrid_y_step, regrid_y_step)
1096+
x_array = arange(x_min, x_max + regrid_step, regrid_step)
1097+
y_array = arange(y_min, min(y_max + regrid_step, 89), regrid_step)
10971098
bins = (x_array, y_array)
10981099

10991100
x_flat, y_flat, z_flat = x.reshape((-1,)), y.reshape((-1,)), data.reshape((-1,))
1100-
m = -z_flat.mask
1101+
m = ~z_flat.mask
11011102
x_flat, y_flat, z_flat = x_flat[m], y_flat[m], z_flat[m]
11021103

1103-
nb_value, bounds_x, bounds_y = histogram2d(
1104-
x_flat, y_flat,
1105-
bins=bins)
1104+
nb_value, _, _ = histogram2d(x_flat, y_flat, bins=bins)
11061105

1107-
sum_value, _, _ = histogram2d(
1108-
x_flat, y_flat,
1109-
bins=bins,
1110-
weights=z_flat)
1106+
sum_value, _, _ = histogram2d(x_flat, y_flat, bins=bins, weights=z_flat)
11111107

1112-
with errstate(invalid='ignore'):
1108+
with errstate(invalid="ignore"):
11131109
z_grid = ma.array(sum_value / nb_value, mask=nb_value == 0)
1114-
i_x, i_y = x_cut * 0.125 / regrid_x_step, y_cut * 0.125 / regrid_y_step
1115-
m = nb_value == 0
1116-
1117-
z_filtered = self._gaussian_filter(z_grid, (i_x, i_y))
1118-
1119-
z_filtered[m] = 0
1120-
x_center = (bounds_x[:-1] + bounds_x[1:]) / 2
1121-
y_center = (bounds_y[:-1] + bounds_y[1:]) / 2
1110+
regular_grid = RegularGridDataset.with_array(
1111+
coordinates=self.coordinates,
1112+
datas={
1113+
grid_name: z_grid,
1114+
self.coordinates[0]: x_array[:-1],
1115+
self.coordinates[1]: y_array[:-1],
1116+
},
1117+
)
1118+
regular_grid.bessel_low_filter(grid_name, w_cut, order=1)
1119+
z_filtered = regular_grid.grid(grid_name)
1120+
x_center = (x_array[:-1] + x_array[1:]) / 2
1121+
y_center = (y_array[:-1] + y_array[1:]) / 2
11221122
opts_interpolation = dict(kx=1, ky=1, s=0)
1123-
m_interp = RectBivariateSpline(x_center, y_center, m, **opts_interpolation)
1124-
z_interp = RectBivariateSpline(x_center, y_center, z_filtered, **opts_interpolation).ev(x, y)
1123+
m_interp = RectBivariateSpline(
1124+
x_center, y_center, z_filtered.mask, **opts_interpolation
1125+
)
1126+
z_filtered.data[z_filtered.mask] = 0
1127+
z_interp = RectBivariateSpline(
1128+
x_center, y_center, z_filtered.data, **opts_interpolation
1129+
).ev(x, y)
11251130
return ma.array(z_interp, mask=m_interp.ev(x, y) > 0.00001)
11261131

11271132
def speed_coef_mean(self, contour):
@@ -1152,10 +1157,23 @@ class RegularGridDataset(GridDataset):
11521157
def __init__(self, *args, **kwargs):
11531158
super(RegularGridDataset, self).__init__(*args, **kwargs)
11541159
self._is_circular = None
1160+
1161+
def setup_coordinates(self):
1162+
super(RegularGridDataset, self).setup_coordinates()
11551163
self.x_size = self.x_c.shape[0]
11561164
self._x_step = (self.x_c[1:] - self.x_c[:-1]).mean()
11571165
self._y_step = (self.y_c[1:] - self.y_c[:-1]).mean()
11581166

1167+
@classmethod
1168+
def with_array(cls, coordinates, datas):
1169+
x_name, y_name = coordinates[0], coordinates[1]
1170+
obj = cls("array", x_name, y_name, unset=True, centered=False)
1171+
obj.x_dim = (x_name,)
1172+
obj.y_dim = (y_name,)
1173+
for k, v in datas.items():
1174+
obj.vars[k] = v
1175+
obj.setup_coordinates()
1176+
return obj
11591177

11601178
def init_pos_interpolator(self):
11611179
"""Create function to have a quick index interpolator
@@ -1277,7 +1295,7 @@ def kernel_bessel(self, lat, wave_length, order=1):
12771295
min_wave_length = max(step_x_km * 2, step_y_km * 2)
12781296
if wave_length < min_wave_length:
12791297
logger.error(
1280-
"Wave_length to short for resolution, must be > %d km",
1298+
"Wave_length too short for resolution, must be > %d km",
12811299
ceil(min_wave_length),
12821300
)
12831301
raise Exception()

0 commit comments

Comments
 (0)