Skip to content

Commit e7fee17

Browse files
committed
Restart on last tracking
1 parent 7df1dc3 commit e7fee17

File tree

4 files changed

+196
-61
lines changed

4 files changed

+196
-61
lines changed

src/py_eddy_tracker/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,23 @@ def parse_args(self, *args, **kwargs):
123123
),
124124
segment_size=dict(
125125
attr_name=None,
126-
nc_name=None,
126+
nc_name='segment_size',
127127
nc_type='byte',
128-
nc_dims=None,
128+
nc_dims=('Nobs',),
129129
nc_attr=dict()
130130
),
131131
dlon=dict(
132132
attr_name=None,
133-
nc_name=None,
133+
nc_name='dlon',
134134
nc_type='float64',
135-
nc_dims=None,
135+
nc_dims=('Nobs',),
136136
nc_attr=dict()
137137
),
138138
dlat=dict(
139139
attr_name=None,
140-
nc_name=None,
140+
nc_name='dlat',
141141
nc_type='float64',
142-
nc_dims=None,
142+
nc_dims=('Nobs',),
143143
nc_attr=dict()
144144
),
145145
virtual=dict(

src/py_eddy_tracker/observations.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def __init__(self, size=0, track_extra_variables=None,
208208
self.active = True
209209
self.sign_type = None
210210

211+
@property
212+
def shape(self):
213+
return self.observations.shape
214+
211215
def __repr__(self):
212216
return str(self.observations)
213217

@@ -369,11 +373,26 @@ def load_from_netcdf(cls, filename):
369373
eddies.sign_type = h_nc.variables['cyc'][0]
370374
return eddies
371375

376+
@classmethod
377+
def from_netcdf(cls, handler):
378+
nb_obs = len(handler.dimensions['Nobs'])
379+
kwargs = dict()
380+
if hasattr(handler, 'track_array_variables'):
381+
kwargs['track_array_variables'] = handler.track_array_variables
382+
kwargs['array_variables'] = handler.array_variables.split(',')
383+
kwargs['track_extra_variables'] = handler.track_extra_variables.split(',')
384+
for variable in handler.variables:
385+
var_inv = VAR_DESCR_inv[variable]
386+
eddies = cls(size=nb_obs, **kwargs)
387+
for variable in handler.variables:
388+
eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:]
389+
return eddies
390+
372391
@staticmethod
373392
def cost_function2(records_in, records_out, distance):
374393
nb_records = records_in.shape[0]
375394
costs = ma.empty(nb_records,dtype='f4')
376-
for i_record in xrange(nb_records):
395+
for i_record in range(nb_records):
377396
poly_in = Polygon(
378397
concatenate((
379398
(records_in[i_record]['contour_lon_e'],),
@@ -647,6 +666,55 @@ def tracking(self, other):
647666

648667
return i_self, i_other
649668

669+
def to_netcdf(self, handler):
670+
eddy_size = len(self)
671+
logging.debug('Create Dimensions "Nobs" : %d', eddy_size)
672+
handler.createDimension('Nobs', eddy_size)
673+
handler.track_extra_variables = ','.join(self.track_extra_variables)
674+
if self.track_array_variables != 0:
675+
handler.createDimension('NbSample', self.track_array_variables)
676+
handler.track_array_variables = self.track_array_variables
677+
handler.array_variables = ','.join(self.array_variables)
678+
# Iter on variables to create:
679+
for field in self.observations.dtype.descr:
680+
name = field[0]
681+
logging.debug('Create Variable %s', VAR_DESCR[name]['nc_name'])
682+
self.create_variable(
683+
handler,
684+
dict(varname=VAR_DESCR[name]['nc_name'],
685+
datatype=VAR_DESCR[name]['output_type'],
686+
dimensions=VAR_DESCR[name]['nc_dims']),
687+
VAR_DESCR[name]['nc_attr'],
688+
self.observations[name],
689+
scale_factor=VAR_DESCR[name].get('scale_factor', None),
690+
add_offset=VAR_DESCR[name].get('add_offset', None)
691+
)
692+
693+
@staticmethod
694+
def create_variable(handler_nc, kwargs_variable, attr_variable,
695+
data, scale_factor=None, add_offset=None):
696+
var = handler_nc.createVariable(
697+
zlib=True,
698+
complevel=1,
699+
**kwargs_variable)
700+
attrs = list(attr_variable.keys())
701+
attrs.sort()
702+
for attr in attrs:
703+
attr_value = attr_variable[attr]
704+
var.setncattr(attr, attr_value)
705+
if scale_factor is not None:
706+
var.scale_factor = scale_factor
707+
if add_offset is not None:
708+
var.add_offset = add_offset
709+
else:
710+
var.add_offset = 0
711+
var[:] = data
712+
try:
713+
var.setncattr('min', var[:].min())
714+
var.setncattr('max', var[:].max())
715+
except ValueError:
716+
logging.warning('Data is empty')
717+
650718

651719
class VirtualEddiesObservations(EddiesObservations):
652720
"""Class to work with virtual obs
@@ -680,6 +748,7 @@ def move_function(cls, obs_a, obs_b, out):
680748
@classmethod
681749
def forecast_move(cls, obs_a, obs_b, out):
682750
"""Forecast move of an eddy
751+
work to do
683752
"""
684753
# New dead
685754
for key in obs_b.dtype.fields.keys():
@@ -709,6 +778,7 @@ def forecast_move(cls, obs_a, obs_b, out):
709778
# Count
710779
out['segment_size'][:] += 1
711780

781+
712782
class TrackEddiesObservations(EddiesObservations):
713783
"""Class to practice Tracking on observations
714784
"""
@@ -760,37 +830,13 @@ def elements(self):
760830
elements.extend(['track', 'n', 'virtual'])
761831
return elements
762832

763-
@staticmethod
764-
def create_variable(handler_nc, kwargs_variable, attr_variable,
765-
data, scale_factor=None, add_offset=None):
766-
var = handler_nc.createVariable(
767-
zlib=True,
768-
complevel=1,
769-
**kwargs_variable)
770-
attrs = list(attr_variable.keys())
771-
attrs.sort()
772-
for attr in attrs:
773-
attr_value = attr_variable[attr]
774-
var.setncattr(attr, attr_value)
775-
if scale_factor is not None:
776-
var.scale_factor = scale_factor
777-
if add_offset is not None:
778-
var.add_offset = add_offset
779-
else:
780-
var.add_offset = 0
781-
var[:] = data
782-
try:
783-
var.setncattr('min', var[:].min())
784-
var.setncattr('max', var[:].max())
785-
except ValueError:
786-
logging.warning('Data is empty')
787-
788833
def write_netcdf(self, path='./', filename='%(path)s/%(sign_type)s.nc'):
789834
"""Write a netcdf with eddy obs
790835
"""
791836
eddy_size = len(self.observations)
792837
sign_type = 'Cyclonic' if self.sign_type == -1 else 'Anticyclonic'
793838
filename = filename % dict(path=path, sign_type=sign_type)
839+
logging.info('Store in %s', filename)
794840
with Dataset(filename, 'w', format='NETCDF4') as h_nc:
795841
logging.info('Create file %s', filename)
796842
# Create dimensions

src/py_eddy_tracker/tracking.py

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
===========================================================================
2828
2929
"""
30+
from matplotlib.dates import julian2num, num2date
31+
3032
from py_eddy_tracker.observations import EddiesObservations, \
3133
VirtualEddiesObservations, TrackEddiesObservations
3234
from numpy import bool_, array, arange, ones, setdiff1d, zeros, uint16, \
@@ -47,7 +49,7 @@ class Correspondances(list):
4749
# Track limit to 65535
4850
N_DTYPE = 'u2'
4951

50-
def __init__(self, datasets, virtual=0, class_method=None):
52+
def __init__(self, datasets, virtual=0, class_method=None, previous_correspondance=None):
5153
"""Initiate tracking
5254
"""
5355
super(Correspondances, self).__init__()
@@ -59,6 +61,7 @@ def __init__(self, datasets, virtual=0, class_method=None):
5961
self.class_method = EddiesObservations
6062
else:
6163
self.class_method = class_method
64+
6265
# To count ID
6366
self.current_id = 0
6467
# To know the number maximal of link between two state
@@ -76,10 +79,17 @@ def __init__(self, datasets, virtual=0, class_method=None):
7679
self.virtual = virtual > 0
7780
self.virtual_obs = None
7881
self.previous_virtual_obs = None
82+
83+
# Correspondance to prolongate
84+
self.filename_previous_correspondance = previous_correspondance
85+
self.previous_correspondance = self.load_compatible(self.filename_previous_correspondance)
86+
7987
if self.virtual:
8088
# Add field to dtype to follow virtual observations
8189
self.correspondance_dtype += [
90+
# True if it isn't a real obs
8291
('virtual', bool_),
92+
# Length of virtual segment
8393
('virtual_length', self.VIRTUAL_DTYPE)]
8494

8595
# Array to simply merged
@@ -93,6 +103,17 @@ def reset_dataset_cache(self):
93103
self.previous_obs = None
94104
self.current_obs = None
95105

106+
@property
107+
def period(self):
108+
"""To rethink
109+
110+
Returns: period coverage by obs
111+
112+
"""
113+
date_start = num2date(julian2num(self.class_method.load_from_netcdf(self.datasets[0]).obs['time'][0] - 0.5))
114+
date_stop = num2date(julian2num(self.class_method.load_from_netcdf(self.datasets[-1]).obs['time'][0] - 0.5))
115+
return date_start, date_stop
116+
96117
def swap_dataset(self, dataset):
97118
""" Swap to next dataset
98119
"""
@@ -138,6 +159,8 @@ def store_correspondance(self, i_previous, i_current, nb_real_obs):
138159
"""
139160
# Create array to store correspondance data
140161
correspondance = array(i_previous, dtype=self.correspondance_dtype)
162+
if self.virtual:
163+
correspondance['virtual_length'][:] = 255
141164
# index from current_obs
142165
correspondance['out'] = i_current
143166

@@ -279,14 +302,33 @@ def recense_dead_id_to_extend(self):
279302
# Count
280303
self.virtual_obs['segment_size'][:] += 1
281304

305+
def load_state(self):
306+
# If we have a previous file of correspondance, we will replay only recent part
307+
if self.previous_correspondance is not None:
308+
first_dataset = len(self.previous_correspondance.datasets)
309+
for correspondance in self.previous_correspondance[:first_dataset]:
310+
self.append(correspondance)
311+
self.current_obs = self.class_method.load_from_netcdf(self.datasets[first_dataset - 2])
312+
flg_virtual = self.previous_correspondance.virtual
313+
with Dataset(self.filename_previous_correspondance) as general_handler:
314+
self.current_id = general_handler.last_current_id
315+
# Load last virtual obs
316+
self.virtual_obs = VirtualEddiesObservations.from_netcdf(general_handler.groups['LastVirtualObs'])
317+
# Load and last previous virtual obs to be merge with current => will be previous2_obs
318+
self.current_obs = self.current_obs.merge(
319+
VirtualEddiesObservations.from_netcdf(general_handler.groups['LastPreviousVirtualObs']))
320+
return first_dataset, flg_virtual
321+
return 1, False
322+
282323
def track(self):
283324
"""Run tracking
284325
"""
285-
flg_virtual = False
286326
self.reset_dataset_cache()
287-
self.swap_dataset(self.datasets[0])
327+
first_dataset, flg_virtual = self.load_state()
328+
329+
self.swap_dataset(self.datasets[first_dataset - 1])
288330
# We begin with second file, first one is in previous
289-
for i, file_name in enumerate(self.datasets[1:]):
331+
for file_name in self.datasets[first_dataset:]:
290332
self.swap_dataset(file_name)
291333
logging.debug('%s match with previous state', file_name)
292334
logging.debug('%d obs to match', len(self.current_obs))
@@ -295,13 +337,11 @@ def track(self):
295337
if flg_virtual:
296338
logging.debug('%d virtual obs will be add to previous',
297339
len(self.virtual_obs))
298-
# If you comment this the virtual fonctionnality will be
299-
# disable
300340
self.previous_obs = self.previous_obs.merge(self.virtual_obs)
301-
302341
i_previous, i_current = self.previous_obs.tracking(
303342
self.current_obs)
304343

344+
# return true if the first time (previous2obs is none)
305345
if self.store_correspondance(i_previous, i_current, nb_real_obs):
306346
continue
307347

@@ -310,9 +350,11 @@ def track(self):
310350
if self.virtual:
311351
flg_virtual = True
312352

313-
def save(self, filename):
353+
def save(self, filename, dict_completion=None):
314354
self.prepare_merging()
315355
nb_step = len(self.datasets) - 1
356+
if isinstance(dict_completion, dict):
357+
filename = filename.format(**dict_completion)
316358
logging.info('Create correspondance file %s', filename)
317359
with Dataset(filename, 'w', format='NETCDF4') as h_nc:
318360
# Create dimensions
@@ -337,19 +379,46 @@ def save(self, filename):
337379

338380
for name, dtype in self.correspondance_dtype:
339381
if dtype is bool_:
340-
dtype = 'byte'
382+
dtype = 'u1'
383+
kwargs_cv = dict()
384+
if 'u1' in dtype:
385+
kwargs_cv['fill_value'] = 255,
341386
h_nc.createVariable(zlib=True,
342387
complevel=1,
343388
varname=name,
344389
datatype=dtype,
345-
dimensions=('Nstep', 'Nlink'))
390+
dimensions=('Nstep', 'Nlink'),
391+
**kwargs_cv
392+
)
346393

347394
for i, correspondance in enumerate(self):
348395
nb_elt = correspondance.shape[0]
349396
var_nb_link[i] = nb_elt
350397
for name, _ in self.correspondance_dtype:
351398
h_nc.variables[name][i, :nb_elt] = correspondance[name]
352-
h_nc.virtual = int(self.virtual)
399+
h_nc.virtual_use = str(self.virtual)
400+
h_nc.virtual_max_segment = self.nb_virtual
401+
h_nc.last_current_id = self.current_id
402+
if self.virtual_obs is not None:
403+
group = h_nc.createGroup('LastVirtualObs')
404+
self.virtual_obs.to_netcdf(group)
405+
group = h_nc.createGroup('LastPreviousVirtualObs')
406+
self.previous_virtual_obs.to_netcdf(group)
407+
h_nc.module = self.class_method.__module__
408+
h_nc.classname = self.class_method.__qualname__
409+
410+
def load_compatible(self, filename):
411+
if filename is None:
412+
return None
413+
previous_correspondance = Correspondances.load(filename)
414+
if self.nb_virtual != previous_correspondance.nb_virtual:
415+
raise Exception('File of correspondance IN contains a different virtual segment size : file(%d), yaml(%d)' %
416+
(previous_correspondance.nb_virtual, self.nb_virtual))
417+
418+
if self.class_method != previous_correspondance.class_method:
419+
raise Exception('File of correspondance IN contains a different class method: file(%s), yaml(%s)' %
420+
(previous_correspondance.class_method, self.class_method))
421+
return previous_correspondance
353422

354423
@classmethod
355424
def load(cls, filename):
@@ -358,7 +427,11 @@ def load(cls, filename):
358427
datasets = list(h_nc.variables['FileIn'][:])
359428
datasets.append(h_nc.variables['FileOut'][-1])
360429

361-
obj = cls(datasets, h_nc.virtual)
430+
if hasattr(h_nc, 'module'):
431+
class_method= getattr(__import__(h_nc.module, globals(), locals(), h_nc.classname), h_nc.classname)
432+
else:
433+
class_method= None
434+
obj = cls(datasets, h_nc.virtual_max_segment, class_method=class_method)
362435

363436
id_max = 0
364437
for i, nb_elt in enumerate(h_nc.variables['nb_link'][:]):
@@ -371,7 +444,8 @@ def load(cls, filename):
371444
for name, _ in obj.correspondance_dtype:
372445
if name == 'in':
373446
continue
374-
correspondance[name] = h_nc.variables[name][i, :nb_elt]
447+
if name == 'virtual_length':
448+
correspondance[name] = 255
375449
correspondance[name] = h_nc.variables[name][i, :nb_elt]
376450
id_max = max(id_max, correspondance['id'].max())
377451
obj.append(correspondance)

0 commit comments

Comments
 (0)