Skip to content

Commit 2c811b8

Browse files
author
adelepoulle
committed
Add modification to manage array variable
1 parent 472cea2 commit 2c811b8

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed

src/py_eddy_tracker/observations.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ class EddiesObservations(object):
6767
Time
6868
"""
6969

70-
def __init__(self, size=0, track_extra_variables=False):
70+
def __init__(self, size=0, track_extra_variables=False, track_array_variables=0, array_variables=[]):
7171
self.track_extra_variables = track_extra_variables
72+
self.track_array_variables = track_array_variables
73+
self.array_variables = array_variables
7274
for elt in self.elements:
7375
if elt not in VAR_DESCR:
7476
raise Exception('Unknown element : %s' % elt)
@@ -90,9 +92,13 @@ def dtype(self):
9092
"""
9193
dtype = list()
9294
for elt in self.elements:
93-
dtype.append((elt, VAR_DESCR[elt][
95+
data_type = VAR_DESCR[elt][
9496
'compute_type' if 'compute_type' in VAR_DESCR[elt] else
95-
'nc_type']))
97+
'nc_type']
98+
if elt in self.array_variables:
99+
dtype.append((elt, data_type, (self.track_array_variables,)))
100+
else:
101+
dtype.append((elt, data_type))
96102
return dtype
97103

98104
@property
@@ -108,7 +114,9 @@ def elements(self):
108114
'speed_radius', # 'uavg'
109115
'eke', # 'teke'
110116
'time'] # 'rtime'
111-
117+
if self.track_array_variables > 0:
118+
elements += self.array_variables
119+
112120
if self.track_extra_variables:
113121
elements += ['contour_e',
114122
'contour_s',
@@ -120,7 +128,9 @@ def elements(self):
120128
def coherence(self, other):
121129
"""Check coherence between two dataset
122130
"""
123-
return self.track_extra_variables == other.track_extra_variables
131+
test = self.track_array_variables == other.track_array_variables
132+
test *= self.array_variables == other.array_variables
133+
return test
124134

125135
def merge(self, other):
126136
"""Merge two dataset
@@ -133,9 +143,12 @@ def merge(self, other):
133143
eddies.sign_type = self.sign_type
134144
return eddies
135145

146+
def reset(self):
147+
self.observations = np.zeros(0, dtype=self.dtype)
148+
136149
@property
137150
def obs(self):
138-
"""returan array observations
151+
"""return an array observations
139152
"""
140153
return self.observations
141154

@@ -161,7 +174,10 @@ def insert_observations(self, other, index):
161174
return self
162175
if index < 0:
163176
index = self_size + index + 1
164-
eddies = self.__class__(new_size, self.track_extra_variables)
177+
eddies = self.__class__(new_size, self.track_extra_variables,
178+
track_array_variables=self.track_array_variables,
179+
array_variables=self.array_variables
180+
)
165181
eddies.obs[:index] = self.obs[:index]
166182
eddies.obs[index: index + insert_size] = other.obs
167183
eddies.obs[index + insert_size:] = self.obs[index:]
@@ -198,9 +214,18 @@ def index(self, index):
198214

199215
@staticmethod
200216
def load_from_netcdf(filename):
217+
array_dim = 'NbSample'
201218
with Dataset(filename) as h_nc:
202219
nb_obs = len(h_nc.dimensions['Nobs'])
203-
eddies = EddiesObservations(size=nb_obs)
220+
kwargs = dict()
221+
if array_dim in h_nc.dimensions:
222+
kwargs['track_array_variables'] = len(h_nc.dimensions[array_dim])
223+
kwargs['array_variables'] = []
224+
for variable in h_nc.variables:
225+
if array_dim in h_nc.variables[variable].dimensions:
226+
kwargs['array_variables'].append(str(variable))
227+
228+
eddies = EddiesObservations(size=nb_obs, ** kwargs)
204229
for variable in h_nc.variables:
205230
if variable == 'cyc':
206231
continue
@@ -297,8 +322,9 @@ def filled_by_interpolation(self, mask):
297322
nb_obs = len(self)
298323
index = arange(nb_obs)
299324

300-
for var, _ in self.obs.dtype.descr:
301-
if var in ['n', 'virtual', 'track']:
325+
for field in self.obs.dtype.descr:
326+
var = field[0]
327+
if var in ['n', 'virtual', 'track'] or var in self.array_variables:
302328
continue
303329
self.obs[var][mask] = interp(index[mask], index[-mask],
304330
self.obs[var][-mask])
@@ -309,9 +335,14 @@ def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True):
309335
mask = nb_obs >= nb_min
310336
nb_obs_select = mask.sum()
311337
logging.info('Selection of %d observations', nb_obs_select)
312-
eddies = TrackEddiesObservations(size=nb_obs_select)
338+
eddies = TrackEddiesObservations(
339+
size=nb_obs_select,
340+
track_array_variables=self.track_array_variables,
341+
array_variables=self.array_variables
342+
)
313343
eddies.sign_type = self.sign_type
314-
for var, _ in eddies.obs.dtype.descr:
344+
for field in self.obs.dtype.descr:
345+
var = field[0]
315346
eddies.obs[var] = self.obs[var][mask]
316347
if compress_id:
317348
list_id = unique(eddies.obs['track'])
@@ -329,22 +360,20 @@ def elements(self):
329360

330361
@staticmethod
331362
def create_variable(handler_nc, kwargs_variable,
332-
attr_variable, data, scale_factor=None):
333-
"""Create variable
334-
"""
363+
attr_variable, data, scale_factor=None, add_offset=None):
335364
var = handler_nc.createVariable(
336365
zlib=True,
337366
complevel=1,
338367
**kwargs_variable)
339368
for attr, attr_value in attr_variable.iteritems():
340369
var.setncattr(attr, attr_value)
341-
342-
var[:] = data
343-
344-
#~ var.set_auto_maskandscale(False)
345370
if scale_factor is not None:
346371
var.scale_factor = scale_factor
347-
372+
if add_offset is not None:
373+
var.add_offset = add_offset
374+
else:
375+
var.add_offset = 0
376+
var[:] = data
348377
try:
349378
var.setncattr('min', var[:].min())
350379
var.setncattr('max', var[:].max())
@@ -362,19 +391,21 @@ def write_netcdf(self, path='./'):
362391
# Create dimensions
363392
logging.debug('Create Dimensions "Nobs" : %d', eddy_size)
364393
h_nc.createDimension('Nobs', eddy_size)
394+
if self.track_array_variables != 0:
395+
h_nc.createDimension('NbSample', self.track_array_variables)
365396
# Iter on variables to create:
366-
for name, _ in self.observations.dtype.descr:
397+
for field in self.observations.dtype.descr:
398+
name = field[0]
367399
logging.debug('Create Variable %s', VAR_DESCR[name]['nc_name'])
368400
self.create_variable(
369401
h_nc,
370402
dict(varname=VAR_DESCR[name]['nc_name'],
371-
datatype=VAR_DESCR[name]['nc_type'],
403+
datatype=VAR_DESCR[name]['output_type'],
372404
dimensions=VAR_DESCR[name]['nc_dims']),
373405
VAR_DESCR[name]['nc_attr'],
374406
self.observations[name],
375-
scale_factor=None
376-
if 'scale_factor' not in VAR_DESCR[name] else
377-
VAR_DESCR[name]['scale_factor']
407+
scale_factor=VAR_DESCR[name].get('scale_factor', None),
408+
add_offset=VAR_DESCR[name].get('add_offset', None)
378409
)
379410

380411
# Add cyclonic information

src/py_eddy_tracker/tracking.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,16 @@ def prepare_merging(self):
278278
def merge(self):
279279
"""Merge all the correspondance in one array with all fields
280280
"""
281+
# Start loading identification again to save in the finals tracks
282+
# Load first file
283+
self.swap_dataset(self.datasets[0])
284+
281285
# Start create netcdf to agglomerate all eddy
282-
eddies = TrackEddiesObservations(size=self.nb_obs)
286+
eddies = TrackEddiesObservations(
287+
size=self.nb_obs,
288+
track_array_variables=self.current_obs.track_array_variables,
289+
array_variables=self.current_obs.array_variables,
290+
)
283291

284292
# Calculate the index in each tracks, we compute in u4 and translate
285293
# in u2 (which are limited to 65535)
@@ -291,9 +299,6 @@ def merge(self):
291299
eddies['track'][:] = arange(self.current_id
292300
).repeat(self.nb_obs_by_tracks)
293301

294-
# Start loading identification again to save in the finals tracks
295-
# Load first file
296-
self.swap_dataset(self.datasets[0])
297302
# Set type of eddy with first file
298303
eddies.sign_type = self.current_obs.sign_type
299304
# Fields to copy
@@ -317,7 +322,8 @@ def merge(self):
317322
# Index in the previous file
318323
index_in = self[i]['in'][m_first_obs]
319324
# Copy all variable
320-
for var, _ in fields:
325+
for field in fields:
326+
var = field[0]
321327
eddies[var][index_final[m_first_obs]
322328
] = self.previous_obs[var][index_in]
323329
# Increment
@@ -341,7 +347,8 @@ def merge(self):
341347
index_current = self[i]['out']
342348

343349
# Copy all variable
344-
for var, _ in fields:
350+
for field in fields:
351+
var = field[0]
345352
eddies[var][index_final
346353
] = self.current_obs[var][index_current]
347354

0 commit comments

Comments
 (0)