Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Modification to intake to allow use netcdf data directly, e.g. via xa…
…rray. Updates requirements.txt which has a long-standing error with matplotlib versions after 3.7.1 due to their updated contour algorithm.
  • Loading branch information
wienkers committed Jan 15, 2024
commit e2e50f514b1566ed7b6e099235ba9f2610e042ec
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
matplotlib
matplotlib==3.7.1
opencv-python
pint
polygon3
Expand All @@ -8,4 +8,4 @@ scipy
zarr
netCDF4
numpy
numba
numba
147 changes: 75 additions & 72 deletions src/py_eddy_tracker/dataset/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class GridDataset(object):
"filename",
"dimensions",
"indexs",
"nc4file",
"variables_description",
"global_attrs",
"vars",
Expand All @@ -275,6 +276,7 @@ def __init__(
indexs=None,
unset=False,
nan_masking=False,
nc4file=None,
):
"""
:param str filename: Filename to load
Expand All @@ -301,6 +303,7 @@ def __init__(
self.coordinates = x_name, y_name
self.vars = dict()
self.indexs = dict() if indexs is None else indexs
self.nc4file = Dataset(filename, "r") if nc4file is None else nc4file
if centered is None:
logger.warning(
"We assume pixel position of grid is centered for %s", filename
Expand Down Expand Up @@ -344,25 +347,25 @@ def load_general_features(self):
logger.debug(
"Load general feature from %(filename)s", dict(filename=self.filename)
)
with Dataset(self.filename) as h:
# Load generals
self.dimensions = {i: len(v) for i, v in h.dimensions.items()}
self.variables_description = dict()
for i, v in h.variables.items():
args = (i, v.datatype)
kwargs = dict(dimensions=v.dimensions, zlib=True)
if hasattr(v, "_FillValue"):
kwargs["fill_value"] = (v._FillValue,)
attrs = dict()
for attr in v.ncattrs():
if attr in kwargs.keys():
continue
if attr == "_FillValue":
continue
attrs[attr] = getattr(v, attr)
self.variables_description[i] = dict(
args=args, kwargs=kwargs, attrs=attrs, infos=dict()
)
h = self.nc4file
# Load generals
self.dimensions = {i: len(v) for i, v in h.dimensions.items()}
self.variables_description = dict()
for i, v in h.variables.items():
args = (i, v.datatype)
kwargs = dict(dimensions=v.dimensions, zlib=True)
if hasattr(v, "_FillValue"):
kwargs["fill_value"] = (v._FillValue,)
attrs = dict()
for attr in v.ncattrs():
if attr in kwargs.keys():
continue
if attr == "_FillValue":
continue
attrs[attr] = getattr(v, attr)
self.variables_description[i] = dict(
args=args, kwargs=kwargs, attrs=attrs, infos=dict()
)
self.global_attrs = {attr: getattr(h, attr) for attr in h.ncattrs()}

def write(self, filename):
Expand Down Expand Up @@ -407,14 +410,14 @@ def load(self):
Get coordinates and setup coordinates function
"""
x_name, y_name = self.coordinates
with Dataset(self.filename) as h:
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions
h = self.nc4file
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions

sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]
sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]

self.setup_coordinates()

Expand Down Expand Up @@ -481,10 +484,10 @@ def units(self, varname):
stored_units = self.variables_description[varname]["attrs"].get("units", None)
if stored_units is not None:
return stored_units
with Dataset(self.filename) as h:
var = h.variables[varname]
if hasattr(var, "units"):
return var.units
h = self.nc4file
var = h.variables[varname]
if hasattr(var, "units"):
return var.units

@property
def variables(self):
Expand Down Expand Up @@ -535,24 +538,24 @@ def grid(self, varname, indexs=None):
"Load %(varname)s from %(filename)s",
dict(varname=varname, filename=self.filename),
)
with Dataset(self.filename) as h:
dims = h.variables[varname].dimensions
sl = [
indexs.get(
dim,
self.indexs.get(
dim, slice(None) if dim in coordinates_dims else 0
),
)
for dim in dims
]
self.vars[varname] = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
self.variables_description[varname]["infos"]["transpose"] = True
self.vars[varname] = self.vars[varname].T
h = self.nc4file
dims = h.variables[varname].dimensions
sl = [
indexs.get(
dim,
self.indexs.get(
dim, slice(None) if dim in coordinates_dims else 0
),
)
for dim in dims
]
self.vars[varname] = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
self.variables_description[varname]["infos"]["transpose"] = True
self.vars[varname] = self.vars[varname].T
if self.nan_mask:
self.vars[varname] = ma.array(
self.vars[varname],
Expand All @@ -578,20 +581,20 @@ def grid_tiles(self, varname, slice_x, slice_y):
slice_x=slice_x,
),
)
with Dataset(self.filename) as h:
dims = h.variables[varname].dimensions
sl = [
(slice_x if dim in list(self.x_dim) else slice_y)
if dim in coordinates_dims
else 0
for dim in dims
]
data = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
data = data.T
h = self.nc4file
dims = h.variables[varname].dimensions
sl = [
(slice_x if dim in list(self.x_dim) else slice_y)
if dim in coordinates_dims
else 0
for dim in dims
]
data = h.variables[varname][sl]
if len(self.x_dim) == 1:
i_x = where(array(dims) == self.x_dim)[0][0]
i_y = where(array(dims) == self.y_dim)[0][0]
if i_x > i_y:
data = data.T
if not hasattr(data, "mask"):
data = ma.array(data, mask=zeros(data.shape, dtype="bool"))
return data
Expand Down Expand Up @@ -1086,19 +1089,19 @@ class UnRegularGridDataset(GridDataset):
def load(self):
"""Load variable (data)"""
x_name, y_name = self.coordinates
with Dataset(self.filename) as h:
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions
h = self.nc4file
self.x_dim = h.variables[x_name].dimensions
self.y_dim = h.variables[y_name].dimensions

sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]
sl_x = [self.indexs.get(dim, slice(None)) for dim in self.x_dim]
sl_y = [self.indexs.get(dim, slice(None)) for dim in self.y_dim]
self.vars[x_name] = h.variables[x_name][sl_x]
self.vars[y_name] = h.variables[y_name][sl_y]

self.x_c = self.vars[x_name]
self.y_c = self.vars[y_name]
self.x_c = self.vars[x_name]
self.y_c = self.vars[y_name]

self.init_pos_interpolator()
self.init_pos_interpolator()

@property
def bounds(self):
Expand Down