Skip to content

Commit e2289fa

Browse files
committed
Extraction of association cost
1 parent f1caeb1 commit e2289fa

File tree

4 files changed

+53
-17
lines changed

4 files changed

+53
-17
lines changed

src/py_eddy_tracker/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ def parse_args(self, *args, **kwargs):
167167
description='Virtual observation: 0 for real',
168168
)
169169
),
170+
cost_association=dict(
171+
attr_name=None,
172+
nc_name='cost_association',
173+
nc_type='float32',
174+
nc_dims=('Nobs',),
175+
nc_attr=dict(
176+
long_name='cost_value_to_associate_with_next_observation',
177+
description='Cost value to associate with the next observation',
178+
)
179+
),
170180
lon=dict(
171181
attr_name='lon',
172182
compute_type='float64',

src/py_eddy_tracker/observations.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class EddiesObservations(object):
164164

165165
ELEMENTS = ['lon', 'lat', 'radius_s', 'radius_e', 'amplitude', 'speed_radius', 'time', 'eke',
166166
'shape_error_e', 'shape_error_s', 'nb_contour_selected',
167-
'height_max_speed_contour', 'height_external_contour', 'height_inner_contour']
167+
'height_max_speed_contour', 'height_external_contour', 'height_inner_contour', 'cost_association']
168168

169169
def __init__(self, size=0, track_extra_variables=None,
170170
track_array_variables=0, array_variables=None):
@@ -452,22 +452,34 @@ def propagate(previous_obs, current_obs, obs_to_extend, dead_track, nb_next, mod
452452
return next_obs
453453

454454
@staticmethod
455-
def cost_function_common_area(records_in, records_out, distance):
455+
def cost_function_common_area(records_in, records_out, distance, intern=False):
456+
x_name, y_name = ('contour_lon_s', 'contour_lat_s') if intern else ('contour_lon_e', 'contour_lat_e')
456457
nb_records = records_in.shape[0]
457458
costs = ma.empty(nb_records,dtype='f4')
459+
tolerance = 0.025
458460
for i_record in range(nb_records):
459461
poly_in = Polygon(
460462
concatenate((
461-
(records_in[i_record]['contour_lon_e'],),
462-
(records_in[i_record]['contour_lat_e'],))
463+
(records_in[i_record][x_name],),
464+
(records_in[i_record][y_name],))
463465
).T
464466
)
465467
poly_out = Polygon(
466468
concatenate((
467-
(records_out[i_record]['contour_lon_e'],),
468-
(records_out[i_record]['contour_lat_e'],))
469+
(records_out[i_record][x_name],),
470+
(records_out[i_record][y_name],))
469471
).T
470472
)
473+
if not poly_in.is_valid:
474+
poly_in = poly_in.simplify(tolerance=tolerance)
475+
if not poly_in.is_valid:
476+
logging.warning('Need to simplify polygon in for a second time')
477+
poly_in = poly_in.buffer(0)
478+
if not poly_out.is_valid:
479+
poly_out = poly_out.simplify(tolerance=tolerance)
480+
if not poly_out.is_valid:
481+
logging.warning('Need to simplify polygon out for a second time')
482+
poly_out = poly_out.buffer(0)
471483
try:
472484
costs[i_record] = 1 - poly_in.intersection(poly_out).area / poly_in.area
473485
except TopologicalError:
@@ -725,7 +737,7 @@ def tracking(self, other):
725737

726738
logging.debug('%d matched with previous', i_self.shape[0])
727739

728-
return i_self, i_other
740+
return i_self, i_other, cost_mat[i_self, i_other]
729741

730742
def to_netcdf(self, handler):
731743
eddy_size = len(self)
@@ -858,7 +870,7 @@ def filled_by_interpolation(self, mask):
858870

859871
for field in self.obs.dtype.descr:
860872
var = field[0]
861-
if var in ['n', 'virtual', 'track'] or var in self.array_variables:
873+
if var in ['n', 'virtual', 'track', 'cost_association'] or var in self.array_variables:
862874
continue
863875
# to normalize longitude before interpolation
864876
if var== 'lon':

src/py_eddy_tracker/tracking.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030
from matplotlib.dates import julian2num, num2date
3131

3232
from py_eddy_tracker.observations import EddiesObservations, VirtualEddiesObservations, TrackEddiesObservations
33-
from numpy import bool_, array, arange, ones, setdiff1d, zeros, uint16, where, empty, isin, unique, concatenate, ma
34-
from netCDF4 import Dataset
33+
from numpy import bool_, array, arange, ones, setdiff1d, zeros, uint16, where, empty, isin, unique, concatenate, \
34+
ma
35+
from netCDF4 import Dataset, default_fillvals
3536
import logging
3637
import platform
3738

@@ -55,7 +56,9 @@ def __init__(self, datasets, virtual=0, class_method=None, previous_correspondan
5556
# Correspondance dtype
5657
self.correspondance_dtype = [('in', 'u2'),
5758
('out', 'u2'),
58-
('id', self.ID_DTYPE)]
59+
('id', self.ID_DTYPE),
60+
('cost_value', 'f4')
61+
]
5962
if class_method is None:
6063
self.class_method = EddiesObservations
6164
else:
@@ -168,7 +171,7 @@ def merge_correspondance(self, other):
168171
# We set new id available
169172
self.current_id = translate[-1] + 1
170173

171-
def store_correspondance(self, i_previous, i_current, nb_real_obs):
174+
def store_correspondance(self, i_previous, i_current, nb_real_obs, association_cost):
172175
"""Storing correspondance in an array
173176
"""
174177
# Create array to store correspondance data
@@ -177,6 +180,7 @@ def store_correspondance(self, i_previous, i_current, nb_real_obs):
177180
correspondance['virtual_length'][:] = 255
178181
# index from current_obs
179182
correspondance['out'] = i_current
183+
correspondance['cost_value'] = association_cost
180184

181185
if self.virtual:
182186
# if index in previous dataset is bigger than real obs number
@@ -314,10 +318,10 @@ def track(self):
314318
if flg_virtual:
315319
logging.debug('%d virtual obs will be add to previous', len(self.virtual_obs))
316320
self.previous_obs = self.previous_obs.merge(self.virtual_obs)
317-
i_previous, i_current = self.previous_obs.tracking(self.current_obs)
321+
i_previous, i_current, association_cost = self.previous_obs.tracking(self.current_obs)
318322

319323
# return true if the first time (previous2obs is none)
320-
if self.store_correspondance(i_previous, i_current, nb_real_obs):
324+
if self.store_correspondance(i_previous, i_current, nb_real_obs, association_cost):
321325
continue
322326

323327
self.recense_dead_id_to_extend()
@@ -376,7 +380,11 @@ def save(self, filename, dict_completion=None):
376380
for name, _ in self.correspondance_dtype:
377381
datas[name][i, :nb_elt] = correspondance[name]
378382
for name, data in datas.items():
379-
h_nc.variables[name][:] = data
383+
h_v = h_nc.variables[name]
384+
h_v[:] = data
385+
if 'File' not in name:
386+
h_v.min = h_v[:].min()
387+
h_v.max = h_v[:].max()
380388

381389
h_nc.virtual_use = str(self.virtual)
382390
h_nc.virtual_max_segment = self.nb_virtual
@@ -512,6 +520,8 @@ def merge(self, until=-1):
512520
track_array_variables=self.current_obs.track_array_variables,
513521
array_variables=self.current_obs.array_variables)
514522

523+
# All the value put at nan, necessary only for all end of track
524+
eddies['cost_association'][:] = default_fillvals['f4']
515525
# Calculate the index in each tracks, we compute in u4 and translate
516526
# in u2 (which are limited to 65535)
517527
logging.debug('Compute global index array (N)')
@@ -547,6 +557,8 @@ def merge(self, until=-1):
547557
# Copy all variable
548558
for field in fields:
549559
var = field[0]
560+
if var == 'cost_association':
561+
continue
550562
eddies[var][index_final[m_first_obs]] = self.previous_obs[var][index_in]
551563
# Increment
552564
self.i_current_by_tracks[i_id[m_first_obs]] += 1
@@ -570,7 +582,10 @@ def merge(self, until=-1):
570582
# Copy all variable
571583
for field in fields:
572584
var = field[0]
573-
eddies[var][index_final] = self.current_obs[var][index_current]
585+
if var == 'cost_association':
586+
eddies[var][index_final-1] = self[i]['cost_value']
587+
else:
588+
eddies[var][index_final] = self.current_obs[var][index_current]
574589

575590
# Add increment for each index used
576591
self.i_current_by_tracks[i_id] += 1

src/scripts/EddyTracking

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ if __name__ == '__main__':
104104
DICT_COMPLETION = dict(date_start=DATE_START, date_stop=DATE_STOP, date_prod=START_TIME,
105105
path=SAVE_DIR, sign_type=CORRESPONDANCES.current_obs.sign_legend)
106106

107-
108107
CORRESPONDANCES.save(CORRESPONDANCES_OUT, DICT_COMPLETION)
109108
if SAVE_STOP:
110109
exit()

0 commit comments

Comments
 (0)