Skip to content

Commit 9115bab

Browse files
committed
Change to be ok with python3 and to manage detection without information of area
1 parent fcce30a commit 9115bab

File tree

9 files changed

+164
-204
lines changed

9 files changed

+164
-204
lines changed

src/py_eddy_tracker/grid/__init__.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pyproj import Proj
66
from numpy import unique, array, unravel_index, r_, floor, interp, arange, \
77
sin, cos, deg2rad, arctan2, sqrt, pi, zeros, reciprocal, ma, empty, \
8-
concatenate
8+
concatenate, bytes_
99
import logging
1010
from ..tracking_objects import nearest
1111
from re import compile as re_compile
@@ -20,16 +20,35 @@ def browse_dataset_in(data_dir, files_model, date_regexp, date_model,
2020
full_path = join_path(data_dir, files_model)
2121
logging.info('Search files : %s', full_path)
2222

23-
dataset_list = array(glob(full_path),
23+
filenames = bytes_(glob(full_path))
24+
dataset_list = empty(len(filenames),
2425
dtype=[('filename', 'S256'),
2526
('date', 'datetime64[D]'),
2627
])
28+
dataset_list['filename'] = bytes_(glob(full_path))
2729

2830
logging.info('%s grids available', dataset_list.shape[0])
31+
mode_attrs = False
32+
if '(' not in date_regexp:
33+
logging.debug('Attrs date : %s', date_regexp)
34+
mode_attrs = date_regexp.strip().split(':')
35+
else:
36+
logging.debug('Pattern date : %s', date_regexp)
37+
2938
for item in dataset_list:
30-
result = pattern_regexp.match(item['filename'])
31-
if result:
32-
str_date = result.groups()[0]
39+
str_date = None
40+
if mode_attrs:
41+
with Dataset(item['filename'].decode("utf-8")) as h:
42+
if len(mode_attrs) == 1:
43+
str_date = getattr(h, mode_attrs[0])
44+
else:
45+
str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1])
46+
else:
47+
result = pattern_regexp.match(str(item['filename']))
48+
if result:
49+
str_date = result.groups()[0]
50+
51+
if str_date is not None:
3352
item['date'] = datetime.strptime(str_date, date_model).date()
3453

3554
dataset_list.sort(order=['date', 'filename'])
@@ -43,6 +62,9 @@ def browse_dataset_in(data_dir, files_model, date_regexp, date_model,
4362
dataset_list = dataset_list[::sub_sampling_step]
4463

4564
if start_date is not None or end_date is not None:
65+
logging.info('Available grid from %s to %s',
66+
dataset_list[0]['date'],
67+
dataset_list[-1]['date'])
4668
logging.info('Filtering grid by time %s, %s', start_date, end_date)
4769
mask = (dataset_list['date'] >= start_date) * (
4870
dataset_list['date'] <= end_date)
@@ -148,14 +170,9 @@ def read_nc(self, varname, indices=slice(None)):
148170
varname : variable ('temp', 'mask_rho', etc) to read
149171
indices : slice
150172
"""
151-
with Dataset(self.grid_filename) as h_nc:
173+
with Dataset(self.grid_filename.decode("utf-8")) as h_nc:
152174
return h_nc.variables[varname][indices]
153175

154-
@property
155-
def nc_variables(self):
156-
with Dataset(self.grid_filename) as h_nc:
157-
return h_nc.variables.keys()
158-
159176
@property
160177
def view(self):
161178
return (self.slice_j,
@@ -179,7 +196,7 @@ def read_nc_att(self, varname, att):
179196
varname : variable ('temp', 'mask_rho', etc) to read
180197
att : string of attribute, eg. 'valid_range'
181198
"""
182-
with Dataset(self.grid_filename) as h_nc:
199+
with Dataset(self.grid_filename.decode("utf-8")) as h_nc:
183200
return getattr(h_nc.variables[varname], att)
184201

185202
@property
@@ -388,27 +405,14 @@ def vv2vr(vv_in, m_p, l_p):
388405
mshp, lshp = vv_in.shape
389406
return vv2vr(vv_in, mshp + 1, lshp)
390407

391-
def rho2u_2d(self, rho_in):
392-
"""
393-
Convert a 2D field at rho points to a field at u_val points
394-
"""
395-
assert rho_in.ndim == 2, 'rho_in must be 2d'
396-
return ((rho_in[:, :-1] + rho_in[:, 1:]) / 2).squeeze()
397-
398-
def rho2v_2d(self, rho_in):
399-
"""
400-
Convert a 2D field at rho points to a field at v_val points
401-
"""
402-
assert rho_in.ndim == 2, 'rho_in must be 2d'
403-
return ((rho_in[:-1] + rho_in[1:]) / 2).squeeze()
404-
405408
def uvmask(self):
406409
"""
407410
Get mask at U and V points
408411
"""
409412
logging.info('--- Computing umask and vmask for padded grid')
410-
self._umask = self.mask[:, :-1] * self.mask[:, 1:]
411-
self._vmask = self.mask[:-1] * self.mask[1:]
413+
if getattr(self, 'mask', None) is not None:
414+
self._umask = self.mask[:, :-1] * self.mask[:, 1:]
415+
self._vmask = self.mask[:-1] * self.mask[1:]
412416

413417
def set_basemap(self, with_pad=True):
414418
"""
@@ -477,10 +481,11 @@ def uspd(self):
477481
"""Get scalar speed
478482
"""
479483
uspd = (self.u_val ** 2 + self.v_val ** 2) ** .5
480-
if hasattr(uspd, 'mask'):
481-
uspd.mask += self.mask[self.view_unpad]
482-
else:
483-
uspd = ma.array(uspd, mask=self.mask[self.view_unpad])
484+
if self.mask is not None:
485+
if hasattr(uspd, 'mask'):
486+
uspd.mask += self.mask[self.view_unpad]
487+
else:
488+
uspd = ma.array(uspd, mask=self.mask[self.view_unpad])
484489
return uspd
485490

486491
def set_interp_coeffs(self, sla, uspd):

src/py_eddy_tracker/grid/aviso.py

Lines changed: 38 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from scipy import spatial
55
from dateutil import parser
66
from numpy import meshgrid, zeros, array, where, ma, argmin, vstack, ones, \
7-
newaxis, sqrt, diff, r_
7+
newaxis, sqrt, diff, r_, arange
8+
from scipy.interpolate import interp1d
89
import logging
910
from netCDF4 import Dataset
1011

@@ -33,6 +34,8 @@ class AvisoGrid(BaseData):
3334
'fillval',
3435
'_angle',
3536
'sla_coeffs',
37+
'xinterp',
38+
'yinterp',
3639
'uspd_coeffs',
3740
'__lon',
3841
'__lat',
@@ -50,22 +53,37 @@ def __init__(self, aviso_file, the_domain,
5053
super(AvisoGrid, self).__init__()
5154
logging.info('Initialising the *AVISO_grid*')
5255
self.grid_filename = aviso_file
53-
self.domain = the_domain
54-
self.lonmin = float(lonmin)
55-
self.lonmax = float(lonmax)
56-
self.latmin = float(latmin)
57-
self.latmax = float(latmax)
58-
self.grid_filename = aviso_file
59-
56+
6057
self.lon_name = lon_name
6158
self.lat_name = lat_name
6259
self.grid_name = grid_name
6360

6461
self._lon = self.read_nc(self.lon_name)
6562
self._lat = self.read_nc(self.lat_name)
63+
if the_domain is None:
64+
self.domain = 'Automatic Domain'
65+
dlon = abs(self._lon[1] - self._lon[0])
66+
dlat = abs(self._lat[1] - self._lat[0])
67+
self.lonmin = float(self._lon.min()) + dlon * 2
68+
self.lonmax = float(self._lon.max()) - dlon * 2
69+
self.latmin = float(self._lat.min()) + dlat * 2
70+
self.latmax = float(self._lat.max()) - dlat * 2
71+
if ((self._lon[-1] + dlon) % 360) == self._lon[0]:
72+
self.domain = 'Global'
73+
self.lonmin = -100.
74+
self.lonmax = 290.
75+
self.latmin = -80.
76+
self.latmax = 80.
77+
else:
78+
self.domain = the_domain
79+
self.lonmin = float(lonmin)
80+
self.lonmax = float(lonmax)
81+
self.latmin = float(latmin)
82+
self.latmax = float(latmax)
83+
6684
self.fillval = self.read_nc_att(self.grid_name, '_FillValue')
6785

68-
if lonmin < 0 and lonmax <= 0:
86+
if self.lonmin < 0 and self.lonmax <= 0:
6987
self._lon -= 360.
7088
self._lon, self._lat = meshgrid(self._lon, self._lat)
7189
self._angle = zeros(self._lon.shape)
@@ -75,7 +93,7 @@ def __init__(self, aviso_file, the_domain,
7593

7694
# zero_crossing, used for handling a longitude range that
7795
# crosses zero degree meridian
78-
if lonmin < 0 and lonmax >= 0 and 'MedSea' not in self.domain:
96+
if self.lonmin < 0 and self.lonmax >= 0 and 'MedSea' not in self.domain:
7997
if ((self.lonmax < self._lon.max()) and (self.lonmax > self._lon.min()) and (self.lonmin < self._lon.max()) and (self.lonmin > self._lon.min())):
8098
pass
8199
else:
@@ -92,9 +110,15 @@ def __init__(self, aviso_file, the_domain,
92110
self.get_aviso_f_pm_pn()
93111
self.set_u_v_eke()
94112
self.shape = self.lon.shape
95-
# pad2 = 2 * self.pad
96-
# self.shape = (self.f_coriolis.shape[0] - pad2,
97-
# self.f_coriolis.shape[1] - pad2)
113+
114+
# self.init_pos_interpolator()
115+
116+
def init_pos_interpolator(self):
117+
self.xinterp = interp1d(self.lon[0].copy(), arange(self.lon.shape[1]), assume_sorted=True, copy=False, fill_value=(0, -1), bounds_error=False, kind='nearest')
118+
self.yinterp = interp1d(self.lat[:, 0].copy(), arange(self.lon.shape[0]), assume_sorted=True, copy=False, fill_value=(0, -1), bounds_error=False, kind='nearest')
119+
120+
def nearest_indice(self, lon, lat):
121+
return self.xinterp(lon), self.yinterp(lat)
98122

99123
def set_filename(self, file_name):
100124
self.grid_filename = file_name
@@ -110,7 +134,7 @@ def get_aviso_data(self, aviso_file, dimensions=None):
110134
if units not in self.KNOWN_UNITS:
111135
raise Exception('Unknown units : %s' % units)
112136

113-
with Dataset(self.grid_filename) as h_nc:
137+
with Dataset(self.grid_filename.decode('utf-8')) as h_nc:
114138
grid_dims = array(h_nc.variables[self.grid_name].dimensions)
115139
lat_dim = h_nc.variables[self.lat_name].dimensions[0]
116140
lon_dim = h_nc.variables[self.lon_name].dimensions[0]
@@ -159,44 +183,6 @@ def set_mask(self, sla):
159183
sea_label = self.labels[plus9, plus200]
160184
self.mask += self.labels != sea_label
161185

162-
def fillmask(self, data, mask):
163-
"""
164-
Fill missing values in an array with an average of nearest
165-
neighbours
166-
From http://permalink.gmane.org/gmane.comp.python.scientific.user/19610
167-
"""
168-
raise Exception('Use convolution to fill data')
169-
assert data.ndim == 2, 'data must be a 2D array.'
170-
fill_value = 9999.99
171-
data[mask == 0] = fill_value
172-
173-
# Create (i, j) point arrays for good and bad data.
174-
# Bad data are marked by the fill_value, good data elsewhere.
175-
igood = vstack(where(data != fill_value)).T
176-
ibad = vstack(where(data == fill_value)).T
177-
178-
# Create a tree for the bad points, the points to be filled
179-
tree = spatial.cKDTree(igood)
180-
181-
# Get the four closest points to the bad points
182-
# here, distance is squared
183-
dist, iquery = tree.query(ibad, k=4, p=2)
184-
185-
# Create a normalised weight, the nearest points are weighted as 1.
186-
# Points greater than one are then set to zero
187-
weight = dist / (dist.min(axis=1)[:, newaxis])
188-
weight *= ones(dist.shape)
189-
weight[weight > 1.] = 0.
190-
191-
# Multiply the queried good points by the weight, selecting only the
192-
# nearest points. Divide by the number of nearest points to get average
193-
xfill = weight * data[igood[:, 0][iquery], igood[:, 1][iquery]]
194-
xfill = (xfill / weight.sum(axis=1)[:, newaxis]).sum(axis=1)
195-
196-
# Place average of nearest good points, xfill, into bad point locations
197-
data[ibad[:, 0], ibad[:, 1]] = xfill
198-
return data
199-
200186
@property
201187
def lon(self):
202188
if self.__lon is None:

src/py_eddy_tracker/observations.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ def compute_pixel_path(self, x0, y0, x1, y1):
110110
# Delta index of y
111111
d_y = i_y1 - i_y0
112112

113-
d_max = maximum(abs(d_x), abs(d_y))
113+
d_max = int_(maximum(abs(d_x), abs(d_y)))
114114

115115
# Compute number of pixel which we go trought
116-
nb_value = (abs(d_max) + 1).sum()
116+
nb_value = int((abs(d_max) + 1).sum())
117117
# Create an empty array to store value of pixel across the travel
118118
# Max Index ~65000
119119
i_g = empty(nb_value, dtype='u2')
@@ -257,9 +257,12 @@ def merge(self, other):
257257
"""
258258
nb_obs_self = len(self)
259259
nb_obs = nb_obs_self + len(other)
260-
eddies = self.__class__(size=nb_obs)
261-
eddies.obs[:nb_obs_self] = self.obs[:]
262-
eddies.obs[nb_obs_self:] = other.obs[:]
260+
eddies = self.new_like(self, nb_obs)
261+
other_keys = other.obs.dtype.fields.keys()
262+
for key in eddies.obs.dtype.fields.keys():
263+
eddies.obs[key][:nb_obs_self] = self.obs[key][:]
264+
if key in other_keys:
265+
eddies.obs[key][nb_obs_self:] = other.obs[key][:]
263266
eddies.sign_type = self.sign_type
264267
return eddies
265268

@@ -294,11 +297,7 @@ def insert_observations(self, other, index):
294297
return self
295298
if index < 0:
296299
index = self_size + index + 1
297-
eddies = self.__class__(new_size,
298-
track_extra_variables=self.track_extra_variables,
299-
track_array_variables=self.track_array_variables,
300-
array_variables=self.array_variables
301-
)
300+
eddies = self.new_like(self, new_size)
302301
eddies.obs[:index] = self.obs[:index]
303302
eddies.obs[index: index + insert_size] = other.obs
304303
eddies.obs[index + insert_size:] = self.obs[index:]
@@ -323,13 +322,21 @@ def distance(self, other):
323322
dist_result)
324323
return dist_result
325324

325+
@staticmethod
326+
def new_like(eddies, new_size):
327+
return eddies.__class__(new_size,
328+
track_extra_variables=eddies.track_extra_variables,
329+
track_array_variables=eddies.track_array_variables,
330+
array_variables=eddies.array_variables
331+
)
332+
326333
def index(self, index):
327334
"""Return obs from self at the index
328335
"""
329336
size = 1
330337
if hasattr(index, '__iter__'):
331338
size = len(index)
332-
eddies = self.__class__(size, self.track_extra_variables)
339+
eddies = self.new_like(self, size)
333340
eddies.obs[:] = self.obs[index]
334341
return eddies
335342

@@ -761,7 +768,10 @@ def create_variable(handler_nc, kwargs_variable, attr_variable,
761768
zlib=True,
762769
complevel=1,
763770
**kwargs_variable)
764-
for attr, attr_value in attr_variable.iteritems():
771+
attrs = list(attr_variable.keys())
772+
attrs.sort()
773+
for attr in attrs:
774+
attr_value = attr_variable[attr]
765775
var.setncattr(attr, attr_value)
766776
if scale_factor is not None:
767777
var.scale_factor = scale_factor

0 commit comments

Comments
 (0)