Skip to content

Commit c979b2f

Browse files
committed
Improve unregular grid
1 parent 5001165 commit c979b2f

File tree

3 files changed

+51
-15
lines changed

3 files changed

+51
-15
lines changed

src/py_eddy_tracker/dataset/grid.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def lat(self):
5555
BasePath.lat = lat
5656

5757

58+
@njit(cache=True)
59+
def prepare_for_kdtree(x_val, y_val):
60+
data = empty((x_val.shape[0], 2))
61+
data[:, 0] = x_val
62+
data[:, 1] = y_val
63+
return data
64+
65+
5866
@njit(cache=True)
5967
def uniform_resample_stack(vertices, num_fac=2, fixed_size=None):
6068
x_val, y_val = vertices[:, 0], vertices[:, 1]
@@ -196,6 +204,7 @@ class GridDataset(object):
196204
'coordinates',
197205
'filename',
198206
'dimensions',
207+
'indexs',
199208
'variables_description',
200209
'global_attrs',
201210
'vars',
@@ -209,7 +218,7 @@ class GridDataset(object):
209218
# EARTH_RADIUS = 6378136.3
210219
N = 1
211220

212-
def __init__(self, filename, x_name, y_name, centered=None):
221+
def __init__(self, filename, x_name, y_name, centered=None, indexs=None):
213222
self.dimensions = None
214223
self.variables_description = None
215224
self.global_attrs = None
@@ -226,6 +235,7 @@ def __init__(self, filename, x_name, y_name, centered=None):
226235
self.filename = filename
227236
self.coordinates = x_name, y_name
228237
self.vars = dict()
238+
self.indexs = None if indexs is None else indexs
229239
self.interpolators = dict()
230240
if centered is None:
231241
logger.warning('We assume the position of grid is the center'
@@ -312,8 +322,10 @@ def load(self):
312322
self.x_dim = h.variables[x_name].dimensions
313323
self.y_dim = h.variables[y_name].dimensions
314324

315-
self.vars[x_name] = h.variables[x_name][:]
316-
self.vars[y_name] = h.variables[y_name][:]
325+
sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
326+
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
327+
self.vars[x_name] = h.variables[x_name][sl_x]
328+
self.vars[y_name] = h.variables[y_name][sl_y]
317329

318330
if self.is_centered:
319331
logger.info('Grid center')
@@ -382,16 +394,18 @@ def copy(self, grid_in, grid_out):
382394
)
383395
self.vars[grid_out] = self.grid(grid_in).copy()
384396

385-
def grid(self, varname):
397+
def grid(self, varname, indexs=None):
386398
"""give grid required
387399
"""
400+
if indexs is None:
401+
indexs = dict()
388402
if varname not in self.vars:
389403
coordinates_dims = list(self.x_dim)
390404
coordinates_dims.extend(list(self.y_dim))
391405
logger.debug('Load %(varname)s from %(filename)s', dict(varname=varname, filename=self.filename))
392406
with Dataset(self.filename) as h:
393407
dims = h.variables[varname].dimensions
394-
sl = [slice(None) if dim in coordinates_dims else 0 for dim in dims]
408+
sl = [indexs.get(dim, self.indexs.get(dim, slice(None) if dim in coordinates_dims else 0)) for dim in dims]
395409
self.vars[varname] = h.variables[varname][sl]
396410
if len(self.x_dim) == 1:
397411
i_x = where(array(dims) == self.x_dim)[0][0]
@@ -483,6 +497,10 @@ def eddy_identification(self, grid_height, uname, vname, date, step=0.005, shape
483497

484498
# Get h grid
485499
data = self.grid(grid_height).astype('f8')
500+
# In case of a reduce mask
501+
if len(data.mask.shape) == 0 and not data.mask:
502+
data.mask = zeros(data.shape, dtype='bool')
503+
# we remove noisy information
486504
if precision is not None:
487505
data = (data / precision).round() * precision
488506
# Compute levels for ssh
@@ -753,6 +771,24 @@ class UnRegularGridDataset(GridDataset):
753771
'_speed_norm',
754772
)
755773

774+
def load(self):
775+
"""Load variable (data)
776+
"""
777+
x_name, y_name = self.coordinates
778+
with Dataset(self.filename) as h:
779+
self.x_dim = h.variables[x_name].dimensions
780+
self.y_dim = h.variables[y_name].dimensions
781+
782+
sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
783+
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
784+
self.vars[x_name] = h.variables[x_name][sl_x]
785+
self.vars[y_name] = h.variables[y_name][sl_y]
786+
787+
self.x_c = self.vars[x_name]
788+
self.y_c = self.vars[y_name]
789+
790+
self.init_pos_interpolator()
791+
756792
def bbox_indice(self, vertices):
757793
dist, idx = self.index_interp.query(vertices, k=1)
758794
i_y = idx % self.x_c.shape[1]
@@ -761,8 +797,9 @@ def bbox_indice(self, vertices):
761797

762798
def get_pixels_in(self, contour):
763799
(x_start, x_stop), (y_start, y_stop) = contour.bbox_slice
764-
pts = array((self.x_c[x_start:x_stop, y_start:x_stop].reshape(-1),
765-
self.y_c[x_start:y_stop, y_start:y_stop].reshape(-1))).T
800+
pts = array((self.x_c[x_start:x_stop, y_start:y_stop].reshape(-1),
801+
self.y_c[x_start:x_stop, y_start:y_stop].reshape(-1))).T
802+
x_stop = min(x_stop, self.x_c.shape[0])
766803
mask = contour.contains_points(pts).reshape((x_stop - x_start, -1))
767804
i_x, i_y = where(mask)
768805
i_x += x_start
@@ -785,10 +822,8 @@ def compute_pixel_path(self, x0, y0, x1, y1):
785822
def init_pos_interpolator(self):
786823
logger.debug('Create a KdTree could be long ...')
787824
self.index_interp = cKDTree(
788-
uniform_resample_stack((
789-
self.x_c.reshape(-1),
790-
self.y_c.reshape(-1)
791-
)))
825+
prepare_for_kdtree(self.x_c.reshape(-1), self.y_c.reshape(-1)))
826+
792827
logger.debug('... OK')
793828

794829
def _low_filter(self, grid_name, x_cut, y_cut, factor=40.):

src/py_eddy_tracker/eddy_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False):
352352
logger.debug('Y shape : %s', y.shape)
353353
logger.debug('Z shape : %s', z.shape)
354354
logger.info('Start computing iso lines with %d levels from %f to %f ...', len(levels), levels[0], levels[-1])
355-
self.contours = ax.contour(x, y, z.T, levels, cmap='rainbow')
355+
self.contours = ax.contour(x, y, z.T if z.shape != x.shape else z, levels, cmap='rainbow')
356356
if wrap_x:
357357
self.find_wrapcut_path_and_join(x[0], x[-1])
358358
logger.info('Finish computing iso lines')

src/scripts/EddyId

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Do identification
66
from datetime import datetime
77
from netCDF4 import Dataset
88
from py_eddy_tracker import EddyParser
9-
from py_eddy_tracker.dataset.grid import RegularGridDataset
9+
from py_eddy_tracker.dataset.grid import RegularGridDataset, UnRegularGridDataset
1010
import zarr
1111

1212
def id_parser():
@@ -30,6 +30,7 @@ def id_parser():
3030
help='Force height unit')
3131
parser.add_argument('--speed_unit', default=None, type=str,
3232
help='Force speed unit')
33+
parser.add_argument('--unregular', action='store_true', help='if grid is unregular')
3334
parser.add_argument('--zarr',
3435
action='store_true',
3536
help='Output will be wrote in zarr')
@@ -38,8 +39,8 @@ def id_parser():
3839

3940
if __name__ == '__main__':
4041
args = id_parser().parse_args()
41-
42-
h = RegularGridDataset(args.filename, args.longitude, args.latitude)
42+
grid_class = (UnRegularGridDataset if args.unregular else RegularGridDataset)
43+
h = grid_class(args.filename, args.longitude, args.latitude)
4344
date = datetime.strptime(args.datetime, '%Y%m%d')
4445
if args.u == 'None' and args.v == 'None':
4546
h.add_uv(args.h)

0 commit comments

Comments
 (0)