Skip to content

Commit a44b2e1

Browse files
author
adelepoulle
committed
update tracking
1 parent e5af5ad commit a44b2e1

File tree

2 files changed

+112
-23
lines changed

2 files changed

+112
-23
lines changed

src/py_eddy_tracker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class ColoredFormatter(logging.Formatter):
99
ERROR="\033[31;47m",
1010
WARNING="\033[30;47m",
1111
INFO="\033[36m",
12-
DEBUG="\033[34m",
12+
DEBUG="\033[34m\t",
1313
)
1414

1515
def __init__(self, message):

src/py_eddy_tracker/observations.py

Lines changed: 111 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
===========================================================================
2929
3030
"""
31-
from numpy import zeros, empty, nan, arange, interp, where, unique
31+
from numpy import zeros, empty, nan, arange, interp, where, unique, \
32+
ma
3233
from netCDF4 import Dataset
3334
from py_eddy_tracker.tools import distance_matrix
3435
from . import VAR_DESCR, VAR_DESCR_inv
@@ -82,9 +83,11 @@ def __getitem__(self, attr):
8283
if attr in self.elements:
8384
return self.observations[attr]
8485
raise KeyError('%s unknown' % attr)
85-
86+
8687
@property
8788
def dtype(self):
89+
"""Return dtype to build numpy array
90+
"""
8891
dtype = list()
8992
for elt in self.elements:
9093
dtype.append((elt, VAR_DESCR[elt][
@@ -94,6 +97,8 @@ def dtype(self):
9497

9598
@property
9699
def elements(self):
100+
"""Return all variable name
101+
"""
97102
elements = [
98103
'lon', # 'centlon'
99104
'lat', # 'centlat'
@@ -113,9 +118,13 @@ def elements(self):
113118
return elements
114119

115120
def coherence(self, other):
121+
"""Check coherence between two dataset
122+
"""
116123
return self.track_extra_variables == other.track_extra_variables
117-
124+
118125
def merge(self, other):
126+
"""Merge two dataset
127+
"""
119128
nb_obs_self = len(self)
120129
nb_obs = nb_obs_self + len(other)
121130
eddies = self.__class__(size=nb_obs)
@@ -126,6 +135,8 @@ def merge(self, other):
126135

127136
@property
128137
def obs(self):
138+
"""returan array observations
139+
"""
129140
return self.observations
130141

131142
def __len__(self):
@@ -136,6 +147,8 @@ def __iter__(self):
136147
yield obs
137148

138149
def insert_observations(self, other, index):
150+
"""Insert other obs in self at the index
151+
"""
139152
if not self.coherence(other):
140153
raise Exception('Observations with no coherence')
141154
insert_size = len(other.obs)
@@ -156,6 +169,8 @@ def insert_observations(self, other, index):
156169
return self
157170

158171
def append(self, other):
172+
"""Merge
173+
"""
159174
return self + other
160175

161176
def __add__(self, other):
@@ -172,13 +187,15 @@ def distance(self, other):
172187
return dist_result
173188

174189
def index(self, index):
190+
"""Return obs from self at the index
191+
"""
175192
size = 1
176193
if hasattr(index, '__iter__'):
177194
size = len(index)
178195
eddies = self.__class__(size, self.track_extra_variables)
179196
eddies.obs[:] = self.obs[index]
180197
return eddies
181-
198+
182199
@staticmethod
183200
def load_from_netcdf(filename):
184201
with Dataset(filename) as h_nc:
@@ -187,22 +204,79 @@ def load_from_netcdf(filename):
187204
for variable in h_nc.variables:
188205
if variable == 'cyc':
189206
continue
190-
eddies.obs[VAR_DESCR_inv[variable]] = h_nc.variables[variable][:]
207+
eddies.obs[VAR_DESCR_inv[variable]
208+
] = h_nc.variables[variable][:]
191209
eddies.sign_type = h_nc.variables['cyc'][0]
192210
return eddies
193211

194212
def tracking(self, other):
195-
"""Track obs between from self to other
213+
"""Track obs between self and other
196214
"""
197-
dist = self.distance(other)
198-
i_self, i_other = where(dist < 20.)
215+
cost = self.distance(other)
216+
# Links available which respect a maximal cost
217+
mask_accept_cost = cost < 75
218+
cost = ma.array(cost, mask=-mask_accept_cost, dtype='i2')
219+
220+
# Count number of link by self obs and other obs
221+
self_links = mask_accept_cost.sum(axis=1)
222+
other_links = mask_accept_cost.sum(axis=0)
223+
max_links = max(self_links.max(), other_links.max())
224+
if max_links > 5:
225+
logging.warning('One observation have %d links', max_links)
226+
227+
# If some obs have multiple link, we keep only one link by eddy
228+
eddies_separation = 1 < self_links
229+
eddies_merge = 1 < other_links
230+
test = eddies_separation.any() or eddies_merge.any()
231+
if test:
232+
# We extract matrix which contains concflict
233+
obs_linking_to_self = mask_accept_cost[eddies_separation
234+
].any(axis=0)
235+
obs_linking_to_other = mask_accept_cost[:, eddies_merge
236+
].any(axis=1)
237+
i_self_keep = where(obs_linking_to_other + eddies_separation)[0]
238+
i_other_keep = where(obs_linking_to_self + eddies_merge)[0]
239+
240+
# Cost to resolve conflict
241+
cost_reduce = cost[i_self_keep][:, i_other_keep]
242+
shape = cost_reduce.shape
243+
logging.debug('Shape conflict matrix : %s', shape)
244+
245+
matrix_size = shape[0] * shape[1]
246+
if (matrix_size) >= 20000:
247+
logging.warning('High number of conflict : %d (matrix_size)',
248+
matrix_size)
249+
250+
links_resolve = 0
251+
while False in cost_reduce.mask:
252+
i_min_value = cost_reduce.argmin()
253+
i, j = i_min_value / shape[1], i_min_value % shape[1]
254+
# Set to False all link
255+
mask_accept_cost[i_self_keep[i]] = False
256+
mask_accept_cost[:, i_other_keep[j]] = False
257+
cost_reduce.mask[i] = True
258+
cost_reduce.mask[:, j] = True
259+
# we active only this link
260+
mask_accept_cost[i_self_keep[i], i_other_keep[j]] = True
261+
links_resolve += 1
262+
logging.debug('%d links resolve', links_resolve)
263+
264+
i_self, i_other = where(mask_accept_cost)
199265

200266
logging.debug('%d matched with previous', i_self.shape[0])
267+
268+
# Check
269+
if unique(i_other).shape[0] != i_other.shape[0]:
270+
raise Exception()
271+
if unique(i_self).shape[0] != i_self.shape[0]:
272+
raise Exception()
201273
return i_self, i_other
202274

203275

204276
class VirtualEddiesObservations(EddiesObservations):
205-
277+
"""Class to work with virtual obs
278+
"""
279+
206280
@property
207281
def elements(self):
208282
elements = super(VirtualEddiesObservations, self).elements
@@ -211,7 +285,9 @@ def elements(self):
211285

212286

213287
class TrackEddiesObservations(EddiesObservations):
214-
288+
"""Class to practice Tracking on observations
289+
"""
290+
215291
def filled_by_interpolation(self, mask):
216292
"""Filled selected values by interpolation
217293
"""
@@ -228,42 +304,47 @@ def filled_by_interpolation(self, mask):
228304
self.obs[var][-mask])
229305

230306
def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True):
231-
m = nb_obs >= nb_min
232-
nb_obs_select = m.sum()
307+
"""Select eddies which are longer than nb_min
308+
"""
309+
mask = nb_obs >= nb_min
310+
nb_obs_select = mask.sum()
233311
logging.info('Selection of %d observations', nb_obs_select)
234312
eddies = TrackEddiesObservations(size=nb_obs_select)
235313
eddies.sign_type = self.sign_type
236314
for var, _ in eddies.obs.dtype.descr:
237-
eddies.obs[var] = self.obs[var][m]
315+
eddies.obs[var] = self.obs[var][mask]
238316
if compress_id:
239317
list_id = unique(eddies.obs['track'])
240318
list_id.sort()
241319
id_translate = arange(list_id.max() + 1)
242320
id_translate[list_id] = arange(len(list_id)) + 1
243321
eddies.obs['track'] = id_translate[eddies.obs['track']]
244322
return eddies
245-
323+
246324
@property
247325
def elements(self):
248326
elements = super(TrackEddiesObservations, self).elements
249327
elements.extend(['track', 'n', 'virtual'])
250328
return elements
251-
252-
def create_variable(self, handler_nc, kwargs_variable,
329+
330+
@staticmethod
331+
def create_variable(handler_nc, kwargs_variable,
253332
attr_variable, data, scale_factor=None):
333+
"""Create variable
334+
"""
254335
var = handler_nc.createVariable(
255336
zlib=True,
256337
complevel=1,
257338
**kwargs_variable)
258339
for attr, attr_value in attr_variable.iteritems():
259340
var.setncattr(attr, attr_value)
260-
341+
261342
var[:] = data
262-
343+
263344
#~ var.set_auto_maskandscale(False)
264345
if scale_factor is not None:
265346
var.scale_factor = scale_factor
266-
347+
267348
try:
268349
var.setncattr('min', var[:].min())
269350
var.setncattr('max', var[:].max())
@@ -291,7 +372,10 @@ def write_netcdf(self):
291372
dimensions=VAR_DESCR[name]['nc_dims']),
292373
VAR_DESCR[name]['nc_attr'],
293374
self.observations[name],
294-
scale_factor=None if 'scale_factor' not in VAR_DESCR[name] else VAR_DESCR[name]['scale_factor'])
375+
scale_factor=None
376+
if 'scale_factor' not in VAR_DESCR[name] else
377+
VAR_DESCR[name]['scale_factor']
378+
)
295379

296380
# Add cyclonic information
297381
self.create_variable(
@@ -305,7 +389,12 @@ def write_netcdf(self):
305389
self.set_global_attr_netcdf(h_nc)
306390

307391
def set_global_attr_netcdf(self, h_nc):
308-
h_nc.title = 'Cyclonic' if self.sign_type == -1 else 'Anticyclonic' + ' eddy tracks'
392+
"""Set global attr
393+
"""
394+
if self.sign_type == -1:
395+
h_nc.title = 'Cyclonic'
396+
else:
397+
h_nc.title = 'Anticyclonic' + ' eddy tracks'
309398
#~ h_nc.grid_filename = self.grd.grid_filename
310399
#~ h_nc.grid_date = str(self.grd.grid_date)
311400
#~ h_nc.product = self.product
@@ -324,7 +413,7 @@ def set_global_attr_netcdf(self, h_nc):
324413
#~ h_nc.evolve_amp_max = self.evolve_amp_max
325414
#~ h_nc.evolve_area_min = self.evolve_area_min
326415
#~ h_nc.evolve_area_max = self.evolve_area_max
327-
#~
416+
328417
#~ h_nc.llcrnrlon = self.grd.lonmin
329418
#~ h_nc.urcrnrlon = self.grd.lonmax
330419
#~ h_nc.llcrnrlat = self.grd.latmin

0 commit comments

Comments
 (0)