Skip to content

Commit a91a226

Browse files
committed
- Add option to load netcdf in memory before to be open
- Use of safe_load for yaml - replace some numpy operation by numba equivalent
1 parent eaef32b commit a91a226

File tree

13 files changed

+429
-368
lines changed

13 files changed

+429
-368
lines changed

examples/02_eddy_identification/pet_eddy_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def update_axes(ax, mappable=None):
9191
update_axes(ax)
9292

9393
# %%
94-
# Creteria for rejecting a contour
94+
# Criteria for rejecting a contour
9595
# 0. - Accepted (green)
9696
# 1. - Rejection for shape error (red)
9797
# 2. - Masked value within contour (blue)

notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@
235235
},
236236
"outputs": [],
237237
"source": [
238-
"ax = start_axes(\"Detected Eddies\")\na.display(ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\nupdate_axes(ax)"
238+
"ax = start_axes(\"Detected Eddies\")\na.display(\n ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10\n)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\nupdate_axes(ax)"
239239
]
240240
},
241241
{

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
scripts=[
2929
"src/scripts/EddySubSetter",
3030
"src/scripts/EddyTranslate",
31-
"src/scripts/EddyTracking",
3231
"src/scripts/EddyFinalTracking",
3332
"src/scripts/EddyMergeCorrespondances",
3433
],
@@ -43,6 +42,7 @@
4342
"EddyFrequency = py_eddy_tracker.appli.eddies:get_frequency_grid",
4443
"EddyInfos = py_eddy_tracker.appli.eddies:display_infos",
4544
"EddyCircle = py_eddy_tracker.appli.eddies:eddies_add_circle",
45+
"EddyTracking = py_eddy_tracker.appli.eddies:eddies_tracking",
4646
# network
4747
"EddyNetworkGroup = py_eddy_tracker.appli.network:build_network",
4848
"EddyNetworkBuildPath = py_eddy_tracker.appli.network:divide_network",

src/py_eddy_tracker/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def add_base_argument(self):
8585
help="Levels : DEBUG, INFO, WARNING," " ERROR, CRITICAL",
8686
)
8787

88+
def memory_arg(self):
89+
self.add_argument(
90+
"--memory",
91+
action="store_true",
92+
help="Load file in memory before to read with netCDF library",
93+
)
94+
8895
def parse_args(self, *args, **kwargs):
8996
logger = start_logger()
9097
# Parsing

src/py_eddy_tracker/appli/eddies.py

Lines changed: 245 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@
33
Applications on detection and tracking files
44
"""
55
import argparse
6+
import logging
7+
from datetime import datetime
8+
from glob import glob
9+
from os import mkdir
10+
from os.path import basename, dirname, exists
11+
from os.path import join as join_path
12+
from re import compile as re_compile
613

714
from netCDF4 import Dataset
15+
from numpy import bytes_, empty, unique
16+
from yaml import safe_load
817

918
from .. import EddyParser
1019
from ..observations.observation import EddiesObservations
1120
from ..observations.tracking import TrackEddiesObservations
21+
from ..tracking import Correspondances
22+
23+
logger = logging.getLogger("pet")
1224

1325

1426
def eddies_add_circle():
@@ -41,24 +53,22 @@ def merge_eddies():
4153
parser.add_argument(
4254
"--include_var", nargs="+", type=str, help="use only listed variable"
4355
)
56+
parser.memory_arg()
4457
args = parser.parse_args()
4558

4659
if args.include_var is None:
4760
with Dataset(args.filename[0]) as h:
4861
args.include_var = h.variables.keys()
4962

50-
obs = TrackEddiesObservations.load_file(
51-
args.filename[0], raw_data=True, include_vars=args.include_var
52-
)
53-
if args.add_rotation_variable:
54-
obs = obs.add_rotation_type()
55-
for filename in args.filename[1:]:
56-
other = TrackEddiesObservations.load_file(
63+
obs = list()
64+
for filename in args.filename:
65+
e = TrackEddiesObservations.load_file(
5766
filename, raw_data=True, include_vars=args.include_var
5867
)
5968
if args.add_rotation_variable:
60-
other = other.add_rotation_type()
61-
obs = obs.merge(other)
69+
e = e.add_rotation_type()
70+
obs.append(e)
71+
obs = TrackEddiesObservations.concatenate(obs)
6272
obs.write_file(filename=args.out)
6373

6474

@@ -141,3 +151,229 @@ def display_infos():
141151
)
142152
e = e.extract_with_area(area)
143153
print(e)
154+
155+
156+
def eddies_tracking():
157+
parser = EddyParser("Tool to use identification step to compute tracking")
158+
parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker")
159+
parser.add_argument("--correspondance_in", help="Filename of saved correspondance")
160+
parser.add_argument("--correspondance_out", help="Filename to save correspondance")
161+
parser.add_argument(
162+
"--save_correspondance_and_stop",
163+
action="store_true",
164+
help="Stop tracking after correspondance computation,"
165+
" merging can be done with EddyFinalTracking",
166+
)
167+
parser.add_argument(
168+
"--zarr", action="store_true", help="Output will be wrote in zarr"
169+
)
170+
parser.add_argument("--unraw", action="store_true", help="Load unraw data")
171+
parser.add_argument(
172+
"--blank_period",
173+
type=int,
174+
default=0,
175+
help="Nb of detection which will not use at the end of the period",
176+
)
177+
parser.memory_arg()
178+
args = parser.parse_args()
179+
180+
# Read yaml configuration file
181+
with open(args.yaml_file, "r") as stream:
182+
config = safe_load(stream)
183+
184+
if "CLASS" in config:
185+
classname = config["CLASS"]["CLASS"]
186+
obs_class = dict(
187+
class_method=getattr(
188+
__import__(config["CLASS"]["MODULE"], globals(), locals(), classname),
189+
classname,
190+
),
191+
class_kw=config["CLASS"].get("OPTIONS", dict()),
192+
)
193+
else:
194+
obs_class = dict()
195+
196+
c_in, c_out = args.correspondance_in, args.correspondance_out
197+
if c_in is None:
198+
c_in = config["PATHS"].get("CORRESPONDANCES_IN", None)
199+
y_c_out = config["PATHS"].get(
200+
"CORRESPONDANCES_OUT", "{path}/{sign_type}_correspondances.nc"
201+
)
202+
if c_out is None:
203+
c_out = y_c_out
204+
205+
# Create ouput folder if necessary
206+
save_dir = config["PATHS"].get("SAVE_DIR", None)
207+
if save_dir is not None and not exists(save_dir):
208+
mkdir(save_dir)
209+
210+
track(
211+
pattern=config["PATHS"]["FILES_PATTERN"],
212+
output_dir=save_dir,
213+
c_out=c_out,
214+
**obs_class,
215+
virtual=int(config.get("VIRTUAL_LENGTH_MAX", 0)),
216+
previous_correspondance=c_in,
217+
memory=args.memory,
218+
correspondances_only=args.save_correspondance_and_stop,
219+
raw=not args.unraw,
220+
zarr=args.zarr,
221+
nb_obs_min=int(config.get("TRACK_DURATION_MIN", 10)),
222+
blank_period=args.blank_period,
223+
)
224+
225+
226+
def browse_dataset_in(
227+
data_dir,
228+
files_model,
229+
date_regexp,
230+
date_model,
231+
start_date=None,
232+
end_date=None,
233+
sub_sampling_step=1,
234+
files=None,
235+
):
236+
pattern_regexp = re_compile(".*/" + date_regexp)
237+
if files is not None:
238+
filenames = bytes_(files)
239+
else:
240+
full_path = join_path(data_dir, files_model)
241+
logger.info("Search files : %s", full_path)
242+
filenames = bytes_(glob(full_path))
243+
244+
dataset_list = empty(
245+
len(filenames),
246+
dtype=[
247+
("filename", "S500"),
248+
("date", "datetime64[D]"),
249+
],
250+
)
251+
dataset_list["filename"] = filenames
252+
253+
logger.info("%s grids available", dataset_list.shape[0])
254+
mode_attrs = False
255+
if "(" not in date_regexp:
256+
logger.debug("Attrs date : %s", date_regexp)
257+
mode_attrs = date_regexp.strip().split(":")
258+
else:
259+
logger.debug("Pattern date : %s", date_regexp)
260+
261+
for item in dataset_list:
262+
str_date = None
263+
if mode_attrs:
264+
with Dataset(item["filename"].decode("utf-8")) as h:
265+
if len(mode_attrs) == 1:
266+
str_date = getattr(h, mode_attrs[0])
267+
else:
268+
str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1])
269+
else:
270+
result = pattern_regexp.match(str(item["filename"]))
271+
if result:
272+
str_date = result.groups()[0]
273+
274+
if str_date is not None:
275+
item["date"] = datetime.strptime(str_date, date_model).date()
276+
277+
dataset_list.sort(order=["date", "filename"])
278+
279+
steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1])
280+
if len(steps) > 1:
281+
raise Exception("Several days steps in grid dataset %s" % steps)
282+
283+
if sub_sampling_step != 1:
284+
logger.info("Grid subsampling %d", sub_sampling_step)
285+
dataset_list = dataset_list[::sub_sampling_step]
286+
287+
if start_date is not None or end_date is not None:
288+
logger.info(
289+
"Available grid from %s to %s",
290+
dataset_list[0]["date"],
291+
dataset_list[-1]["date"],
292+
)
293+
logger.info("Filtering grid by time %s, %s", start_date, end_date)
294+
mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date)
295+
296+
dataset_list = dataset_list[mask]
297+
return dataset_list
298+
299+
300+
def track(
301+
pattern,
302+
output_dir,
303+
c_out,
304+
nb_obs_min=10,
305+
raw=True,
306+
zarr=False,
307+
blank_period=0,
308+
correspondances_only=False,
309+
**kw_c,
310+
):
311+
kw = dict(date_regexp=".*_([0-9]*?).[nz].*", date_model="%Y%m%d")
312+
if isinstance(pattern, list):
313+
kw.update(dict(data_dir=None, files_model=None, files=pattern))
314+
else:
315+
kw.update(dict(data_dir=dirname(pattern), files_model=basename(pattern)))
316+
datasets = browse_dataset_in(**kw)
317+
if blank_period > 0:
318+
datasets = datasets[:-blank_period]
319+
logger.info("Last %d files will be pop", blank_period)
320+
321+
if nb_obs_min > len(datasets):
322+
raise Exception(
323+
"Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)."
324+
% (len(datasets), nb_obs_min)
325+
)
326+
327+
c = Correspondances(datasets=datasets["filename"], **kw_c)
328+
c.track()
329+
logger.info("Track finish")
330+
t0, t1 = c.period
331+
kw_save = dict(
332+
date_start=t0,
333+
date_stop=t1,
334+
date_prod=datetime.now(),
335+
path=output_dir,
336+
sign_type=c.current_obs.sign_legend,
337+
)
338+
339+
c.save(c_out, kw_save)
340+
if correspondances_only:
341+
return
342+
343+
logger.info("Start merging")
344+
c.prepare_merging()
345+
logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max())
346+
logger.info(
347+
"The mean length is %d observations for all tracks", c.nb_obs_by_tracks.mean()
348+
)
349+
350+
kw_write = dict(path=output_dir, zarr_flag=zarr)
351+
352+
c.get_unused_data(raw_data=raw).write_file(
353+
filename="%(path)s/%(sign_type)s_untracked.nc", **kw_write
354+
)
355+
356+
short_c = c._copy()
357+
short_c.shorter_than(size_max=nb_obs_min)
358+
c.longer_than(size_min=nb_obs_min)
359+
360+
long_track = c.merge(raw_data=raw)
361+
short_track = short_c.merge(raw_data=raw)
362+
363+
# We flag obs
364+
if c.virtual:
365+
long_track["virtual"][:] = long_track["time"] == 0
366+
long_track.filled_by_interpolation(long_track["virtual"] == 1)
367+
short_track["virtual"][:] = short_track["time"] == 0
368+
short_track.filled_by_interpolation(short_track["virtual"] == 1)
369+
370+
logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max())
371+
logger.info(
372+
"The mean length is %d observations for long track",
373+
c.nb_obs_by_tracks.mean(),
374+
)
375+
376+
long_track.write_file(**kw_write)
377+
short_track.write_file(
378+
filename="%(path)s/%(sign_type)s_track_too_short.nc", **kw_write
379+
)

src/py_eddy_tracker/appli/network.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@ def build_network():
3232
action="store_true",
3333
help="Use intern contour instead of outter contour",
3434
)
35+
36+
parser.memory_arg()
3537
args = parser.parse_args()
3638

37-
n = Network(args.identification_regex, window=args.window, intern=args.intern)
39+
n = Network(
40+
args.identification_regex,
41+
window=args.window,
42+
intern=args.intern,
43+
memory=args.memory,
44+
)
3845
group = n.group_observations(minimal_area=True)
3946
n.build_dataset(group).write_file(filename=args.out)
4047

src/py_eddy_tracker/featured_tracking/area_tracker.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

3-
from numpy import ma
3+
from numba import njit
4+
from numpy import empty, ma, ones
45

56
from ..observations.observation import EddiesObservations as Model
67

@@ -29,11 +30,8 @@ def needed_variable(cls):
2930
def tracking(self, other):
3031
shape = (self.shape[0], other.shape[0])
3132
i, j, c = self.match(other, intern=False)
32-
cost_mat = ma.empty(shape, dtype="f4")
33-
cost_mat.mask = ma.ones(shape, dtype="bool")
34-
m = c > self.cmin
35-
i, j, c = i[m], j[m], c[m]
36-
cost_mat[i, j] = 1 - c
33+
cost_mat = ma.array(empty(shape, dtype="f4"), mask=ones(shape, dtype="bool"))
34+
mask_cmin(i, j, c, self.cmin, cost_mat.data, cost_mat.mask)
3735

3836
i_self, i_other = self.solve_function(cost_mat)
3937
i_self, i_other = self.post_process_link(other, i_self, i_other)
@@ -55,3 +53,13 @@ def propagate(
5553
if nb_virtual_extend > 0:
5654
virtual[key][nb_dead:] = obs_to_extend[key]
5755
return virtual
56+
57+
58+
@njit(cache=True)
59+
def mask_cmin(i, j, c, cmin, cost_mat, mask):
60+
for k in range(c.shape[0]):
61+
c_ = c[k]
62+
if c_ > cmin:
63+
i_, j_ = i[k], j[k]
64+
cost_mat[i_, j_] = 1 - c_
65+
mask[i_, j_] = False

src/py_eddy_tracker/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def build_index(groups):
8383
first_index[group - i0 + 1 : next_group - i0 + 1] = i + 1
8484
last_index = zeros(amplitude, dtype=numba_types.int_)
8585
last_index[:-1] = first_index[1:]
86-
last_index[-1] = i + 2
86+
last_index[-1] = i + 1
8787
return first_index, last_index, i0
8888

8989

0 commit comments

Comments
 (0)