From 08c2393a053ee9bffcf77957fd369ddc120a0d1a Mon Sep 17 00:00:00 2001 From: Cori Pegliasco Date: Tue, 17 Aug 2021 12:14:07 +0200 Subject: [PATCH 1/2] - Majuscules + orthographe - n.period in int for particle advection - mise en page --- .../pet_eddy_detection_ACC.py | 13 +- examples/06_grid_manipulation/pet_lavd.py | 6 +- examples/16_network/pet_atlas.py | 28 +-- examples/16_network/pet_follow_particle.py | 4 +- examples/16_network/pet_relative.py | 6 +- .../16_network/pet_replay_segmentation.py | 8 +- examples/16_network/pet_segmentation_anim.py | 3 +- src/py_eddy_tracker/__init__.py | 15 +- src/py_eddy_tracker/appli/eddies.py | 18 +- src/py_eddy_tracker/appli/network.py | 4 +- src/py_eddy_tracker/eddy_feature.py | 19 +- src/py_eddy_tracker/observations/groups.py | 37 +++- src/py_eddy_tracker/observations/network.py | 175 ++++++++++++++---- .../observations/observation.py | 23 ++- src/py_eddy_tracker/observations/tracking.py | 33 +++- tests/test_grid.py | 10 +- 16 files changed, 306 insertions(+), 96 deletions(-) diff --git a/examples/02_eddy_identification/pet_eddy_detection_ACC.py b/examples/02_eddy_identification/pet_eddy_detection_ACC.py index e6c5e381..c799a45e 100644 --- a/examples/02_eddy_identification/pet_eddy_detection_ACC.py +++ b/examples/02_eddy_identification/pet_eddy_detection_ACC.py @@ -65,7 +65,8 @@ def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold" y_name="latitude", # Manual area subset indexs=dict( - latitude=slice(100 - margin, 220 + margin), longitude=slice(0, 230 + margin), + latitude=slice(100 - margin, 220 + margin), + longitude=slice(0, 230 + margin), ), ) g_raw = RegularGridDataset(**kw_data) @@ -187,10 +188,16 @@ def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold" ax.set_ylabel("With filter") ax.plot( - a_[field][i_a] * factor, a[field][j_a] * factor, "r.", label="Anticyclonic", + a_[field][i_a] * factor, + a[field][j_a] * factor, + "r.", + label="Anticyclonic", ) ax.plot( - c_[field][i_c] * factor, c[field][j_c] * factor, "b.", label="Cyclonic", + c_[field][i_c] * factor, + c[field][j_c] * factor, + "b.", + label="Cyclonic", ) ax.set_aspect("equal"), ax.grid() ax.plot((0, 1000), (0, 1000), "g") diff --git a/examples/06_grid_manipulation/pet_lavd.py b/examples/06_grid_manipulation/pet_lavd.py index d96c0b06..e597821c 100644 --- a/examples/06_grid_manipulation/pet_lavd.py +++ b/examples/06_grid_manipulation/pet_lavd.py @@ -159,7 +159,11 @@ def update(i_frame): # Format LAVD data lavd = RegularGridDataset.with_array( coordinates=("lon", "lat"), - datas=dict(lavd=lavd.T, lon=x_g, lat=y_g,), + datas=dict( + lavd=lavd.T, + lon=x_g, + lat=y_g, + ), centered=True, ) diff --git a/examples/16_network/pet_atlas.py b/examples/16_network/pet_atlas.py index 7f86790a..6927f169 100644 --- a/examples/16_network/pet_atlas.py +++ b/examples/16_network/pet_atlas.py @@ -153,33 +153,33 @@ def update_axes(ax, mappable=None): update_axes(ax, m).set_label("Pixel used in % all atlas") # %% -# All Spliting -# ------------ -# Display the occurence of spliting events +# All splitting +# ------------- +# Display the occurence of splitting events ax = start_axes("") -g_all_spliting = n.spliting_event().grid_count(bins) -m = g_all_spliting.display(ax, **kw_time, vmin=0, vmax=1) +g_all_splitting = n.splitting_event().grid_count(bins) +m = g_all_splitting.display(ax, **kw_time, vmin=0, vmax=1) update_axes(ax, m).set_label("Pixel used in % of time") # %% -# Ratio spliting events / eddy presence +# Ratio splitting events / eddy presence ax = start_axes("") -g_ = g_all_spliting.vars["count"] * 100.0 / g_all.vars["count"] -m = g_all_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +g_ = g_all_splitting.vars["count"] * 100.0 / g_all.vars["count"] +m = g_all_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) update_axes(ax, m).set_label("Pixel used in % all atlas") # %% -# Spliting in networks longer than 10 days -# ---------------------------------------- +# splitting in networks longer than 10 days +# ----------------------------------------- ax = start_axes("") -g_10_spliting = n10.spliting_event().grid_count(bins) -m = g_10_spliting.display(ax, **kw_time, vmin=0, vmax=1) +g_10_splitting = n10.splitting_event().grid_count(bins) +m = g_10_splitting.display(ax, **kw_time, vmin=0, vmax=1) update_axes(ax, m).set_label("Pixel used in % of time") # %% ax = start_axes("") g_ = ma.array( - g_10_spliting.vars["count"] * 100.0 / g_10.vars["count"], + g_10_splitting.vars["count"] * 100.0 / g_10.vars["count"], mask=g_10.vars["count"] < 365, ) -m = g_10_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +m = g_10_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) update_axes(ax, m).set_label("Pixel used in % all atlas") diff --git a/examples/16_network/pet_follow_particle.py b/examples/16_network/pet_follow_particle.py index a5a252e2..1c858879 100644 --- a/examples/16_network/pet_follow_particle.py +++ b/examples/16_network/pet_follow_particle.py @@ -125,9 +125,11 @@ def update(frame): # %% # Particle advection # ^^^^^^^^^^^^^^^^^^ +# Advection from speed contour to speed contour (default) + step = 1 / 60.0 -t_start, t_end = n.period +t_start, t_end = int(n.period[0]), int(n.period[1]) dt = 14 shape = (n.obs.size, 2) diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py index c4989edb..f5e8bc92 100644 --- a/examples/16_network/pet_relative.py +++ b/examples/16_network/pet_relative.py @@ -292,13 +292,13 @@ m1 # %% -# Get spliting event -# ------------------ +# Get splitting event +# ------------------- # Display the position of the eddies before a splitting fig = plt.figure(figsize=(15, 8)) ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) n.plot(ax, color_cycle=n.COLORS) -s0, s1, s1_start = n.spliting_event(triplet=True) +s0, s1, s1_start = n.splitting_event(triplet=True) s0.display(ax, color="violet", lw=2, label="Eddies before splitting") s1.display(ax, color="blueviolet", lw=2, label="Eddies after splitting") s1_start.display(ax, color="black", lw=2, label="Eddies starting by splitting") diff --git a/examples/16_network/pet_replay_segmentation.py b/examples/16_network/pet_replay_segmentation.py index d6b4568b..757854d5 100644 --- a/examples/16_network/pet_replay_segmentation.py +++ b/examples/16_network/pet_replay_segmentation.py @@ -149,7 +149,13 @@ def get_obs(dataset): n_.median_filter(15, "time", "latitude") kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30 ** 2 * 20 m = n_.scatter_timeline( - ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lon", method="all", + ax, + "shape_error_e", + vmin=14, + vmax=70, + **kw, + yfield="lon", + method="all", ) ax.set_ylabel("Longitude") cb = update_axes(ax, m["scatter"]) diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py index 503229e7..340163a1 100644 --- a/examples/16_network/pet_segmentation_anim.py +++ b/examples/16_network/pet_segmentation_anim.py @@ -96,7 +96,8 @@ def update(i_frame): indices_frames = INDICES[i_frame] mappable_CONTOUR.set_data( - e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames], + e.contour_lon_e[indices_frames], + e.contour_lat_e[indices_frames], ) mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) return (mappable_tracks,) diff --git a/src/py_eddy_tracker/__init__.py b/src/py_eddy_tracker/__init__.py index fbeb1450..f3ecec84 100644 --- a/src/py_eddy_tracker/__init__.py +++ b/src/py_eddy_tracker/__init__.py @@ -404,7 +404,7 @@ def identify_time(str_date): nc_dims=("obs",), nc_attr=dict( long_name="Previous observation index", - comment="Index of previous observation in a spliting case", + comment="Index of previous observation in a splitting case", ), ), next_obs=dict( @@ -422,14 +422,20 @@ def identify_time(str_date): nc_name="previous_cost", nc_type="float32", nc_dims=("obs",), - nc_attr=dict(long_name="Previous cost for previous observation", comment="",), + nc_attr=dict( + long_name="Previous cost for previous observation", + comment="", + ), ), next_cost=dict( attr_name=None, nc_name="next_cost", nc_type="float32", nc_dims=("obs",), - nc_attr=dict(long_name="Next cost for next observation", comment="",), + nc_attr=dict( + long_name="Next cost for next observation", + comment="", + ), ), n=dict( attr_name=None, @@ -640,7 +646,8 @@ def identify_time(str_date): nc_type="f4", nc_dims=("obs",), nc_attr=dict( - long_name="Log base 10 background chlorophyll", units="Log(Chl/[mg/m^3])", + long_name="Log base 10 background chlorophyll", + units="Log(Chl/[mg/m^3])", ), ), year=dict( diff --git a/src/py_eddy_tracker/appli/eddies.py b/src/py_eddy_tracker/appli/eddies.py index 4809fddf..df4e7d43 100644 --- a/src/py_eddy_tracker/appli/eddies.py +++ b/src/py_eddy_tracker/appli/eddies.py @@ -243,7 +243,8 @@ def browse_dataset_in( filenames = bytes_(glob(full_path)) dataset_list = empty( - len(filenames), dtype=[("filename", "S500"), ("date", "datetime64[s]")], + len(filenames), + dtype=[("filename", "S500"), ("date", "datetime64[s]")], ) dataset_list["filename"] = filenames @@ -371,7 +372,8 @@ def track( logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max()) logger.info( - "The mean length is %d observations for long track", c.nb_obs_by_tracks.mean(), + "The mean length is %d observations for long track", + c.nb_obs_by_tracks.mean(), ) long_track.write_file(**kw_write) @@ -381,7 +383,14 @@ def track( def get_group( - dataset1, dataset2, index1, index2, score, invalid=2, low=10, high=60, + dataset1, + dataset2, + index1, + index2, + score, + invalid=2, + low=10, + high=60, ): group1, group2 = dict(), dict() m_valid = (score * 100) >= invalid @@ -490,7 +499,8 @@ def get_values(v, dataset): ] labels = dict( - high=f"{high:0.0f} <= high", low=f"{invalid:0.0f} <= low < {low:0.0f}", + high=f"{high:0.0f} <= high", + low=f"{invalid:0.0f} <= low < {low:0.0f}", ) keys = [labels.get(key, key) for key in list(gr_ref.values())[0].keys()] diff --git a/src/py_eddy_tracker/appli/network.py b/src/py_eddy_tracker/appli/network.py index c1a752ee..5c4cdcaf 100644 --- a/src/py_eddy_tracker/appli/network.py +++ b/src/py_eddy_tracker/appli/network.py @@ -76,7 +76,9 @@ def subset_network(): help="Remove short dead end, first is for minimal obs number and second for minimal segment time to keep", ) parser.add_argument( - "--remove_trash", action="store_true", help="Remove trash (network id == 0)", + "--remove_trash", + action="store_true", + help="Remove trash (network id == 0)", ) parser.add_argument( "-p", diff --git a/src/py_eddy_tracker/eddy_feature.py b/src/py_eddy_tracker/eddy_feature.py index f6db848b..59a042fe 100644 --- a/src/py_eddy_tracker/eddy_feature.py +++ b/src/py_eddy_tracker/eddy_feature.py @@ -61,13 +61,13 @@ def __init__( """ Create amplitude object - :param Contours contour: - :param float contour_height: - :param array data: - :param float interval: + :param Contours contour: usefull class defined below + :param float contour_height: field value of the contour + :param array data: grid + :param float interval: step between two contours :param int mle: maximum number of local extrema in contour - :param int nb_step_min: number of intervals to consider an eddy - :param int nb_step_to_be_mle: number of intervals to be considered as an another maxima + :param int nb_step_min: minimum number of intervals to consider the contour as an eddy + :param int nb_step_to_be_mle: number of intervals to be considered as another extrema """ # Height of the contour @@ -116,8 +116,7 @@ def within_amplitude_limits(self): def all_pixels_below_h0(self, level): """ Check CSS11 criterion 1: The SSH values of all of the pixels - are below (above) a given SSH threshold for cyclonic (anticyclonic) - eddies. + are below a given SSH threshold for cyclonic eddies. """ # In some cases pixel value may be very close to the contour bounds if self.sla.mask.any() or ((self.sla.data - self.h_0) > self.EPSILON).any(): @@ -602,8 +601,8 @@ def display( 4. - Amplitude criterion (yellow) :param str field: Must be 'shape_error', 'x', 'y' or 'radius'. - If define display_criterion is not use. - bins argument must be define + If defined display_criterion is not use. + bins argument must be defined :param array bins: bins used to colorize contour :param str cmap: Name of cmap for field display :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py index 544fd5f5..6fea0ace 100644 --- a/src/py_eddy_tracker/observations/groups.py +++ b/src/py_eddy_tracker/observations/groups.py @@ -68,7 +68,7 @@ def get_missing_indices( def advect(x, y, c, t0, n_days): """ - Advect particle from t0 to t0 + n_days, with data cube. + Advect particles from t0 to t0 + n_days, with data cube. :param np.array(float) x: longitude of particles :param np.array(float) y: latitude of particles @@ -87,7 +87,17 @@ def advect(x, y, c, t0, n_days): return t, x, y -def particle_candidate(c, eddies, step_mesh, t_start, i_target, pct, **kwargs): +def particle_candidate( + c, + eddies, + step_mesh, + t_start, + i_target, + pct, + contour_start="speed", + contour_end="effective", + **kwargs +): """Select particles within eddies, advect them, return target observation and associated percentages :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles @@ -95,6 +105,8 @@ def particle_candidate(c, eddies, step_mesh, t_start, i_target, pct, **kwargs): :param int t_start: julian day of the advection :param np.array(int) i_target: corresponding obs where particles are advected :param np.array(int) pct: corresponding percentage of avected particles + :param str contour_start: contour where particles are injected + :param str contour_end: contour where particles are counted after advection :params dict kwargs: dict of params given to `advect` """ @@ -105,7 +117,14 @@ def particle_candidate(c, eddies, step_mesh, t_start, i_target, pct, **kwargs): # to be able to get global index translate_start = where(m_start)[0] - x, y, i_start = e.create_particles(step_mesh) + # Create particles in specified contour + if contour_start == "speed": + x, y, i_start = e.create_particles(step_mesh, intern=True) + elif contour_start == "effective": + x, y, i_start = e.create_particles(step_mesh, intern=False) + else: + x, y, i_start = e.create_particles(step_mesh, intern=True) + print("The contour_start was not correct, speed contour is used") # Advection t_end, x, y = advect(x, y, c, t_start, **kwargs) @@ -117,8 +136,14 @@ def particle_candidate(c, eddies, step_mesh, t_start, i_target, pct, **kwargs): # to be able to get global index translate_end = where(m_end)[0] - # Id eddies for each alive particle (in core and extern) - i_end = e_end.contains(x, y) + # Id eddies for each alive particle in specified contour + if contour_end == "speed": + i_end = e_end.contains(x, y, intern=True) + elif contour_end == "effective": + i_end = e_end.contains(x, y, intern=False) + else: + i_end = e_end.contains(x, y, intern=True) + print("The contour_end was not correct, speed contour is used") # compute matrix and fill target array get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct) @@ -206,7 +231,7 @@ def filled_by_interpolation(self, mask): ) def insert_virtual(self): - """insert virtual observations on segments where observations are missing""" + """Insert virtual observations on segments where observations are missing""" dt_theorical = median(self.time[1:] - self.time[:-1]) indices = self.get_missing_indices(dt_theorical) diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index cb6d3986..0ae80634 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -23,6 +23,9 @@ zeros, ) +import netCDF4 +import zarr + from ..dataset.grid import GridCollection from ..generic import build_index, wrap_longitude from ..poly import bbox_intersection, vertice_overlap @@ -147,7 +150,7 @@ def get_missing_indices(self, dt): ) def fix_next_previous_obs(self): - """function used after 'insert_virtual', to correct next_obs and + """Function used after 'insert_virtual', to correct next_obs and previous obs. """ @@ -577,7 +580,7 @@ def close_network(self, other, nb_obs_min=10, **kwargs): return other.extract_with_mask(m) def normalize_longitude(self): - """Normalize all longitude + """Normalize all longitudes Normalize longitude field and in the same range : - longitude_max @@ -677,7 +680,13 @@ def display_timeline( """ self.only_one_network() j = 0 - line_kw = dict(ls="-", marker="+", markersize=6, zorder=1, lw=3,) + line_kw = dict( + ls="-", + marker="+", + markersize=6, + zorder=1, + lw=3, + ) line_kw.update(kwargs) mappables = dict(lines=list()) @@ -719,7 +728,7 @@ def display_timeline( def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="roll"): """Mark events in plot""" j = 0 - events = dict(spliting=[], merging=[]) + events = dict(splitting=[], merging=[]) # TODO : fill mappables dict y_seg = dict() @@ -784,15 +793,15 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol ) ) ax.plot((x[0], _time[i_p]), (y0, y1), **event_kw)[0] - events["spliting"].append((x[0], y0)) + events["splitting"].append((x[0], y0)) j += 1 kwargs = dict(color="k", zorder=-1, linestyle=" ") - if len(events["spliting"]) > 0: - X, Y = list(zip(*events["spliting"])) + if len(events["splitting"]) > 0: + X, Y = list(zip(*events["splitting"])) ref = ax.plot( - X, Y, marker="*", markersize=12, label="spliting events", **kwargs + X, Y, marker="*", markersize=12, label="splitting events", **kwargs )[0] mappables.setdefault("events", []).append(ref) @@ -910,7 +919,10 @@ def event_map(self, ax, **kwargs): """Add the merging and splitting events to a map""" j = 0 mappables = dict() - symbol_kw = dict(markersize=10, color="k",) + symbol_kw = dict( + markersize=10, + color="k", + ) symbol_kw.update(kwargs) symbol_kw_split = symbol_kw.copy() symbol_kw_split["markersize"] += 4 @@ -939,7 +951,13 @@ def event_map(self, ax, **kwargs): return mappables def scatter( - self, ax, name="time", factor=1, ref=None, edgecolor_cycle=None, **kwargs, + self, + ax, + name="time", + factor=1, + ref=None, + edgecolor_cycle=None, + **kwargs, ): """ This function scatters the path of each network, with the merging and splitting events @@ -1001,6 +1019,8 @@ def segment_track_array(self): return build_unique_array(self.segment, self.track) def birth_event(self): + """Extract birth events. + Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash.""" # FIXME how to manage group 0 indices = list() previous_obs = self.previous_obs @@ -1014,6 +1034,8 @@ def birth_event(self): return self.extract_event(list(set(indices))) def death_event(self): + """Extract death events. + Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash.""" # FIXME how to manage group 0 indices = list() next_obs = self.next_obs @@ -1064,7 +1086,7 @@ def merging_event(self, triplet=False, only_index=False): else: return self.extract_event(idx_m1) - def spliting_event(self, triplet=False, only_index=False): + def splitting_event(self, triplet=False, only_index=False): """Return observation before a splitting event. If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event, @@ -1105,7 +1127,7 @@ def spliting_event(self, triplet=False, only_index=False): def dissociate_network(self): """ - Dissociate networks with no known interaction (spliting/merging) + Dissociate networks with no known interaction (splitting/merging) """ tags = self.tag_segment(multi_network=True) @@ -1183,7 +1205,7 @@ def fully_connected(self): def remove_trash(self): """ - Remove the lonely eddies (only 1 obs in segment, associated segment number is 0) + Remove the lonely eddies (only 1 obs in segment, associated network number is 0) """ return self.extract_with_mask(self.track != 0) @@ -1372,7 +1394,7 @@ def analysis_coherence( date_function, uv_params, advection_mode="both", - dt_advect=14, + n_days=14, step_mesh=1.0 / 50, output_name=None, dissociate_network=False, @@ -1380,7 +1402,26 @@ def analysis_coherence( remove_dead_end=0, ): - """Global function to analyse segments coherence, with network preprocessing""" + """Global function to analyse segments coherence, with network preprocessing. + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: nuber of days for advection + :param float step_mesh: step for particule mesh in degrees + :param str output_name: path/name for the output (without extension) to store the clean + network in .nc and the coherence results in .zarr. Works only for advection_mode = "both" + :param bool dissociate_network: If True apply + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.dissociate_network` + :param int correct_close_events: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.correct_close_events` + :param int remove_dead_end: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.remove_dead_end` + :return target_forward, target_bakward: 2D numpy.array with the eddy observation the + particles ended in after advection + :return target_forward, target_bakward: percentage of ending particles within the + eddy observation with regards to the starting number + """ if dissociate_network: self.dissociate_network() @@ -1393,19 +1434,53 @@ def analysis_coherence( else: network_clean = self - res = network_clean.segment_coherence( - date_function=date_function, - uv_params=uv_params, - advection_mode=advection_mode, - output_name=output_name, - dt_advect=dt_advect, - step_mesh=step_mesh, - ) + network_clean.numbering_segment() + + res = [] + if (advection_mode == "both") | (advection_mode == "forward"): + target_forward, pct_forward = network_clean.segment_coherence_forward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_forward, pct_forward] + + if (advection_mode == "both") | (advection_mode == "backward"): + target_backward, pct_backward = network_clean.segment_coherence_backward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_backward, pct_backward] + + if (output_name is not None) & (advection_mode == "both"): + # TODO : put some path verification? + # Save the clean network in netcdf + with netCDF4.Dataset(output_name + ".nc", "w") as fh: + network_clean.to_netcdf(fh) + # Save the results of particles advection in zarr + # zarr compression parameters + # TODO : check size? compression? + params_seg = dict() + params_pct = dict() + zg = zarr.open(output_name + ".zarr", mode="w") + zg.array("target_forward", target_forward, **params_seg) + zg.array("pct_forward", pct_forward, **params_pct) + zg.array("target_backward", target_backward, **params_seg) + zg.array("pct_backward", pct_backward, **params_pct) return network_clean, res def segment_coherence_backward( - self, date_function, uv_params, n_days=14, step_mesh=1.0 / 50, output_name=None, + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", ): """ @@ -1434,7 +1509,7 @@ def date2file(julian_day): itb_final = -ones((self.obs.size, 2), dtype="i4") ptb_final = zeros((self.obs.size, 2), dtype="i1") - t_start, t_end = self.period + t_start, t_end = int(self.period[0]), int(self.period[1]) dates = arange(t_start, t_start + n_days + 1) first_files = [date_function(x) for x in dates] @@ -1455,17 +1530,33 @@ def date2file(julian_day): # add next date to GridCollection and delete last date c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) particle_candidate( - c, self, step_mesh, _t, itb_final, ptb_final, n_days=-n_days + c, + self, + step_mesh, + _t, + itb_final, + ptb_final, + n_days=-n_days, + contour_start=contour_start, + contour_end=contour_end, + ) + logger.info( + ( + f"coherence {_t} / {range_end-1} ({(_t - range_start) / (range_end - range_start-1):.1%})" + f" : {time.time()-_timestamp:5.2f}s" + ) ) - logger.info(( - f"coherence {_t} / {range_end-1} ({(_t - range_start) / (range_end - range_start-1):.1%})" - f" : {time.time()-_timestamp:5.2f}s" - )) return itb_final, ptb_final def segment_coherence_forward( - self, date_function, uv_params, n_days=14, step_mesh=1.0 / 50, + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", ): """ @@ -1494,7 +1585,7 @@ def date2file(julian_day): itf_final = -ones((self.obs.size, 2), dtype="i4") ptf_final = zeros((self.obs.size, 2), dtype="i1") - t_start, t_end = self.period + t_start, t_end = int(self.period[0]), int(self.period[1]) # if begin is not None and begin > t_start: # t_start = begin # if end is not None and end < t_end: @@ -1519,12 +1610,22 @@ def date2file(julian_day): # add next date to GridCollection and delete last date c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) particle_candidate( - c, self, step_mesh, _t, itf_final, ptf_final, n_days=n_days + c, + self, + step_mesh, + _t, + itf_final, + ptf_final, + n_days=n_days, + contour_start=contour_start, + contour_end=contour_end, + ) + logger.info( + ( + f"coherence {_t} / {range_end-1} ({(_t - range_start) / (range_end - range_start-1):.1%})" + f" : {time.time()-_timestamp:5.2f}s" + ) ) - logger.info(( - f"coherence {_t} / {range_end-1} ({(_t - range_start) / (range_end - range_start-1):.1%})" - f" : {time.time()-_timestamp:5.2f}s" - )) return itf_final, ptf_final diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index db0c2a45..dec9a6b0 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -702,7 +702,11 @@ def load_file(cls, filename, **kwargs): .. code-block:: python kwargs_latlon_300 = dict( - include_vars=["longitude", "latitude",], indexs=dict(obs=slice(0, 300)), + include_vars=[ + "longitude", + "latitude", + ], + indexs=dict(obs=slice(0, 300)), ) small_dataset = TrackEddiesObservations.load_file( filename, **kwargs_latlon_300 @@ -1973,7 +1977,11 @@ def bins_stat(self, xname, bins=None, yname=None, method=None, mask=None): def format_label(self, label): t0, t1 = self.period - return label.format(t0=t0, t1=t1, nb_obs=len(self),) + return label.format( + t0=t0, + t1=t1, + nb_obs=len(self), + ) def display(self, ax, ref=None, extern_only=False, intern_only=False, **kwargs): """Plot the speed and effective (dashed) contour of the eddies @@ -2283,7 +2291,7 @@ def nb_days(self): return self.period[1] - self.period[0] + 1 def create_particles(self, step, intern=True): - """create particles only inside speed contour. Avoid creating too large numpy arrays, only to me masked + """Create particles inside contour (Default : speed contour). Avoid creating too large numpy arrays, only to be masked :param step: step for particles :type step: float @@ -2345,7 +2353,14 @@ def grid_count_pixel_in( x_, y_ = reduce_size(x_, y_) v = create_vertice(x_, y_) (x_start, x_stop), (y_start, y_stop) = bbox_indice_regular( - v, x_bounds, y_bounds, xstep, ystep, N, is_circular, x_size, + v, + x_bounds, + y_bounds, + xstep, + ystep, + N, + is_circular, + x_size, ) i, j = get_pixel_in_regular(v, x_c, y_c, x_start, x_stop, y_start, y_stop) grid_count_(grid, i, j) diff --git a/src/py_eddy_tracker/observations/tracking.py b/src/py_eddy_tracker/observations/tracking.py index 492842c7..3aa43387 100644 --- a/src/py_eddy_tracker/observations/tracking.py +++ b/src/py_eddy_tracker/observations/tracking.py @@ -578,7 +578,10 @@ def close_tracks(self, other, nb_obs_min=10, **kwargs): def format_label(self, label): t0, t1 = self.period return label.format( - t0=t0, t1=t1, nb_obs=len(self), nb_tracks=(self.nb_obs_by_track != 0).sum(), + t0=t0, + t1=t1, + nb_obs=len(self), + nb_tracks=(self.nb_obs_by_track != 0).sum(), ) def plot(self, ax, ref=None, **kwargs): @@ -702,7 +705,16 @@ def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs): @staticmethod def get_previous_obs( - i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.01, + **kwargs, ): """Backward association of observations to the segments""" time_cur = int_(ids["time"][i_current]) @@ -720,7 +732,7 @@ def get_previous_obs( c = zeros(len(xj)) c[ij] = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) # We remove low overlap - c[c < 0.01] = 0 + c[c < min_overlap] = 0 # We get index of maximal overlap i = c.argmax() c_i = c[i] @@ -732,7 +744,18 @@ def get_previous_obs( break @staticmethod - def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs): + def get_next_obs( + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.01, + **kwargs, + ): """Forward association of observations to the segments""" time_max = time_e.shape[0] - 1 time_cur = int_(ids["time"][i_current]) @@ -752,7 +775,7 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg c = zeros(len(xj)) c[ij] = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) # We remove low overlap - c[c < 0.01] = 0 + c[c < min_overlap] = 0 # We get index of maximal overlap i = c.argmax() c_i = c[i] diff --git a/tests/test_grid.py b/tests/test_grid.py index 34187357..2c89550a 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -7,7 +7,15 @@ G = RegularGridDataset(get_demo_path("mask_1_60.nc"), "lon", "lat") X = 0.025 -contour = Path(((-X, 0), (X, 0), (X, X), (-X, X), (-X, 0),)) +contour = Path( + ( + (-X, 0), + (X, 0), + (X, X), + (-X, X), + (-X, 0), + ) +) # contour From 80c529a0981d3e56bc5efd4eddf0f165aa7c6a61 Mon Sep 17 00:00:00 2001 From: Cori Pegliasco Date: Tue, 17 Aug 2021 14:00:51 +0200 Subject: [PATCH 2/2] numba requires specific numpy version because doc compil finds error: numpy 1.21.2 is installed but numpy<1.21,>=1.17 is required by {'numba'} --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 097e786a..477cf32d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ matplotlib netCDF4 numba>=0.53 -numpy +numpy<1.21 opencv-python pint polygon3