Skip to content

Commit 4215eac

Browse files
author
adelepoulle
committed
move action in class method
1 parent e925820 commit 4215eac

File tree

2 files changed

+98
-72
lines changed

2 files changed

+98
-72
lines changed

src/py_eddy_tracker/observations.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
===========================================================================
2929
3030
"""
31-
from numpy import zeros, empty, nan, arange, interp
31+
from numpy import zeros, empty, nan, arange, interp, where
3232
from netCDF4 import Dataset
3333
from py_eddy_tracker.tools import distance_matrix
3434
from . import VAR_DESCR, VAR_DESCR_inv
@@ -186,6 +186,15 @@ def load_from_netcdf(filename):
186186
eddies.sign_type = h_nc.variables['cyc'][0]
187187
return eddies
188188

189+
def tracking(self, other):
190+
"""Track obs between from self to other
191+
"""
192+
dist = self.distance(other)
193+
i_self, i_other = where(dist < 20.)
194+
195+
logging.debug('%d match with previous', i_self.shape[0])
196+
return i_self, i_other
197+
189198

190199
class VirtualEddiesObservations(EddiesObservations):
191200

src/scripts/EddyTracking

Lines changed: 88 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,81 @@
66
from py_eddy_tracker import EddyParser
77
from glob import glob
88
from yaml import load as yaml_load
9+
from py_eddy_tracker.tracking import \
10+
Correspondance
911
from py_eddy_tracker.observations import \
1012
EddiesObservations, TrackEddiesObservations, \
1113
VirtualEddiesObservations
1214

1315
import logging
14-
import numpy as np
16+
from numpy import array, arange, bool_, uint16, unique, setdiff1d, \
17+
ones, zeros
1518
import datetime as dt
1619

1720

18-
D2R = 0.017453292519943295
1921
UINT32_MAX = 4294967295
2022
UINT16_MAX = 65535
23+
# ID limit to 4294967295
24+
ID_DTYPE = 'u4'
25+
# Track limit to 65535
26+
N_DTYPE = 'u2'
27+
# Prolongation limit to 255
28+
VIRTUAL_DTYPE = 'u1'
2129

2230

23-
if __name__ == '__main__':
31+
def usage():
32+
"""Usage
33+
"""
2434
# Run using:
25-
PARSER = EddyParser(
35+
parser = EddyParser(
2636
"Tool to use identification step to compute tracking")
27-
PARSER.add_argument('yaml_file',
37+
parser.add_argument('yaml_file',
2838
help='Yaml file to configure py-eddy-tracker')
29-
YAML_FILE = PARSER.parse_args().yaml_file
39+
yaml_file = parser.parse_args().yaml_file
3040

3141
# Read yaml configuration file
32-
with open(YAML_FILE, 'r') as stream:
33-
CONFIG = yaml_load(stream)
42+
with open(yaml_file, 'r') as stream:
43+
config = yaml_load(stream)
44+
return config
45+
46+
47+
if __name__ == '__main__':
48+
CONFIG = usage()
3449

3550
NB_OBS_MIN = int(CONFIG['TRACK_DURATION_MIN'])
3651
NB_VIRTUAL_OBS_MAX_BY_SEGMENT = int(CONFIG['VIRTUAL_LEGNTH_MAX'])
52+
53+
ACTIVE_VIRTUAL = NB_VIRTUAL_OBS_MAX_BY_SEGMENT > 0
3754

3855
PATTERN = CONFIG['PATHS']['FILES_PATTERN']
3956
FILENAMES = glob(PATTERN)
4057
FILENAMES.sort()
4158

42-
e_previous = EddiesObservations.load_from_netcdf(FILENAMES[0])
43-
4459
START_TIME = dt.datetime.now()
4560
logging.info('Start tracking on %d files', len(FILENAMES))
61+
4662
# To count id tracks
4763
CURRENT_ID = 0
64+
# Will contain all correspondance between each step
4865
CORRESPONDANCES = []
4966
START = True
5067
FLG_VIRTUAL = False
68+
69+
# Dtype for correpondance
70+
STANDARD_DTYPE = [
71+
('in', 'u2'),
72+
('out', 'u2'),
73+
('id', ID_DTYPE)]
74+
75+
if ACTIVE_VIRTUAL:
76+
STANDARD_DTYPE += [
77+
('virtual', bool_),
78+
('virtual_length', VIRTUAL_DTYPE)]
79+
80+
# Init
81+
e_previous = EddiesObservations.load_from_netcdf(FILENAMES[0])
5182

83+
# We begin with second file, first one is in previous
5284
for file_name in FILENAMES[1:]:
5385
logging.debug('%s match with previous state', file_name)
5486
e_current = EddiesObservations.load_from_netcdf(file_name)
@@ -60,35 +92,29 @@ if __name__ == '__main__':
6092
len(virtual_obs))
6193
# If you comment this the virtual fonctionnality will be disable
6294
e_previous = e_previous.merge(virtual_obs)
63-
dist = e_previous.distance(e_current)
64-
i_previous, i_current = np.where(dist < 20.)
95+
96+
i_previous, i_current = e_previous.tracking(e_current)
6597
nb_match = i_previous.shape[0]
6698

67-
logging.debug('%d match with previous', nb_match)
68-
correspondance = np.array(
69-
i_previous,
70-
dtype=[
71-
('in', 'u2'),
72-
('out', 'u2'),
73-
('id', 'u4'),
74-
('virtual', np.bool_),
75-
('virtual_length', 'u1')])
99+
#~ Correspondance()
100+
correspondance = array(i_previous, dtype=STANDARD_DTYPE)
76101
correspondance['out'] = i_current
77102

78-
if FLG_VIRTUAL:
79-
correspondance['virtual'] = i_previous >= nb_real_obs
80-
else:
81-
correspondance['virtual'] = False
103+
if ACTIVE_VIRTUAL:
104+
if FLG_VIRTUAL:
105+
correspondance['virtual'] = i_previous >= nb_real_obs
106+
else:
107+
correspondance['virtual'] = False
82108

83109
if START:
84110
START = False
85111
# Set an id for each match
86-
correspondance['id'] = np.arange(nb_match)
112+
correspondance['id'] = arange(nb_match)
87113
# Set counter
88114
CURRENT_ID += nb_match
89115
else:
90116
# We set all id to UINT32_MAX
91-
id_previous = np.ones(len(e_previous), dtype='u4') * UINT32_MAX
117+
id_previous = ones(len(e_previous), dtype=ID_DTYPE) * UINT32_MAX
92118
# We get old id for previously eddies tracked
93119
previous_id = CORRESPONDANCES[-1]['id']
94120
id_previous[CORRESPONDANCES[-1]['out']] = previous_id
@@ -110,7 +136,7 @@ if __name__ == '__main__':
110136
if FLG_VIRTUAL:
111137
# Save previous state to count virtual obs
112138
previous_virtual_obs = virtual_obs
113-
virtual_dead_id = np.setdiff1d(virtual_obs.obs['track'],
139+
virtual_dead_id = setdiff1d(virtual_obs.obs['track'],
114140
correspondance['id'])
115141
list_previous_virtual_id = virtual_obs.obs['track'].tolist()
116142
i_virtual_dead_id = [
@@ -122,7 +148,7 @@ if __name__ == '__main__':
122148
'next step', nb_virtual_prolongate)
123149

124150
# List previous id which are not use in the next step
125-
dead_id = np.setdiff1d(previous_id, correspondance['id'])
151+
dead_id = setdiff1d(previous_id, correspondance['id'])
126152
nb_dead_track = len(dead_id)
127153
logging.debug('%d death of real obs in this step', nb_dead_track)
128154
# Creation of an virtual step for dead one
@@ -163,7 +189,7 @@ if __name__ == '__main__':
163189
] = obs_to_prolongate['segment_size']
164190
# Count
165191
virtual_obs.obs['segment_size'] += 1
166-
if NB_VIRTUAL_OBS_MAX_BY_SEGMENT > 0:
192+
if ACTIVE_VIRTUAL:
167193
FLG_VIRTUAL = True
168194
# END
169195

@@ -173,7 +199,7 @@ if __name__ == '__main__':
173199
nb_new_tracks = mask_new_id.sum()
174200
logging.debug('%d birth in this step', nb_new_tracks)
175201
# Set new id
176-
correspondance['id'][mask_new_id] = np.arange(
202+
correspondance['id'][mask_new_id] = arange(
177203
CURRENT_ID, CURRENT_ID + nb_new_tracks)
178204
# Set counter
179205
CURRENT_ID += nb_new_tracks
@@ -186,12 +212,13 @@ if __name__ == '__main__':
186212
logging.info('Track finish')
187213
logging.info('Start merging')
188214
# count obs by tracks
189-
nb_obs_by_tracks = np.zeros(CURRENT_ID, dtype='u2') + 1
215+
nb_obs_by_tracks = zeros(CURRENT_ID, dtype=N_DTYPE) + 1
190216
for correspondance in CORRESPONDANCES:
191217
nb_obs_by_tracks[correspondance['id']] += 1
192-
# When start is virtual, we don't have a previous correspondance
193-
nb_obs_by_tracks[correspondance['id'][correspondance['virtual']]
194-
] += correspondance['virtual_length'][correspondance['virtual']]
218+
if ACTIVE_VIRTUAL:
219+
# When start is virtual, we don't have a previous correspondance
220+
nb_obs_by_tracks[correspondance['id'][correspondance['virtual']]
221+
] += correspondance['virtual_length'][correspondance['virtual']]
195222

196223
# Compute index of each tracks
197224
i_current_by_tracks = nb_obs_by_tracks.cumsum() - nb_obs_by_tracks
@@ -205,11 +232,11 @@ if __name__ == '__main__':
205232
# Calculate the index in each tracks, we compute in u4 and translate
206233
# in u2 (which are limited to 65535)
207234
logging.debug('Compute global index array (N)')
208-
n = np.arange(nb_obs,
209-
dtype='u4') - i_current_by_tracks.repeat(nb_obs_by_tracks)
210-
FINAL_EDDIES.obs['n'] = np.uint16(n)
235+
n = arange(nb_obs,
236+
dtype='u4') - i_current_by_tracks.repeat(nb_obs_by_tracks)
237+
FINAL_EDDIES.obs['n'] = uint16(n)
211238
logging.debug('Compute global track array')
212-
FINAL_EDDIES.obs['track'] = np.arange(CURRENT_ID).repeat(nb_obs_by_tracks)
239+
FINAL_EDDIES.obs['track'] = arange(CURRENT_ID).repeat(nb_obs_by_tracks)
213240

214241
# Start loading identification again to save in the finals tracks
215242
# Load first file
@@ -218,8 +245,8 @@ if __name__ == '__main__':
218245
FINAL_EDDIES.sign_type = eddies_previous.sign_type
219246

220247
# To know if the track start
221-
first_obs_save_in_tracks = np.zeros(i_current_by_tracks.shape,
222-
dtype=np.bool_)
248+
first_obs_save_in_tracks = zeros(i_current_by_tracks.shape,
249+
dtype=bool_)
223250

224251
for i, file_name in enumerate(FILENAMES[1:]):
225252
# Load current file (we begin with second one)
@@ -246,16 +273,18 @@ if __name__ == '__main__':
246273

247274
# Index in the current file
248275
index_current = CORRESPONDANCES[i]['out']
249-
# If the flag virtual in correspondance is active,
250-
# the previous is virtual
251-
m_virtual = CORRESPONDANCES[i]['virtual']
252-
if m_virtual.any():
253-
index_virtual = index_final[m_virtual]
254-
# Incrementing index
255-
i_current_by_tracks[i_id[m_virtual]
256-
] += CORRESPONDANCES[i]['virtual_length'][m_virtual]
257-
# Get new index
258-
index_final = i_current_by_tracks[i_id]
276+
277+
if ACTIVE_VIRTUAL:
278+
# If the flag virtual in correspondance is active,
279+
# the previous is virtual
280+
m_virtual = CORRESPONDANCES[i]['virtual']
281+
if m_virtual.any():
282+
index_virtual = index_final[m_virtual]
283+
# Incrementing index
284+
i_current_by_tracks[i_id[m_virtual]
285+
] += CORRESPONDANCES[i]['virtual_length'][m_virtual]
286+
# Get new index
287+
index_final = i_current_by_tracks[i_id]
259288

260289
# Copy all variable
261290
for var, _ in eddies_current.obs.dtype.descr:
@@ -267,28 +296,16 @@ if __name__ == '__main__':
267296
eddies_previous = eddies_current
268297

269298
# We flag obs
270-
FINAL_EDDIES.obs['virtual'] = FINAL_EDDIES.obs['time'] == 0
271-
272-
FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES.obs['virtual'] == 1)
273-
# Localization of virtual observation
274-
m_i = FINAL_EDDIES.obs['virtual'] == 1
275-
# Count virtual observations
276-
nb_virtual = FINAL_EDDIES.obs['virtual'].sum()
277-
278-
logging.info('%d obs are virtual (unobserved)', nb_virtual)
279-
logging.info('Start extrapolation of values for virtual observations')
280-
nb_obs = len(FINAL_EDDIES)
281-
index = np.arange(nb_obs)
282-
for var, _ in eddies_current.obs.dtype.descr:
283-
FINAL_EDDIES.obs[var][m_i] = np.interp(
284-
index[m_i],
285-
index[-m_i],
286-
FINAL_EDDIES.obs[var][-m_i])
299+
if ACTIVE_VIRTUAL:
300+
FINAL_EDDIES.obs['virtual'] = FINAL_EDDIES.obs['time'] == 0
301+
302+
FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES.obs['virtual'] == 1)
287303

288304
# Total running time
305+
FULL_TIME = dt.datetime.now() - START_TIME
289306
logging.info('Mean duration by loop : %s',
290-
(dt.datetime.now() - START_TIME) / (len(FILENAMES) - 1))
291-
logging.info('Duration : %s', dt.datetime.now() - START_TIME)
307+
FULL_TIME / (len(FILENAMES) - 1))
308+
logging.info('Duration : %s', FULL_TIME)
292309

293310
logging.info('The longest tracks have %d observations',
294311
nb_obs_by_tracks.max())
@@ -297,6 +314,6 @@ if __name__ == '__main__':
297314
NB_OBS_MIN, nb_obs_by_tracks.repeat(nb_obs_by_tracks))
298315

299316
logging.info('%d tracks will be saved',
300-
len(np.unique(SUBSET_EDDIES.obs['track'])))
317+
len(unique(SUBSET_EDDIES.obs['track'])))
301318

302319
SUBSET_EDDIES.write_netcdf()

0 commit comments

Comments
 (0)