Skip to content

Commit b6e5cd8

Browse files
committed
Filtering on correspondance shorter track to use less memory
1 parent 424fbfb commit b6e5cd8

File tree

3 files changed

+62
-104
lines changed

3 files changed

+62
-104
lines changed

src/py_eddy_tracker/tracking.py

Lines changed: 50 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,8 @@
2929
"""
3030
from matplotlib.dates import julian2num, num2date
3131

32-
from py_eddy_tracker.observations import EddiesObservations, \
33-
VirtualEddiesObservations, TrackEddiesObservations
34-
from numpy import bool_, array, arange, ones, setdiff1d, zeros, uint16, \
35-
where, empty
32+
from py_eddy_tracker.observations import EddiesObservations, VirtualEddiesObservations, TrackEddiesObservations
33+
from numpy import bool_, array, arange, ones, setdiff1d, zeros, uint16, where, empty, isin
3634
from netCDF4 import Dataset
3735
import logging
3836

@@ -237,17 +235,13 @@ def recense_dead_id_to_extend(self):
237235
# get id already dead from few time
238236
nb_virtual_extend = 0
239237
if self.virtual_obs is not None:
240-
virtual_dead_id = setdiff1d(self.virtual_obs['track'],
241-
self[-1]['id'])
238+
virtual_dead_id = setdiff1d(self.virtual_obs['track'], self[-1]['id'])
242239
list_previous_virtual_id = self.virtual_obs['track'].tolist()
243-
i_virtual_dead_id = [
244-
list_previous_virtual_id.index(i) for i in virtual_dead_id]
240+
i_virtual_dead_id = [list_previous_virtual_id.index(i) for i in virtual_dead_id]
245241
# Virtual obs which can be prolongate
246-
alive_virtual_obs = self.virtual_obs['segment_size'
247-
][i_virtual_dead_id] < self.nb_virtual
242+
alive_virtual_obs = self.virtual_obs['segment_size'][i_virtual_dead_id] < self.nb_virtual
248243
nb_virtual_extend = alive_virtual_obs.sum()
249-
logging.debug('%d virtual obs will be prolongate on the '
250-
'next step', nb_virtual_extend)
244+
logging.debug('%d virtual obs will be prolongate on the next step', nb_virtual_extend)
251245

252246
# Save previous state to count virtual obs
253247
self.previous_virtual_obs = self.virtual_obs
@@ -270,35 +264,27 @@ def recense_dead_id_to_extend(self):
270264
# Position N-1 : B
271265
# Virtual Position : C
272266
# New position C = B + AB
273-
for key in obs_b.dtype.fields.keys():
274-
if key in ['lon', 'lat', 'time', 'track', 'segment_size',
275-
'dlon', 'dlat'] or 'contour_' in key:
267+
for key in self.previous_obs.elements:
268+
if key in ['lon', 'lat', 'time'] or 'contour_' in key:
276269
continue
277270
self.virtual_obs[key][:nb_dead] = obs_b[key]
278271
self.virtual_obs['dlon'][:nb_dead] = obs_b['lon'] - obs_a['lon']
279272
self.virtual_obs['dlat'][:nb_dead] = obs_b['lat'] - obs_a['lat']
280-
self.virtual_obs['lon'][:nb_dead
281-
] = obs_b['lon'] + self.virtual_obs['dlon'][:nb_dead]
282-
self.virtual_obs['lat'][:nb_dead
283-
] = obs_b['lat'] + self.virtual_obs['dlat'][:nb_dead]
273+
self.virtual_obs['lon'][:nb_dead] = obs_b['lon'] + self.virtual_obs['dlon'][:nb_dead]
274+
self.virtual_obs['lat'][:nb_dead] = obs_b['lat'] + self.virtual_obs['dlat'][:nb_dead]
284275
# Id which are extended
285276
self.virtual_obs['track'][:nb_dead] = dead_id
286277
# Add previous virtual
287278
if nb_virtual_extend > 0:
288-
obs_to_extend = self.previous_virtual_obs.obs[i_virtual_dead_id
289-
][alive_virtual_obs]
290-
for key in obs_b.dtype.fields.keys():
291-
if key in ['lon', 'lat', 'time', 'track', 'segment_size',
292-
'dlon', 'dlat'] or 'contour_' in key:
279+
obs_to_extend = self.previous_virtual_obs.obs[i_virtual_dead_id][alive_virtual_obs]
280+
for key in self.virtual_obs.elements:
281+
if key in ['lon', 'lat', 'time', 'track', 'segment_size'] or 'contour_' in key:
293282
continue
294283
self.virtual_obs[key][nb_dead:] = obs_to_extend[key]
295-
self.virtual_obs['lon'][nb_dead:
296-
] = obs_to_extend['lon'] + obs_to_extend['dlon']
297-
self.virtual_obs['lat'][nb_dead:
298-
] = obs_to_extend['lat'] + obs_to_extend['dlat']
284+
self.virtual_obs['lon'][nb_dead:] = obs_to_extend['lon'] + obs_to_extend['dlon']
285+
self.virtual_obs['lat'][nb_dead:] = obs_to_extend['lat'] + obs_to_extend['dlat']
299286
self.virtual_obs['track'][nb_dead:] = obs_to_extend['track']
300-
self.virtual_obs['segment_size'][nb_dead:
301-
] = obs_to_extend['segment_size']
287+
self.virtual_obs['segment_size'][nb_dead:] = obs_to_extend['segment_size']
302288
# Count
303289
self.virtual_obs['segment_size'][:] += 1
304290

@@ -335,11 +321,9 @@ def track(self):
335321

336322
nb_real_obs = len(self.previous_obs)
337323
if flg_virtual:
338-
logging.debug('%d virtual obs will be add to previous',
339-
len(self.virtual_obs))
324+
logging.debug('%d virtual obs will be add to previous', len(self.virtual_obs))
340325
self.previous_obs = self.previous_obs.merge(self.virtual_obs)
341-
i_previous, i_current = self.previous_obs.tracking(
342-
self.current_obs)
326+
i_previous, i_current = self.previous_obs.tracking(self.current_obs)
343327

344328
# return true if the first time (previous2obs is none)
345329
if self.store_correspondance(i_previous, i_current, nb_real_obs):
@@ -455,59 +439,63 @@ def load(cls, filename):
455439
def prepare_merging(self):
456440
# count obs by tracks (we add directly one, because correspondance
457441
# is an interval)
458-
self.nb_obs_by_tracks = zeros(self.current_id, dtype=self.N_DTYPE) + 1
442+
self.nb_obs_by_tracks = ones(self.current_id, dtype=self.N_DTYPE)
459443
for correspondance in self:
460444
self.nb_obs_by_tracks[correspondance['id']] += 1
461445
if self.virtual:
462446
# When start is virtual, we don't have a previous
463447
# correspondance
464-
self.nb_obs_by_tracks[
465-
correspondance['id'][correspondance['virtual']]
466-
] += correspondance['virtual_length'][
467-
correspondance['virtual']]
448+
self.nb_obs_by_tracks[correspondance['id'][correspondance['virtual']]
449+
] += correspondance['virtual_length'][correspondance['virtual']]
468450

469451
# Compute index of each tracks
470-
self.i_current_by_tracks = \
471-
self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks
452+
self.i_current_by_tracks = self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks
472453
# Number of global obs
473454
self.nb_obs = self.nb_obs_by_tracks.sum()
474455
logging.info('%d tracks identified', self.current_id)
475456
logging.info('%d observations will be join', self.nb_obs)
476457

477-
def merge(self, until=-1):
458+
def merge(self, until=-1, size_min=None):
478459
"""Merge all the correspondance in one array with all fields
479460
"""
480461
# Start loading identification again to save in the finals tracks
481462
# Load first file
463+
self.reset_dataset_cache()
482464
self.swap_dataset(self.datasets[0])
483465

484466
# Start create netcdf to agglomerate all eddy
485467
logging.debug('We will create an array (size %d)', self.nb_obs)
468+
i_keep_track = slice(None)
469+
if size_min is not None:
470+
i_keep_track = where(self.nb_obs_by_tracks >= size_min)
471+
self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track]
472+
self.i_current_by_tracks[i_keep_track] = self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks
473+
self.nb_obs = self.nb_obs_by_tracks.sum()
474+
# ??
475+
self.current_id = self.nb_obs_by_tracks.shape[0]
486476
eddies = TrackEddiesObservations(
487477
size=self.nb_obs,
488478
track_extra_variables=self.current_obs.track_extra_variables,
489479
track_array_variables=self.current_obs.track_array_variables,
490-
array_variables=self.current_obs.array_variables,
491-
)
480+
array_variables=self.current_obs.array_variables)
492481

493482
# Calculate the index in each tracks, we compute in u4 and translate
494483
# in u2 (which are limited to 65535)
495484
logging.debug('Compute global index array (N)')
496485
eddies['n'][:] = uint16(
497-
arange(self.nb_obs, dtype='u4')
498-
- self.i_current_by_tracks.repeat(self.nb_obs_by_tracks))
486+
arange(self.nb_obs, dtype='u4') - self.i_current_by_tracks[i_keep_track].repeat(self.nb_obs_by_tracks))
499487
logging.debug('Compute global track array')
500-
eddies['track'][:] = arange(self.current_id
501-
).repeat(self.nb_obs_by_tracks)
488+
eddies['track'][:] = arange(self.current_id).repeat(self.nb_obs_by_tracks)
489+
if size_min is not None:
490+
eddies['track'][:] += 1
502491

503492
# Set type of eddy with first file
504493
eddies.sign_type = self.current_obs.sign_type
505494
# Fields to copy
506495
fields = self.current_obs.obs.dtype.descr
507496

508497
# To know if the track start
509-
first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape,
510-
dtype=bool_)
498+
first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_)
511499

512500
for i, file_name in enumerate(self.datasets[1:]):
513501
if until != -1 and i >= until:
@@ -517,19 +505,23 @@ def merge(self, until=-1):
517505
self.swap_dataset(file_name)
518506
# We select the list of id which are involve in the correspondance
519507
i_id = self[i]['id']
520-
# Index where we will write in the final object
508+
if size_min is not None:
509+
m_id = isin(i_id, i_keep_track)
510+
i_id= i_id[m_id]
511+
else:
512+
m_id = slice(None)
513+
# Index where we will write in the final object
521514
index_final = self.i_current_by_tracks[i_id]
522515

523516
# First obs of eddies
524517
m_first_obs = ~first_obs_save_in_tracks[i_id]
525518
if m_first_obs.any():
526519
# Index in the previous file
527-
index_in = self[i]['in'][m_first_obs]
520+
index_in = self[i]['in'][m_id][m_first_obs]
528521
# Copy all variable
529522
for field in fields:
530523
var = field[0]
531-
eddies[var][index_final[m_first_obs]
532-
] = self.previous_obs[var][index_in]
524+
eddies[var][index_final[m_first_obs]] = self.previous_obs[var][index_in]
533525
# Increment
534526
self.i_current_by_tracks[i_id[m_first_obs]] += 1
535527
# Active this flag, we have only one first by tracks
@@ -539,22 +531,20 @@ def merge(self, until=-1):
539531
if self.virtual:
540532
# If the flag virtual in correspondance is active,
541533
# the previous is virtual
542-
m_virtual = self[i]['virtual']
534+
m_virtual = self[i]['virtual'][m_id]
543535
if m_virtual.any():
544536
# Incrementing index
545-
self.i_current_by_tracks[i_id[m_virtual]
546-
] += self[i]['virtual_length'][m_virtual]
537+
self.i_current_by_tracks[i_id[m_virtual]] += self[i]['virtual_length'][m_id][m_virtual]
547538
# Get new index
548539
index_final = self.i_current_by_tracks[i_id]
549540

550541
# Index in the current file
551-
index_current = self[i]['out']
542+
index_current = self[i]['out'][m_id]
552543

553544
# Copy all variable
554545
for field in fields:
555546
var = field[0]
556-
eddies[var][index_final
557-
] = self.current_obs[var][index_current]
547+
eddies[var][index_final] = self.current_obs[var][index_current]
558548

559549
# Add increment for each index used
560550
self.i_current_by_tracks[i_id] += 1

src/scripts/EddyFinalTracking

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,19 @@ if __name__ == '__main__':
4646

4747
CORRESPONDANCES.prepare_merging()
4848

49-
FINAL_EDDIES = CORRESPONDANCES.merge()
49+
logging.info('The longest tracks have %d observations', CORRESPONDANCES.nb_obs_by_tracks.max())
50+
logging.info('The mean length is %d observations before filtering', CORRESPONDANCES.nb_obs_by_tracks.mean())
51+
FINAL_EDDIES = CORRESPONDANCES.merge(size_min=CONFIG.nb_obs_min)
5052

5153
# We flag obs
5254
if CORRESPONDANCES.virtual:
5355
FINAL_EDDIES['virtual'][:] = FINAL_EDDIES['time'] == 0
54-
5556
FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES['virtual'] == 1)
5657

5758
FULL_TIME = dt.datetime.now() - START_TIME
5859
logging.info('Duration : %s', FULL_TIME)
5960

60-
logging.info('The longest tracks have %d observations',
61-
CORRESPONDANCES.nb_obs_by_tracks.max())
62-
logging.info('The mean length is %d observations before filtering',
63-
CORRESPONDANCES.nb_obs_by_tracks.mean())
64-
65-
SUBSET_EDDIES = FINAL_EDDIES.extract_longer_eddies(
66-
CONFIG.nb_obs_min,
67-
CORRESPONDANCES.nb_obs_by_tracks.repeat(
68-
CORRESPONDANCES.nb_obs_by_tracks)
69-
)
70-
71-
logging.info('%d tracks will be saved',
72-
len(unique(SUBSET_EDDIES['track'])))
73-
74-
logging.info(
75-
'The mean length is %d observations after filtering',
76-
CORRESPONDANCES.nb_obs_by_tracks[
77-
CORRESPONDANCES.nb_obs_by_tracks > CONFIG.nb_obs_min
78-
].mean())
61+
logging.info('%d tracks will be saved', len(unique(FINAL_EDDIES['track'])))
62+
logging.info('The mean length is %d observations after filtering', CORRESPONDANCES.nb_obs_by_tracks.mean())
7963

80-
SUBSET_EDDIES.write_netcdf(path=CONFIG.path_out)
64+
FINAL_EDDIES.write_netcdf(path=CONFIG.path_out)

src/scripts/EddyTracking

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,13 @@ if __name__ == '__main__':
9292
NB_OBS_MIN = int(CONFIG['TRACK_DURATION_MIN'])
9393
CORRESPONDANCES.prepare_merging()
9494

95-
FINAL_EDDIES = CORRESPONDANCES.merge()
95+
logging.info('The longest tracks have %d observations', CORRESPONDANCES.nb_obs_by_tracks.max())
96+
logging.info('The mean length is %d observations before filtering', CORRESPONDANCES.nb_obs_by_tracks.mean())
97+
FINAL_EDDIES = CORRESPONDANCES.merge(size_min=NB_OBS_MIN)
9698

9799
# We flag obs
98100
if CORRESPONDANCES.virtual:
99101
FINAL_EDDIES['virtual'][:] = FINAL_EDDIES['time'] == 0
100-
101102
FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES['virtual'] == 1)
102103

103104
# Total running time
@@ -106,24 +107,7 @@ if __name__ == '__main__':
106107
FULL_TIME / (len(FILENAMES) - 1))
107108
logging.info('Duration : %s', FULL_TIME)
108109

109-
logging.info('The longest tracks have %d observations',
110-
CORRESPONDANCES.nb_obs_by_tracks.max())
111-
logging.info('The mean length is %d observations before filtering',
112-
CORRESPONDANCES.nb_obs_by_tracks.mean())
113-
114-
SUBSET_EDDIES = FINAL_EDDIES.extract_longer_eddies(
115-
NB_OBS_MIN,
116-
CORRESPONDANCES.nb_obs_by_tracks.repeat(
117-
CORRESPONDANCES.nb_obs_by_tracks)
118-
)
119-
120-
logging.info('%d tracks will be saved',
121-
len(unique(SUBSET_EDDIES['track'])))
122-
123-
logging.info(
124-
'The mean length is %d observations after filtering',
125-
CORRESPONDANCES.nb_obs_by_tracks[
126-
CORRESPONDANCES.nb_obs_by_tracks >= NB_OBS_MIN
127-
].mean())
110+
logging.info('%d tracks will be saved', CORRESPONDANCES.nb_obs_by_tracks.max())
111+
logging.info('The mean length is %d observations after filtering', CORRESPONDANCES.nb_obs_by_tracks.mean())
128112

129-
SUBSET_EDDIES.write_netcdf(path=SAVE_DIR)
113+
FINAL_EDDIES.write_netcdf(path=SAVE_DIR)

0 commit comments

Comments
 (0)