diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..d9437d16 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,70 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + schedule: + - cron: '41 16 * * 4' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index a6fcceed..f2f4753e 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -11,7 +11,7 @@ jobs: matrix: # os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, windows-latest] - python_version: [3.7, 3.8, 3.9] + python_version: ['3.10', '3.11', '3.12'] name: Run py eddy tracker build tests runs-on: ${{ matrix.os }} defaults: diff --git a/.readthedocs.yml b/.readthedocs.yml index 1299f38e..ba36f8ea 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,7 +1,13 @@ version: 2 conda: environment: doc/environment.yml +build: + os: ubuntu-lts-latest + tools: + python: "mambaforge-latest" python: - install: - - method: setuptools - path: . + install: + - method: pip + path: . +sphinx: + configuration: doc/conf.py \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 110c6081..6d6d6a30 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,7 +7,51 @@ The format is based on `Keep a Changelog `_ and this project adheres to `Semantic Versioning `_. [Unreleased] ------------- +------------- +Changed +^^^^^^^ + +Fixed +^^^^^ + +Added +^^^^^ + +[3.6.2] - 2025-06-06 +-------------------- +Changed +^^^^^^^ + +- Remove dead end method for network will move dead end to the trash and not remove observations + +Fixed +^^^^^ + +- Fix matplotlib version + +[3.6.1] - 2022-10-14 +-------------------- +Changed +^^^^^^^ + +- Rewrite particle candidate to be easily parallelize + +Fixed +^^^^^ + +- Check strictly increasing coordinates for RegularGridDataset. +- Grid mask is check to replace mask monovalue by 2D mask with fixed value + +Added +^^^^^ + +- Add method to colorize contour with a field +- Add option to force align on to return all step for reference dataset +- Add method and property to network to easily select segment and network +- Add method to found same track/segment/network in dataset + +[3.6.0] - 2022-01-12 +-------------------- Changed ^^^^^^^ @@ -15,8 +59,8 @@ Changed New identifications are produced with this type, old files could still be loaded. If you use old identifications for tracking use the `--unraw` option to unpack old times and store data with the new format. - Now amplitude is stored with .1 mm of precision (instead of 1 mm), same advice as for time. -- expose more parameters to users for bash tools build_network & divide_network -- add warning when loading a file created from a previous version of py-eddy-tracker. +- Expose more parameters to users for bash tools build_network & divide_network +- Add warning when loading a file created from a previous version of py-eddy-tracker. diff --git a/README.md b/README.md index e26e15ac..0cc34894 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ [![PyPI version](https://badge.fury.io/py/pyEddyTracker.svg)](https://badge.fury.io/py/pyEddyTracker) +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6333988.svg)](https://doi.org/10.5281/zenodo.6333988) [![Documentation Status](https://readthedocs.org/projects/py-eddy-tracker/badge/?version=stable)](https://py-eddy-tracker.readthedocs.io/en/stable/?badge=stable) [![Gitter](https://badges.gitter.im/py-eddy-tracker/community.svg)](https://gitter.im/py-eddy-tracker/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/AntSimi/py-eddy-tracker/master?urlpath=lab/tree/notebooks/python_module/) @@ -6,26 +7,32 @@ # README # +### How to cite code? ### + +Zenodo provide DOI for each tagged version, [all DOI are available here](https://doi.org/10.5281/zenodo.6333988) + ### Method ### Method was described in : -[Pegliasco, C., Delepoulle, A., Morrow, R., Faugère, Y., and Dibarboure, G.: META3.1exp : A new Global Mesoscale Eddy Trajectories Atlas derived from altimetry, Earth Syst. Sci. Data Discuss.](https://doi.org/10.5194/essd-2021-300) +[Pegliasco, C., Delepoulle, A., Morrow, R., Faugère, Y., and Dibarboure, G.: META3.1exp : A new Global Mesoscale Eddy Trajectories Atlas derived from altimetry, Earth Syst. Sci. Data Discuss.](https://doi.org/10.5194/essd-14-1087-2022) [Mason, E., A. Pascual, and J. C. McWilliams, 2014: A new sea surface height–based code for oceanic mesoscale eddy tracking.](https://doi.org/10.1175/JTECH-D-14-00019.1) ### Use case ### Method is used in : - + [Mason, E., A. Pascual, P. Gaube, S.Ruiz, J. Pelegrí, A. Delepoulle, 2017: Subregional characterization of mesoscale eddies across the Brazil-Malvinas Confluence](https://doi.org/10.1002/2016JC012611) ### How do I get set up? ### #### Short story #### + ```bash pip install pyeddytracker ``` + #### Long story #### To avoid problems with installation, use of the virtualenv Python virtual environment is recommended. @@ -36,12 +43,20 @@ Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, ... pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pint polygon3 ``` -Then run the following to install the eddy tracker: +Clone : + +```bash +git clone https://github.com/AntSimi/py-eddy-tracker +``` + +Then run the following to install the eddy tracker : ```bash python setup.py install ``` + ### Tools gallery ### + Several examples based on py eddy tracker module are [here](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html). [![](https://py-eddy-tracker.readthedocs.io/en/latest/_static/logo.png)](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html) diff --git a/check.sh b/check.sh index ddafab69..a402bf52 100644 --- a/check.sh +++ b/check.sh @@ -1,7 +1,5 @@ -isort src tests examples -black src tests examples -blackdoc src tests examples -flake8 tests examples src --count --select=E9,F63,F7,F82 --show-source --statistics -# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -flake8 tests examples src --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -pytest -vv --cov py_eddy_tracker --cov-report html +isort . +black . +blackdoc . +flake8 . +python -m pytest -vv --cov py_eddy_tracker --cov-report html diff --git a/doc/conf.py b/doc/conf.py index ccf26e4e..0844d585 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -96,9 +96,9 @@ master_doc = "index" # General information about the project. -project = u"py-eddy-tracker" -copyright = u"2019, A. Delepoulle & E. Mason" -author = u"A. Delepoulle & E. Mason" +project = "py-eddy-tracker" +copyright = "2019, A. Delepoulle & E. Mason" +author = "A. Delepoulle & E. Mason" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -272,8 +272,8 @@ ( master_doc, "py-eddy-tracker.tex", - u"py-eddy-tracker Documentation", - u"A. Delepoulle \\& E. Mason", + "py-eddy-tracker Documentation", + "A. Delepoulle \\& E. Mason", "manual", ), ] @@ -304,7 +304,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "py-eddy-tracker", u"py-eddy-tracker Documentation", [author], 1) + (master_doc, "py-eddy-tracker", "py-eddy-tracker Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -320,7 +320,7 @@ ( master_doc, "py-eddy-tracker", - u"py-eddy-tracker Documentation", + "py-eddy-tracker Documentation", author, "py-eddy-tracker", "One line description of project.", diff --git a/doc/environment.yml b/doc/environment.yml index 7dcb504d..063a60de 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -1,10 +1,12 @@ channels: - conda-forge - - defaults dependencies: - - python=3.8 + - python>=3.10, <3.13 - ffmpeg + - pip - pip: + - -r ../requirements.txt + - git+https://github.com/AntSimi/py-eddy-tracker-sample-id.git - sphinx-gallery - sphinx_rtd_theme - sphinx>=3.1 diff --git a/doc/grid_identification.rst b/doc/grid_identification.rst index c645f80c..2cc3fb52 100644 --- a/doc/grid_identification.rst +++ b/doc/grid_identification.rst @@ -47,38 +47,42 @@ Activate verbose .. code-block:: python from py_eddy_tracker import start_logger - start_logger().setLevel('DEBUG') # Available options: ERROR, WARNING, INFO, DEBUG + + start_logger().setLevel("DEBUG") # Available options: ERROR, WARNING, INFO, DEBUG Run identification .. code-block:: python from datetime import datetime + h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) date = datetime(2019, 2, 23) a, c = h.eddy_identification( - 'adt', 'ugos', 'vgos', # Variables used for identification - date, # Date of identification - 0.002, # step between two isolines of detection (m) - pixel_limit=(5, 2000), # Min and max pixel count for valid contour - shape_error=55, # Error max (%) between ratio of circle fit and contour - ) + "adt", + "ugos", + "vgos", # Variables used for identification + date, # Date of identification + 0.002, # step between two isolines of detection (m) + pixel_limit=(5, 2000), # Min and max pixel count for valid contour + shape_error=55, # Error max (%) between ratio of circle fit and contour + ) Plot the resulting identification .. code-block:: python - fig = plt.figure(figsize=(15,7)) - ax = fig.add_axes([.03,.03,.94,.94]) - ax.set_title('Eddies detected -- Cyclonic(red) and Anticyclonic(blue)') - ax.set_ylim(-75,75) - ax.set_xlim(0,360) - ax.set_aspect('equal') - a.display(ax, color='b', linewidth=.5) - c.display(ax, color='r', linewidth=.5) + fig = plt.figure(figsize=(15, 7)) + ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) + ax.set_title("Eddies detected -- Cyclonic(red) and Anticyclonic(blue)") + ax.set_ylim(-75, 75) + ax.set_xlim(0, 360) + ax.set_aspect("equal") + a.display(ax, color="b", linewidth=0.5) + c.display(ax, color="r", linewidth=0.5) ax.grid() - fig.savefig('share/png/eddies.png') + fig.savefig("share/png/eddies.png") .. image:: ../share/png/eddies.png @@ -87,7 +91,8 @@ Save identification data .. code-block:: python from netCDF import Dataset - with Dataset(date.strftime('share/Anticyclonic_%Y%m%d.nc'), 'w') as h: + + with Dataset(date.strftime("share/Anticyclonic_%Y%m%d.nc"), "w") as h: a.to_netcdf(h) - with Dataset(date.strftime('share/Cyclonic_%Y%m%d.nc'), 'w') as h: + with Dataset(date.strftime("share/Cyclonic_%Y%m%d.nc"), "w") as h: c.to_netcdf(h) diff --git a/doc/grid_load_display.rst b/doc/grid_load_display.rst index 2e570274..2f0e3765 100644 --- a/doc/grid_load_display.rst +++ b/doc/grid_load_display.rst @@ -7,7 +7,12 @@ Loading grid .. code-block:: python from py_eddy_tracker.dataset.grid import RegularGridDataset - grid_name, lon_name, lat_name = 'share/nrt_global_allsat_phy_l4_20190223_20190226.nc', 'longitude', 'latitude' + + grid_name, lon_name, lat_name = ( + "share/nrt_global_allsat_phy_l4_20190223_20190226.nc", + "longitude", + "latitude", + ) h = RegularGridDataset(grid_name, lon_name, lat_name) Plotting grid @@ -15,14 +20,15 @@ Plotting grid .. code-block:: python from matplotlib import pyplot as plt + fig = plt.figure(figsize=(14, 12)) - ax = fig.add_axes([.02, .51, .9, .45]) - ax.set_title('ADT (m)') + ax = fig.add_axes([0.02, 0.51, 0.9, 0.45]) + ax.set_title("ADT (m)") ax.set_ylim(-75, 75) - ax.set_aspect('equal') - m = h.display(ax, name='adt', vmin=-1, vmax=1) + ax.set_aspect("equal") + m = h.display(ax, name="adt", vmin=-1, vmax=1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .51, .01, .45])) + plt.colorbar(m, cax=fig.add_axes([0.94, 0.51, 0.01, 0.45])) Filtering @@ -30,27 +36,27 @@ Filtering .. code-block:: python h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) Save grid .. code-block:: python - h.write('/tmp/grid.nc') + h.write("/tmp/grid.nc") Add second plot .. code-block:: python - ax = fig.add_axes([.02, .02, .9, .45]) - ax.set_title('ADT Filtered (m)') - ax.set_aspect('equal') + ax = fig.add_axes([0.02, 0.02, 0.9, 0.45]) + ax.set_title("ADT Filtered (m)") + ax.set_aspect("equal") ax.set_ylim(-75, 75) - m = h.display(ax, name='adt', vmin=-.1, vmax=.1) + m = h.display(ax, name="adt", vmin=-0.1, vmax=0.1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .02, .01, .45])) - fig.savefig('share/png/filter.png') + plt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.45])) + fig.savefig("share/png/filter.png") .. image:: ../share/png/filter.png \ No newline at end of file diff --git a/doc/spectrum.rst b/doc/spectrum.rst index d751b909..f96e30a0 100644 --- a/doc/spectrum.rst +++ b/doc/spectrum.rst @@ -11,7 +11,7 @@ Load data raw = RegularGridDataset(grid_name, lon_name, lat_name) filtered = RegularGridDataset(grid_name, lon_name, lat_name) - filtered.bessel_low_filter('adt', 150, order=3) + filtered.bessel_low_filter("adt", 150, order=3) areas = dict( sud_pacific=dict(llcrnrlon=188, urcrnrlon=280, llcrnrlat=-64, urcrnrlat=-7), @@ -23,24 +23,33 @@ Compute and display spectrum .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum') - ax.set_xlabel('km') + ax.set_title("Spectrum") + ax.set_xlabel("km") for name_area, area in areas.items(): - - lon_spec, lat_spec = raw.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s raw' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s raw' % name_area, color=mappable.get_color(), linestyle='--') - - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s high' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = raw.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s raw" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s raw" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s high" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum.png') + fig.savefig("share/png/spectrum.png") .. image:: ../share/png/spectrum.png @@ -49,18 +58,23 @@ Compute and display spectrum ratio .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum ratio') - ax.set_xlabel('km') + ax.set_title("Spectrum ratio") + ax.set_xlabel("km") for name_area, area in areas.items(): - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area, ref=raw) - mappable = ax.plot(*lat_spec, label='lat %s high' % name_area)[0] - ax.plot(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area, ref=raw) + mappable = ax.plot(*lat_spec, label="lat %s high" % name_area)[0] + ax.plot( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum_ratio.png') + fig.savefig("share/png/spectrum_ratio.png") .. image:: ../share/png/spectrum_ratio.png diff --git a/environment.yml b/environment.yml index 4ea8f840..e94c7bc1 100644 --- a/environment.yml +++ b/environment.yml @@ -1,9 +1,9 @@ name: binder-pyeddytracker channels: - conda-forge - - defaults dependencies: - - python=3.8 + - python>=3.10, <3.13 + - pip - ffmpeg - pip: - -r requirements.txt diff --git a/examples/01_general_things/pet_storage.py b/examples/01_general_things/pet_storage.py index ccd01f1c..918ebbee 100644 --- a/examples/01_general_things/pet_storage.py +++ b/examples/01_general_things/pet_storage.py @@ -15,9 +15,9 @@ manage eddies associated in networks, the ```track``` and ```segment``` fields allow to separate observations """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, outer +import py_eddy_tracker_sample from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.observations.network import NetworkObservations diff --git a/examples/02_eddy_identification/pet_eddy_detection_ACC.py b/examples/02_eddy_identification/pet_eddy_detection_ACC.py index e6c5e381..d12c62f3 100644 --- a/examples/02_eddy_identification/pet_eddy_detection_ACC.py +++ b/examples/02_eddy_identification/pet_eddy_detection_ACC.py @@ -7,10 +7,10 @@ Two detections are provided : with a filtered ADT and without filtering """ + from datetime import datetime -from matplotlib import pyplot as plt -from matplotlib import style +from matplotlib import pyplot as plt, style from py_eddy_tracker import data from py_eddy_tracker.dataset.grid import RegularGridDataset @@ -65,7 +65,8 @@ def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold" y_name="latitude", # Manual area subset indexs=dict( - latitude=slice(100 - margin, 220 + margin), longitude=slice(0, 230 + margin), + latitude=slice(100 - margin, 220 + margin), + longitude=slice(0, 230 + margin), ), ) g_raw = RegularGridDataset(**kw_data) @@ -80,7 +81,7 @@ def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold" # Identification # ^^^^^^^^^^^^^^ # Run the identification step with slices of 2 mm -date = datetime(2016, 5, 15) +date = datetime(2019, 2, 23) kw_ident = dict( date=date, step=0.002, shape_error=70, sampling=30, uname="u", vname="v" ) @@ -187,10 +188,16 @@ def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold" ax.set_ylabel("With filter") ax.plot( - a_[field][i_a] * factor, a[field][j_a] * factor, "r.", label="Anticyclonic", + a_[field][i_a] * factor, + a[field][j_a] * factor, + "r.", + label="Anticyclonic", ) ax.plot( - c_[field][i_c] * factor, c[field][j_c] * factor, "b.", label="Cyclonic", + c_[field][i_c] * factor, + c[field][j_c] * factor, + "b.", + label="Cyclonic", ) ax.set_aspect("equal"), ax.grid() ax.plot((0, 1000), (0, 1000), "g") diff --git a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py index f9e5d4c3..fa27a3d1 100644 --- a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py +++ b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py @@ -43,7 +43,7 @@ def update_axes(ax, mappable=None): # %% # Compute and store eke in cm²/s² aviso_map.add_grid( - "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100 ** 2) + "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100**2) ) eke_kwargs = dict(vmin=1, vmax=1000, cmap="magma_r") diff --git a/examples/02_eddy_identification/pet_statistics_on_identification.py b/examples/02_eddy_identification/pet_statistics_on_identification.py new file mode 100644 index 00000000..dbd73c61 --- /dev/null +++ b/examples/02_eddy_identification/pet_statistics_on_identification.py @@ -0,0 +1,105 @@ +""" +Stastics on identification files +================================ + +Some statistics on raw identification without any tracking +""" +from matplotlib import pyplot as plt +from matplotlib.dates import date2num +import numpy as np + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.observations.observation import EddiesObservations + +start_logger().setLevel("ERROR") + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title) + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + + +# %% +# We load demo sample and take only first year. +# +# Replace by a list of filename to apply on your own dataset. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:365] + +# %% +# Merge all identification datasets in one object +all_a = EddiesObservations.concatenate( + [EddiesObservations.load_file(i) for i in file_objects] +) + +# %% +# We define polygon bound +x0, x1, y0, y1 = 15, 20, 33, 38 +xs = np.array([[x0, x1, x1, x0, x0]], dtype="f8") +ys = np.array([[y0, y0, y1, y1, y0]], dtype="f8") +# Polygon object created for the match function use. +polygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys) + +# %% +# Geographic frequency of eddies +step = 0.125 +ax = start_axes("") +# Count pixel encompassed in each contour +g_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True) +m = g_a.display( + ax, cmap="terrain_r", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name="count" +) +ax.plot(polygon["contour_lon_e"][0], polygon["contour_lat_e"][0], "r") +update_axes(ax, m) + +# %% +# We use the match function to count the number of eddies that intersect the polygon defined previously +# `p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon. +i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False) +i_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True) + +# %% +dt = np.datetime64("1970-01-01") - np.datetime64("1950-01-01") +kw_hist = dict( + bins=date2num(np.arange(21900, 22300).astype("datetime64") - dt), histtype="step" +) +# translate julian day in datetime64 +t = all_a.time.astype("datetime64") - dt +# %% +# Number of eddies within a polygon +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Different ways to count eddies within a polygon") +ax.set_ylabel("Count") +m = all_a.mask_from_polygons(((xs, ys),)) +ax.hist(t[m], label="Eddy Center in polygon", **kw_hist) +ax.hist(t[i_s[c_s > 0]], label="Intersection Speed contour and polygon", **kw_hist) +ax.hist(t[i_e[c_e > 0]], label="Intersection Effective contour and polygon", **kw_hist) +ax.legend() +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() + +# %% +# Percent of the area of interest occupied by eddies. +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Percent of polygon occupied by an anticyclonic eddy") +ax.set_ylabel("Percent of polygon") +ax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label="speed contour", **kw_hist) +ax.hist( + t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label="effective contour", **kw_hist +) +ax.legend(), ax.set_ylim(0, 25) +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() diff --git a/examples/06_grid_manipulation/pet_advect.py b/examples/06_grid_manipulation/pet_advect.py index 1a98536a..d7cc67e9 100644 --- a/examples/06_grid_manipulation/pet_advect.py +++ b/examples/06_grid_manipulation/pet_advect.py @@ -73,7 +73,7 @@ def save(self, *args, **kwargs): # %% # Movie properties kwargs = dict(frames=arange(51), interval=100) -kw_p = dict(nb_step=2, time_step=21600) +kw_p = dict(u_name="u", v_name="v", nb_step=2, time_step=21600) frame_t = kw_p["nb_step"] * kw_p["time_step"] / 86400.0 @@ -102,7 +102,7 @@ def update(i_frame, t_step): # ^^^^^^^^^^^^^^^^ # Draw 3 last position in one path for each particles., # it could be run backward with `backward=True` option in filament method -p = g.filament(x, y, "u", "v", **kw_p, filament_size=3) +p = g.filament(x, y, **kw_p, filament_size=3) fig, txt, l, t = anim_ax(lw=0.5) _ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) @@ -110,13 +110,13 @@ def update(i_frame, t_step): # Particle forward # ^^^^^^^^^^^^^^^^^ # Forward advection of particles -p = g.advect(x, y, "u", "v", **kw_p) +p = g.advect(x, y, **kw_p) fig, txt, l, t = anim_ax(ls="", marker=".", markersize=1) _ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) # %% # We get last position and run backward until original position -p = g.advect(x, y, "u", "v", **kw_p, backward=True) +p = g.advect(x, y, **kw_p, backward=True) fig, txt, l, _ = anim_ax(ls="", marker=".", markersize=1) _ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,)) @@ -139,9 +139,11 @@ def update(i_frame, t_step): ) for time_step in (10800, 21600, 43200, 86400): x, y = x0.copy(), y0.copy() - kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step) - g.advect(x, y, "u", "v", **kw_advect).__next__() - g.advect(x, y, "u", "v", **kw_advect, backward=True).__next__() + kw_advect = dict( + nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name="u", v_name="v" + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 ax.hist(d, **kw, label=f"{86400. / time_step:.0f} time step by day") ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() @@ -158,9 +160,14 @@ def update(i_frame, t_step): time_step = 10800 for duration in (5, 50, 100): x, y = x0.copy(), y0.copy() - kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step) - g.advect(x, y, "u", "v", **kw_advect).__next__() - g.advect(x, y, "u", "v", **kw_advect, backward=True).__next__() + kw_advect = dict( + nb_step=int(duration * 86400 / time_step), + time_step=time_step, + u_name="u", + v_name="v", + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 ax.hist(d, **kw, label=f"Time duration {duration} days") ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() diff --git a/examples/06_grid_manipulation/pet_lavd.py b/examples/06_grid_manipulation/pet_lavd.py index 89d64108..a3ea846e 100644 --- a/examples/06_grid_manipulation/pet_lavd.py +++ b/examples/06_grid_manipulation/pet_lavd.py @@ -110,9 +110,11 @@ def save(self, *args, **kwargs): step_by_day = 3 # Compute step of advection every 4h nb_step = 2 -kw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step) +kw_p = dict( + nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name="u", v_name="v" +) # Start a generator which at each iteration return new position at next time step -particule = g.advect(x, y, "u", "v", **kw_p, rk4=True) +particule = g.advect(x, y, **kw_p, rk4=True) # %% # LAVD @@ -158,9 +160,7 @@ def update(i_frame): # %% # Format LAVD data lavd = RegularGridDataset.with_array( - coordinates=("lon", "lat"), - datas=dict(lavd=lavd.T, lon=x_g, lat=y_g,), - centered=True, + coordinates=("lon", "lat"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True ) # %% diff --git a/examples/06_grid_manipulation/pet_okubo_weiss.py b/examples/06_grid_manipulation/pet_okubo_weiss.py index 818a6742..aa8a063e 100644 --- a/examples/06_grid_manipulation/pet_okubo_weiss.py +++ b/examples/06_grid_manipulation/pet_okubo_weiss.py @@ -2,7 +2,7 @@ Get Okubo Weis ============== -.. math:: OW = S_n^2 + S_s^2 + \omega^2 +.. math:: OW = S_n^2 + S_s^2 - \omega^2 with normal strain (:math:`S_n`), shear strain (:math:`S_s`) and vorticity (:math:`\omega`) diff --git a/examples/07_cube_manipulation/README.rst b/examples/07_cube_manipulation/README.rst index 147ce3f3..7cecfbd4 100644 --- a/examples/07_cube_manipulation/README.rst +++ b/examples/07_cube_manipulation/README.rst @@ -1,6 +1,2 @@ Time grid computation ===================== - -.. warning:: - - Time grid is under development, API could move quickly! diff --git a/examples/07_cube_manipulation/pet_cube.py b/examples/07_cube_manipulation/pet_cube.py index 7f30c4e1..cba6c85b 100644 --- a/examples/07_cube_manipulation/pet_cube.py +++ b/examples/07_cube_manipulation/pet_cube.py @@ -4,9 +4,10 @@ Example which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles. """ +from datetime import datetime, timedelta + # sphinx_gallery_thumbnail_number = 2 import re -from datetime import datetime, timedelta from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation diff --git a/examples/07_cube_manipulation/pet_fsle_med.py b/examples/07_cube_manipulation/pet_fsle_med.py index b128286a..9d78ea02 100644 --- a/examples/07_cube_manipulation/pet_fsle_med.py +++ b/examples/07_cube_manipulation/pet_fsle_med.py @@ -49,7 +49,7 @@ def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): Check if distance between eastern or northern particle to center particle is bigger than `dist_max` """ nb_p = x.shape[0] // 3 - delta = dist_max ** 2 + delta = dist_max**2 for i in range(nb_p): i0 = i * 3 i_n = i0 + 1 @@ -59,10 +59,10 @@ def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): continue # Distance with north dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n] - dn = dxn ** 2 + dyn ** 2 + dn = dxn**2 + dyn**2 # Distance with east dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e] - de = dxe ** 2 + dye ** 2 + de = dxe**2 + dye**2 if dn >= delta or de >= delta: s1 = dn + de @@ -71,7 +71,7 @@ def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * ( (dxn - dye) ** 2 + (dxe + dyn) ** 2 ) - flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init ** 2) * (s1 + s2 ** 0.5)) + flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5)) theta[i] = arctan2(at1, at2 + s2) * 180 / pi # To know where value are set m_set[i] = False @@ -142,8 +142,10 @@ def build_triplet(x, y, step=0.02): used = zeros(x.shape[0], dtype="bool") # advection generator -kw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used) -p = c.advect(x, y, "u", "v", time_step=86400 / time_step_by_days, **kw) +kw = dict( + t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name="u", v_name="v" +) +p = c.advect(x, y, time_step=86400 / time_step_by_days, **kw) # We check at each step of advection if particle distance is over `dist_max` for i in range(time_step_by_days * nb_days): diff --git a/examples/07_cube_manipulation/pet_lavd_detection.py b/examples/07_cube_manipulation/pet_lavd_detection.py index 1fa4d60b..4dace120 100644 --- a/examples/07_cube_manipulation/pet_lavd_detection.py +++ b/examples/07_cube_manipulation/pet_lavd_detection.py @@ -93,7 +93,7 @@ def update_axes(ax, mappable=None): # Time properties, for example with advection only 25 days nb_days, step_by_day = 25, 6 nb_time = step_by_day * nb_days -kw_p = dict(nb_step=1, time_step=86400 / step_by_day) +kw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name="u", v_name="v") t0 = 20236 t0_grid = c[t0] # Geographic properties, we use a coarser resolution for time consuming reasons @@ -114,7 +114,7 @@ def update_axes(ax, mappable=None): # ---------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = c.advect(x0.copy(), y0.copy(), "u", "v", t_init=t0, **kw_p) +p = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p) for _ in range(nb_time): t, x, y = p.__next__() lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) @@ -131,7 +131,7 @@ def update_axes(ax, mappable=None): # ----------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = c.advect(x0.copy(), y0.copy(), "u", "v", t_init=t0, backward=True, **kw_p) +p = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p) for i in range(nb_time): t, x, y = p.__next__() lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) @@ -148,7 +148,7 @@ def update_axes(ax, mappable=None): # --------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = t0_grid.advect(x0.copy(), y0.copy(), "u", "v", **kw_p) +p = t0_grid.advect(x0.copy(), y0.copy(), **kw_p) for _ in range(nb_time): x, y = p.__next__() lavd_ += abs(t0_grid.interp("vort", x, y)) @@ -165,7 +165,7 @@ def update_axes(ax, mappable=None): # ---------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = t0_grid.advect(x0.copy(), y0.copy(), "u", "v", backward=True, **kw_p) +p = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p) for i in range(nb_time): x, y = p.__next__() lavd_ += abs(t0_grid.interp("vort", x, y)) diff --git a/examples/07_cube_manipulation/pet_particles_drift.py b/examples/07_cube_manipulation/pet_particles_drift.py new file mode 100644 index 00000000..3d7aa1a4 --- /dev/null +++ b/examples/07_cube_manipulation/pet_particles_drift.py @@ -0,0 +1,46 @@ +""" +Build path of particle drifting +=============================== + +""" + +from matplotlib import pyplot as plt +from numpy import arange, meshgrid + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection + +start_logger().setLevel("ERROR") + +# %% +# Load data cube +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + unset=True, +) + +# %% +# Advection properties +nb_days, step_by_day = 10, 6 +nb_time = step_by_day * nb_days +kw_p = dict(nb_step=1, time_step=86400 / step_by_day) +t0 = 20210 + +# %% +# Get paths +x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5)) +x0, y0 = x0.reshape(-1), y0.reshape(-1) +t, x, y = c.path(x0, y0, h_name="adt", t_init=t0, **kw_p, nb_time=nb_time) + +# %% +# Plot paths +ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect="equal") +ax.plot(x0, y0, "k.", ms=20) +ax.plot(x, y, lw=3) +ax.set_title("10 days particle paths") +ax.set_xlim(31, 35), ax.set_ylim(32, 34.5) +ax.grid() diff --git a/examples/08_tracking_manipulation/pet_display_field.py b/examples/08_tracking_manipulation/pet_display_field.py index 30ad75a6..b943a2ba 100644 --- a/examples/08_tracking_manipulation/pet_display_field.py +++ b/examples/08_tracking_manipulation/pet_display_field.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_display_track.py b/examples/08_tracking_manipulation/pet_display_track.py index 13a8d3ad..b15d51d7 100644 --- a/examples/08_tracking_manipulation/pet_display_track.py +++ b/examples/08_tracking_manipulation/pet_display_track.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py new file mode 100644 index 00000000..8161ad81 --- /dev/null +++ b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py @@ -0,0 +1,94 @@ +""" +Correspondances +=============== + +Correspondances is a mechanism to intend to continue tracking with new detection + +""" + +import logging + +# %% +from matplotlib import pyplot as plt +from netCDF4 import Dataset + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker + +# In order to hide some warning +import py_eddy_tracker.observations.observation +from py_eddy_tracker.tracking import Correspondances + +py_eddy_tracker.observations.observation._display_check_warning = False + + +# %% +def plot_eddy(ed): + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_axes([0.05, 0.03, 0.90, 0.94]) + ed.plot(ax, ref=-10, marker="x") + lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True) + plt.colorbar(lc).set_label("Time in Julian days (from 1950/01/01)") + ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3) + ax.set_aspect("equal") + ax.grid() + + +# %% +# Get remote data, we will keep only 20 first days, +# `get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename +# and don't mix cyclonic and anticyclonic files. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:20] + +# %% +# We run a traking with a tracker which use contour overlap, on 10 first time step +c_first_run = Correspondances( + datasets=file_objects[:10], class_method=AreaTracker, virtual=4 +) +start_logger().setLevel("INFO") +c_first_run.track() +start_logger().setLevel("WARNING") +with Dataset("correspondances.nc", "w") as h: + c_first_run.to_netcdf(h) +# Next step are done only to build atlas and display it +c_first_run.prepare_merging() + +# We have now an eddy object +eddies_area_tracker = c_first_run.merge(raw_data=False) +eddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0 +eddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1) + +# %% +# Plot from first ten days +plot_eddy(eddies_area_tracker) + +# %% +# Restart from previous run +# ------------------------- +# We give all filenames, the new one and filename from previous run +c_second_run = Correspondances( + datasets=file_objects[:20], + # This parameter must be identical in each run + class_method=AreaTracker, + virtual=4, + # Previous saved correspondancs + previous_correspondance="correspondances.nc", +) +start_logger().setLevel("INFO") +c_second_run.track() +start_logger().setLevel("WARNING") +c_second_run.prepare_merging() +# We have now another eddy object +eddies_area_tracker_extend = c_second_run.merge(raw_data=False) +eddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0 +eddies_area_tracker_extend.filled_by_interpolation( + eddies_area_tracker_extend.virtual == 1 +) + + +# %% +# Plot with time extension +plot_eddy(eddies_area_tracker_extend) diff --git a/examples/08_tracking_manipulation/pet_one_track.py b/examples/08_tracking_manipulation/pet_one_track.py index 9f930281..a2536c34 100644 --- a/examples/08_tracking_manipulation/pet_one_track.py +++ b/examples/08_tracking_manipulation/pet_one_track.py @@ -2,8 +2,8 @@ One Track =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_select_track_across_area.py b/examples/08_tracking_manipulation/pet_select_track_across_area.py index b88f37e1..58184e1f 100644 --- a/examples/08_tracking_manipulation/pet_select_track_across_area.py +++ b/examples/08_tracking_manipulation/pet_select_track_across_area.py @@ -3,8 +3,8 @@ ============================ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py index 81e57e59..b686fd67 100644 --- a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py +++ b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py @@ -9,9 +9,9 @@ """ import re -import py_eddy_tracker_sample from matplotlib.animation import FuncAnimation from numpy import arange +import py_eddy_tracker_sample from py_eddy_tracker.appli.gui import Anim from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_birth_and_death.py b/examples/10_tracking_diagnostics/pet_birth_and_death.py index d917efbd..b67993a2 100644 --- a/examples/10_tracking_diagnostics/pet_birth_and_death.py +++ b/examples/10_tracking_diagnostics/pet_birth_and_death.py @@ -5,8 +5,8 @@ Following figures are based on https://doi.org/10.1016/j.pocean.2011.01.002 """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_center_count.py b/examples/10_tracking_diagnostics/pet_center_count.py index 6d9fa417..77a4dcda 100644 --- a/examples/10_tracking_diagnostics/pet_center_count.py +++ b/examples/10_tracking_diagnostics/pet_center_count.py @@ -5,9 +5,9 @@ Do Geo stat with center and compare with frequency method show: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py` """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.colors import LogNorm +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -27,7 +27,7 @@ step = 0.125 bins = ((-10, 37, step), (30, 46, step)) kwargs_pcolormesh = dict( - cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step ** 2), name="count" + cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step**2), name="count" ) diff --git a/examples/10_tracking_diagnostics/pet_geographic_stats.py b/examples/10_tracking_diagnostics/pet_geographic_stats.py index d2a7e90d..a2e3f6b5 100644 --- a/examples/10_tracking_diagnostics/pet_geographic_stats.py +++ b/examples/10_tracking_diagnostics/pet_geographic_stats.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_groups.py b/examples/10_tracking_diagnostics/pet_groups.py index f6e800ae..deedcc3f 100644 --- a/examples/10_tracking_diagnostics/pet_groups.py +++ b/examples/10_tracking_diagnostics/pet_groups.py @@ -3,9 +3,9 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones, percentile +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_histo.py b/examples/10_tracking_diagnostics/pet_histo.py index b2eff842..abf97c38 100644 --- a/examples/10_tracking_diagnostics/pet_histo.py +++ b/examples/10_tracking_diagnostics/pet_histo.py @@ -3,9 +3,9 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_lifetime.py b/examples/10_tracking_diagnostics/pet_lifetime.py index 9f84e790..4e2500fd 100644 --- a/examples/10_tracking_diagnostics/pet_lifetime.py +++ b/examples/10_tracking_diagnostics/pet_lifetime.py @@ -3,9 +3,9 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_pixel_used.py b/examples/10_tracking_diagnostics/pet_pixel_used.py index 3907ce19..75a826d6 100644 --- a/examples/10_tracking_diagnostics/pet_pixel_used.py +++ b/examples/10_tracking_diagnostics/pet_pixel_used.py @@ -5,9 +5,9 @@ Do Geo stat with frequency and compare with center count method: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py` """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.colors import LogNorm +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_propagation.py b/examples/10_tracking_diagnostics/pet_propagation.py index 6a65a212..e6bc6c1b 100644 --- a/examples/10_tracking_diagnostics/pet_propagation.py +++ b/examples/10_tracking_diagnostics/pet_propagation.py @@ -3,9 +3,9 @@ ===================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones +import py_eddy_tracker_sample from py_eddy_tracker.generic import cumsum_by_track from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/12_external_data/pet_drifter_loopers.py b/examples/12_external_data/pet_drifter_loopers.py index 92707906..5266db7b 100644 --- a/examples/12_external_data/pet_drifter_loopers.py +++ b/examples/12_external_data/pet_drifter_loopers.py @@ -8,10 +8,10 @@ import re -import numpy as np -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation +import numpy as np +import py_eddy_tracker_sample from py_eddy_tracker import data from py_eddy_tracker.appli.gui import Anim diff --git a/examples/14_generic_tools/pet_visvalingam.py b/examples/14_generic_tools/pet_visvalingam.py index f7b29c10..736e8852 100644 --- a/examples/14_generic_tools/pet_visvalingam.py +++ b/examples/14_generic_tools/pet_visvalingam.py @@ -2,8 +2,8 @@ Visvalingam algorithm ===================== """ -import matplotlib.animation as animation from matplotlib import pyplot as plt +import matplotlib.animation as animation from numba import njit from numpy import array, empty diff --git a/examples/16_network/pet_atlas.py b/examples/16_network/pet_atlas.py index 6927f169..48b374e2 100644 --- a/examples/16_network/pet_atlas.py +++ b/examples/16_network/pet_atlas.py @@ -129,7 +129,9 @@ def update_axes(ax, mappable=None): # Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations) # -------------------------------------------------------------------------------------------- ax = start_axes("") -merger = n10.remove_dead_end(nobs=10).merging_event() +n10_ = n10.copy() +n10_.remove_dead_end(nobs=10) +merger = n10_.merging_event() g_10_merging = merger.grid_count(bins) m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1) update_axes(ax, m).set_label("Pixel used in % of time") diff --git a/examples/16_network/pet_follow_particle.py b/examples/16_network/pet_follow_particle.py index dbe0753e..6815fb6e 100644 --- a/examples/16_network/pet_follow_particle.py +++ b/examples/16_network/pet_follow_particle.py @@ -5,8 +5,7 @@ """ import re -from matplotlib import colors -from matplotlib import pyplot as plt +from matplotlib import colors, pyplot as plt from matplotlib.animation import FuncAnimation from numpy import arange, meshgrid, ones, unique, zeros @@ -42,7 +41,8 @@ def save(self, *args, **kwargs): # %% n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269)) -n = n.remove_dead_end(nobs=0, ndays=10) +n.remove_dead_end(nobs=0, ndays=10) +n = n.remove_trash() n.numbering_segment() c = GridCollection.from_netcdf_cube( get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), @@ -96,11 +96,17 @@ def save(self, *args, **kwargs): a.txt.set_position((25, 31)) step = 0.25 -kw_p = dict(nb_step=2, time_step=86400 * step * 0.5, t_init=t_snapshot - 2 * step) +kw_p = dict( + nb_step=2, + time_step=86400 * step * 0.5, + t_init=t_snapshot - 2 * step, + u_name="u", + v_name="v", +) mappables = dict() -particules = c.advect(x, y, "u", "v", **kw_p) -filament = c.filament(x_f, y_f, "u", "v", **kw_p, filament_size=3) +particules = c.advect(x, y, **kw_p) +filament = c.filament(x_f, y_f, **kw_p, filament_size=3) kw = dict(ls="", marker=".", markersize=0.25) for k in index_: m = k == index diff --git a/examples/16_network/pet_group_anim.py b/examples/16_network/pet_group_anim.py index 047f5820..f2d439ed 100644 --- a/examples/16_network/pet_group_anim.py +++ b/examples/16_network/pet_group_anim.py @@ -2,9 +2,10 @@ Network group process ===================== """ +from datetime import datetime + # sphinx_gallery_thumbnail_number = 2 import re -from datetime import datetime from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation diff --git a/examples/16_network/pet_ioannou_2017_case.py b/examples/16_network/pet_ioannou_2017_case.py index b02b846a..56bec82e 100644 --- a/examples/16_network/pet_ioannou_2017_case.py +++ b/examples/16_network/pet_ioannou_2017_case.py @@ -6,12 +6,12 @@ We want to find the Ierapetra Eddy described above in a network demonstration run. """ +from datetime import datetime, timedelta + # %% import re -from datetime import datetime, timedelta -from matplotlib import colors -from matplotlib import pyplot as plt +from matplotlib import colors, pyplot as plt from matplotlib.animation import FuncAnimation from matplotlib.ticker import FuncFormatter from numpy import arange, array, pi, where diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py index f5e8bc92..dd97b538 100644 --- a/examples/16_network/pet_relative.py +++ b/examples/16_network/pet_relative.py @@ -127,7 +127,9 @@ # Remove dead branch # ------------------ # Remove all tiny segments with less than N obs which didn't join two segments -n_clean = n.remove_dead_end(nobs=5, ndays=10) +n_clean = n.copy() +n_clean.remove_dead_end(nobs=5, ndays=10) +n_clean = n_clean.remove_trash() fig = plt.figure(figsize=(15, 12)) ax = fig.add_axes([0.04, 0.54, 0.90, 0.40]) ax.set_title(f"Original network ({n.infos()})") @@ -261,7 +263,9 @@ # -------------------- # Get a simplified network -n = n2.remove_dead_end(nobs=50, recursive=1) +n = n2.copy() +n.remove_dead_end(nobs=50, recursive=1) +n = n.remove_trash() n.numbering_segment() # %% # Only a map can be tricky to understand, with a timeline it's easier! diff --git a/examples/16_network/pet_replay_segmentation.py b/examples/16_network/pet_replay_segmentation.py index d6b4568b..d909af7f 100644 --- a/examples/16_network/pet_replay_segmentation.py +++ b/examples/16_network/pet_replay_segmentation.py @@ -147,9 +147,9 @@ def get_obs(dataset): ax = timeline_axes() n_.median_filter(15, "time", "latitude") -kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30 ** 2 * 20 +kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30**2 * 20 m = n_.scatter_timeline( - ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lon", method="all", + ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lon", method="all" ) ax.set_ylabel("Longitude") cb = update_axes(ax, m["scatter"]) @@ -163,7 +163,6 @@ def get_obs(dataset): for b0, b1 in [ (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008) ]: - ref, delta = datetime(1950, 1, 1), 20 b0_, b1_ = (b0 - ref).days, (b1 - ref).days ax = timeline_axes() diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py index 58f71188..1fcb9ae1 100644 --- a/examples/16_network/pet_segmentation_anim.py +++ b/examples/16_network/pet_segmentation_anim.py @@ -96,7 +96,7 @@ def update(i_frame): indices_frames = INDICES[i_frame] mappable_CONTOUR.set_data( - e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames], + e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames] ) mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) return (mappable_tracks,) diff --git a/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb new file mode 100644 index 00000000..7fa04435 --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Stastics on identification files\n\nSome statistics on raw identification without any tracking\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom matplotlib.dates import date2num\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.observations.observation import EddiesObservations\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load demo sample and take only first year.\n\nReplace by a list of filename to apply on your own dataset.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:365]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Merge all identification dataset in one object\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "all_a = EddiesObservations.concatenate(\n [EddiesObservations.load_file(i) for i in file_objects]\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define polygon bound\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, x1, y0, y1 = 15, 20, 33, 38\nxs = np.array([[x0, x1, x1, x0, x0]], dtype=\"f8\")\nys = np.array([[y0, y0, y1, y1, y0]], dtype=\"f8\")\n# Polygon object is create to be usable by match function.\npolygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Geographic frequency of eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 0.125\nax = start_axes(\"\")\n# Count pixel used for each contour\ng_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True)\nm = g_a.display(\n ax, cmap=\"terrain_r\", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name=\"count\"\n)\nax.plot(polygon[\"contour_lon_e\"][0], polygon[\"contour_lat_e\"][0], \"r\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use match function to count number of eddies which intersect the polygon defined previously.\n`p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False)\ni_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "dt = np.datetime64(\"1970-01-01\") - np.datetime64(\"1950-01-01\")\nkw_hist = dict(\n bins=date2num(np.arange(21900, 22300).astype(\"datetime64\") - dt), histtype=\"step\"\n)\n# translate julian day in datetime64\nt = all_a.time.astype(\"datetime64\") - dt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count how many are in polygon\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Different way to count eddies presence in a polygon\")\nax.set_ylabel(\"Count\")\nm = all_a.mask_from_polygons(((xs, ys),))\nax.hist(t[m], label=\"center in polygon\", **kw_hist)\nax.hist(t[i_s[c_s > 0]], label=\"intersect speed contour with polygon\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], label=\"intersect extern contour with polygon\", **kw_hist)\nax.legend()\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Percent of are of interest occupy by eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Percent of polygon occupy by an anticyclonic eddy\")\nax.set_ylabel(\"Percent of polygon\")\nax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label=\"speed contour\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label=\"effective contour\", **kw_hist)\nax.legend(), ax.set_ylim(0, 25)\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb index 79d69b0d..90ee1722 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nGrid advection\n==============\n\nDummy advection which use only static geostrophic current, which didn't solve the complex circulation of the ocean.\n" + "\n# Grid advection\n\nDummy advection which use only static geostrophic current, which didn't solve the complex circulation of the ocean.\n" ] }, { @@ -98,7 +98,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Anim\n----\nParticles setup\n\n" + "## Anim\nParticles setup\n\n" ] }, { @@ -127,7 +127,7 @@ }, "outputs": [], "source": [ - "kwargs = dict(frames=arange(51), interval=100)\nkw_p = dict(nb_step=2, time_step=21600)\nframe_t = kw_p[\"nb_step\"] * kw_p[\"time_step\"] / 86400.0" + "kwargs = dict(frames=arange(51), interval=100)\nkw_p = dict(u_name=\"u\", v_name=\"v\", nb_step=2, time_step=21600)\nframe_t = kw_p[\"nb_step\"] * kw_p[\"time_step\"] / 86400.0" ] }, { @@ -152,7 +152,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Filament forward\n^^^^^^^^^^^^^^^^\nDraw 3 last position in one path for each particles.,\nit could be run backward with `backward=True` option in filament method\n\n" + "### Filament forward\nDraw 3 last position in one path for each particles.,\nit could be run backward with `backward=True` option in filament method\n\n" ] }, { @@ -163,14 +163,14 @@ }, "outputs": [], "source": [ - "p = g.filament(x, y, \"u\", \"v\", **kw_p, filament_size=3)\nfig, txt, l, t = anim_ax(lw=0.5)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + "p = g.filament(x, y, **kw_p, filament_size=3)\nfig, txt, l, t = anim_ax(lw=0.5)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Particle forward\n^^^^^^^^^^^^^^^^^\nForward advection of particles\n\n" + "### Particle forward\nForward advection of particles\n\n" ] }, { @@ -181,7 +181,7 @@ }, "outputs": [], "source": [ - "p = g.advect(x, y, \"u\", \"v\", **kw_p)\nfig, txt, l, t = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + "p = g.advect(x, y, **kw_p)\nfig, txt, l, t = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" ] }, { @@ -199,21 +199,21 @@ }, "outputs": [], "source": [ - "p = g.advect(x, y, \"u\", \"v\", **kw_p, backward=True)\nfig, txt, l, _ = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,))" + "p = g.advect(x, y, **kw_p, backward=True)\nfig, txt, l, _ = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Particles stat\n--------------\n\n" + "## Particles stat\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Time_step settings\n^^^^^^^^^^^^^^^^^^\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" + "### Time_step settings\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" ] }, { @@ -224,14 +224,14 @@ }, "outputs": [], "source": [ - "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.001),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step)\n g.advect(x, y, \"u\", \"v\", **kw_advect).__next__()\n g.advect(x, y, \"u\", \"v\", **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" + "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.001),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Time duration\n^^^^^^^^^^^^^\nWe keep same time_step but change time duration\n\n" + "### Time duration\nWe keep same time_step but change time duration\n\n" ] }, { @@ -242,7 +242,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (5, 50, 100):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step)\n g.advect(x, y, \"u\", \"v\", **kw_advect).__next__()\n g.advect(x, y, \"u\", \"v\", **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" + "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (5, 50, 100):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" ] } ], @@ -262,7 +262,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb index c4a4da84..cbe6de64 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nLAVD experiment\n===============\n\nNaive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation) method with a static velocity field.\nIn the current example we didn't remove a mean vorticity.\n\nMethod are described here:\n\n - Abernathey, Ryan, and George Haller. \"Transport by Lagrangian Vortices in the Eastern Pacific\",\n Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021,\n https://doi.org/10.1175/JPO-D-17-0102.1\n - `Transport by Coherent Lagrangian Vortices`_,\n R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019,\n Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop\n\n https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf\n" + "\n# LAVD experiment\n\nNaive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation) method with a static velocity field.\nIn the current example we didn't remove a mean vorticity.\n\nMethod are described here:\n\n - Abernathey, Ryan, and George Haller. \"Transport by Lagrangian Vortices in the Eastern Pacific\",\n Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021,\n https://doi.org/10.1175/JPO-D-17-0102.1\n - `Transport by Coherent Lagrangian Vortices`_,\n R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019,\n Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop\n\n https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf\n" ] }, { @@ -55,7 +55,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Data\n----\nTo compute vorticity ($\\omega$) we compute u/v field with a stencil and apply the following equation with stencil\nmethod :\n\n\\begin{align}\\omega = \\frac{\\partial v}{\\partial x} - \\frac{\\partial u}{\\partial y}\\end{align}\n\n" + "## Data\nTo compute vorticity ($\\omega$) we compute u/v field with a stencil and apply the following equation with stencil\nmethod :\n\n\\begin{align}\\omega = \\frac{\\partial v}{\\partial x} - \\frac{\\partial u}{\\partial y}\\end{align}\n\n" ] }, { @@ -91,7 +91,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Particles\n---------\nParticles specification\n\n" + "## Particles\nParticles specification\n\n" ] }, { @@ -102,14 +102,14 @@ }, "outputs": [], "source": [ - "step = 1 / 32\nx_g, y_g = arange(0, 36, step), arange(28, 46, step)\nx, y = meshgrid(x_g, y_g)\noriginal_shape = x.shape\nx, y = x.reshape(-1), y.reshape(-1)\nprint(f\"{len(x)} particles advected\")\n# A frame every 8h\nstep_by_day = 3\n# Compute step of advection every 4h\nnb_step = 2\nkw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step)\n# Start a generator which at each iteration return new position at next time step\nparticule = g.advect(x, y, \"u\", \"v\", **kw_p, rk4=True)" + "step = 1 / 32\nx_g, y_g = arange(0, 36, step), arange(28, 46, step)\nx, y = meshgrid(x_g, y_g)\noriginal_shape = x.shape\nx, y = x.reshape(-1), y.reshape(-1)\nprint(f\"{len(x)} particles advected\")\n# A frame every 8h\nstep_by_day = 3\n# Compute step of advection every 4h\nnb_step = 2\nkw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name=\"u\", v_name=\"v\")\n# Start a generator which at each iteration return new position at next time step\nparticule = g.advect(x, y, **kw_p, rk4=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "LAVD\n----\n\n" + "## LAVD\n\n" ] }, { @@ -127,7 +127,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Anim\n^^^^\nMovie of LAVD integration at each integration time step.\n\n" + "### Anim\nMovie of LAVD integration at each integration time step.\n\n" ] }, { @@ -145,7 +145,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Final LAVD\n^^^^^^^^^^\n\n" + "### Final LAVD\n\n" ] }, { @@ -163,7 +163,7 @@ }, "outputs": [], "source": [ - "lavd = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(lavd=lavd.T, lon=x_g, lat=y_g,),\n centered=True,\n)" + "lavd = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True\n)" ] }, { @@ -201,7 +201,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb index 8ee136b3..6f52e750 100644 --- a/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb +++ b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nFSLE experiment in med\n======================\n\nExample to build Finite Size Lyapunov Exponents, parameter values must be adapted for your case.\n\nExample use a method similar to `AVISO flse`_\n\n https://www.aviso.altimetry.fr/en/data/products/value-added-products/\n fsle-finite-size-lyapunov-exponents/fsle-description.html\n" + "\n# FSLE experiment in med\n\nExample to build Finite Size Lyapunov Exponents, parameter values must be adapted for your case.\n\nExample use a method similar to `AVISO flse`_\n\n https://www.aviso.altimetry.fr/en/data/products/value-added-products/\n fsle-finite-size-lyapunov-exponents/fsle-description.html\n" ] }, { @@ -33,7 +33,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "ADT in med\n----------\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is\nmade for data stores in time cube, you could use also \n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to\nload data-cube from multiple file.\n\n" + "## ADT in med\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is\nmade for data stores in time cube, you could use also\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to\nload data-cube from multiple file.\n\n" ] }, { @@ -51,7 +51,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Methods to compute FSLE\n-----------------------\n\n" + "## Methods to compute FSLE\n\n" ] }, { @@ -62,14 +62,14 @@ }, "outputs": [], "source": [ - "@njit(cache=True, fastmath=True)\ndef check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6):\n \"\"\"\n Check if distance between eastern or northern particle to center particle is bigger than `dist_max`\n \"\"\"\n nb_p = x.shape[0] // 3\n delta = dist_max ** 2\n for i in range(nb_p):\n i0 = i * 3\n i_n = i0 + 1\n i_e = i0 + 2\n # If particle already set, we skip\n if m[i0] or m[i_n] or m[i_e]:\n continue\n # Distance with north\n dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n]\n dn = dxn ** 2 + dyn ** 2\n # Distance with east\n dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e]\n de = dxe ** 2 + dye ** 2\n\n if dn >= delta or de >= delta:\n s1 = dn + de\n at1 = 2 * (dxe * dxn + dye * dyn)\n at2 = de - dn\n s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * (\n (dxn - dye) ** 2 + (dxe + dyn) ** 2\n )\n flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init ** 2) * (s1 + s2 ** 0.5))\n theta[i] = arctan2(at1, at2 + s2) * 180 / pi\n # To know where value are set\n m_set[i] = False\n # To stop particle advection\n m[i0], m[i_n], m[i_e] = True, True, True\n\n\n@njit(cache=True)\ndef build_triplet(x, y, step=0.02):\n \"\"\"\n Triplet building for each position we add east and north point with defined step\n \"\"\"\n nb_x = x.shape[0]\n x_ = empty(nb_x * 3, dtype=x.dtype)\n y_ = empty(nb_x * 3, dtype=y.dtype)\n for i in range(nb_x):\n i0 = i * 3\n i_n, i_e = i0 + 1, i0 + 2\n x__, y__ = x[i], y[i]\n x_[i0], y_[i0] = x__, y__\n x_[i_n], y_[i_n] = x__, y__ + step\n x_[i_e], y_[i_e] = x__ + step, y__\n return x_, y_" + "@njit(cache=True, fastmath=True)\ndef check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6):\n \"\"\"\n Check if distance between eastern or northern particle to center particle is bigger than `dist_max`\n \"\"\"\n nb_p = x.shape[0] // 3\n delta = dist_max**2\n for i in range(nb_p):\n i0 = i * 3\n i_n = i0 + 1\n i_e = i0 + 2\n # If particle already set, we skip\n if m[i0] or m[i_n] or m[i_e]:\n continue\n # Distance with north\n dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n]\n dn = dxn**2 + dyn**2\n # Distance with east\n dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e]\n de = dxe**2 + dye**2\n\n if dn >= delta or de >= delta:\n s1 = dn + de\n at1 = 2 * (dxe * dxn + dye * dyn)\n at2 = de - dn\n s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * (\n (dxn - dye) ** 2 + (dxe + dyn) ** 2\n )\n flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5))\n theta[i] = arctan2(at1, at2 + s2) * 180 / pi\n # To know where value are set\n m_set[i] = False\n # To stop particle advection\n m[i0], m[i_n], m[i_e] = True, True, True\n\n\n@njit(cache=True)\ndef build_triplet(x, y, step=0.02):\n \"\"\"\n Triplet building for each position we add east and north point with defined step\n \"\"\"\n nb_x = x.shape[0]\n x_ = empty(nb_x * 3, dtype=x.dtype)\n y_ = empty(nb_x * 3, dtype=y.dtype)\n for i in range(nb_x):\n i0 = i * 3\n i_n, i_e = i0 + 1, i0 + 2\n x__, y__ = x[i], y[i]\n x_[i0], y_[i0] = x__, y__\n x_[i_n], y_[i_n] = x__, y__ + step\n x_[i_e], y_[i_e] = x__ + step, y__\n return x_, y_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Settings\n--------\n\n" + "## Settings\n\n" ] }, { @@ -87,7 +87,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Particles\n---------\n\n" + "## Particles\n\n" ] }, { @@ -105,7 +105,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "FSLE\n----\n\n" + "## FSLE\n\n" ] }, { @@ -116,14 +116,14 @@ }, "outputs": [], "source": [ - "# Array to compute fsle\nfsle = zeros(x0.shape[0], dtype=\"f4\")\ntheta = zeros(x0.shape[0], dtype=\"f4\")\nmask = ones(x0.shape[0], dtype=\"f4\")\nx, y = build_triplet(x0, y0, dist_init)\nused = zeros(x.shape[0], dtype=\"bool\")\n\n# advection generator\nkw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used)\np = c.advect(x, y, \"u\", \"v\", time_step=86400 / time_step_by_days, **kw)\n\n# We check at each step of advection if particle distance is over `dist_max`\nfor i in range(time_step_by_days * nb_days):\n t, xt, yt = p.__next__()\n dt = t / 86400.0 - t0\n check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init)\n\n# Get index with original_position\ni = ((x0 - x0_) / step_grid_out).astype(\"i4\")\nj = ((y0 - y0_) / step_grid_out).astype(\"i4\")\nfsle_ = empty(grid_shape, dtype=\"f4\")\ntheta_ = empty(grid_shape, dtype=\"f4\")\nmask_ = ones(grid_shape, dtype=\"bool\")\nfsle_[i, j] = fsle\ntheta_[i, j] = theta\nmask_[i, j] = mask\n# Create a grid object\nfsle_custom = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n fsle=ma.array(fsle_, mask=mask_),\n theta=ma.array(theta_, mask=mask_),\n lon=lon_p,\n lat=lat_p,\n ),\n centered=True,\n)" + "# Array to compute fsle\nfsle = zeros(x0.shape[0], dtype=\"f4\")\ntheta = zeros(x0.shape[0], dtype=\"f4\")\nmask = ones(x0.shape[0], dtype=\"f4\")\nx, y = build_triplet(x0, y0, dist_init)\nused = zeros(x.shape[0], dtype=\"bool\")\n\n# advection generator\nkw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name=\"u\", v_name=\"v\")\np = c.advect(x, y, time_step=86400 / time_step_by_days, **kw)\n\n# We check at each step of advection if particle distance is over `dist_max`\nfor i in range(time_step_by_days * nb_days):\n t, xt, yt = p.__next__()\n dt = t / 86400.0 - t0\n check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init)\n\n# Get index with original_position\ni = ((x0 - x0_) / step_grid_out).astype(\"i4\")\nj = ((y0 - y0_) / step_grid_out).astype(\"i4\")\nfsle_ = empty(grid_shape, dtype=\"f4\")\ntheta_ = empty(grid_shape, dtype=\"f4\")\nmask_ = ones(grid_shape, dtype=\"bool\")\nfsle_[i, j] = fsle\ntheta_[i, j] = theta\nmask_[i, j] = mask\n# Create a grid object\nfsle_custom = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n fsle=ma.array(fsle_, mask=mask_),\n theta=ma.array(theta_, mask=mask_),\n lon=lon_p,\n lat=lat_p,\n ),\n centered=True,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Display FSLE\n------------\n\n" + "## Display FSLE\n\n" ] }, { @@ -141,7 +141,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Display Theta\n-------------\n\n" + "## Display Theta\n\n" ] }, { @@ -172,7 +172,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb index bd197c57..708d7024 100644 --- a/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb +++ b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb @@ -84,7 +84,7 @@ }, "outputs": [], "source": [ - "# Time properties, for example with advection only 25 days\nnb_days, step_by_day = 25, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day)\nt0 = 20236\nt0_grid = c[t0]\n# Geographic properties, we use a coarser resolution for time consuming reasons\nstep = 1 / 32.0\nx_g, y_g = arange(-6, 36, step), arange(30, 46, step)\nx0, y0 = meshgrid(x_g, y_g)\noriginal_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Get all particles in defined area\nm = ~isnan(t0_grid.interp(\"vort\", x0, y0))\nx0, y0 = x0[m], y0[m]\nprint(f\"{x0.size} particles advected\")\n# Gridded mask\nm = m.reshape(original_shape)" + "# Time properties, for example with advection only 25 days\nnb_days, step_by_day = 25, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name=\"u\", v_name=\"v\")\nt0 = 20236\nt0_grid = c[t0]\n# Geographic properties, we use a coarser resolution for time consuming reasons\nstep = 1 / 32.0\nx_g, y_g = arange(-6, 36, step), arange(30, 46, step)\nx0, y0 = meshgrid(x_g, y_g)\noriginal_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Get all particles in defined area\nm = ~isnan(t0_grid.interp(\"vort\", x0, y0))\nx0, y0 = x0[m], y0[m]\nprint(f\"{x0.size} particles advected\")\n# Gridded mask\nm = m.reshape(original_shape)" ] }, { @@ -102,7 +102,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), \"u\", \"v\", t_init=t0, **kw_p)\nfor _ in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection\")\nmappable = lavd_forward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p)\nfor _ in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection\")\nmappable = lavd_forward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -120,7 +120,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), \"u\", \"v\", t_init=t0, backward=True, **kw_p)\nfor i in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection\")\nmappable = lavd_backward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p)\nfor i in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection\")\nmappable = lavd_backward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -138,7 +138,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), \"u\", \"v\", **kw_p)\nfor _ in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection on a static velocity field\")\nmappable = lavd_forward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), **kw_p)\nfor _ in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection on a static velocity field\")\nmappable = lavd_forward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -156,7 +156,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), \"u\", \"v\", backward=True, **kw_p)\nfor i in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection on a static velocity field\")\nmappable = lavd_backward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p)\nfor i in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection on a static velocity field\")\nmappable = lavd_backward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -194,7 +194,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.9" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb new file mode 100644 index 00000000..b92c4d21 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Build path of particle drifting\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import arange, meshgrid\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data cube\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n unset=True\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Advection properties\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "nb_days, step_by_day = 10, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day)\nt0 = 20210" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5))\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\nt, x, y = c.path(x0, y0, h_name=\"adt\", t_init=t0, **kw_p, nb_time=nb_time)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect=\"equal\")\nax.plot(x0, y0, \"k.\", ms=20)\nax.plot(x, y, lw=3)\nax.set_title(\"10 days particle paths\")\nax.set_xlim(31, 35), ax.set_ylim(32, 34.5)\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb new file mode 100644 index 00000000..0681c0fc --- /dev/null +++ b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Correspondances\n\nCorrespondances is a mechanism to intend to continue tracking with new detection\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom netCDF4 import Dataset\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\n\n# In order to hide some warning\nimport py_eddy_tracker.observations.observation\nfrom py_eddy_tracker.tracking import Correspondances\n\npy_eddy_tracker.observations.observation._display_check_warning = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_eddy(ed):\n fig = plt.figure(figsize=(10, 5))\n ax = fig.add_axes([0.05, 0.03, 0.90, 0.94])\n ed.plot(ax, ref=-10, marker=\"x\")\n lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True)\n plt.colorbar(lc).set_label(\"Time in Julian days (from 1950/01/01)\")\n ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3)\n ax.set_aspect(\"equal\")\n ax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get remote data, we will keep only 20 first days,\n`get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:20]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We run a traking with a tracker which use contour overlap, on 10 first time step\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_first_run = Correspondances(\n datasets=file_objects[:10], class_method=AreaTracker, virtual=4\n)\nstart_logger().setLevel(\"INFO\")\nc_first_run.track()\nstart_logger().setLevel(\"WARNING\")\nwith Dataset(\"correspondances.nc\", \"w\") as h:\n c_first_run.to_netcdf(h)\n# Next step are done only to build atlas and display it\nc_first_run.prepare_merging()\n\n# We have now an eddy object\neddies_area_tracker = c_first_run.merge(raw_data=False)\neddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0\neddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot from first ten days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restart from previous run\nWe give all filenames, the new one and filename from previous run\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_second_run = Correspondances(\n datasets=file_objects[:20],\n # This parameter must be identical in each run\n class_method=AreaTracker,\n virtual=4,\n # Previous saved correspondancs\n previous_correspondance=\"correspondances.nc\",\n)\nstart_logger().setLevel(\"INFO\")\nc_second_run.track()\nstart_logger().setLevel(\"WARNING\")\nc_second_run.prepare_merging()\n# We have now another eddy object\neddies_area_tracker_extend = c_second_run.merge(raw_data=False)\neddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0\neddies_area_tracker_extend.filled_by_interpolation(\n eddies_area_tracker_extend.virtual == 1\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot with time extension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker_extend)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 477cf32d..556cabbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -matplotlib -netCDF4 -numba>=0.53 -numpy<1.21 +matplotlib < 3.8 # need an update of contour management opencv-python pint polygon3 pyyaml requests scipy -zarr +zarr < 3.0 +netCDF4 +numpy +numba \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index dfed5c3b..7e773ae8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,29 @@ + +[yapf] +column_limit = 100 + [flake8] +max-line-length = 140 ignore = - E203, # whitespace before ':' - W503, # line break before binary operator + E203, + W503, +exclude= + build + doc + versioneer.py + +[isort] +combine_as_imports=True +force_grid_wrap=0 +force_sort_within_sections=True +force_to_top=typing +include_trailing_comma=True +line_length=140 +multi_line_output=3 +skip= + build + doc/conf.py + [versioneer] VCS = git @@ -13,4 +35,5 @@ parentdir_prefix = [tool:pytest] filterwarnings= - ignore:tostring.*is deprecated \ No newline at end of file + ignore:tostring.*is deprecated + diff --git a/setup.py b/setup.py index 06432bd1..7b836763 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- +from setuptools import find_packages, setup + import versioneer -from setuptools import setup, find_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -9,7 +10,7 @@ setup( name="pyEddyTracker", - python_requires=">=3.7", + python_requires=">=3.10", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), description="Py-Eddy-Tracker libraries", @@ -48,6 +49,7 @@ "EddyNetworkGroup = py_eddy_tracker.appli.network:build_network", "EddyNetworkBuildPath = py_eddy_tracker.appli.network:divide_network", "EddyNetworkSubSetter = py_eddy_tracker.appli.network:subset_network", + "EddyNetworkQuickCompare = py_eddy_tracker.appli.network:quick_compare", # anim/gui "EddyAnim = py_eddy_tracker.appli.gui:anim", "GUIEddy = py_eddy_tracker.appli.gui:guieddy", diff --git a/share/fig.py b/share/fig.py index 8640abcb..80c7f12b 100644 --- a/share/fig.py +++ b/share/fig.py @@ -1,8 +1,10 @@ -from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset from datetime import datetime import logging +from matplotlib import pyplot as plt + +from py_eddy_tracker.dataset.grid import RegularGridDataset + grid_name, lon_name, lat_name = ( "nrt_global_allsat_phy_l4_20190223_20190226.nc", "longitude", diff --git a/src/py_eddy_tracker/__init__.py b/src/py_eddy_tracker/__init__.py index 275bb795..7115bf67 100644 --- a/src/py_eddy_tracker/__init__.py +++ b/src/py_eddy_tracker/__init__.py @@ -20,9 +20,9 @@ """ -import logging from argparse import ArgumentParser from datetime import datetime +import logging import zarr @@ -32,13 +32,13 @@ del get_versions -def start_logger(): +def start_logger(color=True): FORMAT_LOG = "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t%(message)s" logger = logging.getLogger("pet") if len(logger.handlers) == 0: # set up logging to CONSOLE console = logging.StreamHandler() - console.setFormatter(ColoredFormatter(FORMAT_LOG)) + console.setFormatter(ColoredFormatter(FORMAT_LOG, color=color)) # add the handler to the root logger logger.addHandler(console) return logger @@ -53,13 +53,14 @@ class ColoredFormatter(logging.Formatter): DEBUG="\033[34m\t", ) - def __init__(self, message): + def __init__(self, message, color=True): super().__init__(message) + self.with_color = color def format(self, record): color = self.COLOR_LEVEL.get(record.levelname, "") color_reset = "\033[0m" - model = color + "%s" + color_reset + model = (color + "%s" + color_reset) if self.with_color else "%s" record.msg = model % record.msg record.funcName = model % record.funcName record.module = model % record.module @@ -422,14 +423,20 @@ def identify_time(str_date): nc_name="previous_cost", nc_type="float32", nc_dims=("obs",), - nc_attr=dict(long_name="Previous cost for previous observation", comment="",), + nc_attr=dict( + long_name="Previous cost for previous observation", + comment="", + ), ), next_cost=dict( attr_name=None, nc_name="next_cost", nc_type="float32", nc_dims=("obs",), - nc_attr=dict(long_name="Next cost for next observation", comment="",), + nc_attr=dict( + long_name="Next cost for next observation", + comment="", + ), ), n=dict( attr_name=None, @@ -640,7 +647,8 @@ def identify_time(str_date): nc_type="f4", nc_dims=("obs",), nc_attr=dict( - long_name="Log base 10 background chlorophyll", units="Log(Chl/[mg/m^3])", + long_name="Log base 10 background chlorophyll", + units="Log(Chl/[mg/m^3])", ), ), year=dict( @@ -689,3 +697,6 @@ def identify_time(str_date): VAR_DESCR_inv[VAR_DESCR[key]["nc_name"]] = key for key_old in VAR_DESCR[key].get("old_nc_name", list()): VAR_DESCR_inv[key_old] = key + +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/src/py_eddy_tracker/_version.py b/src/py_eddy_tracker/_version.py index 44367e3a..589e706f 100644 --- a/src/py_eddy_tracker/_version.py +++ b/src/py_eddy_tracker/_version.py @@ -1,11 +1,13 @@ + # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -14,9 +16,11 @@ import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -32,8 +36,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -51,41 +62,50 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} - +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -96,18 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -116,64 +138,64 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -186,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -199,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -207,30 +229,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -241,7 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -249,33 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -284,16 +341,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -302,12 +360,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -318,26 +374,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -355,29 +412,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -404,12 +510,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -426,7 +561,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -446,7 +581,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -466,26 +601,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -495,16 +632,12 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -515,7 +648,8 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) except NotThisMethod: pass @@ -524,16 +658,13 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -547,10 +678,6 @@ def get_versions(): except NotThisMethod: pass - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/src/py_eddy_tracker/appli/eddies.py b/src/py_eddy_tracker/appli/eddies.py index 4809fddf..c1c7a90d 100644 --- a/src/py_eddy_tracker/appli/eddies.py +++ b/src/py_eddy_tracker/appli/eddies.py @@ -3,12 +3,11 @@ Applications on detection and tracking files """ import argparse -import logging from datetime import datetime from glob import glob +import logging from os import mkdir -from os.path import basename, dirname, exists -from os.path import join as join_path +from os.path import basename, dirname, exists, join as join_path from re import compile as re_compile from netCDF4 import Dataset @@ -243,7 +242,8 @@ def browse_dataset_in( filenames = bytes_(glob(full_path)) dataset_list = empty( - len(filenames), dtype=[("filename", "S500"), ("date", "datetime64[s]")], + len(filenames), + dtype=[("filename", "S500"), ("date", "datetime64[s]")], ) dataset_list["filename"] = filenames @@ -371,7 +371,8 @@ def track( logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max()) logger.info( - "The mean length is %d observations for long track", c.nb_obs_by_tracks.mean(), + "The mean length is %d observations for long track", + c.nb_obs_by_tracks.mean(), ) long_track.write_file(**kw_write) @@ -381,7 +382,14 @@ def track( def get_group( - dataset1, dataset2, index1, index2, score, invalid=2, low=10, high=60, + dataset1, + dataset2, + index1, + index2, + score, + invalid=2, + low=10, + high=60, ): group1, group2 = dict(), dict() m_valid = (score * 100) >= invalid @@ -490,7 +498,8 @@ def get_values(v, dataset): ] labels = dict( - high=f"{high:0.0f} <= high", low=f"{invalid:0.0f} <= low < {low:0.0f}", + high=f"{high:0.0f} <= high", + low=f"{invalid:0.0f} <= low < {low:0.0f}", ) keys = [labels.get(key, key) for key in list(gr_ref.values())[0].keys()] diff --git a/src/py_eddy_tracker/appli/gui.py b/src/py_eddy_tracker/appli/gui.py index 427db24b..c3d7619b 100644 --- a/src/py_eddy_tracker/appli/gui.py +++ b/src/py_eddy_tracker/appli/gui.py @@ -3,15 +3,15 @@ Entry point of graphic user interface """ -import logging from datetime import datetime, timedelta from itertools import chain +import logging from matplotlib import pyplot from matplotlib.animation import FuncAnimation from matplotlib.axes import Axes from matplotlib.collections import LineCollection -from numpy import arange, where +from numpy import arange, where, nan from .. import EddyParser from ..gui import GUI @@ -58,7 +58,10 @@ def setup( self.kw_label["fontweight"] = kwargs.pop("fontweight", "demibold") # To text each visible eddy if field_txt: - self.field_txt = self.eddy[field_txt] + if isinstance(field_txt,str): + self.field_txt = self.eddy[field_txt] + else : + self.field_txt=field_txt if field_color: # To color each visible eddy self.field_color = self.eddy[field_color].astype("f4") diff --git a/src/py_eddy_tracker/appli/network.py b/src/py_eddy_tracker/appli/network.py index e9baa7be..0a3d06ca 100644 --- a/src/py_eddy_tracker/appli/network.py +++ b/src/py_eddy_tracker/appli/network.py @@ -5,6 +5,8 @@ import logging +from numpy import in1d, zeros + from .. import EddyParser from ..observations.network import Network, NetworkObservations from ..observations.tracking import TrackEddiesObservations @@ -34,6 +36,11 @@ def build_network(): action="store_true", help="If True, use intersection/little polygon, else intersection/union", ) + parser.add_argument( + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", + ) parser.contour_intern_arg() @@ -47,7 +54,9 @@ def build_network(): memory=args.memory, ) group = n.group_observations( - min_overlap=args.min_overlap, minimal_area=args.minimal_area + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, ) n.build_dataset(group).write_file(filename=args.out) @@ -72,6 +81,11 @@ def divide_network(): action="store_true", help="If True, use intersection/little polygon, else intersection/union", ) + parser.add_argument( + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", + ) args = parser.parse_args() contour_name = TrackEddiesObservations.intern(args.intern, public_label=True) e = TrackEddiesObservations.load_file( @@ -85,6 +99,7 @@ def divide_network(): window=args.window, min_overlap=args.min_overlap, minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, ), ) n.write_file(filename=args.out) @@ -109,7 +124,12 @@ def subset_network(): help="Remove short dead end, first is for minimal obs number and second for minimal segment time to keep", ) parser.add_argument( - "--remove_trash", action="store_true", help="Remove trash (network id == 0)", + "--remove_trash", + action="store_true", + help="Remove trash (network id == 0)", + ) + parser.add_argument( + "-i", "--ids", nargs="+", type=int, help="List of network which will be extract" ) parser.add_argument( "-p", @@ -121,6 +141,8 @@ def subset_network(): ) args = parser.parse_args() n = NetworkObservations.load_file(args.input, raw_data=True) + if args.ids is not None: + n = n.networks(args.ids) if args.length is not None: n = n.longer_than(*args.length) if args.remove_dead_end is not None: @@ -128,3 +150,153 @@ def subset_network(): if args.period is not None: n = n.extract_with_period(args.period) n.write_file(filename=args.out) + + +def quick_compare(): + parser = EddyParser( + """Tool to have a quick comparison between several network: + - N : network + - S : segment + - Obs : observations + """ + ) + parser.add_argument("ref", help="Identification file of reference") + parser.add_argument("others", nargs="+", help="Identifications files to compare") + parser.add_argument( + "--path_out", default=None, help="Save each group in separate file" + ) + args = parser.parse_args() + + kw = dict( + include_vars=[ + "longitude", + "latitude", + "time", + "track", + "segment", + "next_obs", + "previous_obs", + ] + ) + + if args.path_out is not None: + kw = dict() + + ref = NetworkObservations.load_file(args.ref, **kw) + print( + f"[ref] {args.ref} -> {ref.nb_network} network / {ref.nb_segment} segment / {len(ref)} obs " + f"-> {ref.network_size(0)} trash obs, " + f"{len(ref.merging_event())} merging, {len(ref.splitting_event())} spliting" + ) + others = { + other: NetworkObservations.load_file(other, **kw) for other in args.others + } + + # if args.path_out is not None: + # groups_ref, groups_other = run_compare(ref, others, **kwargs) + # if not exists(args.path_out): + # mkdir(args.path_out) + # for i, other_ in enumerate(args.others): + # dirname_ = f"{args.path_out}/{other_.replace('/', '_')}/" + # if not exists(dirname_): + # mkdir(dirname_) + # for k, v in groups_other[other_].items(): + # basename_ = f"other_{k}.nc" + # others[other_].index(v).write_file(filename=f"{dirname_}/{basename_}") + # for k, v in groups_ref[other_].items(): + # basename_ = f"ref_{k}.nc" + # ref.index(v).write_file(filename=f"{dirname_}/{basename_}") + # return + display_compare(ref, others) + + +def run_compare(ref, others): + outs = dict() + for i, (k, other) in enumerate(others.items()): + out = dict() + print( + f"[{i}] {k} -> {other.nb_network} network / {other.nb_segment} segment / {len(other)} obs " + f"-> {other.network_size(0)} trash obs, " + f"{len(other.merging_event())} merging, {len(other.splitting_event())} spliting" + ) + ref_id, other_id = ref.identify_in(other, size_min=2) + m = other_id != -1 + ref_id, other_id = ref_id[m], other_id[m] + out["same N(N)"] = m.sum() + out["same N(Obs)"] = ref.network_size(ref_id).sum() + + # For network which have same obs + ref_, other_ = ref.networks(ref_id), other.networks(other_id) + ref_segu, other_segu = ref_.identify_in(other_, segment=True) + m = other_segu == -1 + ref_track_no_match, _ = ref_.unique_segment_to_id(ref_segu[m]) + ref_segu, other_segu = ref_segu[~m], other_segu[~m] + m = ~in1d(ref_id, ref_track_no_match) + out["same NS(N)"] = m.sum() + out["same NS(Obs)"] = ref.network_size(ref_id[m]).sum() + + # Check merge/split + def follow_obs(d, i_follow): + m = i_follow != -1 + i_follow = i_follow[m] + t, x, y = ( + zeros(m.size, d.time.dtype), + zeros(m.size, d.longitude.dtype), + zeros(m.size, d.latitude.dtype), + ) + t[m], x[m], y[m] = ( + d.time[i_follow], + d.longitude[i_follow], + d.latitude[i_follow], + ) + return t, x, y + + def next_obs(d, i_seg): + last_i = d.index_segment_track[1][i_seg] - 1 + return follow_obs(d, d.next_obs[last_i]) + + def previous_obs(d, i_seg): + first_i = d.index_segment_track[0][i_seg] + return follow_obs(d, d.previous_obs[first_i]) + + tref, xref, yref = next_obs(ref_, ref_segu) + tother, xother, yother = next_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + tref, xref, yref = previous_obs(ref_, ref_segu) + tother, xother, yother = previous_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + ref_segu, other_segu = ref.identify_in(other, segment=True) + m = other_segu != -1 + out["same S(S)"] = m.sum() + out["same S(Obs)"] = ref.segment_size()[ref_segu[m]].sum() + + outs[k] = out + return outs + + +def display_compare(ref, others): + def display(value, ref=None): + if ref: + outs = [f"{v / ref[k] * 100:.1f}% ({v})" for k, v in value.items()] + else: + outs = value + return "".join([f"{v:^18}" for v in outs]) + + datas = run_compare(ref, others) + ref_ = { + "same N(N)": ref.nb_network, + "same N(Obs)": len(ref), + "same NS(N)": ref.nb_network, + "same NS(Obs)": len(ref), + "same S(S)": ref.nb_segment, + "same S(Obs)": len(ref), + } + print(" ", display(ref_.keys())) + for i, (_, v) in enumerate(datas.items()): + print(f"[{i:2}] ", display(v, ref=ref_)) diff --git a/src/py_eddy_tracker/data/__init__.py b/src/py_eddy_tracker/data/__init__.py index 4702af8f..bf062983 100644 --- a/src/py_eddy_tracker/data/__init__.py +++ b/src/py_eddy_tracker/data/__init__.py @@ -8,10 +8,11 @@ 20160515 adt None None longitude latitude . \ --cut 800 --fil 1 """ + import io import lzma -import tarfile from os import path +import tarfile import requests @@ -26,14 +27,20 @@ def get_remote_demo_sample(path): if path.endswith(".nc"): return io.BytesIO(content) else: - if path.endswith(".nc"): + try: + import py_eddy_tracker_sample_id + if path.endswith(".nc"): + return py_eddy_tracker_sample_id.get_remote_demo_sample(path) + content = open(py_eddy_tracker_sample_id.get_remote_demo_sample(f"{path}.tar.xz"), "rb").read() + except: + if path.endswith(".nc"): + content = requests.get( + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}" + ).content + return io.BytesIO(content) content = requests.get( - f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}" + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" ).content - return io.BytesIO(content) - content = requests.get( - f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" - ).content # Tar module could manage lzma tar, but it will apply uncompress for each extractfile tar = tarfile.open(mode="r", fileobj=io.BytesIO(lzma.decompress(content))) diff --git a/src/py_eddy_tracker/dataset/grid.py b/src/py_eddy_tracker/dataset/grid.py index 091d2016..f15503b2 100644 --- a/src/py_eddy_tracker/dataset/grid.py +++ b/src/py_eddy_tracker/dataset/grid.py @@ -2,14 +2,14 @@ """ Class to load and manipulate RegularGrid and UnRegularGrid """ -import logging from datetime import datetime +import logging from cv2 import filter2D from matplotlib.path import Path as BasePath from netCDF4 import Dataset -from numba import njit, prange -from numba import types as numba_types +from numba import njit, prange, types as numba_types +import numpy as np from numpy import ( arange, array, @@ -28,9 +28,7 @@ isnan, linspace, ma, -) -from numpy import mean as np_mean -from numpy import ( + mean as np_mean, meshgrid, nan, nanmean, @@ -38,9 +36,9 @@ percentile, pi, radians, - round_, sin, sinc, + sqrt, where, zeros, ) @@ -52,6 +50,7 @@ from scipy.special import j1 from .. import VAR_DESCR +from ..data import get_demo_path from ..eddy_feature import Amplitude, Contours from ..generic import ( bbox_indice_regular, @@ -125,7 +124,7 @@ def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size @njit(cache=True) def mean_on_regular_contour( - x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None, nan_remove=False + x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=-1, nan_remove=False ): x_val, y_val = vertices[:, 0], vertices[:, 1] x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size) @@ -307,9 +306,26 @@ def __init__( "We assume pixel position of grid is centered for %s", filename ) if not unset: + self.populate() + + def populate(self): + if self.dimensions is None: self.load_general_features() self.load() + def clean(self): + self.dimensions = None + self.variables_description = None + self.global_attrs = None + self.x_c = None + self.y_c = None + self.x_bounds = None + self.y_bounds = None + self.x_dim = None + self.y_dim = None + self.contours = None + self.vars = dict() + @property def is_centered(self): """Give True if pixel is described with its center's position or @@ -402,6 +418,14 @@ def load(self): self.setup_coordinates() + @staticmethod + def get_mask(a): + if len(a.mask.shape): + m = a.mask + else: + m = ones(a.shape, dtype="bool") if a.mask else zeros(a.shape, dtype="bool") + return m + @staticmethod def c_to_bounds(c): """ @@ -419,9 +443,9 @@ def c_to_bounds(c): def setup_coordinates(self): x_name, y_name = self.coordinates if self.is_centered: - logger.info("Grid center") - self.x_c = self.vars[x_name].astype("float64") - self.y_c = self.vars[y_name].astype("float64") + # logger.info("Grid center") + self.x_c = array(self.vars[x_name].astype("float64")) + self.y_c = array(self.vars[y_name].astype("float64")) self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],))) self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],))) @@ -433,8 +457,8 @@ def setup_coordinates(self): self.y_bounds[-1] -= d_y[-1] / 2 else: - self.x_bounds = self.vars[x_name].astype("float64") - self.y_bounds = self.vars[y_name].astype("float64") + self.x_bounds = array(self.vars[x_name].astype("float64")) + self.y_bounds = array(self.vars[y_name].astype("float64")) if len(self.x_dim) == 1: self.x_c = self.x_bounds.copy() @@ -531,7 +555,8 @@ def grid(self, varname, indexs=None): self.vars[varname] = self.vars[varname].T if self.nan_mask: self.vars[varname] = ma.array( - self.vars[varname], mask=isnan(self.vars[varname]), + self.vars[varname], + mask=isnan(self.vars[varname]), ) if not hasattr(self.vars[varname], "mask"): self.vars[varname] = ma.array( @@ -770,7 +795,7 @@ def eddy_identification( # Test of the rotating sense: cyclone or anticyclone if has_value( - data, i_x_in, i_y_in, cvalues, below=anticyclonic_search + data.data, i_x_in, i_y_in, cvalues, below=anticyclonic_search ): continue @@ -801,13 +826,12 @@ def eddy_identification( contour.reject = 4 continue if reset_centroid: - if self.is_circular(): centi = self.normalize_x_indice(reset_centroid[0]) else: centi = reset_centroid[0] centj = reset_centroid[1] - # To move in regular and unregular grid + # FIXME : To move in regular and unregular grid if len(x.shape) == 1: centlon_e = x[centi] centlat_e = y[centj] @@ -861,7 +885,9 @@ def eddy_identification( num_fac=presampling_multiplier, ) xy_e = uniform_resample( - contour.lon, contour.lat, num_fac=presampling_multiplier, + contour.lon, + contour.lat, + num_fac=presampling_multiplier, ) xy_s = uniform_resample( speed_contour.lon, @@ -1126,7 +1152,7 @@ def _low_filter(self, grid_name, w_cut, factor=8.0): bins = (x_array, y_array) x_flat, y_flat, z_flat = x.reshape((-1,)), y.reshape((-1,)), data.reshape((-1,)) - m = ~z_flat.mask + m = ~self.get_mask(z_flat) x_flat, y_flat, z_flat = x_flat[m], y_flat[m], z_flat[m] nb_value, _, _ = histogram2d(x_flat, y_flat, bins=bins) @@ -1193,6 +1219,12 @@ def setup_coordinates(self): raise Exception( "Coordinates in RegularGridDataset must be 1D array, or think to use UnRegularGridDataset" ) + dx = self.x_bounds[1:] - self.x_bounds[:-1] + dy = self.y_bounds[1:] - self.y_bounds[:-1] + if (dx < 0).any() or (dy < 0).any(): + raise Exception( + "Coordinates in RegularGridDataset must be strictly increasing" + ) self._x_step = (self.x_c[1:] - self.x_c[:-1]).mean() self._y_step = (self.y_c[1:] - self.y_c[:-1]).mean() @@ -1287,9 +1319,13 @@ def compute_pixel_path(self, x0, y0, x1, y1): self.x_size, ) - def clean_land(self): + def clean_land(self, name): """Function to remove all land pixel""" - pass + mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat") + x, y = meshgrid(self.x_c, self.y_c) + m = mask_land.interp("mask", x.reshape(-1), y.reshape(-1), "nearest") + data = self.grid(name) + self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T) def is_circular(self): """Check if the grid is circular""" @@ -1311,7 +1347,7 @@ def get_step_in_km(self, lat, wave_length): min_wave_length = max(step_x_km, step_y_km) * 2 if wave_length < min_wave_length: logger.error( - "wave_length too short for resolution, must be > %d km", + "Wave_length too short for resolution, must be > %d km", ceil(min_wave_length), ) raise Exception() @@ -1362,6 +1398,24 @@ def kernel_lanczos(self, lat, wave_length, order=1): kernel[dist_norm > order] = 0 return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_loess(self, lat, wave_length, order=1): + """ + https://fr.wikipedia.org/wiki/R%C3%A9gression_locale + """ + order = self.check_order(order) + half_x_pt, half_y_pt, dist_norm = self.estimate_kernel_shape( + lat, wave_length, order + ) + + def inc_func(xdist): + f = zeros(xdist.size) + f[abs(xdist) < 1] = 1 + return f + + kernel = (1 - abs(dist_norm) ** 3) ** 3 + kernel[abs(dist_norm) > order] = 0 + return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_bessel(self, lat, wave_length, order=1): """wave_length in km order must be int @@ -1616,7 +1670,7 @@ def spectrum_lonlat(self, grid_name, area=None, ref=None, **kwargs): (lat_content[0], lat_content[1] / ref_lat_content[1]), ) - def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False): + def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False, second=False): if not isinstance(schema, int) and schema < 1: raise Exception("schema must be a positive int") @@ -1639,12 +1693,17 @@ def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=Fal data1[-schema:] = nan data2[:schema] = nan + # Distance for one degree d = self.EARTH_RADIUS * 2 * pi / 360 * 2 * schema + # Mulitply by 2 step if vertical: - d *= self.ystep + d *= self.ystep else: d *= self.xstep * cos(deg2rad(self.y_c)) - return (data1 - data2) / d + if second: + return (data1 + data2 - 2 * data) / (d ** 2 / 4) + else: + return (data1 - data2) / d def compute_stencil( self, data, stencil_halfwidth=4, mode="reflect", vertical=False @@ -1724,20 +1783,21 @@ def compute_stencil( self.x_c, self.y_c, data.data, - data.mask, + self.get_mask(data), self.EARTH_RADIUS, vertical=vertical, stencil_halfwidth=stencil_halfwidth, ) return ma.array(g, mask=m) - def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): - self.add_uv(grid_height, uname, vname) + def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15, **kwargs): + self.add_uv(grid_height, uname, vname, **kwargs) latmax = 5 - _, (i_start, i_end) = self.nearest_grd_indice((0, 0), (-latmax, latmax)) + _, i_start = self.nearest_grd_indice(0, -latmax) + _, i_end = self.nearest_grd_indice(0, latmax) sl = slice(i_start, i_end) # Divide by sideral day - lat = self.y_c[sl] + lat = self.y_c gob = ( cos(deg2rad(lat)) * ones((self.x_c.shape[0], 1)) @@ -1751,39 +1811,26 @@ def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): mode = "wrap" if self.is_circular() else "reflect" # fill data to compute a finite difference on all point - data = self.convolve_filter_with_dynamic_kernel( - grid_height, - self.kernel_bessel, - lat_max=10, - wave_length=500, - order=1, - extend=0.1, - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) + kw_filter = dict(kernel_func=self.kernel_bessel, order=1, extend=.1) + data = self.convolve_filter_with_dynamic_kernel(grid_height, wave_length=500, **kw_filter, lat_max=6+5+2+3) v_lagerloef = ( self.compute_finite_difference( - self.compute_finite_difference(data, mode=mode, schema=schema), - mode=mode, - schema=schema, - )[:, sl] - * gob - ) - u_lagerloef = ( - -self.compute_finite_difference( - self.compute_finite_difference(data, vertical=True, schema=schema), - vertical=True, - schema=schema, - )[:, sl] + self.compute_finite_difference(data, mode=mode, schema=1), + vertical=True, schema=1 + ) * gob ) - w = 1 - exp(-((lat / 2.2) ** 2)) - self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef * (1 - w) - self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef * (1 - w) + u_lagerloef = -self.compute_finite_difference(data, vertical=True, schema=schema, second=True) * gob + + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6) + w = 1 - exp(-((lat[sl] / 2.2) ** 2)) + self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef[:, sl] * (1 - w) + self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef[:, sl] * (1 - w) def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): r"""Compute a u and v grid @@ -1856,7 +1903,7 @@ def speed_coef_mean(self, contour): return mean_on_regular_contour( self.x_c, self.y_c, - self._speed_ev, + self._speed_ev.data, self._speed_ev.mask, contour.vertices, nan_remove=True, @@ -1864,7 +1911,8 @@ def speed_coef_mean(self, contour): def init_speed_coef(self, uname="u", vname="v"): """Draft""" - self._speed_ev = (self.grid(uname) ** 2 + self.grid(vname) ** 2) ** 0.5 + u, v = self.grid(uname), self.grid(vname) + self._speed_ev = sqrt(u * u + v * v) def display(self, ax, name, factor=1, ref=None, **kwargs): """ @@ -1932,14 +1980,6 @@ def regrid(self, other, grid_name, new_name=None): # self.variables_description[new_name]['infos'] = False # self.variables_description[new_name]['kwargs']['dimensions'] = ... - @staticmethod - def get_mask(a): - if len(a.mask.shape): - m = a.mask - else: - m = ones(a.shape) if a.mask else zeros(a.shape) - return m - def interp(self, grid_name, lons, lats, method="bilinear"): """ Compute z over lons, lats @@ -1954,17 +1994,32 @@ def interp(self, grid_name, lons, lats, method="bilinear"): g = self.grid(grid_name) m = self.get_mask(g) return interp2d_geo( - self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest" + self.x_c, self.y_c, g.data, m, lons, lats, nearest=method == "nearest" ) - def uv_for_advection(self, u_name, v_name, time_step=600, backward=False, factor=1): + def uv_for_advection( + self, + u_name=None, + v_name=None, + time_step=600, + h_name=None, + backward=False, + factor=1, + ): """ Get U,V to be used in degrees with precomputed time step - :param str,array u_name: U field to advect obs - :param str,array v_name: V field to advect obs + :param None,str,array u_name: U field to advect obs, if h_name is None + :param None,str,array v_name: V field to advect obs, if h_name is None + :param None,str,array h_name: H field to compute UV to advect obs, if u_name and v_name are None :param int time_step: Number of second for each advection """ + if h_name is not None: + u_name, v_name = "u", "v" + if u_name not in self.vars: + self.add_uv(h_name) + self.vars.pop(h_name, None) + u = (self.grid(u_name) if isinstance(u_name, str) else u_name).copy() v = (self.grid(v_name) if isinstance(v_name, str) else v_name).copy() # N seconds / 1 degrees in m @@ -1975,7 +2030,7 @@ def uv_for_advection(self, u_name, v_name, time_step=600, backward=False, factor u = -u v = -v m = u.mask + v.mask - return u, v, m + return u.data, v.data, m def advect(self, x, y, u_name, v_name, nb_step=10, rk4=True, **kw): """ @@ -2196,12 +2251,11 @@ def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): i_x1 = empty(nx, dtype=numba_types.int_) i_y0 = empty(nx, dtype=numba_types.int_) i_y1 = empty(nx, dtype=numba_types.int_) - # Because round_ is not accepted with array in numba for i in range(nx): - i_x0[i] = round_(((x0[i] - x_ori) % 360) / x_step) - i_x1[i] = round_(((x1[i] - x_ori) % 360) / x_step) - i_y0[i] = round_((y0[i] - y_ori) / y_step) - i_y1[i] = round_((y1[i] - y_ori) / y_step) + i_x0[i] = np.round(((x0[i] - x_ori) % 360) / x_step) + i_x1[i] = np.round(((x1[i] - x_ori) % 360) / x_step) + i_y0[i] = np.round((y0[i] - y_ori) / y_step) + i_y1[i] = np.round((y1[i] - y_ori) / y_step) # Delta index of x d_x = i_x1 - i_x0 d_x = (d_x + nb_x // 2) % nb_x - (nb_x // 2) @@ -2281,28 +2335,40 @@ def __init__(self): self.datasets = list() @classmethod - def from_netcdf_cube(cls, filename, x_name, y_name, t_name, heigth=None): + def from_netcdf_cube(cls, filename, x_name, y_name, t_name, heigth=None, **kwargs): new = cls() with Dataset(filename) as h: for i, t in enumerate(h.variables[t_name][:]): - d = RegularGridDataset(filename, x_name, y_name, indexs={t_name: i}) + d = RegularGridDataset( + filename, x_name, y_name, indexs={t_name: i}, **kwargs + ) if heigth is not None: d.add_uv(heigth) new.datasets.append((t, d)) return new @classmethod - def from_netcdf_list(cls, filenames, t, x_name, y_name, indexs=None, heigth=None): + def from_netcdf_list( + cls, filenames, t, x_name, y_name, indexs=None, heigth=None, **kwargs + ): new = cls() for i, _t in enumerate(t): filename = filenames[i] logger.debug(f"load file {i:02d}/{len(t)} t={_t} : {filename}") - d = RegularGridDataset(filename, x_name, y_name, indexs=indexs) + d = RegularGridDataset(filename, x_name, y_name, indexs=indexs, **kwargs) if heigth is not None: d.add_uv(heigth) new.datasets.append((_t, d)) return new + @property + def are_loaded(self): + return ~array([d.dimensions is None for _, d in self.datasets]) + + def __repr__(self): + nb_dataset = len(self.datasets) + return f"{self.are_loaded.sum()}/{nb_dataset} datasets loaded" + def shift_files(self, t, filename, heigth=None, **rgd_kwargs): """Add next file to the list and remove the oldest""" @@ -2342,9 +2408,19 @@ def __iter__(self): for _, d in self.datasets: yield d + @property + def time(self): + return array([t for t, _ in self.datasets]) + + @property + def period(self): + t = self.time + return t.min(), t.max() + def __getitem__(self, item): for t, d in self.datasets: if t == item: + d.populate() return d raise KeyError(item) @@ -2424,17 +2500,23 @@ def filament( t += dt yield t, f_x, f_y + def reset_grids(self, N=None): + if N is not None: + m = self.are_loaded + if m.sum() > N: + for i in where(m)[0]: + self.datasets[i][1].clean() + def advect( self, x, y, - u_name, - v_name, t_init, mask_particule=None, nb_step=10, time_step=600, rk4=True, + reset_grid=None, **kw, ): """ @@ -2442,15 +2524,18 @@ def advect( :param array x: Longitude of obs to move :param array y: Latitude of obs to move - :param str,array u_name: U field to advect obs - :param str,array v_name: V field to advect obs + :param float t_init: time to start advection + :param array,None mask_particule: advect only i mask is True :param int nb_step: Number of iteration before to release data :param int time_step: Number of second for each advection + :param bool rk4: Use rk4 algorithm instead of finite difference + :param int reset_grid: Delete all loaded data in cube if there are more than N grid loaded - :return: x,y position + :return: t,x,y position .. minigallery:: py_eddy_tracker.GridCollection.advect """ + self.reset_grids(reset_grid) backward = kw.get("backward", False) if backward: generator = self.get_previous_time_step(t_init) @@ -2461,9 +2546,9 @@ def advect( dt = nb_step * time_step t_step = time_step t0, d0 = generator.__next__() - u0, v0, m0 = d0.uv_for_advection(u_name, v_name, time_step, **kw) + u0, v0, m0 = d0.uv_for_advection(time_step=time_step, **kw) t1, d1 = generator.__next__() - u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) t0 = t0 * 86400 t1 = t1 * 86400 t = t_init * 86400 @@ -2473,12 +2558,12 @@ def advect( else: mask_particule += isnan(x) + isnan(y) while True: - logger.debug(f"advect : t={t}") + logger.debug(f"advect : t={t/86400}") if (backward and t <= t1) or (not backward and t >= t1): t0, u0, v0, m0 = t1, u1, v1, m1 t1, d1 = generator.__next__() t1 = t1 * 86400 - u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0) half_w = t_step / 2.0 / (t1 - t0) advect_( @@ -2503,7 +2588,7 @@ def get_next_time_step(self, t_init): for i, (t, dataset) in enumerate(self.datasets): if t < t_init: continue - + dataset.populate() logger.debug(f"i={i}, t={t}, dataset={dataset}") yield t, dataset @@ -2513,10 +2598,32 @@ def get_previous_time_step(self, t_init): i -= 1 if t > t_init: continue - + dataset.populate() logger.debug(f"i={i}, t={t}, dataset={dataset}") yield t, dataset + def path(self, x0, y0, *args, nb_time=2, **kwargs): + """ + At each call it will update position in place with u & v field + + :param array x0: Longitude of obs to move + :param array y0: Latitude of obs to move + :param int nb_time: Number of iteration for particle + :param dict kwargs: look at :py:meth:`GridCollection.advect` + + :return: t,x,y + + .. minigallery:: py_eddy_tracker.GridCollection.path + """ + particles = self.advect(x0.copy(), y0.copy(), *args, **kwargs) + t = empty(nb_time + 1, dtype="f8") + x = empty((nb_time + 1, x0.size), dtype=x0.dtype) + y = empty(x.shape, dtype=y0.dtype) + t[0], x[0], y[0] = kwargs.get("t_init"), x0, y0 + for i in range(nb_time): + t[i + 1], x[i + 1], y[i + 1] = particles.__next__() + return t, x, y + @njit(cache=True) def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, half_w=0): @@ -2833,7 +2940,7 @@ def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth= h_3, h_2, h_1, h0 = h[-4, j], h[-3, j], h[-2, j], h[-1, j] m_3, m_2, m_1, m0 = m[-4, j], m[-3, j], m[-2, j], m[-1, j] else: - m_3, m_2, m_1, m0 = False, False, False, False + m_3, m_2, m_1, m0 = True, True, True, True h1, h2, h3, h4 = h[0, j], h[1, j], h[2, j], h[3, j] m1, m2, m3, m4 = m[0, j], m[1, j], m[2, j], m[3, j] for i in range(nb_x): diff --git a/src/py_eddy_tracker/eddy_feature.py b/src/py_eddy_tracker/eddy_feature.py index 3640b306..8bc139ab 100644 --- a/src/py_eddy_tracker/eddy_feature.py +++ b/src/py_eddy_tracker/eddy_feature.py @@ -8,8 +8,7 @@ from matplotlib.cm import get_cmap from matplotlib.colors import Normalize from matplotlib.figure import Figure -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( array, concatenate, @@ -433,8 +432,8 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): closed_contours = 0 # Count level and contour for i, collection in enumerate(self.contours.collections): - collection.get_nearest_path_bbox_contain_pt = lambda x, y, i=i: self.get_index_nearest_path_bbox_contain_pt( - i, x, y + collection.get_nearest_path_bbox_contain_pt = ( + lambda x, y, i=i: self.get_index_nearest_path_bbox_contain_pt(i, x, y) ) nb_level += 1 @@ -784,7 +783,7 @@ def index_from_nearest_path_with_pt_in_bbox_( d_x = x_value[i_elt_pt] - xpt_ if abs(d_x) > 180: d_x = (d_x + 180) % 360 - 180 - dist = d_x ** 2 + (y_value[i_elt_pt] - ypt) ** 2 + dist = d_x**2 + (y_value[i_elt_pt] - ypt) ** 2 if dist < dist_ref: dist_ref = dist i_ref = i_elt_c diff --git a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py index 41e02db9..b0d4abfa 100644 --- a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py +++ b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py @@ -8,7 +8,6 @@ class CheltonTracker(Model): - __slots__ = tuple() GROUND = RegularGridDataset( diff --git a/src/py_eddy_tracker/generic.py b/src/py_eddy_tracker/generic.py index 94cf321f..2fdb737a 100644 --- a/src/py_eddy_tracker/generic.py +++ b/src/py_eddy_tracker/generic.py @@ -3,8 +3,7 @@ Tool method which use mostly numba """ -from numba import njit, prange -from numba import types as numba_types +from numba import njit, prange, types as numba_types from numpy import ( absolute, arcsin, @@ -132,8 +131,8 @@ def distance_grid(lon0, lat0, lon1, lat1): sin_dlon = sin((dlon) * 0.5 * D2R) cos_lat1 = cos(lat0[i] * D2R) cos_lat2 = cos(lat1[j] * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - dist[i, j] = 6370.997 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + dist[i, j] = 6370.997 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) return dist @@ -154,8 +153,8 @@ def distance(lon0, lat0, lon1, lat1): sin_dlon = sin((lon1 - lon0) * 0.5 * D2R) cos_lat1 = cos(lat0 * D2R) cos_lat2 = cos(lat1 * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - return 6370997.0 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + return 6370997.0 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) @njit(cache=True) @@ -302,14 +301,14 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): @njit(cache=True, fastmath=True) -def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): +def uniform_resample(x_val, y_val, num_fac=2, fixed_size=-1): """ Resample contours to have (nearly) equal spacing. :param array_like x_val: input x contour coordinates :param array_like y_val: input y contour coordinates :param int num_fac: factor to increase lengths of output coordinates - :param int,None fixed_size: if defined, will be used to set sampling + :param int fixed_size: if > -1, will be used to set sampling """ nb = x_val.shape[0] # Get distances @@ -320,7 +319,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): dist[1:][dist[1:] < 1e-3] = 1e-3 dist = dist.cumsum() # Get uniform distances - if fixed_size is None: + if fixed_size == -1: fixed_size = dist.size * num_fac d_uniform = linspace(0, dist[-1], fixed_size) x_new = interp(d_uniform, dist, x_val) @@ -367,7 +366,7 @@ def simplify(x, y, precision=0.1): :return: (x,y) :rtype: (array,array) """ - precision2 = precision ** 2 + precision2 = precision**2 nb = x.shape[0] # will be True for kept values mask = ones(nb, dtype=bool_) @@ -399,7 +398,7 @@ def simplify(x, y, precision=0.1): if d_y > precision: x_previous, y_previous = x_, y_ continue - d2 = d_x ** 2 + d_y ** 2 + d2 = d_x**2 + d_y**2 if d2 > precision2: x_previous, y_previous = x_, y_ continue @@ -427,7 +426,7 @@ def split_line(x, y, i): """ nb_jump = len(where(i[1:] - i[:-1] != 0)[0]) nb_value = x.shape[0] - final_size = (nb_jump - 1) + nb_value + final_size = nb_jump + nb_value new_x = empty(final_size, dtype=x.dtype) new_y = empty(final_size, dtype=y.dtype) new_j = 0 @@ -457,17 +456,18 @@ def wrap_longitude(x, y, ref, cut=False): if cut: indexs = list() nb = x.shape[0] - new_previous = (x[0] - ref) % 360 + + new_x_previous = (x[0] - ref) % 360 + ref x_previous = x[0] for i in range(1, nb): x_ = x[i] - new_x = (x_ - ref) % 360 + new_x = (x_ - ref) % 360 + ref if not isnan(x_) and not isnan(x_previous): - d_new = new_x - new_previous + d_new = new_x - new_x_previous d = x_ - x_previous if abs(d - d_new) > 1e-5: indexs.append(i) - x_previous, new_previous = x_, new_x + x_previous, new_x_previous = x_, new_x nb_indexs = len(indexs) new_size = nb + nb_indexs * 3 @@ -478,6 +478,7 @@ def wrap_longitude(x, y, ref, cut=False): for i in range(nb): if j < nb_indexs and i == indexs[j]: j += 1 + # FIXME need check cor = 360 if x[i - 1] > x[i] else -360 out_x[i + i_] = (x[i] - ref) % 360 + ref - cor out_y[i + i_] = y[i] @@ -517,8 +518,8 @@ def coordinates_to_local(lon, lat, lon0, lat0): sin_dlon = sin(dlon * 0.5) cos_lat0 = cos(lat0 * D2R) cos_lat = cos(lat * D2R) - a_val = sin_dlon ** 2 * cos_lat0 * cos_lat + sin_dlat ** 2 - module = R * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat0 * cos_lat + sin_dlat**2 + module = R * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) azimuth = pi / 2 - arctan2( cos_lat * sin(dlon), @@ -541,7 +542,7 @@ def local_to_coordinates(x, y, lon0, lat0): """ D2R = pi / 180.0 R = 6370997 - d = (x ** 2 + y ** 2) ** 0.5 / R + d = (x**2 + y**2) ** 0.5 / R a = -(arctan2(y, x) - pi / 2) lat = arcsin(sin(lat0 * D2R) * cos(d) + cos(lat0 * D2R) * sin(d) * cos(a)) lon = ( @@ -612,3 +613,48 @@ def build_circle(x0, y0, r): angle = radians(linspace(0, 360, 50)) x_norm, y_norm = cos(angle), sin(angle) return x_norm * r + x0, y_norm * r + y0 + + +def window_index(x, x0, half_window=1): + """ + Give for a fixed half_window each start and end index for each x0, in + an unsorted array. + + :param array x: array of value + :param array x0: array of window center + :param float half_window: half window + """ + # Sort array, bounds will be sort also + i_ordered = x.argsort(kind="mergesort") + return window_index_(x, i_ordered, x0, half_window) + + +@njit(cache=True) +def window_index_(x, i_ordered, x0, half_window=1): + nb_x, nb_pt = x.size, x0.size + first_index = empty(nb_pt, dtype=i_ordered.dtype) + last_index = empty(nb_pt, dtype=i_ordered.dtype) + # First bound to find + j_min, j_max = 0, 0 + x_min = x0[j_min] - half_window + x_max = x0[j_max] + half_window + # We iterate on ordered x + for i, i_x in enumerate(i_ordered): + x_ = x[i_x] + # if x bigger than x_min , we found bound and search next one + while x_ > x_min and j_min < nb_pt: + first_index[j_min] = i + j_min += 1 + x_min = x0[j_min] - half_window + # if x bigger than x_max , we found bound and search next one + while x_ > x_max and j_max < nb_pt: + last_index[j_max] = i + j_max += 1 + x_max = x0[j_max] + half_window + if j_max == nb_pt: + break + for i in range(j_min, nb_pt): + first_index[i] = nb_x + for i in range(j_max, nb_pt): + last_index[i] = nb_x + return i_ordered, first_index, last_index diff --git a/src/py_eddy_tracker/gui.py b/src/py_eddy_tracker/gui.py index deeb6660..a85e9c18 100644 --- a/src/py_eddy_tracker/gui.py +++ b/src/py_eddy_tracker/gui.py @@ -4,13 +4,16 @@ """ from datetime import datetime, timedelta +import logging +from matplotlib.projections import register_projection import matplotlib.pyplot as plt import numpy as np -from matplotlib.projections import register_projection from .generic import flatten_line_matrix, split_line +logger = logging.getLogger("pet") + try: from pylook.axes import PlatCarreAxes except ImportError: @@ -91,7 +94,7 @@ def set_initial_values(self): for dataset in self.datasets.values(): t0_, t1_ = dataset.period t0, t1 = min(t0, t0_), max(t1, t1_) - + logger.debug("period detected %f -> %f", t0, t1) self.settings = dict(period=(t0, t1), now=t1) @property diff --git a/src/py_eddy_tracker/misc.py b/src/py_eddy_tracker/misc.py new file mode 100644 index 00000000..647bfba3 --- /dev/null +++ b/src/py_eddy_tracker/misc.py @@ -0,0 +1,21 @@ +import re + +from matplotlib.animation import FuncAnimation + + +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py index 6fea0ace..81929e1e 100644 --- a/src/py_eddy_tracker/observations/groups.py +++ b/src/py_eddy_tracker/observations/groups.py @@ -1,10 +1,11 @@ -import logging from abc import ABC, abstractmethod +import logging -from numba import njit -from numba import types as nb_types -from numpy import arange, int32, interp, median, where, zeros +from numba import njit, types as nb_types +from numpy import arange, full, int32, interp, isnan, median, where, zeros +from ..generic import window_index +from ..poly import create_meshed_particles, poly_indexs from .observation import EddiesObservations logger = logging.getLogger("pet") @@ -66,7 +67,7 @@ def get_missing_indices( return indices -def advect(x, y, c, t0, n_days): +def advect(x, y, c, t0, n_days, u_name="u", v_name="v"): """ Advect particles from t0 to t0 + n_days, with data cube. @@ -75,18 +76,59 @@ def advect(x, y, c, t0, n_days): :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles :param int t0: julian day of advection start :param int n_days: number of days to advect + :param str u_name: variable name for u component + :param str v_name: variable name for v component """ kw = dict(nb_step=6, time_step=86400 / 6) if n_days < 0: kw["backward"] = True n_days = -n_days - p = c.advect(x, y, "u", "v", t_init=t0, **kw) + p = c.advect(x, y, u_name=u_name, v_name=v_name, t_init=t0, **kw) for _ in range(n_days): t, x, y = p.__next__() return t, x, y +def particle_candidate_step( + t_start, contours_start, contours_end, space_step, dt, c, day_fraction=6, **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages. + For one time step. + + :param int t_start: julian day of the advection + :param (np.array(float),np.array(float)) contours_start: origin contour + :param (np.array(float),np.array(float)) contours_end: destination contour + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int day_fraction: fraction of day + :params dict kwargs: dict of params given to advection + :return (np.array,np.array): return target index and percent associate + """ + # In case of zarr array + contours_start = [i[:] for i in contours_start] + contours_end = [i[:] for i in contours_end] + # Create particles in start contour + x, y, i_start = create_meshed_particles(*contours_start, space_step) + # Advect particles + kw = dict(nb_step=day_fraction, time_step=86400 / day_fraction) + p = c.advect(x, y, t_init=t_start, **kwargs, **kw) + for _ in range(abs(dt)): + _, x, y = p.__next__() + m = ~(isnan(x) + isnan(y)) + i_end = full(x.shape, -1, dtype="i4") + if m.any(): + # Id eddies for each alive particle in start contour + i_end[m] = poly_indexs(x[m], y[m], *contours_end) + shape = (contours_start[0].shape[0], 2) + # Get target for each contour + i_target, pct_target = full(shape, -1, dtype="i4"), zeros(shape, dtype="f8") + nb_end = contours_end[0].shape[0] + get_targets(i_start, i_end, i_target, pct_target, nb_end) + return i_target, pct_target.astype("i1") + + def particle_candidate( c, eddies, @@ -118,14 +160,8 @@ def particle_candidate( translate_start = where(m_start)[0] # Create particles in specified contour - if contour_start == "speed": - x, y, i_start = e.create_particles(step_mesh, intern=True) - elif contour_start == "effective": - x, y, i_start = e.create_particles(step_mesh, intern=False) - else: - x, y, i_start = e.create_particles(step_mesh, intern=True) - print("The contour_start was not correct, speed contour is used") - + intern = False if contour_start == "effective" else True + x, y, i_start = e.create_particles(step_mesh, intern=intern) # Advection t_end, x, y = advect(x, y, c, t_start, **kwargs) @@ -137,18 +173,55 @@ def particle_candidate( translate_end = where(m_end)[0] # Id eddies for each alive particle in specified contour - if contour_end == "speed": - i_end = e_end.contains(x, y, intern=True) - elif contour_end == "effective": - i_end = e_end.contains(x, y, intern=False) - else: - i_end = e_end.contains(x, y, intern=True) - print("The contour_end was not correct, speed contour is used") + intern = False if contour_end == "effective" else True + i_end = e_end.contains(x, y, intern=intern) # compute matrix and fill target array get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct) +@njit(cache=True) +def get_targets(i_start, i_end, i_target, pct, nb_end): + """Compute target observation and associated percentages + + :param array(int) i_start: indices in time 0 + :param array(int) i_end: indices in time N + :param array(int) i_target: corresponding obs where particles are advected + :param array(int) pct: corresponding percentage of avected particles + :param int nb_end: number of contour at time N + """ + nb_start = i_target.shape[0] + # Matrix which will store count for every couple + counts = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + ref[i_start_] += 1 + if i_end_ != -1: + counts[i_start_, i_end_] += 1 + # From i to j + for i in range(nb_start): + for j in range(nb_end): + count = counts[i, j] + if count == 0: + continue + pct_ = count / ref[i] * 100 + pct_0 = pct[i, 0] + # If percent is higher than previous stored in rank 0 + if pct_ > pct_0: + pct[i, 1] = pct_0 + pct[i, 0] = pct_ + i_target[i, 1] = i_target[i, 0] + i_target[i, 0] = j + # If percent is higher than previous stored in rank 1 + elif pct_ > pct[i, 1]: + pct[i, 1] = pct_ + i_target[i, 1] = j + + @njit(cache=True) def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct): """Compute target observation and associated percentages @@ -219,15 +292,14 @@ def filled_by_interpolation(self, mask): nb_obs = len(self) index = arange(nb_obs) - for field in self.obs.dtype.descr: - var = field[0] + for field in self.fields: if ( - var in ["n", "virtual", "track", "cost_association"] - or var in self.array_variables + field in ["n", "virtual", "track", "cost_association"] + or field in self.array_variables ): continue - self.obs[var][mask] = interp( - index[mask], index[~mask], self.obs[var][~mask] + self.obs[field][mask] = interp( + index[mask], index[~mask], self.obs[field][~mask] ) def insert_virtual(self): @@ -277,3 +349,128 @@ def keep_tracks_by_date(self, date, nb_days): mask[i] = True return self.extract_with_mask(mask) + + def particle_candidate_atlas( + self, + cube, + space_step, + dt, + start_intern=False, + end_intern=False, + callback_coherence=None, + finalize_coherence=None, + **kwargs + ): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` cube: GridCollection with speed for particles + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param bool start_intern: Use intern or extern contour at injection, defaults to False + :param bool end_intern: Use intern or extern contour at end of advection, defaults to False + :param dict kwargs: dict of params given to advection + :param func callback_coherence: if None we will use cls.fill_coherence + :param func finalize_coherence: to apply on results of callback_coherence + :return (np.array,np.array): return target index and percent associate + """ + t_start, t_end = int(self.period[0]), int(self.period[1]) + # Pre-compute to get time index + i_sort, i_start, i_end = window_index( + self.time, arange(t_start, t_end + 1), 0.5 + ) + # Out shape + shape = (len(self), 2) + i_target, pct = full(shape, -1, dtype="i4"), zeros(shape, dtype="i1") + # Backward or forward + times = arange(t_start, t_end - dt) if dt > 0 else arange(t_start - dt, t_end) + + if callback_coherence is None: + callback_coherence = self.fill_coherence + indexs = dict() + results = list() + kw_coherence = dict(space_step=space_step, dt=dt, c=cube) + kw_coherence.update(kwargs) + for t in times: + logger.info( + "Coherence for time step : %s in [%s:%s]", t, times[0], times[-1] + ) + # Get index for origin + i = t - t_start + indexs0 = i_sort[i_start[i] : i_end[i]] + # Get index for end + i = t + dt - t_start + indexs1 = i_sort[i_start[i] : i_end[i]] + if indexs0.size == 0 or indexs1.size == 0: + continue + + results.append( + callback_coherence( + self, + i_target, + pct, + indexs0, + indexs1, + start_intern, + end_intern, + t_start=t, + **kw_coherence + ) + ) + indexs[results[-1]] = indexs0, indexs1 + + if finalize_coherence is not None: + finalize_coherence(results, indexs, i_target, pct) + return i_target, pct + + @classmethod + def fill_coherence( + cls, + network, + i_targets, + percents, + i_origin, + i_end, + start_intern, + end_intern, + **kwargs + ): + """_summary_ + + :param array i_targets: global target + :param array percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + :param bool start_intern: Use intern or extern contour at injection + :param bool end_intern: Use intern or extern contour at end of advection + """ + # Get contour data + contours_start = [ + network[label][i_origin] for label in cls.intern(start_intern) + ] + contours_end = [network[label][i_end] for label in cls.intern(end_intern)] + # Compute local coherence + i_local_targets, local_percents = particle_candidate_step( + contours_start=contours_start, contours_end=contours_end, **kwargs + ) + # Store + cls.merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ) + + @staticmethod + def merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ): + """Copy local result in merged result with global indexation + + :param array i_targets: global target + :param array percents: + :param array i_local_targets: local index target + :param array local_percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + """ + m = i_local_targets != -1 + i_local_targets[m] = i_end[i_local_targets[m]] + i_targets[i_origin] = i_local_targets + percents[i_origin] = local_percents diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 1c078bf8..f0b9d7cc 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -2,28 +2,33 @@ """ Class to create network of observations """ +from glob import glob import logging import time -from glob import glob - +from datetime import timedelta, datetime +import os import netCDF4 -import zarr -from numba import njit +from numba import njit, types as nb_types +from numba.typed import List +import numpy as np from numpy import ( arange, array, bincount, bool_, concatenate, + empty, - in1d, + nan, ones, + percentile, uint16, uint32, unique, where, zeros, ) +import zarr from ..dataset.grid import GridCollection from ..generic import build_index, wrap_longitude @@ -102,14 +107,37 @@ def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): class NetworkObservations(GroupEddiesObservations): - - __slots__ = ("_index_network",) - + __slots__ = ("_index_network", "_index_segment_track", "_segment_track_array") NOGROUP = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.reset_index() + + def __repr__(self): + m_event, s_event = ( + self.merging_event(only_index=True, triplet=True)[0], + self.splitting_event(only_index=True, triplet=True)[0], + ) + period = (self.period[1] - self.period[0]) / 365.25 + nb_by_network = self.network_size() + nb_trash = 0 if self.ref_index != 0 else nb_by_network[0] + lifetime=self.lifetime + big = 50_000 + infos = [ + f"Atlas with {self.nb_network} networks ({self.nb_network / period:0.0f} networks/year)," + f" {self.nb_segment} segments ({self.nb_segment / period:0.0f} segments/year), {len(self)} observations ({len(self) / period:0.0f} observations/year)", + f" {m_event.size} merging ({m_event.size / period:0.0f} merging/year), {s_event.size} splitting ({s_event.size / period:0.0f} splitting/year)", + f" with {(nb_by_network > big).sum()} network with more than {big} obs and the biggest have {nb_by_network.max()} observations ({nb_by_network[nb_by_network > big].sum()} observations cumulate)", + f" {nb_trash} observations in trash", + f" {lifetime.max()} days max of lifetime", + ] + return "\n".join(infos) + + def reset_index(self): self._index_network = None + self._index_segment_track = None + self._segment_track_array = None def find_segments_relative(self, obs, stopped=None, order=1): """ @@ -161,6 +189,80 @@ def index_network(self): self._index_network = build_index(self.track) return self._index_network + @property + def index_segment_track(self): + if self._index_segment_track is None: + self._index_segment_track = build_index(self.segment_track_array) + return self._index_segment_track + + def segment_size(self): + return self.index_segment_track[1] - self.index_segment_track[0] + + @property + def ref_segment_track_index(self): + return self.index_segment_track[2] + + @property + def ref_index(self): + return self.index_network[2] + + @property + def lifetime(self): + """Return lifetime for each observation""" + lt=self.networks_period.astype("int") + nb_by_network=self.network_size() + return lt.repeat(nb_by_network) + + def network_segment_size(self, id_networks=None): + """Get number of segment by network + + :return array: + """ + i0, i1, ref = build_index(self.track[self.index_segment_track[0]]) + if id_networks is None: + return i1 - i0 + else: + i = id_networks - ref + return i1[i] - i0[i] + + def network_size(self, id_networks=None): + """ + Return size for specified network + + :param list,array, None id_networks: ids to identify network + """ + if id_networks is None: + return self.index_network[1] - self.index_network[0] + else: + i = id_networks - self.index_network[2] + return self.index_network[1][i] - self.index_network[0][i] + + @property + def networks_period(self): + """ + Return period for each network + """ + return get_period_with_index(self.time, *self.index_network[:2]) + + + + def unique_segment_to_id(self, id_unique): + """Return id network and id segment for a unique id + + :param array id_unique: + """ + i = self.index_segment_track[0][id_unique] - self.ref_segment_track_index + return self.track[i], self.segment[i] + + def segment_slice(self, id_network, id_segment): + """ + Return slice for one segment + + :param int id_network: id to identify network + :param int id_segment: id to identify segment + """ + raise Exception("need to be implemented") + def network_slice(self, id_network): """ Return slice for one network @@ -188,17 +290,25 @@ def elements(self): def astype(self, cls): new = cls.new_like(self, self.shape) - print() - for k in new.obs.dtype.names: - if k in self.obs.dtype.names: + for k in new.fields: + if k in self.fields: new[k][:] = self[k][:] new.sign_type = self.sign_type return new - + def longer_than(self, nb_day_min=-1, nb_day_max=-1): """ Select network on time duration + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used + """ + return self.extract_with_mask(self.mask_longer_than(nb_day_min, nb_day_max)) + + def mask_longer_than(self, nb_day_min=-1, nb_day_max=-1): + """ + Select network on time duration + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used """ @@ -206,13 +316,13 @@ def longer_than(self, nb_day_min=-1, nb_day_max=-1): nb_day_max = 1000000000000 mask = zeros(self.shape, dtype="bool") t = self.time - for i, b0, b1 in self.iter_on(self.track): + for i, _, _ in self.iter_on(self.track): nb = i.stop - i.start if nb == 0: continue - if nb_day_min <= ptp(t[i]) <= nb_day_max: + if nb_day_min <= (ptp(t[i]) + 1) <= nb_day_max: mask[i] = True - return self.extract_with_mask(mask) + return mask @classmethod def from_split_network(cls, group_dataset, indexs, **kwargs): @@ -249,13 +359,19 @@ def correct_close_events(self, nb_days_max=20): """ Transform event where segment A splits from segment B, then x days after segment B merges with A - to - segment A splits from segment B then x days after segment A merges with B (B will be longer) - These events have to last less than `nb_days_max` to be changed. + + ------------------- A + / / + B -------------------- + to + --A-- + / \ + B ----------------------------------- + :param float nb_days_max: maximum time to search for splitting-merging event """ @@ -278,38 +394,38 @@ def correct_close_events(self, nb_days_max=20): segments_connexion[seg] = [i, i_p, i_n] for seg in sorted(segments_connexion.keys()): - seg_slice, i_seg_p, i_seg_n = segments_connexion[seg] + seg_slice, _, i_seg_n = segments_connexion[seg] # the segment ID has to be corrected, because we may have changed it since seg_corrected = segment[seg_slice.stop - 1] # we keep the real segment number seg_corrected_copy = segment_copy[seg_slice.stop - 1] + if i_seg_n == -1: + continue + # if segment is split n_seg = segment[i_seg_n] - # if segment is split - if i_seg_n != -1: - seg2_slice, i2_seg_p, i2_seg_n = segments_connexion[n_seg] - p2_seg = segment[i2_seg_p] - - # if it merges on the first in a certain time - if (p2_seg == seg_corrected) and ( - _time[i_seg_n] - _time[i2_seg_p] < nb_days_max - ): - my_slice = slice(i_seg_n, seg2_slice.stop) - # correct the factice segment - segment[my_slice] = seg_corrected - # correct the good segment - segment_copy[my_slice] = seg_corrected_copy - previous_obs[i_seg_n] = seg_slice.stop - 1 - - segments_connexion[seg_corrected][0] = my_slice - - self.segment[:] = segment_copy - self.previous_obs[:] = previous_obs + seg2_slice, i2_seg_p, _ = segments_connexion[n_seg] + if i2_seg_p == -1: + continue + p2_seg = segment[i2_seg_p] - self.sort() + # if it merges on the first in a certain time + if (p2_seg == seg_corrected) and ( + _time[i_seg_n] - _time[i2_seg_p] < nb_days_max + ): + my_slice = slice(i_seg_n, seg2_slice.stop) + # correct the factice segment + segment[my_slice] = seg_corrected + # correct the good segment + segment_copy[my_slice] = seg_corrected_copy + previous_obs[i_seg_n] = seg_slice.stop - 1 + + segments_connexion[seg_corrected][0] = my_slice + + return self.sort() def sort(self, order=("track", "segment", "time")): """ @@ -317,14 +433,19 @@ def sort(self, order=("track", "segment", "time")): :param tuple order: order or sorting. Given to :func:`numpy.argsort` """ - index_order = self.obs.argsort(order=order) - for field in self.elements: + index_order = self.obs.argsort(order=order, kind="mergesort") + self.reset_index() + for field in self.fields: self[field][:] = self[field][index_order] - translate = -ones(index_order.max() + 2, dtype="i4") - translate[index_order] = arange(index_order.shape[0]) + nb_obs = len(self) + # we add 1 for -1 index return index -1 + translate = -ones(nb_obs + 1, dtype="i4") + translate[index_order] = arange(nb_obs) + # next & previous must be re-indexed self.next_obs[:] = translate[self.next_obs] self.previous_obs[:] = translate[self.previous_obs] + return index_order, translate def obs_relative_order(self, i_obs): self.only_one_network() @@ -375,7 +496,6 @@ def find_link(self, i_observations, forward=True, backward=False): segments_connexion[seg][0] = i_slice if i_p != -1: - if p_seg not in segments_connexion: segments_connexion[p_seg] = [None, [], []] @@ -427,8 +547,10 @@ def func_backward(seg, indice): return self.extract_with_mask(mask) def connexions(self, multi_network=False): - """ - Create dictionnary for each segment, gives the segments in interaction with + """Create dictionnary for each segment, gives the segments in interaction with + + :param bool multi_network: use segment_track_array instead of segment, defaults to False + :return dict: Return dict of set, for each seg id we get set of segment which have event with him """ if multi_network: segment = self.segment_track_array @@ -437,25 +559,28 @@ def connexions(self, multi_network=False): segment = self.segment segments_connexion = dict() - def add_seg(father, child): - if father not in segments_connexion: - segments_connexion[father] = set() - segments_connexion[father].add(child) - - previous_obs, next_obs = self.previous_obs, self.next_obs - for i, seg, _ in self.iter_on(segment): - if i.start == i.stop: - continue - i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1] - # segment in interaction - p_seg, n_seg = segment[i_p], segment[i_n] - # Where segment are called - if i_p != -1: - add_seg(p_seg, seg) - add_seg(seg, p_seg) - if i_n != -1: - add_seg(n_seg, seg) - add_seg(seg, n_seg) + def add_seg(s1, s2): + if s1 not in segments_connexion: + segments_connexion[s1] = set() + if s2 not in segments_connexion: + segments_connexion[s2] = set() + segments_connexion[s1].add(s2), segments_connexion[s2].add(s1) + + # Get index for each segment + i0, i1, _ = self.index_segment_track + i1 = i1 - 1 + # Check if segment merge + i_next = self.next_obs[i1] + m_n = i_next != -1 + # Check if segment come from splitting + i_previous = self.previous_obs[i0] + m_p = i_previous != -1 + # For each split + for s1, s2 in zip(segment[i_previous[m_p]], segment[i0[m_p]]): + add_seg(s1, s2) + # For each merge + for s1, s2 in zip(segment[i_next[m_n]], segment[i1[m_n]]): + add_seg(s1, s2) return segments_connexion @classmethod @@ -477,6 +602,7 @@ def segment_relative_order(self, seg_origine): """ Compute the relative order of each segment to the chosen segment """ + self.only_one_network() i_s, i_e, i_ref = build_index(self.segment) segment_connexions = self.connexions() relative_tr = -ones(i_s.shape, dtype="i4") @@ -518,7 +644,6 @@ def relatives(self, obs, order=2): segments_connexion[seg][0] = i_slice if i_p != -1: - if p_seg not in segments_connexion: segments_connexion[p_seg] = [None, []] @@ -590,16 +715,16 @@ def normalize_longitude(self): lon0 = (self.lon[i_start] - 180).repeat(i_stop - i_start) logger.debug("Normalize longitude") self.lon[:] = (self.lon - lon0) % 360 + lon0 - if "lon_max" in self.obs.dtype.names: + if "lon_max" in self.fields: logger.debug("Normalize longitude_max") self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 if not self.raw_data: - if "contour_lon_e" in self.obs.dtype.names: + if "contour_lon_e" in self.fields: logger.debug("Normalize effective contour longitude") self.contour_lon_e[:] = ( (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 ).T - if "contour_lon_s" in self.obs.dtype.names: + if "contour_lon_s" in self.fields: logger.debug("Normalize speed contour longitude") self.contour_lon_s[:] = ( (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 @@ -624,7 +749,7 @@ def only_one_network(self): if there are more than one network """ _, i_start, _ = self.index_network - if len(i_start) > 1: + if i_start.size > 1: raise Exception("Several networks") def position_filter(self, median_half_window, loess_half_window): @@ -679,7 +804,13 @@ def display_timeline( """ self.only_one_network() j = 0 - line_kw = dict(ls="-", marker="+", markersize=6, zorder=1, lw=3,) + line_kw = dict( + ls="-", + marker="+", + markersize=6, + zorder=1, + lw=3, + ) line_kw.update(kwargs) mappables = dict(lines=list()) @@ -693,17 +824,19 @@ def display_timeline( colors_mode=colors_mode, ) ) + if field is not None: + field = self.parse_varname(field) for i, b0, b1 in self.iter_on("segment"): - x = self.time[i] + x = self.time_datetime64[i] if x.shape[0] == 0: continue if field is None: y = b0 * ones(x.shape) else: if method == "all": - y = self[field][i] * factor + y = field[i] * factor else: - y = self[field][i].mean() * ones(x.shape) * factor + y = field[i].mean() * ones(x.shape) * factor if colors_mode == "roll": _color = self.get_color(j) @@ -725,11 +858,11 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol # TODO : fill mappables dict y_seg = dict() - _time = self.time + _time = self.time_datetime64 if field is not None and method != "all": for i, b0, _ in self.iter_on("segment"): - y = self[field][i] + y = self.parse_varname(field)[i] if y.shape[0] != 0: y_seg[b0] = y.mean() * factor mappables = dict() @@ -755,7 +888,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol y0 = b0 else: if method == "all": - y0 = self[field][i.stop - 1] * factor + y0 = self.parse_varname(field)[i.stop - 1] * factor else: y0 = y_seg[b0] if i_n != -1: @@ -764,7 +897,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol seg_next if field is None else ( - self[field][i_n] * factor + self.parse_varname(field)[i_n] * factor if method == "all" else y_seg[seg_next] ) @@ -780,7 +913,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol seg_previous if field is None else ( - self[field][i_p] * factor + self.parse_varname(field)[i_p] * factor if method == "all" else y_seg[seg_previous] ) @@ -816,7 +949,7 @@ def map_segment(self, method, y, same=True, **kw): out = empty(y.shape, **kw) else: out = list() - for i, b0, b1 in self.iter_on(self.segment_track_array): + for i, _, _ in self.iter_on(self.segment_track_array): res = method(y[i]) if same: out[i] = res @@ -905,14 +1038,17 @@ def scatter_timeline( if "c" not in kwargs: v = self.parse_varname(name) kwargs["c"] = v * factor - mappables["scatter"] = ax.scatter(self.time, y, **kwargs) + mappables["scatter"] = ax.scatter(self.time_datetime64, y, **kwargs) return mappables def event_map(self, ax, **kwargs): """Add the merging and splitting events to a map""" j = 0 mappables = dict() - symbol_kw = dict(markersize=10, color="k",) + symbol_kw = dict( + markersize=10, + color="k", + ) symbol_kw.update(kwargs) symbol_kw_split = symbol_kw.copy() symbol_kw_split["markersize"] += 4 @@ -941,7 +1077,13 @@ def event_map(self, ax, **kwargs): return mappables def scatter( - self, ax, name="time", factor=1, ref=None, edgecolor_cycle=None, **kwargs, + self, + ax, + name="time", + factor=1, + ref=None, + edgecolor_cycle=None, + **kwargs, ): """ This function scatters the path of each network, with the merging and splitting events @@ -992,7 +1134,7 @@ def extract_event(self, indices): raw_data=self.raw_data, ) - for k in new.obs.dtype.names: + for k in new.fields: new[k][:] = self[k][indices] new.sign_type = self.sign_type return new @@ -1000,37 +1142,35 @@ def extract_event(self, indices): @property def segment_track_array(self): """Return a unique segment id when multiple networks are considered""" - return build_unique_array(self.segment, self.track) - - def birth_event(self): - """Extract birth events. - Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash.""" - # FIXME how to manage group 0 - indices = list() - previous_obs = self.previous_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_p = previous_obs[i.start] - if i_p == -1: - indices.append(i.start) - return self.extract_event(list(set(indices))) - - def death_event(self): - """Extract death events. - Advice : individual eddies (self.track == 0) should be removed before -> apply remove_trash.""" - # FIXME how to manage group 0 - indices = list() - next_obs = self.next_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_n = next_obs[i.stop - 1] - if i_n == -1: - indices.append(i.stop - 1) - return self.extract_event(list(set(indices))) + if self._segment_track_array is None: + self._segment_track_array = build_unique_array(self.segment, self.track) + return self._segment_track_array + + def birth_event(self, only_index=False): + """Extract birth events.""" + i_start, _, _ = self.index_segment_track + indices = i_start[self.previous_obs[i_start] == -1] + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + generation_event = birth_event + + def death_event(self, only_index=False): + """Extract death events.""" + _, i_stop, _ = self.index_segment_track + indices = i_stop[self.next_obs[i_stop - 1] == -1] - 1 + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + dissipation_event = death_event def merging_event(self, triplet=False, only_index=False): """Return observation after a merging event. @@ -1038,25 +1178,26 @@ def merging_event(self, triplet=False, only_index=False): If `triplet=True` return the eddy after a merging event, the eddy before the merging event, and the eddy stopped due to merging. """ - idx_m1 = list() + # Get start and stop for each segment, there is no empty segment + _, i1, _ = self.index_segment_track + # Get last index for each segment + i_stop = i1 - 1 + # Get target index + idx_m1 = self.next_obs[i_stop] + # Get mask and valid target + m = idx_m1 != -1 + idx_m1 = idx_m1[m] + # Sort by time event + i = self.time[idx_m1].argsort() + idx_m1 = idx_m1[i] if triplet: - idx_m0_stop = list() - idx_m0 = list() - next_obs, previous_obs = self.next_obs, self.previous_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_n = next_obs[i.stop - 1] - if i_n != -1: - if triplet: - idx_m0_stop.append(i.stop - 1) - idx_m0.append(previous_obs[i_n]) - idx_m1.append(i_n) + # Get obs before target + idx_m0_stop = i_stop[m][i] + idx_m0 = self.previous_obs[idx_m1].copy() if triplet: if only_index: - return (idx_m1, idx_m0, idx_m0_stop) + return idx_m1, idx_m0, idx_m0_stop else: return ( self.extract_event(idx_m1), @@ -1064,7 +1205,7 @@ def merging_event(self, triplet=False, only_index=False): self.extract_event(idx_m0_stop), ) else: - idx_m1 = list(set(idx_m1)) + idx_m1 = unique(idx_m1) if only_index: return idx_m1 else: @@ -1076,34 +1217,33 @@ def splitting_event(self, triplet=False, only_index=False): If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event, and the eddy starting due to splitting. """ - idx_s0 = list() + # Get start and stop for each segment, there is no empty segment + i_start, _, _ = self.index_segment_track + # Get target index + idx_s0 = self.previous_obs[i_start] + # Get mask and valid target + m = idx_s0 != -1 + idx_s0 = idx_s0[m] + # Sort by time event + i = self.time[idx_s0].argsort() + idx_s0 = idx_s0[i] if triplet: - idx_s1_start = list() - idx_s1 = list() - next_obs, previous_obs = self.next_obs, self.previous_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_p = previous_obs[i.start] - if i_p != -1: - if triplet: - idx_s1_start.append(i.start) - idx_s1.append(next_obs[i_p]) - idx_s0.append(i_p) + # Get obs after target + idx_s1_start = i_start[m][i] + idx_s1 = self.next_obs[idx_s0].copy() if triplet: if only_index: - return (idx_s0, idx_s1, idx_s1_start) + return idx_s0, idx_s1, idx_s1_start else: return ( - self.extract_event(list(idx_s0)), - self.extract_event(list(idx_s1)), - self.extract_event(list(idx_s1_start)), + self.extract_event(idx_s0), + self.extract_event(idx_s1), + self.extract_event(idx_s1_start), ) else: - idx_s0 = list(set(idx_s0)) + idx_s0 = unique(idx_s0) if only_index: return idx_s0 else: @@ -1113,35 +1253,131 @@ def dissociate_network(self): """ Dissociate networks with no known interaction (splitting/merging) """ - - tags = self.tag_segment(multi_network=True) + tags = self.tag_segment() if self.track[0] == 0: tags -= 1 - self.track[:] = tags[self.segment_track_array] + return self.sort() - i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort") - # Sort directly obs, with hope to save memory - self.obs.sort(order=("track", "segment", "time"), kind="mergesort") - self._index_network = None - - # n & p must be re-indexed - n, p = self.next_obs, self.previous_obs - # we add 2 for -1 index return index -1 - nb_obs = len(self) - translate = -ones(nb_obs + 1, dtype="i4") - translate[:-1][i_sort] = arange(nb_obs) - self.next_obs[:] = translate[n] - self.previous_obs[:] = translate[p] + def network_segment(self, id_network, id_segment): + return self.extract_with_mask(self.segment_slice(id_network, id_segment)) def network(self, id_network): return self.extract_with_mask(self.network_slice(id_network)) + def networks_mask(self, id_networks, segment=False): + if segment: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_segment_track + ) + else: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_network + ) + def networks(self, id_networks): - m = zeros(self.track.shape, dtype=bool) - for tr in id_networks: - m[self.network_slice(tr)] = True - return self.extract_with_mask(m) + return self.extract_with_mask( + generate_mask_from_ids( + array(id_networks), self.track.size, *self.index_network + ) + ) + + @property + def nb_network(self): + """ + Count and return number of network + """ + return (self.network_size() != 0).sum() + + @property + def nb_segment(self): + """ + Count and return number of segment in all network + """ + return self.index_segment_track[0].size + + def identify_in(self, other, size_min=1, segment=False): + """ + Return couple of segment or network which are equal + + :param other: other atlas to compare + :param int size_min: number of observation in network/segment + :param bool segment: segment mode + """ + if segment: + counts = self.segment_size(), other.segment_size() + i_self_ref, i_other_ref = ( + self.ref_segment_track_index, + other.ref_segment_track_index, + ) + var_id = "segment" + else: + counts = self.network_size(), other.network_size() + i_self_ref, i_other_ref = self.ref_index, other.ref_index + var_id = "track" + # object to contain index of couple + in_self, in_other = list(), list() + # We iterate on item of same size + for i_self, i_other, i0, _ in self.align_on(other, counts, all_ref=True): + if i0 < size_min: + continue + if isinstance(i_other, slice): + i_other = arange(i_other.start, i_other.stop) + # All_ref will give all item of self, sometime there is no things to compare with other + if i_other.size == 0: + id_self = i_self + i_self_ref + in_self.append(id_self) + in_other.append(-ones(id_self.shape, dtype=id_self.dtype)) + continue + if isinstance(i_self, slice): + i_self = arange(i_self.start, i_self.stop) + # We get absolute id + id_self, id_other = i_self + i_self_ref, i_other + i_other_ref + # We compute mask to select data + m_self, m_other = self.networks_mask(id_self, segment), other.networks_mask( + id_other, segment + ) + + # We extract obs + obs_self, obs_other = self.obs[m_self], other.obs[m_other] + x1, y1, t1 = obs_self["lon"], obs_self["lat"], obs_self["time"] + x2, y2, t2 = obs_other["lon"], obs_other["lat"], obs_other["time"] + + if segment: + ids1 = build_unique_array(obs_self["segment"], obs_self["track"]) + ids2 = build_unique_array(obs_other["segment"], obs_other["track"]) + label1 = self.segment_track_array[m_self] + label2 = other.segment_track_array[m_other] + else: + label1, label2 = ids1, ids2 = obs_self[var_id], obs_other[var_id] + # For each item we get index to sort + i01, indexs1, id1 = list(), List(), list() + for sl_self, id_, _ in self.iter_on(ids1): + i01.append(sl_self.start) + indexs1.append(obs_self[sl_self].argsort(order=["time", "lon", "lat"])) + id1.append(label1[sl_self.start]) + i02, indexs2, id2 = list(), List(), list() + for sl_other, _, _ in other.iter_on(ids2): + i02.append(sl_other.start) + indexs2.append( + obs_other[sl_other].argsort(order=["time", "lon", "lat"]) + ) + id2.append(label2[sl_other.start]) + + id1, id2 = array(id1), array(id2) + # We search item from self in item of others + i_local_target = same_position( + x1, y1, t1, x2, y2, t2, array(i01), array(i02), indexs1, indexs2 + ) + + # -1 => no item found in other dataset + m = i_local_target != -1 + in_self.append(id1) + track2_ = -ones(id1.shape, dtype="i4") + track2_[m] = id2[i_local_target[m]] + in_other.append(track2_) + + return concatenate(in_self), concatenate(in_other) @classmethod def __tag_segment(cls, seg, tag, groups, connexions): @@ -1165,16 +1401,22 @@ def __tag_segment(cls, seg, tag, groups, connexions): # For each connexion we apply same function cls.__tag_segment(seg, tag, groups, connexions) - def tag_segment(self, multi_network=False): - if multi_network: - nb = self.segment_track_array[-1] + 1 - else: - nb = self.segment.max() + 1 + def tag_segment(self): + """For each segment, method give a new network id, and all segment are connected + + :return array: for each unique seg id, it return new network id + """ + nb = self.segment_track_array[-1] + 1 sub_group = zeros(nb, dtype="u4") - c = self.connexions(multi_network=multi_network) + c = self.connexions(multi_network=True) j = 1 # for each available id for i in range(nb): + # No connexions, no need to explore + if i not in c: + sub_group[i] = j + j += 1 + continue # Skip if already set if sub_group[i] != 0: continue @@ -1184,14 +1426,28 @@ def tag_segment(self, multi_network=False): return sub_group def fully_connected(self): + """Suspicious""" + raise Exception("Must be check") self.only_one_network() return self.tag_segment().shape[0] == 1 + def first_is_trash(self): + """Check if first network is Trash + + :return bool: True if first network is trash + """ + i_start, i_stop, _ = self.index_segment_track + sl = slice(i_start[0], i_stop[0]) + return (self.previous_obs[sl] == -1).all() and (self.next_obs[sl] == -1).all() + def remove_trash(self): """ Remove the lonely eddies (only 1 obs in segment, associated network number is 0) """ - return self.extract_with_mask(self.track != 0) + if self.first_is_trash(): + return self.extract_with_mask(self.track != 0) + else: + return self def plot(self, ax, ref=None, color_cycle=None, **kwargs): """ @@ -1202,10 +1458,10 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs): :param dict kwargs: keyword arguments for Axes.plot :return: a list of matplotlib mappables """ - nb_colors = 0 - if color_cycle is not None: - kwargs = kwargs.copy() - nb_colors = len(color_cycle) + kwargs = kwargs.copy() + if color_cycle is None: + color_cycle = self.COLORS + nb_colors = len(color_cycle) mappables = list() if "label" in kwargs: kwargs["label"] = self.format_label(kwargs["label"]) @@ -1223,7 +1479,7 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs): j += 1 return mappables - def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None): + def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None, return_mask=False): """ Remove short segments that don't connect several segments @@ -1235,35 +1491,61 @@ def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None): .. warning:: It will remove short segment that splits from then merges with the same segment """ - segments_keep = list() connexions = self.connexions(multi_network=True) - t = self.time - for i, b0, _ in self.iter_on(self.segment_track_array): - if mask and mask[i].any(): - segments_keep.append(b0) - continue - nb = i.stop - i.start - dt = t[i.stop - 1] - t[i.start] - if (nb < nobs or dt < ndays) and len(connexions.get(b0, tuple())) < 2: - continue - segments_keep.append(b0) - if recursive > 0: - return self.extract_segment(segments_keep, absolute=True).remove_dead_end( - nobs, ndays, recursive - 1 + i0, i1, _ = self.index_segment_track + dt = self.time[i1 - 1] - self.time[i0] + 1 + nb = i1 - i0 + m = (dt >= ndays) * (nb >= nobs) + nb_connexions = array([len(connexions.get(i, tuple())) for i in where(~m)[0]]) + m[~m] = nb_connexions >= 2 + segments_keep = where(m)[0] + if mask is not None: + segments_keep = unique( + concatenate((segments_keep, self.segment_track_array[mask])) ) - return self.extract_segment(segments_keep, absolute=True) + # get mask for selected obs + m = ~self.segment_mask(segments_keep) + if return_mask: + return ~m + self.track[m] = 0 + self.segment[m] = 0 + self.previous_obs[m] = -1 + self.previous_cost[m] = 0 + self.next_obs[m] = -1 + self.next_cost[m] = 0 + + m_previous = m[self.previous_obs] + self.previous_obs[m_previous] = -1 + self.previous_cost[m_previous] = 0 + m_next = m[self.next_obs] + self.next_obs[m_next] = -1 + self.next_cost[m_next] = 0 + + self.sort() + if recursive > 0: + self.remove_dead_end(nobs, ndays, recursive - 1) + + def extract_segment(self, segments, absolute=False): - mask = ones(self.shape, dtype="bool") - segments = array(segments) - values = self.segment_track_array if absolute else "segment" - keep = ones(values.max() + 1, dtype="bool") - v = unique(values) - keep[v] = in1d(v, segments) - for i, b0, b1 in self.iter_on(values): - if not keep[b0]: - mask[i] = False - return self.extract_with_mask(mask) + """Extract given segments + + :param array,tuple,list segments: list of segment to extract + :param bool absolute: keep for compatibility, defaults to False + :return NetworkObservations: Return observations from selected segment + """ + if not absolute: + raise Exception("Not implemented") + return self.extract_with_mask(self.segment_mask(segments)) + + def segment_mask(self, segments): + """Get mask from list of segment + + :param list,array segments: absolute id of segment + """ + return generate_mask_from_ids( + array(segments), len(self), *self.index_segment_track + ) def get_mask_with_period(self, period): """ @@ -1371,12 +1653,11 @@ def extract_with_mask(self, mask): logger.debug( f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" ) - for field in self.obs.dtype.descr: + for field in self.fields: if field in ("next_obs", "previous_obs"): continue logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] # n & p must be re-index n, p = self.next_obs[mask], self.previous_obs[mask] # we add 2 for -1 index return index -1 @@ -1398,7 +1679,6 @@ def analysis_coherence( correct_close_events=0, remove_dead_end=0, ): - """Global function to analyse segments coherence, with network preprocessing. :param callable date_function: python function, takes as param `int` (julian day) and return data filename associated to the date @@ -1479,7 +1759,6 @@ def segment_coherence_backward( contour_start="speed", contour_end="speed", ): - """ Percentage of particules and their targets after backward advection from a specific eddy. @@ -1502,9 +1781,9 @@ def date2file(julian_day): return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" """ - - itb_final = -ones((self.obs.size, 2), dtype="i4") - ptb_final = zeros((self.obs.size, 2), dtype="i1") + shape = len(self), 2 + itb_final = -ones(shape, dtype="i4") + ptb_final = zeros(shape, dtype="i1") t_start, t_end = int(self.period[0]), int(self.period[1]) @@ -1540,8 +1819,8 @@ def date2file(julian_day): ) logger.info( ( - f"coherence {_t} / {range_end-1} ({(_t - range_start) / (range_end - range_start-1):.1%})" - f" : {time.time()-_timestamp:5.2f}s" + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" ) ) @@ -1555,8 +1834,8 @@ def segment_coherence_forward( step_mesh=1.0 / 50, contour_start="speed", contour_end="speed", + **kwargs, ): - """ Percentage of particules and their targets after forward advection from a specific eddy. @@ -1579,9 +1858,9 @@ def date2file(julian_day): return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" """ - - itf_final = -ones((self.obs.size, 2), dtype="i4") - ptf_final = zeros((self.obs.size, 2), dtype="i1") + shape = len(self), 2 + itf_final = -ones(shape, dtype="i4") + ptf_final = zeros(shape, dtype="i1") t_start, t_end = int(self.period[0]), int(self.period[1]) @@ -1613,15 +1892,80 @@ def date2file(julian_day): n_days=n_days, contour_start=contour_start, contour_end=contour_end, + **kwargs, ) logger.info( ( - f"coherence {_t} / {range_end-1} ({(_t - range_start) / (range_end - range_start-1):.1%})" - f" : {time.time()-_timestamp:5.2f}s" + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" ) ) return itf_final, ptf_final + def mask_obs_close_event(self, merging=True, spliting=True, dt=3): + """Build a mask of close observation from event + + :param n: Network + :param bool merging: select merging event, defaults to True + :param bool spliting: select splitting event, defaults to True + :param int dt: delta of time max , defaults to 3 + :return array: mask + """ + m = zeros(len(self), dtype="bool") + if merging: + i_target, ip1, ip2 = self.merging_event(triplet=True, only_index=True) + mask_follow_obs(m, self.previous_obs, self.time, ip1, dt) + mask_follow_obs(m, self.previous_obs, self.time, ip2, dt) + mask_follow_obs(m, self.next_obs, self.time, i_target, dt) + if spliting: + i_target, in1, in2 = self.splitting_event(triplet=True, only_index=True) + mask_follow_obs(m, self.next_obs, self.time, in1, dt) + mask_follow_obs(m, self.next_obs, self.time, in2, dt) + mask_follow_obs(m, self.previous_obs, self.time, i_target, dt) + return m + + def swap_track( + self, + length_main_max_after_event=2, + length_secondary_min_after_event=10, + delta_pct_max=-0.2, + ): + events = self.splitting_event(triplet=True, only_index=True) + count = 0 + for i_main, i1, i2 in zip(*events): + seg_main, _, seg2 = ( + self.segment_track_array[i_main], + self.segment_track_array[i1], + self.segment_track_array[i2], + ) + i_start, i_end, i0 = self.index_segment_track + # For splitting + last_index_main = i_end[seg_main - i0] - 1 + last_index_secondary = i_end[seg2 - i0] - 1 + last_main_next_obs = self.next_obs[last_index_main] + t_event, t_main_end, t_secondary_start, t_secondary_end = ( + self.time[i_main], + self.time[last_index_main], + self.time[i2], + self.time[last_index_secondary], + ) + dt_main, dt_secondary = ( + t_main_end - t_event, + t_secondary_end - t_secondary_start, + ) + delta_cost = self.previous_cost[i2] - self.previous_cost[i1] + if ( + dt_main <= length_main_max_after_event + and dt_secondary >= length_secondary_min_after_event + and last_main_next_obs == -1 + and delta_cost > delta_pct_max + ): + self.segment[i1 : last_index_main + 1] = self.segment[i2] + self.segment[i2 : last_index_secondary + 1] = self.segment[i_main] + count += 1 + logger.info("%d segmnent swap on %d", count, len(events[0])) + return self.sort() + class Network: __slots__ = ( @@ -1709,26 +2053,20 @@ def group_translator(nb, duos): apply_replace(translate, gr_i, gr_j) return translate - def group_observations(self, min_overlap=0.2, minimal_area=False): + def group_observations(self, min_overlap=0.2, minimal_area=False, **kwargs): """Store every interaction between identifications - Parameters - ---------- - minimal_area : bool, optional - If True, function will compute intersection/little polygon, else intersection/union, by default False + :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union, by default False + :param float min_overlap: minimum overlap area to associate observations, by default 0.2 - min_overlap : float, optional - minimum overlap area to associate observations, by default 0.2 - - Returns - ------- - TrackEddiesObservations - netcdf with interactions + :return: + :rtype: TrackEddiesObservations """ results, nb_obs = list(), list() # To display print only in INFO display_iteration = logger.getEffectiveLevel() == logging.INFO + for i, filename in enumerate(self.filenames): if display_iteration: print(f"{filename} compared to {self.window} next", end="\r") @@ -1741,9 +2079,15 @@ def group_observations(self, min_overlap=0.2, minimal_area=False): ii, ij = bbox_intersection(xi, yi, xj, yj) m = ( vertice_overlap( - xi[ii], yi[ii], xj[ij], yj[ij], minimal_area=minimal_area + xi[ii], + yi[ii], + xj[ij], + yj[ij], + minimal_area=minimal_area, + min_overlap=min_overlap, + **kwargs, ) - > min_overlap + != 0 ) results.append((i, j, ii[m], ij[m])) if display_iteration: @@ -1753,7 +2097,7 @@ def group_observations(self, min_overlap=0.2, minimal_area=False): nb_alone, nb_obs, nb_gr = (gr == self.NOGROUP).sum(), len(gr), len(unique(gr)) logger.info( f"{nb_alone} alone / {nb_obs} obs, {nb_gr} groups, " - f"{nb_alone *100./nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group" + f"{nb_alone * 100. / nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group" ) return gr @@ -1788,6 +2132,94 @@ def build_dataset(self, group, raw_data=True): return eddies +@njit(cache=True) +def get_percentile_on_following_obs( + i, indexs, percents, follow_obs, t, segment, i_target, window, q=50, nb_min=1 +): + """Get stat on a part of segment close of an event + + :param int i: index to follow + :param array indexs: indexs from coherence + :param array percents: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int i_target: index of target + :param int window: time window of search + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return float : return statistic + """ + last_t, segment_follow = t[i], segment[i] + segment_target = segment[i_target] + percent_target = empty(window, dtype=percents.dtype) + j = 0 + while abs(last_t - t[i]) < window and i != -1 and segment_follow == segment[i]: + # Iter on primary & secondary + for index, percent in zip(indexs[i], percents[i]): + if index != -1 and segment[index] == segment_target: + percent_target[j] = percent + j += 1 + i = follow_obs[i] + if j < nb_min: + return nan + return percentile(percent_target[:j], q) + + +@njit(cache=True) +def get_percentile_around_event( + i, + i1, + i2, + ind, + pct, + follow_obs, + t, + segment, + window=10, + follow_parent=False, + q=50, + nb_min=1, +): + """Get stat around event + + :param array[int] i: Indexs of target + :param array[int] i1: Indexs of primary origin + :param array[int] i2: Indexs of secondary origin + :param array ind: indexs from coherence + :param array pct: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int window: time window of search, defaults to 10 + :param bool follow_parent: Follow parent instead of child, defaults to False + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return (array,array) : statistic for each event + """ + stat1 = empty(i.size, dtype=nb_types.float32) + stat2 = empty(i.size, dtype=nb_types.float32) + # iter on event + for j, (i_, i1_, i2_) in enumerate(zip(i, i1, i2)): + if follow_parent: + # We follow parent + stat1[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i1_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i2_, window, q, nb_min + ) + else: + # We follow child + stat1[j] = get_percentile_on_following_obs( + i1_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i2_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + return stat1, stat2 + + @njit(cache=True) def get_next_index(gr): """Return for each obs index the new position to join all groups""" @@ -1839,3 +2271,98 @@ def new_numbering(segs, start=0): @njit(cache=True) def ptp(values): return values.max() - values.min() + + +@njit(cache=True) +def generate_mask_from_ids(id_networks, nb, istart, iend, i0): + """From list of id, we generate a mask + + :param array id_networks: list of ids + :param int nb: size of mask + :param array istart: first index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param array iend: last index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param int i0: ref index from :py:meth:`~py_eddy_tracker.generic.build_index` + :return array: return a mask + """ + m = zeros(nb, dtype="bool") + for i in id_networks: + for j in range(istart[i - i0], iend[i - i0]): + m[j] = True + return m + + +@njit(cache=True) +def same_position(x0, y0, t0, x1, y1, t1, i00, i01, i0, i1): + """Return index of track/segment found in other dataset + + :param array x0: + :param array y0: + :param array t0: + :param array x1: + :param array y1: + :param array t1: + :param array i00: First index of track/segment/network in dataset0 + :param array i01: First index of track/segment/network in dataset1 + :param List(array) i0: list of array which contain index to order dataset0 + :param List(array) i1: list of array which contain index to order dataset1 + :return array: index of dataset1 which match with dataset0, -1 => no match + """ + nb0, nb1 = i00.size, i01.size + i_target = -ones(nb0, dtype="i4") + # To avoid to compare multiple time, if already match + used1 = zeros(nb1, dtype="bool") + for j0 in range(nb0): + for j1 in range(nb1): + if used1[j1]: + continue + test = True + for i0_, i1_ in zip(i0[j0], i1[j1]): + i0_ += i00[j0] + i1_ += i01[j1] + if t0[i0_] != t1[i1_] or x0[i0_] != x1[i1_] or y0[i0_] != y1[i1_]: + test = False + break + if test: + i_target[j0] = j1 + used1[j1] = True + break + return i_target + + +@njit(cache=True) +def mask_follow_obs(m, next_obs, time, indexs, dt=3): + """Generate a mask to select close obs in time from index + + :param array m: mask to fill with True + :param array next_obs: index of the next observation + :param array time: time of each obs + :param array indexs: index to start follow + :param int dt: delta of time max from index, defaults to 3 + """ + for i in indexs: + t0 = time[i] + m[i] = True + i_next = next_obs[i] + dt_ = abs(time[i_next] - t0) + while dt_ < dt and i_next != -1: + m[i_next] = True + i_next = next_obs[i_next] + dt_ = abs(time[i_next] - t0) + + +@njit(cache=True) +def get_period_with_index(t, i0, i1): + """Return peek to peek cover by each slice define by i0 and i1 + + :param array t: array which contain values to estimate spread + :param array i0: index which determine start of slice + :param array i1: index which determine end of slice + :return array: Peek to peek of t + """ + periods = np.empty(i0.size, t.dtype) + for i in range(i0.size): + if i1[i] == i0[i]: + periods[i] = 0 + continue + periods[i] = t[i0[i] : i1[i]].ptp() + return periods diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index 3543caa7..b39f7f83 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -2,20 +2,18 @@ """ Base class to manage eddy observation """ -import logging from datetime import datetime from io import BufferedReader, BytesIO +import logging from tarfile import ExFileObject from tokenize import TokenError -import packaging -import zarr +from Polygon import Polygon from matplotlib.cm import get_cmap -from matplotlib.collections import PolyCollection +from matplotlib.collections import LineCollection, PolyCollection from matplotlib.colors import Normalize from netCDF4 import Dataset -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( absolute, arange, @@ -24,6 +22,7 @@ ceil, concatenate, cos, + datetime64, digitize, empty, errstate, @@ -44,9 +43,10 @@ where, zeros, ) +import packaging.version from pint import UnitRegistry from pint.errors import UndefinedUnitError -from Polygon import Polygon +import zarr from .. import VAR_DESCR, VAR_DESCR_inv, __version__ from ..generic import ( @@ -58,19 +58,20 @@ hist_numba, local_to_coordinates, reverse_index, + window_index, wrap_longitude, ) from ..poly import ( bbox_intersection, close_center, convexs, + create_meshed_particles, create_vertice, get_pixel_in_regular, insidepoly, poly_indexs, reduce_size, vertice_overlap, - winding_number_poly, ) logger = logging.getLogger("pet") @@ -79,6 +80,7 @@ _software_version_reduced = packaging.version.Version( "{v.major}.{v.minor}".format(v=packaging.version.parse(__version__)) ) +_display_check_warning = True def _check_versions(version): @@ -89,7 +91,8 @@ def _check_versions(version): :param version: string version of software used to create the file. If None, version was not provided :type version: str, None """ - + if not _display_check_warning: + return file_version = packaging.version.parse(version) if version is not None else None if file_version is None or file_version < _software_version_reduced: logger.warning( @@ -130,7 +133,7 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 if dx > major[j]: m[j, i] = False continue - d_normalize = dx ** 2 / major[j] ** 2 + dy ** 2 / minor ** 2 + d_normalize = dx**2 / major[j] ** 2 + dy**2 / minor**2 m[j, i] = d_normalize < 1.0 return m @@ -265,7 +268,7 @@ def get_infos(self): bins_lat=(-90, -60, -15, 15, 60, 90), bins_amplitude=array((0, 1, 2, 3, 4, 5, 10, 500)), bins_radius=array((0, 15, 30, 45, 60, 75, 100, 200, 2000)), - nb_obs=self.observations.shape[0], + nb_obs=len(self), ) t0, t1 = self.period infos["t0"], infos["t1"] = t0, t1 @@ -307,12 +310,16 @@ def box_display(value): """Return values evenly spaced with few numbers""" return "".join([f"{v_:10.2f}" for v_ in value]) + @property + def fields(self): + return list(self.obs.dtype.names) + def field_table(self): """ Produce description table of the fields available in this object """ rows = [("Name (Unit)", "Long name", "Scale factor", "Offset")] - names = list(self.obs.dtype.names) + names = self.fields names.sort() for field in names: infos = VAR_DESCR[field] @@ -338,7 +345,7 @@ def __repr__(self): bins_lat = (-90, -60, -15, 15, 60, 90) bins_amplitude = array((0, 1, 2, 3, 4, 5, 10, 500)) bins_radius = array((0, 15, 30, 45, 60, 75, 100, 200, 2000)) - nb_obs = self.observations.shape[0] + nb_obs = len(self) return f""" | {nb_obs} observations from {t0} to {t1} ({period} days, ~{nb_obs / period:.0f} obs/day) | Speed area : {self.speed_area.sum() / period / 1e12:.2f} Mkm²/day @@ -413,9 +420,9 @@ def remove_fields(self, *fields): """ Copy with fields listed remove """ - nb_obs = self.obs.shape[0] + nb_obs = len(self) fields = set(fields) - only_variables = set(self.obs.dtype.names) - fields + only_variables = set(self.fields) - fields track_extra_variables = set(self.track_extra_variables) - fields array_variables = set(self.array_variables) - fields new = self.__class__( @@ -427,7 +434,7 @@ def remove_fields(self, *fields): raw_data=self.raw_data, ) new.sign_type = self.sign_type - for name in new.obs.dtype.names: + for name in new.fields: logger.debug("Copy of field %s ...", name) new.obs[name] = self.obs[name] return new @@ -436,7 +443,7 @@ def add_fields(self, fields=list(), array_fields=list()): """ Add a new field. """ - nb_obs = self.obs.shape[0] + nb_obs = len(self) new = self.__class__( size=nb_obs, track_extra_variables=list( @@ -444,13 +451,11 @@ def add_fields(self, fields=list(), array_fields=list()): ), track_array_variables=self.track_array_variables, array_variables=list(concatenate((self.array_variables, array_fields))), - only_variables=list( - concatenate((self.obs.dtype.names, fields, array_fields)) - ), + only_variables=list(concatenate((self.fields, fields, array_fields))), raw_data=self.raw_data, ) new.sign_type = self.sign_type - for name in self.obs.dtype.names: + for name in self.fields: logger.debug("Copy of field %s ...", name) new.obs[name] = self.obs[name] return new @@ -468,8 +473,8 @@ def circle_contour(self, only_virtual=False, factor=1): """ angle = radians(linspace(0, 360, self.track_array_variables)) x_norm, y_norm = cos(angle), sin(angle) - radius_s = "contour_lon_s" in self.obs.dtype.names - radius_e = "contour_lon_e" in self.obs.dtype.names + radius_s = "contour_lon_s" in self.fields + radius_e = "contour_lon_e" in self.fields for i, obs in enumerate(self): if only_virtual and not obs["virtual"]: continue @@ -544,9 +549,9 @@ def merge(self, other): nb_obs_self = len(self) nb_obs = nb_obs_self + len(other) eddies = self.new_like(self, nb_obs) - other_keys = other.obs.dtype.fields.keys() - self_keys = self.obs.dtype.fields.keys() - for key in eddies.obs.dtype.fields.keys(): + other_keys = other.fields + self_keys = self.fields + for key in eddies.fields: eddies.obs[key][:nb_obs_self] = self.obs[key][:] if key in other_keys: eddies.obs[key][nb_obs_self:] = other.obs[key][:] @@ -571,60 +576,76 @@ def __iter__(self): for obs in self.obs: yield obs - def iter_on(self, xname, bins=None): + def iter_on(self, xname, window=None, bins=None): """ Yield observation group for each bin. :param str,array xname: - :param array bins: bounds of each bin , - :return: index or mask, bound low, bound up + :param float,None window: if defined we use a moving window with value like half window + :param array bins: bounds of each bin + :yield array,float,float: index in self, lower bound, upper bound .. minigallery:: py_eddy_tracker.EddiesObservations.iter_on """ - x = self[xname] if isinstance(xname, str) else xname - d = x[1:] - x[:-1] - if bins is None: - bins = arange(x.min(), x.max() + 2) - elif not isinstance(bins, ndarray): - bins = array(bins) - nb_bins = len(bins) - 1 - - # Not monotonous - if (d < 0).any(): - # If bins cover a small part of value - test, translate, x = iter_mode_reduce(x, bins) - # convert value in bins number - i = numba_digitize(x, bins) - 1 - # Order by bins - i_sort = i.argsort() - # If in reduced mode we will translate i_sort in full array index - i_sort_ = translate[i_sort] if test else i_sort - # Bound for each bins in sorting view - i0, i1, _ = build_index(i[i_sort]) - m = ~(i0 == i1) - i0, i1 = i0[m], i1[m] - for i0_, i1_ in zip(i0, i1): - i_bins = i[i_sort[i0_]] - if i_bins == -1 or i_bins == nb_bins: - continue - yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1] + x = self.parse_varname(xname) + if window is not None: + x0 = arange(x.min(), x.max()) if bins is None else array(bins) + i_ordered, first_index, last_index = window_index(x, x0, window) + for x_, i0, i1 in zip(x0, first_index, last_index): + yield i_ordered[i0:i1], x_ - window, x_ + window else: - i = numba_digitize(x, bins) - 1 - i0, i1, _ = build_index(i) - m = ~(i0 == i1) - i0, i1 = i0[m], i1[m] - for i0_, i1_ in zip(i0, i1): - i_bins = i[i0_] - yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] + d = x[1:] - x[:-1] + if bins is None: + bins = arange(x.min(), x.max() + 2) + elif not isinstance(bins, ndarray): + bins = array(bins) + nb_bins = len(bins) - 1 + + # Not monotonous + if (d < 0).any(): + # If bins cover a small part of value + test, translate, x = iter_mode_reduce(x, bins) + # convert value in bins number + i = numba_digitize(x, bins) - 1 + # Order by bins + i_sort = i.argsort() + # If in reduced mode we will translate i_sort in full array index + i_sort_ = translate[i_sort] if test else i_sort + # Bound for each bins in sorting view + i0, i1, _ = build_index(i[i_sort]) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i_sort[i0_]] + if i_bins == -1 or i_bins == nb_bins: + continue + yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1] + else: + i = numba_digitize(x, bins) - 1 + i0, i1, _ = build_index(i) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i0_] + yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] - def align_on(self, other, var_name="time", **kwargs): + def align_on(self, other, var_name="time", all_ref=False, **kwargs): """ - Align the time indices of two datasets. + Align the variable indices of two datasets. + + :param other: other compare with self + :param str,tuple var_name: variable name to align or two array, defaults to "time" + :param bool all_ref: yield all value of ref, if false only common value, defaults to False + :yield array,array,float,float: index in self, index in other, lower bound, upper bound .. minigallery:: py_eddy_tracker.EddiesObservations.align_on """ - iter_self = self.iter_on(var_name, **kwargs) - iter_other = other.iter_on(var_name, **kwargs) + if isinstance(var_name, str): + iter_self = self.iter_on(var_name, **kwargs) + iter_other = other.iter_on(var_name, **kwargs) + else: + iter_self = self.iter_on(var_name[0], **kwargs) + iter_other = other.iter_on(var_name[1], **kwargs) indexs_other, b0_other, b1_other = iter_other.__next__() for indexs_self, b0_self, b1_self in iter_self: if b0_self > b0_other: @@ -634,6 +655,10 @@ def align_on(self, other, var_name="time", **kwargs): except StopIteration: break if b0_self < b0_other: + if all_ref: + yield indexs_self, empty( + 0, dtype=indexs_self.dtype + ), b0_self, b1_self continue yield indexs_self, indexs_other, b0_self, b1_self @@ -641,8 +666,8 @@ def insert_observations(self, other, index): """Insert other obs in self at the given index.""" if not self.coherence(other): raise Exception("Observations with no coherence") - insert_size = len(other.obs) - self_size = len(self.obs) + insert_size = len(other) + self_size = len(self) new_size = self_size + insert_size if self_size == 0: self.observations = other.obs @@ -672,7 +697,7 @@ def distance(self, other): def __copy__(self): eddies = self.new_like(self, len(self)) - for k in self.obs.dtype.names: + for k in self.fields: eddies[k][:] = self[k][:] eddies.sign_type = self.sign_type return eddies @@ -729,7 +754,11 @@ def load_file(cls, filename, **kwargs): .. code-block:: python kwargs_latlon_300 = dict( - include_vars=["longitude", "latitude",], indexs=dict(obs=slice(0, 300)), + include_vars=[ + "longitude", + "latitude", + ], + indexs=dict(obs=slice(0, 300)), ) small_dataset = TrackEddiesObservations.load_file( filename, **kwargs_latlon_300 @@ -747,7 +776,7 @@ def load_file(cls, filename, **kwargs): zarr_file = filename_.endswith(end) else: zarr_file = False - logger.info(f"loading file '{filename}'") + logger.info(f"loading file '{filename_}'") if zarr_file: return cls.load_from_zarr(filename, **kwargs) else: @@ -1046,6 +1075,17 @@ def compare_units(input_unit, output_unit, name): input_unit, output_unit, ) + return factor + else: + return 1 + + @classmethod + def from_array(cls, arrays, **kwargs): + nb = arrays["time"].size + eddies = cls(size=nb, **kwargs) + for k, v in arrays.items(): + eddies.obs[k] = v + return eddies @classmethod def from_zarr(cls, handler): @@ -1302,7 +1342,7 @@ def fixed_ellipsoid_mask( if isinstance(minor, ndarray): minor = minor[index_self] # focal distance - f_degree = ((major ** 2 - minor ** 2) ** 0.5) / ( + f_degree = ((major**2 - minor**2) ** 0.5) / ( 111.2 * cos(radians(self.lat[index_self])) ) @@ -1511,8 +1551,7 @@ def to_zarr(self, handler, **kwargs): handler.attrs["track_array_variables"] = self.track_array_variables handler.attrs["array_variables"] = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - for ori_name in fields: + for ori_name in self.fields: # Patch for a transition name = ori_name # @@ -1557,12 +1596,9 @@ def to_netcdf(self, handler, **kwargs): handler.track_array_variables = self.track_array_variables handler.array_variables = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - fields_ = array( - [VAR_DESCR[field[0]]["nc_name"] for field in self.observations.dtype.descr] - ) + fields_ = array([VAR_DESCR[field]["nc_name"] for field in self.fields]) i = fields_.argsort() - for ori_name in array(fields)[i]: + for ori_name in array(self.fields)[i]: # Patch for a transition name = ori_name # @@ -1629,6 +1665,33 @@ def create_variable( except ValueError: logger.warning("Data is empty") + @staticmethod + def get_filters_zarr(name): + """Get filters to store in zarr for known variable + + :param str name: private variable name + :return list: filters list + """ + content = VAR_DESCR.get(name) + filters = list() + store_dtype = content["output_type"] + scale_factor, add_offset = content.get("scale_factor", None), content.get( + "add_offset", None + ) + if scale_factor is not None or add_offset is not None: + if add_offset is None: + add_offset = 0 + filters.append( + zarr.FixedScaleOffset( + offset=add_offset, + scale=1 / scale_factor, + dtype=content["nc_type"], + astype=store_dtype, + ) + ) + filters.extend(content.get("filters", [])) + return filters + def create_variable_zarr( self, handler_zarr, @@ -1784,6 +1847,11 @@ def extract_with_area(self, area, **kwargs): mask *= (lon > lon0) * (lon < area["urcrnrlon"]) return self.extract_with_mask(mask, **kwargs) + @property + def time_datetime64(self): + dt = (datetime64("1970-01-01") - datetime64("1950-01-01")).astype("i8") + return (self.time - dt).astype("datetime64[D]") + def time_sub_sample(self, t0, time_step): """ Time sub sampling @@ -1809,10 +1877,9 @@ def extract_with_mask(self, mask): if nb_obs == 0: logger.warning("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] return new def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): @@ -2010,7 +2077,43 @@ def bins_stat(self, xname, bins=None, yname=None, method=None, mask=None): def format_label(self, label): t0, t1 = self.period - return label.format(t0=t0, t1=t1, nb_obs=len(self),) + return label.format( + t0=t0, + t1=t1, + nb_obs=len(self), + ) + + def display_color(self, ax, field, ref=None, intern=False, **kwargs): + """Plot colored contour of eddies + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array field: color field + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool intern: if True, draw the speed contour + :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` + + .. minigallery:: py_eddy_tracker.EddiesObservations.display_color + """ + xname, yname = self.intern(intern) + x, y = self[xname], self[yname] + + if ref is not None: + # TODO : maybe buggy with global display + shape_out = x.shape + x, y = wrap_longitude(x.reshape(-1), y.reshape(-1), ref) + x, y = x.reshape(shape_out), y.reshape(shape_out) + + c = self.parse_varname(field) + cmap = get_cmap(kwargs.pop("cmap", "Spectral_r")) + cmin, cmax = kwargs.pop("vmin", c.min()), kwargs.pop("vmax", c.max()) + colors = cmap((c - cmin) / (cmax - cmin)) + lines = LineCollection( + [create_vertice(i, j) for i, j in zip(x, y)], colors=colors, **kwargs + ) + ax.add_collection(lines) + lines.cmap = cmap + lines.norm = Normalize(vmin=cmin, vmax=cmax) + return lines def display(self, ax, ref=None, extern_only=False, intern_only=False, **kwargs): """Plot the speed and effective (dashed) contour of the eddies @@ -2153,7 +2256,7 @@ def grid_count(self, bins, intern=False, center=False, filter=slice(None)): x_ref = ((self.longitude[filter] - x0) % 360 + x0 - 180).reshape(-1, 1) x_contour, y_contour = self[x_name][filter], self[y_name][filter] grid_count_pixel_in( - grid, + grid.data, x_contour, y_contour, x_ref, @@ -2256,7 +2359,7 @@ def grid_stat(self, bins, varname, data=None): return regular_grid def interp_grid( - self, grid_object, varname, method="center", dtype=None, intern=None + self, grid_object, varname, i=None, method="center", dtype=None, intern=None ): """ Interpolate a grid on a center or contour with mean, min or max method @@ -2264,6 +2367,8 @@ def interp_grid( :param grid_object: Handler of grid to interp :type grid_object: py_eddy_tracker.dataset.grid.RegularGridDataset :param str varname: Name of variable to use + :param array[bool,int],None i: + Index or mask to subset observations, it could avoid to build a specific dataset. :param str method: 'center', 'mean', 'max', 'min', 'nearest' :param str dtype: if None we use var dtype :param bool intern: Use extern or intern contour @@ -2271,19 +2376,25 @@ def interp_grid( .. minigallery:: py_eddy_tracker.EddiesObservations.interp_grid """ if method in ("center", "nearest"): - return grid_object.interp(varname, self.longitude, self.latitude, method) + x, y = self.longitude, self.latitude + if i is not None: + x, y = x[i], y[i] + return grid_object.interp(varname, x, y, method) elif method in ("min", "max", "mean", "count"): x0 = grid_object.x_bounds[0] x_name, y_name = self.intern(False if intern is None else intern) x_ref = ((self.longitude - x0) % 360 + x0 - 180).reshape(-1, 1) x, y = (self[x_name] - x_ref) % 360 + x_ref, self[y_name] + if i is not None: + x, y = x[i], y[i] grid = grid_object.grid(varname) - result = empty(self.shape, dtype=grid.dtype if dtype is None else dtype) + result = empty(x.shape[0], dtype=grid.dtype if dtype is None else dtype) min_method = method == "min" grid_stat( grid_object.x_c, grid_object.y_c, - -grid if min_method else grid, + -grid.data if min_method else grid.data, + grid.mask, x, y, result, @@ -2329,7 +2440,10 @@ def create_particles(self, step, intern=True): """ xname, yname = self.intern(intern) - return _create_meshed_particles(self[xname], self[yname], step) + return create_meshed_particles(self[xname], self[yname], step) + + def empty_dataset(self): + return self.new_like(self, 0) @njit(cache=True) @@ -2381,7 +2495,14 @@ def grid_count_pixel_in( x_, y_ = reduce_size(x_, y_) v = create_vertice(x_, y_) (x_start, x_stop), (y_start, y_stop) = bbox_indice_regular( - v, x_bounds, y_bounds, xstep, ystep, N, is_circular, x_size, + v, + x_bounds, + y_bounds, + xstep, + ystep, + N, + is_circular, + x_size, ) i, j = get_pixel_in_regular(v, x_c, y_c, x_start, x_stop, y_start, y_stop) grid_count_(grid, i, j) @@ -2440,13 +2561,14 @@ def grid_box_stat(x_c, y_c, grid, mask, x, y, value, circular=False, method=50): @njit(cache=True) -def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): +def grid_stat(x_c, y_c, grid, mask, x, y, result, circular=False, method="mean"): """ Compute the mean or the max of the grid for each contour :param array_like x_c: the grid longitude coordinates :param array_like y_c: the grid latitude coordinates :param array_like grid: grid value + :param array[bool] mask: mask for invalid value :param array_like x: longitude of contours :param array_like y: latitude of contours :param array_like result: return values @@ -2472,9 +2594,12 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): result[elt] = i.shape[0] elif mean_method: v_sum = 0 + nb_ = 0 for i_, j_ in zip(i, j): + if mask[i_, j_]: + continue v_sum += grid[i_, j_] - nb_ = i.shape[0] + nb_ += 1 # FIXME : how does it work on grid bound, if nb_ == 0: result[elt] = nan @@ -2483,28 +2608,12 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): elif max_method: v_max = -1e40 for i_, j_ in zip(i, j): - v_max = max(v_max, grid[i_, j_]) + values = grid[i_, j_] + # FIXME must use mask + v_max = max(v_max, values) result[elt] = v_max -@njit(cache=True) -def _create_meshed_particles(lons, lats, step): - x_out, y_out, i_out = list(), list(), list() - for i, (lon, lat) in enumerate(zip(lons, lats)): - lon_min, lon_max = lon.min(), lon.max() - lat_min, lat_max = lat.min(), lat.max() - lon_min -= lon_min % step - lon_max -= lon_max % step - step * 2 - lat_min -= lat_min % step - lat_max -= lat_max % step - step * 2 - - for x in arange(lon_min, lon_max, step): - for y in arange(lat_min, lat_max, step): - if winding_number_poly(x, y, create_vertice(*reduce_size(lon, lat))): - x_out.append(x), y_out.append(y), i_out.append(i) - return array(x_out), array(y_out), array(i_out) - - class VirtualEddiesObservations(EddiesObservations): """Class to work with virtual obs""" diff --git a/src/py_eddy_tracker/observations/tracking.py b/src/py_eddy_tracker/observations/tracking.py index 6612c6d5..fa1c1f93 100644 --- a/src/py_eddy_tracker/observations/tracking.py +++ b/src/py_eddy_tracker/observations/tracking.py @@ -2,8 +2,8 @@ """ Class to manage observations gathered in trajectories """ -import logging from datetime import datetime, timedelta +import logging from numba import njit from numpy import ( @@ -68,6 +68,10 @@ def __init__(self, *args, **kwargs): self.__obs_by_track = None self.__nb_track = None + def track_slice(self, track): + i0 = self.index_from_track[track] + return slice(i0, i0 + self.nb_obs_by_track[track]) + def iter_track(self): """ Yield track @@ -114,7 +118,7 @@ def __repr__(self): t0, t1 = self.period period = t1 - t0 + 1 nb = self.nb_obs_by_track - nb_obs = self.observations.shape[0] + nb_obs = len(self) m = self.virtual.astype("bool") nb_m = m.sum() bins_t = (1, 30, 90, 180, 270, 365, 1000, 10000) @@ -143,7 +147,7 @@ def __repr__(self): def add_distance(self): """Add a field of distance (m) between two consecutive observations, 0 for the last observation of each track""" - if "distance_next" in self.observations.dtype.descr: + if "distance_next" in self.fields: return self new = self.add_fields(("distance_next",)) new["distance_next"][:1] = self.distance_to_next() @@ -179,16 +183,16 @@ def normalize_longitude(self): lon0 = (self.lon[self.index_from_track] - 180).repeat(self.nb_obs_by_track) logger.debug("Normalize longitude") self.lon[:] = (self.lon - lon0) % 360 + lon0 - if "lon_max" in self.obs.dtype.names: + if "lon_max" in self.fields: logger.debug("Normalize longitude_max") self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 if not self.raw_data: - if "contour_lon_e" in self.obs.dtype.names: + if "contour_lon_e" in self.fields: logger.debug("Normalize effective contour longitude") self.contour_lon_e[:] = ( (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 ).T - if "contour_lon_s" in self.obs.dtype.names: + if "contour_lon_s" in self.fields: logger.debug("Normalize speed contour longitude") self.contour_lon_s[:] = ( (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 @@ -201,10 +205,9 @@ def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True): logger.info("Selection of %d observations", nb_obs_select) eddies = self.__class__.new_like(self, nb_obs_select) eddies.sign_type = self.sign_type - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - eddies.obs[var] = self.obs[var][mask] + eddies.obs[field] = self.obs[field][mask] if compress_id: list_id = unique(eddies.obs.track) list_id.sort() @@ -377,19 +380,17 @@ def extract_toward_direction(self, west=True, delta_lon=None): d_lon = lon[i1] - lon[i0] m = d_lon < 0 if west else d_lon > 0 if delta_lon is not None: - m *= delta_lon < d_lon + m *= delta_lon < abs(d_lon) m = m.repeat(nb) return self.extract_with_mask(m) def extract_first_obs_in_box(self, res): - data = empty( - self.obs.shape, dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")] - ) + data = empty(len(self), dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]) data["lon"] = self.longitude - self.longitude % res data["lat"] = self.latitude - self.latitude % res data["track"] = self.track _, indexs = unique(data, return_index=True) - mask = zeros(self.obs.shape, dtype="bool") + mask = zeros(len(self), dtype="bool") mask[indexs] = True return self.extract_with_mask(mask) @@ -434,9 +435,6 @@ def extract_with_length(self, bounds): raise Exception("One bound must be positive") return self.extract_with_mask(track_mask.repeat(self.nb_obs_by_track)) - def empty_dataset(self): - return self.new_like(self, 0) - def loess_filter(self, half_window, xfield, yfield, inplace=True): track = self.track x = self.obs[xfield] @@ -504,10 +502,9 @@ def extract_with_mask( if nb_obs == 0: logger.info("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] if compress_id: list_id = unique(new.track) list_id.sort() @@ -582,7 +579,10 @@ def close_tracks(self, other, nb_obs_min=10, **kwargs): def format_label(self, label): t0, t1 = self.period return label.format( - t0=t0, t1=t1, nb_obs=len(self), nb_tracks=(self.nb_obs_by_track != 0).sum(), + t0=t0, + t1=t1, + nb_obs=len(self), + nb_tracks=(self.nb_obs_by_track != 0).sum(), ) def plot(self, ax, ref=None, **kwargs): @@ -715,7 +715,7 @@ def get_previous_obs( time_ref, window, min_overlap=0.2, - minimal_area=False, + **kwargs, ): """Backward association of observations to the segments""" time_cur = int_(ids["time"][i_current]) @@ -732,10 +732,8 @@ def get_previous_obs( continue c = zeros(len(xj)) c[ij] = vertice_overlap( - xi[ii], yi[ii], xj[ij], yj[ij], minimal_area=minimal_area + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs ) - # We remove low overlap - c[c < min_overlap] = 0 # We get index of maximal overlap i = c.argmax() c_i = c[i] @@ -757,7 +755,7 @@ def get_next_obs( time_ref, window, min_overlap=0.2, - minimal_area=False, + **kwargs, ): """Forward association of observations to the segments""" time_max = time_e.shape[0] - 1 @@ -777,10 +775,8 @@ def get_next_obs( continue c = zeros(len(xj)) c[ij] = vertice_overlap( - xi[ii], yi[ii], xj[ij], yj[ij], minimal_area=minimal_area + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs ) - # We remove low overlap - c[c < min_overlap] = 0 # We get index of maximal overlap i = c.argmax() c_i = c[i] diff --git a/src/py_eddy_tracker/poly.py b/src/py_eddy_tracker/poly.py index abe8becb..b5849610 100644 --- a/src/py_eddy_tracker/poly.py +++ b/src/py_eddy_tracker/poly.py @@ -5,11 +5,10 @@ import heapq -from numba import njit, prange -from numba import types as numba_types +from Polygon import Polygon +from numba import njit, prange, types as numba_types from numpy import arctan, array, concatenate, empty, nan, ones, pi, where, zeros from numpy.linalg import lstsq -from Polygon import Polygon from .generic import build_index @@ -279,7 +278,10 @@ def close_center(x0, y0, x1, y1, delta=0.1): for i0 in range(nb0): xi0, yi0 = x0[i0], y0[i0] for i1 in range(nb1): - if abs(x1[i1] - xi0) > delta: + d_x = x1[i1] - xi0 + if abs(d_x) > 180: + d_x = (d_x + 180) % 360 - 180 + if abs(d_x) > delta: continue if abs(y1[i1] - yi0) > delta: continue @@ -287,6 +289,27 @@ def close_center(x0, y0, x1, y1, delta=0.1): return array(i), array(j), array(c) +@njit(cache=True) +def create_meshed_particles(lons, lats, step): + x_out, y_out, i_out = list(), list(), list() + nb = lons.shape[0] + for i in range(nb): + lon, lat = lons[i], lats[i] + vertice = create_vertice(*reduce_size(lon, lat)) + lon_min, lon_max = lon.min(), lon.max() + lat_min, lat_max = lat.min(), lat.max() + y0 = lat_min - lat_min % step + x = lon_min - lon_min % step + while x <= lon_max: + y = y0 + while y <= lat_max: + if winding_number_poly(x, y, vertice): + x_out.append(x), y_out.append(y), i_out.append(i) + y += step + x += step + return array(x_out), array(y_out), array(i_out) + + @njit(cache=True, fastmath=True) def bbox_intersection(x0, y0, x1, y1): """ @@ -411,7 +434,9 @@ def merge(x, y): return concatenate(x), concatenate(y) -def vertice_overlap(x0, y0, x1, y1, minimal_area=False): +def vertice_overlap( + x0, y0, x1, y1, minimal_area=False, p1_area=False, hybrid_area=False, min_overlap=0 +): r""" Return percent of overlap for each item. @@ -420,6 +445,10 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): :param array x1: x for polygon list 1 :param array y1: y for polygon list 1 :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union + :param bool p1_area: If True, function will compute intersection/p1 polygon, else intersection/union + :param bool hybrid_area: If True, function will compute like union, + but if cost is under min_overlap, obs is kept in case of fully included + :param float min_overlap: under this value cost is set to zero :return: Result of cost function :rtype: array @@ -430,6 +459,10 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): If minimal area: .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{min(P_{0 area},P_{1 area})} + + If P1 area: + + .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{P_{1 area}} """ nb = x0.shape[0] cost = empty(nb) @@ -441,11 +474,29 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): # Area of intersection intersection = (p0 & p1).area() # we divide intersection with the little one result from 0 to 1 + if intersection == 0: + cost[i] = 0 + continue + p0_area_, p1_area_ = p0.area(), p1.area() if minimal_area: - cost[i] = intersection / min(p0.area(), p1.area()) + cost_ = intersection / min(p0_area_, p1_area_) + # we divide intersection with p1 + elif p1_area: + cost_ = intersection / p1_area_ # we divide intersection with polygon merging result from 0 to 1 else: - cost[i] = intersection / (p0 + p1).area() + cost_ = intersection / (p0_area_ + p1_area_ - intersection) + if cost_ >= min_overlap: + cost[i] = cost_ + else: + if ( + hybrid_area + and cost_ != 0 + and (intersection / min(p0_area_, p1_area_)) > 0.99 + ): + cost[i] = cost_ + else: + cost[i] = 0 return cost @@ -495,7 +546,7 @@ def fit_circle(x, y): norme = (x[1:] - x_mean) ** 2 + (y[1:] - y_mean) ** 2 norme_max = norme.max() - scale = norme_max ** 0.5 + scale = norme_max**0.5 # Form matrix equation and solve it # Maybe put f4 @@ -506,7 +557,7 @@ def fit_circle(x, y): (x0, y0, radius), _, _, _ = lstsq(datas, norme / norme_max) # Unscale data and get circle variables - radius += x0 ** 2 + y0 ** 2 + radius += x0**2 + y0**2 radius **= 0.5 x0 *= scale y0 *= scale @@ -538,21 +589,21 @@ def fit_ellipse(x, y): """ nb = x.shape[0] datas = ones((nb, 5), dtype=x.dtype) - datas[:, 0] = x ** 2 + datas[:, 0] = x**2 datas[:, 1] = x * y - datas[:, 2] = y ** 2 + datas[:, 2] = y**2 datas[:, 3] = x datas[:, 4] = y (a, b, c, d, e), _, _, _ = lstsq(datas, ones(nb, dtype=x.dtype)) - det = b ** 2 - 4 * a * c + det = b**2 - 4 * a * c if det > 0: print(det) x0 = (2 * c * d - b * e) / det y0 = (2 * a * e - b * d) / det - AB1 = 2 * (a * e ** 2 + c * d ** 2 - b * d * e - det) + AB1 = 2 * (a * e**2 + c * d**2 - b * d * e - det) AB2 = a + c - AB3 = ((a - c) ** 2 + b ** 2) ** 0.5 + AB3 = ((a - c) ** 2 + b**2) ** 0.5 A = -((AB1 * (AB2 + AB3)) ** 0.5) / det B = -((AB1 * (AB2 - AB3)) ** 0.5) / det theta = arctan((c - a - AB3) / b) @@ -613,7 +664,7 @@ def fit_circle_(x, y): # Linear regression (a, b, c), _, _, _ = lstsq(datas, x[1:] ** 2 + y[1:] ** 2) x0, y0 = a / 2.0, b / 2.0 - radius = (c + x0 ** 2 + y0 ** 2) ** 0.5 + radius = (c + x0**2 + y0**2) ** 0.5 err = shape_error(x, y, x0, y0, radius) return x0, y0, radius, err @@ -638,14 +689,14 @@ def shape_error(x, y, x0, y0, r): :rtype: float """ # circle area - c_area = (r ** 2) * pi + c_area = (r**2) * pi p_area = poly_area(x, y) nb = x.shape[0] x, y = x.copy(), y.copy() # Find distance between circle center and polygon for i in range(nb): dx, dy = x[i] - x0, y[i] - y0 - rd = r / (dx ** 2 + dy ** 2) ** 0.5 + rd = r / (dx**2 + dy**2) ** 0.5 if rd < 1: x[i] = x0 + dx * rd y[i] = y0 + dy * rd diff --git a/src/py_eddy_tracker/tracking.py b/src/py_eddy_tracker/tracking.py index 7543a4d3..b64b6fcc 100644 --- a/src/py_eddy_tracker/tracking.py +++ b/src/py_eddy_tracker/tracking.py @@ -2,15 +2,14 @@ """ Class to store link between observations """ - +from datetime import datetime, timedelta import json import logging import platform -from datetime import datetime, timedelta +from tarfile import ExFileObject from netCDF4 import Dataset, default_fillvals -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( arange, array, @@ -376,7 +375,10 @@ def track(self): # We begin with second file, first one is in previous for file_name in self.datasets[first_dataset:]: self.swap_dataset(file_name, **kwargs) - logger.info("%s match with previous state", file_name) + filename_ = ( + file_name.filename if isinstance(file_name, ExFileObject) else file_name + ) + logger.info("%s match with previous state", filename_) logger.debug("%d obs to match", len(self.current_obs)) nb_real_obs = len(self.previous_obs) @@ -410,14 +412,14 @@ def to_netcdf(self, handler): logger.debug('Create Dimensions "Nstep" : %d', nb_step) handler.createDimension("Nstep", nb_step) var_file_in = handler.createVariable( - zlib=True, + zlib=False, complevel=1, varname="FileIn", datatype="S1024", dimensions="Nstep", ) var_file_out = handler.createVariable( - zlib=True, + zlib=False, complevel=1, varname="FileOut", datatype="S1024", @@ -659,7 +661,7 @@ def merge(self, until=-1, raw_data=True): # Set type of eddy with first file eddies.sign_type = self.current_obs.sign_type # Fields to copy - fields = self.current_obs.obs.dtype.names + fields = self.current_obs.fields # To know if the track start first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_) @@ -708,7 +710,7 @@ def merge(self, until=-1, raw_data=True): # Index in the current file index_current = self[i]["out"] - if "cost_association" in eddies.obs.dtype.names: + if "cost_association" in eddies.fields: eddies["cost_association"][index_final - 1] = self[i]["cost_value"] # Copy all variable for field in fields: diff --git a/src/scripts/EddyTranslate b/src/scripts/EddyTranslate index 26ab3a7b..a0060e9b 100644 --- a/src/scripts/EddyTranslate +++ b/src/scripts/EddyTranslate @@ -3,8 +3,8 @@ """ Translate eddy Dataset """ -import zarr from netCDF4 import Dataset +import zarr from py_eddy_tracker import EddyParser from py_eddy_tracker.observations.observation import EddiesObservations diff --git a/tests/test_generic.py b/tests/test_generic.py index ab3832cc..ee2d7881 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,6 +1,6 @@ from numpy import arange, array, nan, ones, zeros -from py_eddy_tracker.generic import cumsum_by_track, simplify +from py_eddy_tracker.generic import cumsum_by_track, simplify, wrap_longitude def test_simplify(): @@ -30,3 +30,22 @@ def test_cumsum_by_track(): a = ones(10, dtype="i4") * 2 track = array([1, 1, 2, 2, 2, 2, 44, 44, 44, 48]) assert (cumsum_by_track(a, track) == [2, 4, 2, 4, 6, 8, 2, 4, 6, 2]).all() + + +def test_wrapping(): + y = x = arange(-5, 5, dtype="f4") + x_, _ = wrap_longitude(x, y, ref=-10) + assert (x_ == x).all() + x_, _ = wrap_longitude(x, y, ref=1) + assert x.size == x_.size + assert (x_[6:] == x[6:]).all() + assert (x_[:6] == x[:6] + 360).all() + x_, _ = wrap_longitude(x, y, ref=1, cut=True) + assert x.size + 3 == x_.size + assert (x_[6 + 3 :] == x[6:]).all() + assert (x_[:7] == x[:7] + 360).all() + + # FIXME Need evolution in wrap_longitude + # x %= 360 + # x_, _ = wrap_longitude(x, y, ref=-10, cut=True) + # assert x.size == x_.size diff --git a/tests/test_grid.py b/tests/test_grid.py index 759a40e1..0e6dd586 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -7,7 +7,15 @@ G = RegularGridDataset(get_demo_path("mask_1_60.nc"), "lon", "lat") X = 0.025 -contour = Path(((-X, 0), (X, 0), (X, X), (-X, X), (-X, 0),)) +contour = Path( + ( + (-X, 0), + (X, 0), + (X, X), + (-X, X), + (-X, 0), + ) +) # contour @@ -91,7 +99,11 @@ def test_convolution(): ) g = RegularGridDataset.with_array( coordinates=("x", "y"), - datas=dict(z=z, x=arange(0, 6, 0.5), y=arange(0, 5, 0.5),), + datas=dict( + z=z, + x=arange(0, 6, 0.5), + y=arange(0, 5, 0.5), + ), centered=True, ) diff --git a/tests/test_poly.py b/tests/test_poly.py index cca53635..a780f64d 100644 --- a/tests/test_poly.py +++ b/tests/test_poly.py @@ -22,7 +22,7 @@ def test_fit_circle(): x0, y0, r, err = fit_circle(*V) assert x0 == approx(2.5, rel=1e-10) assert y0 == approx(-9.5, rel=1e-10) - assert r == approx(2 ** 0.5 / 2, rel=1e-10) + assert r == approx(2**0.5 / 2, rel=1e-10) assert err == approx((1 - 2 / pi) * 100, rel=1e-10) diff --git a/tests/test_track.py b/tests/test_track.py index 4f362a26..f7e83786 100644 --- a/tests/test_track.py +++ b/tests/test_track.py @@ -1,5 +1,5 @@ -import zarr from netCDF4 import Dataset +import zarr from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker diff --git a/versioneer.py b/versioneer.py index 2b545405..1e3753e6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,5 @@ -# Version: 0.18 + +# Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -6,18 +7,14 @@ ============== * like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer +* https://github.com/python-versioneer/python-versioneer * Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based +* License: Public Domain (Unlicense) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control @@ -26,9 +23,38 @@ ## Quick Install -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results +Versioneer provides two installation modes. The "classic" vendored mode installs +a copy of versioneer into your repository. The experimental build-time dependency mode +is intended to allow you to skip this step and simplify the process of upgrading. + +### Vendored mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) + * Note that you will need to add `tomli; python_version < "3.11"` to your + build-time dependencies if you use `pyproject.toml` +* run `versioneer install --vendor` in your source tree, commit the results +* verify version information with `python setup.py version` + +### Build-time dependency mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) +* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) + to the `requires` key of the `build-system` table in `pyproject.toml`: + ```toml + [build-system] + requires = ["setuptools", "versioneer[toml]"] + build-backend = "setuptools.build_meta" + ``` +* run `versioneer install --no-vendor` in your source tree, commit the results +* verify version information with `python setup.py version` ## Version Identifiers @@ -60,7 +86,7 @@ for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. +uncommitted changes). The version identifier is used for multiple purposes: @@ -165,7 +191,7 @@ Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). +[issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects @@ -179,7 +205,7 @@ `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. + provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs @@ -193,9 +219,9 @@ Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve @@ -223,31 +249,20 @@ cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace +* edit `setup.cfg` and `pyproject.toml`, if necessary, + to include any new configuration settings indicated by the release notes. + See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files @@ -264,36 +279,70 @@ direction and include code from all supported VCS systems, reducing the number of intermediate scripts. +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . +Specifically, both are released under the "Unlicense", as described in +https://unlicense.org/. -""" +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer -from __future__ import print_function +""" +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments -try: - import configparser -except ImportError: - import ConfigParser as configparser +import configparser import errno import json import os import re import subprocess import sys +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import NoReturn +import functools + +have_tomllib = True +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + versionfile_source: str + versionfile_build: Optional[str] + parentdir_prefix: Optional[str] + verbose: Optional[bool] -def get_root(): + +def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the @@ -301,20 +350,28 @@ def get_root(): """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -323,46 +380,62 @@ def get_root(): # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py) - ) + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root -def get_config_from_root(root): +def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or + # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None + root_pth = Path(root) + pyproject_toml = root_pth / "pyproject.toml" + setup_cfg = root_pth / "setup.cfg" + section: Union[Dict[str, Any], configparser.SectionProxy, None] = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, 'rb') as fobj: + pp = tomllib.load(fobj) + section = pp['tool']['versioneer'] + except (tomllib.TOMLDecodeError, KeyError) as e: + print(f"Failed to load config from {pyproject_toml}: {e}") + print("Try to load it from setup.cfg") + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + # `cast`` really shouldn't be used, but its simplest for the + # common VersioneerConfig users at the moment. We verify against + # `None` values elsewhere where it matters cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): + cfg.VCS = section['VCS'] + cfg.style = section.get("style", "") + cfg.versionfile_source = cast(str, section.get("versionfile_source")) + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = cast(str, section.get("tag_prefix")) + if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") + cfg.parentdir_prefix = section.get("parentdir_prefix") + if isinstance(section, configparser.SectionProxy): + # Make sure configparser translates to bool + cfg.verbose = section.getboolean("verbose") + else: + cfg.verbose = section.get("verbose") + return cfg @@ -371,41 +444,48 @@ class NotThisMethod(Exception): # these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f + HANDLERS.setdefault(vcs, {})[method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -416,28 +496,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -446,9 +523,11 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -464,8 +543,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -483,13 +569,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -498,22 +584,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -524,18 +623,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -544,15 +645,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% @@ -561,41 +661,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -608,11 +715,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d @@ -621,7 +728,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: @@ -630,6 +737,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %%s" %% r) return {"version": r, @@ -645,7 +757,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -656,8 +773,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) @@ -665,24 +789,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -699,7 +856,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces @@ -724,26 +881,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -768,23 +926,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] + rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -811,12 +1017,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -833,7 +1068,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -853,7 +1088,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -873,7 +1108,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -887,10 +1122,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -905,7 +1144,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -926,7 +1165,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, @@ -953,41 +1192,48 @@ def get_versions(): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -1000,11 +1246,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1013,7 +1259,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1021,30 +1267,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -1055,7 +1304,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1063,33 +1320,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -1098,16 +1379,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -1116,12 +1398,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1132,19 +1412,20 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def do_vcs_install(manifest_in, versionfile_source, ipy): +def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py @@ -1153,36 +1434,40 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] + files = [versionfile_source] if ipy: files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) present = False try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: pass if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -1191,30 +1476,23 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from +# This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. @@ -1231,43 +1509,41 @@ def get_versions(): """ -def versions_from_file(filename): +def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search( - r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: - mo = re.search( - r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) -def write_to_version_file(filename, versions): +def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) print("set %s to '%s'" % (filename, versions["version"])) -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -1285,29 +1561,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -1334,12 +1659,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -1356,7 +1710,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -1376,7 +1730,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -1396,26 +1750,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -1425,20 +1781,16 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" -def get_versions(verbose=False): +def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. @@ -1453,10 +1805,9 @@ def get_versions(verbose=False): assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1510,22 +1861,22 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} -def get_version(): +def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" +def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and @@ -1539,25 +1890,25 @@ def get_cmdclass(): # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - cmds = {} + cmds = {} if cmdclass is None else cmdclass.copy() - # we add "version" to both distutils and setuptools - from distutils.core import Command + # we add "version" to setuptools + from setuptools import Command class cmd_version(Command): description = "report generated version string" - user_options = [] - boolean_options = [] + user_options: List[Tuple[str, str, str]] = [] + boolean_options: List[str] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) @@ -1565,10 +1916,9 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - # we override "build_py" in both distutils and setuptools + # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py @@ -1583,30 +1933,68 @@ def run(self): # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py + if 'build_py' in cmds: + _build_py: Any = cmds['build_py'] else: - from distutils.command.build_py import build_py as _build_py + from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe + if 'build_ext' in cmds: + _build_ext: Any = cmds['build_ext'] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self) -> None: + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + if not os.path.exists(target_versionfile): + print(f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py.") + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1615,7 +2003,7 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1627,28 +2015,24 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if "py2exe" in sys.modules: # py2exe enabled? + if 'py2exe' in sys.modules: # py2exe enabled? try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 + from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1660,27 +2044,60 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["py2exe"] = cmd_py2exe + # sdist farms its file list building out to egg_info + if 'egg_info' in cmds: + _egg_info: Any = cmds['egg_info'] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self) -> None: + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append('versioneer.py') + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') + for f in self.filelist.files] + + manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') + with open(manifest_filename, 'w') as fobj: + fobj.write('\n'.join(normalized)) + + cmds['egg_info'] = cmd_egg_info + # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist + if 'sdist' in cmds: + _sdist: Any = cmds['sdist'] else: - from distutils.command.sdist import sdist as _sdist + from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): + def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old @@ -1688,7 +2105,7 @@ def run(self): self.distribution.metadata.version = versions["version"] return _sdist.run(self) - def make_release_tree(self, base_dir, files): + def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) @@ -1697,10 +2114,8 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file( - target_versionfile, self._versioneer_generated_versions - ) - + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) cmds["sdist"] = cmd_sdist return cmds @@ -1743,25 +2158,28 @@ def make_release_tree(self, base_dir, files): """ -INIT_PY_SNIPPET = """ +OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" +def do_setup() -> int: + """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) - except ( - EnvironmentError, - configparser.NoSectionError, - configparser.NoOptionError, - ) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) + except (OSError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (OSError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1770,76 +2188,46 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" - if INIT_PY_SNIPPET not in old: + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) + f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print( - " appending versionfile_source ('%s') to MANIFEST.in" - % cfg.versionfile_source - ) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") + maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 -def scan_setup_py(): +def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False @@ -1876,10 +2264,14 @@ def scan_setup_py(): return errors +def setup_command() -> NoReturn: + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) + setup_command()