diff --git a/requirements.txt b/requirements.txt index 4c8af099..7f54f14d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -matplotlib +matplotlib==3.7.1 opencv-python pint polygon3 @@ -8,4 +8,4 @@ scipy zarr netCDF4 numpy -numba \ No newline at end of file +numba diff --git a/src/py_eddy_tracker/dataset/grid.py b/src/py_eddy_tracker/dataset/grid.py index edb96bac..4de6f599 100644 --- a/src/py_eddy_tracker/dataset/grid.py +++ b/src/py_eddy_tracker/dataset/grid.py @@ -253,6 +253,7 @@ class GridDataset(object): "filename", "dimensions", "indexs", + "nc4file", "variables_description", "global_attrs", "vars", @@ -275,6 +276,7 @@ def __init__( indexs=None, unset=False, nan_masking=False, + nc4file=None, ): """ :param str filename: Filename to load @@ -301,6 +303,7 @@ def __init__( self.coordinates = x_name, y_name self.vars = dict() self.indexs = dict() if indexs is None else indexs + self.nc4file = Dataset(filename, "r") if nc4file is None else nc4file if centered is None: logger.warning( "We assume pixel position of grid is centered for %s", filename @@ -344,25 +347,25 @@ def load_general_features(self): logger.debug( "Load general feature from %(filename)s", dict(filename=self.filename) ) - with Dataset(self.filename) as h: - # Load generals - self.dimensions = {i: len(v) for i, v in h.dimensions.items()} - self.variables_description = dict() - for i, v in h.variables.items(): - args = (i, v.datatype) - kwargs = dict(dimensions=v.dimensions, zlib=True) - if hasattr(v, "_FillValue"): - kwargs["fill_value"] = (v._FillValue,) - attrs = dict() - for attr in v.ncattrs(): - if attr in kwargs.keys(): - continue - if attr == "_FillValue": - continue - attrs[attr] = getattr(v, attr) - self.variables_description[i] = dict( - args=args, kwargs=kwargs, attrs=attrs, infos=dict() - ) + h = self.nc4file + # Load generals + self.dimensions = {i: len(v) for i, v in h.dimensions.items()} + self.variables_description = dict() + for i, v in h.variables.items(): + args = (i, v.datatype) + kwargs = dict(dimensions=v.dimensions, zlib=True) + if hasattr(v, "_FillValue"): + kwargs["fill_value"] = (v._FillValue,) + attrs = dict() + for attr in v.ncattrs(): + if attr in kwargs.keys(): + continue + if attr == "_FillValue": + continue + attrs[attr] = getattr(v, attr) + self.variables_description[i] = dict( + args=args, kwargs=kwargs, attrs=attrs, infos=dict() + ) self.global_attrs = {attr: getattr(h, attr) for attr in h.ncattrs()} def write(self, filename): @@ -407,14 +410,14 @@ def load(self): Get coordinates and setup coordinates function """ x_name, y_name = self.coordinates - with Dataset(self.filename) as h: - self.x_dim = h.variables[x_name].dimensions - self.y_dim = h.variables[y_name].dimensions + h = self.nc4file + self.x_dim = h.variables[x_name].dimensions + self.y_dim = h.variables[y_name].dimensions - sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim] - sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim] - self.vars[x_name] = h.variables[x_name][sl_x] - self.vars[y_name] = h.variables[y_name][sl_y] + sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim] + sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim] + self.vars[x_name] = h.variables[x_name][sl_x] + self.vars[y_name] = h.variables[y_name][sl_y] self.setup_coordinates() @@ -481,10 +484,10 @@ def units(self, varname): stored_units = self.variables_description[varname]["attrs"].get("units", None) if stored_units is not None: return stored_units - with Dataset(self.filename) as h: - var = h.variables[varname] - if hasattr(var, "units"): - return var.units + h = self.nc4file + var = h.variables[varname] + if hasattr(var, "units"): + return var.units @property def variables(self): @@ -535,24 +538,24 @@ def grid(self, varname, indexs=None): "Load %(varname)s from %(filename)s", dict(varname=varname, filename=self.filename), ) - with Dataset(self.filename) as h: - dims = h.variables[varname].dimensions - sl = [ - indexs.get( - dim, - self.indexs.get( - dim, slice(None) if dim in coordinates_dims else 0 - ), - ) - for dim in dims - ] - self.vars[varname] = h.variables[varname][sl] - if len(self.x_dim) == 1: - i_x = where(array(dims) == self.x_dim)[0][0] - i_y = where(array(dims) == self.y_dim)[0][0] - if i_x > i_y: - self.variables_description[varname]["infos"]["transpose"] = True - self.vars[varname] = self.vars[varname].T + h = self.nc4file + dims = h.variables[varname].dimensions + sl = [ + indexs.get( + dim, + self.indexs.get( + dim, slice(None) if dim in coordinates_dims else 0 + ), + ) + for dim in dims + ] + self.vars[varname] = h.variables[varname][sl] + if len(self.x_dim) == 1: + i_x = where(array(dims) == self.x_dim)[0][0] + i_y = where(array(dims) == self.y_dim)[0][0] + if i_x > i_y: + self.variables_description[varname]["infos"]["transpose"] = True + self.vars[varname] = self.vars[varname].T if self.nan_mask: self.vars[varname] = ma.array( self.vars[varname], @@ -578,20 +581,20 @@ def grid_tiles(self, varname, slice_x, slice_y): slice_x=slice_x, ), ) - with Dataset(self.filename) as h: - dims = h.variables[varname].dimensions - sl = [ - (slice_x if dim in list(self.x_dim) else slice_y) - if dim in coordinates_dims - else 0 - for dim in dims - ] - data = h.variables[varname][sl] - if len(self.x_dim) == 1: - i_x = where(array(dims) == self.x_dim)[0][0] - i_y = where(array(dims) == self.y_dim)[0][0] - if i_x > i_y: - data = data.T + h = self.nc4file + dims = h.variables[varname].dimensions + sl = [ + (slice_x if dim in list(self.x_dim) else slice_y) + if dim in coordinates_dims + else 0 + for dim in dims + ] + data = h.variables[varname][sl] + if len(self.x_dim) == 1: + i_x = where(array(dims) == self.x_dim)[0][0] + i_y = where(array(dims) == self.y_dim)[0][0] + if i_x > i_y: + data = data.T if not hasattr(data, "mask"): data = ma.array(data, mask=zeros(data.shape, dtype="bool")) return data @@ -1086,19 +1089,19 @@ class UnRegularGridDataset(GridDataset): def load(self): """Load variable (data)""" x_name, y_name = self.coordinates - with Dataset(self.filename) as h: - self.x_dim = h.variables[x_name].dimensions - self.y_dim = h.variables[y_name].dimensions + h = self.nc4file + self.x_dim = h.variables[x_name].dimensions + self.y_dim = h.variables[y_name].dimensions - sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim] - sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim] - self.vars[x_name] = h.variables[x_name][sl_x] - self.vars[y_name] = h.variables[y_name][sl_y] + sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim] + sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim] + self.vars[x_name] = h.variables[x_name][sl_x] + self.vars[y_name] = h.variables[y_name][sl_y] - self.x_c = self.vars[x_name] - self.y_c = self.vars[y_name] + self.x_c = self.vars[x_name] + self.y_c = self.vars[y_name] - self.init_pos_interpolator() + self.init_pos_interpolator() @property def bounds(self):