diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 00000000..5d77d4ca --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,37 @@ +--- +name: Bug report +about: Report a bug or issue with py_eddy_tracker. +# This template is based on the matplotlib template +--- + + + +### Bug report + +**Bug summary** + + + +**Code for reproduction** + + + +```python +# Paste your code here +# +# +``` + +**Actual outcome** + +``` +# If applicable, paste the console output here +# +# +``` + +**Expected outcome** + + diff --git a/.github/ISSUE_TEMPLATE/need_documentation.md b/.github/ISSUE_TEMPLATE/need_documentation.md new file mode 100644 index 00000000..cd46ed63 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/need_documentation.md @@ -0,0 +1,27 @@ +--- +name: Documentation improvement +about: Create a report to help us improve the documentation +labels: documentation +# This template is based on the matplotlib template +--- + +### Problem + + + + +### Suggested Improvement + + \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/questions.md b/.github/ISSUE_TEMPLATE/questions.md new file mode 100644 index 00000000..ae684376 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/questions.md @@ -0,0 +1,13 @@ +--- +name: Questions +about: If you have a usage question +# This template is based on the matplotlib template +--- + + diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..d9437d16 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,70 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + schedule: + - cron: '41 16 * * 4' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 7c800b33..f2f4753e 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,29 +1,33 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Python application +name: Pytest & Flake8 -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] +on: [push, pull_request] jobs: build: - - runs-on: ubuntu-latest + strategy: + matrix: + # os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] + python_version: ['3.10', '3.11', '3.12'] + name: Run py eddy tracker build tests + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -l {0} steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python ${{ matrix.python_version }} uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: ${{ matrix.python_version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest pytest-cov if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Install package run: | @@ -34,6 +38,3 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest diff --git a/.gitignore b/.gitignore index d5d00735..fc3e1bf3 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,8 @@ instance/ # Sphinx documentation docs/_build/ +doc/gen_modules/ +doc/_autosummary/ # PyBuilder target/ diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 00000000..ba36f8ea --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,13 @@ +version: 2 +conda: + environment: doc/environment.yml +build: + os: ubuntu-lts-latest + tools: + python: "mambaforge-latest" +python: + install: + - method: pip + path: . +sphinx: + configuration: doc/conf.py \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst new file mode 100644 index 00000000..6d6d6a30 --- /dev/null +++ b/CHANGELOG.rst @@ -0,0 +1,147 @@ +Changelog +========= + +All notable changes to this project will be documented in this file. + +The format is based on `Keep a Changelog `_ +and this project adheres to `Semantic Versioning `_. + +[Unreleased] +------------- +Changed +^^^^^^^ + +Fixed +^^^^^ + +Added +^^^^^ + +[3.6.2] - 2025-06-06 +-------------------- +Changed +^^^^^^^ + +- Remove dead end method for network will move dead end to the trash and not remove observations + +Fixed +^^^^^ + +- Fix matplotlib version + +[3.6.1] - 2022-10-14 +-------------------- +Changed +^^^^^^^ + +- Rewrite particle candidate to be easily parallelize + +Fixed +^^^^^ + +- Check strictly increasing coordinates for RegularGridDataset. +- Grid mask is check to replace mask monovalue by 2D mask with fixed value + +Added +^^^^^ + +- Add method to colorize contour with a field +- Add option to force align on to return all step for reference dataset +- Add method and property to network to easily select segment and network +- Add method to found same track/segment/network in dataset + +[3.6.0] - 2022-01-12 +-------------------- +Changed +^^^^^^^ + +- Now time allows second precision (instead of daily precision) in storage on uint32 from 01/01/1950 to 01/01/2086 + New identifications are produced with this type, old files could still be loaded. + If you use old identifications for tracking use the `--unraw` option to unpack old times and store data with the new format. +- Now amplitude is stored with .1 mm of precision (instead of 1 mm), same advice as for time. +- Expose more parameters to users for bash tools build_network & divide_network +- Add warning when loading a file created from a previous version of py-eddy-tracker. + + + +Fixed +^^^^^ + +- Fix bug in convolution(filter), lowest rows was replace by zeros in convolution computation. + Important impact for tiny kernel +- Fix method of sampling before contour fitting +- Fix bug when loading dataset in zarr format, not all variables were correctly loaded +- Fix bug when zarr dataset has same size for number of observations and contour size +- Fix bug when tracking, previous_virtual_obs was not always loaded + +Added +^^^^^ + +- Allow to replace mask by isnan method to manage nan data instead of masked data +- Add drifter colocation example + +[3.5.0] - 2021-06-22 +-------------------- + +Fixed +^^^^^ +- GridCollection get_next_time_step & get_previous_time_step needed more files to work in the dataset list. + The loop needed explicitly self.dataset[i+-1] even when i==0, therefore indice went out of range + +[3.4.0] - 2021-03-29 +-------------------- +Changed +^^^^^^^ +- `TrackEddiesObservations.filled_by_interpolation` method stop to normalize longitude, to continue to have same + beahviour you must call before `TrackEddiesObservations.normalize_longitude` + +Fixed +^^^^^ +- Use `safe_load` for yaml load +- repr of EddiesObservation when the collection is empty (time attribute empty array) +- display_timeline and event_timeline can now use colors according to 'y' values. +- event_timeline now plot all merging event in one plot, instead of one plot per merging. Same for splitting. (avoid bad legend) + +Added +^^^^^ +- Identification file could be load in memory before to be read with netcdf library to get speed up in case of slow disk +- Add a filter option in EddyId to be able to remove fine scale (like noise) with same filter order than high scale + filter +- Add **EddyQuickCompare** to have few figures about several datasets in comparison based on match function +- Color and text field for contour in **EddyAnim** could be choose +- Save EddyAnim in mp4 +- Add method to get eddy contour which enclosed obs defined with (x,y) coordinates +- Add **EddyNetworkSubSetter** to subset network which need special tool and operation after subset +- Network: + - Add method to find relatives segments + - Add method to get cloase network in an other atlas +- Management of time cube data for advection + +[3.3.0] - 2020-12-03 +-------------------- +Added +^^^^^ +- Add an implementation of visvalingam algorithm to simplify polygons with low modification +- Add method to found close tracks in an other atlas +- Allow to give a x reference when we display grid to be able to change xlim +- Add option to EddyId to select data index like `--indexs time=5 depth=2` +- Add a method to merge several indexs type for eddy obs +- Get dataset variable like attribute, and lifetime/age are available for all observations +- Add **EddyInfos** application to get general information about eddies dataset +- Add method to inspect contour rejection (which are not in eddies) +- Grid interp could be "nearest" or "bilinear" + +Changed +^^^^^^^ +- Now to have object informations in plot label used python ```format``` style, several key are available : + + - "t0" + - "t1" + - "nb_obs" + - "nb_tracks" (only for tracked eddies) + +[3.2.0] - 2020-09-16 +-------------------- + +[3.1.0] - 2020-06-25 +-------------------- diff --git a/README.md b/README.md index 71db6750..0cc34894 100644 --- a/README.md +++ b/README.md @@ -1,38 +1,65 @@ -[![PyPI version](https://badge.fury.io/py/pyEddyTracker.svg)](https://badge.fury.io/py/pyEddyTracker) [![Documentation Status](https://readthedocs.org/projects/py-eddy-tracker/badge/?version=stable)](https://py-eddy-tracker.readthedocs.io/en/stable/?badge=stable) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/AntSimi/py-eddy-tracker/master?urlpath=lab/tree/notebooks/python_module/) +[![PyPI version](https://badge.fury.io/py/pyEddyTracker.svg)](https://badge.fury.io/py/pyEddyTracker) +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6333988.svg)](https://doi.org/10.5281/zenodo.6333988) +[![Documentation Status](https://readthedocs.org/projects/py-eddy-tracker/badge/?version=stable)](https://py-eddy-tracker.readthedocs.io/en/stable/?badge=stable) +[![Gitter](https://badges.gitter.im/py-eddy-tracker/community.svg)](https://gitter.im/py-eddy-tracker/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) +[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/AntSimi/py-eddy-tracker/master?urlpath=lab/tree/notebooks/python_module/) +[![pytest](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml/badge.svg)](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml) # README # +### How to cite code? ### + +Zenodo provide DOI for each tagged version, [all DOI are available here](https://doi.org/10.5281/zenodo.6333988) + ### Method ### Method was described in : +[Pegliasco, C., Delepoulle, A., Morrow, R., Faugère, Y., and Dibarboure, G.: META3.1exp : A new Global Mesoscale Eddy Trajectories Atlas derived from altimetry, Earth Syst. Sci. Data Discuss.](https://doi.org/10.5194/essd-14-1087-2022) + [Mason, E., A. Pascual, and J. C. McWilliams, 2014: A new sea surface height–based code for oceanic mesoscale eddy tracking.](https://doi.org/10.1175/JTECH-D-14-00019.1) ### Use case ### Method is used in : - + [Mason, E., A. Pascual, P. Gaube, S.Ruiz, J. Pelegrí, A. Delepoulle, 2017: Subregional characterization of mesoscale eddies across the Brazil-Malvinas Confluence](https://doi.org/10.1002/2016JC012611) ### How do I get set up? ### +#### Short story #### + +```bash +pip install pyeddytracker +``` + +#### Long story #### + To avoid problems with installation, use of the virtualenv Python virtual environment is recommended. -Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, pyproj, ...), e.g.: +Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, ...), e.g.: ```bash -pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pyproj pint polygon3 +pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pint polygon3 ``` -Then run the following to install the eddy tracker: +Clone : + +```bash +git clone https://github.com/AntSimi/py-eddy-tracker +``` + +Then run the following to install the eddy tracker : ```bash python setup.py install ``` + ### Tools gallery ### + Several examples based on py eddy tracker module are [here](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html). -![](https://py-eddy-tracker.readthedocs.io/en/latest/_static/logo.png) +[![](https://py-eddy-tracker.readthedocs.io/en/latest/_static/logo.png)](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html) ### Quick use ### diff --git a/apt.txt b/apt.txt new file mode 100644 index 00000000..a72c3b87 --- /dev/null +++ b/apt.txt @@ -0,0 +1 @@ +libgl1-mesa-glx \ No newline at end of file diff --git a/check.sh b/check.sh new file mode 100644 index 00000000..a402bf52 --- /dev/null +++ b/check.sh @@ -0,0 +1,5 @@ +isort . +black . +blackdoc . +flake8 . +python -m pytest -vv --cov py_eddy_tracker --cov-report html diff --git a/doc/.templates/custom-class-template.rst b/doc/.templates/custom-class-template.rst new file mode 100644 index 00000000..ad7dfcb5 --- /dev/null +++ b/doc/.templates/custom-class-template.rst @@ -0,0 +1,33 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :undoc-members: + :show-inheritance: + + {% block methods %} + {% if methods %} + .. rubric:: Methods + + .. autosummary:: + :nosignatures: + {% for item in methods %} + {%- if not item.startswith('_') %} + ~{{ name }}.{{ item }} + {%- endif -%} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block attributes %} + {% if attributes %} + .. rubric:: Attributes + + .. autosummary:: + {% for item in attributes %} + ~{{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/doc/.templates/custom-module-template.rst b/doc/.templates/custom-module-template.rst new file mode 100644 index 00000000..b8786a2d --- /dev/null +++ b/doc/.templates/custom-module-template.rst @@ -0,0 +1,66 @@ +{{ fullname | escape | underline}} + +.. automodule:: {{ fullname }} + + {% block attributes %} + {% if attributes %} + .. rubric:: Module attributes + + .. autosummary:: + :toctree: + {% for item in attributes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: Functions + + .. autosummary:: + :toctree: + :nosignatures: + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: Classes + + .. autosummary:: + :toctree: + :template: custom-class-template.rst + :nosignatures: + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: Exceptions + + .. autosummary:: + :toctree: + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + +{% block modules %} +{% if modules %} +.. autosummary:: + :toctree: + :template: custom-module-template.rst + :recursive: +{% for item in modules %} + {{ item }} +{%- endfor %} +{% endif %} +{% endblock %} diff --git a/doc/api.rst b/doc/api.rst new file mode 100644 index 00000000..866704f8 --- /dev/null +++ b/doc/api.rst @@ -0,0 +1,22 @@ +============= +API reference +============= + + +.. autosummary:: + :toctree: _autosummary + :template: custom-module-template.rst + :recursive: + + py_eddy_tracker.appli + py_eddy_tracker.dataset.grid + py_eddy_tracker.featured_tracking + py_eddy_tracker.observations.network + py_eddy_tracker.observations.observation + py_eddy_tracker.observations.tracking + py_eddy_tracker.observations.groups + py_eddy_tracker.eddy_feature + py_eddy_tracker.generic + py_eddy_tracker.gui + py_eddy_tracker.poly + py_eddy_tracker.tracking diff --git a/doc/autodoc/eddy_feature.rst b/doc/autodoc/eddy_feature.rst deleted file mode 100644 index dec7ef87..00000000 --- a/doc/autodoc/eddy_feature.rst +++ /dev/null @@ -1,8 +0,0 @@ -Eddy Features -============= - -.. automodule:: py_eddy_tracker.eddy_feature - :members: - :undoc-members: - :show-inheritance: - diff --git a/doc/autodoc/featured_tracking.rst b/doc/autodoc/featured_tracking.rst deleted file mode 100644 index d9e8b51e..00000000 --- a/doc/autodoc/featured_tracking.rst +++ /dev/null @@ -1,7 +0,0 @@ -Featured tracking -================= - -.. automodule:: py_eddy_tracker.featured_tracking.old_tracker_reference - :members: - :undoc-members: - :show-inheritance: \ No newline at end of file diff --git a/doc/autodoc/grid.rst b/doc/autodoc/grid.rst deleted file mode 100644 index e915edf3..00000000 --- a/doc/autodoc/grid.rst +++ /dev/null @@ -1,8 +0,0 @@ -Grid -==== - -.. automodule:: py_eddy_tracker.dataset.grid - :members: - :undoc-members: - :show-inheritance: - diff --git a/doc/autodoc/observations.rst b/doc/autodoc/observations.rst deleted file mode 100644 index 9b54f3a3..00000000 --- a/doc/autodoc/observations.rst +++ /dev/null @@ -1,13 +0,0 @@ -Observations -============ - -.. automodule:: py_eddy_tracker.observations.observation - :members: - :undoc-members: - :show-inheritance: - -.. automodule:: py_eddy_tracker.observations.tracking - :members: - :undoc-members: - :show-inheritance: - diff --git a/doc/autodoc/poly.rst b/doc/autodoc/poly.rst deleted file mode 100644 index fa6d5964..00000000 --- a/doc/autodoc/poly.rst +++ /dev/null @@ -1,8 +0,0 @@ -Polygon function -================ - -.. automodule:: py_eddy_tracker.poly - :members: - :undoc-members: - :show-inheritance: - diff --git a/doc/changelog.rst b/doc/changelog.rst new file mode 100644 index 00000000..4d7817ae --- /dev/null +++ b/doc/changelog.rst @@ -0,0 +1 @@ +.. include:: ../CHANGELOG.rst \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py index b14a8941..0844d585 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -15,7 +15,7 @@ # import sys # import os import sphinx_rtd_theme - +import py_eddy_tracker # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -33,29 +33,54 @@ extensions = [ "sphinx.ext.autodoc", "sphinx.ext.doctest", + "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", "sphinx_gallery.gen_gallery", + "matplotlib.sphinxext.plot_directive", ] +# autodoc conf +autoclass_content = "both" + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "numpy": ("https://numpy.org/doc/stable/", None), + "python": ("https://docs.python.org/3/", None), + "matplotlib": ("https://matplotlib.org/", None), +} + + sphinx_gallery_conf = { "examples_dirs": "../examples", # path to your example scripts "gallery_dirs": "python_module", - "capture_repr": ("_repr_html_",), + "capture_repr": ("_repr_html_", "__repr__"), + "backreferences_dir": "gen_modules/backreferences", + "doc_module": ("py_eddy_tracker",), + "reference_url": { + "py_eddy_tracker": None, + }, "line_numbers": False, "filename_pattern": "/pet", + "matplotlib_animations": True, "binder": { # Required keys "org": "AntSimi", "repo": "py-eddy-tracker", "branch": "master", "binderhub_url": "https://mybinder.org", - "dependencies": ["../requirements.txt"], + "dependencies": ["environment.yml"], # Optional keys "use_jupyter_lab": True, }, } +# matplotlib conf +plot_include_source = True + +# Active autosummary +autosummary_generate = True + # Add any paths that contain templates here, relative to this directory. templates_path = [".templates"] @@ -71,18 +96,18 @@ master_doc = "index" # General information about the project. -project = u"py-eddy-tracker" -copyright = u"2019, A. Delepoulle & E. Mason" -author = u"A. Delepoulle & E. Mason" +project = "py-eddy-tracker" +copyright = "2019, A. Delepoulle & E. Mason" +author = "A. Delepoulle & E. Mason" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = u"3.0" +version = py_eddy_tracker.__version__ # The full version, including alpha/beta/rc tags. -release = u"3.0" +release = py_eddy_tracker.__version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -247,8 +272,8 @@ ( master_doc, "py-eddy-tracker.tex", - u"py-eddy-tracker Documentation", - u"A. Delepoulle \\& E. Mason", + "py-eddy-tracker Documentation", + "A. Delepoulle \\& E. Mason", "manual", ), ] @@ -279,7 +304,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "py-eddy-tracker", u"py-eddy-tracker Documentation", [author], 1) + (master_doc, "py-eddy-tracker", "py-eddy-tracker Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -295,7 +320,7 @@ ( master_doc, "py-eddy-tracker", - u"py-eddy-tracker Documentation", + "py-eddy-tracker Documentation", author, "py-eddy-tracker", "One line description of project.", @@ -314,7 +339,3 @@ # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = False - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {"https://docs.python.org/": None} diff --git a/doc/custom_tracking.rst b/doc/custom_tracking.rst index f75ad2c5..f72f3e72 100644 --- a/doc/custom_tracking.rst +++ b/doc/custom_tracking.rst @@ -6,7 +6,6 @@ Customize tracking Code my own tracking ******************** -To use your own tracking method, you just need to create a class which inherit +To use your own tracking method, you just need to create a class which inherits from :meth:`py_eddy_tracker.observations.observation.EddiesObservations` and set this class in yaml file like we see in the previous topic. - diff --git a/doc/environment.yml b/doc/environment.yml new file mode 100644 index 00000000..063a60de --- /dev/null +++ b/doc/environment.yml @@ -0,0 +1,13 @@ +channels: + - conda-forge +dependencies: + - python>=3.10, <3.13 + - ffmpeg + - pip + - pip: + - -r ../requirements.txt + - git+https://github.com/AntSimi/py-eddy-tracker-sample-id.git + - sphinx-gallery + - sphinx_rtd_theme + - sphinx>=3.1 + - pyeddytrackersample diff --git a/doc/grid_identification.rst b/doc/grid_identification.rst index 313f26ab..2cc3fb52 100644 --- a/doc/grid_identification.rst +++ b/doc/grid_identification.rst @@ -8,7 +8,7 @@ Run the identification process for a single day Shell/bash command ****************** -Bash command will allow to process one grid, it will apply a filter and an identification. +Bash command will allow you to process one grid, it will apply a filter and an identification. .. code-block:: bash @@ -18,55 +18,71 @@ Bash command will allow to process one grid, it will apply a filter and an ident out_directory -v DEBUG -Filter could be modify with options *--cut_wavelength* and *--filter_order*. You could also defined height between two isolines with *--isoline_step*, which could +Filter could be modified with options *--cut_wavelength* and *--filter_order*. You could also define height between two isolines with *--isoline_step*, which could improve speed profile quality and detect accurately tiny eddies. You could also use *--fit_errmax* to manage acceptable shape of eddies. An eddy identification will produce two files in the output directory, one for anticyclonic eddies and the other one for cyclonic. -In regional area which are away from the equator, current could be deduce from height, juste write *None None* inplace of *ugos vgos* +In regional areas which are away from the Equator, current could be deduced from height, just write *None None* in place of *ugos vgos* + +In case of **datacube**, you need to specify index for each layer (time, depth, ...) with *--indexs* option like: + +.. code-block:: bash + + EddyId share/nrt_global_allsat_phy_l4_20190223_20190226.nc 20190223 \ + adt ugos vgos longitude latitude \ + out_directory -v DEBUG --indexs time=0 + +.. warning:: + If no index are specified, you will apply identification only on dataset first layer, which could be + a problem for datacube. Date set in command is used only for output storage. Python code *********** -If we want customize eddies identification, python module is here. +If we want to customize eddies identification, the Python module is here. Activate verbose .. code-block:: python from py_eddy_tracker import start_logger - start_logger().setLevel('DEBUG') # Available options: ERROR, WARNING, INFO, DEBUG + + start_logger().setLevel("DEBUG") # Available options: ERROR, WARNING, INFO, DEBUG Run identification .. code-block:: python from datetime import datetime + h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) date = datetime(2019, 2, 23) a, c = h.eddy_identification( - 'adt', 'ugos', 'vgos', # Variables used for identification - date, # Date of identification - 0.002, # step between two isolines of detection (m) - pixel_limit=(5, 2000), # Min and max pixel count for valid contour - shape_error=55, # Error max (%) between ratio of circle fit and contour - ) + "adt", + "ugos", + "vgos", # Variables used for identification + date, # Date of identification + 0.002, # step between two isolines of detection (m) + pixel_limit=(5, 2000), # Min and max pixel count for valid contour + shape_error=55, # Error max (%) between ratio of circle fit and contour + ) Plot the resulting identification .. code-block:: python - fig = plt.figure(figsize=(15,7)) - ax = fig.add_axes([.03,.03,.94,.94]) - ax.set_title('Eddies detected -- Cyclonic(red) and Anticyclonic(blue)') - ax.set_ylim(-75,75) - ax.set_xlim(0,360) - ax.set_aspect('equal') - a.display(ax, color='b', linewidth=.5) - c.display(ax, color='r', linewidth=.5) + fig = plt.figure(figsize=(15, 7)) + ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) + ax.set_title("Eddies detected -- Cyclonic(red) and Anticyclonic(blue)") + ax.set_ylim(-75, 75) + ax.set_xlim(0, 360) + ax.set_aspect("equal") + a.display(ax, color="b", linewidth=0.5) + c.display(ax, color="r", linewidth=0.5) ax.grid() - fig.savefig('share/png/eddies.png') + fig.savefig("share/png/eddies.png") .. image:: ../share/png/eddies.png @@ -75,7 +91,8 @@ Save identification data .. code-block:: python from netCDF import Dataset - with Dataset(date.strftime('share/Anticyclonic_%Y%m%d.nc'), 'w') as h: + + with Dataset(date.strftime("share/Anticyclonic_%Y%m%d.nc"), "w") as h: a.to_netcdf(h) - with Dataset(date.strftime('share/Cyclonic_%Y%m%d.nc'), 'w') as h: + with Dataset(date.strftime("share/Cyclonic_%Y%m%d.nc"), "w") as h: c.to_netcdf(h) diff --git a/doc/grid_load_display.rst b/doc/grid_load_display.rst index 2e570274..2f0e3765 100644 --- a/doc/grid_load_display.rst +++ b/doc/grid_load_display.rst @@ -7,7 +7,12 @@ Loading grid .. code-block:: python from py_eddy_tracker.dataset.grid import RegularGridDataset - grid_name, lon_name, lat_name = 'share/nrt_global_allsat_phy_l4_20190223_20190226.nc', 'longitude', 'latitude' + + grid_name, lon_name, lat_name = ( + "share/nrt_global_allsat_phy_l4_20190223_20190226.nc", + "longitude", + "latitude", + ) h = RegularGridDataset(grid_name, lon_name, lat_name) Plotting grid @@ -15,14 +20,15 @@ Plotting grid .. code-block:: python from matplotlib import pyplot as plt + fig = plt.figure(figsize=(14, 12)) - ax = fig.add_axes([.02, .51, .9, .45]) - ax.set_title('ADT (m)') + ax = fig.add_axes([0.02, 0.51, 0.9, 0.45]) + ax.set_title("ADT (m)") ax.set_ylim(-75, 75) - ax.set_aspect('equal') - m = h.display(ax, name='adt', vmin=-1, vmax=1) + ax.set_aspect("equal") + m = h.display(ax, name="adt", vmin=-1, vmax=1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .51, .01, .45])) + plt.colorbar(m, cax=fig.add_axes([0.94, 0.51, 0.01, 0.45])) Filtering @@ -30,27 +36,27 @@ Filtering .. code-block:: python h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) Save grid .. code-block:: python - h.write('/tmp/grid.nc') + h.write("/tmp/grid.nc") Add second plot .. code-block:: python - ax = fig.add_axes([.02, .02, .9, .45]) - ax.set_title('ADT Filtered (m)') - ax.set_aspect('equal') + ax = fig.add_axes([0.02, 0.02, 0.9, 0.45]) + ax.set_title("ADT Filtered (m)") + ax.set_aspect("equal") ax.set_ylim(-75, 75) - m = h.display(ax, name='adt', vmin=-.1, vmax=.1) + m = h.display(ax, name="adt", vmin=-0.1, vmax=0.1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .02, .01, .45])) - fig.savefig('share/png/filter.png') + plt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.45])) + fig.savefig("share/png/filter.png") .. image:: ../share/png/filter.png \ No newline at end of file diff --git a/doc/index.rst b/doc/index.rst index 4e6cfa46..36b6d8c3 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -35,14 +35,12 @@ Welcome to py-eddy-tracker's documentation! custom_tracking .. toctree:: - :maxdepth: 2 - :caption: code + :maxdepth: 1 + :caption: Code - autodoc/grid - autodoc/observations - autodoc/eddy_feature - autodoc/featured_tracking - autodoc/poly + api + Source (Git) + changelog Indices and tables diff --git a/doc/installation.rst b/doc/installation.rst index 7e56f063..b2bcb45c 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -2,16 +2,22 @@ How do I get set up ? ===================== -Source are available on github https://github.com/AntSimi/py-eddy-tracker +You could install stable version with pip + +.. code-block:: bash + + pip install pyEddyTracker + +Or with source which are available on github https://github.com/AntSimi/py-eddy-tracker Use python3. To avoid problems with installation, use of the virtualenv Python virtual environment is recommended or conda. -Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, pyproj, ...), e.g.: +Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, ...), e.g.: .. code-block:: bash - pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pyproj pint polygon3 + pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pint polygon3 Then run the following to install the eddy tracker: diff --git a/doc/run_tracking.rst b/doc/run_tracking.rst index 17621729..36290339 100644 --- a/doc/run_tracking.rst +++ b/doc/run_tracking.rst @@ -2,24 +2,31 @@ Tracking ======== +Requirements +************ + +Before tracking, you will need to run identification on every time step of the period (period of your study). + +**Advice** : Before tracking, displaying some identification files. You will learn a lot Default method ************** To run a tracking just create an yaml file with minimal specification (*FILES_PATTERN* and *SAVE_DIR*). +You will run tracking separately between Cyclonic eddies and Anticyclonic eddies. -Example of yaml +Example of conf.yaml .. code-block:: yaml PATHS: # Files produces with EddyIdentification - FILES_PATTERN: MY/IDENTIFICATION_PATH/Anticyclonic*.nc + FILES_PATTERN: MY_IDENTIFICATION_PATH/Anticyclonic*.nc SAVE_DIR: MY_OUTPUT_PATH - # Number of timestep for missing detection + # Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 3 - # Minimal time to consider as a full track + # Minimal number of timesteps to considered as a long trajectory TRACK_DURATION_MIN: 10 To run: @@ -28,20 +35,25 @@ To run: EddyTracking conf.yaml -v DEBUG -It will use default tracker: +It will use the default tracker: -- No travel longer than 125 km between two observation -- Amplitude and speed radius must be close to previous observation -- In case of several candidate only closest is kept +- No travel longer than 125 km between two observations +- Amplitude and speed radius must be close to the previous observation +- In case of several candidates only the closest is kept It will produce 4 files by run: -- A file of correspondances which will contains all the information to merge all identifications file -- A file which will contains all the observations which are alone -- A file which will contains all the short track which are shorter than TRACK_DURATION_MIN -- A file which will contains all the long track which are longer than TRACK_DURATION_MIN +- A file of correspondences which will contain all the information to merge all identifications file +- A file which will contain all the observations which are alone +- A file which will contain all the short tracks which are shorter than **TRACK_DURATION_MIN** +- A file which will contain all the long tracks which are longer than **TRACK_DURATION_MIN** +Use Python module +***************** + +An example of tracking with the Python module is available in the gallery: +:ref:`sphx_glr_python_module_08_tracking_manipulation_pet_run_a_tracking.py` Choose a tracker **************** @@ -51,13 +63,13 @@ With yaml you could also select another tracker: .. code-block:: yaml PATHS: - # Files produces with EddyIdentification + # Files produced with EddyIdentification FILES_PATTERN: MY/IDENTIFICATION_PATH/Anticyclonic*.nc SAVE_DIR: MY_OUTPUT_PATH - # Number of timestep for missing detection + # Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 3 - # Minimal time to consider as a full track + # Minimal number of timesteps to considered as a long trajectory TRACK_DURATION_MIN: 10 CLASS: @@ -68,5 +80,5 @@ With yaml you could also select another tracker: # py_eddy_tracker.observations.observation.EddiesObservations CLASS: CheltonTracker -This tracker is like described in CHELTON11[https://doi.org/10.1016/j.pocean.2011.01.002]. +This tracker is like the one described in CHELTON11[https://doi.org/10.1016/j.pocean.2011.01.002]. Code is here :meth:`py_eddy_tracker.featured_tracking.old_tracker_reference` diff --git a/doc/spectrum.rst b/doc/spectrum.rst index d751b909..f96e30a0 100644 --- a/doc/spectrum.rst +++ b/doc/spectrum.rst @@ -11,7 +11,7 @@ Load data raw = RegularGridDataset(grid_name, lon_name, lat_name) filtered = RegularGridDataset(grid_name, lon_name, lat_name) - filtered.bessel_low_filter('adt', 150, order=3) + filtered.bessel_low_filter("adt", 150, order=3) areas = dict( sud_pacific=dict(llcrnrlon=188, urcrnrlon=280, llcrnrlat=-64, urcrnrlat=-7), @@ -23,24 +23,33 @@ Compute and display spectrum .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum') - ax.set_xlabel('km') + ax.set_title("Spectrum") + ax.set_xlabel("km") for name_area, area in areas.items(): - - lon_spec, lat_spec = raw.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s raw' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s raw' % name_area, color=mappable.get_color(), linestyle='--') - - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s high' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = raw.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s raw" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s raw" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s high" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum.png') + fig.savefig("share/png/spectrum.png") .. image:: ../share/png/spectrum.png @@ -49,18 +58,23 @@ Compute and display spectrum ratio .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum ratio') - ax.set_xlabel('km') + ax.set_title("Spectrum ratio") + ax.set_xlabel("km") for name_area, area in areas.items(): - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area, ref=raw) - mappable = ax.plot(*lat_spec, label='lat %s high' % name_area)[0] - ax.plot(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area, ref=raw) + mappable = ax.plot(*lat_spec, label="lat %s high" % name_area)[0] + ax.plot( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum_ratio.png') + fig.savefig("share/png/spectrum_ratio.png") .. image:: ../share/png/spectrum_ratio.png diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..e94c7bc1 --- /dev/null +++ b/environment.yml @@ -0,0 +1,11 @@ +name: binder-pyeddytracker +channels: + - conda-forge +dependencies: + - python>=3.10, <3.13 + - pip + - ffmpeg + - pip: + - -r requirements.txt + - pyeddytrackersample + - . diff --git a/examples/01_general_things/README.rst b/examples/01_general_things/README.rst new file mode 100644 index 00000000..5876c1b6 --- /dev/null +++ b/examples/01_general_things/README.rst @@ -0,0 +1,3 @@ +General features +================ + diff --git a/examples/01_general_things/pet_storage.py b/examples/01_general_things/pet_storage.py new file mode 100644 index 00000000..918ebbee --- /dev/null +++ b/examples/01_general_things/pet_storage.py @@ -0,0 +1,158 @@ +""" +How data is stored +================== + +General information about eddies storage. + +All files have the same structure, with more or less fields and possible different order. + +There are 3 class of files: + +- **Eddies collections** : contain a list of eddies without link between them +- **Track eddies collections** : + manage eddies associated in trajectories, the ```track``` field allows to separate each trajectory +- **Network eddies collections** : + manage eddies associated in networks, the ```track``` and ```segment``` fields allow to separate observations +""" + +from matplotlib import pyplot as plt +from numpy import arange, outer +import py_eddy_tracker_sample + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.observations.observation import EddiesObservations, Table +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + +# %% +# Eddies can be stored in 2 formats with the same structure: +# +# - zarr (https://zarr.readthedocs.io/en/stable/), which allow efficiency in IO,... +# - NetCDF4 (https://unidata.github.io/netcdf4-python/), well-known format +# +# Each field are stored in column, each row corresponds at 1 observation, +# array field like contour/profile are 2D column. + +# %% +# Eddies files (zarr or netcdf) can be loaded with ```load_file``` method: +eddies_collections = EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")) +eddies_collections.field_table() +# offset and scale_factor are used only when data is stored in zarr or netCDF4 + +# %% +# Field access +# ------------ +# To access the total field, here ```amplitude``` +eddies_collections.amplitude + +# To access only a specific part of the field +eddies_collections.amplitude[4:15] + +# %% +# Data matrix is a numpy ndarray +eddies_collections.obs +# %% +eddies_collections.obs.dtype + + +# %% +# Contour storage +# --------------- +# All contours are stored on the same number of points, and are resampled if needed with an algorithm to be stored as objects + +# %% +# Speed profile storage +# --------------------- +# Speed profile is an interpolation of speed mean along each contour. +# For each contour included in eddy, we compute mean of speed along the contour, +# and after we interpolate speed mean array on a fixed size array. +# +# Several field are available to understand "uavg_profile" : +# 0. - num_contours : Number of contour in eddies, must be equal to amplitude divide by isoline step +# 1. - height_inner_contour : height of inner contour used +# 2. - height_max_speed_contour : height of max speed contour used +# 3. - height_external_contour : height of outter contour used +# +# Last value of "uavg_profile" is for inner contour and first value for outter contour. + +# Observations selection of "uavg_profile" with high number of contour(Eddy with high amplitude) +e = eddies_collections.extract_with_mask(eddies_collections.num_contours > 15) + +# %% + +# Raw display of profiles with more than 15 contours +ax = plt.subplot(111) +_ = ax.plot(e.uavg_profile.T, lw=0.5) + +# %% + +# Profile from inner to outter +ax = plt.subplot(111) +ax.plot(e.uavg_profile[:, ::-1].T, lw=0.5) +_ = ax.set_xlabel("From inner to outter contour"), ax.set_ylabel("Speed (m/s)") + +# %% + +# If we normalize indice of contour to set speed contour to 1 and inner contour to 0 +ax = plt.subplot(111) +h_in = e.height_inner_contour +h_s = e.height_max_speed_contour +h_e = e.height_external_contour +r = (h_e - h_in) / (h_s - h_in) +nb_pt = e.uavg_profile.shape[1] +# Create an x array for each profile +x = outer(arange(nb_pt) / nb_pt, r) + +ax.plot(x, e.uavg_profile[:, ::-1].T, lw=0.5) +_ = ax.set_xlabel("From inner to outter contour"), ax.set_ylabel("Speed (m/s)") + + +# %% +# Trajectories +# ------------ +# Tracks eddies collections add several fields : +# +# - **track** : Trajectory number +# - **observation_flag** : Flag indicating if the value is interpolated between two observations or not +# (0: observed eddy, 1: interpolated eddy)" +# - **observation_number** : Eddy temporal index in a trajectory, days starting at the eddy first detection +# - **cost_association** : result of the cost function to associate the eddy with the next observation +eddies_tracks = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") +) +# In this example some fields are removed (effective_contour_longitude,...) in order to save time for doc building +eddies_tracks.field_table() + +# %% +# Networks +# -------- +# Network files use some specific fields : +# +# - track : ID of network (ID 0 correspond to lonely eddies) +# - segment : ID of a segment within a network (from 1 to N) +# - previous_obs : Index of the previous observation in the full dataset, +# if -1 there are no previous observation (the segment starts) +# - next_obs : Index of the next observation in the full dataset, if -1 there are no next observation (the segment ends) +# - previous_cost : Result of the cost function (1 is a good association, 0 is bad) with previous observation +# - next_cost : Result of the cost function (1 is a good association, 0 is bad) with next observation +eddies_network = NetworkObservations.load_file(get_demo_path("network_med.nc")) +eddies_network.field_table() + +# %% +sl = slice(70, 100) +Table( + eddies_network.network(651).obs[sl][ + [ + "time", + "track", + "segment", + "previous_obs", + "previous_cost", + "next_obs", + "next_cost", + ] + ] +) + +# %% +# Networks are ordered by increasing network number (`track`), then increasing segment number, then increasing time diff --git a/examples/02_eddy_identification/README.rst b/examples/02_eddy_identification/README.rst index 30d23e5b..07ef8f44 100644 --- a/examples/02_eddy_identification/README.rst +++ b/examples/02_eddy_identification/README.rst @@ -1,2 +1,4 @@ Eddy detection ============== + +Method to detect eddies from grid and display, with various parameters. diff --git a/examples/02_eddy_identification/pet_contour_circle.py b/examples/02_eddy_identification/pet_contour_circle.py index 5cbba26b..03332285 100644 --- a/examples/02_eddy_identification/pet_contour_circle.py +++ b/examples/02_eddy_identification/pet_contour_circle.py @@ -5,15 +5,16 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.observation import EddiesObservations + from py_eddy_tracker import data +from py_eddy_tracker.observations.observation import EddiesObservations # %% # Load detection files -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) # %% -# Plot +# Plot the speed and effective (dashed) contours fig = plt.figure(figsize=(15, 8)) ax = fig.add_axes((0.05, 0.05, 0.9, 0.9)) ax.set_aspect("equal") @@ -21,7 +22,7 @@ ax.set_ylim(-50, -25) a.display(ax, label="Anticyclonic contour", color="r", lw=1) -# Replace contour by circle +# Replace contours by circles using center and radius (effective is dashed) a.circle_contour() a.display(ax, label="Anticyclonic circle", color="g", lw=1) -ax.legend(loc="upper right") +_ = ax.legend(loc="upper right") diff --git a/examples/02_eddy_identification/pet_display_id.py b/examples/02_eddy_identification/pet_display_id.py index ed1dadcb..57c59bc2 100644 --- a/examples/02_eddy_identification/pet_display_id.py +++ b/examples/02_eddy_identification/pet_display_id.py @@ -5,21 +5,42 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.observation import EddiesObservations + from py_eddy_tracker import data +from py_eddy_tracker.observations.observation import EddiesObservations # %% # Load detection files -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) -c = EddiesObservations.load_file(data.get_path("Cyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +c = EddiesObservations.load_file(data.get_demo_path("Cyclonic_20190223.nc")) + +# %% +# Fill effective contour with amplitude +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) +ax.set_aspect("equal") +ax.set_xlim(0, 140) +ax.set_ylim(-80, 0) +kwargs = dict(extern_only=True, color="k", lw=1) +a.display(ax, **kwargs), c.display(ax, **kwargs) +a.filled(ax, "amplitude", cmap="magma_r", vmin=0, vmax=0.5) +m = c.filled(ax, "amplitude", cmap="magma_r", vmin=0, vmax=0.5) +colorbar = plt.colorbar(m, cax=ax.figure.add_axes([0.95, 0.03, 0.02, 0.94])) +colorbar.set_label("Amplitude (m)") # %% -# Plot +# Draw speed contours fig = plt.figure(figsize=(15, 8)) -ax = fig.add_subplot(111) +ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) ax.set_aspect("equal") ax.set_xlim(0, 360) ax.set_ylim(-80, 80) -a.display(ax, label="Anticyclonic", color="r", lw=1) -c.display(ax, label="Cyclonic", color="b", lw=1) +a.display(ax, label="Anticyclonic ({nb_obs} eddies)", color="r", lw=1) +c.display(ax, label="Cyclonic ({nb_obs} eddies)", color="b", lw=1) ax.legend(loc="upper right") + +# %% +# Get general informations +print(a) +# %% +print(c) diff --git a/examples/02_eddy_identification/pet_eddy_detection.py b/examples/02_eddy_identification/pet_eddy_detection.py index 25be4901..b1b2c1af 100644 --- a/examples/02_eddy_identification/pet_eddy_detection.py +++ b/examples/02_eddy_identification/pet_eddy_detection.py @@ -1,6 +1,6 @@ """ -Eddy detection -============== +Eddy detection : Med +==================== Script will detect eddies on adt field, and compute u,v with method add_uv(which could use, only if equator is avoid) @@ -8,9 +8,12 @@ """ from datetime import datetime + from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset +from numpy import arange + from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset # %% @@ -19,34 +22,36 @@ def start_axes(title): ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) ax.set_aspect("equal") - ax.set_title(title) + ax.set_title(title, weight="bold") return ax def update_axes(ax, mappable=None): ax.grid() if mappable: - plt.colorbar(m, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) # %% -# Load Input grid, ADT will be used to detect eddies +# Load Input grid, ADT is used to detect eddies g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) ax = start_axes("ADT (m)") -m = g.display(ax, "adt", vmin=-0.15, vmax=0.15) +m = g.display(ax, "adt", vmin=-0.15, vmax=0.15, cmap="RdBu_r") update_axes(ax, m) # %% -# Get u/v -# ------- -# U/V are deduced from ADT, this algortihm are not usable around equator (~+- 2°) +# Get geostrophic speed u,v +# ------------------------- +# U/V are deduced from ADT, this algortihm is not ok near the equator (~+- 2°) g.add_uv("adt") ax = start_axes("U/V deduce from ADT (m)") ax.set_xlim(2.5, 9), ax.set_ylim(37.5, 40) -m = g.display(ax, "adt", vmin=-0.15, vmax=0.15) +m = g.display(ax, "adt", vmin=-0.15, vmax=0.15, cmap="RdBu_r") u, v = g.grid("u").T, g.grid("v").T ax.quiver(g.x_c, g.y_c, u, v, scale=10) update_axes(ax, m) @@ -54,65 +59,101 @@ def update_axes(ax, mappable=None): # %% # Pre-processings # --------------- -# Apply high filter to remove long scale to highlight mesoscale +# Apply a high-pass filter to remove the large scale and highlight the mesoscale g.bessel_high_filter("adt", 500) ax = start_axes("ADT (m) filtered (500km)") -m = g.display(ax, "adt", vmin=-0.15, vmax=0.15) +m = g.display(ax, "adt", vmin=-0.15, vmax=0.15, cmap="RdBu_r") update_axes(ax, m) # %% # Identification # -------------- -# run identification with slice of 2 mm +# Run the identification step with slices of 2 mm date = datetime(2016, 5, 15) -a, c = g.eddy_identification("adt", "u", "v", date, 0.002) +a, c = g.eddy_identification("adt", "u", "v", date, 0.002, shape_error=55) # %% -# All closed contour found in this input grid (Display only 1 contour every 4) -ax = start_axes("ADT closed contour (only 1 / 4 levels)") +# Display of all closed contours found in the grid (only 1 contour every 4) +ax = start_axes("ADT closed contours (only 1 / 4 levels)") g.contours.display(ax, step=4) update_axes(ax) # %% -# Contours include in eddies -ax = start_axes("ADT contour used as eddies") +# Contours included in eddies +ax = start_axes("ADT contours used as eddies") g.contours.display(ax, only_used=True) update_axes(ax) # %% -# Contours reject from several origin (shape error to high, several extremum in contour, ...) -ax = start_axes("ADT contour reject") +# Post analysis +# ------------- +# Contours can be rejected for several reasons (shape error to high, several extremum in contour, ...) +ax = start_axes("ADT rejected contours") g.contours.display(ax, only_unused=True) update_axes(ax) # %% -# Contours closed which contains several eddies -ax = start_axes("ADT contour reject but which contain eddies") +# Criteria for rejecting a contour: +# 0. - Accepted (green) +# 1. - Rejection for shape error (red) +# 2. - Masked value within contour (blue) +# 3. - Under or over the pixel limit bounds (black) +# 4. - Amplitude criterion (yellow) +ax = start_axes("Contours' rejection criteria") +g.contours.display(ax, only_unused=True, lw=0.5, display_criterion=True) +update_axes(ax) + +# %% +# Display the shape error of each tested contour, the limit of shape error is set to 55 % +ax = start_axes("Contour shape error") +m = g.contours.display( + ax, lw=0.5, field="shape_error", bins=arange(20, 90.1, 5), cmap="PRGn_r" +) +update_axes(ax, m) + +# %% +# Some closed contours contains several eddies (aka, more than one extremum) +ax = start_axes("ADT rejected contours containing eddies") g.contours.label_contour_unused_which_contain_eddies(a) g.contours.label_contour_unused_which_contain_eddies(c) g.contours.display( - ax, only_contain_eddies=True, color="k", lw=1, label="Could be interaction contour" + ax, + only_contain_eddies=True, + color="k", + lw=1, + label="Could be a contour of interaction", ) -a.display(ax, color="r", linewidth=0.5, label="Anticyclonic", ref=-10) -c.display(ax, color="b", linewidth=0.5, label="Cyclonic", ref=-10) +a.display(ax, color="r", linewidth=0.75, label="Anticyclonic", ref=-10) +c.display(ax, color="b", linewidth=0.75, label="Cyclonic", ref=-10) ax.legend() update_axes(ax) # %% # Output # ------ -# Display detected eddies, dashed lines represent effective contour -# and solid lines represent contour of maximum of speed. See figure 1 of https://doi.org/10.1175/JTECH-D-14-00019.1 +# When displaying the detected eddies, dashed lines are for effective contour, solide lines for the contour of +# the maximum mean speed. See figure 1 of https://doi.org/10.1175/JTECH-D-14-00019.1 -ax = start_axes("Eddies detected") -a.display(ax, color="r", linewidth=0.5, label="Anticyclonic", ref=-10) -c.display(ax, color="b", linewidth=0.5, label="Cyclonic", ref=-10) +ax = start_axes("Detected Eddies") +a.display( + ax, color="r", linewidth=0.75, label="Anticyclonic ({nb_obs} eddies)", ref=-10 +) +c.display(ax, color="b", linewidth=0.75, label="Cyclonic ({nb_obs} eddies)", ref=-10) ax.legend() update_axes(ax) # %% -# Display speed radius of eddies detected -ax = start_axes("Eddies speed radius (km)") -a.scatter(ax, "radius_s", vmin=10, vmax=50, s=80, ref=-10, cmap="jet", factor=0.001) -m = c.scatter(ax, "radius_s", vmin=10, vmax=50, s=80, ref=-10, cmap="jet", factor=0.001) +# Display the speed radius of the detected eddies +ax = start_axes("Speed Radius (km)") +kwargs = dict(vmin=10, vmax=50, s=80, ref=-10, cmap="magma_r", factor=0.001) +a.scatter(ax, "radius_s", **kwargs) +m = c.scatter(ax, "radius_s", **kwargs) +update_axes(ax, m) + +# %% +# Filling the effective radius contours with the effective radius values +ax = start_axes("Effective Radius (km)") +kwargs = dict(vmin=10, vmax=80, cmap="magma_r", factor=0.001, lut=14, ref=-10) +a.filled(ax, "effective_radius", **kwargs) +m = c.filled(ax, "radius_e", **kwargs) update_axes(ax, m) diff --git a/examples/02_eddy_identification/pet_eddy_detection_ACC.py b/examples/02_eddy_identification/pet_eddy_detection_ACC.py new file mode 100644 index 00000000..d12c62f3 --- /dev/null +++ b/examples/02_eddy_identification/pet_eddy_detection_ACC.py @@ -0,0 +1,207 @@ +""" +Eddy detection : Antartic Circumpolar Current +============================================= + +This script detect eddies on the ADT field, and compute u,v with the method add_uv (use it only if the Equator is avoided) + +Two detections are provided : with a filtered ADT and without filtering + +""" + +from datetime import datetime + +from matplotlib import pyplot as plt, style + +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset + +pos_cb = [0.1, 0.52, 0.83, 0.015] +pos_cb2 = [0.1, 0.07, 0.4, 0.015] + + +def quad_axes(title): + style.use("default") + fig = plt.figure(figsize=(13, 10)) + fig.suptitle(title, weight="bold", fontsize=14) + axes = list() + + ax_pos = dict( + topleft=[0.1, 0.54, 0.4, 0.38], + topright=[0.53, 0.54, 0.4, 0.38], + botleft=[0.1, 0.09, 0.4, 0.38], + botright=[0.53, 0.09, 0.4, 0.38], + ) + + for key, position in ax_pos.items(): + ax = fig.add_axes(position) + ax.set_xlim(5, 45), ax.set_ylim(-60, -37) + ax.set_aspect("equal"), ax.grid(True) + axes.append(ax) + if "right" in key: + ax.set_yticklabels("") + return fig, axes + + +def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold"): + for ax in fig.get_axes(): + ax.grid() + ax.grid(which="major", linestyle="-", linewidth="0.5", color="black") + if ax.get_ylabel() != "": + ax.set_ylabel(ax.get_ylabel(), fontsize=labelsize, fontweight=labelweight) + if ax.get_xlabel() != "": + ax.set_xlabel(ax.get_xlabel(), fontsize=labelsize, fontweight=labelweight) + if ax.get_title() != "": + ax.set_title(ax.get_title(), fontsize=labelsize, fontweight=labelweight) + ax.tick_params(labelsize=ticklabelsize) + + +# %% +# Load Input grid, ADT is used to detect eddies +margin = 30 + +kw_data = dict( + filename=data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + x_name="longitude", + y_name="latitude", + # Manual area subset + indexs=dict( + latitude=slice(100 - margin, 220 + margin), + longitude=slice(0, 230 + margin), + ), +) +g_raw = RegularGridDataset(**kw_data) +g_raw.add_uv("adt") +g = RegularGridDataset(**kw_data) +g.copy("adt", "adt_low") +g.bessel_high_filter("adt", 700) +g.bessel_low_filter("adt_low", 700) +g.add_uv("adt") + +# %% +# Identification +# ^^^^^^^^^^^^^^ +# Run the identification step with slices of 2 mm +date = datetime(2019, 2, 23) +kw_ident = dict( + date=date, step=0.002, shape_error=70, sampling=30, uname="u", vname="v" +) +a, c = g.eddy_identification("adt", **kw_ident) +a_, c_ = g_raw.eddy_identification("adt", **kw_ident) + + +# %% +# Figures +# ------- +kw_adt = dict(vmin=-1.5, vmax=1.5, cmap=plt.get_cmap("RdBu_r", 30)) +fig, axs = quad_axes("General properties field") +g_raw.display(axs[0], "adt", **kw_adt) +axs[0].set_title("Total ADT (m)") +m = g.display(axs[1], "adt_low", **kw_adt) +axs[1].set_title("ADT (m) large scale, cutoff at 700 km") +m2 = g.display(axs[2], "adt", cmap=plt.get_cmap("RdBu_r", 20), vmin=-0.5, vmax=0.5) +axs[2].set_title("ADT (m) high-pass filtered, a cutoff at 700 km") +cb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation="horizontal") +cb.set_label("ADT (m)", labelpad=0) +cb2 = plt.colorbar(m2, cax=axs[2].figure.add_axes(pos_cb2), orientation="horizontal") +cb2.set_label("ADT (m)", labelpad=0) +set_fancy_labels(fig) + +# %% +# The large-scale North-South gradient is removed by the filtering step. + +# %% +fig, axs = quad_axes("") +axs[0].set_title("Without filter") +axs[0].set_ylabel("Contours used in eddies") +axs[1].set_title("With filter") +axs[2].set_ylabel("Closed contours but not used") +g_raw.contours.display(axs[0], lw=0.5, only_used=True) +g.contours.display(axs[1], lw=0.5, only_used=True) +g_raw.contours.display(axs[2], lw=0.5, only_unused=True) +g.contours.display(axs[3], lw=0.5, only_unused=True) +set_fancy_labels(fig) + +# %% +# Removing the large-scale North-South gradient reveals closed contours in the +# South-Western corner of the ewample region. + +# %% +kw = dict(ref=-10, linewidth=0.75) +kw_a = dict(color="r", label="Anticyclonic ({nb_obs} eddies)") +kw_c = dict(color="b", label="Cyclonic ({nb_obs} eddies)") +kw_filled = dict(vmin=0, vmax=100, cmap="Spectral_r", lut=20, intern=True, factor=100) +fig, axs = quad_axes("Comparison between two detections") +# Match with intern/inner contour +i_a, j_a, s_a = a_.match(a, intern=True, cmin=0.15) +i_c, j_c, s_c = c_.match(c, intern=True, cmin=0.15) + +a_.index(i_a).filled(axs[0], s_a, **kw_filled) +a.index(j_a).filled(axs[1], s_a, **kw_filled) +c_.index(i_c).filled(axs[0], s_c, **kw_filled) +m = c.index(j_c).filled(axs[1], s_c, **kw_filled) + +cb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation="horizontal") +cb.set_label("Similarity index (%)", labelpad=-5) +a_.display(axs[0], **kw, **kw_a), c_.display(axs[0], **kw, **kw_c) +a.display(axs[1], **kw, **kw_a), c.display(axs[1], **kw, **kw_c) + +axs[0].set_title("Without filter") +axs[0].set_ylabel("Detection") +axs[1].set_title("With filter") +axs[2].set_ylabel("Contours' rejection criteria") + +g_raw.contours.display(axs[2], lw=0.5, only_unused=True, display_criterion=True) +g.contours.display(axs[3], lw=0.5, only_unused=True, display_criterion=True) + +for ax in axs: + ax.legend() + +set_fancy_labels(fig) + +# %% +# Very similar eddies have Similarity Indexes >= 40% + +# %% +# Criteria for rejecting a contour : +# 0. Accepted (green) +# 1. Rejection for shape error (red) +# 2. Masked value within contour (blue) +# 3. Under or over the pixel limit bounds (black) +# 4. Amplitude criterion (yellow) + +# %% +i_a, j_a = i_a[s_a >= 0.4], j_a[s_a >= 0.4] +i_c, j_c = i_c[s_c >= 0.4], j_c[s_c >= 0.4] +fig = plt.figure(figsize=(12, 12)) +fig.suptitle(f"Scatter plot (A : {i_a.shape[0]}, C : {i_c.shape[0]} matches)") + +for i, (label, field, factor, stop) in enumerate( + ( + ("Speed radius (km)", "radius_s", 0.001, 120), + ("Effective radius (km)", "radius_e", 0.001, 120), + ("Amplitude (cm)", "amplitude", 100, 25), + ("Speed max (cm/s)", "speed_average", 100, 25), + ) +): + ax = fig.add_subplot(2, 2, i + 1, title=label) + ax.set_xlabel("Without filter") + ax.set_ylabel("With filter") + + ax.plot( + a_[field][i_a] * factor, + a[field][j_a] * factor, + "r.", + label="Anticyclonic", + ) + ax.plot( + c_[field][i_c] * factor, + c[field][j_c] * factor, + "b.", + label="Cyclonic", + ) + ax.set_aspect("equal"), ax.grid() + ax.plot((0, 1000), (0, 1000), "g") + ax.set_xlim(0, stop), ax.set_ylim(0, stop) + ax.legend() + +set_fancy_labels(fig) diff --git a/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py b/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py new file mode 100644 index 00000000..55267b76 --- /dev/null +++ b/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py @@ -0,0 +1,161 @@ +""" +Eddy detection : Gulf stream +============================ + +Script will detect eddies on adt field, and compute u,v with method add_uv(which could use, only if equator is avoid) + +Figures will show different step to detect eddies. + +""" +from datetime import datetime + +from matplotlib import pyplot as plt +from numpy import arange + +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.eddy_feature import Contours + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 8)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(279, 304), ax.set_ylim(29, 44) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Load Input grid, ADT is used to detect eddies +margin = 30 +g = RegularGridDataset( + data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + "longitude", + "latitude", + # Manual area subset + indexs=dict( + longitude=slice(1116 - margin, 1216 + margin), + latitude=slice(476 - margin, 536 + margin), + ), +) + +ax = start_axes("ADT (m)") +m = g.display(ax, "adt", vmin=-1, vmax=1, cmap="RdBu_r") +# Draw line on the gulf stream front +great_current = Contours(g.x_c, g.y_c, g.grid("adt"), levels=(0.35,), keep_unclose=True) +great_current.display(ax, color="k") +update_axes(ax, m) + +# %% +# Get geostrophic speed u,v +# ------------------------- +# U/V are deduced from ADT, this algortihm is not ok near the equator (~+- 2°) +g.add_uv("adt") + +# %% +# Pre-processings +# --------------- +# Apply a high-pass filter to remove the large scale and highlight the mesoscale +g.bessel_high_filter("adt", 700) +ax = start_axes("ADT (m) filtered (700km)") +m = g.display(ax, "adt", vmin=-0.4, vmax=0.4, cmap="RdBu_r") +great_current.display(ax, color="k") +update_axes(ax, m) + +# %% +# Identification +# -------------- +# Run the identification step with slices of 2 mm +date = datetime(2016, 5, 15) +a, c = g.eddy_identification("adt", "u", "v", date, 0.002, shape_error=55) + +# %% +# Display of all closed contours found in the grid (only 1 contour every 5) +ax = start_axes("ADT closed contours (only 1 / 5 levels)") +g.contours.display(ax, step=5, lw=1) +great_current.display(ax, color="k") +update_axes(ax) + +# %% +# Contours included in eddies +ax = start_axes("ADT contours used as eddies") +g.contours.display(ax, only_used=True, lw=0.25) +great_current.display(ax, color="k") +update_axes(ax) + +# %% +# Post analysis +# ------------- +# Contours can be rejected for several reasons (shape error to high, several extremum in contour, ...) +ax = start_axes("ADT rejected contours") +g.contours.display(ax, only_unused=True, lw=0.25) +great_current.display(ax, color="k") +update_axes(ax) + +# %% +# Criteria for rejecting a contour : +# 0. Accepted (green) +# 1. Rejection for shape error (red) +# 2. Masked value within contour (blue) +# 3. Under or over the pixel limit bounds (black) +# 4. Amplitude criterion (yellow) +ax = start_axes("Contours' rejection criteria") +g.contours.display(ax, only_unused=True, lw=0.5, display_criterion=True) +update_axes(ax) + +# %% +# Display the shape error of each tested contour, the limit of shape error is set to 55 % +ax = start_axes("Contour shape error") +m = g.contours.display( + ax, lw=0.5, field="shape_error", bins=arange(20, 90.1, 5), cmap="PRGn_r" +) +update_axes(ax, m) + +# %% +# Some closed contours contains several eddies (aka, more than one extremum) +ax = start_axes("ADT rejected contours containing eddies") +g.contours.label_contour_unused_which_contain_eddies(a) +g.contours.label_contour_unused_which_contain_eddies(c) +g.contours.display( + ax, + only_contain_eddies=True, + color="k", + lw=1, + label="Could be a contour of interaction", +) +a.display(ax, color="r", linewidth=0.75, label="Anticyclonic", ref=-10) +c.display(ax, color="b", linewidth=0.75, label="Cyclonic", ref=-10) +ax.legend() +update_axes(ax) + +# %% +# Output +# ------ +# When displaying the detected eddies, dashed lines are for effective contour, solide lines for the contour of the +# maximum mean speed. See figure 1 of https://doi.org/10.1175/JTECH-D-14-00019.1 + +ax = start_axes("Eddies detected") +a.display( + ax, color="r", linewidth=0.75, label="Anticyclonic ({nb_obs} eddies)", ref=-10 +) +c.display(ax, color="b", linewidth=0.75, label="Cyclonic ({nb_obs} eddies)", ref=-10) +ax.legend() +great_current.display(ax, color="k") +update_axes(ax) + + +# %% +# Display the effective radius of the detected eddies +ax = start_axes("Effective radius (km)") +a.filled(ax, "radius_e", vmin=10, vmax=150, cmap="magma_r", factor=0.001, lut=14) +m = c.filled(ax, "radius_e", vmin=10, vmax=150, cmap="magma_r", factor=0.001, lut=14) +great_current.display(ax, color="k") +update_axes(ax, m) diff --git a/examples/02_eddy_identification/pet_filter_and_detection.py b/examples/02_eddy_identification/pet_filter_and_detection.py index a329d603..ec02a28c 100644 --- a/examples/02_eddy_identification/pet_filter_and_detection.py +++ b/examples/02_eddy_identification/pet_filter_and_detection.py @@ -4,11 +4,13 @@ """ from datetime import datetime + from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset -from py_eddy_tracker import data from numpy import arange +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset + # %% def start_axes(title): @@ -16,65 +18,76 @@ def start_axes(title): ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) ax.set_aspect("equal") - ax.set_title(title) + ax.set_title(title, weight="bold") return ax def update_axes(ax, mappable=None): ax.grid() if mappable: - plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) # %% -# Load Input grid, ADT will be used to detect eddies +# Load Input grid, ADT is used to detect eddies. +# Add a new filed to store the high-pass filtered ADT g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude", + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) g.add_uv("adt") g.copy("adt", "adt_high") -wavelength = 400 +wavelength = 800 g.bessel_high_filter("adt_high", wavelength) date = datetime(2016, 5, 15) # %% -# Run algorithm of detection -a_f, c_f = g.eddy_identification("adt_high", "u", "v", date, 0.002) -merge_f = a_f.merge(c_f) -a_r, c_r = g.eddy_identification("adt", "u", "v", date, 0.002) -merge_r = a_r.merge(c_r) +# Run the detection for the total grid and the filtered grid +a_filtered, c_filtered = g.eddy_identification("adt_high", "u", "v", date, 0.002) +merge_f = a_filtered.merge(c_filtered) +a_tot, c_tot = g.eddy_identification("adt", "u", "v", date, 0.002) +merge_t = a_tot.merge(c_tot) # %% -# Display detection +# Display the two detections ax = start_axes("Eddies detected over ADT") m = g.display(ax, "adt", vmin=-0.15, vmax=0.15) -merge_f.display(ax, lw=0.5, label="Eddy from filtered grid", ref=-10, color="k") -merge_r.display(ax, lw=0.5, label="Eddy from raw grid", ref=-10, color="r") +merge_f.display( + ax, + lw=0.75, + label="Eddies in the filtered grid ({nb_obs} eddies)", + ref=-10, + color="k", +) +merge_t.display( + ax, lw=0.75, label="Eddies without filter ({nb_obs} eddies)", ref=-10, color="r" +) ax.legend() update_axes(ax, m) # %% -# Parameters distribution -# ----------------------- +# Amplitude and Speed Radius distributions +# ---------------------------------------- fig = plt.figure(figsize=(12, 5)) -ax_a = plt.subplot(121, xlabel="amplitdue(cm)") -ax_r = plt.subplot(122, xlabel="speed radius (km)") +ax_a = fig.add_subplot(121, xlabel="Amplitude (cm)") +ax_r = fig.add_subplot(122, xlabel="Speed Radius (km)") ax_a.hist( - merge_f["amplitude"] * 100, + merge_f.amplitude * 100, bins=arange(0.0005, 100, 1), - label="Eddy from filtered grid", + label="Eddies in the filtered grid", histtype="step", ) ax_a.hist( - merge_r["amplitude"] * 100, + merge_t.amplitude * 100, bins=arange(0.0005, 100, 1), - label="Eddy from raw grid", + label="Eddies without filter", histtype="step", ) ax_a.set_xlim(0, 10) -ax_r.hist(merge_f["radius_s"] / 1000.0, bins=arange(0, 300, 5), histtype="step") -ax_r.hist(merge_r["radius_s"] / 1000.0, bins=arange(0, 300, 5), histtype="step") +ax_r.hist(merge_f.radius_s / 1000.0, bins=arange(0, 300, 5), histtype="step") +ax_r.hist(merge_t.radius_s / 1000.0, bins=arange(0, 300, 5), histtype="step") ax_r.set_xlim(0, 100) ax_a.legend() @@ -82,16 +95,16 @@ def update_axes(ax, mappable=None): # Match detection and compare # --------------------------- -i_, j_, c = merge_f.match(merge_r, cmin=0.1) +i_, j_, c = merge_f.match(merge_t, cmin=0.1) # %% -# where is lonely eddies -kwargs_f = dict(lw=1.5, label="Lonely eddy from filtered grid", ref=-10, color="k") -kwargs_r = dict(lw=1.5, label="Lonely eddy from raw grid", ref=-10, color="r") +# Where are the lonely eddies? +kwargs_f = dict(lw=1.5, label="Lonely eddies in the filtered grid", ref=-10, color="k") +kwargs_t = dict(lw=1.5, label="Lonely eddies without filter", ref=-10, color="r") ax = start_axes("Eddies with no match, over filtered ADT") mappable = g.display(ax, "adt_high", vmin=-0.15, vmax=0.15) merge_f.index(i_, reverse=True).display(ax, **kwargs_f) -merge_r.index(j_, reverse=True).display(ax, **kwargs_r) +merge_t.index(j_, reverse=True).display(ax, **kwargs_t) ax.legend() update_axes(ax, mappable) @@ -101,26 +114,26 @@ def update_axes(ax, mappable=None): u, v = g.grid("u").T, g.grid("v").T ax.quiver(g.x_c, g.y_c, u, v, scale=10, pivot="mid", color="gray") merge_f.index(i_, reverse=True).display(ax, **kwargs_f) -merge_r.index(j_, reverse=True).display(ax, **kwargs_r) +merge_t.index(j_, reverse=True).display(ax, **kwargs_t) ax.legend() update_axes(ax, mappable) # %% fig = plt.figure(figsize=(12, 12)) -fig.suptitle(f"Scatter plot ({i_.shape[0]} matches)") +fig.suptitle(f"Scatter plot ({i_.shape[0]} matches)", weight="bold") for i, (label, field, factor, stop) in enumerate( ( - ("speed radius (km)", "radius_s", 0.001, 80), - ("outter radius (km)", "radius_e", 0.001, 120), - ("amplitude (cm)", "amplitude", 100, 25), - ("speed max (cm/s)", "speed_average", 100, 25), + ("Speed radius (km)", "radius_s", 0.001, 80), + ("Effective radius (km)", "radius_e", 0.001, 120), + ("Amplitude (cm)", "amplitude", 100, 25), + ("Maximum Speed (cm/s)", "speed_average", 100, 25), ) ): ax = fig.add_subplot( - 2, 2, i + 1, xlabel="filtered grid", ylabel="raw grid", title=label + 2, 2, i + 1, xlabel="Filtered grid", ylabel="Without filter", title=label ) - ax.plot(merge_f[field][i_] * factor, merge_r[field][j_] * factor, ".") + ax.plot(merge_f[field][i_] * factor, merge_t[field][j_] * factor, ".") ax.set_aspect("equal"), ax.grid() ax.plot((0, 1000), (0, 1000), "r") ax.set_xlim(0, stop), ax.set_ylim(0, stop) diff --git a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py new file mode 100644 index 00000000..fa27a3d1 --- /dev/null +++ b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py @@ -0,0 +1,67 @@ +""" +Get mean of grid in each eddies +=============================== + +""" + +from matplotlib import pyplot as plt + +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.observations.observation import EddiesObservations + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title) + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + ax.legend() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + + +# %% +# Load detection files and data to interp +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20160515.nc")) +c = EddiesObservations.load_file(data.get_demo_path("Cyclonic_20160515.nc")) + +aviso_map = RegularGridDataset( + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", +) +aviso_map.add_uv("adt") + +# %% +# Compute and store eke in cm²/s² +aviso_map.add_grid( + "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100**2) +) + +eke_kwargs = dict(vmin=1, vmax=1000, cmap="magma_r") + +ax = start_axes("EKE (cm²/s²)") +m = aviso_map.display(ax, "eke", **eke_kwargs) +a.display(ax, color="r", linewidth=0.5, label="Anticyclonic", ref=-10) +c.display(ax, color="b", linewidth=0.5, label="Cyclonic", ref=-10) +update_axes(ax, m) + +# %% +# Get mean of eke in each effective contour + +ax = start_axes("EKE mean (cm²/s²)") +a.display(ax, color="r", linewidth=0.5, label="Anticyclonic ({nb_obs} eddies)", ref=-10) +c.display(ax, color="b", linewidth=0.5, label="Cyclonic ({nb_obs} eddies)", ref=-10) +eke = a.interp_grid(aviso_map, "eke", method="mean", intern=False) +a.filled(ax, eke, ref=-10, **eke_kwargs) +eke = c.interp_grid(aviso_map, "eke", method="mean", intern=False) +m = c.filled(ax, eke, ref=-10, **eke_kwargs) +update_axes(ax, m) diff --git a/examples/02_eddy_identification/pet_radius_vs_area.py b/examples/02_eddy_identification/pet_radius_vs_area.py new file mode 100644 index 00000000..e34ad725 --- /dev/null +++ b/examples/02_eddy_identification/pet_radius_vs_area.py @@ -0,0 +1,48 @@ +""" +Radius vs area +============== + +""" +from matplotlib import pyplot as plt +from numpy import array, pi + +from py_eddy_tracker import data +from py_eddy_tracker.generic import coordinates_to_local +from py_eddy_tracker.observations.observation import EddiesObservations +from py_eddy_tracker.poly import poly_area + +# %% +# Load detection files +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +areas = list() +# For each contour area will be compute in local reference +for i in a: + x, y = coordinates_to_local( + i["contour_lon_s"], i["contour_lat_s"], i["lon"], i["lat"] + ) + areas.append(poly_area(x, y)) +areas = array(areas) + +# %% +# Radius provided by eddy detection is computed with :func:`~py_eddy_tracker.poly.fit_circle` method. +# This radius will be compared with an equivalent radius deduced from polygon area. +ax = plt.subplot(111) +ax.set_aspect("equal") +ax.grid() +ax.set_xlabel("Speed radius computed with fit_circle") +ax.set_ylabel("Radius deduced from area\nof contour_lon_s/contour_lat_s") +ax.set_title("Area vs radius") +ax.plot(a["radius_s"] / 1000.0, (areas / pi) ** 0.5 / 1000.0, ".") +ax.plot((0, 250), (0, 250), "r") + +# %% +# Fit circle give a radius bigger than polygon area + +# %% +# When error is tiny, radius are very close. +ax = plt.subplot(111) +ax.grid() +ax.set_xlabel("Radius ratio") +ax.set_ylabel("Shape error") +ax.set_title("err = f(radius_ratio)") +ax.plot(a["radius_s"] / (areas / pi) ** 0.5, a["shape_error_s"], ".") diff --git a/examples/02_eddy_identification/pet_shape_gallery.py b/examples/02_eddy_identification/pet_shape_gallery.py index cc8d2242..ed8df83d 100644 --- a/examples/02_eddy_identification/pet_shape_gallery.py +++ b/examples/02_eddy_identification/pet_shape_gallery.py @@ -6,9 +6,10 @@ """ from matplotlib import pyplot as plt -from numpy import arange, radians, linspace, cos, sin -from py_eddy_tracker.dataset.grid import RegularGridDataset +from numpy import arange, cos, linspace, radians, sin + from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset from py_eddy_tracker.eddy_feature import Contours from py_eddy_tracker.generic import local_to_coordinates @@ -24,7 +25,9 @@ def build_circle(x0, y0, r): # %% # We iterate over closed contours and sort with regards of shape error g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) c = Contours(g.x_c, g.y_c, g.grid("adt") * 100, arange(-50, 50, 0.2)) contours = dict() diff --git a/examples/02_eddy_identification/pet_sla_and_adt.py b/examples/02_eddy_identification/pet_sla_and_adt.py index 7d8fb288..29dcc0a7 100644 --- a/examples/02_eddy_identification/pet_sla_and_adt.py +++ b/examples/02_eddy_identification/pet_sla_and_adt.py @@ -4,9 +4,11 @@ """ from datetime import datetime + from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset + from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset # %% @@ -29,7 +31,9 @@ def update_axes(ax, mappable=None): # Load Input grid, ADT will be used to detect eddies g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude", + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) g.add_uv("adt", "ugos", "vgos") g.add_uv("sla", "ugosa", "vgosa") @@ -41,10 +45,14 @@ def update_axes(ax, mappable=None): date = datetime(2016, 5, 15) # %% -kwargs_a_adt = dict(lw=0.5, label="Anticyclonic ADT", ref=-10, color="k") -kwargs_c_adt = dict(lw=0.5, label="Cyclonic ADT", ref=-10, color="r") -kwargs_a_sla = dict(lw=0.5, label="Anticyclonic SLA", ref=-10, color="g") -kwargs_c_sla = dict(lw=0.5, label="Cyclonic SLA", ref=-10, color="b") +kwargs_a_adt = dict( + lw=0.5, label="Anticyclonic ADT ({nb_obs} eddies)", ref=-10, color="k" +) +kwargs_c_adt = dict(lw=0.5, label="Cyclonic ADT ({nb_obs} eddies)", ref=-10, color="r") +kwargs_a_sla = dict( + lw=0.5, label="Anticyclonic SLA ({nb_obs} eddies)", ref=-10, color="g" +) +kwargs_c_sla = dict(lw=0.5, label="Cyclonic SLA ({nb_obs} eddies)", ref=-10, color="b") # %% # Run algorithm of detection @@ -133,8 +141,19 @@ def update_axes(ax, mappable=None): ax.set_xlabel("Absolute Dynamic Topography") ax.set_ylabel("Sea Level Anomaly") - ax.plot(a_adt[field][i_a_adt] * factor, a_sla[field][i_a_sla] * factor, "r.") - ax.plot(c_adt[field][i_c_adt] * factor, c_sla[field][i_c_sla] * factor, "b.") + ax.plot( + a_adt[field][i_a_adt] * factor, + a_sla[field][i_a_sla] * factor, + "r.", + label="Anticyclonic", + ) + ax.plot( + c_adt[field][i_c_adt] * factor, + c_sla[field][i_c_sla] * factor, + "b.", + label="Cyclonic", + ) ax.set_aspect("equal"), ax.grid() ax.plot((0, 1000), (0, 1000), "g") ax.set_xlim(0, stop), ax.set_ylim(0, stop) + ax.legend() diff --git a/examples/02_eddy_identification/pet_statistics_on_identification.py b/examples/02_eddy_identification/pet_statistics_on_identification.py new file mode 100644 index 00000000..dbd73c61 --- /dev/null +++ b/examples/02_eddy_identification/pet_statistics_on_identification.py @@ -0,0 +1,105 @@ +""" +Stastics on identification files +================================ + +Some statistics on raw identification without any tracking +""" +from matplotlib import pyplot as plt +from matplotlib.dates import date2num +import numpy as np + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.observations.observation import EddiesObservations + +start_logger().setLevel("ERROR") + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title) + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + + +# %% +# We load demo sample and take only first year. +# +# Replace by a list of filename to apply on your own dataset. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:365] + +# %% +# Merge all identification datasets in one object +all_a = EddiesObservations.concatenate( + [EddiesObservations.load_file(i) for i in file_objects] +) + +# %% +# We define polygon bound +x0, x1, y0, y1 = 15, 20, 33, 38 +xs = np.array([[x0, x1, x1, x0, x0]], dtype="f8") +ys = np.array([[y0, y0, y1, y1, y0]], dtype="f8") +# Polygon object created for the match function use. +polygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys) + +# %% +# Geographic frequency of eddies +step = 0.125 +ax = start_axes("") +# Count pixel encompassed in each contour +g_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True) +m = g_a.display( + ax, cmap="terrain_r", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name="count" +) +ax.plot(polygon["contour_lon_e"][0], polygon["contour_lat_e"][0], "r") +update_axes(ax, m) + +# %% +# We use the match function to count the number of eddies that intersect the polygon defined previously +# `p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon. +i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False) +i_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True) + +# %% +dt = np.datetime64("1970-01-01") - np.datetime64("1950-01-01") +kw_hist = dict( + bins=date2num(np.arange(21900, 22300).astype("datetime64") - dt), histtype="step" +) +# translate julian day in datetime64 +t = all_a.time.astype("datetime64") - dt +# %% +# Number of eddies within a polygon +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Different ways to count eddies within a polygon") +ax.set_ylabel("Count") +m = all_a.mask_from_polygons(((xs, ys),)) +ax.hist(t[m], label="Eddy Center in polygon", **kw_hist) +ax.hist(t[i_s[c_s > 0]], label="Intersection Speed contour and polygon", **kw_hist) +ax.hist(t[i_e[c_e > 0]], label="Intersection Effective contour and polygon", **kw_hist) +ax.legend() +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() + +# %% +# Percent of the area of interest occupied by eddies. +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Percent of polygon occupied by an anticyclonic eddy") +ax.set_ylabel("Percent of polygon") +ax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label="speed contour", **kw_hist) +ax.hist( + t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label="effective contour", **kw_hist +) +ax.legend(), ax.set_ylim(0, 25) +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() diff --git a/examples/06_grid_manipulation/README.rst b/examples/06_grid_manipulation/README.rst index 16e64b89..a885d867 100644 --- a/examples/06_grid_manipulation/README.rst +++ b/examples/06_grid_manipulation/README.rst @@ -1,2 +1,2 @@ Grid Manipulation -======================= +================= diff --git a/examples/06_grid_manipulation/pet_advect.py b/examples/06_grid_manipulation/pet_advect.py new file mode 100644 index 00000000..d7cc67e9 --- /dev/null +++ b/examples/06_grid_manipulation/pet_advect.py @@ -0,0 +1,178 @@ +""" +Grid advection +============== + +Dummy advection which use only static geostrophic current, which didn't solve the complex circulation of the ocean. +""" +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, isnan, meshgrid, ones + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.observation import EddiesObservations + +# %% +# Load Input grid ADT +g = RegularGridDataset( + get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" +) +# Compute u/v from height +g.add_uv("adt") + +# %% +# Load detection files +a = EddiesObservations.load_file(get_demo_path("Anticyclonic_20160515.nc")) +c = EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")) + + +# %% +# Quiver from u/v with eddies +fig = plt.figure(figsize=(10, 5)) +ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) +ax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid() +x, y = meshgrid(g.x_c, g.y_c) +a.filled(ax, facecolors="r", alpha=0.1), c.filled(ax, facecolors="b", alpha=0.1) +_ = ax.quiver(x.T, y.T, g.grid("u"), g.grid("v"), scale=20) + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Anim +# ---- +# Particles setup +step_p = 1 / 8 +x, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p)) +x, y = x.reshape(-1), y.reshape(-1) +# Remove all original position that we can't advect at first place +m = ~isnan(g.interp("u", x, y)) +x0, y0 = x[m], y[m] +x, y = x0.copy(), y0.copy() + +# %% +# Movie properties +kwargs = dict(frames=arange(51), interval=100) +kw_p = dict(u_name="u", v_name="v", nb_step=2, time_step=21600) +frame_t = kw_p["nb_step"] * kw_p["time_step"] / 86400.0 + + +# %% +# Function +def anim_ax(**kw): + t = 0 + fig = plt.figure(figsize=(10, 5), dpi=55) + axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) + axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid() + a.filled(axes, facecolors="r", alpha=0.1), c.filled(axes, facecolors="b", alpha=0.1) + line = axes.plot([], [], "k", **kw)[0] + return fig, axes.text(21, 32.1, ""), line, t + + +def update(i_frame, t_step): + global t + x, y = p.__next__() + t += t_step + l.set_data(x, y) + txt.set_text(f"T0 + {t:.1f} days") + + +# %% +# Filament forward +# ^^^^^^^^^^^^^^^^ +# Draw 3 last position in one path for each particles., +# it could be run backward with `backward=True` option in filament method +p = g.filament(x, y, **kw_p, filament_size=3) +fig, txt, l, t = anim_ax(lw=0.5) +_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) + +# %% +# Particle forward +# ^^^^^^^^^^^^^^^^^ +# Forward advection of particles +p = g.advect(x, y, **kw_p) +fig, txt, l, t = anim_ax(ls="", marker=".", markersize=1) +_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) + +# %% +# We get last position and run backward until original position +p = g.advect(x, y, **kw_p, backward=True) +fig, txt, l, _ = anim_ax(ls="", marker=".", markersize=1) +_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,)) + +# %% +# Particles stat +# -------------- + +# %% +# Time_step settings +# ^^^^^^^^^^^^^^^^^^ +# Dummy experiment to test advection precision, we run particles 50 days forward and backward with different time step +# and we measure distance between new positions and original positions. +fig = plt.figure() +ax = fig.add_subplot(111) +kw = dict( + bins=arange(0, 50, 0.001), + cumulative=True, + weights=ones(x0.shape) / x0.shape[0] * 100.0, + histtype="step", +) +for time_step in (10800, 21600, 43200, 86400): + x, y = x0.copy(), y0.copy() + kw_advect = dict( + nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name="u", v_name="v" + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"{86400. / time_step:.0f} time step by day") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title("Distance after 50 days forward and 50 days backward") +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than") + +# %% +# Time duration +# ^^^^^^^^^^^^^ +# We keep same time_step but change time duration +fig = plt.figure() +ax = fig.add_subplot(111) +time_step = 10800 +for duration in (5, 50, 100): + x, y = x0.copy(), y0.copy() + kw_advect = dict( + nb_step=int(duration * 86400 / time_step), + time_step=time_step, + u_name="u", + v_name="v", + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"Time duration {duration} days") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title( + "Distance after N days forward and N days backward\nwith a time step of 1/8 days" +) +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than ") diff --git a/examples/06_grid_manipulation/pet_filter.py b/examples/06_grid_manipulation/pet_filter.py index 83d687af..ae4356d7 100644 --- a/examples/06_grid_manipulation/pet_filter.py +++ b/examples/06_grid_manipulation/pet_filter.py @@ -7,11 +7,12 @@ We code a specific filter in order to filter grid with same wavelength at each pixel. """ -from py_eddy_tracker.dataset.grid import RegularGridDataset -from py_eddy_tracker import data from matplotlib import pyplot as plt from numpy import arange +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset + def start_axes(title): fig = plt.figure(figsize=(13, 5)) @@ -25,13 +26,15 @@ def start_axes(title): def update_axes(ax, mappable=None): ax.grid() if mappable: - plt.colorbar(m, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) # %% # All information will be for regular grid g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude", + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) # %% # Kernel diff --git a/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py b/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py index 2c007dfa..388c9c7f 100644 --- a/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py +++ b/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py @@ -7,19 +7,20 @@ from matplotlib import pyplot as plt from matplotlib.path import Path from numpy import ones -from py_eddy_tracker.observations.observation import EddiesObservations + +from py_eddy_tracker import data from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.observations.observation import EddiesObservations from py_eddy_tracker.poly import create_vertice -from py_eddy_tracker import data # %% # Load an eddy file which contains contours -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) # %% # Load a grid where we want found pixels in eddies or out g = RegularGridDataset( - data.get_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), "longitude", "latitude", ) @@ -34,7 +35,7 @@ # We will used the outter contour x_name, y_name = a.intern(False) adt = g.grid("adt") -mask = ones(adt.shape, dtype='bool') +mask = ones(adt.shape, dtype="bool") for eddy in a: i, j = Path(create_vertice(eddy[x_name], eddy[y_name])).pixels_in(g) mask[i, j] = False diff --git a/examples/06_grid_manipulation/pet_lavd.py b/examples/06_grid_manipulation/pet_lavd.py new file mode 100644 index 00000000..a3ea846e --- /dev/null +++ b/examples/06_grid_manipulation/pet_lavd.py @@ -0,0 +1,177 @@ +""" +LAVD experiment +=============== + +Naive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation) method with a static velocity field. +In the current example we didn't remove a mean vorticity. + +Method are described here: + + - Abernathey, Ryan, and George Haller. "Transport by Lagrangian Vortices in the Eastern Pacific", + Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021, + https://doi.org/10.1175/JPO-D-17-0102.1 + - `Transport by Coherent Lagrangian Vortices`_, + R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019, + Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop + +.. _Transport by Coherent Lagrangian Vortices: + https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf + +""" +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, meshgrid, zeros + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.observation import EddiesObservations + + +# %% +def start_ax(title="", dpi=90): + fig = plt.figure(figsize=(16, 9), dpi=dpi) + ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) + ax.set_xlim(0, 32), ax.set_ylim(28, 46) + ax.set_title(title) + return fig, ax, ax.text(3, 32, "", fontsize=20) + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + cb = plt.colorbar( + mappable, + cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]), + orientation="horizontal", + ) + cb.set_label("Vorticity integration along trajectory at initial position") + return cb + + +kw_vorticity = dict(vmin=0, vmax=2e-5, cmap="viridis") + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Data +# ---- +# To compute vorticity (:math:`\omega`) we compute u/v field with a stencil and apply the following equation with stencil +# method : +# +# .. math:: +# \omega = \frac{\partial v}{\partial x} - \frac{\partial u}{\partial y} +g = RegularGridDataset( + get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" +) +g.add_uv("adt") +u_y = g.compute_stencil(g.grid("u"), vertical=True) +v_x = g.compute_stencil(g.grid("v")) +g.vars["vort"] = v_x - u_y + +# %% +# Display vorticity field +fig, ax, _ = start_ax() +mappable = g.display(ax, abs(g.grid("vort")), **kw_vorticity) +cb = update_axes(ax, mappable) +cb.set_label("Vorticity") + +# %% +# Particles +# --------- +# Particles specification +step = 1 / 32 +x_g, y_g = arange(0, 36, step), arange(28, 46, step) +x, y = meshgrid(x_g, y_g) +original_shape = x.shape +x, y = x.reshape(-1), y.reshape(-1) +print(f"{len(x)} particles advected") +# A frame every 8h +step_by_day = 3 +# Compute step of advection every 4h +nb_step = 2 +kw_p = dict( + nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name="u", v_name="v" +) +# Start a generator which at each iteration return new position at next time step +particule = g.advect(x, y, **kw_p, rk4=True) + +# %% +# LAVD +# ---- +lavd = zeros(original_shape) +# Advection time +nb_days = 8 +# Nb frame +nb_time = step_by_day * nb_days +i = 0.0 + + +# %% +# Anim +# ^^^^ +# Movie of LAVD integration at each integration time step. +def update(i_frame): + global lavd, i + i += 1 + x, y = particule.__next__() + # Interp vorticity on new_position + lavd += abs(g.interp("vort", x, y).reshape(original_shape) * 1 / nb_time) + txt.set_text(f"T0 + {i / step_by_day:.2f} days of advection") + pcolormesh.set_array(lavd / i * nb_time) + return pcolormesh, txt + + +kw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True) +fig, ax, txt = start_ax(dpi=60) +x_g_, y_g_ = ( + arange(0 - step / 2, 36 + step / 2, step), + arange(28 - step / 2, 46 + step / 2, step), +) +# pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular +pcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity) +update_axes(ax, pcolormesh) +_ = VideoAnimation(ax.figure, update, **kw_video) + +# %% +# Final LAVD +# ^^^^^^^^^^ + +# %% +# Format LAVD data +lavd = RegularGridDataset.with_array( + coordinates=("lon", "lat"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True +) + +# %% +# Display final LAVD with py eddy tracker detection. +# Period used for LAVD integration (8 days) is too short for a real use, but choose for example efficiency. +fig, ax, _ = start_ax() +mappable = lavd.display(ax, "lavd", **kw_vorticity) +EddiesObservations.load_file(get_demo_path("Anticyclonic_20160515.nc")).display( + ax, color="k" +) +EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")).display( + ax, color="k" +) +_ = update_axes(ax, mappable) diff --git a/examples/06_grid_manipulation/pet_okubo_weiss.py b/examples/06_grid_manipulation/pet_okubo_weiss.py new file mode 100644 index 00000000..aa8a063e --- /dev/null +++ b/examples/06_grid_manipulation/pet_okubo_weiss.py @@ -0,0 +1,154 @@ +r""" +Get Okubo Weis +============== + +.. math:: OW = S_n^2 + S_s^2 - \omega^2 + +with normal strain (:math:`S_n`), shear strain (:math:`S_s`) and vorticity (:math:`\omega`) + +.. math:: + S_n = \frac{\partial u}{\partial x} - \frac{\partial v}{\partial y}, + S_s = \frac{\partial v}{\partial x} + \frac{\partial u}{\partial y}, + \omega = \frac{\partial v}{\partial x} - \frac{\partial u}{\partial y} + +""" +from matplotlib import pyplot as plt +from numpy import arange, ma, where + +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.observations.observation import EddiesObservations + + +# %% +def start_axes(title, zoom=False): + fig = plt.figure(figsize=(12, 6)) + axes = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + axes.set_xlim(0, 360), axes.set_ylim(-80, 80) + if zoom: + axes.set_xlim(270, 340), axes.set_ylim(20, 50) + axes.set_aspect("equal") + axes.set_title(title) + return axes + + +def update_axes(axes, mappable=None): + axes.grid() + if mappable: + plt.colorbar(mappable, cax=axes.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Load detection files +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +c = EddiesObservations.load_file(data.get_demo_path("Cyclonic_20190223.nc")) + +# %% +# Load Input grid, ADT will be used to detect eddies +g = RegularGridDataset( + data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + "longitude", + "latitude", +) + +ax = start_axes("ADT (cm)") +m = g.display(ax, "adt", vmin=-120, vmax=120, factor=100) +update_axes(ax, m) + +# %% +# Get parameter for ow +u_x = g.compute_stencil(g.grid("ugos")) +u_y = g.compute_stencil(g.grid("ugos"), vertical=True) +v_x = g.compute_stencil(g.grid("vgos")) +v_y = g.compute_stencil(g.grid("vgos"), vertical=True) +ow = g.vars["ow"] = (u_x - v_y) ** 2 + (v_x + u_y) ** 2 - (v_x - u_y) ** 2 + +ax = start_axes("Okubo weis") +m = g.display(ax, "ow", vmin=-1e-10, vmax=1e-10, cmap="bwr") +update_axes(ax, m) + +# %% +# Gulf stream zoom +ax = start_axes("Okubo weis, Gulf stream", zoom=True) +m = g.display(ax, "ow", vmin=-1e-10, vmax=1e-10, cmap="bwr") +kw_ed = dict(intern_only=True, color="k", lw=1) +a.display(ax, **kw_ed), c.display(ax, **kw_ed) +update_axes(ax, m) + +# %% +# only negative OW +ax = start_axes("Okubo weis, Gulf stream", zoom=True) +threshold = ow.std() * -0.2 +ow = ma.array(ow, mask=ow > threshold) +m = g.display(ax, ow, vmin=-1e-10, vmax=1e-10, cmap="bwr") +a.display(ax, **kw_ed), c.display(ax, **kw_ed) +update_axes(ax, m) + +# %% +# Get okubo-weiss mean/min/center in eddies +plt.figure(figsize=(8, 6)) +ax = plt.subplot(111) +ax.set_xlabel("Okubo-Weiss parameter") +kw_hist = dict(bins=arange(-20e-10, 20e-10, 50e-12), histtype="step") +for method in ("mean", "center", "min"): + kw_interp = dict(grid_object=g, varname="ow", method=method, intern=True) + _, _, m = ax.hist( + a.interp_grid(**kw_interp), label=f"Anticyclonic - OW {method}", **kw_hist + ) + ax.hist( + c.interp_grid(**kw_interp), + label=f"Cyclonic - OW {method}", + color=m[0].get_edgecolor(), + ls="--", + **kw_hist, + ) +ax.axvline(threshold, color="r") +ax.set_yscale("log") +ax.grid() +ax.set_ylim(1, 1e4) +ax.set_xlim(-15e-10, 15e-10) +ax.legend() + +# %% +# Catch eddies with bad OW +ax = start_axes("Eddies with a min OW in speed contour over threshold") +ow_min = a.interp_grid(**kw_interp) +a_bad_ow = a.index(where(ow_min > threshold)[0]) +a_bad_ow.display(ax, color="r", label="Anticyclonic") +ow_min = c.interp_grid(**kw_interp) +c_bad_ow = c.index(where(ow_min > threshold)[0]) +c_bad_ow.display(ax, color="b", label="Cyclonic") +ax.legend() + +# %% +# Display Radius and amplitude of eddies +fig = plt.figure(figsize=(12, 5)) +fig.suptitle( + "Parameter distribution (solid line) and cumulative distribution (dashed line)" +) +ax_amp, ax_rad = fig.add_subplot(121), fig.add_subplot(122) +ax_amp_c, ax_rad_c = ax_amp.twinx(), ax_rad.twinx() +ax_amp_c.set_ylim(0, 1), ax_rad_c.set_ylim(0, 1) +kw_a = dict(xname="amplitude", bins=arange(0, 2, 0.002).astype("f4")) +kw_r = dict(xname="radius_s", bins=arange(0, 500e6, 2e3).astype("f4")) +for d, label, color in ( + (a, "Anticyclonic all", "r"), + (a_bad_ow, "Anticyclonic bad OW", "orange"), + (c, "Cyclonic all", "blue"), + (c_bad_ow, "Cyclonic bad OW", "lightblue"), +): + x, y = d.bins_stat(**kw_a) + ax_amp.plot(x * 100, y, label=label, color=color) + ax_amp_c.plot( + x * 100, y.cumsum() / y.sum(), label=label, color=color, ls="-.", lw=0.5 + ) + x, y = d.bins_stat(**kw_r) + ax_rad.plot(x * 1e-3, y, label=label, color=color) + ax_rad_c.plot( + x * 1e-3, y.cumsum() / y.sum(), label=label, color=color, ls="-.", lw=0.5 + ) + +ax_amp.set_xlim(0, 12.5), ax_amp.grid(), ax_amp.set_ylim(0), ax_amp.legend() +ax_rad.set_xlim(0, 120), ax_rad.grid(), ax_rad.set_ylim(0) +ax_amp.set_xlabel("Amplitude (cm)"), ax_amp.set_ylabel("Nb eddies") +ax_rad.set_xlabel("Speed radius (km)") diff --git a/examples/07_cube_manipulation/README.rst b/examples/07_cube_manipulation/README.rst new file mode 100644 index 00000000..7cecfbd4 --- /dev/null +++ b/examples/07_cube_manipulation/README.rst @@ -0,0 +1,2 @@ +Time grid computation +===================== diff --git a/examples/07_cube_manipulation/pet_cube.py b/examples/07_cube_manipulation/pet_cube.py new file mode 100644 index 00000000..cba6c85b --- /dev/null +++ b/examples/07_cube_manipulation/pet_cube.py @@ -0,0 +1,152 @@ +""" +Time advection +============== + +Example which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles. +""" +from datetime import datetime, timedelta + +# sphinx_gallery_thumbnail_number = 2 +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, isnan, meshgrid, ones + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection +from py_eddy_tracker.gui import GUI_AXES + +start_logger().setLevel("ERROR") + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Data +# ---- +# Load Input time grid ADT +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + # To create U/V variable + heigth="adt", +) + +# %% +# Anim +# ---- +# Particles setup +step_p = 1 / 8 +x, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p)) +x, y = x.reshape(-1), y.reshape(-1) +# Remove all original position that we can't advect at first place +t0 = 20181 +m = ~isnan(c[t0].interp("u", x, y)) +x0, y0 = x[m], y[m] +x, y = x0.copy(), y0.copy() + + +# %% +# Function +def anim_ax(**kw): + fig = plt.figure(figsize=(10, 5), dpi=55) + axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) + axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid() + line = axes.plot([], [], "k", **kw)[0] + return fig, axes.text(21, 32.1, ""), line + + +def update(_): + tt, xt, yt = f.__next__() + mappable.set_data(xt, yt) + d = timedelta(tt / 86400.0) + datetime(1950, 1, 1) + txt.set_text(f"{d:%Y/%m/%d-%H}") + + +# %% +f = c.filament(x, y, "u", "v", t_init=t0, nb_step=2, time_step=21600, filament_size=3) +fig, txt, mappable = anim_ax(lw=0.5) +ani = VideoAnimation(fig, update, frames=arange(160), interval=100) + + +# %% +# Particules stat +# --------------- +# Time_step settings +# ^^^^^^^^^^^^^^^^^^ +# Dummy experiment to test advection precision, we run particles 50 days forward and backward with different time step +# and we measure distance between new positions and original positions. +fig = plt.figure() +ax = fig.add_subplot(111) +kw = dict( + bins=arange(0, 50, 0.002), + cumulative=True, + weights=ones(x0.shape) / x0.shape[0] * 100.0, + histtype="step", +) +kw_p = dict(u_name="u", v_name="v", nb_step=1) +for time_step in (10800, 21600, 43200, 86400): + x, y = x0.copy(), y0.copy() + nb = int(30 * 86400 / time_step) + # Go forward + p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + # Go backward + p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"{86400. / time_step:.0f} time step by day") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title("Distance after 50 days forward and 50 days backward") +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than") + +# %% +# Time duration +# ^^^^^^^^^^^^^ +# We keep same time_step but change time duration +fig = plt.figure() +ax = fig.add_subplot(111) +time_step = 10800 +for duration in (10, 40, 80): + x, y = x0.copy(), y0.copy() + nb = int(duration * 86400 / time_step) + # Go forward + p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + # Go backward + p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"Time duration {duration} days") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title( + "Distance after N days forward and N days backward\nwith a time step of 1/8 days" +) +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than ") diff --git a/examples/07_cube_manipulation/pet_fsle_med.py b/examples/07_cube_manipulation/pet_fsle_med.py new file mode 100644 index 00000000..9d78ea02 --- /dev/null +++ b/examples/07_cube_manipulation/pet_fsle_med.py @@ -0,0 +1,200 @@ +""" +FSLE experiment in med +====================== + +Example to build Finite Size Lyapunov Exponents, parameter values must be adapted for your case. + +Example use a method similar to `AVISO flse`_ + +.. _AVISO flse: + https://www.aviso.altimetry.fr/en/data/products/value-added-products/ + fsle-finite-size-lyapunov-exponents/fsle-description.html + +""" + +from matplotlib import pyplot as plt +from numba import njit +from numpy import arange, arctan2, empty, isnan, log, ma, meshgrid, ones, pi, zeros + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset + +start_logger().setLevel("ERROR") + + +# %% +# ADT in med +# ---------- +# :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is +# made for data stores in time cube, you could use also +# :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to +# load data-cube from multiple file. +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + # To create U/V variable + heigth="adt", +) + + +# %% +# Methods to compute FSLE +# ----------------------- +@njit(cache=True, fastmath=True) +def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): + """ + Check if distance between eastern or northern particle to center particle is bigger than `dist_max` + """ + nb_p = x.shape[0] // 3 + delta = dist_max**2 + for i in range(nb_p): + i0 = i * 3 + i_n = i0 + 1 + i_e = i0 + 2 + # If particle already set, we skip + if m[i0] or m[i_n] or m[i_e]: + continue + # Distance with north + dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n] + dn = dxn**2 + dyn**2 + # Distance with east + dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e] + de = dxe**2 + dye**2 + + if dn >= delta or de >= delta: + s1 = dn + de + at1 = 2 * (dxe * dxn + dye * dyn) + at2 = de - dn + s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * ( + (dxn - dye) ** 2 + (dxe + dyn) ** 2 + ) + flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5)) + theta[i] = arctan2(at1, at2 + s2) * 180 / pi + # To know where value are set + m_set[i] = False + # To stop particle advection + m[i0], m[i_n], m[i_e] = True, True, True + + +@njit(cache=True) +def build_triplet(x, y, step=0.02): + """ + Triplet building for each position we add east and north point with defined step + """ + nb_x = x.shape[0] + x_ = empty(nb_x * 3, dtype=x.dtype) + y_ = empty(nb_x * 3, dtype=y.dtype) + for i in range(nb_x): + i0 = i * 3 + i_n, i_e = i0 + 1, i0 + 2 + x__, y__ = x[i], y[i] + x_[i0], y_[i0] = x__, y__ + x_[i_n], y_[i_n] = x__, y__ + step + x_[i_e], y_[i_e] = x__ + step, y__ + return x_, y_ + + +# %% +# Settings +# -------- + +# Step in degrees for ouput +step_grid_out = 1 / 25.0 +# Initial separation in degrees +dist_init = 1 / 50.0 +# Final separation in degrees +dist_max = 1 / 5.0 +# Time of start +t0 = 20268 +# Number of time step by days +time_step_by_days = 5 +# Maximal time of advection +# Here we limit because our data cube cover only 3 month +nb_days = 85 +# Backward or forward +backward = True + +# %% +# Particles +# --------- +x0_, y0_ = -5, 30 +lon_p = arange(x0_, x0_ + 43, step_grid_out) +lat_p = arange(y0_, y0_ + 16, step_grid_out) +y0, x0 = meshgrid(lat_p, lon_p) +grid_shape = x0.shape +x0, y0 = x0.reshape(-1), y0.reshape(-1) +# Identify all particle not on land +m = ~isnan(c[t0].interp("adt", x0, y0)) +x0, y0 = x0[m], y0[m] + +# %% +# FSLE +# ---- + +# Array to compute fsle +fsle = zeros(x0.shape[0], dtype="f4") +theta = zeros(x0.shape[0], dtype="f4") +mask = ones(x0.shape[0], dtype="f4") +x, y = build_triplet(x0, y0, dist_init) +used = zeros(x.shape[0], dtype="bool") + +# advection generator +kw = dict( + t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name="u", v_name="v" +) +p = c.advect(x, y, time_step=86400 / time_step_by_days, **kw) + +# We check at each step of advection if particle distance is over `dist_max` +for i in range(time_step_by_days * nb_days): + t, xt, yt = p.__next__() + dt = t / 86400.0 - t0 + check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init) + +# Get index with original_position +i = ((x0 - x0_) / step_grid_out).astype("i4") +j = ((y0 - y0_) / step_grid_out).astype("i4") +fsle_ = empty(grid_shape, dtype="f4") +theta_ = empty(grid_shape, dtype="f4") +mask_ = ones(grid_shape, dtype="bool") +fsle_[i, j] = fsle +theta_[i, j] = theta +mask_[i, j] = mask +# Create a grid object +fsle_custom = RegularGridDataset.with_array( + coordinates=("lon", "lat"), + datas=dict( + fsle=ma.array(fsle_, mask=mask_), + theta=ma.array(theta_, mask=mask_), + lon=lon_p, + lat=lat_p, + ), + centered=True, +) + +# %% +# Display FSLE +# ------------ +fig = plt.figure(figsize=(13, 5), dpi=150) +ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) +ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) +ax.set_aspect("equal") +ax.set_title("Finite size lyapunov exponent", weight="bold") +kw = dict(cmap="viridis_r", vmin=-20, vmax=0) +m = fsle_custom.display(ax, 1 / fsle_custom.grid("fsle"), **kw) +ax.grid() +_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9])) +# %% +# Display Theta +# ------------- +fig = plt.figure(figsize=(13, 5), dpi=150) +ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) +ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) +ax.set_aspect("equal") +ax.set_title("Theta from finite size lyapunov exponent", weight="bold") +kw = dict(cmap="Spectral_r", vmin=-180, vmax=180) +m = fsle_custom.display(ax, fsle_custom.grid("theta"), **kw) +ax.grid() +_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9])) diff --git a/examples/07_cube_manipulation/pet_lavd_detection.py b/examples/07_cube_manipulation/pet_lavd_detection.py new file mode 100644 index 00000000..4dace120 --- /dev/null +++ b/examples/07_cube_manipulation/pet_lavd_detection.py @@ -0,0 +1,218 @@ +""" +LAVD detection and geometric detection +====================================== + +Naive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation). +In the current example we didn't remove a mean vorticity. + +Method are described here: + + - Abernathey, Ryan, and George Haller. "Transport by Lagrangian Vortices in the Eastern Pacific", + Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021, + https://doi.org/10.1175/JPO-D-17-0102.1 + - `Transport by Coherent Lagrangian Vortices`_, + R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019, + Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop + +.. _Transport by Coherent Lagrangian Vortices: + https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf + +""" +from datetime import datetime + +from matplotlib import pyplot as plt +from numpy import arange, isnan, ma, meshgrid, zeros + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES + +start_logger().setLevel("ERROR") + + +# %% +class LAVDGrid(RegularGridDataset): + def init_speed_coef(self, uname="u", vname="v"): + """Hack to be able to identify eddy with LAVD field""" + self._speed_ev = self.grid("lavd") + + @classmethod + def from_(cls, x, y, z): + z.mask += isnan(z.data) + datas = dict(lavd=z, lon=x, lat=y) + return cls.with_array(coordinates=("lon", "lat"), datas=datas, centered=True) + + +# %% +def start_ax(title="", dpi=90): + fig = plt.figure(figsize=(12, 5), dpi=dpi) + ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=GUI_AXES) + ax.set_xlim(-6, 36), ax.set_ylim(31, 45) + ax.set_title(title) + return fig, ax, ax.text(3, 32, "", fontsize=20) + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + cb = plt.colorbar( + mappable, + cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]), + orientation="horizontal", + ) + cb.set_label("LAVD at initial position") + return cb + + +kw_lavd = dict(vmin=0, vmax=2e-5, cmap="viridis") + +# %% +# Data +# ---- + +# Load data cube of 3 month +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + heigth="adt", +) + +# Add vorticity at each time step +for g in c: + u_y = g.compute_stencil(g.grid("u"), vertical=True) + v_x = g.compute_stencil(g.grid("v")) + g.vars["vort"] = v_x - u_y + +# %% +# Particles +# --------- + +# Time properties, for example with advection only 25 days +nb_days, step_by_day = 25, 6 +nb_time = step_by_day * nb_days +kw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name="u", v_name="v") +t0 = 20236 +t0_grid = c[t0] +# Geographic properties, we use a coarser resolution for time consuming reasons +step = 1 / 32.0 +x_g, y_g = arange(-6, 36, step), arange(30, 46, step) +x0, y0 = meshgrid(x_g, y_g) +original_shape = x0.shape +x0, y0 = x0.reshape(-1), y0.reshape(-1) +# Get all particles in defined area +m = ~isnan(t0_grid.interp("vort", x0, y0)) +x0, y0 = x0[m], y0[m] +print(f"{x0.size} particles advected") +# Gridded mask +m = m.reshape(original_shape) + +# %% +# LAVD forward (dynamic field) +# ---------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p) +for _ in range(nb_time): + t, x, y = p.__next__() + lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a forward advection") +mappable = lavd_forward.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# LAVD backward (dynamic field) +# ----------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p) +for i in range(nb_time): + t, x, y = p.__next__() + lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a backward advection") +mappable = lavd_backward.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# LAVD forward (static field) +# --------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = t0_grid.advect(x0.copy(), y0.copy(), **kw_p) +for _ in range(nb_time): + x, y = p.__next__() + lavd_ += abs(t0_grid.interp("vort", x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a forward advection on a static velocity field") +mappable = lavd_forward_static.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# LAVD backward (static field) +# ---------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p) +for i in range(nb_time): + x, y = p.__next__() + lavd_ += abs(t0_grid.interp("vort", x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a backward advection on a static velocity field") +mappable = lavd_backward_static.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# Contour detection +# ----------------- +# To extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options. +# It will produce false amplitude and speed. +kw_ident = dict( + force_speed_unit="m/s", + force_height_unit="m", + pixel_limit=(40, 200000), + date=datetime(2005, 5, 18), + uname=None, + vname=None, + grid_height="lavd", + shape_error=70, + step=1e-6, +) +fig, ax, _ = start_ax("Detection of eddies with several method") +t0_grid.bessel_high_filter("adt", 700) +a, c = t0_grid.eddy_identification( + "adt", "u", "v", kw_ident["date"], step=0.002, shape_error=70 +) +kw_ed = dict(ax=ax, intern=True, ref=-10) +a.filled( + facecolors="#FFEFCD", label="Anticyclonic SSH detection {nb_obs} eddies", **kw_ed +) +c.filled(facecolors="#DEDEDE", label="Cyclonic SSH detection {nb_obs} eddies", **kw_ed) +kw_cont = dict(ax=ax, extern_only=True, ls="-", ref=-10) +forward, _ = lavd_forward.eddy_identification(**kw_ident) +forward.display(label="LAVD forward {nb_obs} eddies", color="g", **kw_cont) +backward, _ = lavd_backward.eddy_identification(**kw_ident) +backward.display(label="LAVD backward {nb_obs} eddies", color="r", **kw_cont) +forward, _ = lavd_forward_static.eddy_identification(**kw_ident) +forward.display(label="LAVD forward static {nb_obs} eddies", color="cyan", **kw_cont) +backward, _ = lavd_backward_static.eddy_identification(**kw_ident) +backward.display( + label="LAVD backward static {nb_obs} eddies", color="orange", **kw_cont +) +ax.legend() +update_axes(ax) diff --git a/examples/07_cube_manipulation/pet_particles_drift.py b/examples/07_cube_manipulation/pet_particles_drift.py new file mode 100644 index 00000000..3d7aa1a4 --- /dev/null +++ b/examples/07_cube_manipulation/pet_particles_drift.py @@ -0,0 +1,46 @@ +""" +Build path of particle drifting +=============================== + +""" + +from matplotlib import pyplot as plt +from numpy import arange, meshgrid + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection + +start_logger().setLevel("ERROR") + +# %% +# Load data cube +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + unset=True, +) + +# %% +# Advection properties +nb_days, step_by_day = 10, 6 +nb_time = step_by_day * nb_days +kw_p = dict(nb_step=1, time_step=86400 / step_by_day) +t0 = 20210 + +# %% +# Get paths +x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5)) +x0, y0 = x0.reshape(-1), y0.reshape(-1) +t, x, y = c.path(x0, y0, h_name="adt", t_init=t0, **kw_p, nb_time=nb_time) + +# %% +# Plot paths +ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect="equal") +ax.plot(x0, y0, "k.", ms=20) +ax.plot(x, y, lw=3) +ax.set_title("10 days particle paths") +ax.set_xlim(31, 35), ax.set_ylim(32, 34.5) +ax.grid() diff --git a/examples/08_tracking_manipulation/README.rst b/examples/08_tracking_manipulation/README.rst index 73d43181..a971049f 100644 --- a/examples/08_tracking_manipulation/README.rst +++ b/examples/08_tracking_manipulation/README.rst @@ -1,2 +1,4 @@ Tracking Manipulation -======================= +===================== + +Method to subset and display atlas. \ No newline at end of file diff --git a/examples/08_tracking_manipulation/pet_display_field.py b/examples/08_tracking_manipulation/pet_display_field.py index c21c25f4..b943a2ba 100644 --- a/examples/08_tracking_manipulation/pet_display_field.py +++ b/examples/08_tracking_manipulation/pet_display_field.py @@ -5,13 +5,14 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% # Load an experimental cyclonic atlas, we keep only eddies which are follow more than 180 days c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) c = c.extract_with_length((180, -1)) diff --git a/examples/08_tracking_manipulation/pet_display_track.py b/examples/08_tracking_manipulation/pet_display_track.py index ef74218a..b15d51d7 100644 --- a/examples/08_tracking_manipulation/pet_display_track.py +++ b/examples/08_tracking_manipulation/pet_display_track.py @@ -5,19 +5,27 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% -# Load experimental atlas, and keep only eddies longer than 20 weeks +# Load experimental atlas a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) +print(a) + +# %% +# keep only eddies longer than 20 weeks, use -1 to have no upper limit a = a.extract_with_length((7 * 20, -1)) c = c.extract_with_length((7 * 20, -1)) +print(a) # %% # Position filtering for nice display @@ -30,7 +38,7 @@ ax = fig.add_axes((0.05, 0.1, 0.9, 0.9)) ax.set_aspect("equal") ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) -a.plot(ax, ref=-10, label="Anticyclonic", color="r", lw=0.1) -c.plot(ax, ref=-10, label="Cyclonic", color="b", lw=0.1) +a.plot(ax, ref=-10, label="Anticyclonic ({nb_tracks} tracks)", color="r", lw=0.1) +c.plot(ax, ref=-10, label="Cyclonic ({nb_tracks} tracks)", color="b", lw=0.1) ax.legend() ax.grid() diff --git a/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py new file mode 100644 index 00000000..8161ad81 --- /dev/null +++ b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py @@ -0,0 +1,94 @@ +""" +Correspondances +=============== + +Correspondances is a mechanism to intend to continue tracking with new detection + +""" + +import logging + +# %% +from matplotlib import pyplot as plt +from netCDF4 import Dataset + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker + +# In order to hide some warning +import py_eddy_tracker.observations.observation +from py_eddy_tracker.tracking import Correspondances + +py_eddy_tracker.observations.observation._display_check_warning = False + + +# %% +def plot_eddy(ed): + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_axes([0.05, 0.03, 0.90, 0.94]) + ed.plot(ax, ref=-10, marker="x") + lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True) + plt.colorbar(lc).set_label("Time in Julian days (from 1950/01/01)") + ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3) + ax.set_aspect("equal") + ax.grid() + + +# %% +# Get remote data, we will keep only 20 first days, +# `get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename +# and don't mix cyclonic and anticyclonic files. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:20] + +# %% +# We run a traking with a tracker which use contour overlap, on 10 first time step +c_first_run = Correspondances( + datasets=file_objects[:10], class_method=AreaTracker, virtual=4 +) +start_logger().setLevel("INFO") +c_first_run.track() +start_logger().setLevel("WARNING") +with Dataset("correspondances.nc", "w") as h: + c_first_run.to_netcdf(h) +# Next step are done only to build atlas and display it +c_first_run.prepare_merging() + +# We have now an eddy object +eddies_area_tracker = c_first_run.merge(raw_data=False) +eddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0 +eddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1) + +# %% +# Plot from first ten days +plot_eddy(eddies_area_tracker) + +# %% +# Restart from previous run +# ------------------------- +# We give all filenames, the new one and filename from previous run +c_second_run = Correspondances( + datasets=file_objects[:20], + # This parameter must be identical in each run + class_method=AreaTracker, + virtual=4, + # Previous saved correspondancs + previous_correspondance="correspondances.nc", +) +start_logger().setLevel("INFO") +c_second_run.track() +start_logger().setLevel("WARNING") +c_second_run.prepare_merging() +# We have now another eddy object +eddies_area_tracker_extend = c_second_run.merge(raw_data=False) +eddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0 +eddies_area_tracker_extend.filled_by_interpolation( + eddies_area_tracker_extend.virtual == 1 +) + + +# %% +# Plot with time extension +plot_eddy(eddies_area_tracker_extend) diff --git a/examples/08_tracking_manipulation/pet_one_track.py b/examples/08_tracking_manipulation/pet_one_track.py index 86097483..a2536c34 100644 --- a/examples/08_tracking_manipulation/pet_one_track.py +++ b/examples/08_tracking_manipulation/pet_one_track.py @@ -3,13 +3,16 @@ =================== """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% # Load experimental atlas, and we select one eddy a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) eddy = a.extract_ids([9672]) eddy_f = a.extract_ids([9672]) @@ -35,6 +38,6 @@ ax.grid() eddy.plot(ax, color="r", lw=0.5, label="track") eddy.index(range(0, len(eddy), 40)).display( - ax, intern_only=True, label="observations every 40" + ax, intern_only=True, label="observations every 40 days" ) ax.legend() diff --git a/examples/08_tracking_manipulation/pet_run_a_tracking.py b/examples/08_tracking_manipulation/pet_run_a_tracking.py index 8349bb05..15d8b18b 100644 --- a/examples/08_tracking_manipulation/pet_run_a_tracking.py +++ b/examples/08_tracking_manipulation/pet_run_a_tracking.py @@ -7,15 +7,16 @@ # %% -from py_eddy_tracker.data import get_remote_sample -from py_eddy_tracker.tracking import Correspondances +from py_eddy_tracker.data import get_remote_demo_sample from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker from py_eddy_tracker.gui import GUI - +from py_eddy_tracker.tracking import Correspondances # %% -# Get remote data, we will keep only 180 first days -file_objects = get_remote_sample( +# Get remote data, we will keep only 180 first days, +# `get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename +# and don't mix cyclonic and anticyclonic files. +file_objects = get_remote_demo_sample( "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" )[:180] @@ -26,8 +27,8 @@ c.prepare_merging() # We have now an eddy object eddies_area_tracker = c.merge(raw_data=False) -eddies_area_tracker["virtual"][:] = eddies_area_tracker["time"] == 0 -eddies_area_tracker.filled_by_interpolation(eddies_area_tracker["virtual"] == 1) +eddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0 +eddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1) # %% # We run a traking with default tracker @@ -35,8 +36,8 @@ c.track() c.prepare_merging() eddies_default_tracker = c.merge(raw_data=False) -eddies_default_tracker["virtual"][:] = eddies_default_tracker["time"] == 0 -eddies_default_tracker.filled_by_interpolation(eddies_default_tracker["virtual"] == 1) +eddies_default_tracker.virtual[:] = eddies_default_tracker.time == 0 +eddies_default_tracker.filled_by_interpolation(eddies_default_tracker.virtual == 1) # %% # Start GUI to compare tracking diff --git a/examples/08_tracking_manipulation/pet_select_track_across_area.py b/examples/08_tracking_manipulation/pet_select_track_across_area.py index ca3c3c72..58184e1f 100644 --- a/examples/08_tracking_manipulation/pet_select_track_across_area.py +++ b/examples/08_tracking_manipulation/pet_select_track_across_area.py @@ -4,13 +4,14 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% # Load experimental atlas, we filter position to have nice display c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) c.position_filter(median_half_window=1, loess_half_window=5) @@ -28,11 +29,13 @@ ax.set_ylim(36, 40) ax.set_aspect("equal") ax.grid() -c.plot(ax, color="gray", lw=0.1, ref=-10, label="all tracks") -c_subset.plot(ax, color="red", lw=0.2, ref=-10, label="selected tracks") +c.plot(ax, color="gray", lw=0.1, ref=-10, label="All tracks ({nb_tracks} tracks)") +c_subset.plot( + ax, color="red", lw=0.2, ref=-10, label="selected tracks ({nb_tracks} tracks)" +) ax.plot( - (x0, x0, x1, x1, x0,), - (y0, y1, y1, y0, y0,), + (x0, x0, x1, x1, x0), + (y0, y1, y1, y0, y0), color="green", lw=1.5, label="Box of selection", diff --git a/examples/08_tracking_manipulation/pet_track_anim.py b/examples/08_tracking_manipulation/pet_track_anim.py index 36d847a3..94e09ad3 100644 --- a/examples/08_tracking_manipulation/pet_track_anim.py +++ b/examples/08_tracking_manipulation/pet_track_anim.py @@ -2,29 +2,35 @@ Track animation =============== -Run in a terminal this script, which allow to watch eddy evolution +Run in a terminal this script, which allow to watch eddy evolution. + +You could use also *EddyAnim* script to display/save animation. """ -from py_eddy_tracker.observations.tracking import TrackEddiesObservations -from py_eddy_tracker.appli.gui import Anim import py_eddy_tracker_sample +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% # Load experimental atlas, and we select one eddy a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) -eddy = a.extract_ids([9672]) +# We get only 300 first step to save time of documentation builder +eddy = a.extract_ids([9672]).index(slice(0, 300)) # %% # Run animation -# Key shortcut -# Escape => exit -# SpaceBar => pause -# left arrow => t - 1 -# right arrow => t + 1 -# + => speed increase of 10 % -# - => speed decrease of 10 % +# Key shortcut : +# * Escape => exit +# * SpaceBar => pause +# * left arrow => t - 1 +# * right arrow => t + 1 +# * \+ => speed increase of 10 % +# * \- => speed decrease of 10 % a = Anim(eddy, sleep_event=1e-10, intern=True, figsize=(8, 3.5), cmap="viridis") a.txt.set_position((17, 34.6)) a.ax.set_xlim(16.5, 23) diff --git a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py new file mode 100644 index 00000000..b686fd67 --- /dev/null +++ b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py @@ -0,0 +1,60 @@ +""" +Track animation with standard matplotlib +======================================== + +Run in a terminal this script, which allow to watch eddy evolution. + +You could use also *EddyAnim* script to display/save animation. + +""" +import re + +from matplotlib.animation import FuncAnimation +from numpy import arange +import py_eddy_tracker_sample + +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + +# sphinx_gallery_thumbnail_path = '_static/no_image.png' + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Load experimental atlas, and we select one eddy +a = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) +) +eddy = a.extract_ids([9672]) + +# %% +# Run animation +a = Anim(eddy, intern=True, figsize=(8, 3.5), cmap="magma_r", nb_step=5, dpi=50) +a.txt.set_position((17, 34.6)) +a.ax.set_xlim(16.5, 23) +a.ax.set_ylim(34.5, 37) + +# arguments to get full animation +kwargs = dict(frames=arange(*a.period)[300:800], interval=90) + +ani = VideoAnimation(a.fig, a.func_animation, **kwargs) diff --git a/examples/10_tracking_diagnostics/README.rst b/examples/10_tracking_diagnostics/README.rst index 00178ba7..2030c0cc 100644 --- a/examples/10_tracking_diagnostics/README.rst +++ b/examples/10_tracking_diagnostics/README.rst @@ -1,2 +1,4 @@ Tracking diagnostics -======================= +==================== + +Method to produce statistics with eddies atlas. \ No newline at end of file diff --git a/examples/10_tracking_diagnostics/pet_birth_and_death.py b/examples/10_tracking_diagnostics/pet_birth_and_death.py new file mode 100644 index 00000000..b67993a2 --- /dev/null +++ b/examples/10_tracking_diagnostics/pet_birth_and_death.py @@ -0,0 +1,84 @@ +""" +Birth and death +=============== + +Following figures are based on https://doi.org/10.1016/j.pocean.2011.01.002 + +""" +from matplotlib import pyplot as plt +import py_eddy_tracker_sample + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title) + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + + +# %% +# Load an experimental med atlas over a period of 26 years (1993-2019) +kwargs_load = dict( + include_vars=( + "longitude", + "latitude", + "observation_number", + "track", + "time", + "speed_contour_longitude", + "speed_contour_latitude", + ) +) +a = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) +) +c = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") +) + +# %% +t0, t1 = a.period +step = 0.125 +bins = ((-10, 37, step), (30, 46, step)) +kwargs = dict(cmap="terrain_r", factor=100 / (t1 - t0), name="count", vmin=0, vmax=1) + +# %% +# Cyclonic +# -------- +ax = start_axes("Birth cyclonic frenquency (%)") +g_c_first = c.first_obs().grid_count(bins, intern=True) +m = g_c_first.display(ax, **kwargs) +update_axes(ax, m) + +# %% +ax = start_axes("Death cyclonic frenquency (%)") +g_c_last = c.last_obs().grid_count(bins, intern=True) +m = g_c_last.display(ax, **kwargs) +update_axes(ax, m) + +# %% +# Anticyclonic +# ------------ +ax = start_axes("Birth anticyclonic frequency (%)") +g_a_first = a.first_obs().grid_count(bins, intern=True) +m = g_a_first.display(ax, **kwargs) +update_axes(ax, m) + +# %% +ax = start_axes("Death anticyclonic frequency (%)") +g_a_last = a.last_obs().grid_count(bins, intern=True) +m = g_a_last.display(ax, **kwargs) +update_axes(ax, m) diff --git a/examples/10_tracking_diagnostics/pet_center_count.py b/examples/10_tracking_diagnostics/pet_center_count.py index cdc8189c..77a4dcda 100644 --- a/examples/10_tracking_diagnostics/pet_center_count.py +++ b/examples/10_tracking_diagnostics/pet_center_count.py @@ -1,28 +1,33 @@ """ Count center -====================== +============ +Do Geo stat with center and compare with frequency method +show: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py` """ from matplotlib import pyplot as plt from matplotlib.colors import LogNorm -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) +# %% # Parameters -t0, t1 = a.period -step = 0.1 +step = 0.125 bins = ((-10, 37, step), (30, 46, step)) kwargs_pcolormesh = dict( - cmap="terrain_r", vmin=0, vmax=2, factor=1 / (step ** 2 * (t1 - t0)), name="count" + cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step**2), name="count" ) @@ -58,10 +63,32 @@ cb.set_label("Eddies by 1°^2 by day") g_c.vars["count"] = ratio -m = g_c.display(ax_ratio, name="count", vmin=0.1, vmax=10, norm=LogNorm(), cmap='coolwarm_r') +m = g_c.display( + ax_ratio, name="count", norm=LogNorm(vmin=0.1, vmax=10), cmap="coolwarm_r" +) plt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.2])) for ax in (ax_a, ax_c, ax_all, ax_ratio): ax.set_aspect("equal") ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) ax.grid() + +# %% +# Count Anticyclones as a function of lifetime +# -------------------------------------------- +# Count at the center's position + +fig = plt.figure(figsize=(12, 10)) +mask = a.lifetime >= 60 +ax_long = fig.add_axes([0.03, 0.53, 0.90, 0.45]) +g_a = a.grid_count(bins, center=True, filter=mask) +g_a.display(ax_long, **kwargs_pcolormesh) +ax_long.set_title(f"Anticyclones with lifetime >= 60 days ({mask.sum()} Obs)") +ax_short = fig.add_axes([0.03, 0.03, 0.90, 0.45]) +g_a = a.grid_count(bins, center=True, filter=~mask) +m = g_a.display(ax_short, **kwargs_pcolormesh) +ax_short.set_title(f"Anticyclones with lifetime < 60 days ({(~mask).sum()} Obs)") +for ax in (ax_short, ax_long): + ax.set_aspect("equal"), ax.grid() + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) +cb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.015, 0.9])) diff --git a/examples/10_tracking_diagnostics/pet_geographic_stats.py b/examples/10_tracking_diagnostics/pet_geographic_stats.py index d7b79a43..a2e3f6b5 100644 --- a/examples/10_tracking_diagnostics/pet_geographic_stats.py +++ b/examples/10_tracking_diagnostics/pet_geographic_stats.py @@ -5,9 +5,10 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + def start_axes(title): fig = plt.figure(figsize=(13.5, 5)) @@ -21,10 +22,12 @@ def start_axes(title): # %% # Load an experimental med atlas over a period of 26 years (1993-2019), we merge the 2 datasets a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) a = a.merge(c) diff --git a/examples/10_tracking_diagnostics/pet_groups.py b/examples/10_tracking_diagnostics/pet_groups.py new file mode 100644 index 00000000..deedcc3f --- /dev/null +++ b/examples/10_tracking_diagnostics/pet_groups.py @@ -0,0 +1,81 @@ +""" +Groups distribution +=================== + +""" +from matplotlib import pyplot as plt +from numpy import arange, ones, percentile +import py_eddy_tracker_sample + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + +# %% +# Load an experimental med atlas over a period of 26 years (1993-2019) +a = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) +) + +# %% +# Group distribution +groups = dict() +bins_time = [10, 20, 30, 60, 90, 180, 360, 100000] +for t0, t1 in zip(bins_time[:-1], bins_time[1:]): + groups[f"lifetime_{t0}_{t1}"] = lambda dataset, t0=t0, t1=t1: ( + dataset.lifetime >= t0 + ) * (dataset.lifetime < t1) +bins_percentile = arange(0, 100.0001, 5) + + +# %% +# Function to build stats +def stats_compilation(dataset, groups, field, bins, filter=None): + datas = dict(ref=dataset.bins_stat(field, bins=bins, mask=filter)[1], y=dict()) + for k, index in groups.items(): + i = dataset.merge_filters(filter, index) + x, datas["y"][k] = dataset.bins_stat(field, bins=bins, mask=i) + datas["x"], datas["bins"] = x, bins + return datas + + +def plot_stats(ax, bins, x, y, ref, box=False, cmap=None, percentiles=None, **kw): + base, ref = ones(x.shape) * 100.0, ref / 100.0 + x = arange(bins.shape[0]).repeat(2)[1:-1] if box else x + y0 = base + if cmap is not None: + cmap, nb_groups = plt.get_cmap(cmap), len(y) + keys = tuple(y.keys()) + for i, k in enumerate(keys[::-1]): + y1 = y0 - y[k] / ref + args = (y0.repeat(2), y1.repeat(2)) if box else (y0, y1) + if cmap is not None: + kw["color"] = cmap(1 - i / (nb_groups - 1)) + ax.fill_between(x, *args, label=k, **kw) + y0 = y1 + if percentiles: + for b in bins: + ax.axvline(b, **percentiles) + + +# %% +# Speed radius by track period +stats = stats_compilation( + a, groups, "radius_s", percentile(a.radius_s, bins_percentile) +) +fig = plt.figure() +ax = fig.add_subplot(111) +plot_stats(ax, **stats, cmap="magma", percentiles=dict(color="gray", ls="-.", lw=0.4)) +ax.set_xlabel("Speed radius (m)"), ax.set_ylabel("% of class"), ax.set_ylim(0, 100) +ax.grid(), ax.legend() + +# %% +# Amplitude by track period +stats = stats_compilation( + a, groups, "amplitude", percentile(a.amplitude, bins_percentile) +) +fig = plt.figure() +ax = fig.add_subplot(111) +plot_stats(ax, **stats, cmap="magma") +ax.set_xlabel("Amplitude (m)"), ax.set_ylabel("% of class"), ax.set_ylim(0, 100) +ax.grid(), ax.legend() diff --git a/examples/10_tracking_diagnostics/pet_histo.py b/examples/10_tracking_diagnostics/pet_histo.py index 68050016..abf97c38 100644 --- a/examples/10_tracking_diagnostics/pet_histo.py +++ b/examples/10_tracking_diagnostics/pet_histo.py @@ -4,17 +4,20 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations -import py_eddy_tracker_sample from numpy import arange +import py_eddy_tracker_sample + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) kwargs_a = dict(label="Anticyclonic", color="r", histtype="step", density=True) kwargs_c = dict(label="Cyclonic", color="b", histtype="step", density=True) @@ -25,7 +28,7 @@ for x0, name, title, xmax, factor, bins in zip( (0.4, 0.72, 0.08), - ("radius_s", "speed_average", "amplitude"), + ("speed_radius", "speed_average", "amplitude"), ("Speed radius (km)", "Speed average (cm/s)", "Amplitude (cm)"), (100, 50, 20), (0.001, 100, 100), diff --git a/examples/10_tracking_diagnostics/pet_lifetime.py b/examples/10_tracking_diagnostics/pet_lifetime.py index 7f2a22c9..4e2500fd 100644 --- a/examples/10_tracking_diagnostics/pet_lifetime.py +++ b/examples/10_tracking_diagnostics/pet_lifetime.py @@ -4,51 +4,56 @@ """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations +from numpy import arange, ones import py_eddy_tracker_sample -from numpy import arange + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) +nb_year = (a.period[1] - a.period[0] + 1) / 365.25 # %% -# Plot -fig = plt.figure() -ax_lifetime = fig.add_axes([0.05, 0.55, 0.4, 0.4]) -ax_cum_lifetime = fig.add_axes([0.55, 0.55, 0.4, 0.4]) -ax_ratio_lifetime = fig.add_axes([0.05, 0.05, 0.4, 0.4]) -ax_ratio_cum_lifetime = fig.add_axes([0.55, 0.05, 0.4, 0.4]) - -cum_a, bins, _ = ax_cum_lifetime.hist( - a["n"], histtype="step", bins=arange(0, 800, 1), label="Anticyclonic", color="r" -) -cum_c, bins, _ = ax_cum_lifetime.hist( - c["n"], histtype="step", bins=arange(0, 800, 1), label="Cyclonic", color="b" -) - -x = (bins[1:] + bins[:-1]) / 2.0 -ax_ratio_cum_lifetime.plot(x, cum_c / cum_a) - -nb_a, nb_c = cum_a[:-1] - cum_a[1:], cum_c[:-1] - cum_c[1:] -ax_lifetime.plot(x[1:], nb_a, label="Anticyclonic", color="r") -ax_lifetime.plot(x[1:], nb_c, label="Cyclonic", color="b") - -ax_ratio_lifetime.plot(x[1:], nb_c / nb_a) - -for ax in (ax_lifetime, ax_cum_lifetime, ax_ratio_cum_lifetime, ax_ratio_lifetime): - ax.set_xlim(0, 365) - if ax in (ax_lifetime, ax_cum_lifetime): - ax.set_ylim(1, None) - ax.set_yscale("log") - ax.legend() +# Setup axes +figure = plt.figure(figsize=(12, 8)) +ax_ratio_cum = figure.add_axes([0.55, 0.06, 0.42, 0.34]) +ax_ratio = figure.add_axes([0.07, 0.06, 0.46, 0.34]) +ax_cum = figure.add_axes([0.55, 0.43, 0.42, 0.54]) +ax = figure.add_axes([0.07, 0.43, 0.46, 0.54]) +ax.set_ylabel("Eddies by year") +ax_ratio.set_ylabel("Ratio Cyclonic/Anticyclonic") +for ax_ in (ax, ax_cum, ax_ratio_cum, ax_ratio): + ax_.set_xlim(0, 400) + if ax_ in (ax, ax_cum): + ax_.set_ylim(1e-1, 1e4), ax_.set_yscale("log") else: - ax.set_ylim(0, 2) - ax.set_ylabel("Ratio Cyclonic/Anticyclonic") - ax.set_xlabel("Lifetime (days)") - ax.grid() + ax_.set_xlabel("Lifetime in days (by week bins)") + ax_.set_ylim(0, 2) + ax_.axhline(1, color="g", lw=2) + ax_.grid() +ax_cum.xaxis.set_ticklabels([]), ax_cum.yaxis.set_ticklabels([]) +ax.xaxis.set_ticklabels([]), ax_ratio_cum.yaxis.set_ticklabels([]) + +# plot data +bin_hist = arange(7, 2000, 7) +x = (bin_hist[1:] + bin_hist[:-1]) / 2.0 +a_nb, c_nb = a.nb_obs_by_track, c.nb_obs_by_track +a_nb, c_nb = a_nb[a_nb != 0], c_nb[c_nb != 0] +w_a, w_c = ones(a_nb.shape) / nb_year, ones(c_nb.shape) / nb_year +kwargs_a = dict(histtype="step", bins=bin_hist, x=a_nb, color="r", weights=w_a) +kwargs_c = dict(histtype="step", bins=bin_hist, x=c_nb, color="b", weights=w_c) +cum_a, _, _ = ax_cum.hist(cumulative=-1, **kwargs_a) +cum_c, _, _ = ax_cum.hist(cumulative=-1, **kwargs_c) +nb_a, _, _ = ax.hist(label="Anticyclonic", **kwargs_a) +nb_c, _, _ = ax.hist(label="Cyclonic", **kwargs_c) +ax_ratio_cum.plot(x, cum_c / cum_a) +ax_ratio.plot(x, nb_c / nb_a) +ax.legend() diff --git a/examples/10_tracking_diagnostics/pet_normalised_lifetime.py b/examples/10_tracking_diagnostics/pet_normalised_lifetime.py new file mode 100644 index 00000000..1c84a8cc --- /dev/null +++ b/examples/10_tracking_diagnostics/pet_normalised_lifetime.py @@ -0,0 +1,78 @@ +""" +Normalised Eddy Lifetimes +========================= + +Example from Evan Mason +""" +from matplotlib import pyplot as plt +from numba import njit +from numpy import interp, linspace, zeros +from py_eddy_tracker_sample import get_demo_path + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +@njit(cache=True) +def sum_profile(x_new, y, out): + """Will sum all interpolated given array""" + out += interp(x_new, linspace(0, 1, y.size), y) + + +class MyObs(TrackEddiesObservations): + def eddy_norm_lifetime(self, name, nb, factor=1): + """ + :param str,array name: Array or field name + :param int nb: size of output array + """ + y = self.parse_varname(name) + x = linspace(0, 1, nb) + out = zeros(nb, dtype=y.dtype) + nb_track = 0 + for i, b0, b1 in self.iter_on("track"): + y_ = y[i] + size_ = y_.size + if size_ == 0: + continue + sum_profile(x, y_, out) + nb_track += 1 + return x, out / nb_track * factor + + +# %% +# Load atlas +# ---------- +kw = dict(include_vars=("speed_radius", "amplitude", "track")) +a = MyObs.load_file( + get_demo_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr"), **kw +) +c = MyObs.load_file(get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr"), **kw) + +nb_max_a = a.nb_obs_by_track.max() +nb_max_c = c.nb_obs_by_track.max() + +# %% +# Compute normalised lifetime +# --------------------------- + +# Radius +AC_radius = a.eddy_norm_lifetime("speed_radius", nb=nb_max_a, factor=1e-3) +CC_radius = c.eddy_norm_lifetime("speed_radius", nb=nb_max_c, factor=1e-3) +# Amplitude +AC_amplitude = a.eddy_norm_lifetime("amplitude", nb=nb_max_a, factor=1e2) +CC_amplitude = c.eddy_norm_lifetime("amplitude", nb=nb_max_c, factor=1e2) + +# %% +# Figure +# ------ +fig, (ax0, ax1) = plt.subplots(nrows=2, figsize=(8, 6)) + +ax0.set_title("Normalised Mean Radius") +ax0.plot(*AC_radius), ax0.plot(*CC_radius) +ax0.set_ylabel("Radius (km)"), ax0.grid() +ax0.set_xlim(0, 1), ax0.set_ylim(0, None) + +ax1.set_title("Normalised Mean Amplitude") +ax1.plot(*AC_amplitude, label="AC"), ax1.plot(*CC_amplitude, label="CC") +ax1.set_ylabel("Amplitude (cm)"), ax1.grid(), ax1.legend() +_ = ax1.set_xlim(0, 1), ax1.set_ylim(0, None) diff --git a/examples/10_tracking_diagnostics/pet_pixel_used.py b/examples/10_tracking_diagnostics/pet_pixel_used.py index f08c41e4..75a826d6 100644 --- a/examples/10_tracking_diagnostics/pet_pixel_used.py +++ b/examples/10_tracking_diagnostics/pet_pixel_used.py @@ -1,28 +1,36 @@ """ Count pixel used -====================== +================ +Do Geo stat with frequency and compare with center count +method: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py` """ from matplotlib import pyplot as plt from matplotlib.colors import LogNorm -from py_eddy_tracker.observations.tracking import TrackEddiesObservations import py_eddy_tracker_sample +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) -t0, t1 = a.period -step = 0.1 + +# %% +# Parameters +step = 0.125 bins = ((-10, 37, step), (30, 46, step)) kwargs_pcolormesh = dict( - cmap="terrain_r", vmin=0, vmax=0.75, factor=1 / (t1 - t0), name="count" + cmap="terrain_r", vmin=0, vmax=0.75, factor=1 / a.nb_days, name="count" ) + # %% # Plot fig = plt.figure(figsize=(12, 18.5)) @@ -54,10 +62,30 @@ plt.colorbar(m, cax=fig.add_axes([0.95, 0.27, 0.01, 0.7])) g_c.vars["count"] = ratio -m = g_c.display(ax_ratio, name="count", vmin=0.1, vmax=10, norm=LogNorm(), cmap='coolwarm_r') +m = g_c.display( + ax_ratio, name="count", norm=LogNorm(vmin=0.1, vmax=10), cmap="coolwarm_r" +) plt.colorbar(m, cax=fig.add_axes([0.95, 0.02, 0.01, 0.2])) for ax in (ax_a, ax_c, ax_all, ax_ratio): ax.set_aspect("equal") ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) ax.grid() + +# %% +# Count Anticyclones as a function of lifetime +# -------------------------------------------- +fig = plt.figure(figsize=(12, 10)) +mask = a.lifetime >= 60 +ax_long = fig.add_axes([0.03, 0.53, 0.90, 0.45]) +g_a = a.grid_count(bins, intern=True, filter=mask) +g_a.display(ax_long, **kwargs_pcolormesh) +ax_long.set_title(f"Anticyclones with lifetime >= 60 days ({mask.sum()} Obs)") +ax_short = fig.add_axes([0.03, 0.03, 0.90, 0.45]) +g_a = a.grid_count(bins, intern=True, filter=~mask) +m = g_a.display(ax_short, **kwargs_pcolormesh) +ax_short.set_title(f"Anticyclones with lifetime < 60 days ({(~mask).sum()} Obs)") +for ax in (ax_short, ax_long): + ax.set_aspect("equal"), ax.grid() + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) +cb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.015, 0.9])) diff --git a/examples/10_tracking_diagnostics/pet_propagation.py b/examples/10_tracking_diagnostics/pet_propagation.py index 49a3b532..e6bc6c1b 100644 --- a/examples/10_tracking_diagnostics/pet_propagation.py +++ b/examples/10_tracking_diagnostics/pet_propagation.py @@ -1,42 +1,26 @@ """ Propagation Histogram -=================== +===================== """ from matplotlib import pyplot as plt -from py_eddy_tracker.observations.tracking import TrackEddiesObservations -from py_eddy_tracker.generic import distance +from numpy import arange, ones import py_eddy_tracker_sample -from numpy import arange, empty -from numba import njit - - -# %% -# We will create a function compile with numba, to compute a field which contains curvilign distance -@njit(cache=True) -def cum_distance_by_track(distance, track): - tr_previous = 0 - d_cum = 0 - new_distance = empty(track.shape, dtype=distance.dtype) - for i in range(distance.shape[0]): - tr = track[i] - if i != 0 and tr != tr_previous: - d_cum = 0 - new_distance[i] = d_cum - d_cum += distance[i] - tr_previous = tr - new_distance[i + 1] = d_cum - return new_distance +from py_eddy_tracker.generic import cumsum_by_track +from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) +nb_year = (a.period[1] - a.period[0] + 1) / 365.25 # %% # Filtering position to remove noisy position @@ -45,49 +29,42 @@ def cum_distance_by_track(distance, track): # %% # Compute curvilign distance -d_a = distance(a.longitude[:-1], a.latitude[:-1], a.longitude[1:], a.latitude[1:]) -d_c = distance(c.longitude[:-1], c.latitude[:-1], c.longitude[1:], c.latitude[1:]) -d_a = cum_distance_by_track(d_a, a["track"]) / 1000.0 -d_c = cum_distance_by_track(d_c, c["track"]) / 1000.0 +i0, nb = a.index_from_track, a.nb_obs_by_track +d_a = cumsum_by_track(a.distance_to_next(), a.tracks)[(i0 - 1 + nb)[nb != 0]] / 1000.0 +i0, nb = c.index_from_track, c.nb_obs_by_track +d_c = cumsum_by_track(c.distance_to_next(), c.tracks)[(i0 - 1 + nb)[nb != 0]] / 1000.0 # %% -# Plot -fig = plt.figure() -ax_propagation = fig.add_axes([0.05, 0.55, 0.4, 0.4]) -ax_cum_propagation = fig.add_axes([0.55, 0.55, 0.4, 0.4]) -ax_ratio_propagation = fig.add_axes([0.05, 0.05, 0.4, 0.4]) -ax_ratio_cum_propagation = fig.add_axes([0.55, 0.05, 0.4, 0.4]) - -bins = arange(0, 1500, 10) -cum_a, bins, _ = ax_cum_propagation.hist( - d_a, histtype="step", bins=bins, label="Anticyclonic", color="r" -) -cum_c, bins, _ = ax_cum_propagation.hist( - d_c, histtype="step", bins=bins, label="Cyclonic", color="b" -) - -x = (bins[1:] + bins[:-1]) / 2.0 -ax_ratio_cum_propagation.plot(x, cum_c / cum_a) - -nb_a, nb_c = cum_a[:-1] - cum_a[1:], cum_c[:-1] - cum_c[1:] -ax_propagation.plot(x[1:], nb_a, label="Anticyclonic", color="r") -ax_propagation.plot(x[1:], nb_c, label="Cyclonic", color="b") - -ax_ratio_propagation.plot(x[1:], nb_c / nb_a) - -for ax in ( - ax_propagation, - ax_cum_propagation, - ax_ratio_cum_propagation, - ax_ratio_propagation, -): - ax.set_xlim(0, 1000) - if ax in (ax_propagation, ax_cum_propagation): - ax.set_ylim(1, None) - ax.set_yscale("log") - ax.legend() +# Setup axes +figure = plt.figure(figsize=(12, 8)) +ax_ratio_cum = figure.add_axes([0.55, 0.06, 0.42, 0.34]) +ax_ratio = figure.add_axes([0.07, 0.06, 0.46, 0.34]) +ax_cum = figure.add_axes([0.55, 0.43, 0.42, 0.54]) +ax = figure.add_axes([0.07, 0.43, 0.46, 0.54]) +ax.set_ylabel("Eddies by year") +ax_ratio.set_ylabel("Ratio Cyclonic/Anticyclonic") +for ax_ in (ax, ax_cum, ax_ratio_cum, ax_ratio): + ax_.set_xlim(0, 1000) + if ax_ in (ax, ax_cum): + ax_.set_ylim(1e-1, 1e4), ax_.set_yscale("log") else: - ax.set_ylim(0, 2) - ax.set_ylabel("Ratio Cyclonic/Anticyclonic") - ax.set_xlabel("Propagation (km)") - ax.grid() + ax_.set_xlabel("Propagation in km (with bins of 20 km)") + ax_.set_ylim(0, 2) + ax_.axhline(1, color="g", lw=2) + ax_.grid() +ax_cum.xaxis.set_ticklabels([]), ax_cum.yaxis.set_ticklabels([]) +ax.xaxis.set_ticklabels([]), ax_ratio_cum.yaxis.set_ticklabels([]) + +# plot data +bin_hist = arange(0, 2000, 20) +x = (bin_hist[1:] + bin_hist[:-1]) / 2.0 +w_a, w_c = ones(d_a.shape) / nb_year, ones(d_c.shape) / nb_year +kwargs_a = dict(histtype="step", bins=bin_hist, x=d_a, color="r", weights=w_a) +kwargs_c = dict(histtype="step", bins=bin_hist, x=d_c, color="b", weights=w_c) +cum_a, _, _ = ax_cum.hist(cumulative=-1, **kwargs_a) +cum_c, _, _ = ax_cum.hist(cumulative=-1, **kwargs_c) +nb_a, _, _ = ax.hist(label="Anticyclonic", **kwargs_a) +nb_c, _, _ = ax.hist(label="Cyclonic", **kwargs_c) +ax_ratio_cum.plot(x, cum_c / cum_a) +ax_ratio.plot(x, nb_c / nb_a) +ax.legend() diff --git a/examples/12_external_data/README.rst b/examples/12_external_data/README.rst new file mode 100644 index 00000000..7ecbe30b --- /dev/null +++ b/examples/12_external_data/README.rst @@ -0,0 +1,4 @@ +External data +============= + +Eddies comparison with external data diff --git a/examples/12_external_data/pet_SST_collocation.py b/examples/12_external_data/pet_SST_collocation.py new file mode 100644 index 00000000..defe00df --- /dev/null +++ b/examples/12_external_data/pet_SST_collocation.py @@ -0,0 +1,128 @@ +""" +Collocating external data +========================= + +Script will use py-eddy-tracker methods to upload external data (sea surface temperature, SST) +in a common structure with altimetry. + +Figures higlights the different steps. +""" + +from datetime import datetime + +from matplotlib import pyplot as plt + +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset + +date = datetime(2016, 7, 7) + +filename_alt = data.get_demo_path( + f"dt_blacksea_allsat_phy_l4_{date:%Y%m%d}_20200801.nc" +) +filename_sst = data.get_demo_path( + f"{date:%Y%m%d}000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc" +) +var_name_sst = "analysed_sst" + +extent = [27, 42, 40.5, 47] + +# %% +# Loading data +# ------------ +sst = RegularGridDataset(filename=filename_sst, x_name="lon", y_name="lat") +alti = RegularGridDataset( + data.get_demo_path(filename_alt), x_name="longitude", y_name="latitude" +) +# We can use `Grid` tools to interpolate ADT on the sst grid +sst.regrid(alti, "sla") +sst.add_uv("sla") + + +# %% +# Functions to initiate figure axes +def start_axes(title, extent=extent): + fig = plt.figure(figsize=(13, 6), dpi=120) + ax = fig.add_axes([0.03, 0.05, 0.89, 0.91]) + ax.set_xlim(extent[0], extent[1]) + ax.set_ylim(extent[2], extent[3]) + ax.set_title(title) + ax.set_aspect("equal") + return ax + + +def update_axes(ax, mappable=None, unit=""): + ax.grid() + if mappable: + cax = ax.figure.add_axes([0.93, 0.05, 0.01, 0.9], title=unit) + plt.colorbar(mappable, cax=cax) + + +# %% +# ADT first display +# ----------------- +ax = start_axes("SLA", extent=extent) +m = sst.display(ax, "sla", vmin=0.05, vmax=0.35) +update_axes(ax, m, unit="[m]") + +# %% +# SST first display +# ----------------- + +# %% +# We can now plot SST from `sst` +ax = start_axes("SST") +m = sst.display(ax, "analysed_sst", vmin=295, vmax=300) +update_axes(ax, m, unit="[°K]") + +# %% +ax = start_axes("SST") +m = sst.display(ax, "analysed_sst", vmin=295, vmax=300) +u, v = sst.grid("u").T, sst.grid("v").T +ax.quiver(sst.x_c[::3], sst.y_c[::3], u[::3, ::3], v[::3, ::3], scale=10) +update_axes(ax, m, unit="[°K]") + +# %% +# Now, with eddy contours, and displaying SST anomaly +sst.bessel_high_filter("analysed_sst", 400) + +# %% +# Eddy detection +sst.bessel_high_filter("sla", 400) +# ADT filtered +ax = start_axes("SLA", extent=extent) +m = sst.display(ax, "sla", vmin=-0.1, vmax=0.1) +update_axes(ax, m, unit="[m]") +a, c = sst.eddy_identification("sla", "u", "v", date, 0.002) + +# %% +kwargs_a = dict(lw=2, label="Anticyclonic", ref=-10, color="b") +kwargs_c = dict(lw=2, label="Cyclonic", ref=-10, color="r") +ax = start_axes("SST anomaly") +m = sst.display(ax, "analysed_sst", vmin=-1, vmax=1) +a.display(ax, **kwargs_a), c.display(ax, **kwargs_c) +ax.legend() +update_axes(ax, m, unit="[°K]") + +# %% +# Example of post-processing +# -------------------------- +# Get mean of sst anomaly_high in each internal contour +anom_a = a.interp_grid(sst, "analysed_sst", method="mean", intern=True) +anom_c = c.interp_grid(sst, "analysed_sst", method="mean", intern=True) + +# %% +# Are cyclonic (resp. anticyclonic) eddies generally associated with positive (resp. negative) SST anomaly ? +fig = plt.figure(figsize=(7, 5)) +ax = fig.add_axes([0.05, 0.05, 0.90, 0.90]) +ax.set_xlabel("SST anomaly") +ax.set_xlim([-1, 1]) +ax.set_title("Histograms of SST anomalies") +ax.hist( + anom_a, 5, alpha=0.5, color="b", label="Anticyclonic (mean:%s)" % (anom_a.mean()) +) +ax.hist(anom_c, 5, alpha=0.5, color="r", label="Cyclonic (mean:%s)" % (anom_c.mean())) +ax.legend() + +# %% +# Not clearly so in that case .. diff --git a/examples/12_external_data/pet_drifter_loopers.py b/examples/12_external_data/pet_drifter_loopers.py new file mode 100644 index 00000000..5266db7b --- /dev/null +++ b/examples/12_external_data/pet_drifter_loopers.py @@ -0,0 +1,153 @@ +""" +Colocate looper with eddy from altimetry +======================================== + +All loopers data used in this example are a subset from the dataset described in this article +[Lumpkin, R. : Global characteristics of coherent vortices from surface drifter trajectories](https://doi.org/10.1002/2015JC011435) +""" + +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +import numpy as np +import py_eddy_tracker_sample + +from py_eddy_tracker import data +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], aspect="equal") + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_title(title, weight="bold") + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Load eddies dataset +cyclonic_eddies = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") +) +anticyclonic_eddies = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) +) + +# %% +# Load loopers dataset +loopers_med = TrackEddiesObservations.load_file( + data.get_demo_path("loopers_lumpkin_med.nc") +) + +# %% +# Global view +# =========== +ax = start_axes("All drifters available in Med from Lumpkin dataset") +loopers_med.plot(ax, lw=0.5, color="r", ref=-10) +update_axes(ax) + +# %% +# One segment of drifter +# ====================== +# +# Get a drifter segment (the indexes used have no correspondance with the original dataset). +looper = loopers_med.extract_ids((3588,)) +fig = plt.figure(figsize=(16, 6)) +ax = fig.add_subplot(111, aspect="equal") +looper.plot(ax, lw=0.5, label="Original position of drifter") +looper_filtered = looper.copy() +looper_filtered.position_filter(1, 13) +s = looper_filtered.scatter( + ax, + "time", + cmap=plt.get_cmap("Spectral_r", 20), + label="Filtered position of drifter", +) +plt.colorbar(s).set_label("time (days from 1/1/1950)") +ax.legend() +ax.grid() + +# %% +# Try to find a detected eddies with adt at same place. We used filtered track to simulate an eddy center +match = looper_filtered.close_tracks( + anticyclonic_eddies, method="close_center", delta=0.1, nb_obs_min=50 +) +fig = plt.figure(figsize=(16, 6)) +ax = fig.add_subplot(111, aspect="equal") +looper.plot(ax, lw=0.5, label="Original position of drifter") +looper_filtered.plot(ax, lw=1.5, label="Filtered position of drifter") +match.plot(ax, lw=1.5, label="Matched eddy") +ax.legend() +ax.grid() + +# %% +# Display radius of this 2 datasets. +fig = plt.figure(figsize=(20, 8)) +ax = fig.add_subplot(111) +ax.plot(looper.time, looper.radius_s / 1e3, label="loopers") +looper_radius = looper.copy() +looper_radius.median_filter(1, "time", "radius_s", inplace=True) +looper_radius.loess_filter(13, "time", "radius_s", inplace=True) +ax.plot( + looper_radius.time, + looper_radius.radius_s / 1e3, + label="loopers (filtered half window 13 days)", +) +ax.plot(match.time, match.radius_s / 1e3, label="altimetry") +match_radius = match.copy() +match_radius.median_filter(1, "time", "radius_s", inplace=True) +match_radius.loess_filter(13, "time", "radius_s", inplace=True) +ax.plot( + match_radius.time, + match_radius.radius_s / 1e3, + label="altimetry (filtered half window 13 days)", +) +ax.set_ylabel("radius(km)"), ax.set_ylim(0, 100) +ax.legend() +ax.set_title("Radius from loopers and altimeter") +ax.grid() + + +# %% +# Animation of a drifter and its colocated eddy +def update(frame): + # We display last 5 days of loopers trajectory + m = (looper.time < frame) * (looper.time > (frame - 5)) + anim.func_animation(frame) + line.set_data(looper.lon[m], looper.lat[m]) + + +anim = Anim(match, intern=True, figsize=(8, 8), cmap="magma_r", nb_step=10, dpi=75) +# mappable to show drifter in red +line = anim.ax.plot([], [], "r", lw=4, zorder=100)[0] +anim.fig.suptitle("") +_ = VideoAnimation(anim.fig, update, frames=np.arange(*anim.period, 1), interval=125) diff --git a/examples/14_generic_tools/README.rst b/examples/14_generic_tools/README.rst new file mode 100644 index 00000000..295d55fe --- /dev/null +++ b/examples/14_generic_tools/README.rst @@ -0,0 +1,4 @@ +Polygon tools +============= + +Method to work with contour \ No newline at end of file diff --git a/examples/14_generic_tools/pet_fit_contour.py b/examples/14_generic_tools/pet_fit_contour.py new file mode 100644 index 00000000..2d3b6dc9 --- /dev/null +++ b/examples/14_generic_tools/pet_fit_contour.py @@ -0,0 +1,62 @@ +""" +Contour fit +=========== + +Two type of fit : + - Ellipse + - Circle + +In the two case we use a least square algorithm +""" + +from matplotlib import pyplot as plt +from numpy import cos, linspace, radians, sin + +from py_eddy_tracker import data +from py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates +from py_eddy_tracker.observations.observation import EddiesObservations +from py_eddy_tracker.poly import fit_circle_, fit_ellipse + +# %% +# Load example identification file +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) + + +# %% +# Function to draw circle or ellipse from parameter +def build_circle(x0, y0, r): + angle = radians(linspace(0, 360, 50)) + x_norm, y_norm = cos(angle), sin(angle) + return local_to_coordinates(x_norm * r, y_norm * r, x0, y0) + + +def build_ellipse(x0, y0, a, b, theta): + angle = radians(linspace(0, 360, 50)) + x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle) + y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle) + return local_to_coordinates(x, y, x0, y0) + + +# %% +# Plot fitted circle or ellipse on stored contour +xs, ys = a.contour_lon_s, a.contour_lat_s + +fig = plt.figure(figsize=(15, 15)) + +j = 1 +for i in range(0, 800, 30): + x, y = xs[i], ys[i] + x0_, y0_ = x.mean(), y.mean() + x_, y_ = coordinates_to_local(x, y, x0_, y0_) + ax = fig.add_subplot(4, 4, j) + ax.grid(), ax.set_aspect("equal") + ax.plot(x, y, label="store", color="black") + x0, y0, a, b, theta = fit_ellipse(x_, y_) + x0, y0 = local_to_coordinates(x0, y0, x0_, y0_) + ax.plot(*build_ellipse(x0, y0, a, b, theta), label="ellipse", color="green") + x0, y0, radius, shape_error = fit_circle_(x_, y_) + x0, y0 = local_to_coordinates(x0, y0, x0_, y0_) + ax.plot(*build_circle(x0, y0, radius), label="circle", color="red", lw=0.5) + if j == 16: + break + j += 1 diff --git a/examples/14_generic_tools/pet_visvalingam.py b/examples/14_generic_tools/pet_visvalingam.py new file mode 100644 index 00000000..736e8852 --- /dev/null +++ b/examples/14_generic_tools/pet_visvalingam.py @@ -0,0 +1,96 @@ +""" +Visvalingam algorithm +===================== +""" +from matplotlib import pyplot as plt +import matplotlib.animation as animation +from numba import njit +from numpy import array, empty + +from py_eddy_tracker import data +from py_eddy_tracker.generic import uniform_resample +from py_eddy_tracker.observations.observation import EddiesObservations +from py_eddy_tracker.poly import vertice_overlap, visvalingam + + +@njit(cache=True) +def visvalingam_polys(x, y, nb_pt): + nb = x.shape[0] + x_new = empty((nb, nb_pt), dtype=x.dtype) + y_new = empty((nb, nb_pt), dtype=y.dtype) + for i in range(nb): + x_new[i], y_new[i] = visvalingam(x[i], y[i], nb_pt) + return x_new, y_new + + +@njit(cache=True) +def uniform_resample_polys(x, y, nb_pt): + nb = x.shape[0] + x_new = empty((nb, nb_pt), dtype=x.dtype) + y_new = empty((nb, nb_pt), dtype=y.dtype) + for i in range(nb): + x_new[i], y_new[i] = uniform_resample(x[i], y[i], fixed_size=nb_pt) + return x_new, y_new + + +def update_line(num): + nb = 50 - num - 20 + x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb) + for i, (x_, y_) in enumerate(zip(x_v, y_v)): + lines_v[i].set_data(x_, y_) + x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb) + for i, (x_, y_) in enumerate(zip(x_u, y_u)): + lines_u[i].set_data(x_, y_) + scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0 + scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0 + for i, (s_v, s_u) in enumerate(zip(scores_v, scores_u)): + texts[i].set_text(f"Score uniform {s_u:.1f} %\nScore visvalingam {s_v:.1f} %") + title.set_text(f"{nb} points by contour in place of 50") + return (title, *lines_u, *lines_v, *texts) + + +# %% +# Load detection files +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +a = a.extract_with_mask((abs(a.lat) < 66) * (abs(a.radius_e) > 80e3)) + +nb_pt = 10 +x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0 +scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0 +d_6 = scores_v - scores_u +nb_pt = 18 +x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0 +scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0 +d_12 = scores_v - scores_u +a = a.index(array((d_6.argmin(), d_6.argmax(), d_12.argmin(), d_12.argmax()))) + + +# %% +fig = plt.figure() +axs = [ + fig.add_subplot(221), + fig.add_subplot(222), + fig.add_subplot(223), + fig.add_subplot(224), +] +lines_u, lines_v, texts, score_text = list(), list(), list(), list() +for i, obs in enumerate(a): + axs[i].set_aspect("equal") + axs[i].grid() + axs[i].set_xticklabels([]), axs[i].set_yticklabels([]) + axs[i].plot( + obs["contour_lon_e"], obs["contour_lat_e"], "r", lw=6, label="Original contour" + ) + lines_v.append(axs[i].plot([], [], color="limegreen", lw=4, label="visvalingam")[0]) + lines_u.append( + axs[i].plot([], [], color="black", lw=2, label="uniform resampling")[0] + ) + texts.append(axs[i].set_title("", fontsize=8)) +axs[0].legend(fontsize=8) +title = fig.suptitle("") +anim = animation.FuncAnimation(fig, update_line, 27) +anim diff --git a/examples/16_network/README.rst b/examples/16_network/README.rst new file mode 100644 index 00000000..49bdc3ab --- /dev/null +++ b/examples/16_network/README.rst @@ -0,0 +1,6 @@ +Network +======= + +.. warning:: + + Network is under development, API could move quickly! diff --git a/examples/16_network/pet_atlas.py b/examples/16_network/pet_atlas.py new file mode 100644 index 00000000..48b374e2 --- /dev/null +++ b/examples/16_network/pet_atlas.py @@ -0,0 +1,187 @@ +""" +Network Analysis +================ +""" +from matplotlib import pyplot as plt +from numpy import ma + +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations + +n = NetworkObservations.load_file( + get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc" + ) +) +# %% +# Parameters +step = 1 / 10.0 +bins = ((-10, 37, step), (30, 46, step)) +kw_time = dict(cmap="terrain_r", factor=100.0 / n.nb_days, name="count") +kw_ratio = dict(cmap=plt.get_cmap("YlGnBu_r", 10)) + + +# %% +# Functions +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# All +# --- +# Display the % of time each pixel (1/10°) is within an anticyclonic network +ax = start_axes("") +g_all = n.grid_count(bins) +m = g_all.display(ax, **kw_time, vmin=0, vmax=75) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Network longer than 10 days +# --------------------------- +# Display the % of time each pixel (1/10°) is within an anticyclonic network +# which total lifetime in longer than 10 days +ax = start_axes("") +n10 = n.longer_than(10) +g_10 = n10.grid_count(bins) +m = g_10.display(ax, **kw_time, vmin=0, vmax=75) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio +# ^^^^^ +# Ratio between the longer and total presence +ax = start_axes("") +g_ = g_10.vars["count"] * 100.0 / g_all.vars["count"] +m = g_10.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Blue = mostly short networks +# +# Network longer than 20 days +# --------------------------- +# Display the % of time each pixel (1/10°) is within an anticyclonic network +# which total lifetime is longer than 20 days +ax = start_axes("") +n20 = n.longer_than(20) +g_20 = n20.grid_count(bins) +m = g_20.display(ax, **kw_time, vmin=0, vmax=75) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio +# ^^^^^ +# Ratio between the longer and total presence +ax = start_axes("") +g_ = g_20.vars["count"] * 100.0 / g_all.vars["count"] +m = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Now we will hide pixel which are used less than 365 times +g_ = ma.array( + g_20.vars["count"] * 100.0 / g_all.vars["count"], mask=g_all.vars["count"] < 365 +) +ax = start_axes("") +m = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") +# %% +# Now we will hide pixel which are used more than 365 times +ax = start_axes("") +g_ = ma.array( + g_20.vars["count"] * 100.0 / g_all.vars["count"], mask=g_all.vars["count"] >= 365 +) +m = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Coastal areas are mostly populated by short networks +# +# All merging +# ----------- +# Display the occurence of merging events +ax = start_axes("") +g_all_merging = n.merging_event().grid_count(bins) +m = g_all_merging.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio merging events / eddy presence +ax = start_axes("") +g_ = g_all_merging.vars["count"] * 100.0 / g_all.vars["count"] +m = g_all_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations) +# -------------------------------------------------------------------------------------------- +ax = start_axes("") +n10_ = n10.copy() +n10_.remove_dead_end(nobs=10) +merger = n10_.merging_event() +g_10_merging = merger.grid_count(bins) +m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Merging in networks longer than 10 days +# --------------------------------------- +ax = start_axes("") +merger = n10.merging_event() +g_10_merging = merger.grid_count(bins) +m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") +# %% +# Ratio merging events / eddy presence +ax = start_axes("") +g_ = ma.array( + g_10_merging.vars["count"] * 100.0 / g_10.vars["count"], + mask=g_10.vars["count"] < 365, +) +m = g_10_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# All splitting +# ------------- +# Display the occurence of splitting events +ax = start_axes("") +g_all_splitting = n.splitting_event().grid_count(bins) +m = g_all_splitting.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio splitting events / eddy presence +ax = start_axes("") +g_ = g_all_splitting.vars["count"] * 100.0 / g_all.vars["count"] +m = g_all_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# splitting in networks longer than 10 days +# ----------------------------------------- +ax = start_axes("") +g_10_splitting = n10.splitting_event().grid_count(bins) +m = g_10_splitting.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") +# %% +ax = start_axes("") +g_ = ma.array( + g_10_splitting.vars["count"] * 100.0 / g_10.vars["count"], + mask=g_10.vars["count"] < 365, +) +m = g_10_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") diff --git a/examples/16_network/pet_follow_particle.py b/examples/16_network/pet_follow_particle.py new file mode 100644 index 00000000..6815fb6e --- /dev/null +++ b/examples/16_network/pet_follow_particle.py @@ -0,0 +1,186 @@ +""" +Follow particle +=============== + +""" +import re + +from matplotlib import colors, pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, meshgrid, ones, unique, zeros + +from py_eddy_tracker import start_logger +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection +from py_eddy_tracker.observations.groups import particle_candidate +from py_eddy_tracker.observations.network import NetworkObservations + +start_logger().setLevel("ERROR") + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) +n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269)) +n.remove_dead_end(nobs=0, ndays=10) +n = n.remove_trash() +n.numbering_segment() +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + heigth="adt", +) + +# %% +# Schema +# ------ +fig = plt.figure(figsize=(12, 6)) +ax = fig.add_axes([0.05, 0.05, 0.9, 0.9]) +_ = n.display_timeline(ax, field="longitude", marker="+", lw=2, markersize=5) + +# %% +# Animation +# --------- +# Particle settings +t_snapshot = 20200 +step = 1 / 50.0 +x, y = meshgrid(arange(20, 36, step), arange(30, 46, step)) +N = 6 +x_f, y_f = x[::N, ::N].copy(), y[::N, ::N].copy() +x, y = x.reshape(-1), y.reshape(-1) +x_f, y_f = x_f.reshape(-1), y_f.reshape(-1) +n_ = n.extract_with_mask(n.time == t_snapshot) +index = n_.contains(x, y, intern=True) +m = index != -1 +index = n_.segment[index[m]] +index_ = unique(index) +x, y = x[m], y[m] +m = ~n_.inside(x_f, y_f, intern=True) +x_f, y_f = x_f[m], y_f[m] + +# %% +# Animation +cmap = colors.ListedColormap(list(n.COLORS), name="from_list", N=n.segment.max() + 1) +a = Anim( + n, + intern=False, + figsize=(12, 6), + nb_step=1, + dpi=60, + field_color="segment", + field_txt="segment", + cmap=cmap, +) +a.fig.suptitle(""), a.ax.set_xlim(24, 36), a.ax.set_ylim(30, 36) +a.txt.set_position((25, 31)) + +step = 0.25 +kw_p = dict( + nb_step=2, + time_step=86400 * step * 0.5, + t_init=t_snapshot - 2 * step, + u_name="u", + v_name="v", +) + +mappables = dict() +particules = c.advect(x, y, **kw_p) +filament = c.filament(x_f, y_f, **kw_p, filament_size=3) +kw = dict(ls="", marker=".", markersize=0.25) +for k in index_: + m = k == index + mappables[k] = a.ax.plot([], [], color=cmap(k), **kw)[0] +m_filament = a.ax.plot([], [], lw=0.25, color="gray")[0] + + +def update(frame): + tt, xt, yt = particules.__next__() + for k, mappable in mappables.items(): + m = index == k + mappable.set_data(xt[m], yt[m]) + tt, xt, yt = filament.__next__() + m_filament.set_data(xt, yt) + if frame % 1 == 0: + a.func_animation(frame) + + +ani = VideoAnimation(a.fig, update, frames=arange(20200, 20269, step), interval=200) + + +# %% +# Particle advection +# ^^^^^^^^^^^^^^^^^^ +# Advection from speed contour to speed contour (default) + +step = 1 / 60.0 + +t_start, t_end = int(n.period[0]), int(n.period[1]) +dt = 14 + +shape = (n.obs.size, 2) +# Forward run +i_target_f, pct_target_f = -ones(shape, dtype="i4"), zeros(shape, dtype="i1") +for t in arange(t_start, t_end - dt): + particle_candidate(c, n, step, t, i_target_f, pct_target_f, n_days=dt) + +# Backward run +i_target_b, pct_target_b = -ones(shape, dtype="i4"), zeros(shape, dtype="i1") +for t in arange(t_start + dt, t_end): + particle_candidate(c, n, step, t, i_target_b, pct_target_b, n_days=-dt) + +# %% +fig = plt.figure(figsize=(10, 10)) +ax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45]) +ax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45]) +ax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45]) +ax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45]) +ax_1st_b.set_title("Backward advection for each time step") +ax_1st_f.set_title("Forward advection for each time step") +ax_1st_b.set_ylabel("Color -> First target\nLatitude") +ax_2nd_b.set_ylabel("Color -> Secondary target\nLatitude") +ax_2nd_b.set_xlabel("Julian days"), ax_2nd_f.set_xlabel("Julian days") +ax_1st_f.set_yticks([]), ax_2nd_f.set_yticks([]) +ax_1st_f.set_xticks([]), ax_1st_b.set_xticks([]) + + +def color_alpha(target, pct, vmin=5, vmax=80): + color = cmap(n.segment[target]) + # We will hide under 5 % and from 80% to 100 % it will be 1 + alpha = (pct - vmin) / (vmax - vmin) + alpha[alpha < 0] = 0 + alpha[alpha > 1] = 1 + color[:, 3] = alpha + return color + + +kw = dict( + name=None, yfield="longitude", event=False, zorder=-100, s=(n.speed_area / 20e6) +) +n.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw) +n.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw) +n.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw) +n.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw) +for ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f): + n.display_timeline(ax, field="longitude", marker="+", lw=2, markersize=5) + ax.grid() diff --git a/examples/16_network/pet_group_anim.py b/examples/16_network/pet_group_anim.py new file mode 100644 index 00000000..f2d439ed --- /dev/null +++ b/examples/16_network/pet_group_anim.py @@ -0,0 +1,154 @@ +""" +Network group process +===================== +""" +from datetime import datetime + +# sphinx_gallery_thumbnail_number = 2 +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib.colors import ListedColormap +from numba import njit +from numpy import arange, array, empty, ones + +from py_eddy_tracker import data +from py_eddy_tracker.generic import flatten_line_matrix +from py_eddy_tracker.observations.network import Network +from py_eddy_tracker.observations.observation import EddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +NETWORK_GROUPS = list() + + +@njit(cache=True) +def apply_replace(x, x0, x1): + nb = x.shape[0] + for i in range(nb): + if x[i] == x0: + x[i] = x1 + + +# %% +# Modified class to catch group process at each step in order to illustrate processing +class MyNetwork(Network): + def get_group_array(self, results, nb_obs): + """With a loop on all pair of index, we will label each obs with a group + number + """ + nb_obs = array(nb_obs, dtype="u4") + day_start = nb_obs.cumsum() - nb_obs + gr = empty(nb_obs.sum(), dtype="u4") + gr[:] = self.NOGROUP + + id_free = 1 + for i, j, ii, ij in results: + gr_i = gr[slice(day_start[i], day_start[i] + nb_obs[i])] + gr_j = gr[slice(day_start[j], day_start[j] + nb_obs[j])] + # obs with no groups + m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] == self.NOGROUP) + nb_new = m.sum() + gr_i[ii[m]] = gr_j[ij[m]] = arange(id_free, id_free + nb_new) + id_free += nb_new + # associate obs with no group with obs with group + m = (gr_i[ii] != self.NOGROUP) * (gr_j[ij] == self.NOGROUP) + gr_j[ij[m]] = gr_i[ii[m]] + m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] != self.NOGROUP) + gr_i[ii[m]] = gr_j[ij[m]] + # case where 2 obs have a different group + m = gr_i[ii] != gr_j[ij] + if m.any(): + # Merge of group, ref over etu + for i_, j_ in zip(ii[m], ij[m]): + g0, g1 = gr_i[i_], gr_j[j_] + apply_replace(gr, g0, g1) + NETWORK_GROUPS.append((i, j, gr.copy())) + return gr + + +# %% +# Movie period +t0 = (datetime(2005, 5, 1) - datetime(1950, 1, 1)).days +t1 = (datetime(2005, 6, 1) - datetime(1950, 1, 1)).days + +# %% +# Get data from period and area +e = EddiesObservations.load_file(data.get_demo_path("network_med.nc")) +e = e.extract_with_mask((e.time >= t0) * (e.time < t1)).extract_with_area( + dict(llcrnrlon=25, urcrnrlon=35, llcrnrlat=31, urcrnrlat=37.5) +) +# %% +# Reproduce individual daily identification(for demonstration) +EDDIES_BY_DAYS = list() +for i, b0, b1 in e.iter_on("time"): + EDDIES_BY_DAYS.append(e.index(i)) +# need for display +e = EddiesObservations.concatenate(EDDIES_BY_DAYS) + +# %% +# Run network building group to intercept every step +n = MyNetwork.from_eddiesobservations(EDDIES_BY_DAYS, window=7) +_ = n.group_observations(minimal_area=True) + + +# %% +def update(frame): + i_current, i_match, gr = NETWORK_GROUPS[frame] + current = EDDIES_BY_DAYS[i_current] + x = flatten_line_matrix(current.contour_lon_e) + y = flatten_line_matrix(current.contour_lat_e) + current_contour.set_data(x, y) + match = EDDIES_BY_DAYS[i_match] + x = flatten_line_matrix(match.contour_lon_e) + y = flatten_line_matrix(match.contour_lat_e) + matched_contour.set_data(x, y) + groups.set_array(gr) + txt.set_text(f"Day {i_current} match with day {i_match}") + s = 80 * ones(gr.shape) + s[gr == 0] = 4 + groups.set_sizes(s) + + +# %% +# Anim +# ---- +fig = plt.figure(figsize=(16, 9), dpi=50) +ax = fig.add_axes([0, 0, 1, 1]) +ax.set_aspect("equal"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5) +cmap = ListedColormap(["gray", *e.COLORS[:-1]], name="from_list", N=30) +kw_s = dict(cmap=cmap, vmin=0, vmax=30) +groups = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[0][2], **kw_s) +current_contour = ax.plot([], [], "k", lw=2, label="Current contour")[0] +matched_contour = ax.plot([], [], "r", lw=1, ls="--", label="Candidate contour")[0] +txt = ax.text(29, 35, "", fontsize=25) +ax.legend(fontsize=25) +ani = VideoAnimation(fig, update, frames=len(NETWORK_GROUPS), interval=220) + +# %% +# Final Result +# ------------ +fig = plt.figure(figsize=(16, 9)) +ax = fig.add_axes([0, 0, 1, 1]) +ax.set_aspect("equal"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5) +_ = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[-1][2], **kw_s) diff --git a/examples/16_network/pet_ioannou_2017_case.py b/examples/16_network/pet_ioannou_2017_case.py new file mode 100644 index 00000000..56bec82e --- /dev/null +++ b/examples/16_network/pet_ioannou_2017_case.py @@ -0,0 +1,237 @@ +""" +Ioannou case +============ +Figure 10 from https://doi.org/10.1002/2017JC013158 + +We want to find the Ierapetra Eddy described above in a network demonstration run. +""" + +from datetime import datetime, timedelta + +# %% +import re + +from matplotlib import colors, pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib.ticker import FuncFormatter +from numpy import arange, array, pi, where + +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.generic import coordinates_to_local +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.poly import fit_ellipse + +# %% + + +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +@FuncFormatter +def formatter(x, pos): + return (timedelta(x) + datetime(1950, 1, 1)).strftime("%d/%m/%Y") + + +def start_axes(title=""): + fig = plt.figure(figsize=(13, 6)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) + ax.set_xlim(19, 29), ax.set_ylim(31, 35.5) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def timeline_axes(title=""): + fig = plt.figure(figsize=(15, 5)) + ax = fig.add_axes([0.03, 0.06, 0.90, 0.88]) + ax.set_title(title, weight="bold") + ax.xaxis.set_major_formatter(formatter), ax.grid() + return ax + + +def update_axes(ax, mappable=None): + ax.grid(True) + if mappable: + return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# We know the network ID, we will get directly +ioannou_case = NetworkObservations.load_file(get_demo_path("network_med.nc")).network( + 651 +) +print(ioannou_case.infos()) + +# %% +# It seems that this network is huge! Our case is visible at 22E 33.5N +ax = start_axes() +ioannou_case.plot(ax, color_cycle=ioannou_case.COLORS) +update_axes(ax) + +# %% +# Full Timeline +# ------------- +# The network span for many years... How to cut the interesting part? +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.xaxis.set_major_formatter(formatter), ax.grid() +_ = ioannou_case.display_timeline(ax) + + +# %% +# Sub network and new numbering +# ----------------------------- +# Here we chose to keep only the order 3 segments relatives to our chosen eddy +i = where( + (ioannou_case.lat > 33) + * (ioannou_case.lat < 34) + * (ioannou_case.lon > 22) + * (ioannou_case.lon < 23) + * (ioannou_case.time > 20630) + * (ioannou_case.time < 20650) +)[0][0] +close_to_i3 = ioannou_case.relative(i, order=3) +close_to_i3.numbering_segment() + +# %% +# Anim +# ---- +# Quick movie to see better! +a = Anim( + close_to_i3, + figsize=(12, 4), + cmap=colors.ListedColormap( + list(close_to_i3.COLORS), name="from_list", N=close_to_i3.segment.max() + 1 + ), + nb_step=7, + dpi=70, + field_color="segment", + field_txt="segment", +) +a.ax.set_xlim(19, 30), a.ax.set_ylim(32, 35.25) +a.txt.set_position((21.5, 32.7)) +# We display in video only from the 100th day to the 500th +kwargs = dict(frames=arange(*a.period)[100:501], interval=100) +ani = VideoAnimation(a.fig, a.func_animation, **kwargs) + +# %% +# Classic display +# --------------- +ax = timeline_axes() +_ = close_to_i3.display_timeline(ax) + +# %% +ax = start_axes("") +n_copy = close_to_i3.copy() +n_copy.position_filter(2, 4) +n_copy.plot(ax, color_cycle=n_copy.COLORS) +update_axes(ax) + +# %% +# Latitude Timeline +# ----------------- +ax = timeline_axes(f"Close segments ({close_to_i3.infos()})") +n_copy = close_to_i3.copy() +n_copy.median_filter(15, "time", "latitude") +_ = n_copy.display_timeline(ax, field="lat", method="all") + +# %% +# Local radius timeline +# --------------------- +# Effective (bold) and Speed (thin) Radius together +n_copy.median_filter(2, "time", "radius_e") +n_copy.median_filter(2, "time", "radius_s") +for b0, b1 in [ + (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007) +]: + ref, delta = datetime(1950, 1, 1), 20 + b0_, b1_ = (b0 - ref).days, (b1 - ref).days + ax = timeline_axes() + ax.set_xlim(b0_ - delta, b1_ + delta) + ax.set_ylim(10, 115) + ax.axvline(b0_, color="k", lw=1.5, ls="--"), ax.axvline( + b1_, color="k", lw=1.5, ls="--" + ) + n_copy.display_timeline( + ax, field="radius_e", method="all", lw=4, markersize=8, factor=1e-3 + ) + n_copy.display_timeline( + ax, field="radius_s", method="all", lw=1, markersize=3, factor=1e-3 + ) + +# %% +# Parameters timeline +# ------------------- +# Effective Radius +kw = dict(s=35, cmap=plt.get_cmap("Spectral_r", 8), zorder=10) +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, "radius_e", factor=1e-3, vmin=20, vmax=100, **kw) +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective radius (km)") +# %% +# Shape error +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, "shape_error_e", vmin=14, vmax=70, **kw) +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") + +# %% +# Rotation angle +# -------------- +# For each obs, fit an ellipse to the contour, with theta the angle from the x-axis, +# a the semi ax in x direction and b the semi ax in y dimension + +theta_ = list() +a_ = list() +b_ = list() +for obs in close_to_i3: + x, y = obs["contour_lon_s"], obs["contour_lat_s"] + x0_, y0_ = x.mean(), y.mean() + x_, y_ = coordinates_to_local(x, y, x0_, y0_) + x0, y0, a, b, theta = fit_ellipse(x_, y_) + theta_.append(theta) + a_.append(a) + b_.append(b) +a_ = array(a_) +b_ = array(b_) + +# %% +# Theta +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, theta_, vmin=-pi / 2, vmax=pi / 2, cmap="hsv") +_ = update_axes(ax, m["scatter"]) + +# %% +# a +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, a_ * 1e-3, vmin=0, vmax=80, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) + +# %% +# b +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, b_ * 1e-3, vmin=0, vmax=80, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) + +# %% +# a/b +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, a_ / b_, vmin=1, vmax=2, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py new file mode 100644 index 00000000..dd97b538 --- /dev/null +++ b/examples/16_network/pet_relative.py @@ -0,0 +1,336 @@ +""" +Network basic manipulation +========================== +""" +from matplotlib import pyplot as plt +from numpy import where + +from py_eddy_tracker import data +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations + +# %% +# Load data +# --------- +# Load data where observations are put in same network but no segmentation +n = NetworkObservations.load_file(data.get_demo_path("network_med.nc")).network(651) +i = where( + (n.lat > 33) + * (n.lat < 34) + * (n.lon > 22) + * (n.lon < 23) + * (n.time > 20630) + * (n.time < 20650) +)[0][0] +# For event use +n2 = n.relative(i, order=2) +n = n.relative(i, order=4) +n.numbering_segment() + +# %% +# Timeline +# -------- + +# %% +# Display timeline with events +# A segment generated by a splitting is marked with a star +# +# A segment merging in another is marked with an exagon +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +_ = n.display_timeline(ax) + +# %% +# Display timeline without event +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +_ = n.display_timeline(ax, event=False) + +# %% +# Timeline by mean latitude +# ------------------------- +# Display timeline with the mean latitude of the segments in yaxis +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +ax.set_ylabel("Latitude") +_ = n.display_timeline(ax, field="latitude") + +# %% +# Timeline by mean Effective Radius +# --------------------------------- +# The factor argument is applied on the chosen field +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +ax.set_ylabel("Effective Radius (km)") +_ = n.display_timeline(ax, field="radius_e", factor=1e-3) + +# %% +# Timeline by latitude +# -------------------- +# Use `method="all"` to display the consecutive values of the field +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.set_ylabel("Latitude") +_ = n.display_timeline(ax, field="lat", method="all") + +# %% +# You can filter the data, here with a time window of 15 days +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +n_copy = n.copy() +n_copy.median_filter(15, "time", "latitude") +_ = n_copy.display_timeline(ax, field="lat", method="all") + +# %% +# Parameters timeline +# ------------------- +# Scatter is usefull to display the parameters' temporal evolution +# +# Effective Radius and Amplitude +kw = dict(s=25, cmap="Spectral_r", zorder=10) +fig = plt.figure(figsize=(15, 12)) +ax = fig.add_axes([0.04, 0.54, 0.90, 0.44]) +m = n.scatter_timeline(ax, "radius_e", factor=1e-3, vmin=50, vmax=150, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.54, 0.01, 0.44]), orientation="vertical" +) +cb.set_label("Effective radius (km)") + +ax = fig.add_axes([0.04, 0.04, 0.90, 0.44]) +m = n.scatter_timeline(ax, "amplitude", factor=100, vmin=0, vmax=15, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.44]), orientation="vertical" +) +cb.set_label("Amplitude (cm)") + +# %% +# Speed +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +m = n.scatter_timeline(ax, "speed_average", factor=100, vmin=0, vmax=40, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation="vertical" +) +cb.set_label("Maximum speed (cm/s)") + +# %% +# Speed Radius +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +m = n.scatter_timeline(ax, "radius_s", factor=1e-3, vmin=20, vmax=100, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation="vertical" +) +cb.set_label("Speed radius (km)") + +# %% +# Remove dead branch +# ------------------ +# Remove all tiny segments with less than N obs which didn't join two segments +n_clean = n.copy() +n_clean.remove_dead_end(nobs=5, ndays=10) +n_clean = n_clean.remove_trash() +fig = plt.figure(figsize=(15, 12)) +ax = fig.add_axes([0.04, 0.54, 0.90, 0.40]) +ax.set_title(f"Original network ({n.infos()})") +n.display_timeline(ax) +ax = fig.add_axes([0.04, 0.04, 0.90, 0.40]) +ax.set_title(f"Clean network ({n_clean.infos()})") +_ = n_clean.display_timeline(ax) + +# %% +# For further figure we will use clean path +n = n_clean + +# %% +# Change splitting-merging events +# ------------------------------- +# change event where seg A split to B, then A merge into B, to A split to B then B merge into A +fig = plt.figure(figsize=(15, 12)) +ax = fig.add_axes([0.04, 0.54, 0.90, 0.40]) +ax.set_title(f"Clean network ({n.infos()})") +n.display_timeline(ax) + +clean_modified = n.copy() +# If it's happen in less than 40 days +clean_modified.correct_close_events(40) + +ax = fig.add_axes([0.04, 0.04, 0.90, 0.40]) +ax.set_title(f"resplitted network ({clean_modified.infos()})") +_ = clean_modified.display_timeline(ax) + +# %% +# Keep only observations where water could propagate from an observation +# ---------------------------------------------------------------------- +i_observation = 600 +only_linked = n.find_link(i_observation) + +fig = plt.figure(figsize=(15, 12)) +ax1 = fig.add_axes([0.04, 0.54, 0.90, 0.40]) +ax2 = fig.add_axes([0.04, 0.04, 0.90, 0.40]) + +kw = dict(marker="s", s=300, color="black", zorder=200, label="observation start") +for ax, dataset in zip([ax1, ax2], [n, only_linked]): + dataset.display_timeline(ax, field="segment", lw=2, markersize=5, colors_mode="y") + ax.scatter(n.time[i_observation], n.segment[i_observation], **kw) + ax.legend() + +ax1.set_title(f"full example ({n.infos()})") +ax2.set_title(f"only linked observations ({only_linked.infos()})") +_ = ax2.set_xlim(ax1.get_xlim()), ax2.set_ylim(ax1.get_ylim()) + +# %% +# Keep close relative +# ------------------- +# When you want to investigate one particular observation and select only the closest segments + +# First choose an observation in the network +i = 1100 + +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +n.display_timeline(ax) +obs_args = n.time[i], n.segment[i] +obs_kw = dict(color="black", markersize=30, marker=".") +_ = ax.plot(*obs_args, **obs_kw) + +# %% +# Colors show the relative order of the segment with regards to the chosen one +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +m = n.scatter_timeline( + ax, n.obs_relative_order(i), vmin=-1.5, vmax=6.5, cmap=plt.get_cmap("jet", 8), s=10 +) +ax.plot(*obs_args, **obs_kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation="vertical" +) +cb.set_label("Relative order") +# %% +# You want to keep only the segments at the order 1 +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +close_to_i1 = n.relative(i, order=1) +ax.set_title(f"Close segments ({close_to_i1.infos()})") +_ = close_to_i1.display_timeline(ax) +# %% +# You want to keep the segments until order 2 +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +close_to_i2 = n.relative(i, order=2) +ax.set_title(f"Close segments ({close_to_i2.infos()})") +_ = close_to_i2.display_timeline(ax) +# %% +# You want to keep the segments until order 3 +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +close_to_i3 = n.relative(i, order=3) +ax.set_title(f"Close segments ({close_to_i3.infos()})") +_ = close_to_i3.display_timeline(ax) + +# %% +# Keep relatives to an event +# -------------------------- +# When you want to investigate one particular event and select only the closest segments +# +# First choose a merging event in the network +after, before, stopped = n.merging_event(triplet=True, only_index=True) +i_event = 7 +# %% +# then see some order of relatives + +max_order = 1 +fig, axs = plt.subplots( + max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2)) +) +# Original network +ax = axs[0] +ax.set_title("Full network", weight="bold") +n.display_timeline(axs[0], colors_mode="y") +ax.grid(), ax.legend() + +for k in range(0, max_order + 1): + ax = axs[k + 1] + ax.set_title(f"Relatives order={k}", weight="bold") + # Extract neighbours of event + sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k) + sub_network.display_timeline(ax, colors_mode="y") + ax.legend(), ax.grid() + _ = ax.set_ylim(axs[0].get_ylim()) + +# %% +# Display track on map +# -------------------- + +# Get a simplified network +n = n2.copy() +n.remove_dead_end(nobs=50, recursive=1) +n = n.remove_trash() +n.numbering_segment() +# %% +# Only a map can be tricky to understand, with a timeline it's easier! +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) +n.plot(ax, color_cycle=n.COLORS) +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +ax = fig.add_axes([0.08, 0.7, 0.7, 0.3]) +_ = n.display_timeline(ax) + + +# %% +# Get merging event +# ----------------- +# Display the position of the eddies after a merging +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +n.plot(ax, color_cycle=n.COLORS) +m1, m0, m0_stop = n.merging_event(triplet=True) +m1.display(ax, color="violet", lw=2, label="Eddies after merging") +m0.display(ax, color="blueviolet", lw=2, label="Eddies before merging") +m0_stop.display(ax, color="black", lw=2, label="Eddies stopped by merging") +ax.plot(m1.lon, m1.lat, marker=".", color="purple", ls="") +ax.plot(m0.lon, m0.lat, marker=".", color="blueviolet", ls="") +ax.plot(m0_stop.lon, m0_stop.lat, marker=".", color="black", ls="") +ax.legend() +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +m1 + +# %% +# Get splitting event +# ------------------- +# Display the position of the eddies before a splitting +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +n.plot(ax, color_cycle=n.COLORS) +s0, s1, s1_start = n.splitting_event(triplet=True) +s0.display(ax, color="violet", lw=2, label="Eddies before splitting") +s1.display(ax, color="blueviolet", lw=2, label="Eddies after splitting") +s1_start.display(ax, color="black", lw=2, label="Eddies starting by splitting") +ax.plot(s0.lon, s0.lat, marker=".", color="purple", ls="") +ax.plot(s1.lon, s1.lat, marker=".", color="blueviolet", ls="") +ax.plot(s1_start.lon, s1_start.lat, marker=".", color="black", ls="") +ax.legend() +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +s1 + +# %% +# Get birth event +# --------------- +# Display the starting position of non-splitted eddies +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +birth = n.birth_event() +birth.display(ax) +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +birth + +# %% +# Get death event +# --------------- +# Display the last position of non-merged eddies +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +death = n.death_event() +death.display(ax) +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +death diff --git a/examples/16_network/pet_replay_segmentation.py b/examples/16_network/pet_replay_segmentation.py new file mode 100644 index 00000000..d909af7f --- /dev/null +++ b/examples/16_network/pet_replay_segmentation.py @@ -0,0 +1,176 @@ +""" +Replay segmentation +=================== +Case from figure 10 from https://doi.org/10.1002/2017JC013158 + +Again with the Ierapetra Eddy +""" +from datetime import datetime, timedelta + +from matplotlib import pyplot as plt +from matplotlib.ticker import FuncFormatter +from numpy import where + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +@FuncFormatter +def formatter(x, pos): + return (timedelta(x) + datetime(1950, 1, 1)).strftime("%d/%m/%Y") + + +def start_axes(title=""): + fig = plt.figure(figsize=(13, 6)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) + ax.set_xlim(19, 29), ax.set_ylim(31, 35.5) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def timeline_axes(title=""): + fig = plt.figure(figsize=(15, 5)) + ax = fig.add_axes([0.04, 0.06, 0.89, 0.88]) + ax.set_title(title, weight="bold") + ax.xaxis.set_major_formatter(formatter), ax.grid() + return ax + + +def update_axes(ax, mappable=None): + ax.grid(True) + if mappable: + return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Class for new_segmentation +# -------------------------- +# The oldest win +class MyTrackEddiesObservations(TrackEddiesObservations): + __slots__ = tuple() + + @classmethod + def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs): + """ + Method to overwrite behaviour in merging. + + We will give the point to the older one instead of the maximum overlap ratio + """ + while i_next != -1: + # Flag + used[i_next] = True + # Assign id + ids["track"][i_next] = track_id + # Search next + i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs) + if i_next_ == -1: + break + ids["next_obs"][i_next] = i_next_ + # Target was previously used + if used[i_next_]: + i_next_ = -1 + else: + ids["previous_obs"][i_next_] = i_next + i_next = i_next_ + + +def get_obs(dataset): + "Function to isolate a specific obs" + return where( + (dataset.lat > 33) + * (dataset.lat < 34) + * (dataset.lon > 22) + * (dataset.lon < 23) + * (dataset.time > 20630) + * (dataset.time < 20650) + )[0][0] + + +# %% +# Get original network, we will isolate only relative at order *2* +n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) +n_ = n.relative(get_obs(n), order=2) + +# %% +# Display the default segmentation +ax = start_axes(n_.infos()) +n_.plot(ax, color_cycle=n.COLORS) +update_axes(ax) +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.xaxis.set_major_formatter(formatter), ax.grid() +_ = n_.display_timeline(ax) + +# %% +# Run a new segmentation +# ---------------------- +e = n.astype(MyTrackEddiesObservations) +e.obs.sort(order=("track", "time"), kind="stable") +split_matrix = e.split_network(intern=False, window=7) +n_ = NetworkObservations.from_split_network(e, split_matrix) +n_ = n_.relative(get_obs(n_), order=2) +n_.numbering_segment() + +# %% +# New segmentation +# ---------------- +# "The oldest wins" method produce a very long segment +ax = start_axes(n_.infos()) +n_.plot(ax, color_cycle=n_.COLORS) +update_axes(ax) +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.xaxis.set_major_formatter(formatter), ax.grid() +_ = n_.display_timeline(ax) + +# %% +# Parameters timeline +# ------------------- +kw = dict(s=35, cmap=plt.get_cmap("Spectral_r", 8), zorder=10) +ax = timeline_axes() +n_.median_filter(15, "time", "latitude") +m = n_.scatter_timeline(ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lat") +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") + +ax = timeline_axes() +n_.median_filter(15, "time", "latitude") +m = n_.scatter_timeline( + ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lat", method="all" +) +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") +ax.set_ylabel("Latitude") + +ax = timeline_axes() +n_.median_filter(15, "time", "latitude") +kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30**2 * 20 +m = n_.scatter_timeline( + ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lon", method="all" +) +ax.set_ylabel("Longitude") +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") + +# %% +# Cost association plot +# --------------------- +n_copy = n_.copy() +n_copy.median_filter(2, "time", "next_cost") +for b0, b1 in [ + (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008) +]: + ref, delta = datetime(1950, 1, 1), 20 + b0_, b1_ = (b0 - ref).days, (b1 - ref).days + ax = timeline_axes() + ax.set_xlim(b0_ - delta, b1_ + delta) + ax.set_ylim(0, 1) + ax.axvline(b0_, color="k", lw=1.5, ls="--"), ax.axvline( + b1_, color="k", lw=1.5, ls="--" + ) + n_copy.display_timeline(ax, field="next_cost", method="all", lw=4, markersize=8) + + n_.display_timeline(ax, field="next_cost", method="all", lw=0.5, markersize=0) diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py new file mode 100644 index 00000000..1fcb9ae1 --- /dev/null +++ b/examples/16_network/pet_segmentation_anim.py @@ -0,0 +1,125 @@ +""" +Network segmentation process +============================ +""" +# sphinx_gallery_thumbnail_number = 2 +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib.colors import ListedColormap +from numpy import ones, where + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +def get_obs(dataset): + "Function to isolate a specific obs" + return where( + (dataset.lat > 33) + * (dataset.lat < 34) + * (dataset.lon > 22) + * (dataset.lon < 23) + * (dataset.time > 20630) + * (dataset.time < 20650) + )[0][0] + + +# %% +# Hack to pick up each step of segmentation +TRACKS = list() +INDICES = list() + + +class MyTrack(TrackEddiesObservations): + @staticmethod + def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs): + TRACKS.append(ids["track"].copy()) + INDICES.append(i_current) + return TrackEddiesObservations.get_next_obs( + i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs + ) + + +# %% +# Load data +# --------- +# Load data where observations are put in same network but no segmentation + +# Get a known network for the demonstration +n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) +# We keep only some segment +n = n.relative(get_obs(n), order=2) +print(len(n)) +# We convert and order object like segmentation was never happen on observations +e = n.astype(MyTrack) +e.obs.sort(order=("track", "time"), kind="stable") + +# %% +# Do segmentation +# --------------- +# Segmentation based on maximum overlap, temporal window for candidates = 5 days +matrix = e.split_network(intern=False, window=5) + + +# %% +# Anim +# ---- +def update(i_frame): + tr = TRACKS[i_frame] + mappable_tracks.set_array(tr) + s = 40 * ones(tr.shape) + s[tr == 0] = 4 + mappable_tracks.set_sizes(s) + + indices_frames = INDICES[i_frame] + mappable_CONTOUR.set_data( + e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames] + ) + mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) + return (mappable_tracks,) + + +fig = plt.figure(figsize=(16, 9), dpi=60) +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) +ax.set_title(f"{len(e)} observations to segment") +ax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid() +vmax = TRACKS[-1].max() +cmap = ListedColormap(["gray", *e.COLORS[:-1]], name="from_list", N=vmax) +mappable_tracks = ax.scatter( + e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20 +) +mappable_CONTOUR = ax.plot( + e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0] +)[0] +ani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125) + +# %% +# Final Result +# ------------ +fig = plt.figure(figsize=(16, 9)) +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) +ax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid() +_ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20) diff --git a/notebooks/README.md b/notebooks/README.md index 0e481b17..fd8971aa 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -1,2 +1,3 @@ +# rm build/sphinx/ doc/python_module/ doc/gen_modules/ doc/_autosummary/ -rf python setup.py build_sphinx -rsync -vrltp doc/python_module notebooks/. --include '*/' --include '*.ipynb' --exclude '*' --prune-empty-dirs \ No newline at end of file +rsync -vrltp doc/python_module notebooks/. --include '*/' --include '*.ipynb' --exclude '*' --prune-empty-dirs diff --git a/notebooks/python_module/01_general_things/pet_storage.ipynb b/notebooks/python_module/01_general_things/pet_storage.ipynb new file mode 100644 index 00000000..a56e4def --- /dev/null +++ b/notebooks/python_module/01_general_things/pet_storage.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# How data is stored\n\nGeneral information about eddies storage.\n\nAll files have the same structure, with more or less fields and possible different order.\n\nThere are 3 class of files:\n\n- **Eddies collections** : contain a list of eddies without link between them\n- **Track eddies collections** :\n manage eddies associated in trajectories, the ```track``` field allows to separate each trajectory\n- **Network eddies collections** :\n manage eddies associated in networks, the ```track``` and ```segment``` fields allow to separate observations\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, outer\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.observation import EddiesObservations, Table\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Eddies can be stored in 2 formats with the same structure:\n\n- zarr (https://zarr.readthedocs.io/en/stable/), which allow efficiency in IO,...\n- NetCDF4 (https://unidata.github.io/netcdf4-python/), well-known format\n\nEach field are stored in column, each row corresponds at 1 observation,\narray field like contour/profile are 2D column.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Eddies files (zarr or netcdf) can be loaded with ```load_file``` method:\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections = EddiesObservations.load_file(get_demo_path(\"Cyclonic_20160515.nc\"))\neddies_collections.field_table()\n# offset and scale_factor are used only when data is stored in zarr or netCDF4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Field access\nTo access the total field, here ```amplitude```\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections.amplitude\n\n# To access only a specific part of the field\neddies_collections.amplitude[4:15]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data matrix is a numpy ndarray\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections.obs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections.obs.dtype" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Contour storage\nAll contours are stored on the same number of points, and are resampled if needed with an algorithm to be stored as objects\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Speed profile storage\nSpeed profile is an interpolation of speed mean along each contour.\nFor each contour included in eddy, we compute mean of speed along the contour,\nand after we interpolate speed mean array on a fixed size array.\n\nSeveral field are available to understand \"uavg_profile\" :\n 0. - num_contours : Number of contour in eddies, must be equal to amplitude divide by isoline step\n 1. - height_inner_contour : height of inner contour used\n 2. - height_max_speed_contour : height of max speed contour used\n 3. - height_external_contour : height of outter contour used\n\nLast value of \"uavg_profile\" is for inner contour and first value for outter contour.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Observations selection of \"uavg_profile\" with high number of contour(Eddy with high amplitude)\ne = eddies_collections.extract_with_mask(eddies_collections.num_contours > 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Raw display of profiles with more than 15 contours\nax = plt.subplot(111)\n_ = ax.plot(e.uavg_profile.T, lw=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Profile from inner to outter\nax = plt.subplot(111)\nax.plot(e.uavg_profile[:, ::-1].T, lw=0.5)\n_ = ax.set_xlabel(\"From inner to outter contour\"), ax.set_ylabel(\"Speed (m/s)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# If we normalize indice of contour to set speed contour to 1 and inner contour to 0\nax = plt.subplot(111)\nh_in = e.height_inner_contour\nh_s = e.height_max_speed_contour\nh_e = e.height_external_contour\nr = (h_e - h_in) / (h_s - h_in)\nnb_pt = e.uavg_profile.shape[1]\n# Create an x array for each profile\nx = outer(arange(nb_pt) / nb_pt, r)\n\nax.plot(x, e.uavg_profile[:, ::-1].T, lw=0.5)\n_ = ax.set_xlabel(\"From inner to outter contour\"), ax.set_ylabel(\"Speed (m/s)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Trajectories\nTracks eddies collections add several fields :\n\n- **track** : Trajectory number\n- **observation_flag** : Flag indicating if the value is interpolated between two observations or not\n (0: observed eddy, 1: interpolated eddy)\"\n- **observation_number** : Eddy temporal index in a trajectory, days starting at the eddy first detection\n- **cost_association** : result of the cost function to associate the eddy with the next observation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_tracks = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\n# In this example some fields are removed (effective_contour_longitude,...) in order to save time for doc building\neddies_tracks.field_table()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Networks\nNetwork files use some specific fields :\n\n- track : ID of network (ID 0 correspond to lonely eddies)\n- segment : ID of a segment within a network (from 1 to N)\n- previous_obs : Index of the previous observation in the full dataset,\n if -1 there are no previous observation (the segment starts)\n- next_obs : Index of the next observation in the full dataset, if -1 there are no next observation (the segment ends)\n- previous_cost : Result of the cost function (1 is a good association, 0 is bad) with previous observation\n- next_cost : Result of the cost function (1 is a good association, 0 is bad) with next observation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_network = NetworkObservations.load_file(get_demo_path(\"network_med.nc\"))\neddies_network.field_table()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sl = slice(70, 100)\nTable(\n eddies_network.network(651).obs[sl][\n [\n \"time\",\n \"track\",\n \"segment\",\n \"previous_obs\",\n \"previous_cost\",\n \"next_obs\",\n \"next_cost\",\n ]\n ]\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Networks are ordered by increasing network number (`track`), then increasing segment number, then increasing time\n\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb b/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb index 4f372a24..2d924387 100644 --- a/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nDisplay contour & circle\n========================\n" + "\n# Display contour & circle\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker import data" + "from matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.observations.observation import EddiesObservations" ] }, { @@ -44,14 +44,14 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Plot\n\n" + "Plot the speed and effective (dashed) contours\n\n" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(10, 70)\nax.set_ylim(-50, -25)\na.display(ax, label=\"Anticyclonic contour\", color=\"r\", lw=1)\n\n# Replace contour by circle\na.circle_contour()\na.display(ax, label=\"Anticyclonic circle\", color=\"g\", lw=1)\nax.legend(loc=\"upper right\")" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(10, 70)\nax.set_ylim(-50, -25)\na.display(ax, label=\"Anticyclonic contour\", color=\"r\", lw=1)\n\n# Replace contours by circles using center and radius (effective is dashed)\na.circle_contour()\na.display(ax, label=\"Anticyclonic circle\", color=\"g\", lw=1)\n_ = ax.legend(loc=\"upper right\")" ] } ], @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb b/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb index 57ca57d2..d59f9e15 100644 --- a/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nDisplay identification\n======================\n" + "\n# Display identification\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker import data" + "from matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.observations.observation import EddiesObservations" ] }, { @@ -44,14 +44,14 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_path(\"Cyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_demo_path(\"Cyclonic_20190223.nc\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Plot\n\n" + "Fill effective contour with amplitude\n\n" ] }, { @@ -62,7 +62,54 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_subplot(111)\nax.set_aspect(\"equal\")\nax.set_xlim(0, 360)\nax.set_ylim(-80, 80)\na.display(ax, label=\"Anticyclonic\", color=\"r\", lw=1)\nc.display(ax, label=\"Cyclonic\", color=\"b\", lw=1)\nax.legend(loc=\"upper right\")" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_aspect(\"equal\")\nax.set_xlim(0, 140)\nax.set_ylim(-80, 0)\nkwargs = dict(extern_only=True, color=\"k\", lw=1)\na.display(ax, **kwargs), c.display(ax, **kwargs)\na.filled(ax, \"amplitude\", cmap=\"magma_r\", vmin=0, vmax=0.5)\nm = c.filled(ax, \"amplitude\", cmap=\"magma_r\", vmin=0, vmax=0.5)\ncolorbar = plt.colorbar(m, cax=ax.figure.add_axes([0.95, 0.03, 0.02, 0.94]))\ncolorbar.set_label(\"Amplitude (m)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Draw speed contours\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.03, 0.03, 0.94, 0.94])\nax.set_aspect(\"equal\")\nax.set_xlim(0, 360)\nax.set_ylim(-80, 80)\na.display(ax, label=\"Anticyclonic ({nb_obs} eddies)\", color=\"r\", lw=1)\nc.display(ax, label=\"Cyclonic ({nb_obs} eddies)\", color=\"b\", lw=1)\nax.legend(loc=\"upper right\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get general informations\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "print(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "print(c)" ] } ], @@ -82,7 +129,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb index 8a6708a3..7469b034 100644 --- a/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nEddy detection\n==============\n\nScript will detect eddies on adt field, and compute u,v with method add_uv(which could use, only if equator is avoid)\n\nFigures will show different step to detect eddies.\n" + "\n# Eddy detection : Med\n\nScript will detect eddies on adt field, and compute u,v with method add_uv(which could use, only if equator is avoid)\n\nFigures will show different step to detect eddies.\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from datetime import datetime\nfrom matplotlib import pyplot as plt\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker import data" + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset" ] }, { @@ -37,14 +37,14 @@ }, "outputs": [], "source": [ - "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(m, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Load Input grid, ADT will be used to detect eddies\n\n" + "Load Input grid, ADT is used to detect eddies\n\n" ] }, { @@ -55,14 +55,14 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nupdate_axes(ax, m)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15, cmap=\"RdBu_r\")\nupdate_axes(ax, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Get u/v\n-------\nU/V are deduced from ADT, this algortihm are not usable around equator (~+- 2\u00b0)\n\n" + "## Get geostrophic speed u,v\nU/V are deduced from ADT, this algortihm is not ok near the equator (~+- 2\u00b0)\n\n" ] }, { @@ -73,14 +73,14 @@ }, "outputs": [], "source": [ - "g.add_uv(\"adt\")\nax = start_axes(\"U/V deduce from ADT (m)\")\nax.set_xlim(2.5, 9), ax.set_ylim(37.5, 40)\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nu, v = g.grid(\"u\").T, g.grid(\"v\").T\nax.quiver(g.x_c, g.y_c, u, v, scale=10)\nupdate_axes(ax, m)" + "g.add_uv(\"adt\")\nax = start_axes(\"U/V deduce from ADT (m)\")\nax.set_xlim(2.5, 9), ax.set_ylim(37.5, 40)\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15, cmap=\"RdBu_r\")\nu, v = g.grid(\"u\").T, g.grid(\"v\").T\nax.quiver(g.x_c, g.y_c, u, v, scale=10)\nupdate_axes(ax, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Pre-processings\n---------------\nApply high filter to remove long scale to highlight mesoscale\n\n" + "## Pre-processings\nApply a high-pass filter to remove the large scale and highlight the mesoscale\n\n" ] }, { @@ -91,14 +91,14 @@ }, "outputs": [], "source": [ - "g.bessel_high_filter(\"adt\", 500)\nax = start_axes(\"ADT (m) filtered (500km)\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nupdate_axes(ax, m)" + "g.bessel_high_filter(\"adt\", 500)\nax = start_axes(\"ADT (m) filtered (500km)\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15, cmap=\"RdBu_r\")\nupdate_axes(ax, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Identification\n--------------\nrun identification with slice of 2 mm\n\n" + "## Identification\nRun the identification step with slices of 2 mm\n\n" ] }, { @@ -109,14 +109,14 @@ }, "outputs": [], "source": [ - "date = datetime(2016, 5, 15)\na, c = g.eddy_identification(\"adt\", \"u\", \"v\", date, 0.002)" + "date = datetime(2016, 5, 15)\na, c = g.eddy_identification(\"adt\", \"u\", \"v\", date, 0.002, shape_error=55)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "All closed contour found in this input grid (Display only 1 contour every 4)\n\n" + "Display of all closed contours found in the grid (only 1 contour every 4)\n\n" ] }, { @@ -127,14 +127,14 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"ADT closed contour (only 1 / 4 levels)\")\ng.contours.display(ax, step=4)\nupdate_axes(ax)" + "ax = start_axes(\"ADT closed contours (only 1 / 4 levels)\")\ng.contours.display(ax, step=4)\nupdate_axes(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Contours include in eddies\n\n" + "Contours included in eddies\n\n" ] }, { @@ -145,14 +145,14 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"ADT contour used as eddies\")\ng.contours.display(ax, only_used=True)\nupdate_axes(ax)" + "ax = start_axes(\"ADT contours used as eddies\")\ng.contours.display(ax, only_used=True)\nupdate_axes(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Contours reject from several origin (shape error to high, several extremum in contour, ...)\n\n" + "## Post analysis\nContours can be rejected for several reasons (shape error to high, several extremum in contour, ...)\n\n" ] }, { @@ -163,14 +163,14 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"ADT contour reject\")\ng.contours.display(ax, only_unused=True)\nupdate_axes(ax)" + "ax = start_axes(\"ADT rejected contours\")\ng.contours.display(ax, only_unused=True)\nupdate_axes(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Contours closed which contains several eddies\n\n" + "Criteria for rejecting a contour:\n 0. - Accepted (green)\n 1. - Rejection for shape error (red)\n 2. - Masked value within contour (blue)\n 3. - Under or over the pixel limit bounds (black)\n 4. - Amplitude criterion (yellow)\n\n" ] }, { @@ -181,14 +181,14 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"ADT contour reject but which contain eddies\")\ng.contours.label_contour_unused_which_contain_eddies(a)\ng.contours.label_contour_unused_which_contain_eddies(c)\ng.contours.display(\n ax, only_contain_eddies=True, color=\"k\", lw=1, label=\"Could be interaction contour\"\n)\na.display(ax, color=\"r\", linewidth=0.5, label=\"Anticyclonic\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.5, label=\"Cyclonic\", ref=-10)\nax.legend()\nupdate_axes(ax)" + "ax = start_axes(\"Contours' rejection criteria\")\ng.contours.display(ax, only_unused=True, lw=0.5, display_criterion=True)\nupdate_axes(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Output\n------\nDisplay detected eddies, dashed lines represent effective contour\nand solid lines represent contour of maximum of speed. See figure 1 of https://doi.org/10.1175/JTECH-D-14-00019.1\n\n" + "Display the shape error of each tested contour, the limit of shape error is set to 55 %\n\n" ] }, { @@ -199,14 +199,14 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Eddies detected\")\na.display(ax, color=\"r\", linewidth=0.5, label=\"Anticyclonic\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.5, label=\"Cyclonic\", ref=-10)\nax.legend()\nupdate_axes(ax)" + "ax = start_axes(\"Contour shape error\")\nm = g.contours.display(\n ax, lw=0.5, field=\"shape_error\", bins=arange(20, 90.1, 5), cmap=\"PRGn_r\"\n)\nupdate_axes(ax, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Display speed radius of eddies detected\n\n" + "Some closed contours contains several eddies (aka, more than one extremum)\n\n" ] }, { @@ -217,7 +217,61 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Eddies speed radius (km)\")\na.scatter(ax, \"radius_s\", vmin=10, vmax=50, s=80, ref=-10, cmap=\"jet\", factor=0.001)\nm = c.scatter(ax, \"radius_s\", vmin=10, vmax=50, s=80, ref=-10, cmap=\"jet\", factor=0.001)\nupdate_axes(ax, m)" + "ax = start_axes(\"ADT rejected contours containing eddies\")\ng.contours.label_contour_unused_which_contain_eddies(a)\ng.contours.label_contour_unused_which_contain_eddies(c)\ng.contours.display(\n ax,\n only_contain_eddies=True,\n color=\"k\",\n lw=1,\n label=\"Could be a contour of interaction\",\n)\na.display(ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic\", ref=-10)\nax.legend()\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output\nWhen displaying the detected eddies, dashed lines are for effective contour, solide lines for the contour of\nthe maximum mean speed. See figure 1 of https://doi.org/10.1175/JTECH-D-14-00019.1\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Detected Eddies\")\na.display(\n ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10\n)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the speed radius of the detected eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Speed Radius (km)\")\nkwargs = dict(vmin=10, vmax=50, s=80, ref=-10, cmap=\"magma_r\", factor=0.001)\na.scatter(ax, \"radius_s\", **kwargs)\nm = c.scatter(ax, \"radius_s\", **kwargs)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Filling the effective radius contours with the effective radius values\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Effective Radius (km)\")\nkwargs = dict(vmin=10, vmax=80, cmap=\"magma_r\", factor=0.001, lut=14, ref=-10)\na.filled(ax, \"effective_radius\", **kwargs)\nm = c.filled(ax, \"radius_e\", **kwargs)\nupdate_axes(ax, m)" ] } ], @@ -237,7 +291,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb new file mode 100644 index 00000000..6ac75cee --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Eddy detection : Antartic Circumpolar Current\n\nThis script detect eddies on the ADT field, and compute u,v with the method add_uv (use it only if the Equator is avoided)\n\nTwo detections are provided : with a filtered ADT and without filtering\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib import style\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\n\npos_cb = [0.1, 0.52, 0.83, 0.015]\npos_cb2 = [0.1, 0.07, 0.4, 0.015]\n\n\ndef quad_axes(title):\n style.use(\"default\")\n fig = plt.figure(figsize=(13, 10))\n fig.suptitle(title, weight=\"bold\", fontsize=14)\n axes = list()\n\n ax_pos = dict(\n topleft=[0.1, 0.54, 0.4, 0.38],\n topright=[0.53, 0.54, 0.4, 0.38],\n botleft=[0.1, 0.09, 0.4, 0.38],\n botright=[0.53, 0.09, 0.4, 0.38],\n )\n\n for key, position in ax_pos.items():\n ax = fig.add_axes(position)\n ax.set_xlim(5, 45), ax.set_ylim(-60, -37)\n ax.set_aspect(\"equal\"), ax.grid(True)\n axes.append(ax)\n if \"right\" in key:\n ax.set_yticklabels(\"\")\n return fig, axes\n\n\ndef set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight=\"semibold\"):\n for ax in fig.get_axes():\n ax.grid()\n ax.grid(which=\"major\", linestyle=\"-\", linewidth=\"0.5\", color=\"black\")\n if ax.get_ylabel() != \"\":\n ax.set_ylabel(ax.get_ylabel(), fontsize=labelsize, fontweight=labelweight)\n if ax.get_xlabel() != \"\":\n ax.set_xlabel(ax.get_xlabel(), fontsize=labelsize, fontweight=labelweight)\n if ax.get_title() != \"\":\n ax.set_title(ax.get_title(), fontsize=labelsize, fontweight=labelweight)\n ax.tick_params(labelsize=ticklabelsize)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load Input grid, ADT is used to detect eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "margin = 30\n\nkw_data = dict(\n filename=data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n x_name=\"longitude\",\n y_name=\"latitude\",\n # Manual area subset\n indexs=dict(\n latitude=slice(100 - margin, 220 + margin),\n longitude=slice(0, 230 + margin),\n ),\n)\ng_raw = RegularGridDataset(**kw_data)\ng_raw.add_uv(\"adt\")\ng = RegularGridDataset(**kw_data)\ng.copy(\"adt\", \"adt_low\")\ng.bessel_high_filter(\"adt\", 700)\ng.bessel_low_filter(\"adt_low\", 700)\ng.add_uv(\"adt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Identification\nRun the identification step with slices of 2 mm\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "date = datetime(2016, 5, 15)\nkw_ident = dict(\n date=date, step=0.002, shape_error=70, sampling=30, uname=\"u\", vname=\"v\"\n)\na, c = g.eddy_identification(\"adt\", **kw_ident)\na_, c_ = g_raw.eddy_identification(\"adt\", **kw_ident)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Figures\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw_adt = dict(vmin=-1.5, vmax=1.5, cmap=plt.get_cmap(\"RdBu_r\", 30))\nfig, axs = quad_axes(\"General properties field\")\ng_raw.display(axs[0], \"adt\", **kw_adt)\naxs[0].set_title(\"Total ADT (m)\")\nm = g.display(axs[1], \"adt_low\", **kw_adt)\naxs[1].set_title(\"ADT (m) large scale, cutoff at 700 km\")\nm2 = g.display(axs[2], \"adt\", cmap=plt.get_cmap(\"RdBu_r\", 20), vmin=-0.5, vmax=0.5)\naxs[2].set_title(\"ADT (m) high-pass filtered, a cutoff at 700 km\")\ncb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation=\"horizontal\")\ncb.set_label(\"ADT (m)\", labelpad=0)\ncb2 = plt.colorbar(m2, cax=axs[2].figure.add_axes(pos_cb2), orientation=\"horizontal\")\ncb2.set_label(\"ADT (m)\", labelpad=0)\nset_fancy_labels(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The large-scale North-South gradient is removed by the filtering step.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, axs = quad_axes(\"\")\naxs[0].set_title(\"Without filter\")\naxs[0].set_ylabel(\"Contours used in eddies\")\naxs[1].set_title(\"With filter\")\naxs[2].set_ylabel(\"Closed contours but not used\")\ng_raw.contours.display(axs[0], lw=0.5, only_used=True)\ng.contours.display(axs[1], lw=0.5, only_used=True)\ng_raw.contours.display(axs[2], lw=0.5, only_unused=True)\ng.contours.display(axs[3], lw=0.5, only_unused=True)\nset_fancy_labels(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Removing the large-scale North-South gradient reveals closed contours in the\nSouth-Western corner of the ewample region.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(ref=-10, linewidth=0.75)\nkw_a = dict(color=\"r\", label=\"Anticyclonic ({nb_obs} eddies)\")\nkw_c = dict(color=\"b\", label=\"Cyclonic ({nb_obs} eddies)\")\nkw_filled = dict(vmin=0, vmax=100, cmap=\"Spectral_r\", lut=20, intern=True, factor=100)\nfig, axs = quad_axes(\"Comparison between two detections\")\n# Match with intern/inner contour\ni_a, j_a, s_a = a_.match(a, intern=True, cmin=0.15)\ni_c, j_c, s_c = c_.match(c, intern=True, cmin=0.15)\n\na_.index(i_a).filled(axs[0], s_a, **kw_filled)\na.index(j_a).filled(axs[1], s_a, **kw_filled)\nc_.index(i_c).filled(axs[0], s_c, **kw_filled)\nm = c.index(j_c).filled(axs[1], s_c, **kw_filled)\n\ncb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation=\"horizontal\")\ncb.set_label(\"Similarity index (%)\", labelpad=-5)\na_.display(axs[0], **kw, **kw_a), c_.display(axs[0], **kw, **kw_c)\na.display(axs[1], **kw, **kw_a), c.display(axs[1], **kw, **kw_c)\n\naxs[0].set_title(\"Without filter\")\naxs[0].set_ylabel(\"Detection\")\naxs[1].set_title(\"With filter\")\naxs[2].set_ylabel(\"Contours' rejection criteria\")\n\ng_raw.contours.display(axs[2], lw=0.5, only_unused=True, display_criterion=True)\ng.contours.display(axs[3], lw=0.5, only_unused=True, display_criterion=True)\n\nfor ax in axs:\n ax.legend()\n\nset_fancy_labels(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Very similar eddies have Similarity Indexes >= 40%\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Criteria for rejecting a contour :\n 0. Accepted (green)\n 1. Rejection for shape error (red)\n 2. Masked value within contour (blue)\n 3. Under or over the pixel limit bounds (black)\n 4. Amplitude criterion (yellow)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_a, j_a = i_a[s_a >= 0.4], j_a[s_a >= 0.4]\ni_c, j_c = i_c[s_c >= 0.4], j_c[s_c >= 0.4]\nfig = plt.figure(figsize=(12, 12))\nfig.suptitle(f\"Scatter plot (A : {i_a.shape[0]}, C : {i_c.shape[0]} matches)\")\n\nfor i, (label, field, factor, stop) in enumerate(\n (\n (\"Speed radius (km)\", \"radius_s\", 0.001, 120),\n (\"Effective radius (km)\", \"radius_e\", 0.001, 120),\n (\"Amplitude (cm)\", \"amplitude\", 100, 25),\n (\"Speed max (cm/s)\", \"speed_average\", 100, 25),\n )\n):\n ax = fig.add_subplot(2, 2, i + 1, title=label)\n ax.set_xlabel(\"Without filter\")\n ax.set_ylabel(\"With filter\")\n\n ax.plot(\n a_[field][i_a] * factor,\n a[field][j_a] * factor,\n \"r.\",\n label=\"Anticyclonic\",\n )\n ax.plot(\n c_[field][i_c] * factor,\n c[field][j_c] * factor,\n \"b.\",\n label=\"Cyclonic\",\n )\n ax.set_aspect(\"equal\"), ax.grid()\n ax.plot((0, 1000), (0, 1000), \"g\")\n ax.set_xlim(0, stop), ax.set_ylim(0, stop)\n ax.legend()\n\nset_fancy_labels(fig)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb new file mode 100644 index 00000000..49024327 --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Eddy detection : Gulf stream\n\nScript will detect eddies on adt field, and compute u,v with method add_uv(which could use, only if equator is avoid)\n\nFigures will show different step to detect eddies.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.eddy_feature import Contours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 8))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(279, 304), ax.set_ylim(29, 44)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load Input grid, ADT is used to detect eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "margin = 30\ng = RegularGridDataset(\n data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n # Manual area subset\n indexs=dict(\n longitude=slice(1116 - margin, 1216 + margin),\n latitude=slice(476 - margin, 536 + margin),\n ),\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-1, vmax=1, cmap=\"RdBu_r\")\n# Draw line on the gulf stream front\ngreat_current = Contours(g.x_c, g.y_c, g.grid(\"adt\"), levels=(0.35,), keep_unclose=True)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get geostrophic speed u,v\nU/V are deduced from ADT, this algortihm is not ok near the equator (~+- 2\u00b0)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g.add_uv(\"adt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-processings\nApply a high-pass filter to remove the large scale and highlight the mesoscale\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g.bessel_high_filter(\"adt\", 700)\nax = start_axes(\"ADT (m) filtered (700km)\")\nm = g.display(ax, \"adt\", vmin=-0.4, vmax=0.4, cmap=\"RdBu_r\")\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Identification\nRun the identification step with slices of 2 mm\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "date = datetime(2016, 5, 15)\na, c = g.eddy_identification(\"adt\", \"u\", \"v\", date, 0.002, shape_error=55)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display of all closed contours found in the grid (only 1 contour every 5)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"ADT closed contours (only 1 / 5 levels)\")\ng.contours.display(ax, step=5, lw=1)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Contours included in eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"ADT contours used as eddies\")\ng.contours.display(ax, only_used=True, lw=0.25)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Post analysis\nContours can be rejected for several reasons (shape error to high, several extremum in contour, ...)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"ADT rejected contours\")\ng.contours.display(ax, only_unused=True, lw=0.25)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Criteria for rejecting a contour :\n 0. Accepted (green)\n 1. Rejection for shape error (red)\n 2. Masked value within contour (blue)\n 3. Under or over the pixel limit bounds (black)\n 4. Amplitude criterion (yellow)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Contours' rejection criteria\")\ng.contours.display(ax, only_unused=True, lw=0.5, display_criterion=True)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the shape error of each tested contour, the limit of shape error is set to 55 %\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Contour shape error\")\nm = g.contours.display(\n ax, lw=0.5, field=\"shape_error\", bins=arange(20, 90.1, 5), cmap=\"PRGn_r\"\n)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some closed contours contains several eddies (aka, more than one extremum)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"ADT rejected contours containing eddies\")\ng.contours.label_contour_unused_which_contain_eddies(a)\ng.contours.label_contour_unused_which_contain_eddies(c)\ng.contours.display(\n ax,\n only_contain_eddies=True,\n color=\"k\",\n lw=1,\n label=\"Could be a contour of interaction\",\n)\na.display(ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic\", ref=-10)\nax.legend()\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Output\nWhen displaying the detected eddies, dashed lines are for effective contour, solide lines for the contour of the\nmaximum mean speed. See figure 1 of https://doi.org/10.1175/JTECH-D-14-00019.1\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Eddies detected\")\na.display(\n ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10\n)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the effective radius of the detected eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Effective radius (km)\")\na.filled(ax, \"radius_e\", vmin=10, vmax=150, cmap=\"magma_r\", factor=0.001, lut=14)\nm = c.filled(ax, \"radius_e\", vmin=10, vmax=150, cmap=\"magma_r\", factor=0.001, lut=14)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax, m)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb b/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb index 34c15b95..381aa8f6 100644 --- a/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nEddy detection and filter\n=========================\n" + "\n# Eddy detection and filter\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from datetime import datetime\nfrom matplotlib import pyplot as plt\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker import data\nfrom numpy import arange" + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset" ] }, { @@ -37,14 +37,14 @@ }, "outputs": [], "source": [ - "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Load Input grid, ADT will be used to detect eddies\n\n" + "Load Input grid, ADT is used to detect eddies.\nAdd a new filed to store the high-pass filtered ADT\n\n" ] }, { @@ -55,14 +55,14 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\",\n)\ng.add_uv(\"adt\")\ng.copy(\"adt\", \"adt_high\")\nwavelength = 400\ng.bessel_high_filter(\"adt_high\", wavelength)\ndate = datetime(2016, 5, 15)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\ng.add_uv(\"adt\")\ng.copy(\"adt\", \"adt_high\")\nwavelength = 800\ng.bessel_high_filter(\"adt_high\", wavelength)\ndate = datetime(2016, 5, 15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Run algorithm of detection\n\n" + "Run the detection for the total grid and the filtered grid\n\n" ] }, { @@ -73,14 +73,14 @@ }, "outputs": [], "source": [ - "a_f, c_f = g.eddy_identification(\"adt_high\", \"u\", \"v\", date, 0.002)\nmerge_f = a_f.merge(c_f)\na_r, c_r = g.eddy_identification(\"adt\", \"u\", \"v\", date, 0.002)\nmerge_r = a_r.merge(c_r)" + "a_filtered, c_filtered = g.eddy_identification(\"adt_high\", \"u\", \"v\", date, 0.002)\nmerge_f = a_filtered.merge(c_filtered)\na_tot, c_tot = g.eddy_identification(\"adt\", \"u\", \"v\", date, 0.002)\nmerge_t = a_tot.merge(c_tot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Display detection\n\n" + "Display the two detections\n\n" ] }, { @@ -91,14 +91,14 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Eddies detected over ADT\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nmerge_f.display(ax, lw=0.5, label=\"Eddy from filtered grid\", ref=-10, color=\"k\")\nmerge_r.display(ax, lw=0.5, label=\"Eddy from raw grid\", ref=-10, color=\"r\")\nax.legend()\nupdate_axes(ax, m)" + "ax = start_axes(\"Eddies detected over ADT\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nmerge_f.display(\n ax,\n lw=0.75,\n label=\"Eddies in the filtered grid ({nb_obs} eddies)\",\n ref=-10,\n color=\"k\",\n)\nmerge_t.display(\n ax, lw=0.75, label=\"Eddies without filter ({nb_obs} eddies)\", ref=-10, color=\"r\"\n)\nax.legend()\nupdate_axes(ax, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Parameters distribution\n-----------------------\n\n" + "## Amplitude and Speed Radius distributions\n\n" ] }, { @@ -109,14 +109,14 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 5))\nax_a = plt.subplot(121, xlabel=\"amplitdue(cm)\")\nax_r = plt.subplot(122, xlabel=\"speed radius (km)\")\nax_a.hist(\n merge_f[\"amplitude\"] * 100,\n bins=arange(0.0005, 100, 1),\n label=\"Eddy from filtered grid\",\n histtype=\"step\",\n)\nax_a.hist(\n merge_r[\"amplitude\"] * 100,\n bins=arange(0.0005, 100, 1),\n label=\"Eddy from raw grid\",\n histtype=\"step\",\n)\nax_a.set_xlim(0, 10)\nax_r.hist(merge_f[\"radius_s\"] / 1000.0, bins=arange(0, 300, 5), histtype=\"step\")\nax_r.hist(merge_r[\"radius_s\"] / 1000.0, bins=arange(0, 300, 5), histtype=\"step\")\nax_r.set_xlim(0, 100)\nax_a.legend()" + "fig = plt.figure(figsize=(12, 5))\nax_a = fig.add_subplot(121, xlabel=\"Amplitude (cm)\")\nax_r = fig.add_subplot(122, xlabel=\"Speed Radius (km)\")\nax_a.hist(\n merge_f.amplitude * 100,\n bins=arange(0.0005, 100, 1),\n label=\"Eddies in the filtered grid\",\n histtype=\"step\",\n)\nax_a.hist(\n merge_t.amplitude * 100,\n bins=arange(0.0005, 100, 1),\n label=\"Eddies without filter\",\n histtype=\"step\",\n)\nax_a.set_xlim(0, 10)\nax_r.hist(merge_f.radius_s / 1000.0, bins=arange(0, 300, 5), histtype=\"step\")\nax_r.hist(merge_t.radius_s / 1000.0, bins=arange(0, 300, 5), histtype=\"step\")\nax_r.set_xlim(0, 100)\nax_a.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Match detection and compare\n---------------------------\n\n" + "## Match detection and compare\n\n" ] }, { @@ -127,14 +127,14 @@ }, "outputs": [], "source": [ - "i_, j_, c = merge_f.match(merge_r, 0.1)" + "i_, j_, c = merge_f.match(merge_t, cmin=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "where is lonely eddies\n\n" + "Where are the lonely eddies?\n\n" ] }, { @@ -145,7 +145,7 @@ }, "outputs": [], "source": [ - "kwargs_f = dict(lw=1.5, label=\"Lonely eddy from filtered grid\", ref=-10, color=\"k\")\nkwargs_r = dict(lw=1.5, label=\"Lonely eddy from raw grid\", ref=-10, color=\"r\")\nax = start_axes(\"Eddies with no match, over filtered ADT\")\nmappable = g.display(ax, \"adt_high\", vmin=-0.15, vmax=0.15)\nmerge_f.index(i_, reverse=True).display(ax, **kwargs_f)\nmerge_r.index(j_, reverse=True).display(ax, **kwargs_r)\nax.legend()\nupdate_axes(ax, mappable)\n\nax = start_axes(\"Eddies with no match, over filtered ADT (zoom)\")\nax.set_xlim(25, 36), ax.set_ylim(31, 35.25)\nmappable = g.display(ax, \"adt_high\", vmin=-0.15, vmax=0.15)\nu, v = g.grid(\"u\").T, g.grid(\"v\").T\nax.quiver(g.x_c, g.y_c, u, v, scale=10, pivot=\"mid\", color=\"gray\")\nmerge_f.index(i_, reverse=True).display(ax, **kwargs_f)\nmerge_r.index(j_, reverse=True).display(ax, **kwargs_r)\nax.legend()\nupdate_axes(ax, mappable)" + "kwargs_f = dict(lw=1.5, label=\"Lonely eddies in the filtered grid\", ref=-10, color=\"k\")\nkwargs_t = dict(lw=1.5, label=\"Lonely eddies without filter\", ref=-10, color=\"r\")\nax = start_axes(\"Eddies with no match, over filtered ADT\")\nmappable = g.display(ax, \"adt_high\", vmin=-0.15, vmax=0.15)\nmerge_f.index(i_, reverse=True).display(ax, **kwargs_f)\nmerge_t.index(j_, reverse=True).display(ax, **kwargs_t)\nax.legend()\nupdate_axes(ax, mappable)\n\nax = start_axes(\"Eddies with no match, over filtered ADT (zoom)\")\nax.set_xlim(25, 36), ax.set_ylim(31, 35.25)\nmappable = g.display(ax, \"adt_high\", vmin=-0.15, vmax=0.15)\nu, v = g.grid(\"u\").T, g.grid(\"v\").T\nax.quiver(g.x_c, g.y_c, u, v, scale=10, pivot=\"mid\", color=\"gray\")\nmerge_f.index(i_, reverse=True).display(ax, **kwargs_f)\nmerge_t.index(j_, reverse=True).display(ax, **kwargs_t)\nax.legend()\nupdate_axes(ax, mappable)" ] }, { @@ -156,7 +156,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 12))\nfig.suptitle(f\"Scatter plot ({i_.shape[0]} matches)\")\n\nfor i, (label, field, factor, stop) in enumerate(\n (\n (\"speed radius (km)\", \"radius_s\", 0.001, 80),\n (\"outter radius (km)\", \"radius_e\", 0.001, 120),\n (\"amplitude (cm)\", \"amplitude\", 100, 25),\n (\"speed max (cm/s)\", \"speed_average\", 100, 25),\n )\n):\n ax = fig.add_subplot(\n 2, 2, i + 1, xlabel=\"filtered grid\", ylabel=\"raw grid\", title=label\n )\n ax.plot(merge_f[field][i_] * factor, merge_r[field][j_] * factor, \".\")\n ax.set_aspect(\"equal\"), ax.grid()\n ax.plot((0, 1000), (0, 1000), \"r\")\n ax.set_xlim(0, stop), ax.set_ylim(0, stop)" + "fig = plt.figure(figsize=(12, 12))\nfig.suptitle(f\"Scatter plot ({i_.shape[0]} matches)\", weight=\"bold\")\n\nfor i, (label, field, factor, stop) in enumerate(\n (\n (\"Speed radius (km)\", \"radius_s\", 0.001, 80),\n (\"Effective radius (km)\", \"radius_e\", 0.001, 120),\n (\"Amplitude (cm)\", \"amplitude\", 100, 25),\n (\"Maximum Speed (cm/s)\", \"speed_average\", 100, 25),\n )\n):\n ax = fig.add_subplot(\n 2, 2, i + 1, xlabel=\"Filtered grid\", ylabel=\"Without filter\", title=label\n )\n ax.plot(merge_f[field][i_] * factor, merge_t[field][j_] * factor, \".\")\n ax.set_aspect(\"equal\"), ax.grid()\n ax.plot((0, 1000), (0, 1000), \"r\")\n ax.set_xlim(0, stop), ax.set_ylim(0, stop)" ] } ], @@ -176,7 +176,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb b/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb new file mode 100644 index 00000000..0cfdc9a8 --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Get mean of grid in each eddies\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n ax.legend()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files and data to interp\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20160515.nc\"))\nc = EddiesObservations.load_file(data.get_demo_path(\"Cyclonic_20160515.nc\"))\n\naviso_map = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\naviso_map.add_uv(\"adt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute and store eke in cm\u00b2/s\u00b2\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "aviso_map.add_grid(\n \"eke\", (aviso_map.grid(\"u\") ** 2 + aviso_map.grid(\"v\") ** 2) * 0.5 * (100 ** 2)\n)\n\neke_kwargs = dict(vmin=1, vmax=1000, cmap=\"magma_r\")\n\nax = start_axes(\"EKE (cm\u00b2/s\u00b2)\")\nm = aviso_map.display(ax, \"eke\", **eke_kwargs)\na.display(ax, color=\"r\", linewidth=0.5, label=\"Anticyclonic\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.5, label=\"Cyclonic\", ref=-10)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get mean of eke in each effective contour\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"EKE mean (cm\u00b2/s\u00b2)\")\na.display(ax, color=\"r\", linewidth=0.5, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.5, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\neke = a.interp_grid(aviso_map, \"eke\", method=\"mean\", intern=False)\na.filled(ax, eke, ref=-10, **eke_kwargs)\neke = c.interp_grid(aviso_map, \"eke\", method=\"mean\", intern=False)\nm = c.filled(ax, eke, ref=-10, **eke_kwargs)\nupdate_axes(ax, m)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb b/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb new file mode 100644 index 00000000..03eba8bf --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb @@ -0,0 +1,115 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Radius vs area\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import array, pi\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import coordinates_to_local\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import poly_area" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\nareas = list()\n# For each contour area will be compute in local reference\nfor i in a:\n x, y = coordinates_to_local(\n i[\"contour_lon_s\"], i[\"contour_lat_s\"], i[\"lon\"], i[\"lat\"]\n )\n areas.append(poly_area(x, y))\nareas = array(areas)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Radius provided by eddy detection is computed with :func:`~py_eddy_tracker.poly.fit_circle` method.\nThis radius will be compared with an equivalent radius deduced from polygon area.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.subplot(111)\nax.set_aspect(\"equal\")\nax.grid()\nax.set_xlabel(\"Speed radius computed with fit_circle\")\nax.set_ylabel(\"Radius deduced from area\\nof contour_lon_s/contour_lat_s\")\nax.set_title(\"Area vs radius\")\nax.plot(a[\"radius_s\"] / 1000.0, (areas / pi) ** 0.5 / 1000.0, \".\")\nax.plot((0, 250), (0, 250), \"r\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Fit circle give a radius bigger than polygon area\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When error is tiny, radius are very close.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.subplot(111)\nax.grid()\nax.set_xlabel(\"Radius ratio\")\nax.set_ylabel(\"Shape error\")\nax.set_title(\"err = f(radius_ratio)\")\nax.plot(a[\"radius_s\"] / (areas / pi) ** 0.5, a[\"shape_error_s\"], \".\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb b/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb index 34687a19..0ef03f6f 100644 --- a/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nShape error gallery\n===================\n\nGallery of contours with shape error\n" + "\n# Shape error gallery\n\nGallery of contours with shape error\n" ] }, { @@ -26,14 +26,14 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom numpy import arange, radians, linspace, cos, sin\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.eddy_feature import Contours\nfrom py_eddy_tracker.generic import local_to_coordinates" + "from matplotlib import pyplot as plt\nfrom numpy import arange, cos, linspace, radians, sin\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.eddy_feature import Contours\nfrom py_eddy_tracker.generic import local_to_coordinates" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Method to compute circle path\n\n" + "Method to built circle from center coordinates\n\n" ] }, { @@ -62,14 +62,14 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\nc = Contours(g.x_c, g.y_c, g.grid(\"adt\") * 100, arange(-50, 50, 0.2))\ncontours = dict()\nfor coll in c.iter():\n for current_contour in coll.get_paths():\n _, _, _, aerr = current_contour.fit_circle()\n i = int(aerr // 4) + 1\n if i not in contours:\n contours[i] = list()\n contours[i].append(current_contour)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\nc = Contours(g.x_c, g.y_c, g.grid(\"adt\") * 100, arange(-50, 50, 0.2))\ncontours = dict()\nfor coll in c.iter():\n for current_contour in coll.get_paths():\n _, _, _, aerr = current_contour.fit_circle()\n i = int(aerr // 4) + 1\n if i not in contours:\n contours[i] = list()\n contours[i].append(current_contour)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Shape error gallery\n-------------------\nFor each contour display, we display circle fitted, we work at different latitude circle could have distorsion\n\n" + "## Shape error gallery\nFor each contour display, we display circle fitted, we work at different latitude circle could have distorsion\n\n" ] }, { @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb b/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb index 8e166895..9b8b3951 100644 --- a/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nEddy detection on SLA and ADT\n=============================\n" + "\n# Eddy detection on SLA and ADT\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from datetime import datetime\nfrom matplotlib import pyplot as plt\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker import data" + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset" ] }, { @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\",\n)\ng.add_uv(\"adt\", \"ugos\", \"vgos\")\ng.add_uv(\"sla\", \"ugosa\", \"vgosa\")\nwavelength = 400\ng.copy(\"adt\", \"adt_raw\")\ng.copy(\"sla\", \"sla_raw\")\ng.bessel_high_filter(\"adt\", wavelength)\ng.bessel_high_filter(\"sla\", wavelength)\ndate = datetime(2016, 5, 15)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\ng.add_uv(\"adt\", \"ugos\", \"vgos\")\ng.add_uv(\"sla\", \"ugosa\", \"vgosa\")\nwavelength = 400\ng.copy(\"adt\", \"adt_raw\")\ng.copy(\"sla\", \"sla_raw\")\ng.bessel_high_filter(\"adt\", wavelength)\ng.bessel_high_filter(\"sla\", wavelength)\ndate = datetime(2016, 5, 15)" ] }, { @@ -66,7 +66,7 @@ }, "outputs": [], "source": [ - "kwargs_a_adt = dict(lw=0.5, label=\"Anticyclonic ADT\", ref=-10, color=\"k\")\nkwargs_c_adt = dict(lw=0.5, label=\"Cyclonic ADT\", ref=-10, color=\"r\")\nkwargs_a_sla = dict(lw=0.5, label=\"Anticyclonic SLA\", ref=-10, color=\"g\")\nkwargs_c_sla = dict(lw=0.5, label=\"Cyclonic SLA\", ref=-10, color=\"b\")" + "kwargs_a_adt = dict(\n lw=0.5, label=\"Anticyclonic ADT ({nb_obs} eddies)\", ref=-10, color=\"k\"\n)\nkwargs_c_adt = dict(lw=0.5, label=\"Cyclonic ADT ({nb_obs} eddies)\", ref=-10, color=\"r\")\nkwargs_a_sla = dict(\n lw=0.5, label=\"Anticyclonic SLA ({nb_obs} eddies)\", ref=-10, color=\"g\"\n)\nkwargs_c_sla = dict(lw=0.5, label=\"Cyclonic SLA ({nb_obs} eddies)\", ref=-10, color=\"b\")" ] }, { @@ -145,7 +145,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Match\n-----------------------\nWhere cyclone meet anticyclone\n\n" + "## Match\nWhere cyclone meet anticyclone\n\n" ] }, { @@ -156,14 +156,14 @@ }, "outputs": [], "source": [ - "i_c_adt, i_a_sla, c = c_adt.match(a_sla, 0.1)\ni_a_adt, i_c_sla, c = a_adt.match(c_sla, 0.1)\n\nax = start_axes(\"Cyclone share area with anticyclone\")\na_adt.index(i_a_adt).display(ax, **kwargs_a_adt)\nc_adt.index(i_c_adt).display(ax, **kwargs_c_adt)\na_sla.index(i_a_sla).display(ax, **kwargs_a_sla)\nc_sla.index(i_c_sla).display(ax, **kwargs_c_sla)\nax.legend()\nupdate_axes(ax)" + "i_c_adt, i_a_sla, c = c_adt.match(a_sla, cmin=0.01)\ni_a_adt, i_c_sla, c = a_adt.match(c_sla, cmin=0.01)\n\nax = start_axes(\"Cyclone share area with anticyclone\")\na_adt.index(i_a_adt).display(ax, **kwargs_a_adt)\nc_adt.index(i_c_adt).display(ax, **kwargs_c_adt)\na_sla.index(i_a_sla).display(ax, **kwargs_a_sla)\nc_sla.index(i_c_sla).display(ax, **kwargs_c_sla)\nax.legend()\nupdate_axes(ax)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Scatter plot\n------------\n\n" + "## Scatter plot\n\n" ] }, { @@ -174,7 +174,7 @@ }, "outputs": [], "source": [ - "i_a_adt, i_a_sla, c = a_adt.match(a_sla, 0.1)\ni_c_adt, i_c_sla, c = c_adt.match(c_sla, 0.1)" + "i_a_adt, i_a_sla, c = a_adt.match(a_sla, cmin=0.1)\ni_c_adt, i_c_sla, c = c_adt.match(c_sla, cmin=0.1)" ] }, { @@ -203,7 +203,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 12))\nfig.suptitle(f\"Scatter plot (A : {i_a_adt.shape[0]}, C : {i_c_adt.shape[0]} matches)\")\n\nfor i, (label, field, factor, stop) in enumerate(\n (\n (\"speed radius (km)\", \"radius_s\", 0.001, 80),\n (\"outter radius (km)\", \"radius_e\", 0.001, 120),\n (\"amplitude (cm)\", \"amplitude\", 100, 25),\n (\"speed max (cm/s)\", \"speed_average\", 100, 25),\n )\n):\n ax = fig.add_subplot(2, 2, i + 1, title=label)\n ax.set_xlabel(\"Absolute Dynamic Topography\")\n ax.set_ylabel(\"Sea Level Anomaly\")\n\n ax.plot(a_adt[field][i_a_adt] * factor, a_sla[field][i_a_sla] * factor, \"r.\")\n ax.plot(c_adt[field][i_c_adt] * factor, c_sla[field][i_c_sla] * factor, \"b.\")\n ax.set_aspect(\"equal\"), ax.grid()\n ax.plot((0, 1000), (0, 1000), \"g\")\n ax.set_xlim(0, stop), ax.set_ylim(0, stop)" + "fig = plt.figure(figsize=(12, 12))\nfig.suptitle(f\"Scatter plot (A : {i_a_adt.shape[0]}, C : {i_c_adt.shape[0]} matches)\")\n\nfor i, (label, field, factor, stop) in enumerate(\n (\n (\"speed radius (km)\", \"radius_s\", 0.001, 80),\n (\"outter radius (km)\", \"radius_e\", 0.001, 120),\n (\"amplitude (cm)\", \"amplitude\", 100, 25),\n (\"speed max (cm/s)\", \"speed_average\", 100, 25),\n )\n):\n ax = fig.add_subplot(2, 2, i + 1, title=label)\n ax.set_xlabel(\"Absolute Dynamic Topography\")\n ax.set_ylabel(\"Sea Level Anomaly\")\n\n ax.plot(\n a_adt[field][i_a_adt] * factor,\n a_sla[field][i_a_sla] * factor,\n \"r.\",\n label=\"Anticyclonic\",\n )\n ax.plot(\n c_adt[field][i_c_adt] * factor,\n c_sla[field][i_c_sla] * factor,\n \"b.\",\n label=\"Cyclonic\",\n )\n ax.set_aspect(\"equal\"), ax.grid()\n ax.plot((0, 1000), (0, 1000), \"g\")\n ax.set_xlim(0, stop), ax.set_ylim(0, stop)\n ax.legend()" ] } ], @@ -223,7 +223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb new file mode 100644 index 00000000..7fa04435 --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Stastics on identification files\n\nSome statistics on raw identification without any tracking\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom matplotlib.dates import date2num\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.observations.observation import EddiesObservations\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load demo sample and take only first year.\n\nReplace by a list of filename to apply on your own dataset.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:365]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Merge all identification dataset in one object\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "all_a = EddiesObservations.concatenate(\n [EddiesObservations.load_file(i) for i in file_objects]\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define polygon bound\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, x1, y0, y1 = 15, 20, 33, 38\nxs = np.array([[x0, x1, x1, x0, x0]], dtype=\"f8\")\nys = np.array([[y0, y0, y1, y1, y0]], dtype=\"f8\")\n# Polygon object is create to be usable by match function.\npolygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Geographic frequency of eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 0.125\nax = start_axes(\"\")\n# Count pixel used for each contour\ng_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True)\nm = g_a.display(\n ax, cmap=\"terrain_r\", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name=\"count\"\n)\nax.plot(polygon[\"contour_lon_e\"][0], polygon[\"contour_lat_e\"][0], \"r\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use match function to count number of eddies which intersect the polygon defined previously.\n`p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False)\ni_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "dt = np.datetime64(\"1970-01-01\") - np.datetime64(\"1950-01-01\")\nkw_hist = dict(\n bins=date2num(np.arange(21900, 22300).astype(\"datetime64\") - dt), histtype=\"step\"\n)\n# translate julian day in datetime64\nt = all_a.time.astype(\"datetime64\") - dt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count how many are in polygon\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Different way to count eddies presence in a polygon\")\nax.set_ylabel(\"Count\")\nm = all_a.mask_from_polygons(((xs, ys),))\nax.hist(t[m], label=\"center in polygon\", **kw_hist)\nax.hist(t[i_s[c_s > 0]], label=\"intersect speed contour with polygon\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], label=\"intersect extern contour with polygon\", **kw_hist)\nax.legend()\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Percent of are of interest occupy by eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Percent of polygon occupy by an anticyclonic eddy\")\nax.set_ylabel(\"Percent of polygon\")\nax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label=\"speed contour\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label=\"effective contour\", **kw_hist)\nax.legend(), ax.set_ylim(0, 25)\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb new file mode 100644 index 00000000..90ee1722 --- /dev/null +++ b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Grid advection\n\nDummy advection which use only static geostrophic current, which didn't solve the complex circulation of the ocean.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load Input grid ADT\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g = RegularGridDataset(\n get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\n# Compute u/v from height\ng.add_uv(\"adt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(get_demo_path(\"Anticyclonic_20160515.nc\"))\nc = EddiesObservations.load_file(get_demo_path(\"Cyclonic_20160515.nc\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Quiver from u/v with eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(10, 5))\nax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\nax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid()\nx, y = meshgrid(g.x_c, g.y_c)\na.filled(ax, facecolors=\"r\", alpha=0.1), c.filled(ax, facecolors=\"b\", alpha=0.1)\n_ = ax.quiver(x.T, y.T, g.grid(\"u\"), g.grid(\"v\"), scale=20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Anim\nParticles setup\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step_p = 1 / 8\nx, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p))\nx, y = x.reshape(-1), y.reshape(-1)\n# Remove all original position that we can't advect at first place\nm = ~isnan(g.interp(\"u\", x, y))\nx0, y0 = x[m], y[m]\nx, y = x0.copy(), y0.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Movie properties\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kwargs = dict(frames=arange(51), interval=100)\nkw_p = dict(u_name=\"u\", v_name=\"v\", nb_step=2, time_step=21600)\nframe_t = kw_p[\"nb_step\"] * kw_p[\"time_step\"] / 86400.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def anim_ax(**kw):\n t = 0\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n a.filled(axes, facecolors=\"r\", alpha=0.1), c.filled(axes, facecolors=\"b\", alpha=0.1)\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line, t\n\n\ndef update(i_frame, t_step):\n global t\n x, y = p.__next__()\n t += t_step\n l.set_data(x, y)\n txt.set_text(f\"T0 + {t:.1f} days\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filament forward\nDraw 3 last position in one path for each particles.,\nit could be run backward with `backward=True` option in filament method\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "p = g.filament(x, y, **kw_p, filament_size=3)\nfig, txt, l, t = anim_ax(lw=0.5)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Particle forward\nForward advection of particles\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "p = g.advect(x, y, **kw_p)\nfig, txt, l, t = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get last position and run backward until original position\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "p = g.advect(x, y, **kw_p, backward=True)\nfig, txt, l, _ = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles stat\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Time_step settings\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.001),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Time duration\nWe keep same time_step but change time duration\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (5, 50, 100):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb b/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb index f4f798a6..2d6a7d3a 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nGrid filtering in PET\n=====================\n\nHow filter work in py eddy tracker. This implementation maybe doesn't respect state art, but ...\n\nWe code a specific filter in order to filter grid with same wavelength at each pixel.\n" + "\n# Grid filtering in PET\n\nHow filter work in py eddy tracker. This implementation maybe doesn't respect state art, but ...\n\nWe code a specific filter in order to filter grid with same wavelength at each pixel.\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker import data\nfrom matplotlib import pyplot as plt\nfrom numpy import arange\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(m, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + "from matplotlib import pyplot as plt\nfrom numpy import arange\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" ] }, { @@ -44,14 +44,14 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\",\n)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Kernel\n------\nShape of kernel will increase in x, when latitude increase\n\n" + "## Kernel\nShape of kernel will increase in x, when latitude increase\n\n" ] }, { @@ -87,7 +87,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Kernel applying\n---------------\nOriginal grid\n\n" + "## Kernel applying\nOriginal grid\n\n" ] }, { @@ -141,7 +141,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Clues\n-----\nwavelength : 80km\n\n" + "## Clues\nwavelength : 80km\n\n" ] }, { @@ -195,7 +195,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Old filter\n----------\nTo do ...\n\n" + "## Old filter\nTo do ...\n\n" ] } ], @@ -215,7 +215,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb b/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb index 823c37c4..f30076fa 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nSelect pixel in eddies\n======================\n" + "\n# Select pixel in eddies\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom matplotlib.path import Path\nfrom numpy import ones\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.poly import create_vertice\nfrom py_eddy_tracker import data" + "from matplotlib import pyplot as plt\nfrom matplotlib.path import Path\nfrom numpy import ones\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import create_vertice" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)" + "g = RegularGridDataset(\n data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 6))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(10, 70)\nax.set_ylim(-50, -25)\n# We will used the outter contour\nx_name, y_name = a.intern(False)\nadt = g.grid(\"adt\")\nmask = ones(adt.shape, dtype='bool')\nfor eddy in a:\n i, j = Path(create_vertice(eddy[x_name], eddy[y_name])).pixels_in(g)\n mask[i, j] = False\nadt.mask[:] += ~mask\ng.display(ax, \"adt\")\na.display(ax, label=\"Anticyclonic\", color=\"g\", lw=1, extern_only=True)" + "fig = plt.figure(figsize=(12, 6))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(10, 70)\nax.set_ylim(-50, -25)\n# We will used the outter contour\nx_name, y_name = a.intern(False)\nadt = g.grid(\"adt\")\nmask = ones(adt.shape, dtype=\"bool\")\nfor eddy in a:\n i, j = Path(create_vertice(eddy[x_name], eddy[y_name])).pixels_in(g)\n mask[i, j] = False\nadt.mask[:] += ~mask\ng.display(ax, \"adt\")\na.display(ax, label=\"Anticyclonic\", color=\"g\", lw=1, extern_only=True)" ] }, { @@ -111,7 +111,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb new file mode 100644 index 00000000..cbe6de64 --- /dev/null +++ b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# LAVD experiment\n\nNaive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation) method with a static velocity field.\nIn the current example we didn't remove a mean vorticity.\n\nMethod are described here:\n\n - Abernathey, Ryan, and George Haller. \"Transport by Lagrangian Vortices in the Eastern Pacific\",\n Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021,\n https://doi.org/10.1175/JPO-D-17-0102.1\n - `Transport by Coherent Lagrangian Vortices`_,\n R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019,\n Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop\n\n https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, zeros\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(16, 9), dpi=dpi)\n ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n ax.set_xlim(0, 32), ax.set_ylim(28, 46)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"Vorticity integration along trajectory at initial position\")\n return cb\n\n\nkw_vorticity = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\nTo compute vorticity ($\\omega$) we compute u/v field with a stencil and apply the following equation with stencil\nmethod :\n\n\\begin{align}\\omega = \\frac{\\partial v}{\\partial x} - \\frac{\\partial u}{\\partial y}\\end{align}\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g = RegularGridDataset(\n get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\ng.add_uv(\"adt\")\nu_y = g.compute_stencil(g.grid(\"u\"), vertical=True)\nv_x = g.compute_stencil(g.grid(\"v\"))\ng.vars[\"vort\"] = v_x - u_y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display vorticity field\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, ax, _ = start_ax()\nmappable = g.display(ax, abs(g.grid(\"vort\")), **kw_vorticity)\ncb = update_axes(ax, mappable)\ncb.set_label(\"Vorticity\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles\nParticles specification\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 1 / 32\nx_g, y_g = arange(0, 36, step), arange(28, 46, step)\nx, y = meshgrid(x_g, y_g)\noriginal_shape = x.shape\nx, y = x.reshape(-1), y.reshape(-1)\nprint(f\"{len(x)} particles advected\")\n# A frame every 8h\nstep_by_day = 3\n# Compute step of advection every 4h\nnb_step = 2\nkw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name=\"u\", v_name=\"v\")\n# Start a generator which at each iteration return new position at next time step\nparticule = g.advect(x, y, **kw_p, rk4=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\n# Advection time\nnb_days = 8\n# Nb frame\nnb_time = step_by_day * nb_days\ni = 0.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Anim\nMovie of LAVD integration at each integration time step.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(i_frame):\n global lavd, i\n i += 1\n x, y = particule.__next__()\n # Interp vorticity on new_position\n lavd += abs(g.interp(\"vort\", x, y).reshape(original_shape) * 1 / nb_time)\n txt.set_text(f\"T0 + {i / step_by_day:.2f} days of advection\")\n pcolormesh.set_array(lavd / i * nb_time)\n return pcolormesh, txt\n\n\nkw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True)\nfig, ax, txt = start_ax(dpi=60)\nx_g_, y_g_ = (\n arange(0 - step / 2, 36 + step / 2, step),\n arange(28 - step / 2, 46 + step / 2, step),\n)\n# pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular\npcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity)\nupdate_axes(ax, pcolormesh)\n_ = VideoAnimation(ax.figure, update, **kw_video)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Final LAVD\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Format LAVD data\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display final LAVD with py eddy tracker detection.\nPeriod used for LAVD integration (8 days) is too short for a real use, but choose for example efficiency.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, ax, _ = start_ax()\nmappable = lavd.display(ax, \"lavd\", **kw_vorticity)\nEddiesObservations.load_file(get_demo_path(\"Anticyclonic_20160515.nc\")).display(\n ax, color=\"k\"\n)\nEddiesObservations.load_file(get_demo_path(\"Cyclonic_20160515.nc\")).display(\n ax, color=\"k\"\n)\n_ = update_axes(ax, mappable)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb b/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb new file mode 100644 index 00000000..ca4998ee --- /dev/null +++ b/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Get Okubo Weis\n\n\\begin{align}OW = S_n^2 + S_s^2 + \\omega^2\\end{align}\n\nwith normal strain ($S_n$), shear strain ($S_s$) and vorticity ($\\omega$)\n\n\\begin{align}S_n = \\frac{\\partial u}{\\partial x} - \\frac{\\partial v}{\\partial y},\n S_s = \\frac{\\partial v}{\\partial x} + \\frac{\\partial u}{\\partial y},\n \\omega = \\frac{\\partial v}{\\partial x} - \\frac{\\partial u}{\\partial y}\\end{align}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import arange, ma, where\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title, zoom=False):\n fig = plt.figure(figsize=(12, 6))\n axes = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n axes.set_xlim(0, 360), axes.set_ylim(-80, 80)\n if zoom:\n axes.set_xlim(270, 340), axes.set_ylim(20, 50)\n axes.set_aspect(\"equal\")\n axes.set_title(title)\n return axes\n\n\ndef update_axes(axes, mappable=None):\n axes.grid()\n if mappable:\n plt.colorbar(mappable, cax=axes.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_demo_path(\"Cyclonic_20190223.nc\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load Input grid, ADT will be used to detect eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g = RegularGridDataset(\n data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)\n\nax = start_axes(\"ADT (cm)\")\nm = g.display(ax, \"adt\", vmin=-120, vmax=120, factor=100)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get parameter for ow\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "u_x = g.compute_stencil(g.grid(\"ugos\"))\nu_y = g.compute_stencil(g.grid(\"ugos\"), vertical=True)\nv_x = g.compute_stencil(g.grid(\"vgos\"))\nv_y = g.compute_stencil(g.grid(\"vgos\"), vertical=True)\now = g.vars[\"ow\"] = (u_x - v_y) ** 2 + (v_x + u_y) ** 2 - (v_x - u_y) ** 2\n\nax = start_axes(\"Okubo weis\")\nm = g.display(ax, \"ow\", vmin=-1e-10, vmax=1e-10, cmap=\"bwr\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Gulf stream zoom\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Okubo weis, Gulf stream\", zoom=True)\nm = g.display(ax, \"ow\", vmin=-1e-10, vmax=1e-10, cmap=\"bwr\")\nkw_ed = dict(intern_only=True, color=\"k\", lw=1)\na.display(ax, **kw_ed), c.display(ax, **kw_ed)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "only negative OW\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Okubo weis, Gulf stream\", zoom=True)\nthreshold = ow.std() * -0.2\now = ma.array(ow, mask=ow > threshold)\nm = g.display(ax, ow, vmin=-1e-10, vmax=1e-10, cmap=\"bwr\")\na.display(ax, **kw_ed), c.display(ax, **kw_ed)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get okubo-weiss mean/min/center in eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(8, 6))\nax = plt.subplot(111)\nax.set_xlabel(\"Okubo-Weiss parameter\")\nkw_hist = dict(bins=arange(-20e-10, 20e-10, 50e-12), histtype=\"step\")\nfor method in (\"mean\", \"center\", \"min\"):\n kw_interp = dict(grid_object=g, varname=\"ow\", method=method, intern=True)\n _, _, m = ax.hist(\n a.interp_grid(**kw_interp), label=f\"Anticyclonic - OW {method}\", **kw_hist\n )\n ax.hist(\n c.interp_grid(**kw_interp),\n label=f\"Cyclonic - OW {method}\",\n color=m[0].get_edgecolor(),\n ls=\"--\",\n **kw_hist,\n )\nax.axvline(threshold, color=\"r\")\nax.set_yscale(\"log\")\nax.grid()\nax.set_ylim(1, 1e4)\nax.set_xlim(-15e-10, 15e-10)\nax.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Catch eddies with bad OW\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Eddies with a min OW in speed contour over threshold\")\now_min = a.interp_grid(**kw_interp)\na_bad_ow = a.index(where(ow_min > threshold)[0])\na_bad_ow.display(ax, color=\"r\", label=\"Anticyclonic\")\now_min = c.interp_grid(**kw_interp)\nc_bad_ow = c.index(where(ow_min > threshold)[0])\nc_bad_ow.display(ax, color=\"b\", label=\"Cyclonic\")\nax.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display Radius and amplitude of eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(12, 5))\nfig.suptitle(\n \"Parameter distribution (solid line) and cumulative distribution (dashed line)\"\n)\nax_amp, ax_rad = fig.add_subplot(121), fig.add_subplot(122)\nax_amp_c, ax_rad_c = ax_amp.twinx(), ax_rad.twinx()\nax_amp_c.set_ylim(0, 1), ax_rad_c.set_ylim(0, 1)\nkw_a = dict(xname=\"amplitude\", bins=arange(0, 2, 0.002).astype(\"f4\"))\nkw_r = dict(xname=\"radius_s\", bins=arange(0, 500e6, 2e3).astype(\"f4\"))\nfor d, label, color in (\n (a, \"Anticyclonic all\", \"r\"),\n (a_bad_ow, \"Anticyclonic bad OW\", \"orange\"),\n (c, \"Cyclonic all\", \"blue\"),\n (c_bad_ow, \"Cyclonic bad OW\", \"lightblue\"),\n):\n x, y = d.bins_stat(**kw_a)\n ax_amp.plot(x * 100, y, label=label, color=color)\n ax_amp_c.plot(\n x * 100, y.cumsum() / y.sum(), label=label, color=color, ls=\"-.\", lw=0.5\n )\n x, y = d.bins_stat(**kw_r)\n ax_rad.plot(x * 1e-3, y, label=label, color=color)\n ax_rad_c.plot(\n x * 1e-3, y.cumsum() / y.sum(), label=label, color=color, ls=\"-.\", lw=0.5\n )\n\nax_amp.set_xlim(0, 12.5), ax_amp.grid(), ax_amp.set_ylim(0), ax_amp.legend()\nax_rad.set_xlim(0, 120), ax_rad.grid(), ax_rad.set_ylim(0)\nax_amp.set_xlabel(\"Amplitude (cm)\"), ax_amp.set_ylabel(\"Nb eddies\")\nax_rad.set_xlabel(\"Speed radius (km)\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb b/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb new file mode 100644 index 00000000..d4cdb187 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nTime advection\n==============\n\nExample which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# sphinx_gallery_thumbnail_number = 2\nimport re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.gui import GUI_AXES\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data\n----\nLoad Input time grid ADT\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n # To create U/V variable\n heigth=\"adt\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\nParticles setup\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step_p = 1 / 8\nx, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p))\nx, y = x.reshape(-1), y.reshape(-1)\n# Remove all original position that we can't advect at first place\nt0 = 20181\nm = ~isnan(c[t0].interp(\"u\", x, y))\nx0, y0 = x[m], y[m]\nx, y = x0.copy(), y0.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def anim_ax(**kw):\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line\n\n\ndef update(_):\n tt, xt, yt = f.__next__()\n mappable.set_data(xt, yt)\n d = timedelta(tt / 86400.0) + datetime(1950, 1, 1)\n txt.set_text(f\"{d:%Y/%m/%d-%H}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "f = c.filament(x, y, \"u\", \"v\", t_init=t0, nb_step=2, time_step=21600, filament_size=3)\nfig, txt, mappable = anim_ax(lw=0.5)\nani = VideoAnimation(fig, update, frames=arange(160), interval=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Particules stat\n---------------\nTime_step settings\n^^^^^^^^^^^^^^^^^^\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.002),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nkw_p = dict(u_name=\"u\", v_name=\"v\", nb_step=1)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n nb = int(30 * 86400 / time_step)\n # Go forward\n p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n # Go backward\n p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Time duration\n^^^^^^^^^^^^^\nWe keep same time_step but change time duration\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (10, 40, 80):\n x, y = x0.copy(), y0.copy()\n nb = int(duration * 86400 / time_step)\n # Go forward\n p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n # Go backward\n p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb new file mode 100644 index 00000000..6f52e750 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# FSLE experiment in med\n\nExample to build Finite Size Lyapunov Exponents, parameter values must be adapted for your case.\n\nExample use a method similar to `AVISO flse`_\n\n https://www.aviso.altimetry.fr/en/data/products/value-added-products/\n fsle-finite-size-lyapunov-exponents/fsle-description.html\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import arange, arctan2, empty, isnan, log, ma, meshgrid, ones, pi, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ADT in med\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is\nmade for data stores in time cube, you could use also\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to\nload data-cube from multiple file.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n # To create U/V variable\n heigth=\"adt\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Methods to compute FSLE\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "@njit(cache=True, fastmath=True)\ndef check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6):\n \"\"\"\n Check if distance between eastern or northern particle to center particle is bigger than `dist_max`\n \"\"\"\n nb_p = x.shape[0] // 3\n delta = dist_max**2\n for i in range(nb_p):\n i0 = i * 3\n i_n = i0 + 1\n i_e = i0 + 2\n # If particle already set, we skip\n if m[i0] or m[i_n] or m[i_e]:\n continue\n # Distance with north\n dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n]\n dn = dxn**2 + dyn**2\n # Distance with east\n dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e]\n de = dxe**2 + dye**2\n\n if dn >= delta or de >= delta:\n s1 = dn + de\n at1 = 2 * (dxe * dxn + dye * dyn)\n at2 = de - dn\n s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * (\n (dxn - dye) ** 2 + (dxe + dyn) ** 2\n )\n flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5))\n theta[i] = arctan2(at1, at2 + s2) * 180 / pi\n # To know where value are set\n m_set[i] = False\n # To stop particle advection\n m[i0], m[i_n], m[i_e] = True, True, True\n\n\n@njit(cache=True)\ndef build_triplet(x, y, step=0.02):\n \"\"\"\n Triplet building for each position we add east and north point with defined step\n \"\"\"\n nb_x = x.shape[0]\n x_ = empty(nb_x * 3, dtype=x.dtype)\n y_ = empty(nb_x * 3, dtype=y.dtype)\n for i in range(nb_x):\n i0 = i * 3\n i_n, i_e = i0 + 1, i0 + 2\n x__, y__ = x[i], y[i]\n x_[i0], y_[i0] = x__, y__\n x_[i_n], y_[i_n] = x__, y__ + step\n x_[i_e], y_[i_e] = x__ + step, y__\n return x_, y_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Settings\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Step in degrees for ouput\nstep_grid_out = 1 / 25.0\n# Initial separation in degrees\ndist_init = 1 / 50.0\n# Final separation in degrees\ndist_max = 1 / 5.0\n# Time of start\nt0 = 20268\n# Number of time step by days\ntime_step_by_days = 5\n# Maximal time of advection\n# Here we limit because our data cube cover only 3 month\nnb_days = 85\n# Backward or forward\nbackward = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0_, y0_ = -5, 30\nlon_p = arange(x0_, x0_ + 43, step_grid_out)\nlat_p = arange(y0_, y0_ + 16, step_grid_out)\ny0, x0 = meshgrid(lat_p, lon_p)\ngrid_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Identify all particle not on land\nm = ~isnan(c[t0].interp(\"adt\", x0, y0))\nx0, y0 = x0[m], y0[m]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FSLE\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Array to compute fsle\nfsle = zeros(x0.shape[0], dtype=\"f4\")\ntheta = zeros(x0.shape[0], dtype=\"f4\")\nmask = ones(x0.shape[0], dtype=\"f4\")\nx, y = build_triplet(x0, y0, dist_init)\nused = zeros(x.shape[0], dtype=\"bool\")\n\n# advection generator\nkw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name=\"u\", v_name=\"v\")\np = c.advect(x, y, time_step=86400 / time_step_by_days, **kw)\n\n# We check at each step of advection if particle distance is over `dist_max`\nfor i in range(time_step_by_days * nb_days):\n t, xt, yt = p.__next__()\n dt = t / 86400.0 - t0\n check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init)\n\n# Get index with original_position\ni = ((x0 - x0_) / step_grid_out).astype(\"i4\")\nj = ((y0 - y0_) / step_grid_out).astype(\"i4\")\nfsle_ = empty(grid_shape, dtype=\"f4\")\ntheta_ = empty(grid_shape, dtype=\"f4\")\nmask_ = ones(grid_shape, dtype=\"bool\")\nfsle_[i, j] = fsle\ntheta_[i, j] = theta\nmask_[i, j] = mask\n# Create a grid object\nfsle_custom = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n fsle=ma.array(fsle_, mask=mask_),\n theta=ma.array(theta_, mask=mask_),\n lon=lon_p,\n lat=lat_p,\n ),\n centered=True,\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display FSLE\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(13, 5), dpi=150)\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\nax.set_aspect(\"equal\")\nax.set_title(\"Finite size lyapunov exponent\", weight=\"bold\")\nkw = dict(cmap=\"viridis_r\", vmin=-20, vmax=0)\nm = fsle_custom.display(ax, 1 / fsle_custom.grid(\"fsle\"), **kw)\nax.grid()\n_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display Theta\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(13, 5), dpi=150)\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\nax.set_aspect(\"equal\")\nax.set_title(\"Theta from finite size lyapunov exponent\", weight=\"bold\")\nkw = dict(cmap=\"Spectral_r\", vmin=-180, vmax=180)\nm = fsle_custom.display(ax, fsle_custom.grid(\"theta\"), **kw)\nax.grid()\n_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb new file mode 100644 index 00000000..708d7024 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# LAVD detection and geometric detection\n\nNaive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation).\nIn the current example we didn't remove a mean vorticity.\n\nMethod are described here:\n\n - Abernathey, Ryan, and George Haller. \"Transport by Lagrangian Vortices in the Eastern Pacific\",\n Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021,\n https://doi.org/10.1175/JPO-D-17-0102.1\n - `Transport by Coherent Lagrangian Vortices`_,\n R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019,\n Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop\n\n https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, isnan, ma, meshgrid, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class LAVDGrid(RegularGridDataset):\n def init_speed_coef(self, uname=\"u\", vname=\"v\"):\n \"\"\"Hack to be able to identify eddy with LAVD field\"\"\"\n self._speed_ev = self.grid(\"lavd\")\n\n @classmethod\n def from_(cls, x, y, z):\n z.mask += isnan(z.data)\n datas = dict(lavd=z, lon=x, lat=y)\n return cls.with_array(coordinates=(\"lon\", \"lat\"), datas=datas, centered=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(12, 5), dpi=dpi)\n ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=GUI_AXES)\n ax.set_xlim(-6, 36), ax.set_ylim(31, 45)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"LAVD at initial position\")\n return cb\n\n\nkw_lavd = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Load data cube of 3 month\nc = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n heigth=\"adt\",\n)\n\n# Add vorticity at each time step\nfor g in c:\n u_y = g.compute_stencil(g.grid(\"u\"), vertical=True)\n v_x = g.compute_stencil(g.grid(\"v\"))\n g.vars[\"vort\"] = v_x - u_y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Time properties, for example with advection only 25 days\nnb_days, step_by_day = 25, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name=\"u\", v_name=\"v\")\nt0 = 20236\nt0_grid = c[t0]\n# Geographic properties, we use a coarser resolution for time consuming reasons\nstep = 1 / 32.0\nx_g, y_g = arange(-6, 36, step), arange(30, 46, step)\nx0, y0 = meshgrid(x_g, y_g)\noriginal_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Get all particles in defined area\nm = ~isnan(t0_grid.interp(\"vort\", x0, y0))\nx0, y0 = x0[m], y0[m]\nprint(f\"{x0.size} particles advected\")\n# Gridded mask\nm = m.reshape(original_shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD forward (dynamic field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p)\nfor _ in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection\")\nmappable = lavd_forward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD backward (dynamic field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p)\nfor i in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection\")\nmappable = lavd_backward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD forward (static field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), **kw_p)\nfor _ in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection on a static velocity field\")\nmappable = lavd_forward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD backward (static field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p)\nfor i in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection on a static velocity field\")\nmappable = lavd_backward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Contour detection\nTo extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options.\nIt will produce false amplitude and speed.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw_ident = dict(\n force_speed_unit=\"m/s\",\n force_height_unit=\"m\",\n pixel_limit=(40, 200000),\n date=datetime(2005, 5, 18),\n uname=None,\n vname=None,\n grid_height=\"lavd\",\n shape_error=70,\n step=1e-6,\n)\nfig, ax, _ = start_ax(\"Detection of eddies with several method\")\nt0_grid.bessel_high_filter(\"adt\", 700)\na, c = t0_grid.eddy_identification(\n \"adt\", \"u\", \"v\", kw_ident[\"date\"], step=0.002, shape_error=70\n)\nkw_ed = dict(ax=ax, intern=True, ref=-10)\na.filled(\n facecolors=\"#FFEFCD\", label=\"Anticyclonic SSH detection {nb_obs} eddies\", **kw_ed\n)\nc.filled(facecolors=\"#DEDEDE\", label=\"Cyclonic SSH detection {nb_obs} eddies\", **kw_ed)\nkw_cont = dict(ax=ax, extern_only=True, ls=\"-\", ref=-10)\nforward, _ = lavd_forward.eddy_identification(**kw_ident)\nforward.display(label=\"LAVD forward {nb_obs} eddies\", color=\"g\", **kw_cont)\nbackward, _ = lavd_backward.eddy_identification(**kw_ident)\nbackward.display(label=\"LAVD backward {nb_obs} eddies\", color=\"r\", **kw_cont)\nforward, _ = lavd_forward_static.eddy_identification(**kw_ident)\nforward.display(label=\"LAVD forward static {nb_obs} eddies\", color=\"cyan\", **kw_cont)\nbackward, _ = lavd_backward_static.eddy_identification(**kw_ident)\nbackward.display(\n label=\"LAVD backward static {nb_obs} eddies\", color=\"orange\", **kw_cont\n)\nax.legend()\nupdate_axes(ax)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb new file mode 100644 index 00000000..b92c4d21 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Build path of particle drifting\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import arange, meshgrid\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data cube\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n unset=True\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Advection properties\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "nb_days, step_by_day = 10, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day)\nt0 = 20210" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5))\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\nt, x, y = c.path(x0, y0, h_name=\"adt\", t_init=t0, **kw_p, nb_time=nb_time)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect=\"equal\")\nax.plot(x0, y0, \"k.\", ms=20)\nax.plot(x, y, lw=3)\nax.set_title(\"10 days particle paths\")\nax.set_xlim(31, 35), ax.set_ylim(32, 34.5)\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb index f260e137..6e43e9a4 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nDisplay fields\n==============\n" + "\n# Display fields\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc = c.extract_with_length((180, -1))" + "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc = c.extract_with_length((180, -1))" ] }, { @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb index fbc3a93c..c98e53f0 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nDisplay Tracks\n======================\n" + "\n# Display Tracks\n" ] }, { @@ -26,14 +26,14 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Load experimental atlas, and keep only eddies longer than 20 weeks\n\n" + "Load experimental atlas\n\n" ] }, { @@ -44,7 +44,25 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\na = a.extract_with_length((7 * 20, -1))\nc = c.extract_with_length((7 * 20, -1))" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nprint(a)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "keep only eddies longer than 20 weeks, use -1 to have no upper limit\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = a.extract_with_length((7 * 20, -1))\nc = c.extract_with_length((7 * 20, -1))\nprint(a)" ] }, { @@ -80,7 +98,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.1, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\na.plot(ax, ref=-10, label=\"Anticyclonic\", color=\"r\", lw=0.1)\nc.plot(ax, ref=-10, label=\"Cyclonic\", color=\"b\", lw=0.1)\nax.legend()\nax.grid()" + "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.1, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\na.plot(ax, ref=-10, label=\"Anticyclonic ({nb_tracks} tracks)\", color=\"r\", lw=0.1)\nc.plot(ax, ref=-10, label=\"Cyclonic ({nb_tracks} tracks)\", color=\"b\", lw=0.1)\nax.legend()\nax.grid()" ] } ], @@ -100,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb new file mode 100644 index 00000000..0681c0fc --- /dev/null +++ b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Correspondances\n\nCorrespondances is a mechanism to intend to continue tracking with new detection\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom netCDF4 import Dataset\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\n\n# In order to hide some warning\nimport py_eddy_tracker.observations.observation\nfrom py_eddy_tracker.tracking import Correspondances\n\npy_eddy_tracker.observations.observation._display_check_warning = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_eddy(ed):\n fig = plt.figure(figsize=(10, 5))\n ax = fig.add_axes([0.05, 0.03, 0.90, 0.94])\n ed.plot(ax, ref=-10, marker=\"x\")\n lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True)\n plt.colorbar(lc).set_label(\"Time in Julian days (from 1950/01/01)\")\n ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3)\n ax.set_aspect(\"equal\")\n ax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get remote data, we will keep only 20 first days,\n`get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:20]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We run a traking with a tracker which use contour overlap, on 10 first time step\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_first_run = Correspondances(\n datasets=file_objects[:10], class_method=AreaTracker, virtual=4\n)\nstart_logger().setLevel(\"INFO\")\nc_first_run.track()\nstart_logger().setLevel(\"WARNING\")\nwith Dataset(\"correspondances.nc\", \"w\") as h:\n c_first_run.to_netcdf(h)\n# Next step are done only to build atlas and display it\nc_first_run.prepare_merging()\n\n# We have now an eddy object\neddies_area_tracker = c_first_run.merge(raw_data=False)\neddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0\neddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot from first ten days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restart from previous run\nWe give all filenames, the new one and filename from previous run\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_second_run = Correspondances(\n datasets=file_objects[:20],\n # This parameter must be identical in each run\n class_method=AreaTracker,\n virtual=4,\n # Previous saved correspondancs\n previous_correspondance=\"correspondances.nc\",\n)\nstart_logger().setLevel(\"INFO\")\nc_second_run.track()\nstart_logger().setLevel(\"WARNING\")\nc_second_run.prepare_merging()\n# We have now another eddy object\neddies_area_tracker_extend = c_second_run.merge(raw_data=False)\neddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0\neddies_area_tracker_extend.filled_by_interpolation(\n eddies_area_tracker_extend.virtual == 1\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot with time extension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker_extend)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb index 21455f51..95595a7a 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nOne Track\n===================\n" + "\n# One Track\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\neddy = a.extract_ids([9672])\neddy_f = a.extract_ids([9672])\neddy_f.position_filter(median_half_window=1, loess_half_window=5)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\neddy = a.extract_ids([9672])\neddy_f = a.extract_ids([9672])\neddy_f.position_filter(median_half_window=1, loess_half_window=5)" ] }, { @@ -73,7 +73,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_xlim(17, 23)\nax.set_ylim(34.5, 37)\nax.set_aspect(\"equal\")\nax.grid()\neddy.plot(ax, color=\"r\", lw=0.5, label=\"track\")\neddy.index(range(0, len(eddy), 40)).display(\n ax, intern_only=True, label=\"observations every 40\"\n)\nax.legend()" + "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_xlim(17, 23)\nax.set_ylim(34.5, 37)\nax.set_aspect(\"equal\")\nax.grid()\neddy.plot(ax, color=\"r\", lw=0.5, label=\"track\")\neddy.index(range(0, len(eddy), 40)).display(\n ax, intern_only=True, label=\"observations every 40 days\"\n)\nax.legend()" ] } ], @@ -93,7 +93,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb index 3f00ff96..d0a2e5b0 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nTrack in python\n===============\n\nThis example didn't replace EddyTracking, we remove check that application do and also postprocessing step.\n" + "\n# Track in python\n\nThis example didn't replace EddyTracking, we remove check that application do and also postprocessing step.\n" ] }, { @@ -26,14 +26,14 @@ }, "outputs": [], "source": [ - "from py_eddy_tracker.data import get_remote_sample\nfrom py_eddy_tracker.tracking import Correspondances\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\nfrom py_eddy_tracker.gui import GUI" + "from py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\nfrom py_eddy_tracker.gui import GUI\nfrom py_eddy_tracker.tracking import Correspondances" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Get remote data, we will keep only 180 first days\n\n" + "Get remote data, we will keep only 180 first days,\n`get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "file_objects = get_remote_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:180]" + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:180]" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "c = Correspondances(datasets=file_objects, class_method=AreaTracker, virtual=3)\nc.track()\nc.prepare_merging()\n# We have now an eddy object\neddies_area_tracker = c.merge(raw_data=False)\neddies_area_tracker[\"virtual\"][:] = eddies_area_tracker[\"time\"] == 0\neddies_area_tracker.filled_by_interpolation(eddies_area_tracker[\"virtual\"] == 1)" + "c = Correspondances(datasets=file_objects, class_method=AreaTracker, virtual=3)\nc.track()\nc.prepare_merging()\n# We have now an eddy object\neddies_area_tracker = c.merge(raw_data=False)\neddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0\neddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1)" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "c = Correspondances(datasets=file_objects, virtual=3)\nc.track()\nc.prepare_merging()\neddies_default_tracker = c.merge(raw_data=False)\neddies_default_tracker[\"virtual\"][:] = eddies_default_tracker[\"time\"] == 0\neddies_default_tracker.filled_by_interpolation(eddies_default_tracker[\"virtual\"] == 1)" + "c = Correspondances(datasets=file_objects, virtual=3)\nc.track()\nc.prepare_merging()\neddies_default_tracker = c.merge(raw_data=False)\neddies_default_tracker.virtual[:] = eddies_default_tracker.time == 0\neddies_default_tracker.filled_by_interpolation(eddies_default_tracker.virtual == 1)" ] }, { @@ -154,7 +154,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb index 1cdefe1f..8e64b680 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nTracks which go through area\n============================\n" + "\n# Tracks which go through area\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc.position_filter(median_half_window=1, loess_half_window=5)" + "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc.position_filter(median_half_window=1, loess_half_window=5)" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_xlim(-1, 9)\nax.set_ylim(36, 40)\nax.set_aspect(\"equal\")\nax.grid()\nc.plot(ax, color=\"gray\", lw=0.1, ref=-10, label=\"all tracks\")\nc_subset.plot(ax, color=\"red\", lw=0.2, ref=-10, label=\"selected tracks\")\nax.plot(\n (x0, x0, x1, x1, x0,),\n (y0, y1, y1, y0, y0,),\n color=\"green\",\n lw=1.5,\n label=\"Box of selection\",\n)\nax.legend()" + "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_xlim(-1, 9)\nax.set_ylim(36, 40)\nax.set_aspect(\"equal\")\nax.grid()\nc.plot(ax, color=\"gray\", lw=0.1, ref=-10, label=\"All tracks ({nb_tracks} tracks)\")\nc_subset.plot(\n ax, color=\"red\", lw=0.2, ref=-10, label=\"selected tracks ({nb_tracks} tracks)\"\n)\nax.plot(\n (x0, x0, x1, x1, x0),\n (y0, y1, y1, y0, y0),\n color=\"green\",\n lw=1.5,\n label=\"Box of selection\",\n)\nax.legend()" ] } ], @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb index e9c9e4dc..08364d16 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nTrack animation\n===============\n\nRun in a terminal this script, which allow to watch eddy evolution\n" + "\nTrack animation\n===============\n\nRun in a terminal this script, which allow to watch eddy evolution.\n\nYou could use also *EddyAnim* script to display/save animation.\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from py_eddy_tracker.observations.tracking import TrackEddiesObservations\nfrom py_eddy_tracker.appli.gui import Anim\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,14 +44,14 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\neddy = a.extract_ids([9672])" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\n# We get only 300 first step to save time of documentation builder\neddy = a.extract_ids([9672]).index(slice(0, 300))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Run animation\nKey shortcut\n Escape => exit\n SpaceBar => pause\n left arrow => t - 1\n right arrow => t + 1\n + => speed increase of 10 %\n - => speed decrease of 10 %\n\n" + "Run animation\nKey shortcut :\n * Escape => exit\n * SpaceBar => pause\n * left arrow => t - 1\n * right arrow => t + 1\n * \\+ => speed increase of 10 %\n * \\- => speed decrease of 10 %\n\n" ] }, { diff --git a/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb new file mode 100644 index 00000000..1fc4d082 --- /dev/null +++ b/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb @@ -0,0 +1,101 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nTrack animation with standard matplotlib\n========================================\n\nRun in a terminal this script, which allow to watch eddy evolution.\n\nYou could use also *EddyAnim* script to display/save animation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nimport py_eddy_tracker_sample\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n# sphinx_gallery_thumbnail_path = '_static/no_image.png'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load experimental atlas, and we select one eddy\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\neddy = a.extract_ids([9672])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run animation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = Anim(eddy, intern=True, figsize=(8, 3.5), cmap=\"magma_r\", nb_step=5, dpi=50)\na.txt.set_position((17, 34.6))\na.ax.set_xlim(16.5, 23)\na.ax.set_ylim(34.5, 37)\n\n# arguments to get full animation\nkwargs = dict(frames=arange(*a.period)[300:800], interval=90)\n\nani = VideoAnimation(a.fig, a.func_animation, **kwargs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb new file mode 100644 index 00000000..635c6b5a --- /dev/null +++ b/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Birth and death\n\nFollowing figures are based on https://doi.org/10.1016/j.pocean.2011.01.002\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load an experimental med atlas over a period of 26 years (1993-2019)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kwargs_load = dict(\n include_vars=(\n \"longitude\",\n \"latitude\",\n \"observation_number\",\n \"track\",\n \"time\",\n \"speed_contour_longitude\",\n \"speed_contour_latitude\",\n )\n)\na = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t0, t1 = a.period\nstep = 0.125\nbins = ((-10, 37, step), (30, 46, step))\nkwargs = dict(cmap=\"terrain_r\", factor=100 / (t1 - t0), name=\"count\", vmin=0, vmax=1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cyclonic\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Birth cyclonic frenquency (%)\")\ng_c_first = c.first_obs().grid_count(bins, intern=True)\nm = g_c_first.display(ax, **kwargs)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Death cyclonic frenquency (%)\")\ng_c_last = c.last_obs().grid_count(bins, intern=True)\nm = g_c_last.display(ax, **kwargs)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Anticyclonic\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Birth anticyclonic frequency (%)\")\ng_a_first = a.first_obs().grid_count(bins, intern=True)\nm = g_a_first.display(ax, **kwargs)\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"Death anticyclonic frequency (%)\")\ng_a_last = a.last_obs().grid_count(bins, intern=True)\nm = g_a_last.display(ax, **kwargs)\nupdate_axes(ax, m)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb index 644a9277..753cd625 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nCount center\n======================\n" + "\n# Count center\n\nDo Geo stat with center and compare with frequency method\nshow: `sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py`\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom matplotlib.colors import LogNorm\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom matplotlib.colors import LogNorm\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,7 +44,25 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\n\n# Parameters\nt0, t1 = a.period\nstep = 0.1\nbins = ((-10, 37, step), (30, 46, step))\nkwargs_pcolormesh = dict(\n cmap=\"terrain_r\", vmin=0, vmax=2, factor=1 / (step ** 2 * (t1 - t0)), name=\"count\"\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 0.125\nbins = ((-10, 37, step), (30, 46, step))\nkwargs_pcolormesh = dict(\n cmap=\"terrain_r\", vmin=0, vmax=2, factor=1 / (a.nb_days * step ** 2), name=\"count\"\n)" ] }, { @@ -62,7 +80,25 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic center frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic center frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies center frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each center\ng_a = a.grid_count(bins, intern=True, center=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True, center=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\ncb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.27, 0.01, 0.7]))\ncb.set_label(\"Eddies by 1\u00b0^2 by day\")\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(ax_ratio, name=\"count\", vmin=0.1, vmax=10, norm=LogNorm())\nplt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" + "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic center frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic center frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies center frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each center\ng_a = a.grid_count(bins, intern=True, center=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True, center=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\ncb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.27, 0.01, 0.7]))\ncb.set_label(\"Eddies by 1\u00b0^2 by day\")\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(\n ax_ratio, name=\"count\", norm=LogNorm(vmin=0.1, vmax=10), cmap=\"coolwarm_r\"\n)\nplt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Count Anticyclones as a function of lifetime\nCount at the center's position\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(12, 10))\nmask = a.lifetime >= 60\nax_long = fig.add_axes([0.03, 0.53, 0.90, 0.45])\ng_a = a.grid_count(bins, center=True, filter=mask)\ng_a.display(ax_long, **kwargs_pcolormesh)\nax_long.set_title(f\"Anticyclones with lifetime >= 60 days ({mask.sum()} Obs)\")\nax_short = fig.add_axes([0.03, 0.03, 0.90, 0.45])\ng_a = a.grid_count(bins, center=True, filter=~mask)\nm = g_a.display(ax_short, **kwargs_pcolormesh)\nax_short.set_title(f\"Anticyclones with lifetime < 60 days ({(~mask).sum()} Obs)\")\nfor ax in (ax_short, ax_long):\n ax.set_aspect(\"equal\"), ax.grid()\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\ncb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.015, 0.9]))" ] } ], @@ -82,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb index e4d31e85..df495703 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nGeographical statistics\n=======================\n" + "\n# Geographical statistics\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13.5, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13.5, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\na = a.merge(c)\n\nstep = 0.1" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\na = a.merge(c)\n\nstep = 0.1" ] }, { @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb new file mode 100644 index 00000000..9f06e010 --- /dev/null +++ b/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb @@ -0,0 +1,144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Groups distribution\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, ones, percentile\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load an experimental med atlas over a period of 26 years (1993-2019)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Group distribution\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "groups = dict()\nbins_time = [10, 20, 30, 60, 90, 180, 360, 100000]\nfor t0, t1 in zip(bins_time[:-1], bins_time[1:]):\n groups[f\"lifetime_{t0}_{t1}\"] = lambda dataset, t0=t0, t1=t1: (\n dataset.lifetime >= t0\n ) * (dataset.lifetime < t1)\nbins_percentile = arange(0, 100.0001, 5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function to build stats\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def stats_compilation(dataset, groups, field, bins, filter=None):\n datas = dict(ref=dataset.bins_stat(field, bins=bins, mask=filter)[1], y=dict())\n for k, index in groups.items():\n i = dataset.merge_filters(filter, index)\n x, datas[\"y\"][k] = dataset.bins_stat(field, bins=bins, mask=i)\n datas[\"x\"], datas[\"bins\"] = x, bins\n return datas\n\n\ndef plot_stats(ax, bins, x, y, ref, box=False, cmap=None, percentiles=None, **kw):\n base, ref = ones(x.shape) * 100.0, ref / 100.0\n x = arange(bins.shape[0]).repeat(2)[1:-1] if box else x\n y0 = base\n if cmap is not None:\n cmap, nb_groups = plt.get_cmap(cmap), len(y)\n keys = tuple(y.keys())\n for i, k in enumerate(keys[::-1]):\n y1 = y0 - y[k] / ref\n args = (y0.repeat(2), y1.repeat(2)) if box else (y0, y1)\n if cmap is not None:\n kw[\"color\"] = cmap(1 - i / (nb_groups - 1))\n ax.fill_between(x, *args, label=k, **kw)\n y0 = y1\n if percentiles:\n for b in bins:\n ax.axvline(b, **percentiles)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Speed radius by track period\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "stats = stats_compilation(\n a, groups, \"radius_s\", percentile(a.radius_s, bins_percentile)\n)\nfig = plt.figure()\nax = fig.add_subplot(111)\nplot_stats(ax, **stats, cmap=\"magma\", percentiles=dict(color=\"gray\", ls=\"-.\", lw=0.4))\nax.set_xlabel(\"Speed radius (m)\"), ax.set_ylabel(\"% of class\"), ax.set_ylim(0, 100)\nax.grid(), ax.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Amplitude by track period\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "stats = stats_compilation(\n a, groups, \"amplitude\", percentile(a.amplitude, bins_percentile)\n)\nfig = plt.figure()\nax = fig.add_subplot(111)\nplot_stats(ax, **stats, cmap=\"magma\")\nax.set_xlabel(\"Amplitude (m)\"), ax.set_ylabel(\"% of class\"), ax.set_ylim(0, 100)\nax.grid(), ax.legend()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb index 77317b16..81809d8b 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nParameter Histogram\n===================\n" + "\n# Parameter Histogram\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample\nfrom numpy import arange" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nkwargs_a = dict(label=\"Anticyclonic\", color=\"r\", histtype=\"step\", density=True)\nkwargs_c = dict(label=\"Cyclonic\", color=\"b\", histtype=\"step\", density=True)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nkwargs_a = dict(label=\"Anticyclonic\", color=\"r\", histtype=\"step\", density=True)\nkwargs_c = dict(label=\"Cyclonic\", color=\"b\", histtype=\"step\", density=True)" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 7))\n\nfor x0, name, title, xmax, factor, bins in zip(\n (0.4, 0.72, 0.08),\n (\"radius_s\", \"speed_average\", \"amplitude\"),\n (\"Speed radius (km)\", \"Speed average (cm/s)\", \"Amplitude (cm)\"),\n (100, 50, 20),\n (0.001, 100, 100),\n (arange(0, 2000, 1), arange(0, 1000, 0.5), arange(0.0005, 1000, 0.2)),\n):\n ax_hist = fig.add_axes((x0, 0.24, 0.27, 0.35))\n nb_a, _, _ = ax_hist.hist(a[name] * factor, bins=bins, **kwargs_a)\n nb_c, _, _ = ax_hist.hist(c[name] * factor, bins=bins, **kwargs_c)\n ax_hist.set_xticklabels([])\n ax_hist.set_xlim(0, xmax)\n ax_hist.grid()\n\n ax_cum = fig.add_axes((x0, 0.62, 0.27, 0.35))\n ax_cum.hist(a[name] * factor, bins=bins, cumulative=-1, **kwargs_a)\n ax_cum.hist(c[name] * factor, bins=bins, cumulative=-1, **kwargs_c)\n ax_cum.set_xticklabels([])\n ax_cum.set_title(title)\n ax_cum.set_xlim(0, xmax)\n ax_cum.set_ylim(0, 1)\n ax_cum.grid()\n\n ax_ratio = fig.add_axes((x0, 0.06, 0.27, 0.15))\n ax_ratio.set_xlim(0, xmax)\n ax_ratio.set_ylim(0, 2)\n ax_ratio.plot((bins[1:] + bins[:-1]) / 2, nb_c / nb_a)\n ax_ratio.axhline(1, color=\"k\")\n ax_ratio.grid()\n ax_ratio.set_xlabel(title)\n\nax_cum.set_ylabel(\"Cumulative\\npercent distribution\")\nax_hist.set_ylabel(\"Percent of observations\")\nax_ratio.set_ylabel(\"Ratio percent\\nCyc/Acyc\")\nax_cum.legend()" + "fig = plt.figure(figsize=(12, 7))\n\nfor x0, name, title, xmax, factor, bins in zip(\n (0.4, 0.72, 0.08),\n (\"speed_radius\", \"speed_average\", \"amplitude\"),\n (\"Speed radius (km)\", \"Speed average (cm/s)\", \"Amplitude (cm)\"),\n (100, 50, 20),\n (0.001, 100, 100),\n (arange(0, 2000, 1), arange(0, 1000, 0.5), arange(0.0005, 1000, 0.2)),\n):\n ax_hist = fig.add_axes((x0, 0.24, 0.27, 0.35))\n nb_a, _, _ = ax_hist.hist(a[name] * factor, bins=bins, **kwargs_a)\n nb_c, _, _ = ax_hist.hist(c[name] * factor, bins=bins, **kwargs_c)\n ax_hist.set_xticklabels([])\n ax_hist.set_xlim(0, xmax)\n ax_hist.grid()\n\n ax_cum = fig.add_axes((x0, 0.62, 0.27, 0.35))\n ax_cum.hist(a[name] * factor, bins=bins, cumulative=-1, **kwargs_a)\n ax_cum.hist(c[name] * factor, bins=bins, cumulative=-1, **kwargs_c)\n ax_cum.set_xticklabels([])\n ax_cum.set_title(title)\n ax_cum.set_xlim(0, xmax)\n ax_cum.set_ylim(0, 1)\n ax_cum.grid()\n\n ax_ratio = fig.add_axes((x0, 0.06, 0.27, 0.15))\n ax_ratio.set_xlim(0, xmax)\n ax_ratio.set_ylim(0, 2)\n ax_ratio.plot((bins[1:] + bins[:-1]) / 2, nb_c / nb_a)\n ax_ratio.axhline(1, color=\"k\")\n ax_ratio.grid()\n ax_ratio.set_xlabel(title)\n\nax_cum.set_ylabel(\"Cumulative\\npercent distribution\")\nax_hist.set_ylabel(\"Percent of observations\")\nax_ratio.set_ylabel(\"Ratio percent\\nCyc/Acyc\")\nax_cum.legend()" ] } ], @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb index 874b60ce..ed8c0295 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nLifetime Histogram\n===================\n" + "\n# Lifetime Histogram\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample\nfrom numpy import arange" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, ones\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,14 +44,14 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nnb_year = (a.period[1] - a.period[0] + 1) / 365.25" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Plot\n\n" + "Setup axes\n\n" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure()\nax_lifetime = fig.add_axes([0.05, 0.55, 0.4, 0.4])\nax_cum_lifetime = fig.add_axes([0.55, 0.55, 0.4, 0.4])\nax_ratio_lifetime = fig.add_axes([0.05, 0.05, 0.4, 0.4])\nax_ratio_cum_lifetime = fig.add_axes([0.55, 0.05, 0.4, 0.4])\n\ncum_a, bins, _ = ax_cum_lifetime.hist(\n a[\"n\"], histtype=\"step\", bins=arange(0, 800, 1), label=\"Anticyclonic\", color=\"r\"\n)\ncum_c, bins, _ = ax_cum_lifetime.hist(\n c[\"n\"], histtype=\"step\", bins=arange(0, 800, 1), label=\"Cyclonic\", color=\"b\"\n)\n\nx = (bins[1:] + bins[:-1]) / 2.0\nax_ratio_cum_lifetime.plot(x, cum_c / cum_a)\n\nnb_a, nb_c = cum_a[:-1] - cum_a[1:], cum_c[:-1] - cum_c[1:]\nax_lifetime.plot(x[1:], nb_a, label=\"Anticyclonic\", color=\"r\")\nax_lifetime.plot(x[1:], nb_c, label=\"Cyclonic\", color=\"b\")\n\nax_ratio_lifetime.plot(x[1:], nb_c / nb_a)\n\nfor ax in (ax_lifetime, ax_cum_lifetime, ax_ratio_cum_lifetime, ax_ratio_lifetime):\n ax.set_xlim(0, 365)\n if ax in (ax_lifetime, ax_cum_lifetime):\n ax.set_ylim(1, None)\n ax.set_yscale(\"log\")\n ax.legend()\n else:\n ax.set_ylim(0, 2)\n ax.set_ylabel(\"Ratio Cyclonic/Anticyclonic\")\n ax.set_xlabel(\"Lifetime (days)\")\n ax.grid()" + "figure = plt.figure(figsize=(12, 8))\nax_ratio_cum = figure.add_axes([0.55, 0.06, 0.42, 0.34])\nax_ratio = figure.add_axes([0.07, 0.06, 0.46, 0.34])\nax_cum = figure.add_axes([0.55, 0.43, 0.42, 0.54])\nax = figure.add_axes([0.07, 0.43, 0.46, 0.54])\nax.set_ylabel(\"Eddies by year\")\nax_ratio.set_ylabel(\"Ratio Cyclonic/Anticyclonic\")\nfor ax_ in (ax, ax_cum, ax_ratio_cum, ax_ratio):\n ax_.set_xlim(0, 400)\n if ax_ in (ax, ax_cum):\n ax_.set_ylim(1e-1, 1e4), ax_.set_yscale(\"log\")\n else:\n ax_.set_xlabel(\"Lifetime in days (by week bins)\")\n ax_.set_ylim(0, 2)\n ax_.axhline(1, color=\"g\", lw=2)\n ax_.grid()\nax_cum.xaxis.set_ticklabels([]), ax_cum.yaxis.set_ticklabels([])\nax.xaxis.set_ticklabels([]), ax_ratio_cum.yaxis.set_ticklabels([])\n\n# plot data\nbin_hist = arange(7, 2000, 7)\nx = (bin_hist[1:] + bin_hist[:-1]) / 2.0\na_nb, c_nb = a.nb_obs_by_track, c.nb_obs_by_track\na_nb, c_nb = a_nb[a_nb != 0], c_nb[c_nb != 0]\nw_a, w_c = ones(a_nb.shape) / nb_year, ones(c_nb.shape) / nb_year\nkwargs_a = dict(histtype=\"step\", bins=bin_hist, x=a_nb, color=\"r\", weights=w_a)\nkwargs_c = dict(histtype=\"step\", bins=bin_hist, x=c_nb, color=\"b\", weights=w_c)\ncum_a, _, _ = ax_cum.hist(cumulative=-1, **kwargs_a)\ncum_c, _, _ = ax_cum.hist(cumulative=-1, **kwargs_c)\nnb_a, _, _ = ax.hist(label=\"Anticyclonic\", **kwargs_a)\nnb_c, _, _ = ax.hist(label=\"Cyclonic\", **kwargs_c)\nax_ratio_cum.plot(x, cum_c / cum_a)\nax_ratio.plot(x, nb_c / nb_a)\nax.legend()" ] } ], @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb new file mode 100644 index 00000000..f9fb474f --- /dev/null +++ b/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNormalised Eddy Lifetimes\n=========================\n\nExample from Evan Mason\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import interp, linspace, zeros\nfrom py_eddy_tracker_sample import get_demo_path\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "@njit(cache=True)\ndef sum_profile(x_new, y, out):\n \"\"\"Will sum all interpolated given array\"\"\"\n out += interp(x_new, linspace(0, 1, y.size), y)\n\n\nclass MyObs(TrackEddiesObservations):\n def eddy_norm_lifetime(self, name, nb, factor=1):\n \"\"\"\n :param str,array name: Array or field name\n :param int nb: size of output array\n \"\"\"\n y = self.parse_varname(name)\n x = linspace(0, 1, nb)\n out = zeros(nb, dtype=y.dtype)\n nb_track = 0\n for i, b0, b1 in self.iter_on(\"track\"):\n y_ = y[i]\n size_ = y_.size\n if size_ == 0:\n continue\n sum_profile(x, y_, out)\n nb_track += 1\n return x, out / nb_track * factor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load atlas\n----------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(include_vars=(\"speed_radius\", \"amplitude\", \"track\"))\na = MyObs.load_file(\n get_demo_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"), **kw\n)\nc = MyObs.load_file(get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\"), **kw)\n\nnb_max_a = a.nb_obs_by_track.max()\nnb_max_c = c.nb_obs_by_track.max()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute normalised lifetime\n---------------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Radius\nAC_radius = a.eddy_norm_lifetime(\"speed_radius\", nb=nb_max_a, factor=1e-3)\nCC_radius = c.eddy_norm_lifetime(\"speed_radius\", nb=nb_max_c, factor=1e-3)\n# Amplitude\nAC_amplitude = a.eddy_norm_lifetime(\"amplitude\", nb=nb_max_a, factor=1e2)\nCC_amplitude = c.eddy_norm_lifetime(\"amplitude\", nb=nb_max_c, factor=1e2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Figure\n------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, (ax0, ax1) = plt.subplots(nrows=2, figsize=(8, 6))\n\nax0.set_title(\"Normalised Mean Radius\")\nax0.plot(*AC_radius), ax0.plot(*CC_radius)\nax0.set_ylabel(\"Radius (km)\"), ax0.grid()\nax0.set_xlim(0, 1), ax0.set_ylim(0, None)\n\nax1.set_title(\"Normalised Mean Amplitude\")\nax1.plot(*AC_amplitude, label=\"AC\"), ax1.plot(*CC_amplitude, label=\"CC\")\nax1.set_ylabel(\"Amplitude (cm)\"), ax1.grid(), ax1.legend()\n_ = ax1.set_xlim(0, 1), ax1.set_ylim(0, None)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb index bf930464..23f830d6 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nCount pixel used\n======================\n" + "\n# Count pixel used\n\nDo Geo stat with frequency and compare with center count\nmethod: `sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py`\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom matplotlib.colors import LogNorm\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nimport py_eddy_tracker_sample" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom matplotlib.colors import LogNorm\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -44,7 +44,25 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nt0, t1 = a.period\nstep = 0.1\nbins = ((-10, 37, step), (30, 46, step))\nkwargs_pcolormesh = dict(\n cmap=\"terrain_r\", vmin=0, vmax=0.75, factor=1 / (t1 - t0), name=\"count\"\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 0.125\nbins = ((-10, 37, step), (30, 46, step))\nkwargs_pcolormesh = dict(\n cmap=\"terrain_r\", vmin=0, vmax=0.75, factor=1 / a.nb_days, name=\"count\"\n)" ] }, { @@ -62,7 +80,25 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each contour\ng_a = a.grid_count(bins, intern=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.27, 0.01, 0.7]))\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(ax_ratio, name=\"count\", vmin=0.1, vmax=10, norm=LogNorm())\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" + "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each contour\ng_a = a.grid_count(bins, intern=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.27, 0.01, 0.7]))\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(\n ax_ratio, name=\"count\", norm=LogNorm(vmin=0.1, vmax=10), cmap=\"coolwarm_r\"\n)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Count Anticyclones as a function of lifetime\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(12, 10))\nmask = a.lifetime >= 60\nax_long = fig.add_axes([0.03, 0.53, 0.90, 0.45])\ng_a = a.grid_count(bins, intern=True, filter=mask)\ng_a.display(ax_long, **kwargs_pcolormesh)\nax_long.set_title(f\"Anticyclones with lifetime >= 60 days ({mask.sum()} Obs)\")\nax_short = fig.add_axes([0.03, 0.03, 0.90, 0.45])\ng_a = a.grid_count(bins, intern=True, filter=~mask)\nm = g_a.display(ax_short, **kwargs_pcolormesh)\nax_short.set_title(f\"Anticyclones with lifetime < 60 days ({(~mask).sum()} Obs)\")\nfor ax in (ax_short, ax_long):\n ax.set_aspect(\"equal\"), ax.grid()\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\ncb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.015, 0.9]))" ] } ], @@ -82,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb index 288c7bf4..9792f8f4 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\nPropagation Histogram\n===================\n" + "\n# Propagation Histogram\n" ] }, { @@ -26,25 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\nfrom py_eddy_tracker.generic import distance\nimport py_eddy_tracker_sample\nfrom numpy import arange, empty\nfrom numba import njit" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will create a function compile with numba, to compute a field which contains curvilign distance\n\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "@njit(cache=True)\ndef cum_distance_by_track(distance, track):\n tr_previous = 0\n d_cum = 0\n new_distance = empty(track.shape, dtype=distance.dtype)\n for i in range(distance.shape[0]):\n tr = track[i]\n if i != 0 and tr != tr_previous:\n d_cum = 0\n new_distance[i] = d_cum\n d_cum += distance[i]\n tr_previous = tr\n new_distance[i + 1] = d_cum\n return new_distance" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, ones\n\nfrom py_eddy_tracker.generic import cumsum_by_track\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -62,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nnb_year = (a.period[1] - a.period[0] + 1) / 365.25" ] }, { @@ -98,14 +80,14 @@ }, "outputs": [], "source": [ - "d_a = distance(a.longitude[:-1], a.latitude[:-1], a.longitude[1:], a.latitude[1:])\nd_c = distance(c.longitude[:-1], c.latitude[:-1], c.longitude[1:], c.latitude[1:])\nd_a = cum_distance_by_track(d_a, a[\"track\"]) / 1000.0\nd_c = cum_distance_by_track(d_c, c[\"track\"]) / 1000.0" + "i0, nb = a.index_from_track, a.nb_obs_by_track\nd_a = cumsum_by_track(a.distance_to_next(), a.tracks)[(i0 - 1 + nb)[nb != 0]] / 1000.0\ni0, nb = c.index_from_track, c.nb_obs_by_track\nd_c = cumsum_by_track(c.distance_to_next(), c.tracks)[(i0 - 1 + nb)[nb != 0]] / 1000.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Plot\n\n" + "Setup axes\n\n" ] }, { @@ -116,7 +98,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure()\nax_propagation = fig.add_axes([0.05, 0.55, 0.4, 0.4])\nax_cum_propagation = fig.add_axes([0.55, 0.55, 0.4, 0.4])\nax_ratio_propagation = fig.add_axes([0.05, 0.05, 0.4, 0.4])\nax_ratio_cum_propagation = fig.add_axes([0.55, 0.05, 0.4, 0.4])\n\nbins = arange(0, 1500, 10)\ncum_a, bins, _ = ax_cum_propagation.hist(\n d_a, histtype=\"step\", bins=bins, label=\"Anticyclonic\", color=\"r\"\n)\ncum_c, bins, _ = ax_cum_propagation.hist(\n d_c, histtype=\"step\", bins=bins, label=\"Cyclonic\", color=\"b\"\n)\n\nx = (bins[1:] + bins[:-1]) / 2.0\nax_ratio_cum_propagation.plot(x, cum_c / cum_a)\n\nnb_a, nb_c = cum_a[:-1] - cum_a[1:], cum_c[:-1] - cum_c[1:]\nax_propagation.plot(x[1:], nb_a, label=\"Anticyclonic\", color=\"r\")\nax_propagation.plot(x[1:], nb_c, label=\"Cyclonic\", color=\"b\")\n\nax_ratio_propagation.plot(x[1:], nb_c / nb_a)\n\nfor ax in (\n ax_propagation,\n ax_cum_propagation,\n ax_ratio_cum_propagation,\n ax_ratio_propagation,\n):\n ax.set_xlim(0, 1000)\n if ax in (ax_propagation, ax_cum_propagation):\n ax.set_ylim(1, None)\n ax.set_yscale(\"log\")\n ax.legend()\n else:\n ax.set_ylim(0, 2)\n ax.set_ylabel(\"Ratio Cyclonic/Anticyclonic\")\n ax.set_xlabel(\"Propagation (km)\")\n ax.grid()" + "figure = plt.figure(figsize=(12, 8))\nax_ratio_cum = figure.add_axes([0.55, 0.06, 0.42, 0.34])\nax_ratio = figure.add_axes([0.07, 0.06, 0.46, 0.34])\nax_cum = figure.add_axes([0.55, 0.43, 0.42, 0.54])\nax = figure.add_axes([0.07, 0.43, 0.46, 0.54])\nax.set_ylabel(\"Eddies by year\")\nax_ratio.set_ylabel(\"Ratio Cyclonic/Anticyclonic\")\nfor ax_ in (ax, ax_cum, ax_ratio_cum, ax_ratio):\n ax_.set_xlim(0, 1000)\n if ax_ in (ax, ax_cum):\n ax_.set_ylim(1e-1, 1e4), ax_.set_yscale(\"log\")\n else:\n ax_.set_xlabel(\"Propagation in km (with bins of 20 km)\")\n ax_.set_ylim(0, 2)\n ax_.axhline(1, color=\"g\", lw=2)\n ax_.grid()\nax_cum.xaxis.set_ticklabels([]), ax_cum.yaxis.set_ticklabels([])\nax.xaxis.set_ticklabels([]), ax_ratio_cum.yaxis.set_ticklabels([])\n\n# plot data\nbin_hist = arange(0, 2000, 20)\nx = (bin_hist[1:] + bin_hist[:-1]) / 2.0\nw_a, w_c = ones(d_a.shape) / nb_year, ones(d_c.shape) / nb_year\nkwargs_a = dict(histtype=\"step\", bins=bin_hist, x=d_a, color=\"r\", weights=w_a)\nkwargs_c = dict(histtype=\"step\", bins=bin_hist, x=d_c, color=\"b\", weights=w_c)\ncum_a, _, _ = ax_cum.hist(cumulative=-1, **kwargs_a)\ncum_c, _, _ = ax_cum.hist(cumulative=-1, **kwargs_c)\nnb_a, _, _ = ax.hist(label=\"Anticyclonic\", **kwargs_a)\nnb_c, _, _ = ax.hist(label=\"Cyclonic\", **kwargs_c)\nax_ratio_cum.plot(x, cum_c / cum_a)\nax_ratio.plot(x, nb_c / nb_a)\nax.legend()" ] } ], @@ -136,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb b/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb new file mode 100644 index 00000000..b30682a1 --- /dev/null +++ b/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb @@ -0,0 +1,234 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Collocating external data\n\nScript will use py-eddy-tracker methods to upload external data (sea surface temperature, SST)\nin a common structure with altimetry.\n\nFigures higlights the different steps.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\n\ndate = datetime(2016, 7, 7)\n\nfilename_alt = data.get_demo_path(\n f\"dt_blacksea_allsat_phy_l4_{date:%Y%m%d}_20200801.nc\"\n)\nfilename_sst = data.get_demo_path(\n f\"{date:%Y%m%d}000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc\"\n)\nvar_name_sst = \"analysed_sst\"\n\nextent = [27, 42, 40.5, 47]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading data\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sst = RegularGridDataset(filename=filename_sst, x_name=\"lon\", y_name=\"lat\")\nalti = RegularGridDataset(\n data.get_demo_path(filename_alt), x_name=\"longitude\", y_name=\"latitude\"\n)\n# We can use `Grid` tools to interpolate ADT on the sst grid\nsst.regrid(alti, \"sla\")\nsst.add_uv(\"sla\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions to initiate figure axes\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title, extent=extent):\n fig = plt.figure(figsize=(13, 6), dpi=120)\n ax = fig.add_axes([0.03, 0.05, 0.89, 0.91])\n ax.set_xlim(extent[0], extent[1])\n ax.set_ylim(extent[2], extent[3])\n ax.set_title(title)\n ax.set_aspect(\"equal\")\n return ax\n\n\ndef update_axes(ax, mappable=None, unit=\"\"):\n ax.grid()\n if mappable:\n cax = ax.figure.add_axes([0.93, 0.05, 0.01, 0.9], title=unit)\n plt.colorbar(mappable, cax=cax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ADT first display\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"SLA\", extent=extent)\nm = sst.display(ax, \"sla\", vmin=0.05, vmax=0.35)\nupdate_axes(ax, m, unit=\"[m]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SST first display\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now plot SST from `sst`\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"SST\")\nm = sst.display(ax, \"analysed_sst\", vmin=295, vmax=300)\nupdate_axes(ax, m, unit=\"[\u00b0K]\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"SST\")\nm = sst.display(ax, \"analysed_sst\", vmin=295, vmax=300)\nu, v = sst.grid(\"u\").T, sst.grid(\"v\").T\nax.quiver(sst.x_c[::3], sst.y_c[::3], u[::3, ::3], v[::3, ::3], scale=10)\nupdate_axes(ax, m, unit=\"[\u00b0K]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, with eddy contours, and displaying SST anomaly\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sst.bessel_high_filter(\"analysed_sst\", 400)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Eddy detection\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sst.bessel_high_filter(\"sla\", 400)\n# ADT filtered\nax = start_axes(\"SLA\", extent=extent)\nm = sst.display(ax, \"sla\", vmin=-0.1, vmax=0.1)\nupdate_axes(ax, m, unit=\"[m]\")\na, c = sst.eddy_identification(\"sla\", \"u\", \"v\", date, 0.002)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kwargs_a = dict(lw=2, label=\"Anticyclonic\", ref=-10, color=\"b\")\nkwargs_c = dict(lw=2, label=\"Cyclonic\", ref=-10, color=\"r\")\nax = start_axes(\"SST anomaly\")\nm = sst.display(ax, \"analysed_sst\", vmin=-1, vmax=1)\na.display(ax, **kwargs_a), c.display(ax, **kwargs_c)\nax.legend()\nupdate_axes(ax, m, unit=\"[\u00b0K]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example of post-processing\nGet mean of sst anomaly_high in each internal contour\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "anom_a = a.interp_grid(sst, \"analysed_sst\", method=\"mean\", intern=True)\nanom_c = c.interp_grid(sst, \"analysed_sst\", method=\"mean\", intern=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Are cyclonic (resp. anticyclonic) eddies generally associated with positive (resp. negative) SST anomaly ?\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(7, 5))\nax = fig.add_axes([0.05, 0.05, 0.90, 0.90])\nax.set_xlabel(\"SST anomaly\")\nax.set_xlim([-1, 1])\nax.set_title(\"Histograms of SST anomalies\")\nax.hist(\n anom_a, 5, alpha=0.5, color=\"b\", label=\"Anticyclonic (mean:%s)\" % (anom_a.mean())\n)\nax.hist(anom_c, 5, alpha=0.5, color=\"r\", label=\"Cyclonic (mean:%s)\" % (anom_c.mean()))\nax.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Not clearly so in that case ..\n\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb b/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb new file mode 100644 index 00000000..7ba30914 --- /dev/null +++ b/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nColocate looper with eddy from altimetry\n========================================\n\nAll loopers data used in this example are a subset from the dataset described in this article\n[Lumpkin, R. : Global characteristics of coherent vortices from surface drifter trajectories](https://doi.org/10.1002/2015JC011435)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nimport numpy as np\nimport py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], aspect=\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load eddies dataset\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cyclonic_eddies = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nanticyclonic_eddies = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load loopers dataset\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "loopers_med = TrackEddiesObservations.load_file(\n data.get_demo_path(\"loopers_lumpkin_med.nc\")\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Global view\n===========\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"All drifters available in Med from Lumpkin dataset\")\nloopers_med.plot(ax, lw=0.5, color=\"r\", ref=-10)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One segment of drifter\n======================\n\nGet a drifter segment (the indexes used have no correspondance with the original dataset).\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "looper = loopers_med.extract_ids((3588,))\nfig = plt.figure(figsize=(16, 6))\nax = fig.add_subplot(111, aspect=\"equal\")\nlooper.plot(ax, lw=0.5, label=\"Original position of drifter\")\nlooper_filtered = looper.copy()\nlooper_filtered.position_filter(1, 13)\ns = looper_filtered.scatter(\n ax,\n \"time\",\n cmap=plt.get_cmap(\"Spectral_r\", 20),\n label=\"Filtered position of drifter\",\n)\nplt.colorbar(s).set_label(\"time (days from 1/1/1950)\")\nax.legend()\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Try to find a detected eddies with adt at same place. We used filtered track to simulate an eddy center\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "match = looper_filtered.close_tracks(\n anticyclonic_eddies, method=\"close_center\", delta=0.1, nb_obs_min=50\n)\nfig = plt.figure(figsize=(16, 6))\nax = fig.add_subplot(111, aspect=\"equal\")\nlooper.plot(ax, lw=0.5, label=\"Original position of drifter\")\nlooper_filtered.plot(ax, lw=1.5, label=\"Filtered position of drifter\")\nmatch.plot(ax, lw=1.5, label=\"Matched eddy\")\nax.legend()\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display radius of this 2 datasets.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(20, 8))\nax = fig.add_subplot(111)\nax.plot(looper.time, looper.radius_s / 1e3, label=\"loopers\")\nlooper_radius = looper.copy()\nlooper_radius.median_filter(1, \"time\", \"radius_s\", inplace=True)\nlooper_radius.loess_filter(13, \"time\", \"radius_s\", inplace=True)\nax.plot(\n looper_radius.time,\n looper_radius.radius_s / 1e3,\n label=\"loopers (filtered half window 13 days)\",\n)\nax.plot(match.time, match.radius_s / 1e3, label=\"altimetry\")\nmatch_radius = match.copy()\nmatch_radius.median_filter(1, \"time\", \"radius_s\", inplace=True)\nmatch_radius.loess_filter(13, \"time\", \"radius_s\", inplace=True)\nax.plot(\n match_radius.time,\n match_radius.radius_s / 1e3,\n label=\"altimetry (filtered half window 13 days)\",\n)\nax.set_ylabel(\"radius(km)\"), ax.set_ylim(0, 100)\nax.legend()\nax.set_title(\"Radius from loopers and altimeter\")\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation of a drifter and its colocated eddy\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(frame):\n # We display last 5 days of loopers trajectory\n m = (looper.time < frame) * (looper.time > (frame - 5))\n anim.func_animation(frame)\n line.set_data(looper.lon[m], looper.lat[m])\n\n\nanim = Anim(match, intern=True, figsize=(8, 8), cmap=\"magma_r\", nb_step=10, dpi=75)\n# mappable to show drifter in red\nline = anim.ax.plot([], [], \"r\", lw=4, zorder=100)[0]\nanim.fig.suptitle(\"\")\n_ = VideoAnimation(anim.fig, update, frames=np.arange(*anim.period, 1), interval=125)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb b/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb new file mode 100644 index 00000000..a46a7e22 --- /dev/null +++ b/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Contour fit\n\nTwo type of fit :\n - Ellipse\n - Circle\n\nIn the two case we use a least square algorithm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import cos, linspace, radians, sin\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import fit_circle_, fit_ellipse" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load example identification file\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function to draw circle or ellipse from parameter\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def build_circle(x0, y0, r):\n angle = radians(linspace(0, 360, 50))\n x_norm, y_norm = cos(angle), sin(angle)\n return local_to_coordinates(x_norm * r, y_norm * r, x0, y0)\n\n\ndef build_ellipse(x0, y0, a, b, theta):\n angle = radians(linspace(0, 360, 50))\n x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle)\n y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle)\n return local_to_coordinates(x, y, x0, y0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot fitted circle or ellipse on stored contour\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "xs, ys = a.contour_lon_s, a.contour_lat_s\n\nfig = plt.figure(figsize=(15, 15))\n\nj = 1\nfor i in range(0, 800, 30):\n x, y = xs[i], ys[i]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n ax = fig.add_subplot(4, 4, j)\n ax.grid(), ax.set_aspect(\"equal\")\n ax.plot(x, y, label=\"store\", color=\"black\")\n x0, y0, a, b, theta = fit_ellipse(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_ellipse(x0, y0, a, b, theta), label=\"ellipse\", color=\"green\")\n x0, y0, radius, shape_error = fit_circle_(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_circle(x0, y0, radius), label=\"circle\", color=\"red\", lw=0.5)\n if j == 16:\n break\n j += 1" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb b/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb new file mode 100644 index 00000000..69e49b57 --- /dev/null +++ b/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Visvalingam algorithm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import matplotlib.animation as animation\nfrom matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import array, empty\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import uniform_resample\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import vertice_overlap, visvalingam\n\n\n@njit(cache=True)\ndef visvalingam_polys(x, y, nb_pt):\n nb = x.shape[0]\n x_new = empty((nb, nb_pt), dtype=x.dtype)\n y_new = empty((nb, nb_pt), dtype=y.dtype)\n for i in range(nb):\n x_new[i], y_new[i] = visvalingam(x[i], y[i], nb_pt)\n return x_new, y_new\n\n\n@njit(cache=True)\ndef uniform_resample_polys(x, y, nb_pt):\n nb = x.shape[0]\n x_new = empty((nb, nb_pt), dtype=x.dtype)\n y_new = empty((nb, nb_pt), dtype=y.dtype)\n for i in range(nb):\n x_new[i], y_new[i] = uniform_resample(x[i], y[i], fixed_size=nb_pt)\n return x_new, y_new\n\n\ndef update_line(num):\n nb = 50 - num - 20\n x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb)\n for i, (x_, y_) in enumerate(zip(x_v, y_v)):\n lines_v[i].set_data(x_, y_)\n x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb)\n for i, (x_, y_) in enumerate(zip(x_u, y_u)):\n lines_u[i].set_data(x_, y_)\n scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0\n scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0\n for i, (s_v, s_u) in enumerate(zip(scores_v, scores_u)):\n texts[i].set_text(f\"Score uniform {s_u:.1f} %\\nScore visvalingam {s_v:.1f} %\")\n title.set_text(f\"{nb} points by contour in place of 50\")\n return (title, *lines_u, *lines_v, *texts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\na = a.extract_with_mask((abs(a.lat) < 66) * (abs(a.radius_e) > 80e3))\n\nnb_pt = 10\nx_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nx_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nscores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0\nscores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0\nd_6 = scores_v - scores_u\nnb_pt = 18\nx_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nx_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nscores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0\nscores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0\nd_12 = scores_v - scores_u\na = a.index(array((d_6.argmin(), d_6.argmax(), d_12.argmin(), d_12.argmax())))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\naxs = [\n fig.add_subplot(221),\n fig.add_subplot(222),\n fig.add_subplot(223),\n fig.add_subplot(224),\n]\nlines_u, lines_v, texts, score_text = list(), list(), list(), list()\nfor i, obs in enumerate(a):\n axs[i].set_aspect(\"equal\")\n axs[i].grid()\n axs[i].set_xticklabels([]), axs[i].set_yticklabels([])\n axs[i].plot(\n obs[\"contour_lon_e\"], obs[\"contour_lat_e\"], \"r\", lw=6, label=\"Original contour\"\n )\n lines_v.append(axs[i].plot([], [], color=\"limegreen\", lw=4, label=\"visvalingam\")[0])\n lines_u.append(\n axs[i].plot([], [], color=\"black\", lw=2, label=\"uniform resampling\")[0]\n )\n texts.append(axs[i].set_title(\"\", fontsize=8))\naxs[0].legend(fontsize=8)\ntitle = fig.suptitle(\"\")\nanim = animation.FuncAnimation(fig, update_line, 27)\nanim" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_atlas.ipynb b/notebooks/python_module/16_network/pet_atlas.ipynb new file mode 100644 index 00000000..31e3580f --- /dev/null +++ b/notebooks/python_module/16_network/pet_atlas.ipynb @@ -0,0 +1,371 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Network Analysis\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import ma\n\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nn = NetworkObservations.load_file(\n get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc\"\n )\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 1 / 10.0\nbins = ((-10, 37, step), (30, 46, step))\nkw_time = dict(cmap=\"terrain_r\", factor=100.0 / n.nb_days, name=\"count\")\nkw_ratio = dict(cmap=plt.get_cmap(\"YlGnBu_r\", 10))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## All\nDisplay the % of time each pixel (1/10\u00b0) is within an anticyclonic network\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_all = n.grid_count(bins)\nm = g_all.display(ax, **kw_time, vmin=0, vmax=75)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network longer than 10 days\nDisplay the % of time each pixel (1/10\u00b0) is within an anticyclonic network\nwhich total lifetime in longer than 10 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nn10 = n.longer_than(10)\ng_10 = n10.grid_count(bins)\nm = g_10.display(ax, **kw_time, vmin=0, vmax=75)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ratio\nRatio between the longer and total presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_10.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_10.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Blue = mostly short networks\n\n## Network longer than 20 days\nDisplay the % of time each pixel (1/10\u00b0) is within an anticyclonic network\nwhich total lifetime is longer than 20 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nn20 = n.longer_than(20)\ng_20 = n20.grid_count(bins)\nm = g_20.display(ax, **kw_time, vmin=0, vmax=75)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ratio\nRatio between the longer and total presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_20.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will hide pixel which are used less than 365 times\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g_ = ma.array(\n g_20.vars[\"count\"] * 100.0 / g_all.vars[\"count\"], mask=g_all.vars[\"count\"] < 365\n)\nax = start_axes(\"\")\nm = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will hide pixel which are used more than 365 times\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = ma.array(\n g_20.vars[\"count\"] * 100.0 / g_all.vars[\"count\"], mask=g_all.vars[\"count\"] >= 365\n)\nm = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Coastal areas are mostly populated by short networks\n\n## All merging\nDisplay the occurence of merging events\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_all_merging = n.merging_event().grid_count(bins)\nm = g_all_merging.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ratio merging events / eddy presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_all_merging.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_all_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nmerger = n10.remove_dead_end(nobs=10).merging_event()\ng_10_merging = merger.grid_count(bins)\nm = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Merging in networks longer than 10 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nmerger = n10.merging_event()\ng_10_merging = merger.grid_count(bins)\nm = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ratio merging events / eddy presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = ma.array(\n g_10_merging.vars[\"count\"] * 100.0 / g_10.vars[\"count\"],\n mask=g_10.vars[\"count\"] < 365,\n)\nm = g_10_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## All Spliting\nDisplay the occurence of spliting events\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_all_spliting = n.spliting_event().grid_count(bins)\nm = g_all_spliting.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ratio spliting events / eddy presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_all_spliting.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_all_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Spliting in networks longer than 10 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_10_spliting = n10.spliting_event().grid_count(bins)\nm = g_10_spliting.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = ma.array(\n g_10_spliting.vars[\"count\"] * 100.0 / g_10.vars[\"count\"],\n mask=g_10.vars[\"count\"] < 365,\n)\nm = g_10_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_follow_particle.ipynb b/notebooks/python_module/16_network/pet_follow_particle.ipynb new file mode 100644 index 00000000..a2a97944 --- /dev/null +++ b/notebooks/python_module/16_network/pet_follow_particle.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nFollow particle\n===============\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, ones, unique, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.observations.groups import particle_candidate\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(651)\nn = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269))\nn = n.remove_dead_end(nobs=0, ndays=10)\nn.numbering_segment()\nc = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n heigth=\"adt\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Schema\n------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(12, 6))\nax = fig.add_axes([0.05, 0.05, 0.9, 0.9])\n_ = n.display_timeline(ax, field=\"longitude\", marker=\"+\", lw=2, markersize=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation\n---------\nParticle settings\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t_snapshot = 20200\nstep = 1 / 50.0\nx, y = meshgrid(arange(20, 36, step), arange(30, 46, step))\nN = 6\nx_f, y_f = x[::N, ::N].copy(), y[::N, ::N].copy()\nx, y = x.reshape(-1), y.reshape(-1)\nx_f, y_f = x_f.reshape(-1), y_f.reshape(-1)\nn_ = n.extract_with_mask(n.time == t_snapshot)\nindex = n_.contains(x, y, intern=True)\nm = index != -1\nindex = n_.segment[index[m]]\nindex_ = unique(index)\nx, y = x[m], y[m]\nm = ~n_.inside(x_f, y_f, intern=True)\nx_f, y_f = x_f[m], y_f[m]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cmap = colors.ListedColormap(list(n.COLORS), name=\"from_list\", N=n.segment.max() + 1)\na = Anim(\n n,\n intern=False,\n figsize=(12, 6),\n nb_step=1,\n dpi=60,\n field_color=\"segment\",\n field_txt=\"segment\",\n cmap=cmap,\n)\na.fig.suptitle(\"\"), a.ax.set_xlim(24, 36), a.ax.set_ylim(30, 36)\na.txt.set_position((25, 31))\n\nstep = 0.25\nkw_p = dict(nb_step=2, time_step=86400 * step * 0.5, t_init=t_snapshot - 2 * step)\n\nmappables = dict()\nparticules = c.advect(x, y, \"u\", \"v\", **kw_p)\nfilament = c.filament(x_f, y_f, \"u\", \"v\", **kw_p, filament_size=3)\nkw = dict(ls=\"\", marker=\".\", markersize=0.25)\nfor k in index_:\n m = k == index\n mappables[k] = a.ax.plot([], [], color=cmap(k), **kw)[0]\nm_filament = a.ax.plot([], [], lw=0.25, color=\"gray\")[0]\n\n\ndef update(frame):\n tt, xt, yt = particules.__next__()\n for k, mappable in mappables.items():\n m = index == k\n mappable.set_data(xt[m], yt[m])\n tt, xt, yt = filament.__next__()\n m_filament.set_data(xt, yt)\n if frame % 1 == 0:\n a.func_animation(frame)\n\n\nani = VideoAnimation(a.fig, update, frames=arange(20200, 20269, step), interval=200)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Particle advection\n^^^^^^^^^^^^^^^^^^\nAdvection from speed contour to speed contour (default)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 1 / 60.0\n\nt_start, t_end = int(n.period[0]), int(n.period[1])\ndt = 14\n\nshape = (n.obs.size, 2)\n# Forward run\ni_target_f, pct_target_f = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in arange(t_start, t_end - dt):\n particle_candidate(c, n, step, t, i_target_f, pct_target_f, n_days=dt)\n\n# Backward run\ni_target_b, pct_target_b = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in arange(t_start + dt, t_end):\n particle_candidate(c, n, step, t, i_target_b, pct_target_b, n_days=-dt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(10, 10))\nax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45])\nax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45])\nax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45])\nax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45])\nax_1st_b.set_title(\"Backward advection for each time step\")\nax_1st_f.set_title(\"Forward advection for each time step\")\nax_1st_b.set_ylabel(\"Color -> First target\\nLatitude\")\nax_2nd_b.set_ylabel(\"Color -> Secondary target\\nLatitude\")\nax_2nd_b.set_xlabel(\"Julian days\"), ax_2nd_f.set_xlabel(\"Julian days\")\nax_1st_f.set_yticks([]), ax_2nd_f.set_yticks([])\nax_1st_f.set_xticks([]), ax_1st_b.set_xticks([])\n\n\ndef color_alpha(target, pct, vmin=5, vmax=80):\n color = cmap(n.segment[target])\n # We will hide under 5 % and from 80% to 100 % it will be 1\n alpha = (pct - vmin) / (vmax - vmin)\n alpha[alpha < 0] = 0\n alpha[alpha > 1] = 1\n color[:, 3] = alpha\n return color\n\n\nkw = dict(\n name=None, yfield=\"longitude\", event=False, zorder=-100, s=(n.speed_area / 20e6)\n)\nn.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw)\nn.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw)\nn.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw)\nn.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw)\nfor ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f):\n n.display_timeline(ax, field=\"longitude\", marker=\"+\", lw=2, markersize=5)\n ax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_group_anim.ipynb b/notebooks/python_module/16_network/pet_group_anim.ipynb new file mode 100644 index 00000000..090170ff --- /dev/null +++ b/notebooks/python_module/16_network/pet_group_anim.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNetwork group process\n=====================\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# sphinx_gallery_thumbnail_number = 2\nimport re\nfrom datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.colors import ListedColormap\nfrom numba import njit\nfrom numpy import arange, array, empty, ones\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import flatten_line_matrix\nfrom py_eddy_tracker.observations.network import Network\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "NETWORK_GROUPS = list()\n\n\n@njit(cache=True)\ndef apply_replace(x, x0, x1):\n nb = x.shape[0]\n for i in range(nb):\n if x[i] == x0:\n x[i] = x1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Modified class to catch group process at each step in order to illustrate processing\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MyNetwork(Network):\n def get_group_array(self, results, nb_obs):\n \"\"\"With a loop on all pair of index, we will label each obs with a group\n number\n \"\"\"\n nb_obs = array(nb_obs, dtype=\"u4\")\n day_start = nb_obs.cumsum() - nb_obs\n gr = empty(nb_obs.sum(), dtype=\"u4\")\n gr[:] = self.NOGROUP\n\n id_free = 1\n for i, j, ii, ij in results:\n gr_i = gr[slice(day_start[i], day_start[i] + nb_obs[i])]\n gr_j = gr[slice(day_start[j], day_start[j] + nb_obs[j])]\n # obs with no groups\n m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] == self.NOGROUP)\n nb_new = m.sum()\n gr_i[ii[m]] = gr_j[ij[m]] = arange(id_free, id_free + nb_new)\n id_free += nb_new\n # associate obs with no group with obs with group\n m = (gr_i[ii] != self.NOGROUP) * (gr_j[ij] == self.NOGROUP)\n gr_j[ij[m]] = gr_i[ii[m]]\n m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] != self.NOGROUP)\n gr_i[ii[m]] = gr_j[ij[m]]\n # case where 2 obs have a different group\n m = gr_i[ii] != gr_j[ij]\n if m.any():\n # Merge of group, ref over etu\n for i_, j_ in zip(ii[m], ij[m]):\n g0, g1 = gr_i[i_], gr_j[j_]\n apply_replace(gr, g0, g1)\n NETWORK_GROUPS.append((i, j, gr.copy()))\n return gr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Movie period\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t0 = (datetime(2005, 5, 1) - datetime(1950, 1, 1)).days\nt1 = (datetime(2005, 6, 1) - datetime(1950, 1, 1)).days" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get data from period and area\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "e = EddiesObservations.load_file(data.get_demo_path(\"network_med.nc\"))\ne = e.extract_with_mask((e.time >= t0) * (e.time < t1)).extract_with_area(\n dict(llcrnrlon=25, urcrnrlon=35, llcrnrlat=31, urcrnrlat=37.5)\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reproduce individual daily identification(for demonstration)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "EDDIES_BY_DAYS = list()\nfor i, b0, b1 in e.iter_on(\"time\"):\n EDDIES_BY_DAYS.append(e.index(i))\n# need for display\ne = EddiesObservations.concatenate(EDDIES_BY_DAYS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run network building group to intercept every step\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = MyNetwork.from_eddiesobservations(EDDIES_BY_DAYS, window=7)\n_ = n.group_observations(minimal_area=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(frame):\n i_current, i_match, gr = NETWORK_GROUPS[frame]\n current = EDDIES_BY_DAYS[i_current]\n x = flatten_line_matrix(current.contour_lon_e)\n y = flatten_line_matrix(current.contour_lat_e)\n current_contour.set_data(x, y)\n match = EDDIES_BY_DAYS[i_match]\n x = flatten_line_matrix(match.contour_lon_e)\n y = flatten_line_matrix(match.contour_lat_e)\n matched_contour.set_data(x, y)\n groups.set_array(gr)\n txt.set_text(f\"Day {i_current} match with day {i_match}\")\n s = 80 * ones(gr.shape)\n s[gr == 0] = 4\n groups.set_sizes(s)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(16, 9), dpi=50)\nax = fig.add_axes([0, 0, 1, 1])\nax.set_aspect(\"equal\"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5)\ncmap = ListedColormap([\"gray\", *e.COLORS[:-1]], name=\"from_list\", N=30)\nkw_s = dict(cmap=cmap, vmin=0, vmax=30)\ngroups = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[0][2], **kw_s)\ncurrent_contour = ax.plot([], [], \"k\", lw=2, label=\"Current contour\")[0]\nmatched_contour = ax.plot([], [], \"r\", lw=1, ls=\"--\", label=\"Candidate contour\")[0]\ntxt = ax.text(29, 35, \"\", fontsize=25)\nax.legend(fontsize=25)\nani = VideoAnimation(fig, update, frames=len(NETWORK_GROUPS), interval=220)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Final Result\n------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(16, 9))\nax = fig.add_axes([0, 0, 1, 1])\nax.set_aspect(\"equal\"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5)\n_ = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[-1][2], **kw_s)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb b/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb new file mode 100644 index 00000000..9d659597 --- /dev/null +++ b/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb @@ -0,0 +1,346 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nIoannou case\n============\nFigure 10 from https://doi.org/10.1002/2017JC013158\n\nWe want to find the Ierapetra Eddy described above in a network demonstration run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import arange, array, pi, where\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.generic import coordinates_to_local\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.poly import fit_ellipse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.03, 0.06, 0.90, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We know the network ID, we will get directly\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ioannou_case = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(\n 651\n)\nprint(ioannou_case.infos())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It seems that this network is huge! Our case is visible at 22E 33.5N\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes()\nioannou_case.plot(ax, color_cycle=ioannou_case.COLORS)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Full Timeline\n-------------\nThe network span for many years... How to cut the interesting part?\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.xaxis.set_major_formatter(formatter), ax.grid()\n_ = ioannou_case.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sub network and new numbering\n-----------------------------\nHere we chose to keep only the order 3 segments relatives to our chosen eddy\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i = where(\n (ioannou_case.lat > 33)\n * (ioannou_case.lat < 34)\n * (ioannou_case.lon > 22)\n * (ioannou_case.lon < 23)\n * (ioannou_case.time > 20630)\n * (ioannou_case.time < 20650)\n)[0][0]\nclose_to_i3 = ioannou_case.relative(i, order=3)\nclose_to_i3.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\nQuick movie to see better!\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = Anim(\n close_to_i3,\n figsize=(12, 4),\n cmap=colors.ListedColormap(\n list(close_to_i3.COLORS), name=\"from_list\", N=close_to_i3.segment.max() + 1\n ),\n nb_step=7,\n dpi=70,\n field_color=\"segment\",\n field_txt=\"segment\",\n)\na.ax.set_xlim(19, 30), a.ax.set_ylim(32, 35.25)\na.txt.set_position((21.5, 32.7))\n# We display in video only from the 100th day to the 500th\nkwargs = dict(frames=arange(*a.period)[100:501], interval=100)\nani = VideoAnimation(a.fig, a.func_animation, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Classic display\n---------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\n_ = close_to_i3.display_timeline(ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nn_copy = close_to_i3.copy()\nn_copy.position_filter(2, 4)\nn_copy.plot(ax, color_cycle=n_copy.COLORS)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Latitude Timeline\n-----------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes(f\"Close segments ({close_to_i3.infos()})\")\nn_copy = close_to_i3.copy()\nn_copy.median_filter(15, \"time\", \"latitude\")\n_ = n_copy.display_timeline(ax, field=\"lat\", method=\"all\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Local radius timeline\n---------------------\nEffective (bold) and Speed (thin) Radius together\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n_copy.median_filter(2, \"time\", \"radius_e\")\nn_copy.median_filter(2, \"time\", \"radius_s\")\nfor b0, b1 in [\n (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007)\n]:\n ref, delta = datetime(1950, 1, 1), 20\n b0_, b1_ = (b0 - ref).days, (b1 - ref).days\n ax = timeline_axes()\n ax.set_xlim(b0_ - delta, b1_ + delta)\n ax.set_ylim(10, 115)\n ax.axvline(b0_, color=\"k\", lw=1.5, ls=\"--\"), ax.axvline(\n b1_, color=\"k\", lw=1.5, ls=\"--\"\n )\n n_copy.display_timeline(\n ax, field=\"radius_e\", method=\"all\", lw=4, markersize=8, factor=1e-3\n )\n n_copy.display_timeline(\n ax, field=\"radius_s\", method=\"all\", lw=1, markersize=3, factor=1e-3\n )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters timeline\n-------------------\nEffective Radius\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(s=35, cmap=plt.get_cmap(\"Spectral_r\", 8), zorder=10)\nax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, \"radius_e\", factor=1e-3, vmin=20, vmax=100, **kw)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective radius (km)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Shape error\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, \"shape_error_e\", vmin=14, vmax=70, **kw)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Rotation angle\n--------------\nFor each obs, fit an ellipse to the contour, with theta the angle from the x-axis,\na the semi ax in x direction and b the semi ax in y dimension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "theta_ = list()\na_ = list()\nb_ = list()\nfor obs in close_to_i3:\n x, y = obs[\"contour_lon_s\"], obs[\"contour_lat_s\"]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n x0, y0, a, b, theta = fit_ellipse(x_, y_)\n theta_.append(theta)\n a_.append(a)\n b_.append(b)\na_ = array(a_)\nb_ = array(b_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Theta\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, theta_, vmin=-pi / 2, vmax=pi / 2, cmap=\"hsv\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "a\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, a_ * 1e-3, vmin=0, vmax=80, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "b\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, b_ * 1e-3, vmin=0, vmax=80, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "a/b\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, a_ / b_, vmin=1, vmax=2, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_relative.ipynb b/notebooks/python_module/16_network/pet_relative.ipynb new file mode 100644 index 00000000..9f3fd3d9 --- /dev/null +++ b/notebooks/python_module/16_network/pet_relative.ipynb @@ -0,0 +1,547 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Network basic manipulation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import where\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data\nLoad data where observations are put in same network but no segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = NetworkObservations.load_file(data.get_demo_path(\"network_med.nc\")).network(651)\ni = where(\n (n.lat > 33)\n * (n.lat < 34)\n * (n.lon > 22)\n * (n.lon < 23)\n * (n.time > 20630)\n * (n.time < 20650)\n)[0][0]\n# For event use\nn2 = n.relative(i, order=2)\nn = n.relative(i, order=4)\nn.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display timeline with events\nA segment generated by a splitting is marked with a star\n\nA segment merging in another is marked with an exagon\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\n_ = n.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display timeline without event\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\n_ = n.display_timeline(ax, event=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline by mean latitude\nDisplay timeline with the mean latitude of the segments in yaxis\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\nax.set_ylabel(\"Latitude\")\n_ = n.display_timeline(ax, field=\"latitude\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline by mean Effective Radius\nThe factor argument is applied on the chosen field\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\nax.set_ylabel(\"Effective Radius (km)\")\n_ = n.display_timeline(ax, field=\"radius_e\", factor=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline by latitude\nUse `method=\"all\"` to display the consecutive values of the field\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.set_ylabel(\"Latitude\")\n_ = n.display_timeline(ax, field=\"lat\", method=\"all\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can filter the data, here with a time window of 15 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nn_copy = n.copy()\nn_copy.median_filter(15, \"time\", \"latitude\")\n_ = n_copy.display_timeline(ax, field=\"lat\", method=\"all\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameters timeline\nScatter is usefull to display the parameters' temporal evolution\n\nEffective Radius and Amplitude\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(s=25, cmap=\"Spectral_r\", zorder=10)\nfig = plt.figure(figsize=(15, 12))\nax = fig.add_axes([0.04, 0.54, 0.90, 0.44])\nm = n.scatter_timeline(ax, \"radius_e\", factor=1e-3, vmin=50, vmax=150, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.54, 0.01, 0.44]), orientation=\"vertical\"\n)\ncb.set_label(\"Effective radius (km)\")\n\nax = fig.add_axes([0.04, 0.04, 0.90, 0.44])\nm = n.scatter_timeline(ax, \"amplitude\", factor=100, vmin=0, vmax=15, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.44]), orientation=\"vertical\"\n)\ncb.set_label(\"Amplitude (cm)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Speed\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nm = n.scatter_timeline(ax, \"speed_average\", factor=100, vmin=0, vmax=40, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation=\"vertical\"\n)\ncb.set_label(\"Maximum speed (cm/s)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Speed Radius\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nm = n.scatter_timeline(ax, \"radius_s\", factor=1e-3, vmin=20, vmax=100, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation=\"vertical\"\n)\ncb.set_label(\"Speed radius (km)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Remove dead branch\nRemove all tiny segments with less than N obs which didn't join two segments\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n_clean = n.remove_dead_end(nobs=5, ndays=10)\nfig = plt.figure(figsize=(15, 12))\nax = fig.add_axes([0.04, 0.54, 0.90, 0.40])\nax.set_title(f\"Original network ({n.infos()})\")\nn.display_timeline(ax)\nax = fig.add_axes([0.04, 0.04, 0.90, 0.40])\nax.set_title(f\"Clean network ({n_clean.infos()})\")\n_ = n_clean.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For further figure we will use clean path\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = n_clean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Change splitting-merging events\nchange event where seg A split to B, then A merge into B, to A split to B then B merge into A\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 12))\nax = fig.add_axes([0.04, 0.54, 0.90, 0.40])\nax.set_title(f\"Clean network ({n.infos()})\")\nn.display_timeline(ax)\n\nclean_modified = n.copy()\n# If it's happen in less than 40 days\nclean_modified.correct_close_events(40)\n\nax = fig.add_axes([0.04, 0.04, 0.90, 0.40])\nax.set_title(f\"resplitted network ({clean_modified.infos()})\")\n_ = clean_modified.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Keep only observations where water could propagate from an observation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_observation = 600\nonly_linked = n.find_link(i_observation)\n\nfig = plt.figure(figsize=(15, 12))\nax1 = fig.add_axes([0.04, 0.54, 0.90, 0.40])\nax2 = fig.add_axes([0.04, 0.04, 0.90, 0.40])\n\nkw = dict(marker=\"s\", s=300, color=\"black\", zorder=200, label=\"observation start\")\nfor ax, dataset in zip([ax1, ax2], [n, only_linked]):\n dataset.display_timeline(ax, field=\"segment\", lw=2, markersize=5, colors_mode=\"y\")\n ax.scatter(n.time[i_observation], n.segment[i_observation], **kw)\n ax.legend()\n\nax1.set_title(f\"full example ({n.infos()})\")\nax2.set_title(f\"only linked observations ({only_linked.infos()})\")\n_ = ax2.set_xlim(ax1.get_xlim()), ax2.set_ylim(ax1.get_ylim())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Keep close relative\nWhen you want to investigate one particular observation and select only the closest segments\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# First choose an observation in the network\ni = 1100\n\nfig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nn.display_timeline(ax)\nobs_args = n.time[i], n.segment[i]\nobs_kw = dict(color=\"black\", markersize=30, marker=\".\")\n_ = ax.plot(*obs_args, **obs_kw)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Colors show the relative order of the segment with regards to the chosen one\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nm = n.scatter_timeline(\n ax, n.obs_relative_order(i), vmin=-1.5, vmax=6.5, cmap=plt.get_cmap(\"jet\", 8), s=10\n)\nax.plot(*obs_args, **obs_kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation=\"vertical\"\n)\ncb.set_label(\"Relative order\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You want to keep only the segments at the order 1\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nclose_to_i1 = n.relative(i, order=1)\nax.set_title(f\"Close segments ({close_to_i1.infos()})\")\n_ = close_to_i1.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You want to keep the segments until order 2\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nclose_to_i2 = n.relative(i, order=2)\nax.set_title(f\"Close segments ({close_to_i2.infos()})\")\n_ = close_to_i2.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You want to keep the segments until order 3\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nclose_to_i3 = n.relative(i, order=3)\nax.set_title(f\"Close segments ({close_to_i3.infos()})\")\n_ = close_to_i3.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Keep relatives to an event\nWhen you want to investigate one particular event and select only the closest segments\n\nFirst choose a merging event in the network\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "after, before, stopped = n.merging_event(triplet=True, only_index=True)\ni_event = 7" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "then see some order of relatives\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "max_order = 1\nfig, axs = plt.subplots(\n max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2))\n)\n# Original network\nax = axs[0]\nax.set_title(\"Full network\", weight=\"bold\")\nn.display_timeline(axs[0], colors_mode=\"y\")\nax.grid(), ax.legend()\n\nfor k in range(0, max_order + 1):\n ax = axs[k + 1]\n ax.set_title(f\"Relatives order={k}\", weight=\"bold\")\n # Extract neighbours of event\n sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k)\n sub_network.display_timeline(ax, colors_mode=\"y\")\n ax.legend(), ax.grid()\n _ = ax.set_ylim(axs[0].get_ylim())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display track on map\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Get a simplified network\nn = n2.remove_dead_end(nobs=50, recursive=1)\nn.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Only a map can be tricky to understand, with a timeline it's easier!\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nax = fig.add_axes([0.08, 0.7, 0.7, 0.3])\n_ = n.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get merging event\nDisplay the position of the eddies after a merging\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\nm1, m0, m0_stop = n.merging_event(triplet=True)\nm1.display(ax, color=\"violet\", lw=2, label=\"Eddies after merging\")\nm0.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies before merging\")\nm0_stop.display(ax, color=\"black\", lw=2, label=\"Eddies stopped by merging\")\nax.plot(m1.lon, m1.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(m0.lon, m0.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(m0_stop.lon, m0_stop.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nm1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get spliting event\nDisplay the position of the eddies before a splitting\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\ns0, s1, s1_start = n.spliting_event(triplet=True)\ns0.display(ax, color=\"violet\", lw=2, label=\"Eddies before splitting\")\ns1.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies after splitting\")\ns1_start.display(ax, color=\"black\", lw=2, label=\"Eddies starting by splitting\")\nax.plot(s0.lon, s0.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(s1.lon, s1.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(s1_start.lon, s1_start.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ns1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get birth event\nDisplay the starting position of non-splitted eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nbirth = n.birth_event()\nbirth.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nbirth" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get death event\nDisplay the last position of non-merged eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\ndeath = n.death_event()\ndeath.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ndeath" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_replay_segmentation.ipynb b/notebooks/python_module/16_network/pet_replay_segmentation.ipynb new file mode 100644 index 00000000..7c632138 --- /dev/null +++ b/notebooks/python_module/16_network/pet_replay_segmentation.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Replay segmentation\nCase from figure 10 from https://doi.org/10.1002/2017JC013158\n\nAgain with the Ierapetra Eddy\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import where\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.04, 0.06, 0.89, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Class for new_segmentation\nThe oldest win\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MyTrackEddiesObservations(TrackEddiesObservations):\n __slots__ = tuple()\n\n @classmethod\n def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs):\n \"\"\"\n Method to overwrite behaviour in merging.\n\n We will give the point to the older one instead of the maximum overlap ratio\n \"\"\"\n while i_next != -1:\n # Flag\n used[i_next] = True\n # Assign id\n ids[\"track\"][i_next] = track_id\n # Search next\n i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs)\n if i_next_ == -1:\n break\n ids[\"next_obs\"][i_next] = i_next_\n # Target was previously used\n if used[i_next_]:\n i_next_ = -1\n else:\n ids[\"previous_obs\"][i_next_] = i_next\n i_next = i_next_\n\n\ndef get_obs(dataset):\n \"Function to isolate a specific obs\"\n return where(\n (dataset.lat > 33)\n * (dataset.lat < 34)\n * (dataset.lon > 22)\n * (dataset.lon < 23)\n * (dataset.time > 20630)\n * (dataset.time < 20650)\n )[0][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get original network, we will isolate only relative at order *2*\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(651)\nn_ = n.relative(get_obs(n), order=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the default segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(n_.infos())\nn_.plot(ax, color_cycle=n.COLORS)\nupdate_axes(ax)\nfig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.xaxis.set_major_formatter(formatter), ax.grid()\n_ = n_.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run a new segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "e = n.astype(MyTrackEddiesObservations)\ne.obs.sort(order=(\"track\", \"time\"), kind=\"stable\")\nsplit_matrix = e.split_network(intern=False, window=7)\nn_ = NetworkObservations.from_split_network(e, split_matrix)\nn_ = n_.relative(get_obs(n_), order=2)\nn_.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## New segmentation\n\"The oldest wins\" method produce a very long segment\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(n_.infos())\nn_.plot(ax, color_cycle=n_.COLORS)\nupdate_axes(ax)\nfig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.xaxis.set_major_formatter(formatter), ax.grid()\n_ = n_.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameters timeline\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(s=35, cmap=plt.get_cmap(\"Spectral_r\", 8), zorder=10)\nax = timeline_axes()\nn_.median_filter(15, \"time\", \"latitude\")\nm = n_.scatter_timeline(ax, \"shape_error_e\", vmin=14, vmax=70, **kw, yfield=\"lat\")\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")\n\nax = timeline_axes()\nn_.median_filter(15, \"time\", \"latitude\")\nm = n_.scatter_timeline(\n ax, \"shape_error_e\", vmin=14, vmax=70, **kw, yfield=\"lat\", method=\"all\"\n)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")\nax.set_ylabel(\"Latitude\")\n\nax = timeline_axes()\nn_.median_filter(15, \"time\", \"latitude\")\nkw[\"s\"] = (n_.radius_e * 1e-3) ** 2 / 30 ** 2 * 20\nm = n_.scatter_timeline(\n ax,\n \"shape_error_e\",\n vmin=14,\n vmax=70,\n **kw,\n yfield=\"lon\",\n method=\"all\",\n)\nax.set_ylabel(\"Longitude\")\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cost association plot\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n_copy = n_.copy()\nn_copy.median_filter(2, \"time\", \"next_cost\")\nfor b0, b1 in [\n (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008)\n]:\n\n ref, delta = datetime(1950, 1, 1), 20\n b0_, b1_ = (b0 - ref).days, (b1 - ref).days\n ax = timeline_axes()\n ax.set_xlim(b0_ - delta, b1_ + delta)\n ax.set_ylim(0, 1)\n ax.axvline(b0_, color=\"k\", lw=1.5, ls=\"--\"), ax.axvline(\n b1_, color=\"k\", lw=1.5, ls=\"--\"\n )\n n_copy.display_timeline(ax, field=\"next_cost\", method=\"all\", lw=4, markersize=8)\n\n n_.display_timeline(ax, field=\"next_cost\", method=\"all\", lw=0.5, markersize=0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_segmentation_anim.ipynb b/notebooks/python_module/16_network/pet_segmentation_anim.ipynb new file mode 100644 index 00000000..0a546832 --- /dev/null +++ b/notebooks/python_module/16_network/pet_segmentation_anim.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNetwork segmentation process\n============================\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# sphinx_gallery_thumbnail_number = 2\nimport re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.colors import ListedColormap\nfrom numpy import ones, where\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef get_obs(dataset):\n \"Function to isolate a specific obs\"\n return where(\n (dataset.lat > 33)\n * (dataset.lat < 34)\n * (dataset.lon > 22)\n * (dataset.lon < 23)\n * (dataset.time > 20630)\n * (dataset.time < 20650)\n )[0][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Hack to pick up each step of segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "TRACKS = list()\nINDICES = list()\n\n\nclass MyTrack(TrackEddiesObservations):\n @staticmethod\n def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):\n TRACKS.append(ids[\"track\"].copy())\n INDICES.append(i_current)\n return TrackEddiesObservations.get_next_obs(\n i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs\n )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data\n---------\nLoad data where observations are put in same network but no segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Get a known network for the demonstration\nn = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(651)\n# We keep only some segment\nn = n.relative(get_obs(n), order=2)\nprint(len(n))\n# We convert and order object like segmentation was never happen on observations\ne = n.astype(MyTrack)\ne.obs.sort(order=(\"track\", \"time\"), kind=\"stable\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Do segmentation\n---------------\nSegmentation based on maximum overlap, temporal window for candidates = 5 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "matrix = e.split_network(intern=False, window=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(i_frame):\n tr = TRACKS[i_frame]\n mappable_tracks.set_array(tr)\n s = 40 * ones(tr.shape)\n s[tr == 0] = 4\n mappable_tracks.set_sizes(s)\n\n indices_frames = INDICES[i_frame]\n mappable_CONTOUR.set_data(\n e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames],\n )\n mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])\n return (mappable_tracks,)\n\n\nfig = plt.figure(figsize=(16, 9), dpi=60)\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nax.set_title(f\"{len(e)} observations to segment\")\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\nvmax = TRACKS[-1].max()\ncmap = ListedColormap([\"gray\", *e.COLORS[:-1]], name=\"from_list\", N=vmax)\nmappable_tracks = ax.scatter(\n e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20\n)\nmappable_CONTOUR = ax.plot(\n e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]\n)[0]\nani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Final Result\n------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(16, 9))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\n_ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_something_cool.ipynb b/notebooks/python_module/16_network/pet_something_cool.ipynb new file mode 100644 index 00000000..158852f9 --- /dev/null +++ b/notebooks/python_module/16_network/pet_something_cool.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# essai\n\non tente des trucs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import cartopy.crs as ccrs\nimport cartopy.feature as cfeature\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\n\ndef rect_from_extent(extent):\n rect_lon = [extent[0], extent[1], extent[1], extent[0], extent[0]]\n rect_lat = [extent[2], extent[2], extent[3], extent[3], extent[2]]\n return rect_lon, rect_lat\n\n\ndef indice_from_extent(lon, lat, extent):\n mask = (lon > extent[0]) * (lon < extent[1]) * (lat > extent[2]) * (lat < extent[3])\n return np.where(mask)[0]\n\n\nfichier = \"/data/adelepoulle/work/Eddies/20201217_network_build/big_network.nc\"\nnetwork = NetworkObservations.load_file(fichier)\nsub_network = network.network(1078566)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# extent_begin = [0, 2, -50, -15]\n# extent_end = [-42, -35, -40, -10]\n\nextent_begin = [2, 22, -50, -30]\ni_obs_begin = indice_from_extent(\n sub_network.longitude, sub_network.latitude, extent_begin\n)\nnetwork_begin = sub_network.find_link(i_obs_begin)\ntime_mini = network_begin.time.min()\ntime_maxi = network_begin.time.max()\n\nextent_end = [-52, -45, -37, -33]\ni_obs_end = indice_from_extent(\n (network_begin.longitude + 180) % 360 - 180, network_begin.latitude, extent_end\n)\nnetwork_end = network_begin.find_link(i_obs_end, forward=False, backward=True)\n\n\ndatasets = [network_begin, network_end]\nextents = [extent_begin, extent_end]\nfig, (ax1, ax2) = plt.subplots(\n 2, 1, figsize=(10, 9), dpi=140, subplot_kw={\"projection\": ccrs.PlateCarree()}\n)\n\nfor ax, dataset, extent in zip([ax1, ax2], datasets, extents):\n sca = dataset.scatter(\n ax,\n name=\"time\",\n cmap=\"Spectral_r\",\n label=\"observation dans le temps\",\n vmin=time_mini,\n vmax=time_maxi,\n )\n\n x, y = rect_from_extent(extent)\n ax.fill(x, y, color=\"grey\", alpha=0.3, label=\"observations choisies\")\n # ax.plot(x, y, marker='o')\n\n ax.legend()\n\n gridlines = ax.gridlines(\n alpha=0.2, color=\"black\", linestyle=\"dotted\", draw_labels=True, dms=True\n )\n\n gridlines.left_labels = False\n gridlines.top_labels = False\n\n ax.coastlines()\n ax.add_feature(cfeature.LAND)\n ax.add_feature(cfeature.LAKES, zorder=10)\n ax.add_feature(cfeature.BORDERS, lw=0.25)\n ax.add_feature(cfeature.OCEAN, alpha=0.2)\n\n\nax1.set_title(\n \"Recherche du d\u00e9placement de l'eau dans les eddies \u00e0 travers les observations choisies\"\n)\nax2.set_title(\"Recherche de la provenance de l'eau \u00e0 travers les observations choisies\")\nax2.set_extent(ax1.get_extent(), ccrs.PlateCarree())\n\nfig.subplots_adjust(right=0.87, left=0.02)\ncbar_ax = fig.add_axes([0.90, 0.1, 0.02, 0.8])\ncbar = fig.colorbar(sca[\"scatter\"], cax=cbar_ax, orientation=\"vertical\")\n_ = cbar.set_label(\"time (jj)\", rotation=270, labelpad=-65)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1edf267d..556cabbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,11 @@ -matplotlib -netCDF4 -numba -numpy +matplotlib < 3.8 # need an update of contour management opencv-python pint polygon3 -pyproj pyyaml requests scipy -zarr -# for binder -pyeddytrackersample \ No newline at end of file +zarr < 3.0 +netCDF4 +numpy +numba \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 00000000..a005c37d --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,7 @@ +-r requirements.txt +isort +black +blackdoc +flake8 +pytest +pytest-cov \ No newline at end of file diff --git a/requirements_doc.txt b/requirements_doc.txt deleted file mode 100644 index 4f07fe28..00000000 --- a/requirements_doc.txt +++ /dev/null @@ -1,15 +0,0 @@ -matplotlib -netCDF4 -numba -numpy -opencv-python -pint -polygon3 -pyproj -pyyaml -scipy -zarr -# doc -sphinx-gallery -pyeddytrackersample -sphinx_rtd_theme \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 9821c1ad..7e773ae8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,29 @@ + +[yapf] +column_limit = 100 + [flake8] +max-line-length = 140 ignore = - E203, # whitespace before ':' - W503, # line break before binary operator + E203, + W503, +exclude= + build + doc + versioneer.py + +[isort] +combine_as_imports=True +force_grid_wrap=0 +force_sort_within_sections=True +force_to_top=typing +include_trailing_comma=True +line_length=140 +multi_line_output=3 +skip= + build + doc/conf.py + [versioneer] VCS = git @@ -10,3 +32,8 @@ versionfile_source = src/py_eddy_tracker/_version.py versionfile_build = py_eddy_tracker/_version.py tag_prefix = parentdir_prefix = + +[tool:pytest] +filterwarnings= + ignore:tostring.*is deprecated + diff --git a/setup.py b/setup.py index 80aa0e53..7b836763 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- +from setuptools import find_packages, setup + import versioneer -from setuptools import setup, find_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -9,6 +10,7 @@ setup( name="pyEddyTracker", + python_requires=">=3.10", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), description="Py-Eddy-Tracker libraries", @@ -27,10 +29,8 @@ scripts=[ "src/scripts/EddySubSetter", "src/scripts/EddyTranslate", - "src/scripts/EddyTracking", "src/scripts/EddyFinalTracking", "src/scripts/EddyMergeCorrespondances", - "src/scripts/GUIEddy", ], zip_safe=False, entry_points=dict( @@ -40,11 +40,19 @@ "EddyId = py_eddy_tracker.appli.grid:eddy_id", # eddies "MergeEddies = py_eddy_tracker.appli.eddies:merge_eddies", + "EddyFrequency = py_eddy_tracker.appli.eddies:get_frequency_grid", + "EddyInfos = py_eddy_tracker.appli.eddies:display_infos", + "EddyCircle = py_eddy_tracker.appli.eddies:eddies_add_circle", + "EddyTracking = py_eddy_tracker.appli.eddies:eddies_tracking", + "EddyQuickCompare = py_eddy_tracker.appli.eddies:quick_compare", # network "EddyNetworkGroup = py_eddy_tracker.appli.network:build_network", "EddyNetworkBuildPath = py_eddy_tracker.appli.network:divide_network", + "EddyNetworkSubSetter = py_eddy_tracker.appli.network:subset_network", + "EddyNetworkQuickCompare = py_eddy_tracker.appli.network:quick_compare", # anim/gui "EddyAnim = py_eddy_tracker.appli.gui:anim", + "GUIEddy = py_eddy_tracker.appli.gui:guieddy", # misc "ZarrDump = py_eddy_tracker.appli.misc:zarrdump", ] diff --git a/share/fig.py b/share/fig.py index 8640abcb..80c7f12b 100644 --- a/share/fig.py +++ b/share/fig.py @@ -1,8 +1,10 @@ -from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset from datetime import datetime import logging +from matplotlib import pyplot as plt + +from py_eddy_tracker.dataset.grid import RegularGridDataset + grid_name, lon_name, lat_name = ( "nrt_global_allsat_phy_l4_20190223_20190226.nc", "longitude", diff --git a/share/nrt_global_allsat_phy_l4_20190223_20190226.nc b/share/nrt_global_allsat_phy_l4_20190223_20190226.nc new file mode 120000 index 00000000..077ce7e6 --- /dev/null +++ b/share/nrt_global_allsat_phy_l4_20190223_20190226.nc @@ -0,0 +1 @@ +../src/py_eddy_tracker/data/nrt_global_allsat_phy_l4_20190223_20190226.nc \ No newline at end of file diff --git a/share/tracking.yaml b/share/tracking.yaml index 0f8766b8..d6264104 100644 --- a/share/tracking.yaml +++ b/share/tracking.yaml @@ -1,13 +1,14 @@ PATHS: - # Files produces with EddyIdentification + # Files produced with EddyIdentification FILES_PATTERN: /home/emason/toto/Anticyclonic_*.nc - # Path for saving of outputs + # Path to save outputs SAVE_DIR: '/home/emason/toto/' -# Minimum number of observations to store eddy -TRACK_DURATION_MIN: 4 +# Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 0 +# Minimal number of timesteps to considered as a long trajectory +TRACK_DURATION_MIN: 4 -#CLASS: -# MODULE: py_eddy_tracker.featured_tracking.old_tracker_reference -# CLASS: CheltonTracker +CLASS: + MODULE: py_eddy_tracker.featured_tracking.area_tracker + CLASS: AreaTracker diff --git a/src/py_eddy_tracker/__init__.py b/src/py_eddy_tracker/__init__.py index 5897c46a..7115bf67 100644 --- a/src/py_eddy_tracker/__init__.py +++ b/src/py_eddy_tracker/__init__.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- """ -=========================================================================== This file is part of py-eddy-tracker. py-eddy-tracker is free software: you can redistribute it and/or modify @@ -16,26 +15,30 @@ You should have received a copy of the GNU General Public License along with py-eddy-tracker. If not, see . -Copyright (c) 2014-2020 by Evan Mason +Copyright (c) 2014-2020 by Evan Mason and Antoine Delepoulle Email: evanmason@gmail.com -=========================================================================== + """ from argparse import ArgumentParser +from datetime import datetime import logging -import numpy + import zarr +from ._version import get_versions -def start_logger(): - FORMAT_LOG = ( - "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t\t\t\t\t%(message)s" - ) +__version__ = get_versions()["version"] +del get_versions + + +def start_logger(color=True): + FORMAT_LOG = "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t%(message)s" logger = logging.getLogger("pet") if len(logger.handlers) == 0: # set up logging to CONSOLE console = logging.StreamHandler() - console.setFormatter(ColoredFormatter(FORMAT_LOG)) + console.setFormatter(ColoredFormatter(FORMAT_LOG, color=color)) # add the handler to the root logger logger.addHandler(console) return logger @@ -50,13 +53,14 @@ class ColoredFormatter(logging.Formatter): DEBUG="\033[34m\t", ) - def __init__(self, message): + def __init__(self, message, color=True): super().__init__(message) + self.with_color = color def format(self, record): color = self.COLOR_LEVEL.get(record.levelname, "") color_reset = "\033[0m" - model = color + "%s" + color_reset + model = (color + "%s" + color_reset) if self.with_color else "%s" record.msg = model % record.msg record.funcName = model % record.funcName record.module = model % record.module @@ -65,16 +69,14 @@ def format(self, record): class EddyParser(ArgumentParser): - """General parser for applications - """ + """General parser for applications""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_base_argument() def add_base_argument(self): - """Base arguments - """ + """Base arguments""" self.add_argument( "-v", "--verbose", @@ -83,6 +85,20 @@ def add_base_argument(self): help="Levels : DEBUG, INFO, WARNING," " ERROR, CRITICAL", ) + def memory_arg(self): + self.add_argument( + "--memory", + action="store_true", + help="Load file in memory before to read with netCDF library", + ) + + def contour_intern_arg(self): + self.add_argument( + "--intern", + action="store_true", + help="Use intern contour instead of outter contour", + ) + def parse_args(self, *args, **kwargs): logger = start_logger() # Parsing @@ -92,12 +108,26 @@ def parse_args(self, *args, **kwargs): return opts +TIME_MODELS = ["%Y%m%d", "%Y%m%d%H%M%S", "%Y%m%dT%H%M%S"] + + +def identify_time(str_date): + for model in TIME_MODELS: + try: + return datetime.strptime(str_date, model) + except ValueError: + pass + raise Exception("No time model found") + + VAR_DESCR = dict( time=dict( attr_name="time", nc_name="time", old_nc_name=["j1"], - nc_type="int32", + nc_type="float64", + output_type="uint32", + scale_factor=1 / 86400.0, nc_dims=("obs",), nc_attr=dict( standard_name="time", @@ -116,7 +146,7 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( long_name="Rotating sense of the eddy", - comment="Cyclonic -1; Anti-cyclonic +1", + comment="Cyclonic -1; Anticyclonic +1", ), ), segment_size=dict( @@ -140,6 +170,15 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict(), ), + distance_next=dict( + attr_name=None, + nc_name="distance_next", + nc_type="float32", + output_type="uint16", + scale_factor=50.0, + nc_dims=("obs",), + nc_attr=dict(long_name="Distance to next position", units="m"), + ), virtual=dict( attr_name=None, nc_name="observation_flag", @@ -172,7 +211,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( units="degrees_east", axis="X", - comment="Longitude center of the fitted circle", + comment="Longitude center of the best fit circle", long_name="Eddy Center Longitude", standard_name="longitude", ), @@ -189,7 +228,7 @@ def parse_args(self, *args, **kwargs): axis="Y", long_name="Eddy Center Latitude", standard_name="latitude", - comment="Latitude center of the fitted circle", + comment="Latitude center of the best fit circle", ), ), lon_max=dict( @@ -204,6 +243,7 @@ def parse_args(self, *args, **kwargs): axis="X", long_name="Longitude of the SSH maximum", standard_name="longitude", + comment="Longitude of the inner contour", ), ), lat_max=dict( @@ -218,6 +258,7 @@ def parse_args(self, *args, **kwargs): axis="Y", long_name="Latitude of the SSH maximum", standard_name="latitude", + comment="Latitude of the inner contour", ), ), amplitude=dict( @@ -226,7 +267,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["A"], nc_type="float32", output_type="uint16", - scale_factor=0.001, + scale_factor=0.0001, nc_dims=("obs",), nc_attr=dict( long_name="Amplitude", @@ -235,6 +276,28 @@ def parse_args(self, *args, **kwargs): "the eddy and the SSH around the effective contour defining the eddy edge", ), ), + speed_area=dict( + attr_name="speed_area", + nc_name="speed_area", + nc_type="float32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Speed area", + units="m^2", + comment="Area enclosed by the speed contour in m^2", + ), + ), + effective_area=dict( + attr_name="effective_area", + nc_name="effective_area", + nc_type="float32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Effective area", + units="m^2", + comment="Area enclosed by the effective contour in m^2", + ), + ), speed_average=dict( attr_name="speed_average", scale_factor=0.0001, @@ -260,7 +323,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Radial Speed Profile", units="m/s", - comment="Speed average values from effective contour inwards to smallest contour, evenly spaced points", + comment="Speed averaged values from the effective contour inwards to the smallest contour, evenly spaced points", ), ), i=dict( @@ -268,18 +331,14 @@ def parse_args(self, *args, **kwargs): nc_name="i", nc_type="uint16", nc_dims=("obs",), - nc_attr=dict( - long_name="Longitude index in the grid of the detection", - ), + nc_attr=dict(long_name="Longitude index in the grid of the detection"), ), j=dict( attr_name="j", nc_name="j", nc_type="uint16", nc_dims=("obs",), - nc_attr=dict( - long_name="Latitude index in the grid of the detection", - ), + nc_attr=dict(long_name="Latitude index in the grid of the detection"), ), eke=dict( attr_name="eke", @@ -303,7 +362,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Effective Radius", units="m", - comment="Radius of a circle whose area is equal to that enclosed by the effective contour", + comment="Radius of the best fit circle corresponding to the effective contour", ), ), radius_s=dict( @@ -317,8 +376,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Speed Radius", units="m", - comment="Radius of a circle whose area is equal to that \ - enclosed by the contour of maximum circum-average speed", + comment="Radius of the best fit circle corresponding to the contour of maximum circum-average speed", ), ), track=dict( @@ -328,18 +386,56 @@ def parse_args(self, *args, **kwargs): nc_type="uint32", nc_dims=("obs",), nc_attr=dict( - long_name="Trajectory number", - comment="Trajectory identification number", + long_name="Trajectory number", comment="Trajectory identification number" ), ), - sub_track=dict( + segment=dict( attr_name=None, - nc_name="sub_track", + nc_name="segment", nc_type="uint32", nc_dims=("obs",), nc_attr=dict( - long_name="Segment Number", - comment="Segment number inside a group", + long_name="Segment Number", comment="Segment number inside a group" + ), + ), + previous_obs=dict( + attr_name=None, + nc_name="previous_obs", + nc_type="int32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Previous observation index", + comment="Index of previous observation in a splitting case", + ), + ), + next_obs=dict( + attr_name=None, + nc_name="next_obs", + nc_type="int32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Next observation index", + comment="Index of next observation in a merging case", + ), + ), + previous_cost=dict( + attr_name=None, + nc_name="previous_cost", + nc_type="float32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Previous cost for previous observation", + comment="", + ), + ), + next_cost=dict( + attr_name=None, + nc_name="next_cost", + nc_type="float32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Next cost for next observation", + comment="", ), ), n=dict( @@ -360,8 +456,8 @@ def parse_args(self, *args, **kwargs): nc_type="f4", filters=[zarr.Delta("i2")], output_type="i2", - scale_factor=numpy.float32(0.01), - add_offset=180, + scale_factor=0.01, + add_offset=180.0, nc_dims=("obs", "NbSample"), nc_attr=dict( long_name="Effective Contour Longitudes", @@ -377,7 +473,7 @@ def parse_args(self, *args, **kwargs): nc_type="f4", filters=[zarr.Delta("i2")], output_type="i2", - scale_factor=numpy.float32(0.01), + scale_factor=0.01, nc_dims=("obs", "NbSample"), nc_attr=dict( long_name="Effective Contour Latitudes", @@ -392,9 +488,9 @@ def parse_args(self, *args, **kwargs): nc_type="u2", nc_dims=("obs",), nc_attr=dict( - longname="number of point for effective contour", + long_name="number of points for effective contour", units="ordinal", - description="Number of point for effective contour, if greater than NbSample, there is a resampling", + description="Number of points for effective contour before resampling", ), ), contour_lon_s=dict( @@ -404,8 +500,8 @@ def parse_args(self, *args, **kwargs): nc_type="f4", filters=[zarr.Delta("i2")], output_type="i2", - scale_factor=numpy.float32(0.01), - add_offset=180, + scale_factor=0.01, + add_offset=180.0, nc_dims=("obs", "NbSample"), nc_attr=dict( long_name="Speed Contour Longitudes", @@ -421,7 +517,7 @@ def parse_args(self, *args, **kwargs): nc_type="f4", filters=[zarr.Delta("i2")], output_type="i2", - scale_factor=numpy.float32(0.01), + scale_factor=0.01, nc_dims=("obs", "NbSample"), nc_attr=dict( long_name="Speed Contour Latitudes", @@ -436,9 +532,9 @@ def parse_args(self, *args, **kwargs): nc_type="u2", nc_dims=("obs",), nc_attr=dict( - longname="number of point for speed contour", + long_name="number of points for speed contour", units="ordinal", - description="Number of point for speed contour, if greater than NbSample, there is a resampling", + description="Number of points for speed contour before resampling", ), ), shape_error_e=dict( @@ -451,8 +547,8 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( units="%", - comment="Error criterion between the effective contour and its fit with the circle of same effective radius", - long_name="Effective Contour Error", + comment="Error criterion between the effective contour and its best fit circle", + long_name="Effective Contour Shape Error", ), ), score=dict( @@ -462,7 +558,7 @@ def parse_args(self, *args, **kwargs): output_type="u1", scale_factor=0.4, nc_dims=("obs",), - nc_attr=dict(units="%", comment="score", long_name="Score",), + nc_attr=dict(units="%", comment="score", long_name="Score"), ), index_other=dict( attr_name=None, @@ -485,8 +581,8 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( units="%", - comment="Error criterion between the speed contour and its fit with the circle of same speed radius", - long_name="Speed Contour Error", + comment="Error criterion between the speed contour and its best fit circle", + long_name="Speed Contour Shape Error", ), ), height_max_speed_contour=dict( @@ -531,7 +627,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["Chl"], nc_type="f4", nc_dims=("obs",), - nc_attr=dict(long_name="Log base 10 chlorophyll", units="Log(Chl/[mg/m^3])",), + nc_attr=dict(long_name="Log base 10 chlorophyll", units="Log(Chl/[mg/m^3])"), ), dchl=dict( attr_name=None, @@ -551,7 +647,8 @@ def parse_args(self, *args, **kwargs): nc_type="f4", nc_dims=("obs",), nc_attr=dict( - long_name="Log base 10 background chlorophyll", units="Log(Chl/[mg/m^3])", + long_name="Log base 10 background chlorophyll", + units="Log(Chl/[mg/m^3])", ), ), year=dict( @@ -560,7 +657,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["Year"], nc_type="u2", nc_dims=("obs",), - nc_attr=dict(long_name="Year", units="year",), + nc_attr=dict(long_name="Year", units="year"), ), month=dict( attr_name=None, @@ -568,7 +665,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["Month"], nc_type="u1", nc_dims=("obs",), - nc_attr=dict(long_name="Month", units="month",), + nc_attr=dict(long_name="Month", units="month"), ), day=dict( attr_name=None, @@ -576,7 +673,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["Day"], nc_type="u1", nc_dims=("obs",), - nc_attr=dict(long_name="Day", units="day",), + nc_attr=dict(long_name="Day", units="day"), ), nb_contour_selected=dict( attr_name=None, @@ -601,7 +698,5 @@ def parse_args(self, *args, **kwargs): for key_old in VAR_DESCR[key].get("old_nc_name", list()): VAR_DESCR_inv[key_old] = key -from ._version import get_versions - -__version__ = get_versions()["version"] -del get_versions +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/src/py_eddy_tracker/_version.py b/src/py_eddy_tracker/_version.py index aadd0910..589e706f 100644 --- a/src/py_eddy_tracker/_version.py +++ b/src/py_eddy_tracker/_version.py @@ -5,8 +5,9 @@ # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -15,9 +16,11 @@ import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -33,8 +36,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool -def get_config(): + +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -52,13 +62,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -67,22 +77,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -93,18 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -113,15 +138,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %s but none started with prefix %s" % @@ -130,41 +154,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -177,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -190,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -199,6 +230,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) return {"version": r, @@ -214,7 +250,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -225,8 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -234,24 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -268,7 +349,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%s'" % describe_out) return pieces @@ -293,26 +374,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -337,23 +419,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -380,12 +510,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -402,7 +561,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -422,7 +581,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -442,7 +601,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -456,10 +615,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -474,7 +637,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -495,7 +658,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, diff --git a/src/py_eddy_tracker/appli/__init__.py b/src/py_eddy_tracker/appli/__init__.py index e69de29b..721f5a41 100644 --- a/src/py_eddy_tracker/appli/__init__.py +++ b/src/py_eddy_tracker/appli/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +""" +Entry point +""" diff --git a/src/py_eddy_tracker/appli/eddies.py b/src/py_eddy_tracker/appli/eddies.py index c750792f..c1c7a90d 100644 --- a/src/py_eddy_tracker/appli/eddies.py +++ b/src/py_eddy_tracker/appli/eddies.py @@ -1,28 +1,45 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. +Applications on detection and tracking files +""" +import argparse +from datetime import datetime +from glob import glob +import logging +from os import mkdir +from os.path import basename, dirname, exists, join as join_path +from re import compile as re_compile - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. +from netCDF4 import Dataset +from numpy import bincount, bytes_, empty, in1d, unique +from yaml import safe_load - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. +from .. import EddyParser, identify_time +from ..observations.observation import EddiesObservations, reverse_index +from ..observations.tracking import TrackEddiesObservations +from ..tracking import Correspondances - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . +logger = logging.getLogger("pet") -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== -""" -from netCDF4 import Dataset -from .. import EddyParser -from ..observations.tracking import TrackEddiesObservations + +def eddies_add_circle(): + parser = EddyParser("Add or replace contour with radius parameter") + parser.add_argument("filename", help="all file to merge") + parser.add_argument("out", help="output file") + args = parser.parse_args() + obs = EddiesObservations.load_file(args.filename) + if obs.track_array_variables == 0: + obs.track_array_variables = 50 + obs = obs.add_fields( + array_fields=( + "contour_lon_e", + "contour_lat_e", + "contour_lon_s", + "contour_lat_s", + ) + ) + obs.circle_contour() + obs.write_file(filename=args.out) def merge_eddies(): @@ -35,22 +52,531 @@ def merge_eddies(): parser.add_argument( "--include_var", nargs="+", type=str, help="use only listed variable" ) + parser.memory_arg() args = parser.parse_args() if args.include_var is None: with Dataset(args.filename[0]) as h: args.include_var = h.variables.keys() - obs = TrackEddiesObservations.load_file( - args.filename[0], raw_data=True, include_vars=args.include_var - ) - if args.add_rotation_variable: - obs = obs.add_rotation_type() - for filename in args.filename[1:]: - other = TrackEddiesObservations.load_file( + obs = list() + for filename in args.filename: + e = TrackEddiesObservations.load_file( filename, raw_data=True, include_vars=args.include_var ) if args.add_rotation_variable: - other = other.add_rotation_type() - obs = obs.merge(other) + e = e.add_rotation_type() + obs.append(e) + obs = TrackEddiesObservations.concatenate(obs) obs.write_file(filename=args.out) + + +def get_frequency_grid(): + parser = EddyParser("Compute eddy frequency") + parser.add_argument("observations", help="Input observations to compute frequency") + parser.add_argument("out", help="Grid output file") + parser.contour_intern_arg() + parser.add_argument( + "--xrange", nargs="+", type=float, help="Horizontal range : START,STOP,STEP" + ) + parser.add_argument( + "--yrange", nargs="+", type=float, help="Vertical range : START,STOP,STEP" + ) + args = parser.parse_args() + + if (args.xrange is None or len(args.xrange) not in (3,)) or ( + args.yrange is None or len(args.yrange) not in (3,) + ): + raise Exception("Use START/STOP/STEP for --xrange and --yrange") + + var_to_load = ["longitude"] + var_to_load.extend(EddiesObservations.intern(args.intern, public_label=True)) + e = EddiesObservations.load_file(args.observations, include_vars=var_to_load) + + bins = args.xrange, args.yrange + g = e.grid_count(bins, intern=args.intern) + g.write(args.out) + + +def display_infos(): + parser = EddyParser("Display General inforamtion") + parser.add_argument( + "observations", nargs="+", help="Input observations to compute frequency" + ) + parser.add_argument("--vars", nargs="+", help=argparse.SUPPRESS) + parser.add_argument( + "--area", + nargs=4, + type=float, + metavar=("llcrnrlon", "llcrnrlat", "urcrnrlon", "urcrnrlat"), + help="Bounding box", + ) + args = parser.parse_args() + if args.vars: + vars = args.vars + else: + vars = [ + "amplitude", + "speed_radius", + "speed_area", + "effective_radius", + "effective_area", + "time", + "latitude", + "longitude", + ] + filenames = args.observations + filenames.sort() + for filename in filenames: + with Dataset(filename) as h: + track = "track" in h.variables + print(f"-- {filename} -- ") + if track: + vars_ = vars.copy() + vars_.extend(("track", "observation_number", "observation_flag")) + e = TrackEddiesObservations.load_file(filename, include_vars=vars_) + else: + e = EddiesObservations.load_file(filename, include_vars=vars) + if args.area is not None: + area = dict( + llcrnrlon=args.area[0], + llcrnrlat=args.area[1], + urcrnrlon=args.area[2], + urcrnrlat=args.area[3], + ) + e = e.extract_with_area(area) + print(e) + + +def eddies_tracking(): + parser = EddyParser("Tool to use identification step to compute tracking") + parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker") + parser.add_argument("--correspondance_in", help="Filename of saved correspondance") + parser.add_argument("--correspondance_out", help="Filename to save correspondance") + parser.add_argument( + "--save_correspondance_and_stop", + action="store_true", + help="Stop tracking after correspondance computation," + " merging can be done with EddyFinalTracking", + ) + parser.add_argument( + "--zarr", action="store_true", help="Output will be wrote in zarr" + ) + parser.add_argument( + "--unraw", + action="store_true", + help="Load unraw data, use only for netcdf." + "If unraw is active, netcdf is loaded without apply scalefactor and add_offset.", + ) + parser.add_argument( + "--blank_period", + type=int, + default=0, + help="Nb of detection which will not use at the end of the period", + ) + parser.memory_arg() + args = parser.parse_args() + + # Read yaml configuration file + with open(args.yaml_file, "r") as stream: + config = safe_load(stream) + + if "CLASS" in config: + classname = config["CLASS"]["CLASS"] + obs_class = dict( + class_method=getattr( + __import__(config["CLASS"]["MODULE"], globals(), locals(), classname), + classname, + ), + class_kw=config["CLASS"].get("OPTIONS", dict()), + ) + else: + obs_class = dict() + + c_in, c_out = args.correspondance_in, args.correspondance_out + if c_in is None: + c_in = config["PATHS"].get("CORRESPONDANCES_IN", None) + y_c_out = config["PATHS"].get( + "CORRESPONDANCES_OUT", "{path}/{sign_type}_correspondances.nc" + ) + if c_out is None: + c_out = y_c_out + + # Create ouput folder if necessary + save_dir = config["PATHS"].get("SAVE_DIR", None) + if save_dir is not None and not exists(save_dir): + mkdir(save_dir) + + track( + pattern=config["PATHS"]["FILES_PATTERN"], + output_dir=save_dir, + c_out=c_out, + **obs_class, + virtual=int(config.get("VIRTUAL_LENGTH_MAX", 0)), + previous_correspondance=c_in, + memory=args.memory, + correspondances_only=args.save_correspondance_and_stop, + raw=not args.unraw, + zarr=args.zarr, + nb_obs_min=int(config.get("TRACK_DURATION_MIN", 10)), + blank_period=args.blank_period, + ) + + +def browse_dataset_in( + data_dir, + files_model, + date_regexp, + date_model=None, + start_date=None, + end_date=None, + sub_sampling_step=1, + files=None, +): + pattern_regexp = re_compile(".*/" + date_regexp) + if files is not None: + filenames = bytes_(files) + else: + full_path = join_path(data_dir, files_model) + logger.info("Search files : %s", full_path) + filenames = bytes_(glob(full_path)) + + dataset_list = empty( + len(filenames), + dtype=[("filename", "S500"), ("date", "datetime64[s]")], + ) + dataset_list["filename"] = filenames + + logger.info("%s grids available", dataset_list.shape[0]) + mode_attrs = False + if "(" not in date_regexp: + logger.debug("Attrs date : %s", date_regexp) + mode_attrs = date_regexp.strip().split(":") + else: + logger.debug("Pattern date : %s", date_regexp) + + for item in dataset_list: + str_date = None + if mode_attrs: + with Dataset(item["filename"].decode("utf-8")) as h: + if len(mode_attrs) == 1: + str_date = getattr(h, mode_attrs[0]) + else: + str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1]) + else: + result = pattern_regexp.match(str(item["filename"])) + if result: + str_date = result.groups()[0] + + if str_date is not None: + if date_model is None: + item["date"] = identify_time(str_date) + else: + item["date"] = datetime.strptime(str_date, date_model) + + dataset_list.sort(order=["date", "filename"]) + steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1]) + if len(steps) > 1: + raise Exception("Several timesteps in grid dataset %s" % steps) + + if sub_sampling_step != 1: + logger.info("Grid subsampling %d", sub_sampling_step) + dataset_list = dataset_list[::sub_sampling_step] + + if start_date is not None or end_date is not None: + logger.info( + "Available grid from %s to %s", + dataset_list[0]["date"], + dataset_list[-1]["date"], + ) + logger.info("Filtering grid by time %s, %s", start_date, end_date) + mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date) + + dataset_list = dataset_list[mask] + return dataset_list + + +def track( + pattern, + output_dir, + c_out, + nb_obs_min=10, + raw=True, + zarr=False, + blank_period=0, + correspondances_only=False, + **kw_c, +): + kw = dict(date_regexp=".*_([0-9]*?).[nz].*") + if isinstance(pattern, list): + kw.update(dict(data_dir=None, files_model=None, files=pattern)) + else: + kw.update(dict(data_dir=dirname(pattern), files_model=basename(pattern))) + datasets = browse_dataset_in(**kw) + if blank_period > 0: + datasets = datasets[:-blank_period] + logger.info("Last %d files will be pop", blank_period) + + if nb_obs_min > len(datasets): + raise Exception( + "Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)." + % (len(datasets), nb_obs_min) + ) + + c = Correspondances(datasets=datasets["filename"], **kw_c) + c.track() + logger.info("Track finish") + kw_save = dict( + date_start=datasets["date"][0], + date_stop=datasets["date"][-1], + date_prod=datetime.now(), + path=output_dir, + sign_type=c.current_obs.sign_legend, + ) + + c.save(c_out, kw_save) + if correspondances_only: + return + + logger.info("Start merging") + c.prepare_merging() + logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max()) + logger.info( + "The mean length is %d observations for all tracks", c.nb_obs_by_tracks.mean() + ) + + kw_write = dict(path=output_dir, zarr_flag=zarr) + + c.get_unused_data(raw_data=raw).write_file( + filename="%(path)s/%(sign_type)s_untracked.nc", **kw_write + ) + + short_c = c._copy() + short_c.shorter_than(size_max=nb_obs_min) + short_track = short_c.merge(raw_data=raw) + + if c.longer_than(size_min=nb_obs_min) is False: + long_track = short_track.empty_dataset() + else: + long_track = c.merge(raw_data=raw) + + # We flag obs + if c.virtual: + long_track["virtual"][:] = long_track["time"] == 0 + long_track.normalize_longitude() + long_track.filled_by_interpolation(long_track["virtual"] == 1) + short_track["virtual"][:] = short_track["time"] == 0 + short_track.normalize_longitude() + short_track.filled_by_interpolation(short_track["virtual"] == 1) + + logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max()) + logger.info( + "The mean length is %d observations for long track", + c.nb_obs_by_tracks.mean(), + ) + + long_track.write_file(**kw_write) + short_track.write_file( + filename="%(path)s/%(sign_type)s_track_too_short.nc", **kw_write + ) + + +def get_group( + dataset1, + dataset2, + index1, + index2, + score, + invalid=2, + low=10, + high=60, +): + group1, group2 = dict(), dict() + m_valid = (score * 100) >= invalid + i1, i2, score = index1[m_valid], index2[m_valid], score[m_valid] * 100 + # Eddies with no association & scores < invalid + group1["nomatch"] = reverse_index(i1, len(dataset1)) + group2["nomatch"] = reverse_index(i2, len(dataset2)) + # Select all eddies involved in multiple associations + i1_, nb1 = unique(i1, return_counts=True) + i2_, nb2 = unique(i2, return_counts=True) + i1_multi = i1_[nb1 >= 2] + i2_multi = i2_[nb2 >= 2] + m_multi = in1d(i1, i1_multi) + in1d(i2, i2_multi) + + # Low scores + m_low = score < low + m_low *= ~m_multi + group1["low"] = i1[m_low] + group2["low"] = i2[m_low] + # Intermediate scores + m_i = (score >= low) * (score < high) + m_i *= ~m_multi + group1["intermediate"] = i1[m_i] + group2["intermediate"] = i2[m_i] + # High scores + m_high = score >= high + m_high *= ~m_multi + group1["high"] = i1[m_high] + group2["high"] = i2[m_high] + + # Here for a nice display order + group1["multi_match"] = unique(i1[m_multi]) + group2["multi_match"] = unique(i2[m_multi]) + + def get_twin(j2, j1): + # True only if j1 is used only one + m = bincount(j1)[j1] == 1 + # We keep only link of this mask j1 have exactly one parent + j2_ = j2[m] + # We count parent times + m_ = (bincount(j2_)[j2_] == 2) * (bincount(j2)[j2_] == 2) + # we fill first mask with second one + m[m] = m_ + return m + + m1 = get_twin(i1, i2) + m2 = get_twin(i2, i1) + group1["parent"] = unique(i1[m1]) + group2["parent"] = unique(i2[m2]) + group1["twin"] = i1[m2] + group2["twin"] = i2[m1] + + m = ~m1 * ~m2 * m_multi + group1["complex"] = unique(i1[m]) + group2["complex"] = unique(i2[m]) + + return group1, group2 + + +def run_compare(ref, others, invalid=1, low=20, high=80, intern=False, **kwargs): + groups_ref, groups_other = dict(), dict() + for i, (k, other) in enumerate(others.items()): + print(f"[{i}] {k} -> {len(other)} obs") + gr1, gr2 = get_group( + ref, + other, + *ref.match(other, intern=intern, **kwargs), + invalid=invalid, + low=low, + high=high, + ) + groups_ref[k] = gr1 + groups_other[k] = gr2 + return groups_ref, groups_other + + +def display_compare( + ref, others, invalid=1, low=20, high=80, area=False, intern=False, **kwargs +): + gr_ref, gr_others = run_compare( + ref, others, invalid=invalid, low=low, high=high, intern=intern, **kwargs + ) + + def display(value, ref=None): + outs = list() + for v in value: + if ref: + if area: + outs.append(f"{v / ref * 100:.1f}% ({v:.1f}Mkm²)") + else: + outs.append(f"{v/ref * 100:.1f}% ({v})") + else: + outs.append(v) + if area: + return "".join([f"{v:^16}" for v in outs]) + else: + return "".join([f"{v:^15}" for v in outs]) + + def get_values(v, dataset): + if area: + area_ = dataset["speed_area" if intern else "effective_area"] + return [area_[v_].sum() / 1e12 for v_ in v.values()] + else: + return [ + v_.sum() if v_.dtype == "bool" else v_.shape[0] for v_ in v.values() + ] + + labels = dict( + high=f"{high:0.0f} <= high", + low=f"{invalid:0.0f} <= low < {low:0.0f}", + ) + + keys = [labels.get(key, key) for key in list(gr_ref.values())[0].keys()] + print(" ", display(keys)) + if area: + ref_ = ref["speed_area" if intern else "effective_area"].sum() / 1e12 + else: + ref_ = len(ref) + for i, v in enumerate(gr_ref.values()): + print(f"[{i:2}] ", display(get_values(v, ref), ref=ref_)) + + print(" Point of view of study dataset") + print(" ", display(keys)) + for i, (k, v) in enumerate(gr_others.items()): + other = others[k] + if area: + ref_ = other["speed_area" if intern else "effective_area"].sum() / 1e12 + else: + ref_ = len(other) + print(f"[{i:2}] ", display(get_values(v, other), ref=ref_)) + + +def quick_compare(): + parser = EddyParser( + "Tool to have a quick comparison between several identification" + ) + parser.add_argument("ref", help="Identification file of reference") + parser.add_argument("others", nargs="+", help="Identifications files to compare") + help = "Display in percent of area instead percent of observation" + parser.add_argument("--area", action="store_true", help=help) + help = "Use minimal cost function" + parser.add_argument("--minimal_area", action="store_true", help=help) + parser.add_argument("--high", default=40, type=float) + parser.add_argument("--low", default=20, type=float) + parser.add_argument("--invalid", default=5, type=float) + parser.add_argument( + "--path_out", default=None, help="Save each group in separate file" + ) + parser.contour_intern_arg() + args = parser.parse_args() + + kw = dict( + include_vars=[ + "longitude", + *EddiesObservations.intern(args.intern, public_label=True), + ] + ) + if args.area: + kw["include_vars"].append("speed_area" if args.intern else "effective_area") + + if args.path_out is not None: + kw = dict() + + ref = EddiesObservations.load_file(args.ref, **kw) + print(f"[ref] {args.ref} -> {len(ref)} obs") + others = {other: EddiesObservations.load_file(other, **kw) for other in args.others} + + kwargs = dict( + invalid=args.invalid, + low=args.low, + high=args.high, + intern=args.intern, + minimal_area=args.minimal_area, + ) + if args.path_out is not None: + groups_ref, groups_other = run_compare(ref, others, **kwargs) + if not exists(args.path_out): + mkdir(args.path_out) + for i, other_ in enumerate(args.others): + dirname_ = f"{args.path_out}/{other_.replace('/', '_')}/" + if not exists(dirname_): + mkdir(dirname_) + for k, v in groups_other[other_].items(): + basename_ = f"other_{k}.nc" + others[other_].index(v).write_file(filename=f"{dirname_}/{basename_}") + for k, v in groups_ref[other_].items(): + basename_ = f"ref_{k}.nc" + ref.index(v).write_file(filename=f"{dirname_}/{basename_}") + return + display_compare(ref, others, **kwargs, area=args.area) diff --git a/src/py_eddy_tracker/appli/grid.py b/src/py_eddy_tracker/appli/grid.py index 4223a236..099465ee 100644 --- a/src/py_eddy_tracker/appli/grid.py +++ b/src/py_eddy_tracker/appli/grid.py @@ -1,27 +1,10 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +All entry point to manipulate grid """ -from datetime import datetime -from .. import EddyParser +from argparse import Action + +from .. import EddyParser, identify_time from ..dataset.grid import RegularGridDataset, UnRegularGridDataset @@ -64,6 +47,17 @@ def grid_filtering(): h.write(args.filename_out) +class DictAction(Action): + def __call__(self, parser, namespace, values, option_string=None): + indexs = None + if len(values): + indexs = dict() + for value in values: + k, v = value.split("=") + indexs[k] = int(v) + setattr(namespace, self.dest, indexs) + + def eddy_id(args=None): parser = EddyParser("Eddy Identification") parser.add_argument("filename") @@ -74,28 +68,69 @@ def eddy_id(args=None): parser.add_argument("longitude") parser.add_argument("latitude") parser.add_argument("path_out") - help = "Wavelength for mesoscale filter in km" - parser.add_argument("--cut_wavelength", default=500, type=float, help=help) + help = ( + "Wavelength for mesoscale filter in km to remove low scale if 2 args is given first one will be" + "used to remove high scale and second value to remove low scale" + ) + parser.add_argument( + "--cut_wavelength", default=[500], type=float, help=help, nargs="+" + ) parser.add_argument("--filter_order", default=3, type=int) help = "Step between 2 isoline in m" parser.add_argument("--isoline_step", default=0.002, type=float, help=help) help = "Error max accepted to fit circle in percent" parser.add_argument("--fit_errmax", default=55, type=float, help=help) + parser.add_argument( + "--lat_max", default=85, type=float, help="Maximal latitude filtered" + ) parser.add_argument("--height_unit", default=None, help="Force height unit") parser.add_argument("--speed_unit", default=None, help="Force speed unit") parser.add_argument("--unregular", action="store_true", help="if grid is unregular") + help = "Array size used to build contour, first and last point will be the same" + parser.add_argument("--sampling", default=50, type=int, help=help) + parser.add_argument( + "--sampling_method", + default="visvalingam", + type=str, + choices=("visvalingam", "uniform"), + help="Method to resample contour", + ) help = "Output will be wrote in zarr" parser.add_argument("--zarr", action="store_true", help=help) + help = "Indexs to select grid : --indexs time=2, will select third step along time dimensions" + parser.add_argument("--indexs", nargs="*", help=help, action=DictAction) + help = "Number of pixel of grid detection which could be in an eddies, you must specify MIN and MAX." + parser.add_argument( + "--pixel_limit", nargs="+", default=(5, 2000), type=int, help=help + ) + help = "Minimal number of amplitude in number of step" + parser.add_argument("--nb_step_min", default=2, type=int, help=help) args = parser.parse_args(args) if args else parser.parse_args() - date = datetime.strptime(args.datetime, "%Y%m%d") + if len(args.pixel_limit) != 2: + raise Exception( + "You must define two value minimal number of pixel and maximal number of pixel" + ) + + cut_wavelength = args.cut_wavelength + nb_cw = len(cut_wavelength) + if nb_cw > 2 or nb_cw == 0: + raise Exception("You must specify 1 or 2 values for cut wavelength.") + elif nb_cw == 1: + cut_wavelength = [0, *cut_wavelength] + inf_bnds, upper_bnds = cut_wavelength + + date = identify_time(args.datetime) kwargs = dict( step=args.isoline_step, shape_error=args.fit_errmax, - pixel_limit=(5, 2000), + pixel_limit=args.pixel_limit, force_height_unit=args.height_unit, force_speed_unit=args.speed_unit, + nb_step_to_be_mle=0, + nb_step_min=args.nb_step_min, ) + a, c = identification( args.filename, args.longitude, @@ -105,11 +140,16 @@ def eddy_id(args=None): args.u, args.v, unregular=args.unregular, - cut_wavelength=args.cut_wavelength, + cut_wavelength=upper_bnds, + cut_highwavelength=inf_bnds, + lat_max=args.lat_max, filter_order=args.filter_order, - **kwargs + indexs=args.indexs, + sampling=args.sampling, + sampling_method=args.sampling_method, + **kwargs, ) - out_name = date.strftime("%(path)s/%(sign_type)s_%Y%m%d.nc") + out_name = date.strftime("%(path)s/%(sign_type)s_%Y%m%dT%H%M%S.nc") a.write_file(path=args.path_out, filename=out_name, zarr_flag=args.zarr) c.write_file(path=args.path_out, filename=out_name, zarr_flag=args.zarr) @@ -124,14 +164,20 @@ def identification( v="None", unregular=False, cut_wavelength=500, + cut_highwavelength=0, + lat_max=85, filter_order=1, + indexs=None, **kwargs ): grid_class = UnRegularGridDataset if unregular else RegularGridDataset - grid = grid_class(filename, lon, lat) + grid = grid_class(filename, lon, lat, indexs=indexs) if u == "None" and v == "None": grid.add_uv(h) u, v = "u", "v" + kw_filter = dict(order=filter_order, lat_max=lat_max) + if cut_highwavelength != 0: + grid.bessel_low_filter(h, cut_highwavelength, **kw_filter) if cut_wavelength != 0: - grid.bessel_high_filter(h, cut_wavelength, order=filter_order) + grid.bessel_high_filter(h, cut_wavelength, **kw_filter) return grid.eddy_identification(h, u, v, date, **kwargs) diff --git a/src/py_eddy_tracker/appli/gui.py b/src/py_eddy_tracker/appli/gui.py index 996f43fe..c3d7619b 100644 --- a/src/py_eddy_tracker/appli/gui.py +++ b/src/py_eddy_tracker/appli/gui.py @@ -1,75 +1,127 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Entry point of graphic user interface """ -from numpy import arange, empty +from datetime import datetime, timedelta +from itertools import chain +import logging + from matplotlib import pyplot +from matplotlib.animation import FuncAnimation +from matplotlib.axes import Axes from matplotlib.collections import LineCollection -from datetime import datetime -from ..poly import create_vertice -from ..generic import flatten_line_matrix +from numpy import arange, where, nan + from .. import EddyParser +from ..gui import GUI from ..observations.tracking import TrackEddiesObservations +from ..poly import create_vertice + +logger = logging.getLogger("pet") class Anim: - def __init__(self, eddy, intern=False, sleep_event=0.1, **kwargs): + def __init__( + self, eddy, intern=False, sleep_event=0.1, graphic_information=False, **kwargs + ): self.eddy = eddy x_name, y_name = eddy.intern(intern) self.t, self.x, self.y = eddy.time, eddy[x_name], eddy[y_name] + self.x_core, self.y_core, self.track = eddy["lon"], eddy["lat"], eddy["track"] + self.graphic_informations = graphic_information self.pause = False self.period = self.eddy.period self.sleep_event = sleep_event + self.mappables = list() + self.field_color = None + self.field_txt = None + self.time_field = False + self.txt = None + self.ax = None + self.kw_label = dict() self.setup(**kwargs) - def setup(self, cmap="jet", nb_step=25, figsize=(8, 6)): - cmap = pyplot.get_cmap(cmap) - self.colors = cmap(arange(nb_step + 1) / nb_step) + def setup( + self, + cmap="jet", + lut=None, + field_color="time", + field_txt="track", + range_color=(None, None), + nb_step=25, + figsize=(8, 6), + position=(0.05, 0.05, 0.9, 0.9), + **kwargs, + ): + self.kw_label["fontsize"] = kwargs.pop("fontsize", 12) + self.kw_label["fontweight"] = kwargs.pop("fontweight", "demibold") + # To text each visible eddy + if field_txt: + if isinstance(field_txt,str): + self.field_txt = self.eddy[field_txt] + else : + self.field_txt=field_txt + if field_color: + # To color each visible eddy + self.field_color = self.eddy[field_color].astype("f4") + rg = range_color + if rg[0] is None and rg[1] is None and field_color == "time": + self.time_field = True + else: + rg = ( + self.field_color.min() if rg[0] is None else rg[0], + self.field_color.max() if rg[1] is None else rg[1], + ) + self.field_color = (self.field_color - rg[0]) / (rg[1] - rg[0]) + self.colors = pyplot.get_cmap(cmap, lut=lut) self.nb_step = nb_step - x_min, x_max = self.x.min(), self.x.max() - d_x = x_max - x_min - x_min -= 0.05 * d_x - x_max += 0.05 * d_x - y_min, y_max = self.y.min(), self.y.max() - d_y = y_max - y_min - y_min -= 0.05 * d_y - y_max += 0.05 * d_y - # plot - self.fig = pyplot.figure(figsize=figsize) + if "figure" in kwargs: + self.fig = kwargs.pop("figure") + else: + self.fig = pyplot.figure(figsize=figsize, **kwargs) t0, t1 = self.period - self.fig.suptitle(f'{t0} -> {t1}') - self.ax = self.fig.add_axes((0.05, 0.05, 0.9, 0.9)) - self.ax.set_xlim(x_min, x_max), self.ax.set_ylim(y_min, y_max) - self.ax.set_aspect("equal") - self.ax.grid() - # init mappable - self.txt = self.ax.text(x_min + 0.05 * d_x, y_min + 0.05 * d_y, "", zorder=10) - self.contour = LineCollection([], zorder=1) + self.fig.suptitle(f"{t0} -> {t1}") + if isinstance(position, Axes): + self.ax = position + else: + x_min, x_max = self.x_core.min() - 2, self.x_core.max() + 2 + d_x = x_max - x_min + y_min, y_max = self.y_core.min() - 2, self.y_core.max() + 2 + d_y = y_max - y_min + self.ax = self.fig.add_axes(position, projection="full_axes") + self.ax.set_xlim(x_min, x_max), self.ax.set_ylim(y_min, y_max) + self.ax.set_aspect("equal") + self.ax.grid() + self.txt = self.ax.text( + x_min + 0.05 * d_x, y_min + 0.05 * d_y, "", zorder=10 + ) + self.segs = list() + self.t_segs = list() + self.c_segs = list() + if field_color is None: + self.contour = LineCollection([], zorder=1, color="gray") + else: + self.contour = LineCollection([], zorder=1) self.ax.add_collection(self.contour) self.fig.canvas.draw() self.fig.canvas.mpl_connect("key_press_event", self.keyboard) + self.fig.canvas.mpl_connect("resize_event", self.reset_bliting) + + def reset_bliting(self, event): + self.contour.set_visible(False) + self.txt.set_visible(False) + for m in self.mappables: + m.set_visible(False) + self.fig.canvas.draw() + self.bg_cache = self.fig.canvas.copy_from_bbox(self.ax.bbox) + self.contour.set_visible(True) + self.txt.set_visible(True) + for m in self.mappables: + m.set_visible(True) def show(self, infinity_loop=False): pyplot.show(block=False) @@ -79,7 +131,6 @@ def show(self, infinity_loop=False): loop = True t0, t1 = self.period while loop: - self.segs = list() self.now = t0 while True: dt = self.sleep_event @@ -91,8 +142,9 @@ def show(self, infinity_loop=False): if dt < 0: # self.sleep_event = dt_draw * 1.01 dt = 1e-10 + if dt == 0: + dt = 1e-10 self.fig.canvas.start_event_loop(dt) - if self.now > t1: break if infinity_loop: @@ -108,28 +160,81 @@ def prev(self): self.now -= 1 return self.draw_contour() - def draw_contour(self): - t0, t1 = self.period - # select contour for this time step + def func_animation(self, frame): + while self.mappables: + self.mappables.pop().remove() + self.now = frame + self.update() + artists = [self.contour, self.txt] + artists.extend(self.mappables) + return artists + + def update(self): m = self.t == self.now - self.ax.figure.canvas.restore_region(self.bg_cache) + color = self.field_color is not None if m.sum(): - self.segs.append( - create_vertice( - flatten_line_matrix(self.x[m]), flatten_line_matrix(self.y[m]) + segs = list() + t = list() + c = list() + for i in where(m)[0]: + segs.append(create_vertice(self.x[i], self.y[i])) + if color: + c.append(self.field_color[i]) + t.append(self.now) + self.segs.append(segs) + if color: + self.c_segs.append(c) + self.t_segs.append(t) + self.contour.set_paths(chain(*self.segs)) + if color: + if self.time_field: + self.contour.set_color( + self.colors( + [ + (self.nb_step - self.now + i) / self.nb_step + for i in chain(*self.c_segs) + ] + ) ) - ) - else: - self.segs.append(empty((0, 2))) - self.contour.set_paths(self.segs) - self.contour.set_color(self.colors[-len(self.segs) :]) - self.contour.set_lw(arange(len(self.segs)) / len(self.segs) * 2.5) - self.txt.set_text(f"{self.now} - {1/self.sleep_event:.0f} frame/s") + else: + self.contour.set_color(self.colors(list(chain(*self.c_segs)))) + # linewidth will be link to time delay + self.contour.set_lw( + [ + (1 - (self.now - i) / self.nb_step) * 2.5 if i <= self.now else 0 + for i in chain(*self.t_segs) + ] + ) + # Update date txt and info + if self.txt is not None: + txt = f"{(timedelta(int(self.now)) + datetime(1950,1,1)).strftime('%Y/%m/%d')}" + if self.graphic_informations: + txt += f"- {1/self.sleep_event:.0f} frame/s" + self.txt.set_text(txt) + self.ax.draw_artist(self.txt) + # Update id txt + if self.field_txt is not None: + for i in where(m)[0]: + mappable = self.ax.text( + self.x_core[i], self.y_core[i], self.field_txt[i], **self.kw_label + ) + self.mappables.append(mappable) + self.ax.draw_artist(mappable) self.ax.draw_artist(self.contour) - self.ax.draw_artist(self.txt) + # Remove first segment to keep only T contour if len(self.segs) > self.nb_step: self.segs.pop(0) + self.t_segs.pop(0) + if color: + self.c_segs.pop(0) + + def draw_contour(self): + # select contour for this time step + while self.mappables: + self.mappables.pop().remove() + self.ax.figure.canvas.restore_region(self.bg_cache) + self.update() # paint updated artist self.ax.figure.canvas.blit(self.ax.bbox) @@ -145,8 +250,13 @@ def keyboard(self, event): elif event.key == "right" and self.pause: self.next() elif event.key == "left" and self.pause: + # we remove 2 step to add 1 so we rewind of only one self.segs.pop(-1) self.segs.pop(-1) + self.t_segs.pop(-1) + self.t_segs.pop(-1) + self.c_segs.pop(-1) + self.c_segs.pop(-1) self.prev() @@ -156,16 +266,13 @@ def anim(): left arrow => t - 1, right arrow => t + 1, + => speed increase of 10 %, - => speed decrease of 10 %""" ) parser.add_argument("filename", help="eddy atlas") - parser.add_argument("id", help="Track id to anim", type=int) - parser.add_argument( - "--intern", - action="store_true", - help="display intern contour inplace of outter contour", - ) + parser.add_argument("id", help="Track id to anim", type=int, nargs="*") + parser.contour_intern_arg() parser.add_argument( "--keep_step", default=25, help="number maximal of step displayed", type=int ) parser.add_argument("--cmap", help="matplotlib colormap used") + parser.add_argument("--all", help="All eddies will be drawed", action="store_true") parser.add_argument( "--time_sleep", type=float, @@ -175,17 +282,84 @@ def anim(): parser.add_argument( "--infinity_loop", action="store_true", help="Press Escape key to stop loop" ) + parser.add_argument( + "--first_centered", + action="store_true", + help="Longitude will be centered on first obs.", + ) + parser.add_argument( + "--field", default="time", help="Field use to color contour instead of time" + ) + parser.add_argument("--txt_field", default="track", help="Field use to text eddy") + parser.add_argument( + "--vmin", default=None, type=float, help="Inferior bound to color contour" + ) + parser.add_argument( + "--vmax", default=None, type=float, help="Upper bound to color contour" + ) + parser.add_argument("--mp4", help="Filename to save animation (mp4)") args = parser.parse_args() - variables = ["time", "track"] + variables = list( + set(["time", "track", "longitude", "latitude", args.field, args.txt_field]) + ) variables.extend(TrackEddiesObservations.intern(args.intern, public_label=True)) - atlas = TrackEddiesObservations.load_file(args.filename, include_vars=variables) - eddy = atlas.extract_ids([args.id]) + eddies = TrackEddiesObservations.load_file( + args.filename, include_vars=set(variables) + ) + if not args.all: + if len(args.id) == 0: + raise Exception( + "You need to specify id to display or ask explicity all with --all option" + ) + eddies = eddies.extract_ids(args.id) + if args.first_centered: + # TODO: include to observation class + x0 = eddies.lon[0] + eddies.lon[:] = (eddies.lon - x0 + 180) % 360 + x0 - 180 + eddies.contour_lon_e[:] = ( + (eddies.contour_lon_e.T - eddies.lon + 180) % 360 + eddies.lon - 180 + ).T + + kw = dict() + if args.mp4: + kw["figsize"] = (16, 9) + kw["dpi"] = 120 a = Anim( - eddy, + eddies, intern=args.intern, sleep_event=args.time_sleep, cmap=args.cmap, nb_step=args.keep_step, + field_color=args.field, + field_txt=args.txt_field, + range_color=(args.vmin, args.vmax), + graphic_information=logger.getEffectiveLevel() == logging.DEBUG, + **kw, ) - a.show(infinity_loop=args.infinity_loop) + if args.mp4 is None: + a.show(infinity_loop=args.infinity_loop) + else: + kwargs = dict(frames=arange(*a.period), interval=50) + ani = FuncAnimation(a.fig, a.func_animation, **kwargs) + ani.save(args.mp4, fps=30, extra_args=["-vcodec", "libx264"]) + + +def gui_parser(): + parser = EddyParser("Eddy atlas GUI") + parser.add_argument("atlas", nargs="+") + parser.add_argument("--med", action="store_true") + parser.add_argument("--nopath", action="store_true", help="Don't draw path") + return parser.parse_args() + + +def guieddy(): + args = gui_parser() + atlas = { + dataset: TrackEddiesObservations.load_file(dataset) for dataset in args.atlas + } + g = GUI(**atlas) + if args.med: + g.med() + g.hide_path(not args.nopath) + g.show() diff --git a/src/py_eddy_tracker/appli/misc.py b/src/py_eddy_tracker/appli/misc.py index f3a292ba..ad7a71e5 100644 --- a/src/py_eddy_tracker/appli/misc.py +++ b/src/py_eddy_tracker/appli/misc.py @@ -1,26 +1,9 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Entry point with no direct link with eddies """ import argparse + import zarr diff --git a/src/py_eddy_tracker/appli/network.py b/src/py_eddy_tracker/appli/network.py index 191fc65d..0a3d06ca 100644 --- a/src/py_eddy_tracker/appli/network.py +++ b/src/py_eddy_tracker/appli/network.py @@ -1,35 +1,15 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Entry point to create and manipulate observations network """ import logging -from netCDF4 import Dataset -from numpy import empty, arange, zeros -from Polygon import Polygon -from ..poly import polygon_overlap, create_vertice_from_2darray + +from numpy import in1d, zeros + from .. import EddyParser -from ..observations.network import Network +from ..observations.network import Network, NetworkObservations from ..observations.tracking import TrackEddiesObservations -from ..generic import build_index logger = logging.getLogger("pet") @@ -37,286 +17,286 @@ def build_network(): parser = EddyParser("Merge eddies") parser.add_argument( - "identification_regex", help="Give an expression which will use with glob", + "identification_regex", help="Give an expression which will use with glob" ) parser.add_argument("out", help="output file") parser.add_argument( "--window", "-w", type=int, help="Half time window to search eddy", default=1 ) + parser.add_argument( - "--intern", + "--min-overlap", + "-p", + type=float, + help="minimum overlap area to associate observations", + default=0.2, + ) + parser.add_argument( + "--minimal-area", action="store_true", - help="Use intern contour instead of outter contour", + help="If True, use intersection/little polygon, else intersection/union", ) + parser.add_argument( + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", + ) + + parser.contour_intern_arg() + + parser.memory_arg() args = parser.parse_args() - n = Network(args.identification_regex, window=args.window, intern=args.intern) - group = n.group_observations(minimal_area=True) + n = Network( + args.identification_regex, + window=args.window, + intern=args.intern, + memory=args.memory, + ) + group = n.group_observations( + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, + ) n.build_dataset(group).write_file(filename=args.out) def divide_network(): - parser = EddyParser("Separate path for a same group") + parser = EddyParser("Separate path for a same group (network)") parser.add_argument("input", help="input network file") parser.add_argument("out", help="output file") + parser.contour_intern_arg() + parser.add_argument( + "--window", "-w", type=int, help="Half time window to search eddy", default=1 + ) parser.add_argument( - "--intern", + "--min-overlap", + "-p", + type=float, + help="minimum overlap area to associate observations", + default=0.2, + ) + parser.add_argument( + "--minimal-area", action="store_true", - help="Use intern contour instead of outter contour", + help="If True, use intersection/little polygon, else intersection/union", ) parser.add_argument( - "--window", "-w", type=int, help="Half time window to search eddy", default=1 + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", ) args = parser.parse_args() contour_name = TrackEddiesObservations.intern(args.intern, public_label=True) e = TrackEddiesObservations.load_file( - args.input, include_vars=("time", "track", *contour_name) + args.input, + include_vars=("time", "track", "latitude", "longitude", *contour_name), ) - e.split_network(intern=args.intern, window=args.window) - # split_network(args.input, args.out) - - -def split_network(input, output): - """Divide each group in track - """ - sl = slice(None) - with Dataset(input) as h: - group = h.variables["track"][sl] - track_s, track_e, track_ref = build_index(group) - # nb = track_e - track_s - # m = nb > 1500 - # print(group[track_s[m]]) - - track_id = 12003 - sls = [slice(track_s[track_id - track_ref], track_e[track_id - track_ref], None)] - for sl in sls: - - print(sl) - with Dataset(input) as h: - time = h.variables["time"][sl] - group = h.variables["track"][sl] - x = h.variables["effective_contour_longitude"][sl] - y = h.variables["effective_contour_latitude"][sl] - print(group[0]) - ids = empty( - time.shape, - dtype=[ - ("group", group.dtype), - ("time", time.dtype), - ("track", "u2"), - ("previous_cost", "f4"), - ("next_cost", "f4"), - ("previous_observation", "i4"), - ("next_observation", "i4"), - ], - ) - ids["group"] = group - ids["time"] = time - # To store id track - ids["track"] = 0 - ids["previous_cost"] = 0 - ids["next_cost"] = 0 - ids["previous_observation"] = -1 - ids["next_observation"] = -1 - # Cost with previous - track_start, track_end, track_ref = build_index(group) - for i0, i1 in zip(track_start, track_end): - if (i1 - i0) == 0 or group[i0] == Network.NOGROUP: - continue - sl_group = slice(i0, i1) - set_tracks( - x[sl_group], - y[sl_group], - time[sl_group], - i0, - ids["track"][sl_group], - ids["previous_cost"][sl_group], - ids["next_cost"][sl_group], - ids["previous_observation"][sl_group], - ids["next_observation"][sl_group], - window=5, - ) + n = NetworkObservations.from_split_network( + TrackEddiesObservations.load_file(args.input, raw_data=True), + e.split_network( + intern=args.intern, + window=args.window, + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, + ), + ) + n.write_file(filename=args.out) - new_i = ids.argsort(order=("group", "track", "time")) - ids_sort = ids[new_i] - # To be able to follow indices sorting - reverse_sort = empty(new_i.shape[0], dtype="u4") - reverse_sort[new_i] = arange(new_i.shape[0]) - # Redirect indices - m = ids_sort["next_observation"] != -1 - ids_sort["next_observation"][m] = reverse_sort[ids_sort["next_observation"][m]] - m = ids_sort["previous_observation"] != -1 - ids_sort["previous_observation"][m] = reverse_sort[ - ids_sort["previous_observation"][m] + +def subset_network(): + parser = EddyParser("Subset network") + parser.add_argument("input", help="input network file") + parser.add_argument("out", help="output file") + parser.add_argument( + "-l", + "--length", + nargs=2, + type=int, + help="Nb of days that must be covered by the network, first minimum number of day and last maximum number of day," + "if value is negative, this bound won't be used", + ) + parser.add_argument( + "--remove_dead_end", + nargs=2, + type=int, + help="Remove short dead end, first is for minimal obs number and second for minimal segment time to keep", + ) + parser.add_argument( + "--remove_trash", + action="store_true", + help="Remove trash (network id == 0)", + ) + parser.add_argument( + "-i", "--ids", nargs="+", type=int, help="List of network which will be extract" + ) + parser.add_argument( + "-p", + "--period", + nargs=2, + type=int, + help="Start day and end day, if it's a negative value we will add to day min and add to day max," + "if 0 it is not used", + ) + args = parser.parse_args() + n = NetworkObservations.load_file(args.input, raw_data=True) + if args.ids is not None: + n = n.networks(args.ids) + if args.length is not None: + n = n.longer_than(*args.length) + if args.remove_dead_end is not None: + n = n.remove_dead_end(*args.remove_dead_end) + if args.period is not None: + n = n.extract_with_period(args.period) + n.write_file(filename=args.out) + + +def quick_compare(): + parser = EddyParser( + """Tool to have a quick comparison between several network: + - N : network + - S : segment + - Obs : observations + """ + ) + parser.add_argument("ref", help="Identification file of reference") + parser.add_argument("others", nargs="+", help="Identifications files to compare") + parser.add_argument( + "--path_out", default=None, help="Save each group in separate file" + ) + args = parser.parse_args() + + kw = dict( + include_vars=[ + "longitude", + "latitude", + "time", + "track", + "segment", + "next_obs", + "previous_obs", ] - # print(ids_sort) - display_network( - x[new_i], - y[new_i], - ids_sort["track"], - ids_sort["time"], - ids_sort["next_cost"], - ) + ) + if args.path_out is not None: + kw = dict() -def next_obs( - i_current, next_cost, previous_cost, polygons, t, t_start, t_end, t_ref, window -): - t_max = t_end.shape[0] - 1 - t_cur = t[i_current] - t0, t1 = t_cur + 1 - t_ref, t_cur + window - t_ref - if t0 > t_max: - return -1 - t1 = min(t1, t_max) - for t_step in range(t0, t1 + 1): - i0, i1 = t_start[t_step], t_end[t_step] - # No observation at the time step ! - if i0 == i1: - continue - sl = slice(i0, i1) - # Intersection / union, to be able to separte in case of multiple inside - c = polygon_overlap(polygons[i_current], polygons[sl]) - # We remove low overlap - if (c > 0.1).sum() > 1: - print(c) - c[c < 0.1] = 0 - # We get index of maximal overlap - i = c.argmax() - c_i = c[i] - # No overlap found - if c_i == 0: - continue - target = i0 + i - # Check if candidate is already used - c_target = previous_cost[target] - if (c_target != 0 and c_target < c_i) or c_target == 0: - previous_cost[target] = c_i - next_cost[i_current] = c_i - return target - return -1 - - -def set_tracks( - x, - y, - t, - ref_index, - track, - previous_cost, - next_cost, - previous_observation, - next_observation, - window, -): - # Will split one group in tracks - t_start, t_end, t_ref = build_index(t) - nb = x.shape[0] - used = zeros(nb, dtype="bool") - current_track = 1 - # build all polygon (need to check if wrap is needed) - polygons = list() - for i in range(nb): - polygons.append(Polygon(create_vertice_from_2darray(x, y, i))) - - for i in range(nb): - # If observation already in one track, we go to the next one - if used[i]: - continue - build_track( - i, - current_track, - used, - track, - previous_observation, - next_observation, - ref_index, - next_cost, - previous_cost, - polygons, - t, - t_start, - t_end, - t_ref, - window, - ) - current_track += 1 - - -def build_track( - first_index, - track_id, - used, - track, - previous_observation, - next_observation, - ref_index, - next_cost, - previous_cost, - *args, -): - i_next = first_index - while i_next != -1: - # Flag - used[i_next] = True - # Assign id - track[i_next] = track_id - # Search next - i_next_ = next_obs(i_next, next_cost, previous_cost, *args) - if i_next_ == -1: - break - next_observation[i_next] = i_next_ + ref_index - if not used[i_next_]: - previous_observation[i_next_] = i_next + ref_index - # Target was previously used - if used[i_next_]: - if next_cost[i_next] == previous_cost[i_next_]: - m = track[i_next_:] == track[i_next_] - track[i_next_:][m] = track_id - previous_observation[i_next_] = i_next + ref_index - i_next_ = -1 - i_next = i_next_ - - -def display_network(x, y, tr, t, c): - tr0, tr1, t_ref = build_index(tr) - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("jet") - from ..generic import flatten_line_matrix - - fig = plt.figure(figsize=(20, 10)) - ax = fig.add_subplot(121, aspect="equal") - ax.grid() - ax_time = fig.add_subplot(122) - ax_time.grid() - i = 0 - for s, e in zip(tr0, tr1): - if s == e: - continue - sl = slice(s, e) - color = cmap((tr[s] - tr[tr0[0]]) / (tr[tr0[-1]] - tr[tr0[0]])) - ax.plot( - flatten_line_matrix(x[sl]), - flatten_line_matrix(y[sl]), - color=color, - label=f"{tr[s]} - {e-s} obs from {t[s]} to {t[e-1]}", - ) - i += 1 - ax_time.plot( - t[sl], - tr[s].repeat(e - s) + c[sl], - color=color, - label=f"{tr[s]} - {e-s} obs", - lw=0.5, - ) - ax_time.plot( - t[sl], tr[s].repeat(e - s), color=color, lw=1, marker="+", - ) - ax_time.text(t[s], tr[s] + 0.15, f"{x[s].mean():.2f}, {y[s].mean():.2f}") - ax_time.axvline(t[s], color=".75", lw=0.5, ls="--", zorder=-10) - ax_time.text( - t[e - 1], tr[e - 1] - 0.25, f"{x[e-1].mean():.2f}, {y[e-1].mean():.2f}" + ref = NetworkObservations.load_file(args.ref, **kw) + print( + f"[ref] {args.ref} -> {ref.nb_network} network / {ref.nb_segment} segment / {len(ref)} obs " + f"-> {ref.network_size(0)} trash obs, " + f"{len(ref.merging_event())} merging, {len(ref.splitting_event())} spliting" + ) + others = { + other: NetworkObservations.load_file(other, **kw) for other in args.others + } + + # if args.path_out is not None: + # groups_ref, groups_other = run_compare(ref, others, **kwargs) + # if not exists(args.path_out): + # mkdir(args.path_out) + # for i, other_ in enumerate(args.others): + # dirname_ = f"{args.path_out}/{other_.replace('/', '_')}/" + # if not exists(dirname_): + # mkdir(dirname_) + # for k, v in groups_other[other_].items(): + # basename_ = f"other_{k}.nc" + # others[other_].index(v).write_file(filename=f"{dirname_}/{basename_}") + # for k, v in groups_ref[other_].items(): + # basename_ = f"ref_{k}.nc" + # ref.index(v).write_file(filename=f"{dirname_}/{basename_}") + # return + display_compare(ref, others) + + +def run_compare(ref, others): + outs = dict() + for i, (k, other) in enumerate(others.items()): + out = dict() + print( + f"[{i}] {k} -> {other.nb_network} network / {other.nb_segment} segment / {len(other)} obs " + f"-> {other.network_size(0)} trash obs, " + f"{len(other.merging_event())} merging, {len(other.splitting_event())} spliting" ) - ax.legend() - ax_time.legend() - plt.show() + ref_id, other_id = ref.identify_in(other, size_min=2) + m = other_id != -1 + ref_id, other_id = ref_id[m], other_id[m] + out["same N(N)"] = m.sum() + out["same N(Obs)"] = ref.network_size(ref_id).sum() + + # For network which have same obs + ref_, other_ = ref.networks(ref_id), other.networks(other_id) + ref_segu, other_segu = ref_.identify_in(other_, segment=True) + m = other_segu == -1 + ref_track_no_match, _ = ref_.unique_segment_to_id(ref_segu[m]) + ref_segu, other_segu = ref_segu[~m], other_segu[~m] + m = ~in1d(ref_id, ref_track_no_match) + out["same NS(N)"] = m.sum() + out["same NS(Obs)"] = ref.network_size(ref_id[m]).sum() + + # Check merge/split + def follow_obs(d, i_follow): + m = i_follow != -1 + i_follow = i_follow[m] + t, x, y = ( + zeros(m.size, d.time.dtype), + zeros(m.size, d.longitude.dtype), + zeros(m.size, d.latitude.dtype), + ) + t[m], x[m], y[m] = ( + d.time[i_follow], + d.longitude[i_follow], + d.latitude[i_follow], + ) + return t, x, y + + def next_obs(d, i_seg): + last_i = d.index_segment_track[1][i_seg] - 1 + return follow_obs(d, d.next_obs[last_i]) + + def previous_obs(d, i_seg): + first_i = d.index_segment_track[0][i_seg] + return follow_obs(d, d.previous_obs[first_i]) + + tref, xref, yref = next_obs(ref_, ref_segu) + tother, xother, yother = next_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + tref, xref, yref = previous_obs(ref_, ref_segu) + tother, xother, yother = previous_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + ref_segu, other_segu = ref.identify_in(other, segment=True) + m = other_segu != -1 + out["same S(S)"] = m.sum() + out["same S(Obs)"] = ref.segment_size()[ref_segu[m]].sum() + + outs[k] = out + return outs + + +def display_compare(ref, others): + def display(value, ref=None): + if ref: + outs = [f"{v / ref[k] * 100:.1f}% ({v})" for k, v in value.items()] + else: + outs = value + return "".join([f"{v:^18}" for v in outs]) + + datas = run_compare(ref, others) + ref_ = { + "same N(N)": ref.nb_network, + "same N(Obs)": len(ref), + "same NS(N)": ref.nb_network, + "same NS(Obs)": len(ref), + "same S(S)": ref.nb_segment, + "same S(Obs)": len(ref), + } + print(" ", display(ref_.keys())) + for i, (_, v) in enumerate(datas.items()): + print(f"[{i:2}] ", display(v, ref=ref_)) diff --git a/src/py_eddy_tracker/data/20160707000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc b/src/py_eddy_tracker/data/20160707000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc new file mode 100644 index 00000000..cdc2f59f Binary files /dev/null and b/src/py_eddy_tracker/data/20160707000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc differ diff --git a/src/py_eddy_tracker/data/Anticyclonic_20160515.nc b/src/py_eddy_tracker/data/Anticyclonic_20160515.nc new file mode 100644 index 00000000..a1c2a922 Binary files /dev/null and b/src/py_eddy_tracker/data/Anticyclonic_20160515.nc differ diff --git a/src/py_eddy_tracker/data/Anticyclonic_20190223.nc b/src/py_eddy_tracker/data/Anticyclonic_20190223.nc index ddedbbd0..4ab8f226 100644 Binary files a/src/py_eddy_tracker/data/Anticyclonic_20190223.nc and b/src/py_eddy_tracker/data/Anticyclonic_20190223.nc differ diff --git a/src/py_eddy_tracker/data/Cyclonic_20160515.nc b/src/py_eddy_tracker/data/Cyclonic_20160515.nc new file mode 100644 index 00000000..4af9c7af Binary files /dev/null and b/src/py_eddy_tracker/data/Cyclonic_20160515.nc differ diff --git a/src/py_eddy_tracker/data/Cyclonic_20190223.nc b/src/py_eddy_tracker/data/Cyclonic_20190223.nc index 4c3cb56f..a133e4ff 100644 Binary files a/src/py_eddy_tracker/data/Cyclonic_20190223.nc and b/src/py_eddy_tracker/data/Cyclonic_20190223.nc differ diff --git a/src/py_eddy_tracker/data/__init__.py b/src/py_eddy_tracker/data/__init__.py index f5d3aaae..bf062983 100644 --- a/src/py_eddy_tracker/data/__init__.py +++ b/src/py_eddy_tracker/data/__init__.py @@ -1,22 +1,48 @@ -from os import path -import requests +""" +EddyId \ + nrt_global_allsat_phy_l4_20190223_20190226.nc \ + 20190223 adt ugos vgos longitude latitude . \ + --cut 800 --fil 1 +EddyId \ + dt_med_allsat_phy_l4_20160515_20190101.nc \ + 20160515 adt None None longitude latitude . \ + --cut 800 --fil 1 +""" + import io -import tarfile import lzma +from os import path +import tarfile + +import requests -def get_path(name): +def get_demo_path(name): return path.join(path.dirname(__file__), name) -def get_remote_sample(path): - url = ( - f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" - ) - - content = requests.get(url).content +def get_remote_demo_sample(path): + if path.startswith("/") or path.startswith("."): + content = open(path, "rb").read() + if path.endswith(".nc"): + return io.BytesIO(content) + else: + try: + import py_eddy_tracker_sample_id + if path.endswith(".nc"): + return py_eddy_tracker_sample_id.get_remote_demo_sample(path) + content = open(py_eddy_tracker_sample_id.get_remote_demo_sample(f"{path}.tar.xz"), "rb").read() + except: + if path.endswith(".nc"): + content = requests.get( + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}" + ).content + return io.BytesIO(content) + content = requests.get( + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" + ).content - # Tar module could manage lzma tar, but it will apply un compress for each extractfile + # Tar module could manage lzma tar, but it will apply uncompress for each extractfile tar = tarfile.open(mode="r", fileobj=io.BytesIO(lzma.decompress(content))) # tar = tarfile.open(mode="r:xz", fileobj=io.BytesIO(content)) files_content = list() diff --git a/src/py_eddy_tracker/data/dt_blacksea_allsat_phy_l4_20160707_20200801.nc b/src/py_eddy_tracker/data/dt_blacksea_allsat_phy_l4_20160707_20200801.nc new file mode 100644 index 00000000..2b22e6ba Binary files /dev/null and b/src/py_eddy_tracker/data/dt_blacksea_allsat_phy_l4_20160707_20200801.nc differ diff --git a/src/py_eddy_tracker/data/dt_med_allsat_phy_l4_2005T2.nc b/src/py_eddy_tracker/data/dt_med_allsat_phy_l4_2005T2.nc new file mode 100644 index 00000000..cff2e2c7 Binary files /dev/null and b/src/py_eddy_tracker/data/dt_med_allsat_phy_l4_2005T2.nc differ diff --git a/src/py_eddy_tracker/data/loopers_lumpkin_med.nc b/src/py_eddy_tracker/data/loopers_lumpkin_med.nc new file mode 100644 index 00000000..cf817424 Binary files /dev/null and b/src/py_eddy_tracker/data/loopers_lumpkin_med.nc differ diff --git a/src/py_eddy_tracker/data/network_med.nc b/src/py_eddy_tracker/data/network_med.nc new file mode 100644 index 00000000..d695b09b Binary files /dev/null and b/src/py_eddy_tracker/data/network_med.nc differ diff --git a/src/py_eddy_tracker/dataset/grid.py b/src/py_eddy_tracker/dataset/grid.py index 033f94f2..f15503b2 100644 --- a/src/py_eddy_tracker/dataset/grid.py +++ b/src/py_eddy_tracker/dataset/grid.py @@ -1,67 +1,75 @@ # -*- coding: utf-8 -*- """ +Class to load and manipulate RegularGrid and UnRegularGrid """ +from datetime import datetime import logging + +from cv2 import filter2D +from matplotlib.path import Path as BasePath +from netCDF4 import Dataset +from numba import njit, prange, types as numba_types +import numpy as np from numpy import ( - concatenate, - empty, - where, + arange, array, - sin, - deg2rad, - pi, - ones, + ceil, + concatenate, cos, - ma, - int8, - histogram2d, - arange, - float_, - linspace, + deg2rad, + empty, errstate, + exp, + float_, + floor, + histogram2d, int_, interp, + isnan, + linspace, + ma, + mean as np_mean, meshgrid, nan, - ceil, - sinc, - isnan, + nanmean, + ones, percentile, + pi, + radians, + sin, + sinc, + sqrt, + where, zeros, - round_, - nanmean, - exp, - mean as np_mean, ) -from datetime import datetime -from scipy.special import j1 -from netCDF4 import Dataset -from scipy.ndimage import gaussian_filter, convolve +from pint import UnitRegistry from scipy.interpolate import RectBivariateSpline, interp1d -from scipy.spatial import cKDTree +from scipy.ndimage import gaussian_filter from scipy.signal import welch -from cv2 import filter2D -from numba import njit, types as numba_types -from matplotlib.path import Path as BasePath -from pyproj import Proj -from pint import UnitRegistry -from ..observations.observation import EddiesObservations -from ..eddy_feature import Amplitude, Contours +from scipy.spatial import cKDTree +from scipy.special import j1 + from .. import VAR_DESCR +from ..data import get_demo_path +from ..eddy_feature import Amplitude, Contours from ..generic import ( + bbox_indice_regular, + coordinates_to_local, distance, interp2d_geo, - fit_circle, - uniform_resample, - coordinates_to_local, local_to_coordinates, + nearest_grd_indice, + uniform_resample, ) +from ..observations.observation import EddiesObservations from ..poly import ( - poly_contain_poly, - winding_number_grid_in_poly, - winding_number_poly, create_vertice, + fit_circle, + get_pixel_in_regular, poly_area, + poly_contain_poly, + visvalingam, + winding_number_poly, ) logger = logging.getLogger("pet") @@ -116,7 +124,7 @@ def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size @njit(cache=True) def mean_on_regular_contour( - x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None, nan_remove=False + x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=-1, nan_remove=False ): x_val, y_val = vertices[:, 0], vertices[:, 1] x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size) @@ -144,7 +152,7 @@ def _circle_from_equal_area(vertice): # last coordinates == first lon0, lat0 = lons[1:].mean(), lats[1:].mean() c_x, c_y = coordinates_to_local(lons, lats, lon0, lat0) - # Some time, edge is only a dot of few coordinates + # Sometimes, edge is only a dot of few coordinates d_lon = lons.max() - lons.min() d_lat = lats.max() - lats.min() if d_lon < 1e-7 and d_lat < 1e-7: @@ -152,7 +160,7 @@ def _circle_from_equal_area(vertice): # logger.debug('%d coordinates %s,%s', len(lons),lons, # lats) return 0, -90, nan, nan - return lon0, lat0, (poly_area(create_vertice(c_x, c_y)) / pi) ** 0.5, nan + return lon0, lat0, (poly_area(c_x, c_y) / pi) ** 0.5, nan @njit(cache=True, fastmath=True) @@ -175,36 +183,6 @@ def _fit_circle_path(vertice): return centlon, centlat, eddy_radius, err -@njit(cache=True, fastmath=True) -def _get_pixel_in_regular(vertices, x_c, y_c, x_start, x_stop, y_start, y_stop): - if x_stop < x_start: - x_ref = vertices[0, 0] - x_array = ( - (concatenate((x_c[x_start:], x_c[:x_stop])) - x_ref + 180) % 360 - + x_ref - - 180 - ) - return winding_number_grid_in_poly( - x_array, - y_c[y_start:y_stop], - x_start, - x_stop, - x_c.shape[0], - y_start, - vertices, - ) - else: - return winding_number_grid_in_poly( - x_c[x_start:x_stop], - y_c[y_start:y_stop], - x_start, - x_stop, - x_c.shape[0], - y_start, - vertices, - ) - - @njit(cache=True, fastmath=True) def _get_pixel_in_unregular(vertices, x_c, y_c, x_start, x_stop, y_start, y_stop): nb_x, nb_y = x_stop - x_start, y_stop - y_start @@ -260,19 +238,15 @@ def nb_pixel(self): class GridDataset(object): """ - Class to have basic tool on NetCDF Grid + Class for basic tools on NetCDF Grid """ __slots__ = ( - "_x_var", - "_y_var", "x_c", "y_c", "x_bounds", "y_bounds", "centered", - "xinterp", - "yinterp", "x_dim", "y_dim", "coordinates", @@ -282,9 +256,8 @@ class GridDataset(object): "variables_description", "global_attrs", "vars", - "interpolators", - "speed_coef", "contours", + "nan_mask", ) GRAVITY = 9.807 @@ -294,8 +267,24 @@ class GridDataset(object): N = 1 def __init__( - self, filename, x_name, y_name, centered=None, indexs=None, unset=False + self, + filename, + x_name, + y_name, + centered=None, + indexs=None, + unset=False, + nan_masking=False, ): + """ + :param str filename: Filename to load + :param str x_name: Name of longitude coordinates + :param str y_name: Name of latitude coordinates + :param bool,None centered: Allow to know how coordinates could be used with pixel + :param dict indexs: A dictionary that sets indexes to use for non-coordinate dimensions + :param bool unset: Set to True to create an empty grid object without file + :param bool nan_masking: Set to True to replace data.mask with isnan method result + """ self.dimensions = None self.variables_description = None self.global_attrs = None @@ -305,27 +294,45 @@ def __init__( self.y_bounds = None self.x_dim = None self.y_dim = None + self.nan_mask = nan_masking self.centered = centered self.contours = None - self.xinterp = None - self.yinterp = None self.filename = filename self.coordinates = x_name, y_name self.vars = dict() self.indexs = dict() if indexs is None else indexs - self.interpolators = dict() if centered is None: logger.warning( - "We assume pixel position of grid is center for %s", filename, + "We assume pixel position of grid is centered for %s", filename ) if not unset: + self.populate() + + def populate(self): + if self.dimensions is None: self.load_general_features() self.load() + def clean(self): + self.dimensions = None + self.variables_description = None + self.global_attrs = None + self.x_c = None + self.y_c = None + self.x_bounds = None + self.y_bounds = None + self.x_dim = None + self.y_dim = None + self.contours = None + self.vars = dict() + @property def is_centered(self): - """Give information if pixel is describe with center position or + """Give True if pixel is described with its center's position or a corner + + :return: True if centered + :rtype: bool """ if self.centered is None: return True @@ -333,8 +340,7 @@ def is_centered(self): return self.centered def load_general_features(self): - """Load attrs - """ + """Load attrs to be stored in object""" logger.debug( "Load general feature from %(filename)s", dict(filename=self.filename) ) @@ -344,7 +350,7 @@ def load_general_features(self): self.variables_description = dict() for i, v in h.variables.items(): args = (i, v.datatype) - kwargs = dict(dimensions=v.dimensions, zlib=True,) + kwargs = dict(dimensions=v.dimensions, zlib=True) if hasattr(v, "_FillValue"): kwargs["fill_value"] = (v._FillValue,) attrs = dict() @@ -360,7 +366,9 @@ def load_general_features(self): self.global_attrs = {attr: getattr(h, attr) for attr in h.ncattrs()} def write(self, filename): - """Write dataset output with same format like input + """Write dataset output with same format as input + + :param str filename: filename used to save the grid """ with Dataset(filename, "w") as h_out: for dimension, size in self.dimensions.items(): @@ -394,7 +402,9 @@ def write(self, filename): setattr(h_out, attr, value) def load(self): - """Load variable (data) + """ + Load variable (data). + Get coordinates and setup coordinates function """ x_name, y_name = self.coordinates with Dataset(self.filename) as h: @@ -407,14 +417,35 @@ def load(self): self.vars[y_name] = h.variables[y_name][sl_y] self.setup_coordinates() - self.init_pos_interpolator() + + @staticmethod + def get_mask(a): + if len(a.mask.shape): + m = a.mask + else: + m = ones(a.shape, dtype="bool") if a.mask else zeros(a.shape, dtype="bool") + return m + + @staticmethod + def c_to_bounds(c): + """ + Centered coordinates to bounds coordinates + + :param array c: centered coordinates to translate + :return: bounds coordinates + """ + bounds = concatenate((c, (2 * c[-1] - c[-2],))) + d = bounds[1:] - bounds[:-1] + bounds[:-1] -= d / 2 + bounds[-1] -= d[-1] / 2 + return bounds def setup_coordinates(self): x_name, y_name = self.coordinates if self.is_centered: - logger.info("Grid center") - self.x_c = self.vars[x_name].astype("float64") - self.y_c = self.vars[y_name].astype("float64") + # logger.info("Grid center") + self.x_c = array(self.vars[x_name].astype("float64")) + self.y_c = array(self.vars[y_name].astype("float64")) self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],))) self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],))) @@ -426,8 +457,8 @@ def setup_coordinates(self): self.y_bounds[-1] -= d_y[-1] / 2 else: - self.x_bounds = self.vars[x_name].astype("float64") - self.y_bounds = self.vars[y_name].astype("float64") + self.x_bounds = array(self.vars[x_name].astype("float64")) + self.y_bounds = array(self.vars[y_name].astype("float64")) if len(self.x_dim) == 1: self.x_c = self.x_bounds.copy() @@ -442,11 +473,11 @@ def setup_coordinates(self): raise Exception("not write") def is_circular(self): - """Check grid circularity - """ + """Check grid circularity""" return False def units(self, varname): + """Get unit from variable""" stored_units = self.variables_description[varname]["attrs"].get("units", None) if stored_units is not None: return stored_units @@ -455,14 +486,16 @@ def units(self, varname): if hasattr(var, "units"): return var.units + @property + def variables(self): + return self.variables_description.keys() + def copy(self, grid_in, grid_out): """ - Duplicate a variable - Args: - grid_in: - grid_out: + Duplicate the variable from grid_in in grid_out - Returns: + :param grid_in: + :param grid_out: """ h_dict = self.variables_description[grid_in] @@ -474,8 +507,24 @@ def copy(self, grid_in, grid_out): ) self.vars[grid_out] = self.grid(grid_in).copy() + def add_grid(self, varname, grid): + """ + Add a grid in handler + + :param str varname: name of the future grid + :param array grid: grid array + """ + self.vars[varname] = grid + def grid(self, varname, indexs=None): - """give grid required + """Give the grid required + + :param str varname: Variable to get + :param dict,None indexs: If defined dict must have dimensions name as key + :return: array asked, reduced by the indexes + :rtype: array + + .. minigallery:: py_eddy_tracker.GridDataset.grid """ if indexs is None: indexs = dict() @@ -504,6 +553,11 @@ def grid(self, varname, indexs=None): if i_x > i_y: self.variables_description[varname]["infos"]["transpose"] = True self.vars[varname] = self.vars[varname].T + if self.nan_mask: + self.vars[varname] = ma.array( + self.vars[varname], + mask=isnan(self.vars[varname]), + ) if not hasattr(self.vars[varname], "mask"): self.vars[varname] = ma.array( self.vars[varname], @@ -512,8 +566,7 @@ def grid(self, varname, indexs=None): return self.vars[varname] def grid_tiles(self, varname, slice_x, slice_y): - """give grid tiles required, without buffer system - """ + """Give the grid tiles required, without buffer system""" coordinates_dims = list(self.x_dim) coordinates_dims.extend(list(self.y_dim)) logger.debug( @@ -544,21 +597,26 @@ def grid_tiles(self, varname, slice_x, slice_y): return data def high_filter(self, grid_name, w_cut, **kwargs): - """create a high filter with a low one + """Return the high-pass filtered grid, by substracting to the initial grid the low-pass filtered grid (default: order=1) + + :param grid_name: the name of the grid + :param int, w_cut: the half-power wavelength cutoff (km) """ result = self._low_filter(grid_name, w_cut, **kwargs) self.vars[grid_name] -= result def low_filter(self, grid_name, w_cut, **kwargs): - """low filtering + """Return the low-pass filtered grid (default: order=1) + + :param grid_name: the name of the grid + :param int, w_cut: the half-power wavelength cutoff (km) """ result = self._low_filter(grid_name, w_cut, **kwargs) self.vars[grid_name] -= self.vars[grid_name] - result @property def bounds(self): - """Give bound - """ + """Give bounds""" return ( self.x_bounds.min(), self.x_bounds.max(), @@ -574,31 +632,46 @@ def eddy_identification( date, step=0.005, shape_error=55, + presampling_multiplier=10, sampling=50, + sampling_method="visvalingam", pixel_limit=None, precision=None, force_height_unit=None, force_speed_unit=None, + **kwargs, ): """ - - Args: - grid_height: - uname: - vname: - date: - step: must be in meter (m) - shape_error: must be in percent (%) - sampling: - pixel_limit: - precision: must be in meter(m) - - Returns: - + Compute eddy identification on the specified grid + + :param str grid_height: Grid name of Sea Surface Height + :param str uname: Grid name of u speed component + :param str vname: Grid name of v speed component + :param datetime.datetime date: Date to be stored in object to date data + :param float,int step: Height between two layers in m + :param float,int shape_error: Maximal error allowed for outermost contour in % + :param int presampling_multiplier: + Evenly oversample the initial number of points in the contour by nb_pts x presampling_multiplier to fit circles + :param int sampling: Number of points to store contours and speed profile + :param str sampling_method: Method to resample the stored contours, 'uniform' or 'visvalingam' + :param (int,int),None pixel_limit: + Min and max number of pixels inside the inner and the outermost contour to be considered as an eddy + :param float,None precision: Truncate values at the defined precision in m + :param str force_height_unit: Unit used for height unit + :param str force_speed_unit: Unit used for speed unit + :param dict kwargs: Arguments given to amplitude (mle, nb_step_min, nb_step_to_be_mle). + Look at :py:meth:`py_eddy_tracker.eddy_feature.Amplitude` + The amplitude threshold is given by `step*nb_step_min` + + + :return: Return a list of 2 elements: Anticyclones and Cyclones + :rtype: py_eddy_tracker.observations.observation.EddiesObservations + + .. minigallery:: py_eddy_tracker.GridDataset.eddy_identification """ if not isinstance(date, datetime): - raise Exception("Date argument be a datetime object") - # The inf limit must be in pixel and sup limit in surface + raise Exception("Date argument must be a datetime object") + # The inf limit must be in pixel and sup limit in surface if pixel_limit is None: pixel_limit = (4, 1000) @@ -606,7 +679,6 @@ def eddy_identification( self.init_speed_coef(uname, vname) # Get unit of h grid - h_units = ( self.units(grid_height) if force_height_unit is None else force_height_unit ) @@ -622,12 +694,12 @@ def eddy_identification( if precision is not None: precision /= factor - # Get h grid + # Get ssh grid data = self.grid(grid_height).astype("f8") - # In case of a reduce mask + # In case of a reduced mask if len(data.mask.shape) == 0 and not data.mask: data.mask = zeros(data.shape, dtype="bool") - # we remove noisy information + # we remove noisy data if precision is not None: data = (data / precision).round() * precision # Compute levels for ssh @@ -650,6 +722,7 @@ def eddy_identification( ) z_min, z_max = z_min_p, z_max_p + logger.debug("Levels from %f to %f", z_min, z_max) levels = arange(z_min - z_min % step, z_max - z_max % step + 2 * step, step) # Get x and y values @@ -659,6 +732,7 @@ def eddy_identification( self.contours = Contours(x, y, data, levels, wrap_x=self.is_circular()) out_sampling = dict(fixed_size=sampling) + resample = visvalingam if sampling_method == "visvalingam" else uniform_resample track_extra_variables = [ "height_max_speed_contour", "height_external_contour", @@ -673,7 +747,7 @@ def eddy_identification( "contour_lat_s", "uavg_profile", ] - # Compute cyclonic and anticylonic research: + # Complete cyclonic and anticylonic research: a_and_c = list() for anticyclonic_search in [True, False]: eddies = list() @@ -701,34 +775,42 @@ def eddy_identification( for contour in contour_paths: if contour.used: continue - centlon_e, centlat_e, eddy_radius_e, aerr = contour.fit_circle() + # FIXME : center could be outside the contour due to the fit + # FIXME : warning : the fit is made on raw sampling + _, _, _, aerr = contour.fit_circle() # Filter for shape if aerr < 0 or aerr > shape_error or isnan(aerr): - continue - # Get indices of centroid - # Give only 1D array of lon and lat not 2D data - i_x, i_y = self.nearest_grd_indice(centlon_e, centlat_e) - i_x = self.normalize_x_indice(i_x) - - # Check if centroid is on define value - if data.mask[i_x, i_y]: - continue - # Test to know cyclone or anticyclone - acyc_not_cyc = data[i_x, i_y] >= cvalues - if anticyclonic_search != acyc_not_cyc: + contour.reject = 1 continue # Find all pixels in the contour i_x_in, i_y_in = contour.pixels_in(self) - # Maybe limit max must be replace with a maximum of surface + # Check if pixels in contour are masked + if has_masked_value(data.mask, i_x_in, i_y_in): + if contour.reject == 0: + contour.reject = 2 + continue + + # Test of the rotating sense: cyclone or anticyclone + if has_value( + data.data, i_x_in, i_y_in, cvalues, below=anticyclonic_search + ): + continue + + # Test the number of pixels within the outermost contour + # FIXME : Maybe limit max must be replaced with a maximum of surface if ( contour.nb_pixel < pixel_limit[0] or contour.nb_pixel > pixel_limit[1] ): + contour.reject = 3 continue + # Here the considered contour passed shape_error test, masked_pixels test, + # values strictly above (AEs) or below (CEs) the contour, number_pixels test) + # Compute amplitude reset_centroid, amp = self.get_amplitude( contour, @@ -736,20 +818,20 @@ def eddy_identification( data, anticyclonic_search=anticyclonic_search, level=self.contours.levels[corrected_coll_index], - step=step, + interval=step, + **kwargs, ) # If we have a valid amplitude if (not amp.within_amplitude_limits()) or (amp.amplitude == 0): + contour.reject = 4 continue - if reset_centroid: - if self.is_circular(): centi = self.normalize_x_indice(reset_centroid[0]) else: centi = reset_centroid[0] centj = reset_centroid[1] - # To move in regular and unregular grid + # FIXME : To move in regular and unregular grid if len(x.shape) == 1: centlon_e = x[centi] centlat_e = y[centj] @@ -757,7 +839,7 @@ def eddy_identification( centlon_e = x[centi, centj] centlat_e = y[centi, centj] - # centlat_e and centlon_e must be index of maximum, we will loose some inner contour, if it's not + # centlat_e and centlon_e must be indexes of maximum, we will loose some inner contour if it's not ( max_average_speed, speed_contour, @@ -775,66 +857,78 @@ def eddy_identification( pixel_min=pixel_limit[0], ) - # Use azimuth equal projection for radius - proj = Proj( - "+proj=aeqd +ellps=WGS84 +lat_0={1} +lon_0={0}".format( - *inner_contour.mean_coordinates - ) - ) - # First, get position based on innermost - # contour - centx_i, centy_i, _, _ = fit_circle( - *proj(inner_contour.lon, inner_contour.lat) - ) - centlon_i, centlat_i = proj(centx_i, centy_i, inverse=True) - # Second, get speed-based radius based on - # contour of max uavg - centx_s, centy_s, eddy_radius_s, aerr_s = fit_circle( - *proj(speed_contour.lon, speed_contour.lat) - ) - # Computed again to be coherent with speed_radius, we will be compute in same reference - _, _, eddy_radius_e, aerr_e = fit_circle( - *proj(contour.lon, contour.lat) - ) - centlon_s, centlat_s = proj(centx_s, centy_s, inverse=True) - - # Instantiate new EddyObservation object (high cost need to be review) + # FIXME : Instantiate new EddyObservation object (high cost, need to be reviewed) obs = EddiesObservations( size=1, track_extra_variables=track_extra_variables, track_array_variables=sampling, array_variables=array_variables, ) - - obs.obs["height_max_speed_contour"] = self.contours.cvalues[ - i_max_speed - ] - obs.obs["height_external_contour"] = cvalues - obs.obs["height_inner_contour"] = self.contours.cvalues[i_inner] + obs.height_max_speed_contour[:] = self.contours.cvalues[i_max_speed] + obs.height_external_contour[:] = cvalues + obs.height_inner_contour[:] = self.contours.cvalues[i_inner] array_size = speed_array.shape[0] - obs.obs["nb_contour_selected"] = array_size + obs.nb_contour_selected[:] = array_size if speed_array.shape[0] == 1: - obs.obs["uavg_profile"][:] = speed_array[0] + obs.uavg_profile[:] = speed_array[0] else: - obs.obs["uavg_profile"] = raw_resample(speed_array, sampling) - obs.obs["amplitude"] = amp.amplitude - obs.obs["radius_s"] = eddy_radius_s - obs.obs["speed_average"] = max_average_speed - obs.obs["radius_e"] = eddy_radius_e - obs.obs["shape_error_e"] = aerr_e - obs.obs["shape_error_s"] = aerr_s - obs.obs["lon"] = centlon_s - obs.obs["lat"] = centlat_s - obs.obs["lon_max"] = centlon_i - obs.obs["lat_max"] = centlat_i - obs.obs["num_point_e"] = contour.lon.shape[0] - xy = uniform_resample(contour.lon, contour.lat, **out_sampling) - obs.obs["contour_lon_e"], obs.obs["contour_lat_e"] = xy - obs.obs["num_point_s"] = speed_contour.lon.shape[0] - xy = uniform_resample( - speed_contour.lon, speed_contour.lat, **out_sampling + obs.uavg_profile[:] = raw_resample(speed_array, sampling) + obs.amplitude[:] = amp.amplitude + obs.speed_average[:] = max_average_speed + obs.num_point_e[:] = contour.lon.shape[0] + obs.num_point_s[:] = speed_contour.lon.shape[0] + + # Evenly resample contours with nb_pts = nb_pts_original x presampling_multiplier + xy_i = uniform_resample( + inner_contour.lon, + inner_contour.lat, + num_fac=presampling_multiplier, + ) + xy_e = uniform_resample( + contour.lon, + contour.lat, + num_fac=presampling_multiplier, + ) + xy_s = uniform_resample( + speed_contour.lon, + speed_contour.lat, + num_fac=presampling_multiplier, + ) + + # First, get position of max SSH based on best fit circle with resampled innermost contour + centlon_i, centlat_i, _, _ = _fit_circle_path(create_vertice(*xy_i)) + obs.lon_max[:] = centlon_i + obs.lat_max[:] = centlat_i + + # Second, get speed-based radius, shape error, eddy center, area based on resampled contour of max uavg + centlon_s, centlat_s, eddy_radius_s, aerr_s = _fit_circle_path( + create_vertice(*xy_s) + ) + obs.radius_s[:] = eddy_radius_s + obs.shape_error_s[:] = aerr_s + obs.speed_area[:] = poly_area( + *coordinates_to_local(*xy_s, lon0=centlon_s, lat0=centlat_s) + ) + obs.lon[:] = centlon_s + obs.lat[:] = centlat_s + + # Third, compute effective radius, shape error, area from resampled effective contour + _, _, eddy_radius_e, aerr_e = _fit_circle_path( + create_vertice(*xy_e) ) - obs.obs["contour_lon_s"], obs.obs["contour_lat_s"] = xy + obs.radius_e[:] = eddy_radius_e + obs.shape_error_e[:] = aerr_e + obs.effective_area[:] = poly_area( + *coordinates_to_local(*xy_e, lon0=centlon_s, lat0=centlat_s) + ) + + # Finally, resample contours with output parameters + xy_e_f = resample(*xy_e, **out_sampling) + xy_s_f = resample(*xy_s, **out_sampling) + + obs.contour_lon_s[:], obs.contour_lat_s[:] = xy_s_f + obs.contour_lon_e[:], obs.contour_lat_e[:] = xy_e_f + if aerr > 99.9 or aerr_s > 99.9: logger.warning( "Strange shape at this step! shape_error : %f, %f", @@ -854,18 +948,14 @@ def eddy_identification( else: eddies = EddiesObservations.concatenate(eddies) eddies.sign_type = 1 if anticyclonic_search else -1 - eddies.obs["time"] = (date - datetime(1950, 1, 1)).total_seconds() / 86400.0 + eddies.time[:] = (date - datetime(1950, 1, 1)).total_seconds() / 86400.0 # normalization longitude between 0 - 360, because storage have an offset on 180 - eddies.obs["lon_max"] %= 360 - eddies.obs["lon"] %= 360 - ref = eddies.obs["lon"] - 180 - eddies.obs["contour_lon_e"] = ( - (eddies.obs["contour_lon_e"].T - ref) % 360 + ref - ).T - eddies.obs["contour_lon_s"] = ( - (eddies.obs["contour_lon_s"].T - ref) % 360 + ref - ).T + eddies.lon_max[:] %= 360 + eddies.lon[:] %= 360 + ref = eddies.lon - 180 + eddies.contour_lon_e[:] = ((eddies.contour_lon_e.T - ref) % 360 + ref).T + eddies.contour_lon_s[:] = ((eddies.contour_lon_s.T - ref) % 360 + ref).T a_and_c.append(eddies) if in_h_unit is not None: @@ -900,7 +990,7 @@ def get_uavg( pixel_min=3, ): """ - Calculate geostrophic speed around successive contours + Compute geostrophic speed around successive contours Returns the average """ # Init max speed to search maximum @@ -924,7 +1014,7 @@ def get_uavg( if not poly_contain_poly(original_contour.vertices, level_contour.vertices): break # 3. Respect size range (for max speed) - # nb_pixel properties need call of pixels_in before with a grid of pixel + # nb_pixel properties need to call pixels_in before with a grid of pixel level_contour.pixels_in(self) # Interpolate uspd to seglon, seglat, then get mean level_average_speed = self.speed_coef_mean(level_contour) @@ -954,8 +1044,7 @@ def get_uavg( @staticmethod def _gaussian_filter(data, sigma, mode="reflect"): - """Standard gaussian filter - """ + """Standard gaussian filter""" local_data = data.copy() local_data[data.mask] = 0 @@ -967,7 +1056,7 @@ def _gaussian_filter(data, sigma, mode="reflect"): @staticmethod def get_amplitude( - contour, contour_height, data, anticyclonic_search=True, level=None, step=None + contour, contour_height, data, anticyclonic_search=True, level=None, **kwargs ): # Instantiate Amplitude object amp = Amplitude( @@ -977,21 +1066,17 @@ def get_amplitude( contour_height=contour_height, # All grid data=data, - # Step by level - interval=step, + **kwargs, ) - if anticyclonic_search: reset_centroid = amp.all_pixels_above_h0(level) else: reset_centroid = amp.all_pixels_below_h0(level) - return reset_centroid, amp class UnRegularGridDataset(GridDataset): - """Class which manage unregular grid - """ + """Class managing unregular grid""" __slots__ = ( "index_interp", @@ -999,8 +1084,7 @@ class UnRegularGridDataset(GridDataset): ) def load(self): - """Load variable (data) - """ + """Load variable (data)""" x_name, y_name = self.coordinates with Dataset(self.filename) as h: self.x_dim = h.variables[x_name].dimensions @@ -1018,8 +1102,7 @@ def load(self): @property def bounds(self): - """Give bound - """ + """Give bounds""" return self.x_c.min(), self.x_c.max(), self.y_c.min(), self.y_c.max() def bbox_indice(self, vertices): @@ -1051,7 +1134,7 @@ def compute_pixel_path(self, x0, y0, x1, y1): pass def init_pos_interpolator(self): - logger.debug("Create a KdTree could be long ...") + logger.debug("Create a KdTree, could be long ...") self.index_interp = cKDTree( create_vertice(self.x_c.reshape(-1), self.y_c.reshape(-1)) ) @@ -1069,7 +1152,7 @@ def _low_filter(self, grid_name, w_cut, factor=8.0): bins = (x_array, y_array) x_flat, y_flat, z_flat = x.reshape((-1,)), y.reshape((-1,)), data.reshape((-1,)) - m = ~z_flat.mask + m = ~self.get_mask(z_flat) x_flat, y_flat, z_flat = x_flat[m], y_flat[m], z_flat[m] nb_value, _, _ = histogram2d(x_flat, y_flat, bins=bins) @@ -1115,8 +1198,7 @@ def init_speed_coef(self, uname="u", vname="v"): class RegularGridDataset(GridDataset): - """Class only for regular grid - """ + """Class only for regular grid""" __slots__ = ( "_speed_ev", @@ -1133,11 +1215,24 @@ def __init__(self, *args, **kwargs): def setup_coordinates(self): super().setup_coordinates() self.x_size = self.x_c.shape[0] + if len(self.x_c.shape) != 1: + raise Exception( + "Coordinates in RegularGridDataset must be 1D array, or think to use UnRegularGridDataset" + ) + dx = self.x_bounds[1:] - self.x_bounds[:-1] + dy = self.y_bounds[1:] - self.y_bounds[:-1] + if (dx < 0).any() or (dy < 0).any(): + raise Exception( + "Coordinates in RegularGridDataset must be strictly increasing" + ) self._x_step = (self.x_c[1:] - self.x_c[:-1]).mean() self._y_step = (self.y_c[1:] - self.y_c[:-1]).mean() @classmethod def with_array(cls, coordinates, datas, variables_description=None, **kwargs): + """ + Geo matrix data must be ordered like this (X,Y) and masked with numpy.ma.array + """ vd = dict() if variables_description is None else variables_description x_name, y_name = coordinates[0], coordinates[1] obj = cls("array", x_name, y_name, unset=True, **kwargs) @@ -1161,12 +1256,6 @@ def with_array(cls, coordinates, datas, variables_description=None, **kwargs): obj.setup_coordinates() return obj - def init_pos_interpolator(self): - """Create function to have a quick index interpolator - """ - self.xinterp = arange(self.x_bounds.shape[0]) - self.yinterp = arange(self.y_bounds.shape[0]) - def bbox_indice(self, vertices): return bbox_indice_regular( vertices, @@ -1180,34 +1269,44 @@ def bbox_indice(self, vertices): ) def get_pixels_in(self, contour): - (x_start, x_stop), (y_start, y_stop) = contour.bbox_slice - return _get_pixel_in_regular( - contour.vertices, self.x_c, self.y_c, x_start, x_stop, y_start, y_stop - ) + """ + Get indexes of pixels in contour. + + :param vertice,Path contour: Contour that encloses some pixels + :return: Indexes of grid in contour + :rtype: array[int],array[int] + """ + if isinstance(contour, BasePath): + (x_start, x_stop), (y_start, y_stop) = contour.bbox_slice + return get_pixel_in_regular( + contour.vertices, self.x_c, self.y_c, x_start, x_stop, y_start, y_stop + ) + else: + (x_start, x_stop), (y_start, y_stop) = self.bbox_indice(contour) + return get_pixel_in_regular( + contour, self.x_c, self.y_c, x_start, x_stop, y_start, y_stop + ) def normalize_x_indice(self, indices): return indices % self.x_size def nearest_grd_indice(self, x, y): - return _nearest_grd_indice( + return nearest_grd_indice( x, y, self.x_bounds, self.y_bounds, self.xstep, self.ystep ) @property def xstep(self): - """Only for regular grid with no step variation - """ + """Only for regular grid with no step variation""" return self._x_step @property def ystep(self): - """Only for regular grid with no step variation - """ + """Only for regular grid with no step variation""" return self._y_step def compute_pixel_path(self, x0, y0, x1, y1): - """Give a series of index which describe the path between to position - """ + """Give a series of indexes describing the path between two positions""" return compute_pixel_path( x0, y0, @@ -1220,14 +1319,16 @@ def compute_pixel_path(self, x0, y0, x1, y1): self.x_size, ) - def clean_land(self): - """Function to remove all land pixel - """ - pass + def clean_land(self, name): + """Function to remove all land pixel""" + mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat") + x, y = meshgrid(self.x_c, self.y_c) + m = mask_land.interp("mask", x.reshape(-1), y.reshape(-1), "nearest") + data = self.grid(name) + self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T) def is_circular(self): - """Check if grid is circular - """ + """Check if the grid is circular""" if self._is_circular is None: self._is_circular = ( abs((self.x_bounds[0] % 360) - (self.x_bounds[-1] % 360)) < 0.0001 @@ -1297,6 +1398,24 @@ def kernel_lanczos(self, lat, wave_length, order=1): kernel[dist_norm > order] = 0 return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_loess(self, lat, wave_length, order=1): + """ + https://fr.wikipedia.org/wiki/R%C3%A9gression_locale + """ + order = self.check_order(order) + half_x_pt, half_y_pt, dist_norm = self.estimate_kernel_shape( + lat, wave_length, order + ) + + def inc_func(xdist): + f = zeros(xdist.size) + f[abs(xdist) < 1] = 1 + return f + + kernel = (1 - abs(dist_norm) ** 3) ** 3 + kernel[abs(dist_norm) > order] = 0 + return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_bessel(self, lat, wave_length, order=1): """wave_length in km order must be int @@ -1312,8 +1431,7 @@ def kernel_bessel(self, lat, wave_length, order=1): return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) def _low_filter(self, grid_name, w_cut, **kwargs): - """low filtering - """ + """low filtering""" return self.convolve_filter_with_dynamic_kernel( grid_name, self.kernel_bessel, wave_length=w_cut, **kwargs ) @@ -1321,6 +1439,15 @@ def _low_filter(self, grid_name, w_cut, **kwargs): def convolve_filter_with_dynamic_kernel( self, grid, kernel_func, lat_max=85, extend=False, **kwargs_func ): + """ + :param str grid: grid name + :param func kernel_func: function of kernel to use + :param float lat_max: absolute latitude above no filtering apply + :param bool extend: if False, only non masked value will return a filtered value + :param dict kwargs_func: look at kernel_func + :return: filtered value + :rtype: array + """ if (abs(self.y_c) > lat_max).any(): logger.warning("No filtering above %f degrees of latitude", lat_max) if isinstance(grid, str): @@ -1364,7 +1491,8 @@ def convolve_filter_with_dynamic_kernel( tmp_matrix = ma.zeros((2 * d_lon + data.shape[0], k_shape[1])) tmp_matrix.mask = ones(tmp_matrix.shape, dtype=bool) # Slice to apply on input data - sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat, data.shape[1])) + # +1 for upper bound, to take in acount this column + sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat + 1, data.shape[1])) # slice to apply on temporary matrix to store input data sl_lat_in = slice( d_lat - (i - sl_lat_data.start), d_lat + (sl_lat_data.stop - i) @@ -1382,7 +1510,7 @@ def convolve_filter_with_dynamic_kernel( demi_x, demi_y = k_shape[0] // 2, k_shape[1] // 2 values_sum = filter2D(tmp_matrix.data, -1, kernel)[demi_x:-demi_x, demi_y] kernel_sum = filter2D(m.astype(float), -1, kernel)[demi_x:-demi_x, demi_y] - with errstate(invalid="ignore"): + with errstate(invalid="ignore", divide="ignore"): if extend: data_out[:, i] = ma.array( values_sum / kernel_sum, @@ -1413,7 +1541,7 @@ def lanczos_high_filter( lat_max=lat_max, wave_length=wave_length, order=order, - **kwargs + **kwargs, ) self.vars[grid_name] -= data_out @@ -1425,7 +1553,7 @@ def lanczos_low_filter(self, grid_name, wave_length, order=1, lat_max=85, **kwar lat_max=lat_max, wave_length=wave_length, order=order, - **kwargs + **kwargs, ) self.vars[grid_name] = data_out @@ -1440,8 +1568,17 @@ def bessel_band_filter(self, grid_name, wave_length_inf, wave_length_sup, **kwar self.vars[grid_name] -= data_out def bessel_high_filter(self, grid_name, wave_length, order=1, lat_max=85, **kwargs): + """ + :param str grid_name: grid to filter, data will replace original one + :param float wave_length: in km + :param int order: order to use, if > 1 negative values of the cardinal sinus are present in kernel + :param float lat_max: absolute latitude, no filtering above + :param dict kwargs: look at :py:meth:`RegularGridDataset.convolve_filter_with_dynamic_kernel` + + .. minigallery:: py_eddy_tracker.RegularGridDataset.bessel_high_filter + """ logger.debug( - "Run filtering with wave of %(wave_length)s km and order of %(order)s ...", + "Run filtering with wavelength of %(wave_length)s km and order of %(order)s ...", dict(wave_length=wave_length, order=order), ) data_out = self.convolve_filter_with_dynamic_kernel( @@ -1450,7 +1587,7 @@ def bessel_high_filter(self, grid_name, wave_length, order=1, lat_max=85, **kwar lat_max=lat_max, wave_length=wave_length, order=order, - **kwargs + **kwargs, ) logger.debug("Filtering done") self.vars[grid_name] -= data_out @@ -1462,7 +1599,7 @@ def bessel_low_filter(self, grid_name, wave_length, order=1, lat_max=85, **kwarg lat_max=lat_max, wave_length=wave_length, order=order, - **kwargs + **kwargs, ) self.vars[grid_name] = data_out @@ -1533,7 +1670,7 @@ def spectrum_lonlat(self, grid_name, area=None, ref=None, **kwargs): (lat_content[0], lat_content[1] / ref_lat_content[1]), ) - def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False): + def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False, second=False): if not isinstance(schema, int) and schema < 1: raise Exception("schema must be a positive int") @@ -1556,93 +1693,111 @@ def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=Fal data1[-schema:] = nan data2[:schema] = nan + # Distance for one degree d = self.EARTH_RADIUS * 2 * pi / 360 * 2 * schema + # Mulitply by 2 step if vertical: - d *= self.ystep + d *= self.ystep else: d *= self.xstep * cos(deg2rad(self.y_c)) - return (data1 - data2) / d + if second: + return (data1 + data2 - 2 * data) / (d ** 2 / 4) + else: + return (data1 - data2) / d def compute_stencil( self, data, stencil_halfwidth=4, mode="reflect", vertical=False ): + r""" + Apply stencil ponderation on field. + + :param array data: array where apply stencil + :param int stencil_halfwidth: from 1 t0 4, maximal stencil used + :param str mode: convolution mode + :param bool vertical: if True, method apply a vertical convolution + :return: gradient array from stencil application + :rtype: array + + Short story, how to get stencil coefficient for stencil (3 points, 5 points and 7 points) + + Taylor's theorem: + + .. math:: + f(x \pm h) = f(x) \pm f'(x)h + + \frac{f''(x)h^2}{2!} \pm \frac{f^{(3)}(x)h^3}{3!} + + \frac{f^{(4)}(x)h^4}{4!} \pm \frac{f^{(5)}(x)h^5}{5!} + + O(h^6) + + If we stop at `O(h^2)`, we get classic differenciation (stencil 3 points): + + .. math:: f(x+h) - f(x-h) = f(x) - f(x) + 2 f'(x)h + O(h^2) + + .. math:: f'(x) = \frac{f(x+h) - f(x-h)}{2h} + O(h^2) + + If we stop at `O(h^4)`, we will get stencil 5 points: + + .. math:: + f(x+h) - f(x-h) = 2 f'(x)h + 2 \frac{f^{(3)}(x)h^3}{3!} + O(h^4) + :label: E1 + + .. math:: + f(x+2h) - f(x-2h) = 4 f'(x)h + 16 \frac{f^{(3)}(x)h^3}{3!} + O(h^4) + :label: E2 + + If we multiply equation :eq:`E1` by 8 and substract equation :eq:`E2`, we get: + + .. math:: 8(f(x+h) - f(x-h)) - (f(x+2h) - f(x-2h)) = 16 f'(x)h - 4 f'(x)h + O(h^4) + + .. math:: f'(x) = \frac{f(x-2h) - 8f(x-h) + 8f(x+h) - f(x+2h)}{12h} + O(h^4) + + If we stop at `O(h^6)`, we will get stencil 7 points: + + .. math:: + f(x+h) - f(x-h) = 2 f'(x)h + 2 \frac{f^{(3)}(x)h^3}{3!} + 2 \frac{f^{(5)}(x)h^5}{5!} + O(h^6) + :label: E3 + + .. math:: + f(x+2h) - f(x-2h) = 4 f'(x)h + 16 \frac{f^{(3)}(x)h^3}{3!} + 64 \frac{f^{(5)}(x)h^5}{5!} + O(h^6) + :label: E4 + + .. math:: + f(x+3h) - f(x-3h) = 6 f'(x)h + 54 \frac{f^{(3)}(x)h^3}{3!} + 486 \frac{f^{(5)}(x)h^5}{5!} + O(h^6) + :label: E5 + + If we multiply equation :eq:`E3` by 45 and substract equation :eq:`E4` multiply by 9 + and add equation :eq:`E5`, we get: + + .. math:: + 45(f(x+h) - f(x-h)) - 9(f(x+2h) - f(x-2h)) + (f(x+3h) - f(x-3h)) = + 90 f'(x)h - 36 f'(x)h + 6 f'(x)h + O(h^6) + + .. math:: + f'(x) = \frac{-f(x-3h) + 9f(x-2h) - 45f(x-h) + 45f(x+h) - 9f(x+2h) +f(x+3h)}{60h} + O(h^6) + + ... + + """ stencil_halfwidth = max(min(int(stencil_halfwidth), 4), 1) logger.debug("Stencil half width apply : %d", stencil_halfwidth) - # output - grad = None - - weights = [ - array((3, -32, 168, -672, 0, 672, -168, 32, -3)) / 840.0, - array((-1, 9, -45, 0, 45, -9, 1)) / 60.0, - array((1, -8, 0, 8, -1)) / 12.0, - array((-1, 0, 1)) / 2.0, - # uncentered kernel - # like array((0, -1, 1)) but left value could be default value - array((-1, 1)), - # like array((-1, 1, 0)) but right value could be default value - (1, array((-1, 1))), - ] - # reduce to stencil selected - weights = weights[4 - stencil_halfwidth :] - if vertical: - data = data.T - # Iteration from larger stencil to smaller (to fill matrix) - for weight in weights: - if isinstance(weight, tuple): - # In the case of unbalanced diff - shift, weight = weight - data_ = data.copy() - data_[shift:] = data[:-shift] - if not vertical: - data_[:shift] = data[-shift:] - else: - data_ = data - # Delta h - d_h = convolve(data_, weights=weight.reshape((-1, 1)), mode=mode) - mask = convolve( - int8(data_.mask), weights=ones(weight.shape).reshape((-1, 1)), mode=mode - ) - d_h = ma.array(d_h, mask=mask != 0) - - # Delta d - if vertical: - d_h = d_h.T - d = self.EARTH_RADIUS * 2 * pi / 360 * convolve(self.y_c, weight) - else: - if mode == "wrap": - # Along x axis, we need to close - # we will compute in two part - x = self.x_c % 360 - d_degrees = convolve(x, weight, mode=mode) - d_degrees_180 = convolve((x + 180) % 360 - 180, weight, mode=mode) - # Arbitrary, to be sure to be far far away of bound - m = (x < 90) + (x > 270) - d_degrees[m] = d_degrees_180[m] - else: - d_degrees = convolve(self.x_c, weight, mode=mode) - d = ( - self.EARTH_RADIUS - * 2 - * pi - / 360 - * d_degrees.reshape((-1, 1)) - * cos(deg2rad(self.y_c)) - ) - if grad is None: - # First Gradient - grad = d_h / d - else: - # Fill hole - grad[grad.mask] = (d_h / d)[grad.mask] - return grad + g, m = compute_stencil( + self.x_c, + self.y_c, + data.data, + self.get_mask(data), + self.EARTH_RADIUS, + vertical=vertical, + stencil_halfwidth=stencil_halfwidth, + ) + return ma.array(g, mask=m) - def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): - self.add_uv(grid_height, uname, vname) + def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15, **kwargs): + self.add_uv(grid_height, uname, vname, **kwargs) latmax = 5 - _, (i_start, i_end) = self.nearest_grd_indice((0, 0), (-latmax, latmax)) + _, i_start = self.nearest_grd_indice(0, -latmax) + _, i_end = self.nearest_grd_indice(0, latmax) sl = slice(i_start, i_end) # Divide by sideral day - lat = self.y_c[sl] + lat = self.y_c gob = ( cos(deg2rad(lat)) * ones((self.x_c.shape[0], 1)) @@ -1656,43 +1811,50 @@ def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): mode = "wrap" if self.is_circular() else "reflect" # fill data to compute a finite difference on all point - data = self.convolve_filter_with_dynamic_kernel( - grid_height, - self.kernel_bessel, - lat_max=10, - wave_length=500, - order=1, - extend=0.1, - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) + kw_filter = dict(kernel_func=self.kernel_bessel, order=1, extend=.1) + data = self.convolve_filter_with_dynamic_kernel(grid_height, wave_length=500, **kw_filter, lat_max=6+5+2+3) v_lagerloef = ( self.compute_finite_difference( - self.compute_finite_difference(data, mode=mode, schema=schema), - mode=mode, - schema=schema, - )[:, sl] - * gob - ) - u_lagerloef = ( - -self.compute_finite_difference( - self.compute_finite_difference(data, vertical=True, schema=schema), - vertical=True, - schema=schema, - )[:, sl] + self.compute_finite_difference(data, mode=mode, schema=1), + vertical=True, schema=1 + ) * gob ) - w = 1 - exp(-((lat / 2.2) ** 2)) - self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef * (1 - w) - self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef * (1 - w) + u_lagerloef = -self.compute_finite_difference(data, vertical=True, schema=schema, second=True) * gob + + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6) + w = 1 - exp(-((lat[sl] / 2.2) ** 2)) + self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef[:, sl] * (1 - w) + self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef[:, sl] * (1 - w) def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): - """Compute a u and v grid - """ + r"""Compute a u and v grid + + :param str grid_height: grid name where the funtion will apply stencil method + :param str uname: future name of u + :param str vname: future name of v + :param int stencil_halfwidth: largest stencil could be apply (max: 4) + + .. math:: + u = \frac{g}{f} \frac{dh}{dy} + + v = -\frac{g}{f} \frac{dh}{dx} + + where + + .. math:: + g = gravity + + f = 2 \Omega sin(\phi) + + + .. minigallery:: py_eddy_tracker.RegularGridDataset.add_uv + """ logger.info("Add u/v variable with stencil method") data = self.grid(grid_height) h_dict = self.variables_description[grid_height] @@ -1735,84 +1897,388 @@ def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): ) def speed_coef_mean(self, contour): - """some nan can be compute over contour if we are near border, + """Some nan can be computed over contour if we are near borders, something to explore """ return mean_on_regular_contour( self.x_c, self.y_c, - self._speed_ev, + self._speed_ev.data, self._speed_ev.mask, contour.vertices, nan_remove=True, ) def init_speed_coef(self, uname="u", vname="v"): - """Draft + """Draft""" + u, v = self.grid(uname), self.grid(vname) + self._speed_ev = sqrt(u * u + v * v) + + def display(self, ax, name, factor=1, ref=None, **kwargs): """ - self._speed_ev = (self.grid(uname) ** 2 + self.grid(vname) ** 2) ** 0.5 + :param matplotlib.axes.Axes ax: matplotlib axes used to draw + :param str,array name: variable to display, could be an array + :param float factor: multiply grid by + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.pcolormesh` - def display(self, ax, name, factor=1, **kwargs): + .. minigallery:: py_eddy_tracker.RegularGridDataset.display + """ if "cmap" not in kwargs: kwargs["cmap"] = "coolwarm" - return ax.pcolormesh( - self.x_bounds, self.y_bounds, self.grid(name).T * factor, **kwargs - ) + data = self.grid(name) if isinstance(name, str) else name + if ref is None: + x = self.x_bounds + else: + x = (self.x_c - ref) % 360 + ref + i = x.argsort() + x = self.c_to_bounds(x[i]) + data = data[i] + return ax.pcolormesh(x, self.y_bounds, data.T * factor, **kwargs) + + def contour(self, ax, name, factor=1, ref=None, **kwargs): + """ + :param matplotlib.axes.Axes ax: matplotlib axes used to draw + :param str,array name: variable to display, could be an array + :param float factor: multiply grid by + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.contour` - def interp(self, grid_name, lons, lats): + .. minigallery:: py_eddy_tracker.RegularGridDataset.contour + """ + data = self.grid(name) if isinstance(name, str) else name + if ref is None: + x = self.x_c + else: + x = (self.x_c - ref) % 360 + ref + i = x.argsort() + x = x[i] + data = data[i] + return ax.contour(x, self.y_c, data.T * factor, **kwargs) + + def regrid(self, other, grid_name, new_name=None): + """ + Interpolate another grid at the current grid position + + :param RegularGridDataset other: + :param str grid_name: variable name to interpolate + :param str new_name: name used to store, if None method will use current ont + + .. minigallery:: py_eddy_tracker.RegularGridDataset.regrid + """ + if new_name is None: + new_name = grid_name + x, y = meshgrid(self.x_c, self.y_c) + # interp and reshape + v_interp = ( + other.interp(grid_name, x.reshape(-1), y.reshape(-1)).reshape(x.shape).T + ) + v_interp = ma.array(v_interp, mask=isnan(v_interp)) + # and add it to self + self.add_grid(new_name, v_interp) + self.variables_description[new_name] = other.variables_description[grid_name] + # self.variables_description[new_name]['infos'] = False + # self.variables_description[new_name]['kwargs']['dimensions'] = ... + + def interp(self, grid_name, lons, lats, method="bilinear"): """ Compute z over lons, lats - Args: - grid_name: Grid which will be interp - lons: new x - lats: new y - Returns: - new z + :param str grid_name: Grid to be interpolated + :param lons: new x + :param lats: new y + :param str method: Could be 'bilinear' or 'nearest' + + :return: new z """ g = self.grid(grid_name) - return interp2d_geo(self.x_c, self.y_c, g, g.mask, lons, lats) + m = self.get_mask(g) + return interp2d_geo( + self.x_c, self.y_c, g.data, m, lons, lats, nearest=method == "nearest" + ) + + def uv_for_advection( + self, + u_name=None, + v_name=None, + time_step=600, + h_name=None, + backward=False, + factor=1, + ): + """ + Get U,V to be used in degrees with precomputed time step + + :param None,str,array u_name: U field to advect obs, if h_name is None + :param None,str,array v_name: V field to advect obs, if h_name is None + :param None,str,array h_name: H field to compute UV to advect obs, if u_name and v_name are None + :param int time_step: Number of second for each advection + """ + if h_name is not None: + u_name, v_name = "u", "v" + if u_name not in self.vars: + self.add_uv(h_name) + self.vars.pop(h_name, None) + + u = (self.grid(u_name) if isinstance(u_name, str) else u_name).copy() + v = (self.grid(v_name) if isinstance(v_name, str) else v_name).copy() + # N seconds / 1 degrees in m + coef = time_step * 180 / pi / self.EARTH_RADIUS * factor + u *= coef / cos(radians(self.y_c)) + v *= coef + if backward: + u = -u + v = -v + m = u.mask + v.mask + return u.data, v.data, m + + def advect(self, x, y, u_name, v_name, nb_step=10, rk4=True, **kw): + """ + At each call it will update position in place with u & v field + + It's a dummy advection using only one layer of current + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param str,array u_name: U field to advect obs + :param str,array v_name: V field to advect obs + :param int nb_step: Number of iterations before releasing data + + .. minigallery:: py_eddy_tracker.GridDataset.advect + """ + u, v, m = self.uv_for_advection(u_name, v_name, **kw) + m_p = isnan(x) + isnan(y) + advect_ = advect_rk4 if rk4 else advect + while True: + advect_(self.x_c, self.y_c, u, v, m, x, y, m_p, nb_step) + yield x, y + + def filament( + self, x, y, u_name, v_name, nb_step=10, filament_size=6, rk4=True, **kw + ): + """ + Produce filament with concatenation of advection + + It's a dummy advection using only one layer of current + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param str,array u_name: U field to advect obs + :param str,array v_name: V field to advect obs + :param int nb_step: Number of iteration before releasing data + :param int filament_size: Number of point by filament + :return: x,y for a line + + .. minigallery:: py_eddy_tracker.GridDataset.filament + """ + u, v, m = self.uv_for_advection(u_name, v_name, **kw) + x, y = x.copy(), y.copy() + nb = x.shape[0] + filament_size_ = filament_size + 1 + f_x = empty(nb * filament_size_, dtype="f4") + f_y = empty(nb * filament_size_, dtype="f4") + f_x[:] = nan + f_y[:] = nan + f_x[::filament_size_] = x + f_y[::filament_size_] = y + mp = isnan(x) + isnan(y) + advect_ = advect_rk4 if rk4 else advect + while True: + # Shift position + f_x[1:] = f_x[:-1] + f_y[1:] = f_y[:-1] + # Remove last position + f_x[filament_size::filament_size_] = nan + f_y[filament_size::filament_size_] = nan + advect_(self.x_c, self.y_c, u, v, m, x, y, mp, nb_step) + f_x[::filament_size_] = x + f_y[::filament_size_] = y + yield f_x, f_y + + +@njit(cache=True) +def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # cache + i_cache, j_cache = -1000000, -1000000 + masked = False + u00, u01, u10, u11 = 0.0, 0.0, 0.0, 0.0 + v00, v01, v10, v11 = 0.0, 0.0, 0.0, 0.0 + # On each particle + for i in prange(x.size): + # If particle is not valid => continue + if m[i]: + continue + x_, y_ = x[i], y[i] + # Iterate on whole steps + for _ in range(nb_step): + # k1, slope at origin + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x_, y_, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + # The 3 following could be in cache operation but this one must be tested in any case + if masked: + x_, y_ = nan, nan + m[i] = True + break + u1, v1 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # k2, slope at middle with first guess position + x1, y1 = x_ + u1 * 0.5, y_ + v1 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x1, y1, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + if masked: + x_, y_ = nan, nan + m[i] = True + break + u2, v2 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # k3, slope at middle with updated guess position + x2, y2 = x_ + u2 * 0.5, y_ + v2 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x2, y2, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + if masked: + x_, y_ = nan, nan + m[i] = True + break + u3, v3 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # k4, slope at end with updated guess position + x3, y3 = x_ + u3, y_ + v3 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x3, y3, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + if masked: + x_, y_ = nan, nan + m[i] = True + break + u4, v4 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # RK4 compute + dx = (u1 + 2 * u2 + 2 * u3 + u4) / 6 + dy = (v1 + 2 * v2 + 2 * v3 + v4) / 6 + # Compute new x,y + x_ += dx + y_ += dy + x[i] = x_ + y[i] = y_ + + +@njit(cache=True) +def advect(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # Indexes which should be never exist + i0_old, j0_old = -100000, -100000 + masked = False + u00, u01, u10, u11 = 0.0, 0.0, 0.0, 0.0 + v00, v01, v10, v11 = 0.0, 0.0, 0.0, 0.0 + # On each particule + for i in prange(x.size): + # If particule is not valid => continue + if m[i]: + continue + # Iterate on whole steps + for _ in range(nb_step): + i0, j0, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x[i], y[i], nb_x + ) + # corners are the same, need only a new xd and yd + if i0 != i0_old or j0 != j0_old: + # Need to be stored only on change + i0_old, j0_old = i0, j0 + if not is_circular and (i0 < 0 or i0 > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + i0, j0, u_g, v_g, m_g, nb_x + ) + if masked: + x[i], y[i] = nan, nan + m[i] = True + break + u, v = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # Compute new x,y + x[i] += u + y[i] += v @njit(cache=True, fastmath=True) def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): - """Give a series of index which describe the path between to position - """ + """Give a serie of indexes describing the path between two positions""" # index nx = x0.shape[0] i_x0 = empty(nx, dtype=numba_types.int_) i_x1 = empty(nx, dtype=numba_types.int_) i_y0 = empty(nx, dtype=numba_types.int_) i_y1 = empty(nx, dtype=numba_types.int_) - # Because round_ is not accepted with array in numba for i in range(nx): - i_x0[i] = round_(((x0[i] - x_ori) % 360) / x_step) - i_x1[i] = round_(((x1[i] - x_ori) % 360) / x_step) - i_y0[i] = round_((y0[i] - y_ori) / y_step) - i_y1[i] = round_((y1[i] - y_ori) / y_step) + i_x0[i] = np.round(((x0[i] - x_ori) % 360) / x_step) + i_x1[i] = np.round(((x1[i] - x_ori) % 360) / x_step) + i_y0[i] = np.round((y0[i] - y_ori) / y_step) + i_y1[i] = np.round((y1[i] - y_ori) / y_step) # Delta index of x d_x = i_x1 - i_x0 d_x = (d_x + nb_x // 2) % nb_x - (nb_x // 2) i_x1 = i_x0 + d_x # Delta index of y d_y = i_y1 - i_y0 - # max and abs sum doesn't work on array? + # max and abs sum do not work on array? d_max = empty(nx, dtype=numba_types.int32) nb_value = 0 for i in range(nx): d_max[i] = max(abs(d_x[i]), abs(d_y[i])) - # Compute number of pixel which we go trought + # Compute number of pixel we go trought nb_value += d_max[i] + 1 - # Create an empty array to store value of pixel across the travel + # Create an empty array to store value of pixel across the path i_g = empty(nb_value, dtype=numba_types.int32) j_g = empty(nb_value, dtype=numba_types.int32) # Index to determine the position in the global array ii = 0 - # Iteration on each travel + # Iteration on each path for i, delta in enumerate(d_max): - # If the travel don't cross multiple pixel + # If the path doesn't cross multiple pixels if delta == 0: i_g[ii : ii + delta + 1] = i_x0[i] j_g[ii : ii + delta + 1] = i_y0[i] @@ -1826,7 +2292,7 @@ def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): sup = -1 if d_x[i] < 0 else 1 i_g[ii : ii + delta + 1] = arange(i_x0[i], i_x1[i] + sup, sup) j_g[ii : ii + delta + 1] = i_y0[i] - # In case of multiple direction + # In case of multiple directions else: a = (i_x1[i] - i_x0[i]) / float(i_y1[i] - i_y0[i]) if abs(d_x[i]) >= abs(d_y[i]): @@ -1845,23 +2311,693 @@ def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): @njit(cache=True) -def bbox_indice_regular(vertices, x0, y0, xstep, ystep, N, circular, x_size): - lon, lat = vertices[:, 0], vertices[:, 1] - lon_min, lon_max = lon.min(), lon.max() - lat_min, lat_max = lat.min(), lat.max() - i_x0, i_y0 = _nearest_grd_indice(lon_min, lat_min, x0, y0, xstep, ystep) - i_x1, i_y1 = _nearest_grd_indice(lon_max, lat_max, x0, y0, xstep, ystep) - if circular: - slice_x = (i_x0 - N) % x_size, (i_x1 + N + 1) % x_size - else: - slice_x = max(i_x0 - N, 0), i_x1 + N + 1 - slice_y = i_y0 - N, i_y1 + N + 1 - return slice_x, slice_y +def has_masked_value(grid, i_x, i_y): + for i, j in zip(i_x, i_y): + if grid[i, j]: + return True + return False + + +@njit(cache=True) +def has_value(grid, i_x, i_y, value, below=False): + for i, j in zip(i_x, i_y): + if below: + if grid[i, j] < value: + return True + else: + if grid[i, j] > value: + return True + return False + + +class GridCollection: + def __init__(self): + self.datasets = list() + + @classmethod + def from_netcdf_cube(cls, filename, x_name, y_name, t_name, heigth=None, **kwargs): + new = cls() + with Dataset(filename) as h: + for i, t in enumerate(h.variables[t_name][:]): + d = RegularGridDataset( + filename, x_name, y_name, indexs={t_name: i}, **kwargs + ) + if heigth is not None: + d.add_uv(heigth) + new.datasets.append((t, d)) + return new + + @classmethod + def from_netcdf_list( + cls, filenames, t, x_name, y_name, indexs=None, heigth=None, **kwargs + ): + new = cls() + for i, _t in enumerate(t): + filename = filenames[i] + logger.debug(f"load file {i:02d}/{len(t)} t={_t} : {filename}") + d = RegularGridDataset(filename, x_name, y_name, indexs=indexs, **kwargs) + if heigth is not None: + d.add_uv(heigth) + new.datasets.append((_t, d)) + return new + + @property + def are_loaded(self): + return ~array([d.dimensions is None for _, d in self.datasets]) + + def __repr__(self): + nb_dataset = len(self.datasets) + return f"{self.are_loaded.sum()}/{nb_dataset} datasets loaded" + + def shift_files(self, t, filename, heigth=None, **rgd_kwargs): + """Add next file to the list and remove the oldest""" + + self.datasets = self.datasets[1:] + + d = RegularGridDataset(filename, **rgd_kwargs) + if heigth is not None: + d.add_uv(heigth) + self.datasets.append((t, d)) + logger.debug(f"shift and adding i={len(self.datasets)} t={t} : {filename}") + + def interp(self, grid_name, t, lons, lats, method="bilinear"): + """ + Compute z over lons, lats + + :param str grid_name: Grid to be interpolated + :param float, t: time for interpolation + :param lons: new x + :param lats: new y + :param str method: Could be 'bilinear' or 'nearest' + + :return: new z + """ + # FIXME: we do assumption on time step + t0 = int(t) + t1 = t0 + 1 + h0, h1 = self[t0], self[t1] + g0, g1 = h0.grid(grid_name), h1.grid(grid_name) + m0, m1 = h0.get_mask(g0), h0.get_mask(g1) + kw = dict(x=lons, y=lats, nearest=method == "nearest") + v0 = interp2d_geo(h0.x_c, h0.y_c, g0, m0, **kw) + v1 = interp2d_geo(h1.x_c, h1.y_c, g1, m1, **kw) + w = (t - t0) / (t1 - t0) + return v1 * w + v0 * (1 - w) + + def __iter__(self): + for _, d in self.datasets: + yield d + + @property + def time(self): + return array([t for t, _ in self.datasets]) + + @property + def period(self): + t = self.time + return t.min(), t.max() + + def __getitem__(self, item): + for t, d in self.datasets: + if t == item: + d.populate() + return d + raise KeyError(item) + + def filament( + self, + x, + y, + u_name, + v_name, + t_init, + nb_step=10, + time_step=600, + filament_size=6, + rk4=True, + **kw, + ): + """ + Produce filament with concatenation of advection + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param str,array u_name: U field to advect obs + :param str,array v_name: V field to advect obs + :param int nb_step: Number of iteration before to release data + :param int time_step: Number of second for each advection + :param int filament_size: Number of point by filament + :return: x,y for a line + + .. minigallery:: py_eddy_tracker.GridCollection.filament + """ + x, y = x.copy(), y.copy() + nb = x.shape[0] + filament_size_ = filament_size + 1 + f_x = empty(nb * filament_size_, dtype="f4") + f_y = empty(nb * filament_size_, dtype="f4") + f_x[:] = nan + f_y[:] = nan + f_x[::filament_size_] = x + f_y[::filament_size_] = y + + backward = kw.get("backward", False) + if backward: + generator = self.get_previous_time_step(t_init) + dt = -nb_step * time_step + t_step = -time_step + else: + generator = self.get_next_time_step(t_init) + dt = nb_step * time_step + t_step = time_step + t0, d0 = generator.__next__() + u0, v0, m0 = d0.uv_for_advection(u_name, v_name, time_step, **kw) + t1, d1 = generator.__next__() + u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + t0 = t0 * 86400 + t1 = t1 * 86400 + t = t_init * 86400 + mp = isnan(x) + isnan(y) + advect_ = advect_t_rk4 if rk4 else advect_t + while True: + # Shift position + f_x[1:] = f_x[:-1] + f_y[1:] = f_y[:-1] + # Remove last position + f_x[filament_size::filament_size_] = nan + f_y[filament_size::filament_size_] = nan + + if (backward and t <= t1) or (not backward and t >= t1): + t0, u0, v0, m0 = t1, u1, v1, m1 + t1, d1 = generator.__next__() + t1 = t1 * 86400 + u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0) + half_w = t_step / 2.0 / (t1 - t0) + advect_(d0.x_c, d0.y_c, u0, v0, m0, u1, v1, m1, x, y, mp, w, half_w=half_w) + f_x[::filament_size_] = x + f_y[::filament_size_] = y + t += dt + yield t, f_x, f_y + + def reset_grids(self, N=None): + if N is not None: + m = self.are_loaded + if m.sum() > N: + for i in where(m)[0]: + self.datasets[i][1].clean() + + def advect( + self, + x, + y, + t_init, + mask_particule=None, + nb_step=10, + time_step=600, + rk4=True, + reset_grid=None, + **kw, + ): + """ + At each call it will update position in place with u & v field + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param float t_init: time to start advection + :param array,None mask_particule: advect only i mask is True + :param int nb_step: Number of iteration before to release data + :param int time_step: Number of second for each advection + :param bool rk4: Use rk4 algorithm instead of finite difference + :param int reset_grid: Delete all loaded data in cube if there are more than N grid loaded + + :return: t,x,y position + + .. minigallery:: py_eddy_tracker.GridCollection.advect + """ + self.reset_grids(reset_grid) + backward = kw.get("backward", False) + if backward: + generator = self.get_previous_time_step(t_init) + dt = -nb_step * time_step + t_step = -time_step + else: + generator = self.get_next_time_step(t_init) + dt = nb_step * time_step + t_step = time_step + t0, d0 = generator.__next__() + u0, v0, m0 = d0.uv_for_advection(time_step=time_step, **kw) + t1, d1 = generator.__next__() + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) + t0 = t0 * 86400 + t1 = t1 * 86400 + t = t_init * 86400 + advect_ = advect_t_rk4 if rk4 else advect_t + if mask_particule is None: + mask_particule = isnan(x) + isnan(y) + else: + mask_particule += isnan(x) + isnan(y) + while True: + logger.debug(f"advect : t={t/86400}") + if (backward and t <= t1) or (not backward and t >= t1): + t0, u0, v0, m0 = t1, u1, v1, m1 + t1, d1 = generator.__next__() + t1 = t1 * 86400 + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) + w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0) + half_w = t_step / 2.0 / (t1 - t0) + advect_( + d0.x_c, + d0.y_c, + u0, + v0, + m0, + u1, + v1, + m1, + x, + y, + mask_particule, + w, + half_w=half_w, + ) + t += dt + yield t, x, y + + def get_next_time_step(self, t_init): + for i, (t, dataset) in enumerate(self.datasets): + if t < t_init: + continue + dataset.populate() + logger.debug(f"i={i}, t={t}, dataset={dataset}") + yield t, dataset + + def get_previous_time_step(self, t_init): + i = len(self.datasets) + for t, dataset in reversed(self.datasets): + i -= 1 + if t > t_init: + continue + dataset.populate() + logger.debug(f"i={i}, t={t}, dataset={dataset}") + yield t, dataset + + def path(self, x0, y0, *args, nb_time=2, **kwargs): + """ + At each call it will update position in place with u & v field + + :param array x0: Longitude of obs to move + :param array y0: Latitude of obs to move + :param int nb_time: Number of iteration for particle + :param dict kwargs: look at :py:meth:`GridCollection.advect` + + :return: t,x,y + + .. minigallery:: py_eddy_tracker.GridCollection.path + """ + particles = self.advect(x0.copy(), y0.copy(), *args, **kwargs) + t = empty(nb_time + 1, dtype="f8") + x = empty((nb_time + 1, x0.size), dtype=x0.dtype) + y = empty(x.shape, dtype=y0.dtype) + t[0], x[0], y[0] = kwargs.get("t_init"), x0, y0 + for i in range(nb_time): + t[i + 1], x[i + 1], y[i + 1] = particles.__next__() + return t, x, y + + +@njit(cache=True) +def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, half_w=0): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # Indexes that should never exist + i0_old, j0_old = -100000, -100000 + m0, m1 = False, False + u000, u001, u010, u011 = 0.0, 0.0, 0.0, 0.0 + v000, v001, v010, v011 = 0.0, 0.0, 0.0, 0.0 + u100, u101, u110, u111 = 0.0, 0.0, 0.0, 0.0 + v100, v101, v110, v111 = 0.0, 0.0, 0.0, 0.0 + # On each particle + for i in prange(x.size): + # If particle is not valid => continue + if m[i]: + continue + # Iterate on whole steps + for w in weigths: + i0, j0, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x[i], y[i], nb_x + ) + if i0 != i0_old or j0 != j0_old: + # Need to be stored only on change + i0_old, j0_old = i0, j0 + if not is_circular and (i0 < 0 or i0 > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + i0, j0, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + i0, j0, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x[i], y[i] = nan, nan + m[i] = True + break + # Compute distance + xd_i, yd_i = 1 - xd, 1 - yd + # Compute new x,y + dx0 = (u000 * xd_i + u010 * xd) * yd_i + (u001 * xd_i + u011 * xd) * yd + dx1 = (u100 * xd_i + u110 * xd) * yd_i + (u101 * xd_i + u111 * xd) * yd + dy0 = (v000 * xd_i + v010 * xd) * yd_i + (v001 * xd_i + v011 * xd) * yd + dy1 = (v100 * xd_i + v110 * xd) * yd_i + (v101 * xd_i + v111 * xd) * yd + x[i] += dx0 * w + dx1 * (1 - w) + y[i] += dy0 * w + dy1 * (1 - w) @njit(cache=True, fastmath=True) -def _nearest_grd_indice(x, y, x0, y0, xstep, ystep): - return ( - numba_types.int32(round(((x - x0[0]) % 360.0) / xstep)), - numba_types.int32(round((y - y0[0]) / ystep)), - ) +def get_uv_quad(i0, j0, u, v, m, nb_x=0): + """ + Return u/v for (i0, j0), (i1, j0), (i0, j1), (i1, j1) + + :param int i0: indexes of longitude + :param int j0: indexes of latitude + :param array[float] u: current along x axis + :param array[float] v: current along y axis + :param array[bool] m: flag to know if position is valid + :param int nb_x: If different of 0 we check if wrapping is needed + + :return: if cell is valid 4 u, 4 v + :rtype: bool,float,float,float,float,float,float,float,float + """ + i1, j1 = i0 + 1, j0 + 1 + if nb_x != 0: + i1 %= nb_x + i_max, j_max = m.shape + + if i1 >= i_max or j1 >= j_max: + return True, nan, nan, nan, nan, nan, nan, nan, nan + + if m[i0, j0] or m[i0, j1] or m[i1, j0] or m[i1, j1]: + return True, nan, nan, nan, nan, nan, nan, nan, nan + # Extract value for u and v + u00, u01, u10, u11 = u[i0, j0], u[i0, j1], u[i1, j0], u[i1, j1] + v00, v01, v10, v11 = v[i0, j0], v[i0, j1], v[i1, j0], v[i1, j1] + return False, u00, u01, u10, u11, v00, v01, v10, v11 + + +@njit(cache=True, fastmath=True) +def get_grid_indices(x0, y0, x_step, y_step, x, y, nb_x=0): + """ + Return grid indexes and weight + + :param float x0: first longitude + :param float y0: first latitude + :param float x_step: longitude grid step + :param float y_step: latitude grid step + :param float x: longitude to interpolate + :param float y: latitude to interpolate + :param int nb_x: If different of 0 we check if wrapping is needed + + :return: indexes and weight + :rtype: int,int,float,float + """ + i, j = (x - x0) / x_step, (y - y0) / y_step + i0, j0 = int(floor(i)), int(floor(j)) + xd, yd = i - i0, j - j0 + if nb_x != 0: + i0 %= nb_x + return i0, j0, xd, yd + + +@njit(cache=True, fastmath=True) +def interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11): + """ + Return u/v interpolated in cell + + :param float xd: x weight + :param float yd: y weight + :param float u00: u lower left + :param float u01: u upper left + :param float u10: u lower right + :param float u11: u upper right + :param float v00: v lower left + :param float v01: v upper left + :param float v10: v lower right + :param float v11: v upper right + """ + xd_i, yd_i = 1 - xd, 1 - yd + u = (u00 * xd_i + u10 * xd) * yd_i + (u01 * xd_i + u11 * xd) * yd + v = (v00 * xd_i + v10 * xd) * yd_i + (v01 * xd_i + v11 * xd) * yd + return u, v + + +@njit(cache=True, fastmath=True) +def advect_t_rk4( + x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, half_w +): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # cache + i_cache, j_cache = -1000000, -1000000 + m0, m1 = False, False + u000, u001, u010, u011 = 0.0, 0.0, 0.0, 0.0 + v000, v001, v010, v011 = 0.0, 0.0, 0.0, 0.0 + u100, u101, u110, u111 = 0.0, 0.0, 0.0, 0.0 + v100, v101, v110, v111 = 0.0, 0.0, 0.0, 0.0 + # On each particle + for i in prange(x.size): + # If particle is not valid => continue + if m[i]: + continue + x_, y_ = x[i], y[i] + # Iterate on whole steps + for w in weigths: + # k1, slope at origin + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x_, y_, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + # The 3 following could be in cache operation but this one must be tested in any case + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + u1, v1 = u0_ * w + u1_ * (1 - w), v0_ * w + v1_ * (1 - w) + # k2, slope at middle with first guess position + x1, y1 = x_ + u1 * 0.5, y_ + v1 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x1, y1, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + w_ = w - half_w + u2, v2 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) + # k3, slope at middle with updated guess position + x2, y2 = x_ + u2 * 0.5, y_ + v2 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x2, y2, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + u3, v3 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) + # k4, slope at end with updated guess position + x3, y3 = x_ + u3, y_ + v3 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x3, y3, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + w_ -= half_w + u4, v4 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) + # RK4 compute + dx = (u1 + 2 * u2 + 2 * u3 + u4) / 6 + dy = (v1 + 2 * v2 + 2 * v3 + v4) / 6 + x_ += dx + y_ += dy + x[i], y[i] = x_, y_ + + +@njit( + [ + "Tuple((f8[:,:],b1[:,:]))(f8[:],f8[:],f8[:,:],b1[:,:],f8,b1,i1)", + "Tuple((f4[:,:],b1[:,:]))(f8[:],f8[:],f4[:,:],b1[:,:],f8,b1,i1)", + ], + cache=True, + fastmath=True, +) +def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth=4): + """ + Compute stencil on RegularGrid + + :param array x: longitude coordinates + :param array y: latitude coordinates + :param array h: 2D array to derivate + :param array m: mask associated to h to know where are invalid data + :param float earth_radius: Earth radius in m + :param bool vertical: if True stencil will be vertical (along y) + :param int stencil_halfwidth: from 1 to 4 to specify maximal kernel usable + + + stencil_halfwidth: + + - (1) : + + - (-1, 1, 0) + - (0, -1, 1) + - (-1, 0, 1) / 2 + + - (2) : (1, -8, 0, 8, 1) / 12 + - (3) : (-1, 9, -45, 0, 45, -9, 1) / 60 + - (4) : (3, -32, 168, -672, 0, 672, -168, 32, 3) / 840 + """ + if vertical: + # If vertical we transpose matrix and inverse coordinates + h = h.T + m = m.T + x, y = y, x + shape = h.shape + nb_x, nb_y = shape + # Out array + m_out = empty(shape, dtype=numba_types.bool_) + grad = empty(shape, dtype=h.dtype) + # Distance step in degrees + d_step = x[1] - x[0] + if vertical: + is_circular = False + else: + # Test if matrix is circular + is_circular = abs(x[-1] % 360 - (x[0] - d_step) % 360) < 1e-5 + + # Compute caracteristic distance, constant when vertical + d_ = 360 / (d_step * pi * 2 * earth_radius) + for j in range(nb_y): + # Buffer of maximal size of stencil (9) + if is_circular: + h_3, h_2, h_1, h0 = h[-4, j], h[-3, j], h[-2, j], h[-1, j] + m_3, m_2, m_1, m0 = m[-4, j], m[-3, j], m[-2, j], m[-1, j] + else: + m_3, m_2, m_1, m0 = True, True, True, True + h1, h2, h3, h4 = h[0, j], h[1, j], h[2, j], h[3, j] + m1, m2, m3, m4 = m[0, j], m[1, j], m[2, j], m[3, j] + for i in range(nb_x): + # Roll value and only last + h_4, h_3, h_2, h_1, h0, h1, h2, h3 = h_3, h_2, h_1, h0, h1, h2, h3, h4 + m_4, m_3, m_2, m_1, m0, m1, m2, m3 = m_3, m_2, m_1, m0, m1, m2, m3, m4 + i_ = i + 4 + if i_ >= nb_x: + if is_circular: + i_ = i_ % nb_x + m4 = m[i_, j] + h4 = h[i_, j] + else: + # When we are out + m4 = False + else: + m4 = m[i_, j] + h4 = h[i_, j] + + # Current value not defined + if m0: + m_out[i, j] = True + continue + if not vertical: + # For each row we compute distance + d_ = 360 / (d_step * cos(deg2rad(y[j])) * pi * 2 * earth_radius) + if m1 ^ m_1: + # unbalanced kernel + if m_1: + grad[i, j] = (h1 - h0) * d_ + m_out[i, j] = False + continue + if m1: + grad[i, j] = (h0 - h_1) * d_ + m_out[i, j] = False + continue + continue + if m2 or m_2 or stencil_halfwidth == 1: + grad[i, j] = (h1 - h_1) / 2 * d_ + m_out[i, j] = False + continue + if m3 or m_3 or stencil_halfwidth == 2: + grad[i, j] = (h_2 - h2 + 8 * (h1 - h_1)) / 12 * d_ + m_out[i, j] = False + continue + if m4 or m_4 or stencil_halfwidth == 3: + grad[i, j] = (h3 - h_3 + 9 * (h_2 - h2) + 45 * (h1 - h_1)) / 60 * d_ + m_out[i, j] = False + continue + # If all values of buffer are available + grad[i, j] = ( + (3 * (h_4 - h4) + 32 * (h3 - h_3) + 168 * (h_2 - h2) + 672 * (h1 - h_1)) + / 840 + * d_ + ) + m_out[i, j] = False + if vertical: + return grad.T, m_out.T + else: + return grad, m_out diff --git a/src/py_eddy_tracker/eddy_feature.py b/src/py_eddy_tracker/eddy_feature.py index 0cf7adb5..8bc139ab 100644 --- a/src/py_eddy_tracker/eddy_feature.py +++ b/src/py_eddy_tracker/eddy_feature.py @@ -1,30 +1,27 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Class to compute Amplitude and average speed profile """ import logging -from numpy import empty, array, concatenate, ma, zeros, unique, round, ones, int_ + +from matplotlib.cm import get_cmap +from matplotlib.colors import Normalize from matplotlib.figure import Figure from numba import njit, types as numba_types +from numpy import ( + array, + concatenate, + digitize, + empty, + int_, + ma, + ones, + round, + unique, + zeros, +) + from .poly import winding_number_poly logger = logging.getLogger("pet") @@ -33,7 +30,7 @@ class Amplitude(object): """ Class to calculate *amplitude* and counts of *local maxima/minima* - within a closed region of a sea level anomaly field. + within a closed region of a sea surface height field. """ EPSILON = 1e-8 @@ -45,15 +42,38 @@ class Amplitude(object): "sla", "contour", "interval_min", + "interval_min_secondary", "amplitude", "mle", ) - def __init__(self, contour, contour_height, data, interval): + def __init__( + self, + contour, + contour_height, + data, + interval, + mle=1, + nb_step_min=2, + nb_step_to_be_mle=2, + ): + """ + Create amplitude object + + :param Contours contour: usefull class defined below + :param float contour_height: field value of the contour + :param array data: grid + :param float interval: step between two contours + :param int mle: maximum number of local extrema in contour + :param int nb_step_min: minimum number of intervals to consider the contour as an eddy + :param int nb_step_to_be_mle: number of intervals to be considered as another extrema + """ + # Height of the contour self.h_0 = contour_height # Step minimal to consider amplitude - self.interval_min = interval * 2 + self.interval_min = interval * nb_step_min + self.interval_min_secondary = interval * nb_step_to_be_mle # Indices of all pixels in contour self.contour = contour # Link on original grid (local view) or copy if it's on bound @@ -81,15 +101,15 @@ def __init__(self, contour, contour_height, data, interval): self.nb_pixel = i_x.shape[0] # Only pixel in contour + # FIXME : change sla by ssh as the grid can be adt? self.sla = data[contour.pixels_index] - # Amplitude which will be provide + # Amplitude which will be provided self.amplitude = 0 # Maximum local extrema accepted - self.mle = 1 + self.mle = mle def within_amplitude_limits(self): - """Need update - """ + """Need update""" return self.interval_min <= self.amplitude def all_pixels_below_h0(self, level): @@ -97,7 +117,7 @@ def all_pixels_below_h0(self, level): Check CSS11 criterion 1: The SSH values of all of the pixels are below a given SSH threshold for cyclonic eddies. """ - # In some case pixel value must be very near of contour bounds + # In some cases pixel value may be very close to the contour bounds if self.sla.mask.any() or ((self.sla.data - self.h_0) > self.EPSILON).any(): return False else: @@ -123,7 +143,8 @@ def all_pixels_below_h0(self, level): else: # Verify if several extrema are seriously below contour nb_real_extrema = ( - (level - self.grid_extract.data[lmi_i, lmi_j]) >= self.interval_min + (level - self.grid_extract.data[lmi_i, lmi_j]) + >= self.interval_min_secondary ).sum() if nb_real_extrema > self.mle: return False @@ -152,6 +173,7 @@ def all_pixels_above_h0(self, level): self.mle, -1, ) + # After we use grid.data because index are in contour and we check before than no pixel are hide nb = len(lmi_i) if nb == 0: logger.warning( @@ -165,7 +187,8 @@ def all_pixels_above_h0(self, level): else: # Verify if several extrema are seriously above contour nb_real_extrema = ( - (self.grid_extract.data[lmi_i, lmi_j] - level) >= self.interval_min + (self.grid_extract.data[lmi_i, lmi_j] - level) + >= self.interval_min_secondary ).sum() if nb_real_extrema > self.mle: return False @@ -270,10 +293,10 @@ class Contours(object): Attributes: contour: - A matplotlib contour object of high-pass filtered SLA + A matplotlib contour object of high-pass filtered SSH eddy: - A tracklist object holding the SLA data + A tracklist object holding the SSH data grd: A grid object @@ -383,7 +406,7 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): fig = Figure() ax = fig.add_subplot(111) if wrap_x: - logger.debug("wrapping activate to compute contour") + logger.debug("wrapping activated to compute contour") x = concatenate((x, x[:1] + 360)) z = ma.concatenate((z, z[:1])) logger.debug("X shape : %s", x.shape) @@ -409,8 +432,8 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): closed_contours = 0 # Count level and contour for i, collection in enumerate(self.contours.collections): - collection.get_nearest_path_bbox_contain_pt = lambda x, y, i=i: self.get_index_nearest_path_bbox_contain_pt( - i, x, y + collection.get_nearest_path_bbox_contain_pt = ( + lambda x, y, i=i: self.get_index_nearest_path_bbox_contain_pt(i, x, y) ) nb_level += 1 @@ -420,17 +443,14 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): # Contour with less vertices than 4 are popped if contour.vertices.shape[0] < 4: continue - if keep_unclose: - keep_path.append(contour) - continue # Remove unclosed path d_closed = ( (contour.vertices[0, 0] - contour.vertices[-1, 0]) ** 2 + (contour.vertices[0, 1] - contour.vertices[-1, 1]) ** 2 ) ** 0.5 - if d_closed > self.DELTA_SUP: + if d_closed > self.DELTA_SUP and not keep_unclose: continue - elif d_closed != 0: + elif d_closed != 0 and d_closed <= self.DELTA_SUP: # Repair almost closed contour if d_closed > self.DELTA_PREC: almost_closed_contours += 1 @@ -451,6 +471,7 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): collection._paths = keep_path for contour in collection.get_paths(): contour.used = False + contour.reject = 0 nb_contour += 1 nb_pt += contour.vertices.shape[0] logger.info( @@ -555,12 +576,51 @@ def display( only_used=False, only_unused=False, only_contain_eddies=False, + display_criterion=False, + field=None, + bins=None, + cmap="Spectral_r", **kwargs ): + """ + Display contour + + :param matplotlib.axes.Axes ax: + :param int step: display only contour every step + :param bool only_used: display only contour used in an eddy + :param bool only_unused: display only contour unused in an eddy + :param bool only_contain_eddies: display only contour which enclosed an eddiy + :param bool display_criterion: + display only unused contour with criterion color + + 0. - Accepted (green) + 1. - Reject for shape error (red) + 2. - Masked value in contour (blue) + 3. - Under or over pixel limit bound (black) + 4. - Amplitude criterion (yellow) + :param str field: + Must be 'shape_error', 'x', 'y' or 'radius'. + If defined display_criterion is not use. + bins argument must be defined + :param array bins: bins used to colorize contour + :param str cmap: Name of cmap for field display + :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` + + .. minigallery:: py_eddy_tracker.Contours.display + """ from matplotlib.collections import LineCollection + overide_color = display_criterion or field is not None + if display_criterion: + paths = {0: list(), 1: list(), 2: list(), 3: list(), 4: list()} + elif field is not None: + paths = dict() + for i in range(len(bins)): + paths[i] = list() + paths[i + 1] = list() for j, collection in enumerate(self.contours.collections[::step]): - paths = list() + if not overide_color: + paths = list() for i in collection.get_paths(): if only_used and not i.used: continue @@ -568,21 +628,67 @@ def display( continue elif only_contain_eddies and not i.contain_eddies: continue - paths.append(i.vertices) + if display_criterion: + paths[i.reject].append(i.vertices) + elif field is not None: + x, y, radius, shape_error = i.fit_circle() + if field == "shape_error": + i_ = digitize(shape_error, bins) + elif field == "radius": + i_ = digitize(radius, bins) + elif field == "x": + i_ = digitize(x, bins) + elif field == "y": + i_ = digitize(y, bins) + paths[i_].append(i.vertices) + else: + paths.append(i.vertices) local_kwargs = kwargs.copy() if "color" not in kwargs: - local_kwargs["color"] = collection.get_color() + local_kwargs["color"] = collection.get_edgecolor() local_kwargs.pop("label", None) elif j != 0: local_kwargs.pop("label", None) - ax.add_collection(LineCollection(paths, **local_kwargs)) - - if hasattr(self.contours, "_mins"): - ax.update_datalim([self.contours._mins, self.contours._maxs]) - ax.autoscale_view() + if not overide_color: + ax.add_collection(LineCollection(paths, **local_kwargs)) + if display_criterion: + colors = { + 0: "limegreen", + 1: "red", + 2: "mediumblue", + 3: "black", + 4: "gold", + } + for k, v in paths.items(): + local_kwargs = kwargs.copy() + local_kwargs.pop("label", None) + local_kwargs["colors"] = colors[k] + ax.add_collection(LineCollection(v, **local_kwargs)) + elif field is not None: + nb_bins = len(bins) - 1 + cmap = get_cmap(cmap, lut=nb_bins) + for k, v in paths.items(): + local_kwargs = kwargs.copy() + local_kwargs.pop("label", None) + if k == 0: + local_kwargs["colors"] = cmap(0.0) + elif k > nb_bins: + local_kwargs["colors"] = cmap(1.0) + else: + local_kwargs["colors"] = cmap((k - 1.0) / nb_bins) + mappable = LineCollection(v, **local_kwargs) + ax.add_collection(mappable) + mappable.cmap = cmap + mappable.norm = Normalize(vmin=bins[0], vmax=bins[-1]) + # TODO : need to create an object with all collections + return mappable + else: + if hasattr(self.contours, "_mins"): + ax.update_datalim([self.contours._mins, self.contours._maxs]) + ax.autoscale_view() def label_contour_unused_which_contain_eddies(self, eddies): - """Select contour which contain several eddies""" + """Select contour containing several eddies""" if eddies.sign_type == 1: # anticyclonic sl = slice(None, -1) @@ -634,8 +740,7 @@ def index_from_nearest_path_with_pt_in_bbox_( xpt, ypt, ): - """Get index from nearest path in edge bbox contain pt - """ + """Get index from nearest path in edge bbox contain pt""" # Nb contour in level if nb_c_per_l[level_index] == 0: return -1 @@ -678,7 +783,7 @@ def index_from_nearest_path_with_pt_in_bbox_( d_x = x_value[i_elt_pt] - xpt_ if abs(d_x) > 180: d_x = (d_x + 180) % 360 - 180 - dist = d_x ** 2 + (y_value[i_elt_pt] - ypt) ** 2 + dist = d_x**2 + (y_value[i_elt_pt] - ypt) ** 2 if dist < dist_ref: dist_ref = dist i_ref = i_elt_c diff --git a/src/py_eddy_tracker/featured_tracking/area_tracker.py b/src/py_eddy_tracker/featured_tracking/area_tracker.py index 8a1cfcdc..9e676fc1 100644 --- a/src/py_eddy_tracker/featured_tracking/area_tracker.py +++ b/src/py_eddy_tracker/featured_tracking/area_tracker.py @@ -1,27 +1,59 @@ -from ..observations.observation import EddiesObservations as Model -from numpy import ma import logging +from numba import njit +from numpy import empty, ma, ones + +from ..observations.observation import EddiesObservations as Model + logger = logging.getLogger("pet") class AreaTracker(Model): + """ + Area Tracker will used overlap to track eddy. + + This tracking will used :py:meth:`~py_eddy_tracker.observations.observation.EddiesObservations.match` method + to get a similarity index, which could be between [0:1]. + You could setup this class with `cmin` option to set a minimal value to valid an association. + """ + + __slots__ = ("cmin",) + + def __init__(self, *args, cmin=0.2, **kwargs): + super().__init__(*args, **kwargs) + self.cmin = cmin + + def merge(self, *args, **kwargs): + eddies = super().merge(*args, **kwargs) + eddies.cmin = self.cmin + return eddies + + @classmethod + def needed_variable(cls): + vars = ["longitude", "latitude"] + vars.extend(cls.intern(False, public_label=True)) + return vars + def tracking(self, other): + """ + Core method to track + """ shape = (self.shape[0], other.shape[0]) i, j, c = self.match(other, intern=False) - cost_mat = ma.empty(shape, dtype="f4") - cost_mat.mask = ma.ones(shape, dtype="bool") - m = c > .2 - i, j, c = i[m], j[m], c[m] - cost_mat[i, j] = 1 - c + cost_mat = ma.array(empty(shape, dtype="f4"), mask=ones(shape, dtype="bool")) + mask_cmin(i, j, c, self.cmin, cost_mat.data, cost_mat.mask) i_self, i_other = self.solve_function(cost_mat) i_self, i_other = self.post_process_link(other, i_self, i_other) logger.debug("%d matched with previous", i_self.shape[0]) return i_self, i_other, cost_mat[i_self, i_other] - def propagate(self, previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model): - virtual = super().propagate(previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model) + def propagate( + self, previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model + ): + virtual = super().propagate( + previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model + ) nb_dead = len(previous_obs) nb_virtual_extend = nb_next - nb_dead for key in model.elements: @@ -31,3 +63,13 @@ def propagate(self, previous_obs, current_obs, obs_to_extend, dead_track, nb_nex if nb_virtual_extend > 0: virtual[key][nb_dead:] = obs_to_extend[key] return virtual + + +@njit(cache=True) +def mask_cmin(i, j, c, cmin, cost_mat, mask): + for k in range(c.shape[0]): + c_ = c[k] + if c_ > cmin: + i_, j_ = i[k], j[k] + cost_mat[i_, j_] = 1 - c_ + mask[i_, j_] = False diff --git a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py index 908f4554..b0d4abfa 100644 --- a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py +++ b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py @@ -1,63 +1,60 @@ -from ..observations.observation import EddiesObservations as Model -from ..dataset.grid import RegularGridDataset -from numpy import where, bincount, ones, unique, bool_, arange -from numba import njit from os import path +from numba import njit +from numpy import arange, bincount, bool_, ones, unique, where + +from ..dataset.grid import RegularGridDataset +from ..observations.observation import EddiesObservations as Model + class CheltonTracker(Model): - GROUND = RegularGridDataset(path.join(path.dirname(__file__), '../data/mask_1_60.nc'), 'lon', 'lat') + __slots__ = tuple() + + GROUND = RegularGridDataset( + path.join(path.dirname(__file__), "../data/mask_1_60.nc"), "lon", "lat" + ) @staticmethod def cost_function(records_in, records_out, distance): - """We minimize on distance between two obs - """ + """We minimize on distance between two obs""" return distance def mask_function(self, other, distance): - """We mask link with ellips and ratio - """ - # Compute Parameter of ellips + """We mask link with ellipse and ratio""" + # Compute Parameter of ellipse minor, major = 1.05, 1.5 - y = self.basic_formula_ellips_major_axis( - self.obs['lat'], - degrees=True, - c0=minor, - cmin=minor, - cmax=major, - lat1=23, - lat2=5, + y = self.basic_formula_ellipse_major_axis( + self.lat, degrees=True, c0=minor, cmin=minor, cmax=major, lat1=23, lat2=5 ) - # mask from ellips + # mask from ellipse mask = self.shifted_ellipsoid_degrees_mask( - other, - minor=minor, # Minor can be bigger than major?? - major=y) + other, minor=minor, major=y # Minor can be bigger than major?? + ) # We check ratio (maybe not usefull) - check_ratio(mask, self.obs['amplitude'], other.obs['amplitude'], self.obs['radius_e'], other.obs['radius_e']) + check_ratio( + mask, self.amplitude, other.amplitude, self.radius_e, other.radius_e + ) indexs_closest = where(mask) - mask[indexs_closest] = self.across_ground(self.obs[indexs_closest[0]], other.obs[indexs_closest[1]]) + mask[indexs_closest] = self.across_ground( + self.obs[indexs_closest[0]], other.obs[indexs_closest[1]] + ) return mask @classmethod def across_ground(cls, record0, record1): i, j, d_pix = cls.GROUND.compute_pixel_path( - x0=record0['lon'], - y0=record0['lat'], - x1=record1['lon'], - y1=record1['lat'], + x0=record0["lon"], y0=record0["lat"], x1=record1["lon"], y1=record1["lat"] ) - data = cls.GROUND.grid('mask')[i, j] + data = cls.GROUND.grid("mask")[i, j] i_ground = unique(arange(len(record0)).repeat(d_pix + 1)[data == 1]) - mask = ones(record1.shape, dtype='bool') + mask = ones(record1.shape, dtype="bool") mask[i_ground] = False return mask def solve_function(self, cost_matrix): - """Give the best link for each self obs - """ + """Give the best link for each self obs""" return where(self.solve_first(cost_matrix, multiple_link=True)) def post_process_link(self, other, i_self, i_other): @@ -70,7 +67,7 @@ def post_process_link(self, other, i_self, i_other): for i in where(nb_link > 1)[0]: m = i == i_other multiple_in = i_self[m] - i_keep = self.obs['amplitude'][multiple_in].argmax() + i_keep = self.amplitude[multiple_in].argmax() m[where(m)[0][i_keep]] = False mask[m] = False @@ -80,9 +77,12 @@ def post_process_link(self, other, i_self, i_other): @njit(cache=True) -def check_ratio(current_mask, self_amplitude, other_amplitude, self_radius, other_radius): +def check_ratio( + current_mask, self_amplitude, other_amplitude, self_radius, other_radius +): """ Only very few case are remove with selection + :param current_mask: :param self_amplitude: :param other_amplitude: diff --git a/src/py_eddy_tracker/generic.py b/src/py_eddy_tracker/generic.py index 800a3f9f..2fdb737a 100644 --- a/src/py_eddy_tracker/generic.py +++ b/src/py_eddy_tracker/generic.py @@ -1,50 +1,61 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Tool method which use mostly numba """ +from numba import njit, prange, types as numba_types from numpy import ( - sin, - pi, - cos, - arctan2, + absolute, arcsin, + arctan2, + bool_, + cos, empty, - nan, - absolute, floor, - ones, - linspace, + histogram, interp, + isnan, + linspace, + nan, + ones, + pi, + radians, + sin, where, zeros, - isnan, - bool_, ) -from numba import njit, prange, types as numba_types -from numpy.linalg import lstsq + + +@njit(cache=True) +def count_consecutive(mask): + """ + Count consecutive events every False flag count restart + + :param array[bool] mask: event to count + :return: count when consecutive event + :rtype: array + """ + count = 0 + output = zeros(mask.shape, dtype=numba_types.int_) + for i in range(mask.shape[0]): + if not mask[i]: + count = 0 + continue + count += 1 + output[i] = count + return output @njit(cache=True) def reverse_index(index, nb): + """ + Compute a list of indices, which are not in index. + + :param array index: index of group which will be set to False + :param array nb: Count for each group + :return: mask of value selected + :rtype: array + """ m = ones(nb, dtype=numba_types.bool_) for i in index: m[i] = False @@ -53,8 +64,18 @@ def reverse_index(index, nb): @njit(cache=True) def build_index(groups): - """We expected that variable is monotonous, and return index for each step change + """We expect that variable is monotonous, and return index for each step change. + + :param array groups: array that contains groups to be separated + :return: (first_index of each group, last_index of each group, value to shift groups) + :rtype: (array, array, int) + + :Example: + + >>> build_index(array((1, 1, 3, 4, 4))) + (array([0, 2, 2, 3]), array([2, 2, 3, 5]), 1) """ + i0, i1 = groups.min(), groups.max() amplitude = i1 - i0 + 1 # Index of first observation for each group @@ -62,26 +83,33 @@ def build_index(groups): for i, group in enumerate(groups[:-1]): # Get next value to compare next_group = groups[i + 1] - # if different we need to set index + # if different we need to set index for all groups between the 2 values if group != next_group: first_index[group - i0 + 1 : next_group - i0 + 1] = i + 1 last_index = zeros(amplitude, dtype=numba_types.int_) last_index[:-1] = first_index[1:] - last_index[-1] = i + 2 + last_index[-1] = len(groups) return first_index, last_index, i0 +@njit(cache=True) +def hist_numba(x, bins): + """Call numba histogram to speed up.""" + return histogram(x, bins) + + @njit(cache=True, fastmath=True, parallel=False) def distance_grid(lon0, lat0, lon1, lat1): """ - Args: - lon0: - lat0: - lon1: - lat1: + Get distance for every couple of points. - Returns: - nan value for far away point, and km for other + :param array lon0: + :param array lat0: + :param array lon1: + :param array lat1: + + :return: nan value for far away points, and km for other + :rtype: array """ nb_0 = lon0.shape[0] nb_1 = lon1.shape[0] @@ -103,177 +131,184 @@ def distance_grid(lon0, lat0, lon1, lat1): sin_dlon = sin((dlon) * 0.5 * D2R) cos_lat1 = cos(lat0[i] * D2R) cos_lat2 = cos(lat1[j] * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - dist[i, j] = 6370.997 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + dist[i, j] = 6370.997 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) return dist @njit(cache=True, fastmath=True) def distance(lon0, lat0, lon1, lat1): + """ + Compute distance between points from each line. + + :param float lon0: + :param float lat0: + :param float lon1: + :param float lat1: + :return: distance (in m) + :rtype: array + """ D2R = pi / 180.0 sin_dlat = sin((lat1 - lat0) * 0.5 * D2R) sin_dlon = sin((lon1 - lon0) * 0.5 * D2R) cos_lat1 = cos(lat0 * D2R) cos_lat2 = cos(lat1 * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - return 6370997.0 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + return 6370997.0 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) @njit(cache=True) -def distance_vincenty(lon0, lat0, lon1, lat1): - """ better than haversine but buggy ??""" - D2R = pi / 180.0 - dlon = (lon1 - lon0) * D2R - cos_dlon = cos(dlon) - cos_lat1 = cos(lat0 * D2R) - cos_lat2 = cos(lat1 * D2R) - sin_lat1 = sin(lat0 * D2R) - sin_lat2 = sin(lat1 * D2R) - return 6370997.0 * arctan2( - ( - (cos_lat2 * sin(dlon) ** 2) - + (cos_lat1 * sin_lat2 - sin_lat1 * cos_lat2 * cos_dlon) ** 2 - ) - ** 0.5, - sin_lat1 * sin_lat2 + cos_lat1 * cos_lat2 * cos_dlon, - ) +def cumsum_by_track(field, track): + """ + Cumsum by track. + + :param array field: data to sum + :pram array(int) track: id of trajectories to separate data + :return: cumsum with a reset at each start of track + :rtype: array + """ + tr_previous = 0 + d_cum = 0 + cumsum_array = empty(track.shape, dtype=field.dtype) + for i in range(field.shape[0]): + tr = track[i] + if tr != tr_previous: + d_cum = 0 + d_cum += field[i] + cumsum_array[i] = d_cum + tr_previous = tr + return cumsum_array + + +@njit(cache=True, fastmath=True) +def interp2d_geo(x_g, y_g, z_g, m_g, x, y, nearest=False): + """ + For geographic grid, test of cicularity. + + :param array x_g: coordinates of grid + :param array y_g: coordinates of grid + :param array z_g: Grid value + :param array m_g: Boolean grid, True if value is masked + :param array x: coordinate where interpolate z + :param array y: coordinate where interpolate z + :param bool nearest: if True we will take nearest pixel + :return: z interpolated + :rtype: array + """ + if nearest: + return interp2d_nearest(x_g, y_g, z_g, x, y) + else: + return interp2d_bilinear(x_g, y_g, z_g, m_g, x, y) + + +@njit(cache=True, fastmath=True) +def interp2d_nearest(x_g, y_g, z_g, x, y): + """ + Nearest interpolation with wrapping if circular + + :param array x_g: coordinates of grid + :param array y_g: coordinates of grid + :param array z_g: Grid value + :param array x: coordinate where interpolate z + :param array y: coordinate where interpolate z + :return: z interpolated + :rtype: array + """ + x_ref = x_g[0] + y_ref = y_g[0] + x_step = x_g[1] - x_ref + y_step = y_g[1] - y_ref + nb_x = x_g.shape[0] + nb_y = y_g.shape[0] + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + z = empty(x.shape, dtype=z_g.dtype) + for i in prange(x.size): + i0 = int(round((x[i] - x_ref) / x_step)) + j0 = int(round((y[i] - y_ref) / y_step)) + if is_circular: + i0 %= nb_x + if i0 >= nb_x or i0 < 0 or j0 < 0 or j0 >= nb_y: + z[i] = nan + continue + z[i] = z_g[i0, j0] + return z @njit(cache=True, fastmath=True) -def interp2d_geo(x_g, y_g, z_g, m_g, x, y): - """For geographic grid, test of cicularity - Maybe test if we are out of bounds +def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): + """ + Bilinear interpolation with wrapping if circular + + :param array x_g: coordinates of grid + :param array y_g: coordinates of grid + :param array z_g: Grid value + :param array m_g: Boolean grid, True if value is masked + :param array x: coordinate where interpolate z + :param array y: coordinate where interpolate z + :return: z interpolated + :rtype: array """ x_ref = x_g[0] y_ref = y_g[0] x_step = x_g[1] - x_ref y_step = y_g[1] - y_ref nb_x = x_g.shape[0] - is_circular = (x_g[-1] + x_step) % 360 == x_g[0] % 360 + nb_y = y_g.shape[0] + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + # Indexes that should never exist + i0_old, j0_old, masked = -100000000, -10000000, False z = empty(x.shape, dtype=z_g.dtype) for i in prange(x.size): x_ = (x[i] - x_ref) / x_step y_ = (y[i] - y_ref) / y_step i0 = int(floor(x_)) - i1 = i0 + 1 - xd = x_ - i0 - if is_circular: - i0 %= nb_x - i1 %= nb_x + # To keep original values if wrapping applied to compute xd + i0_ = i0 j0 = int(floor(y_)) - j1 = j0 + 1 - yd = y_ - j0 - z00 = z_g[i0, j0] - z01 = z_g[i0, j1] - z10 = z_g[i1, j0] - z11 = z_g[i1, j1] - if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]: + # corners are the same need only a new xd and yd + if i0 != i0_old or j0 != j0_old: + i1 = i0 + 1 + j1 = j0 + 1 + if is_circular: + i0 %= nb_x + i1 %= nb_x + if i1 >= nb_x or i0 < 0 or j0 < 0 or j1 >= nb_y: + masked = True + else: + masked = False + if not masked: + if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]: + masked = True + else: + z00, z01, z10, z11 = ( + z_g[i0, j0], + z_g[i0, j1], + z_g[i1, j0], + z_g[i1, j1], + ) + masked = False + # Need to be stored only on change + i0_old, j0_old = i0, j0 + if masked: z[i] = nan else: + xd = x_ - i0_ + yd = y_ - j0 z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + ( z01 * (1 - xd) + z11 * xd ) * yd return z -@njit(cache=True, fastmath=True, parallel=False) -def custom_convolution(data, mask, kernel): - """do sortin at high lattitude big part of value are masked""" - nb_x = kernel.shape[0] - demi_x = int((nb_x - 1) / 2) - demi_y = int((kernel.shape[1] - 1) / 2) - out = empty(data.shape[0] - nb_x + 1) - for i in prange(out.shape[0]): - if mask[i + demi_x, demi_y] == 1: - w = (mask[i : i + nb_x] * kernel).sum() - if w != 0: - out[i] = (data[i : i + nb_x] * kernel).sum() / w - else: - out[i] = nan - else: - out[i] = nan - return out - - -@njit(cache=True) -def fit_circle(x_vec, y_vec): - nb_elt = x_vec.shape[0] - p_inon_x = empty(nb_elt) - p_inon_y = empty(nb_elt) - - # last coordinates == first - x_mean = x_vec[1:].mean() - y_mean = y_vec[1:].mean() - - norme = (x_vec[1:] - x_mean) ** 2 + (y_vec[1:] - y_mean) ** 2 - norme_max = norme.max() - scale = norme_max ** 0.5 - - # Form matrix equation and solve it - # Maybe put f4 - datas = ones((nb_elt - 1, 3)) - datas[:, 0] = 2.0 * (x_vec[1:] - x_mean) / scale - datas[:, 1] = 2.0 * (y_vec[1:] - y_mean) / scale - - (center_x, center_y, radius), _, _, _ = lstsq(datas, norme / norme_max) - - # Unscale data and get circle variables - radius += center_x ** 2 + center_y ** 2 - radius **= 0.5 - center_x *= scale - center_y *= scale - # radius of fitted circle - radius *= scale - # center X-position of fitted circle - center_x += x_mean - # center Y-position of fitted circle - center_y += y_mean - - # area of fitted circle - c_area = (radius ** 2) * pi - # Find distance between circle center and contour points_inside_poly - for i_elt in range(nb_elt): - # Find distance between circle center and contour points_inside_poly - dist_poly = ( - (x_vec[i_elt] - center_x) ** 2 + (y_vec[i_elt] - center_y) ** 2 - ) ** 0.5 - # Indices of polygon points outside circle - # p_inon_? : polygon x or y points inside & on the circle - if dist_poly > radius: - p_inon_y[i_elt] = center_y + radius * (y_vec[i_elt] - center_y) / dist_poly - p_inon_x[i_elt] = center_x - (center_x - x_vec[i_elt]) * ( - center_y - p_inon_y[i_elt] - ) / (center_y - y_vec[i_elt]) - else: - p_inon_x[i_elt] = x_vec[i_elt] - p_inon_y[i_elt] = y_vec[i_elt] - - # Area of closed contour/polygon enclosed by the circle - p_area_incirc = 0 - p_area = 0 - for i_elt in range(nb_elt - 1): - # Indices of polygon points outside circle - # p_inon_? : polygon x or y points inside & on the circle - p_area_incirc += ( - p_inon_x[i_elt] * p_inon_y[1 + i_elt] - - p_inon_x[i_elt + 1] * p_inon_y[i_elt] - ) - # Shape test - # Area and centroid of closed contour/polygon - p_area += x_vec[i_elt] * y_vec[1 + i_elt] - x_vec[1 + i_elt] * y_vec[i_elt] - p_area = abs(p_area) * 0.5 - p_area_incirc = abs(p_area_incirc) * 0.5 - - a_err = (c_area - 2 * p_area_incirc + p_area) * 100.0 / c_area - return center_x, center_y, radius, a_err - - @njit(cache=True, fastmath=True) -def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): +def uniform_resample(x_val, y_val, num_fac=2, fixed_size=-1): """ - Resample contours to have (nearly) equal spacing - x_val, y_val : input contour coordinates - num_fac : factor to increase lengths of output coordinates + Resample contours to have (nearly) equal spacing. + + :param array_like x_val: input x contour coordinates + :param array_like y_val: input y contour coordinates + :param int num_fac: factor to increase lengths of output coordinates + :param int fixed_size: if > -1, will be used to set sampling """ nb = x_val.shape[0] # Get distances @@ -284,7 +319,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): dist[1:][dist[1:] < 1e-3] = 1e-3 dist = dist.cumsum() # Get uniform distances - if fixed_size is None: + if fixed_size == -1: fixed_size = dist.size * num_fac d_uniform = linspace(0, dist[-1], fixed_size) x_new = interp(d_uniform, dist, x_val) @@ -295,15 +330,21 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): @njit(cache=True) def flatten_line_matrix(l_matrix): """ - Flat matrix and add on between each line - Args: - l_matrix: matrix of position + Flat matrix and add on between each line. - Returns: array with nan between line + :param l_matrix: matrix of position + :return: array with nan between line """ nb_line, sampling = l_matrix.shape final_size = (nb_line - 1) + nb_line * sampling + empty_dataset = False + if final_size < 1: + empty_dataset = True + final_size = 1 out = empty(final_size, dtype=l_matrix.dtype) + if empty_dataset: + out[:] = nan + return out inc = 0 for i in range(nb_line): for j in range(sampling): @@ -316,12 +357,39 @@ def flatten_line_matrix(l_matrix): @njit(cache=True) def simplify(x, y, precision=0.1): - precision2 = precision ** 2 + """ + Will remove all middle/end points closer than precision. + + :param array x: + :param array y: + :param float precision: if two points have distance inferior to precision we remove next point + :return: (x,y) + :rtype: (array,array) + """ + precision2 = precision**2 nb = x.shape[0] - x_previous, y_previous = x[0], y[0] + # will be True for kept values mask = ones(nb, dtype=bool_) - for i in range(1, nb): + for j in range(0, nb): + x_previous, y_previous = x[j], y[j] + if isnan(x_previous) or isnan(y_previous): + mask[j] = False + continue + break + # Only nan + if j == (nb - 1): + return zeros(0, dtype=x.dtype), zeros(0, dtype=x.dtype) + + last_nan = False + for i in range(j + 1, nb): x_, y_ = x[i], y[i] + if isnan(x_) or isnan(y_): + if last_nan: + mask[i] = False + else: + last_nan = True + continue + last_nan = False d_x = x_ - x_previous if d_x > precision: x_previous, y_previous = x_, y_ @@ -330,7 +398,7 @@ def simplify(x, y, precision=0.1): if d_y > precision: x_previous, y_previous = x_, y_ continue - d2 = d_x ** 2 + d_y ** 2 + d2 = d_x**2 + d_y**2 if d2 > precision2: x_previous, y_previous = x_, y_ continue @@ -348,17 +416,17 @@ def simplify(x, y, precision=0.1): @njit(cache=True) def split_line(x, y, i): """ - Split x and y at each i change - Args: - x: array - y: array - i: array of int at each i change, we cut x, y + Split x and y at each i change. - Returns: x and y separate by nan at each i jump + :param x: array + :param y: array + :param i: array of int at each i change, we cut x, y + + :return: x and y separated by nan at each i jump """ nb_jump = len(where(i[1:] - i[:-1] != 0)[0]) nb_value = x.shape[0] - final_size = (nb_jump - 1) + nb_value + final_size = nb_jump + nb_value new_x = empty(final_size, dtype=x.dtype) new_y = empty(final_size, dtype=y.dtype) new_j = 0 @@ -375,20 +443,31 @@ def split_line(x, y, i): @njit(cache=True) def wrap_longitude(x, y, ref, cut=False): + """ + Will wrap contiguous longitude with reference as western boundary. + + :param array x: + :param array y: + :param float ref: longitude of reference, all the new values will be between ref and ref + 360 + :param bool cut: if True line will be cut at the bounds + :return: lon,lat + :rtype: (array,array) + """ if cut: indexs = list() nb = x.shape[0] - new_previous = (x[0] - ref) % 360 + + new_x_previous = (x[0] - ref) % 360 + ref x_previous = x[0] for i in range(1, nb): x_ = x[i] - new_x = (x_ - ref) % 360 + new_x = (x_ - ref) % 360 + ref if not isnan(x_) and not isnan(x_previous): - d_new = new_x - new_previous + d_new = new_x - new_x_previous d = x_ - x_previous if abs(d - d_new) > 1e-5: indexs.append(i) - x_previous, new_previous = x_, new_x + x_previous, new_x_previous = x_, new_x nb_indexs = len(indexs) new_size = nb + nb_indexs * 3 @@ -399,6 +478,7 @@ def wrap_longitude(x, y, ref, cut=False): for i in range(nb): if j < nb_indexs and i == indexs[j]: j += 1 + # FIXME need check cor = 360 if x[i - 1] > x[i] else -360 out_x[i + i_] = (x[i] - ref) % 360 + ref - cor out_y[i + i_] = y[i] @@ -421,6 +501,16 @@ def wrap_longitude(x, y, ref, cut=False): @njit(cache=True, fastmath=True) def coordinates_to_local(lon, lat, lon0, lat0): + """ + Take latlong coordinates to transform in local coordinates (in m). + + :param array x: coordinates to transform + :param array y: coordinates to transform + :param float lon0: longitude of local reference + :param float lat0: latitude of local reference + :return: x,y + :retype: (array, array) + """ D2R = pi / 180.0 R = 6370997 dlon = (lon - lon0) * D2R @@ -428,8 +518,8 @@ def coordinates_to_local(lon, lat, lon0, lat0): sin_dlon = sin(dlon * 0.5) cos_lat0 = cos(lat0 * D2R) cos_lat = cos(lat * D2R) - a_val = sin_dlon ** 2 * cos_lat0 * cos_lat + sin_dlat ** 2 - module = R * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat0 * cos_lat + sin_dlat**2 + module = R * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) azimuth = pi / 2 - arctan2( cos_lat * sin(dlon), @@ -440,9 +530,19 @@ def coordinates_to_local(lon, lat, lon0, lat0): @njit(cache=True, fastmath=True) def local_to_coordinates(x, y, lon0, lat0): + """ + Take local coordinates (in m) to transform to latlong. + + :param array x: coordinates to transform + :param array y: coordinates to transform + :param float lon0: longitude of local reference + :param float lat0: latitude of local reference + :return: lon,lat + :retype: (array, array) + """ D2R = pi / 180.0 R = 6370997 - d = (x ** 2 + y ** 2) ** 0.5 / R + d = (x**2 + y**2) ** 0.5 / R a = -(arctan2(y, x) - pi / 2) lat = arcsin(sin(lat0 * D2R) * cos(d) + cos(lat0 * D2R) * sin(d) * cos(a)) lon = ( @@ -453,3 +553,108 @@ def local_to_coordinates(x, y, lon0, lat0): / D2R ) return lon, lat / D2R + + +@njit(cache=True, fastmath=True) +def nearest_grd_indice(x, y, x0, y0, xstep, ystep): + """ + Get nearest grid index from a position. + + :param x: longitude + :param y: latitude + :param float x0: first grid longitude + :param float y0: first grid latitude + :param float xstep: step between two longitude + :param float ystep: step between two latitude + """ + return ( + numba_types.int32(round(((x - x0[0]) % 360.0) / xstep)), + numba_types.int32(round((y - y0[0]) / ystep)), + ) + + +@njit(cache=True) +def bbox_indice_regular(vertices, x0, y0, xstep, ystep, N, circular, x_size): + """ + Get bbox index of a contour in a regular grid. + + :param vertices: vertice of contour + :param float x0: first grid longitude + :param float y0: first grid latitude + :param float xstep: step between two longitude + :param float ystep: step between two latitude + :param int N: shift of index to enlarge window + :param bool circular: To know if grid is wrappable + :param int x_size: Number of longitude + """ + lon, lat = vertices[:, 0], vertices[:, 1] + lon_min, lon_max = lon.min(), lon.max() + lat_min, lat_max = lat.min(), lat.max() + i_x0, i_y0 = nearest_grd_indice(lon_min, lat_min, x0, y0, xstep, ystep) + i_x1, i_y1 = nearest_grd_indice(lon_max, lat_max, x0, y0, xstep, ystep) + if circular: + slice_x = (i_x0 - N) % x_size, (i_x1 + N + 1) % x_size + else: + slice_x = max(i_x0 - N, 0), i_x1 + N + 1 + slice_y = i_y0 - N, i_y1 + N + 1 + return slice_x, slice_y + + +def build_circle(x0, y0, r): + """ + Build circle from center coordinates. + + :param float x0: center coordinate + :param float y0: center coordinate + :param float r: radius i meter + :return: x,y + :rtype: (array,array) + """ + angle = radians(linspace(0, 360, 50)) + x_norm, y_norm = cos(angle), sin(angle) + return x_norm * r + x0, y_norm * r + y0 + + +def window_index(x, x0, half_window=1): + """ + Give for a fixed half_window each start and end index for each x0, in + an unsorted array. + + :param array x: array of value + :param array x0: array of window center + :param float half_window: half window + """ + # Sort array, bounds will be sort also + i_ordered = x.argsort(kind="mergesort") + return window_index_(x, i_ordered, x0, half_window) + + +@njit(cache=True) +def window_index_(x, i_ordered, x0, half_window=1): + nb_x, nb_pt = x.size, x0.size + first_index = empty(nb_pt, dtype=i_ordered.dtype) + last_index = empty(nb_pt, dtype=i_ordered.dtype) + # First bound to find + j_min, j_max = 0, 0 + x_min = x0[j_min] - half_window + x_max = x0[j_max] + half_window + # We iterate on ordered x + for i, i_x in enumerate(i_ordered): + x_ = x[i_x] + # if x bigger than x_min , we found bound and search next one + while x_ > x_min and j_min < nb_pt: + first_index[j_min] = i + j_min += 1 + x_min = x0[j_min] - half_window + # if x bigger than x_max , we found bound and search next one + while x_ > x_max and j_max < nb_pt: + last_index[j_max] = i + j_max += 1 + x_max = x0[j_max] + half_window + if j_max == nb_pt: + break + for i in range(j_min, nb_pt): + first_index[i] = nb_x + for i in range(j_max, nb_pt): + last_index[i] = nb_x + return i_ordered, first_index, last_index diff --git a/src/py_eddy_tracker/gui.py b/src/py_eddy_tracker/gui.py index c9118200..a85e9c18 100644 --- a/src/py_eddy_tracker/gui.py +++ b/src/py_eddy_tracker/gui.py @@ -1,34 +1,18 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +GUI class """ -import numpy as np -from datetime import datetime -import matplotlib.pyplot as plt +from datetime import datetime, timedelta +import logging + from matplotlib.projections import register_projection -import py_eddy_tracker_sample as sample +import matplotlib.pyplot as plt +import numpy as np + from .generic import flatten_line_matrix, split_line -from .observations.tracking import TrackEddiesObservations +logger = logging.getLogger("pet") try: from pylook.axes import PlatCarreAxes @@ -36,13 +20,24 @@ from matplotlib.axes import Axes class PlatCarreAxes(Axes): + """ + Class to replace missing pylook class + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.set_aspect("equal") +GUI_AXES = "full_axes" + + class GUIAxes(PlatCarreAxes): - name = "full_axes" + """ + Axes that uses full space available + """ + + name = GUI_AXES def end_pan(self, *args, **kwargs): (x0, x1), (y0, y1) = self.get_xlim(), self.get_ylim() @@ -65,10 +60,6 @@ def end_pan(self, *args, **kwargs): register_projection(GUIAxes) -class A(TrackEddiesObservations): - pass - - def no(*args, **kwargs): return False @@ -81,6 +72,7 @@ class GUI: "time_ax", "param_ax", "settings", + "d_indexs", "m", "last_event", ) @@ -89,6 +81,7 @@ class GUI: def __init__(self, **datasets): self.datasets = datasets + self.d_indexs = dict() self.m = dict() self.set_initial_values() self.setup() @@ -101,8 +94,8 @@ def set_initial_values(self): for dataset in self.datasets.values(): t0_, t1_ = dataset.period t0, t1 = min(t0, t0_), max(t1, t1_) - - self.settings = dict(period=(t0, t1), now=t1,) + logger.debug("period detected %f -> %f", t0, t1) + self.settings = dict(period=(t0, t1), now=t1) @property def now(self): @@ -138,10 +131,11 @@ def med(self): def setup(self): self.figure = plt.figure() # map - self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection="full_axes") + self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection=GUI_AXES) self.map.grid() - self.map.tick_params("x", pad=-12) - self.map.tick_params("y", pad=-22) + self.map.tick_params("both", pad=-22) + # self.map.tick_params("y", pad=-22) + self.map.bg_cache = None # time ax self.time_ax = self.figure.add_axes((0, 0.15, 1, 0.1), facecolor=".95") self.time_ax.can_pan @@ -158,41 +152,71 @@ def setup(self): # param self.param_ax = self.figure.add_axes((0, 0, 1, 0.15), facecolor="0.2") + def hide_path(self, state): + for name in self.datasets: + self.m[name]["path_previous"].set_visible(state) + self.m[name]["path_future"].set_visible(state) + def draw(self): + self.m["mini_ax"] = self.figure.add_axes((0.3, 0.85, 0.4, 0.15), zorder=80) + self.m["mini_ax"].grid() # map for i, (name, dataset) in enumerate(self.datasets.items()): + kwargs = dict(color=self.COLORS[i]) self.m[name] = dict( - contour_s=self.map.plot( - [], [], color=self.COLORS[i], lw=0.5, label=name - )[0], - contour_e=self.map.plot([], [], color=self.COLORS[i], lw=0.5)[0], - path_previous=self.map.plot([], [], color=self.COLORS[i], lw=0.5)[0], - path_future=self.map.plot([], [], color=self.COLORS[i], lw=0.2, ls=":")[ - 0 - ], + contour_s=self.map.plot([], [], lw=1, label=name, **kwargs)[0], + contour_e=self.map.plot([], [], lw=0.5, ls="-.", **kwargs)[0], + path_previous=self.map.plot([], [], lw=0.5, **kwargs)[0], + path_future=self.map.plot([], [], lw=0.2, ls=":", **kwargs)[0], + mini_line=self.m["mini_ax"].plot([], [], **kwargs, lw=1)[0], ) - self.m["title"] = self.map.set_title("") # time_ax + self.m["annotate"] = self.map.annotate( + "", + (0, 0), + xycoords="figure pixels", + zorder=100, + fontsize=9, + bbox=dict(boxstyle="round", facecolor="w", edgecolor="0.5", alpha=0.85), + ) + self.m["mini_ax"].set_visible(False) + self.m["annotate"].set_visible(False) + self.m["time_vline"] = self.time_ax.axvline(0, color="k", lw=1) self.m["time_text"] = self.time_ax.text( - 0, 0, "", fontsize=8, bbox=dict(facecolor="w", alpha=0.75) + 0, + 0, + "", + fontsize=8, + bbox=dict(facecolor="w", alpha=0.75), + verticalalignment="bottom", ) def update(self): - # text = [] + time_text = [ + (timedelta(days=int(self.now)) + datetime(1950, 1, 1)).strftime("%d/%m/%Y") + ] # map xs, ys, ns = list(), list(), list() for j, (name, dataset) in enumerate(self.datasets.items()): i = self.indexs(dataset) + self.d_indexs[name] = i self.m[name]["contour_s"].set_label(f"{name} {len(i)} eddies") if len(i) == 0: self.m[name]["contour_s"].set_data([], []) + self.m[name]["contour_e"].set_data([], []) else: - self.m[name]["contour_s"].set_data( - flatten_line_matrix(dataset["contour_lon_s"][i]), - flatten_line_matrix(dataset["contour_lat_s"][i]), - ) - # text.append(f"{i.shape[0]}") + if "contour_lon_s" in dataset.elements: + self.m[name]["contour_s"].set_data( + flatten_line_matrix(dataset["contour_lon_s"][i]), + flatten_line_matrix(dataset["contour_lat_s"][i]), + ) + if "contour_lon_e" in dataset.elements: + self.m[name]["contour_e"].set_data( + flatten_line_matrix(dataset["contour_lon_e"][i]), + flatten_line_matrix(dataset["contour_lat_e"][i]), + ) + time_text.append(f"{i.shape[0]}") local_path = dataset.extract_ids(dataset["track"][i]) x, y, t, n, tr = ( local_path.longitude, @@ -228,12 +252,11 @@ def update(self): self.map.text(x_, y_, n_) for x_, y_, n_ in zip(x, y, n) if n_ >= n_min ] - self.m["title"].set_text(self.now) self.map.legend() # time ax x, y = self.m["time_vline"].get_data() self.m["time_vline"].set_data(self.now, y) - # self.m["time_text"].set_text("\n".join(text)) + self.m["time_text"].set_text("\n".join(time_text)) self.m["time_text"].set_position((self.now, 0)) # force update self.map.figure.canvas.draw() @@ -262,6 +285,30 @@ def press(self, event): self.time_ax.press = True self.time_ax.bg_cache = self.figure.canvas.copy_from_bbox(self.time_ax.bbox) + def get_infos(self, name, index): + i = self.d_indexs[name][index] + d = self.datasets[name] + now = d.obs[i] + tr = now["track"] + nb = d.nb_obs_by_track[tr] + i_first = d.index_from_track[tr] + track = d.obs[i_first : i_first + nb] + nb -= 1 + t0 = timedelta(days=track[0]["time"]) + datetime(1950, 1, 1) + t1 = timedelta(days=track[-1]["time"]) + datetime(1950, 1, 1) + txt = f"--{name}--\n" + txt += f" {t0} -> {t1}\n" + txt += f" Tracks : {tr} {now['n']}/{nb} ({now['n'] / nb * 100:.2f} %)\n" + for label, n, f, u in ( + ("Amp.", "amplitude", 100, "cm"), + ("S. radius", "radius_s", 1e-3, "km"), + ("E. radius", "radius_e", 1e-3, "km"), + ): + v = track[n] * f + min_, max_, mean_, std_ = v.min(), v.max(), v.mean(), v.std() + txt += f" {label} : {now[n] * f:.1f} {u} ({min_:.1f} <-{mean_:.1f}+-{std_:.1f}-> {max_:.1f})\n" + return track, txt.strip() + def move(self, event): if event.inaxes == self.time_ax and self.time_ax.press: x, y = self.m["time_vline"].get_data() @@ -270,6 +317,49 @@ def move(self, event): self.time_ax.draw_artist(self.m["time_vline"]) self.figure.canvas.blit(self.time_ax.bbox) + if event.inaxes == self.map: + touch = dict() + for name in self.datasets.keys(): + flag, data = self.m[name]["contour_s"].contains(event) + if flag: + # 51 is for contour on 50 point must be rewrote + touch[name] = data["ind"][0] // 51 + a = self.m["annotate"] + ax = self.m["mini_ax"] + if touch: + if not a.get_visible(): + self.map.bg_cache = self.figure.canvas.copy_from_bbox(self.map.bbox) + a.set_visible(True) + ax.set_visible(True) + else: + self.figure.canvas.restore_region(self.map.bg_cache) + a.set_x(event.x), a.set_y(event.y) + txt = list() + x0_, x1_, y1_ = list(), list(), list() + for name in self.datasets.keys(): + if name in touch: + track, txt_ = self.get_infos(name, touch[name]) + txt.append(txt_) + x, y = track["time"], track["radius_s"] / 1e3 + self.m[name]["mini_line"].set_data(x, y) + x0_.append(x.min()), x1_.append(x.max()), y1_.append(y.max()) + else: + self.m[name]["mini_line"].set_data([], []) + ax.set_xlim(min(x0_), max(x1_)), ax.set_ylim(0, max(y1_)) + a.set_text("\n".join(txt)) + + self.map.draw_artist(a) + self.map.draw_artist(ax) + self.figure.canvas.blit(self.map.bbox) + if not flag and self.map.bg_cache is not None and a.get_visible(): + a.set_visible(False) + ax.set_visible(False) + self.figure.canvas.restore_region(self.map.bg_cache) + self.map.draw_artist(a) + self.map.draw_artist(ax) + self.figure.canvas.blit(self.map.bbox) + self.map.bg_cache = None + def release(self, event): if self.time_ax.press: self.time_ax.press = False @@ -292,21 +382,3 @@ def adjust(self, event=None): def show(self): self.update() plt.show() - - -if __name__ == "__main__": - - # a_ = A.load_file( - # "/home/toto/dev/work/pet/20200611_example_dataset/tracking/Anticyclonic_track_too_short.nc" - # ) - # c_ = A.load_file( - # "/home/toto/dev/work/pet/20200611_example_dataset/tracking/Cyclonic_track_too_short.nc" - # ) - a = A.load_file(sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr")) - # c = A.load_file(sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr")) - # g = GUI(Acyc=a, Cyc=c, Acyc_short=a_, Cyc_short=c_) - g = GUI(Acyc=a) - # g = GUI(Acyc_short=a_) - # g = GUI(Acyc_short=a_, Cyc_short=c_) - g.med() - g.show() diff --git a/src/py_eddy_tracker/misc.py b/src/py_eddy_tracker/misc.py new file mode 100644 index 00000000..647bfba3 --- /dev/null +++ b/src/py_eddy_tracker/misc.py @@ -0,0 +1,21 @@ +import re + +from matplotlib.animation import FuncAnimation + + +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py new file mode 100644 index 00000000..81929e1e --- /dev/null +++ b/src/py_eddy_tracker/observations/groups.py @@ -0,0 +1,476 @@ +from abc import ABC, abstractmethod +import logging + +from numba import njit, types as nb_types +from numpy import arange, full, int32, interp, isnan, median, where, zeros + +from ..generic import window_index +from ..poly import create_meshed_particles, poly_indexs +from .observation import EddiesObservations + +logger = logging.getLogger("pet") + + +@njit(cache=True) +def get_missing_indices( + array_time, array_track, dt=1, flag_untrack=True, indice_untrack=0 +): + """Return indexes where values are missing + + :param np.array(int) array_time : array of strictly increasing int representing time + :param np.array(int) array_track: N° track where observations belong + :param int,float dt: theorical timedelta between 2 observations + :param bool flag_untrack: if True, ignore observations where n°track equal `indice_untrack` + :param int indice_untrack: n° representing where observations are untracked + + + ex : array_time = np.array([67, 68, 70, 71, 74, 75]) + array_track= np.array([ 1, 1, 1, 1, 1, 1]) + return : np.array([2, 4, 4]) + """ + + t0 = array_time[0] + t1 = t0 + + tr0 = array_track[0] + tr1 = tr0 + + nbr_step = zeros(array_time.shape, dtype=int32) + + for i in range(array_time.size - 1): + t0 = t1 + tr0 = tr1 + + t1 = array_time[i + 1] + tr1 = array_track[i + 1] + + if flag_untrack & (tr1 == indice_untrack): + continue + + if tr1 != tr0: + continue + + diff = t1 - t0 + if diff > dt: + nbr_step[i] = int(diff / dt) - 1 + + indices = zeros(nbr_step.sum(), dtype=int32) + + j = 0 + for i in range(array_time.size - 1): + nbr_missing = nbr_step[i] + + if nbr_missing != 0: + for k in range(nbr_missing): + indices[j] = i + 1 + j += 1 + return indices + + +def advect(x, y, c, t0, n_days, u_name="u", v_name="v"): + """ + Advect particles from t0 to t0 + n_days, with data cube. + + :param np.array(float) x: longitude of particles + :param np.array(float) y: latitude of particles + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int t0: julian day of advection start + :param int n_days: number of days to advect + :param str u_name: variable name for u component + :param str v_name: variable name for v component + """ + + kw = dict(nb_step=6, time_step=86400 / 6) + if n_days < 0: + kw["backward"] = True + n_days = -n_days + p = c.advect(x, y, u_name=u_name, v_name=v_name, t_init=t0, **kw) + for _ in range(n_days): + t, x, y = p.__next__() + return t, x, y + + +def particle_candidate_step( + t_start, contours_start, contours_end, space_step, dt, c, day_fraction=6, **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages. + For one time step. + + :param int t_start: julian day of the advection + :param (np.array(float),np.array(float)) contours_start: origin contour + :param (np.array(float),np.array(float)) contours_end: destination contour + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int day_fraction: fraction of day + :params dict kwargs: dict of params given to advection + :return (np.array,np.array): return target index and percent associate + """ + # In case of zarr array + contours_start = [i[:] for i in contours_start] + contours_end = [i[:] for i in contours_end] + # Create particles in start contour + x, y, i_start = create_meshed_particles(*contours_start, space_step) + # Advect particles + kw = dict(nb_step=day_fraction, time_step=86400 / day_fraction) + p = c.advect(x, y, t_init=t_start, **kwargs, **kw) + for _ in range(abs(dt)): + _, x, y = p.__next__() + m = ~(isnan(x) + isnan(y)) + i_end = full(x.shape, -1, dtype="i4") + if m.any(): + # Id eddies for each alive particle in start contour + i_end[m] = poly_indexs(x[m], y[m], *contours_end) + shape = (contours_start[0].shape[0], 2) + # Get target for each contour + i_target, pct_target = full(shape, -1, dtype="i4"), zeros(shape, dtype="f8") + nb_end = contours_end[0].shape[0] + get_targets(i_start, i_end, i_target, pct_target, nb_end) + return i_target, pct_target.astype("i1") + + +def particle_candidate( + c, + eddies, + step_mesh, + t_start, + i_target, + pct, + contour_start="speed", + contour_end="effective", + **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param GroupEddiesObservations eddies: GroupEddiesObservations considered + :param int t_start: julian day of the advection + :param np.array(int) i_target: corresponding obs where particles are advected + :param np.array(int) pct: corresponding percentage of avected particles + :param str contour_start: contour where particles are injected + :param str contour_end: contour where particles are counted after advection + :params dict kwargs: dict of params given to `advect` + + """ + # Obs from initial time + m_start = eddies.time == t_start + e = eddies.extract_with_mask(m_start) + + # to be able to get global index + translate_start = where(m_start)[0] + + # Create particles in specified contour + intern = False if contour_start == "effective" else True + x, y, i_start = e.create_particles(step_mesh, intern=intern) + # Advection + t_end, x, y = advect(x, y, c, t_start, **kwargs) + + # eddies at last date + m_end = eddies.time == t_end / 86400 + e_end = eddies.extract_with_mask(m_end) + + # to be able to get global index + translate_end = where(m_end)[0] + + # Id eddies for each alive particle in specified contour + intern = False if contour_end == "effective" else True + i_end = e_end.contains(x, y, intern=intern) + + # compute matrix and fill target array + get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct) + + +@njit(cache=True) +def get_targets(i_start, i_end, i_target, pct, nb_end): + """Compute target observation and associated percentages + + :param array(int) i_start: indices in time 0 + :param array(int) i_end: indices in time N + :param array(int) i_target: corresponding obs where particles are advected + :param array(int) pct: corresponding percentage of avected particles + :param int nb_end: number of contour at time N + """ + nb_start = i_target.shape[0] + # Matrix which will store count for every couple + counts = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + ref[i_start_] += 1 + if i_end_ != -1: + counts[i_start_, i_end_] += 1 + # From i to j + for i in range(nb_start): + for j in range(nb_end): + count = counts[i, j] + if count == 0: + continue + pct_ = count / ref[i] * 100 + pct_0 = pct[i, 0] + # If percent is higher than previous stored in rank 0 + if pct_ > pct_0: + pct[i, 1] = pct_0 + pct[i, 0] = pct_ + i_target[i, 1] = i_target[i, 0] + i_target[i, 0] = j + # If percent is higher than previous stored in rank 1 + elif pct_ > pct[i, 1]: + pct[i, 1] = pct_ + i_target[i, 1] = j + + +@njit(cache=True) +def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct): + """Compute target observation and associated percentages + + :param np.array(int) i_start: indices of associated contours at starting advection day + :param np.array(int) i_end: indices of associated contours after advection + :param np.array(int) translate_start: corresponding global indices at starting advection day + :param np.array(int) translate_end: corresponding global indices after advection + :param np.array(int) i_target: corresponding obs where particles are advected + :param np.array(int) pct: corresponding percentage of avected particles + """ + + nb_start, nb_end = translate_start.size, translate_end.size + # Matrix which will store count for every couple + count = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + if i_end_ != -1: + count[i_start_, i_end_] += 1 + ref[i_start_] += 1 + for i in range(nb_start): + for j in range(nb_end): + pct_ = count[i, j] + # If there are particles from i to j + if pct_ != 0: + # Get percent + pct_ = pct_ / ref[i] * 100.0 + # Get indices in full dataset + i_, j_ = translate_start[i], translate_end[j] + pct_0 = pct[i_, 0] + if pct_ > pct_0: + pct[i_, 1] = pct_0 + pct[i_, 0] = pct_ + i_target[i_, 1] = i_target[i_, 0] + i_target[i_, 0] = j_ + elif pct_ > pct[i_, 1]: + pct[i_, 1] = pct_ + i_target[i_, 1] = j_ + return i_target, pct + + +class GroupEddiesObservations(EddiesObservations, ABC): + @abstractmethod + def fix_next_previous_obs(self): + pass + + @abstractmethod + def get_missing_indices(self, dt): + "Find indexes where observations are missing" + pass + + def filled_by_interpolation(self, mask): + """Fill selected values by interpolation + + :param array(bool) mask: True if must be filled by interpolation + + .. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation + """ + if self.track.size == 0: + return + nb_filled = mask.sum() + logger.info("%d obs will be filled (unobserved)", nb_filled) + + nb_obs = len(self) + index = arange(nb_obs) + + for field in self.fields: + if ( + field in ["n", "virtual", "track", "cost_association"] + or field in self.array_variables + ): + continue + self.obs[field][mask] = interp( + index[mask], index[~mask], self.obs[field][~mask] + ) + + def insert_virtual(self): + """Insert virtual observations on segments where observations are missing""" + + dt_theorical = median(self.time[1:] - self.time[:-1]) + indices = self.get_missing_indices(dt_theorical) + + logger.info("%d virtual observation will be added", indices.size) + + # new observations size + size_obs_corrected = self.time.size + indices.size + + # correction of indexes for new size + indices_corrected = indices + arange(indices.size) + + # creating mask with indexes + mask = zeros(size_obs_corrected, dtype=bool) + mask[indices_corrected] = 1 + + new_TEO = self.new_like(self, size_obs_corrected) + new_TEO.obs[~mask] = self.obs + new_TEO.filled_by_interpolation(mask) + new_TEO.virtual[:] = mask + new_TEO.fix_next_previous_obs() + return new_TEO + + def keep_tracks_by_date(self, date, nb_days): + """ + Find tracks that exist at date `date` and lasted at least `nb_days` after. + + :param int,float date: date where the tracks must exist + :param int,float nb_days: number of times the tracks must exist. Can be negative + + If nb_days is negative, it searches a track that exists at the date, + but existed at least `nb_days` before the date + """ + + time = self.time + + mask = zeros(time.shape, dtype=bool) + + for i, b0, b1 in self.iter_on(self.tracks): + _time = time[i] + + if date in _time and (date + nb_days) in _time: + mask[i] = True + + return self.extract_with_mask(mask) + + def particle_candidate_atlas( + self, + cube, + space_step, + dt, + start_intern=False, + end_intern=False, + callback_coherence=None, + finalize_coherence=None, + **kwargs + ): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` cube: GridCollection with speed for particles + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param bool start_intern: Use intern or extern contour at injection, defaults to False + :param bool end_intern: Use intern or extern contour at end of advection, defaults to False + :param dict kwargs: dict of params given to advection + :param func callback_coherence: if None we will use cls.fill_coherence + :param func finalize_coherence: to apply on results of callback_coherence + :return (np.array,np.array): return target index and percent associate + """ + t_start, t_end = int(self.period[0]), int(self.period[1]) + # Pre-compute to get time index + i_sort, i_start, i_end = window_index( + self.time, arange(t_start, t_end + 1), 0.5 + ) + # Out shape + shape = (len(self), 2) + i_target, pct = full(shape, -1, dtype="i4"), zeros(shape, dtype="i1") + # Backward or forward + times = arange(t_start, t_end - dt) if dt > 0 else arange(t_start - dt, t_end) + + if callback_coherence is None: + callback_coherence = self.fill_coherence + indexs = dict() + results = list() + kw_coherence = dict(space_step=space_step, dt=dt, c=cube) + kw_coherence.update(kwargs) + for t in times: + logger.info( + "Coherence for time step : %s in [%s:%s]", t, times[0], times[-1] + ) + # Get index for origin + i = t - t_start + indexs0 = i_sort[i_start[i] : i_end[i]] + # Get index for end + i = t + dt - t_start + indexs1 = i_sort[i_start[i] : i_end[i]] + if indexs0.size == 0 or indexs1.size == 0: + continue + + results.append( + callback_coherence( + self, + i_target, + pct, + indexs0, + indexs1, + start_intern, + end_intern, + t_start=t, + **kw_coherence + ) + ) + indexs[results[-1]] = indexs0, indexs1 + + if finalize_coherence is not None: + finalize_coherence(results, indexs, i_target, pct) + return i_target, pct + + @classmethod + def fill_coherence( + cls, + network, + i_targets, + percents, + i_origin, + i_end, + start_intern, + end_intern, + **kwargs + ): + """_summary_ + + :param array i_targets: global target + :param array percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + :param bool start_intern: Use intern or extern contour at injection + :param bool end_intern: Use intern or extern contour at end of advection + """ + # Get contour data + contours_start = [ + network[label][i_origin] for label in cls.intern(start_intern) + ] + contours_end = [network[label][i_end] for label in cls.intern(end_intern)] + # Compute local coherence + i_local_targets, local_percents = particle_candidate_step( + contours_start=contours_start, contours_end=contours_end, **kwargs + ) + # Store + cls.merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ) + + @staticmethod + def merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ): + """Copy local result in merged result with global indexation + + :param array i_targets: global target + :param array percents: + :param array i_local_targets: local index target + :param array local_percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + """ + m = i_local_targets != -1 + i_local_targets[m] = i_end[i_local_targets[m]] + i_targets[i_origin] = i_local_targets + percents[i_origin] = local_percents diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 96009220..f0b9d7cc 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -1,75 +1,2012 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . +Class to create network of observations +""" +from glob import glob +import logging +import time +from datetime import timedelta, datetime +import os +import netCDF4 +from numba import njit, types as nb_types +from numba.typed import List +import numpy as np +from numpy import ( + arange, + array, + bincount, + bool_, + concatenate, -Copyright (c) 2014-2017 by Evan Mason and Antoine Delepoulle -Email: emason@imedea.uib-csic.es -=========================================================================== + empty, + nan, + ones, + percentile, + uint16, + uint32, + unique, + where, + zeros, +) +import zarr -tracking.py +from ..dataset.grid import GridCollection +from ..generic import build_index, wrap_longitude +from ..poly import bbox_intersection, vertice_overlap +from .groups import GroupEddiesObservations, get_missing_indices, particle_candidate +from .observation import EddiesObservations +from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter -Version 3.0.0 +logger = logging.getLogger("pet") -=========================================================================== -""" -import logging -from glob import glob -from numpy import array, empty, arange, unique, bincount, uint32 -from numba import njit -from .observation import EddiesObservations -from .tracking import TrackEddiesObservations -from ..poly import bbox_intersection, vertice_overlap +class Singleton(type): + _instances = {} -logger = logging.getLogger("pet") + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] -class Network: - __slots__ = ("window", "filenames", "contour_name", "nb_input", "xname", "yname") - # To be used like a buffer +class Buffer(metaclass=Singleton): + __slots__ = ( + "buffersize", + "contour_name", + "xname", + "yname", + "memory", + ) DATA = dict() FLIST = list() - NOGROUP = TrackEddiesObservations.NOGROUP - def __init__(self, input_regex, window=5, intern=False): - self.window = window + def __init__(self, buffersize, intern=False, memory=False): + self.buffersize = buffersize self.contour_name = EddiesObservations.intern(intern, public_label=True) self.xname, self.yname = EddiesObservations.intern(intern) - self.filenames = glob(input_regex) - self.filenames.sort() - self.nb_input = len(self.filenames) + self.memory = memory def load_contour(self, filename): + if isinstance(filename, EddiesObservations): + return filename[self.xname], filename[self.yname] if filename not in self.DATA: - if len(self.FLIST) > self.window: + if len(self.FLIST) > self.buffersize: self.DATA.pop(self.FLIST.pop(0)) - e = EddiesObservations.load_file(filename, include_vars=self.contour_name) + if self.memory: + # Only if netcdf + with open(filename, "rb") as h: + e = EddiesObservations.load_file(h, include_vars=self.contour_name) + else: + e = EddiesObservations.load_file( + filename, include_vars=self.contour_name + ) + self.FLIST.append(filename) self.DATA[filename] = e[self.xname], e[self.yname] return self.DATA[filename] + +@njit(cache=True) +def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): + """When an observation is virtual, we have to fix the previous and next obs + + :param np.array(int) next_obs : index of next observation from network + :param np.array(int previous_obs: index of previous observation from network + :param np.array(bool) flag_virtual: if observation is virtual or not + """ + + for i_o in range(next_obs.size): + if not flag_virtual[i_o]: + continue + + # if there are several consecutive virtuals, some values are written multiple times. + # but it should not be slow + next_obs[i_o - 1] = i_o + next_obs[i_o] = i_o + 1 + previous_obs[i_o] = i_o - 1 + previous_obs[i_o + 1] = i_o + + +class NetworkObservations(GroupEddiesObservations): + __slots__ = ("_index_network", "_index_segment_track", "_segment_track_array") + NOGROUP = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.reset_index() + + def __repr__(self): + m_event, s_event = ( + self.merging_event(only_index=True, triplet=True)[0], + self.splitting_event(only_index=True, triplet=True)[0], + ) + period = (self.period[1] - self.period[0]) / 365.25 + nb_by_network = self.network_size() + nb_trash = 0 if self.ref_index != 0 else nb_by_network[0] + lifetime=self.lifetime + big = 50_000 + infos = [ + f"Atlas with {self.nb_network} networks ({self.nb_network / period:0.0f} networks/year)," + f" {self.nb_segment} segments ({self.nb_segment / period:0.0f} segments/year), {len(self)} observations ({len(self) / period:0.0f} observations/year)", + f" {m_event.size} merging ({m_event.size / period:0.0f} merging/year), {s_event.size} splitting ({s_event.size / period:0.0f} splitting/year)", + f" with {(nb_by_network > big).sum()} network with more than {big} obs and the biggest have {nb_by_network.max()} observations ({nb_by_network[nb_by_network > big].sum()} observations cumulate)", + f" {nb_trash} observations in trash", + f" {lifetime.max()} days max of lifetime", + ] + return "\n".join(infos) + + def reset_index(self): + self._index_network = None + self._index_segment_track = None + self._segment_track_array = None + + def find_segments_relative(self, obs, stopped=None, order=1): + """ + Find all relative segments from obs linked with merging/splitting events at a specific order. + + :param int obs: index of observation after the event + :param int stopped: index of observation before the event + :param int order: order of relatives accepted + :return: all relative segments + :rtype: EddiesObservations + """ + + # extraction of network where the event is + network_id = self.tracks[obs] + nw = self.network(network_id) + + # indice of observation in new subnetwork + i_obs = where(nw.segment == self.segment[obs])[0][0] + + if stopped is None: + return nw.relatives(i_obs, order=order) + + else: + i_stopped = where(nw.segment == self.segment[stopped])[0][0] + return nw.relatives([i_obs, i_stopped], order=order) + + def get_missing_indices(self, dt): + """Find indices where observations are missing. + + As network have all untracked observation in tracknumber `self.NOGROUP`, + we don't compute them + + :param int,float dt: theorical delta time between 2 observations + """ + return get_missing_indices( + self.time, self.track, dt=dt, flag_untrack=True, indice_untrack=self.NOGROUP + ) + + def fix_next_previous_obs(self): + """Function used after 'insert_virtual', to correct next_obs and + previous obs. + """ + + fix_next_previous_obs(self.next_obs, self.previous_obs, self.virtual) + + @property + def index_network(self): + if self._index_network is None: + self._index_network = build_index(self.track) + return self._index_network + + @property + def index_segment_track(self): + if self._index_segment_track is None: + self._index_segment_track = build_index(self.segment_track_array) + return self._index_segment_track + + def segment_size(self): + return self.index_segment_track[1] - self.index_segment_track[0] + + @property + def ref_segment_track_index(self): + return self.index_segment_track[2] + + @property + def ref_index(self): + return self.index_network[2] + + @property + def lifetime(self): + """Return lifetime for each observation""" + lt=self.networks_period.astype("int") + nb_by_network=self.network_size() + return lt.repeat(nb_by_network) + + def network_segment_size(self, id_networks=None): + """Get number of segment by network + + :return array: + """ + i0, i1, ref = build_index(self.track[self.index_segment_track[0]]) + if id_networks is None: + return i1 - i0 + else: + i = id_networks - ref + return i1[i] - i0[i] + + def network_size(self, id_networks=None): + """ + Return size for specified network + + :param list,array, None id_networks: ids to identify network + """ + if id_networks is None: + return self.index_network[1] - self.index_network[0] + else: + i = id_networks - self.index_network[2] + return self.index_network[1][i] - self.index_network[0][i] + + @property + def networks_period(self): + """ + Return period for each network + """ + return get_period_with_index(self.time, *self.index_network[:2]) + + + + def unique_segment_to_id(self, id_unique): + """Return id network and id segment for a unique id + + :param array id_unique: + """ + i = self.index_segment_track[0][id_unique] - self.ref_segment_track_index + return self.track[i], self.segment[i] + + def segment_slice(self, id_network, id_segment): + """ + Return slice for one segment + + :param int id_network: id to identify network + :param int id_segment: id to identify segment + """ + raise Exception("need to be implemented") + + def network_slice(self, id_network): + """ + Return slice for one network + + :param int id_network: id to identify network + """ + i = id_network - self.index_network[2] + i_start, i_stop = self.index_network[0][i], self.index_network[1][i] + return slice(i_start, i_stop) + + @property + def elements(self): + elements = super().elements + elements.extend( + [ + "track", + "segment", + "next_obs", + "previous_obs", + "next_cost", + "previous_cost", + ] + ) + return list(set(elements)) + + def astype(self, cls): + new = cls.new_like(self, self.shape) + for k in new.fields: + if k in self.fields: + new[k][:] = self[k][:] + new.sign_type = self.sign_type + return new + + def longer_than(self, nb_day_min=-1, nb_day_max=-1): + """ + Select network on time duration + + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used + """ + return self.extract_with_mask(self.mask_longer_than(nb_day_min, nb_day_max)) + + def mask_longer_than(self, nb_day_min=-1, nb_day_max=-1): + """ + Select network on time duration + + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used + """ + if nb_day_max < 0: + nb_day_max = 1000000000000 + mask = zeros(self.shape, dtype="bool") + t = self.time + for i, _, _ in self.iter_on(self.track): + nb = i.stop - i.start + if nb == 0: + continue + if nb_day_min <= (ptp(t[i]) + 1) <= nb_day_max: + mask[i] = True + return mask + + @classmethod + def from_split_network(cls, group_dataset, indexs, **kwargs): + """ + Build a NetworkObservations object with Group dataset and indices + + :param TrackEddiesObservations group_dataset: Group dataset + :param indexs: result from split_network + :return: NetworkObservations + """ + index_order = indexs.argsort(order=("group", "track", "time")) + network = cls.new_like(group_dataset, len(group_dataset), **kwargs) + network.sign_type = group_dataset.sign_type + for field in group_dataset.elements: + if field not in network.elements: + continue + network[field][:] = group_dataset[field][index_order] + network.segment[:] = indexs["track"][index_order] + # n & p must be re-indexed + n, p = indexs["next_obs"][index_order], indexs["previous_obs"][index_order] + # we add 2 for -1 index return index -1 + translate = -ones(index_order.max() + 2, dtype="i4") + translate[index_order] = arange(index_order.shape[0]) + network.next_obs[:] = translate[n] + network.previous_obs[:] = translate[p] + network.next_cost[:] = indexs["next_cost"][index_order] + network.previous_cost[:] = indexs["previous_cost"][index_order] + return network + + def infos(self, label=""): + return f"{len(self)} obs {unique(self.segment).shape[0]} segments" + + def correct_close_events(self, nb_days_max=20): + """ + Transform event where + segment A splits from segment B, then x days after segment B merges with A + to + segment A splits from segment B then x days after segment A merges with B (B will be longer) + These events have to last less than `nb_days_max` to be changed. + + + ------------------- A + / / + B -------------------- + to + --A-- + / \ + B ----------------------------------- + + :param float nb_days_max: maximum time to search for splitting-merging event + """ + + _time = self.time + # segment used to correct and track changes + segment = self.segment_track_array.copy() + # final segment used to copy into self.segment + segment_copy = self.segment + + segments_connexion = dict() + + previous_obs, next_obs = self.previous_obs, self.next_obs + + # record for every segment the slice, index of next obs & index of previous obs + for i, seg, _ in self.iter_on(segment): + if i.start == i.stop: + continue + + i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1] + segments_connexion[seg] = [i, i_p, i_n] + + for seg in sorted(segments_connexion.keys()): + seg_slice, _, i_seg_n = segments_connexion[seg] + + # the segment ID has to be corrected, because we may have changed it since + seg_corrected = segment[seg_slice.stop - 1] + + # we keep the real segment number + seg_corrected_copy = segment_copy[seg_slice.stop - 1] + if i_seg_n == -1: + continue + + # if segment is split + n_seg = segment[i_seg_n] + + seg2_slice, i2_seg_p, _ = segments_connexion[n_seg] + if i2_seg_p == -1: + continue + p2_seg = segment[i2_seg_p] + + # if it merges on the first in a certain time + if (p2_seg == seg_corrected) and ( + _time[i_seg_n] - _time[i2_seg_p] < nb_days_max + ): + my_slice = slice(i_seg_n, seg2_slice.stop) + # correct the factice segment + segment[my_slice] = seg_corrected + # correct the good segment + segment_copy[my_slice] = seg_corrected_copy + previous_obs[i_seg_n] = seg_slice.stop - 1 + + segments_connexion[seg_corrected][0] = my_slice + + return self.sort() + + def sort(self, order=("track", "segment", "time")): + """ + Sort observations + + :param tuple order: order or sorting. Given to :func:`numpy.argsort` + """ + index_order = self.obs.argsort(order=order, kind="mergesort") + self.reset_index() + for field in self.fields: + self[field][:] = self[field][index_order] + + nb_obs = len(self) + # we add 1 for -1 index return index -1 + translate = -ones(nb_obs + 1, dtype="i4") + translate[index_order] = arange(nb_obs) + # next & previous must be re-indexed + self.next_obs[:] = translate[self.next_obs] + self.previous_obs[:] = translate[self.previous_obs] + return index_order, translate + + def obs_relative_order(self, i_obs): + self.only_one_network() + return self.segment_relative_order(self.segment[i_obs]) + + def find_link(self, i_observations, forward=True, backward=False): + """ + Find all observations where obs `i_observation` could be + in future or past. + + If forward=True, search all observations where water + from obs "i_observation" could go + + If backward=True, search all observation + where water from obs `i_observation` could come from + + :param int,iterable(int) i_observation: + indices of observation. Can be + int, or iterable of int. + :param bool forward, backward: + if forward, search observations after obs. + else mode==backward search before obs + + """ + + i_obs = ( + [i_observations] + if not hasattr(i_observations, "__iter__") + else i_observations + ) + + segment = self.segment_track_array + previous_obs, next_obs = self.previous_obs, self.next_obs + + segments_connexion = dict() + + for i_slice, seg, _ in self.iter_on(segment): + if i_slice.start == i_slice.stop: + continue + + i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1] + p_seg, n_seg = segment[i_p], segment[i_n] + + # dumping slice into dict + if seg not in segments_connexion: + segments_connexion[seg] = [i_slice, [], []] + else: + segments_connexion[seg][0] = i_slice + + if i_p != -1: + if p_seg not in segments_connexion: + segments_connexion[p_seg] = [None, [], []] + + # backward + segments_connexion[seg][2].append((i_slice.start, i_p, p_seg)) + # forward + segments_connexion[p_seg][1].append((i_p, i_slice.start, seg)) + + if i_n != -1: + if n_seg not in segments_connexion: + segments_connexion[n_seg] = [None, [], []] + + # forward + segments_connexion[seg][1].append((i_slice.stop - 1, i_n, n_seg)) + # backward + segments_connexion[n_seg][2].append((i_n, i_slice.stop - 1, seg)) + + mask = zeros(segment.size, dtype=bool) + + def func_forward(seg, indice): + seg_slice, _forward, _ = segments_connexion[seg] + + mask[indice : seg_slice.stop] = True + for i_begin, i_end, seg2 in _forward: + if i_begin < indice: + continue + + if not mask[i_end]: + func_forward(seg2, i_end) + + def func_backward(seg, indice): + seg_slice, _, _backward = segments_connexion[seg] + + mask[seg_slice.start : indice + 1] = True + for i_begin, i_end, seg2 in _backward: + if i_begin > indice: + continue + + if not mask[i_end]: + func_backward(seg2, i_end) + + for indice in i_obs: + if forward: + func_forward(segment[indice], indice) + + if backward: + func_backward(segment[indice], indice) + + return self.extract_with_mask(mask) + + def connexions(self, multi_network=False): + """Create dictionnary for each segment, gives the segments in interaction with + + :param bool multi_network: use segment_track_array instead of segment, defaults to False + :return dict: Return dict of set, for each seg id we get set of segment which have event with him + """ + if multi_network: + segment = self.segment_track_array + else: + self.only_one_network() + segment = self.segment + segments_connexion = dict() + + def add_seg(s1, s2): + if s1 not in segments_connexion: + segments_connexion[s1] = set() + if s2 not in segments_connexion: + segments_connexion[s2] = set() + segments_connexion[s1].add(s2), segments_connexion[s2].add(s1) + + # Get index for each segment + i0, i1, _ = self.index_segment_track + i1 = i1 - 1 + # Check if segment merge + i_next = self.next_obs[i1] + m_n = i_next != -1 + # Check if segment come from splitting + i_previous = self.previous_obs[i0] + m_p = i_previous != -1 + # For each split + for s1, s2 in zip(segment[i_previous[m_p]], segment[i0[m_p]]): + add_seg(s1, s2) + # For each merge + for s1, s2 in zip(segment[i_next[m_n]], segment[i1[m_n]]): + add_seg(s1, s2) + return segments_connexion + + @classmethod + def __close_segment(cls, father, shift, connexions, distance): + i_father = father - shift + if distance[i_father] == -1: + distance[i_father] = 0 + d_target = distance[i_father] + 1 + for son in connexions.get(father, list()): + i_son = son - shift + d_son = distance[i_son] + if d_son == -1 or d_son > d_target: + distance[i_son] = d_target + else: + continue + cls.__close_segment(son, shift, connexions, distance) + + def segment_relative_order(self, seg_origine): + """ + Compute the relative order of each segment to the chosen segment + """ + self.only_one_network() + i_s, i_e, i_ref = build_index(self.segment) + segment_connexions = self.connexions() + relative_tr = -ones(i_s.shape, dtype="i4") + self.__close_segment(seg_origine, i_ref, segment_connexions, relative_tr) + d = -ones(self.shape) + for i0, i1, v in zip(i_s, i_e, relative_tr): + if i0 == i1: + continue + d[i0:i1] = v + return d + + def relatives(self, obs, order=2): + """ + Extract the segments at a certain order from multiple observations. + + :param iterable,int obs: + indices of observation for relatives computation. Can be one observation (int) + or collection of observations (iterable(int)) + :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ... + :return: all segments' relatives + :rtype: EddiesObservations + """ + segment = self.segment_track_array + previous_obs, next_obs = self.previous_obs, self.next_obs + + segments_connexion = dict() + + for i_slice, seg, _ in self.iter_on(segment): + if i_slice.start == i_slice.stop: + continue + + i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1] + p_seg, n_seg = segment[i_p], segment[i_n] + + # dumping slice into dict + if seg not in segments_connexion: + segments_connexion[seg] = [i_slice, []] + else: + segments_connexion[seg][0] = i_slice + + if i_p != -1: + if p_seg not in segments_connexion: + segments_connexion[p_seg] = [None, []] + + # backward + segments_connexion[seg][1].append(p_seg) + segments_connexion[p_seg][1].append(seg) + + if i_n != -1: + if n_seg not in segments_connexion: + segments_connexion[n_seg] = [None, []] + + # forward + segments_connexion[seg][1].append(n_seg) + segments_connexion[n_seg][1].append(seg) + + i_obs = [obs] if not hasattr(obs, "__iter__") else obs + distance = zeros(segment.size, dtype=uint16) - 1 + + def loop(seg, dist=1): + i_slice, links = segments_connexion[seg] + d = distance[i_slice.start] + + if dist < d and dist <= order: + distance[i_slice] = dist + for _seg in links: + loop(_seg, dist + 1) + + for indice in i_obs: + loop(segment[indice], 0) + + return self.extract_with_mask(distance <= order) + + # keep old names, for backward compatibility + relative = relatives + + def close_network(self, other, nb_obs_min=10, **kwargs): + """ + Get close network from another atlas. + + :param self other: Atlas to compare + :param int nb_obs_min: Minimal number of overlap for one trajectory + :param dict kwargs: keyword arguments for match function + :return: return other atlas reduced to common tracks with self + + .. warning:: + It could be a costly operation for huge dataset + """ + p0, p1 = self.period + indexs = list() + for i_self, i_other, t0, t1 in self.align_on(other, bins=range(p0, p1 + 2)): + i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs) + indexs.append(other.re_reference_index(j, i_other)) + indexs = concatenate(indexs) + tr, nb = unique(other.track[indexs], return_counts=True) + m = zeros(other.track.shape, dtype=bool) + for i in tr[nb >= nb_obs_min]: + m[other.network_slice(i)] = True + return other.extract_with_mask(m) + + def normalize_longitude(self): + """Normalize all longitudes + + Normalize longitude field and in the same range : + - longitude_max + - contour_lon_e (how to do if in raw) + - contour_lon_s (how to do if in raw) + """ + i_start, i_stop, _ = self.index_network + lon0 = (self.lon[i_start] - 180).repeat(i_stop - i_start) + logger.debug("Normalize longitude") + self.lon[:] = (self.lon - lon0) % 360 + lon0 + if "lon_max" in self.fields: + logger.debug("Normalize longitude_max") + self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 + if not self.raw_data: + if "contour_lon_e" in self.fields: + logger.debug("Normalize effective contour longitude") + self.contour_lon_e[:] = ( + (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 + ).T + if "contour_lon_s" in self.fields: + logger.debug("Normalize speed contour longitude") + self.contour_lon_s[:] = ( + (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 + ).T + + def numbering_segment(self, start=0): + """ + New numbering of segment + """ + for i, _, _ in self.iter_on("track"): + new_numbering(self.segment[i], start) + + def numbering_network(self, start=1): + """ + New numbering of network + """ + new_numbering(self.track, start) + + def only_one_network(self): + """ + Raise a warning or error? + if there are more than one network + """ + _, i_start, _ = self.index_network + if i_start.size > 1: + raise Exception("Several networks") + + def position_filter(self, median_half_window, loess_half_window): + self.median_filter(median_half_window, "time", "lon").loess_filter( + loess_half_window, "time", "lon" + ) + self.median_filter(median_half_window, "time", "lat").loess_filter( + loess_half_window, "time", "lat" + ) + + def loess_filter(self, half_window, xfield, yfield, inplace=True): + result = track_loess_filter( + half_window, self.obs[xfield], self.obs[yfield], self.segment_track_array + ) + if inplace: + self.obs[yfield] = result + return self + return result + + def median_filter(self, half_window, xfield, yfield, inplace=True): + result = track_median_filter( + half_window, self[xfield], self[yfield], self.segment_track_array + ) + if inplace: + self[yfield][:] = result + return self + return result + + def display_timeline( + self, + ax, + event=True, + field=None, + method=None, + factor=1, + colors_mode="roll", + **kwargs, + ): + """ + Plot the timeline of a network. + Must be called on only one network. + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param bool event: if True, draw the splitting and merging events + :param str,array field: yaxis values, if None, segments are used + :param str method: if None, mean values are used + :param float factor: to multiply field + :param str colors_mode: + color of lines. "roll" means looping through colors, + "y" means color adapt the y values (for matching color plots) + :return: plot mappable + """ + self.only_one_network() + j = 0 + line_kw = dict( + ls="-", + marker="+", + markersize=6, + zorder=1, + lw=3, + ) + line_kw.update(kwargs) + mappables = dict(lines=list()) + + if event: + mappables.update( + self.event_timeline( + ax, + field=field, + method=method, + factor=factor, + colors_mode=colors_mode, + ) + ) + if field is not None: + field = self.parse_varname(field) + for i, b0, b1 in self.iter_on("segment"): + x = self.time_datetime64[i] + if x.shape[0] == 0: + continue + if field is None: + y = b0 * ones(x.shape) + else: + if method == "all": + y = field[i] * factor + else: + y = field[i].mean() * ones(x.shape) * factor + + if colors_mode == "roll": + _color = self.get_color(j) + elif colors_mode == "y": + _color = self.get_color(b0 - 1) + else: + raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") + + line = ax.plot(x, y, **line_kw, color=_color)[0] + mappables["lines"].append(line) + j += 1 + + return mappables + + def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="roll"): + """Mark events in plot""" + j = 0 + events = dict(splitting=[], merging=[]) + + # TODO : fill mappables dict + y_seg = dict() + _time = self.time_datetime64 + + if field is not None and method != "all": + for i, b0, _ in self.iter_on("segment"): + y = self.parse_varname(field)[i] + if y.shape[0] != 0: + y_seg[b0] = y.mean() * factor + mappables = dict() + for i, b0, b1 in self.iter_on("segment"): + x = _time[i] + if x.shape[0] == 0: + continue + + if colors_mode == "roll": + _color = self.get_color(j) + elif colors_mode == "y": + _color = self.get_color(b0 - 1) + else: + raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") + + event_kw = dict(color=_color, ls="-", zorder=1) + + i_n, i_p = ( + self.next_obs[i.stop - 1], + self.previous_obs[i.start], + ) + if field is None: + y0 = b0 + else: + if method == "all": + y0 = self.parse_varname(field)[i.stop - 1] * factor + else: + y0 = y_seg[b0] + if i_n != -1: + seg_next = self.segment[i_n] + y1 = ( + seg_next + if field is None + else ( + self.parse_varname(field)[i_n] * factor + if method == "all" + else y_seg[seg_next] + ) + ) + ax.plot((x[-1], _time[i_n]), (y0, y1), **event_kw)[0] + events["merging"].append((x[-1], y0)) + + if i_p != -1: + seg_previous = self.segment[i_p] + if field is not None and method == "all": + y0 = self[field][i.start] * factor + y1 = ( + seg_previous + if field is None + else ( + self.parse_varname(field)[i_p] * factor + if method == "all" + else y_seg[seg_previous] + ) + ) + ax.plot((x[0], _time[i_p]), (y0, y1), **event_kw)[0] + events["splitting"].append((x[0], y0)) + + j += 1 + + kwargs = dict(color="k", zorder=-1, linestyle=" ") + if len(events["splitting"]) > 0: + X, Y = list(zip(*events["splitting"])) + ref = ax.plot( + X, Y, marker="*", markersize=12, label="splitting events", **kwargs + )[0] + mappables.setdefault("events", []).append(ref) + + if len(events["merging"]) > 0: + X, Y = list(zip(*events["merging"])) + ref = ax.plot( + X, Y, marker="H", markersize=10, label="merging events", **kwargs + )[0] + mappables.setdefault("events", []).append(ref) + + return mappables + + def mean_by_segment(self, y, **kw): + kw["dtype"] = y.dtype + return self.map_segment(lambda x: x.mean(), y, **kw) + + def map_segment(self, method, y, same=True, **kw): + if same: + out = empty(y.shape, **kw) + else: + out = list() + for i, _, _ in self.iter_on(self.segment_track_array): + res = method(y[i]) + if same: + out[i] = res + else: + if isinstance(i, slice): + if i.start == i.stop: + continue + elif len(i) == 0: + continue + out.append(res) + if not same: + out = array(out) + return out + + def map_network(self, method, y, same=True, return_dict=False, **kw): + """ + Transform data `y` with method `method` for each track. + + :param Callable method: method to apply on each track + :param np.array y: data where to apply method + :param bool same: if True, return an array with the same size than y. Else, return a list with the edited tracks + :param bool return_dict: if None, mean values are used + :param float kw: to multiply field + :return: array or dict of result from method for each network + """ + + if same and return_dict: + raise NotImplementedError( + "both conditions 'same' and 'return_dict' should no be true" + ) + + if same: + out = empty(y.shape, **kw) + + elif return_dict: + out = dict() + + else: + out = list() + + for i, b0, b1 in self.iter_on(self.track): + res = method(y[i]) + if same: + out[i] = res + + elif return_dict: + out[b0] = res + + else: + if isinstance(i, slice): + if i.start == i.stop: + continue + elif len(i) == 0: + continue + out.append(res) + + if not same and not return_dict: + out = array(out) + return out + + def scatter_timeline( + self, + ax, + name, + factor=1, + event=True, + yfield=None, + yfactor=1, + method=None, + **kwargs, + ): + """ + Must be called on only one network + """ + self.only_one_network() + y = (self.segment if yfield is None else self.parse_varname(yfield)) * yfactor + if method == "all": + pass + else: + y = self.mean_by_segment(y) + mappables = dict() + if event: + mappables.update( + self.event_timeline(ax, field=yfield, method=method, factor=yfactor) + ) + if "c" not in kwargs: + v = self.parse_varname(name) + kwargs["c"] = v * factor + mappables["scatter"] = ax.scatter(self.time_datetime64, y, **kwargs) + return mappables + + def event_map(self, ax, **kwargs): + """Add the merging and splitting events to a map""" + j = 0 + mappables = dict() + symbol_kw = dict( + markersize=10, + color="k", + ) + symbol_kw.update(kwargs) + symbol_kw_split = symbol_kw.copy() + symbol_kw_split["markersize"] += 4 + for i, b0, b1 in self.iter_on("segment"): + nb = i.stop - i.start + if nb == 0: + continue + event_kw = dict(color=self.COLORS[j % self.NB_COLORS], ls="-", **kwargs) + i_n, i_p = ( + self.next_obs[i.stop - 1], + self.previous_obs[i.start], + ) + + if i_n != -1: + y0, y1 = self.lat[i.stop - 1], self.lat[i_n] + x0, x1 = self.lon[i.stop - 1], self.lon[i_n] + ax.plot((x0, x1), (y0, y1), **event_kw)[0] + ax.plot(x0, y0, marker="H", **symbol_kw)[0] + if i_p != -1: + y0, y1 = self.lat[i.start], self.lat[i_p] + x0, x1 = self.lon[i.start], self.lon[i_p] + ax.plot((x0, x1), (y0, y1), **event_kw)[0] + ax.plot(x0, y0, marker="*", **symbol_kw_split)[0] + + j += 1 + return mappables + + def scatter( + self, + ax, + name="time", + factor=1, + ref=None, + edgecolor_cycle=None, + **kwargs, + ): + """ + This function scatters the path of each network, with the merging and splitting events + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array,None name: + variable used to fill the contours, if None all elements have the same color + :param float,None ref: if defined, ref is used as western boundary + :param float factor: multiply value by + :param list edgecolor_cycle: list of colors + :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.scatter` + :return: a dict of scattered mappables + """ + mappables = dict() + nb_colors = len(edgecolor_cycle) if edgecolor_cycle else None + x = self.longitude + if ref is not None: + x = (x - ref) % 360 + ref + kwargs = kwargs.copy() + if nb_colors: + edgecolors = list() + seg_previous = self.segment[0] + j = 0 + for seg in self.segment: + if seg != seg_previous: + j += 1 + edgecolors.append(edgecolor_cycle[j % nb_colors]) + seg_previous = seg + mappables["edges"] = ax.scatter( + x, self.latitude, edgecolor=edgecolors, **kwargs + ) + kwargs.pop("linewidths", None) + kwargs["lw"] = 0 + if name is not None and "c" not in kwargs: + v = self.parse_varname(name) + kwargs["c"] = v * factor + mappables["scatter"] = ax.scatter(x, self.latitude, **kwargs) + return mappables + + def extract_event(self, indices): + nb = len(indices) + new = EddiesObservations( + nb, + track_extra_variables=self.track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=self.array_variables, + only_variables=self.only_variables, + raw_data=self.raw_data, + ) + + for k in new.fields: + new[k][:] = self[k][indices] + new.sign_type = self.sign_type + return new + + @property + def segment_track_array(self): + """Return a unique segment id when multiple networks are considered""" + if self._segment_track_array is None: + self._segment_track_array = build_unique_array(self.segment, self.track) + return self._segment_track_array + + def birth_event(self, only_index=False): + """Extract birth events.""" + i_start, _, _ = self.index_segment_track + indices = i_start[self.previous_obs[i_start] == -1] + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + generation_event = birth_event + + def death_event(self, only_index=False): + """Extract death events.""" + _, i_stop, _ = self.index_segment_track + indices = i_stop[self.next_obs[i_stop - 1] == -1] - 1 + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + dissipation_event = death_event + + def merging_event(self, triplet=False, only_index=False): + """Return observation after a merging event. + + If `triplet=True` return the eddy after a merging event, the eddy before the merging event, + and the eddy stopped due to merging. + """ + # Get start and stop for each segment, there is no empty segment + _, i1, _ = self.index_segment_track + # Get last index for each segment + i_stop = i1 - 1 + # Get target index + idx_m1 = self.next_obs[i_stop] + # Get mask and valid target + m = idx_m1 != -1 + idx_m1 = idx_m1[m] + # Sort by time event + i = self.time[idx_m1].argsort() + idx_m1 = idx_m1[i] + if triplet: + # Get obs before target + idx_m0_stop = i_stop[m][i] + idx_m0 = self.previous_obs[idx_m1].copy() + + if triplet: + if only_index: + return idx_m1, idx_m0, idx_m0_stop + else: + return ( + self.extract_event(idx_m1), + self.extract_event(idx_m0), + self.extract_event(idx_m0_stop), + ) + else: + idx_m1 = unique(idx_m1) + if only_index: + return idx_m1 + else: + return self.extract_event(idx_m1) + + def splitting_event(self, triplet=False, only_index=False): + """Return observation before a splitting event. + + If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event, + and the eddy starting due to splitting. + """ + # Get start and stop for each segment, there is no empty segment + i_start, _, _ = self.index_segment_track + # Get target index + idx_s0 = self.previous_obs[i_start] + # Get mask and valid target + m = idx_s0 != -1 + idx_s0 = idx_s0[m] + # Sort by time event + i = self.time[idx_s0].argsort() + idx_s0 = idx_s0[i] + if triplet: + # Get obs after target + idx_s1_start = i_start[m][i] + idx_s1 = self.next_obs[idx_s0].copy() + + if triplet: + if only_index: + return idx_s0, idx_s1, idx_s1_start + else: + return ( + self.extract_event(idx_s0), + self.extract_event(idx_s1), + self.extract_event(idx_s1_start), + ) + + else: + idx_s0 = unique(idx_s0) + if only_index: + return idx_s0 + else: + return self.extract_event(idx_s0) + + def dissociate_network(self): + """ + Dissociate networks with no known interaction (splitting/merging) + """ + tags = self.tag_segment() + if self.track[0] == 0: + tags -= 1 + self.track[:] = tags[self.segment_track_array] + return self.sort() + + def network_segment(self, id_network, id_segment): + return self.extract_with_mask(self.segment_slice(id_network, id_segment)) + + def network(self, id_network): + return self.extract_with_mask(self.network_slice(id_network)) + + def networks_mask(self, id_networks, segment=False): + if segment: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_segment_track + ) + else: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_network + ) + + def networks(self, id_networks): + return self.extract_with_mask( + generate_mask_from_ids( + array(id_networks), self.track.size, *self.index_network + ) + ) + + @property + def nb_network(self): + """ + Count and return number of network + """ + return (self.network_size() != 0).sum() + + @property + def nb_segment(self): + """ + Count and return number of segment in all network + """ + return self.index_segment_track[0].size + + def identify_in(self, other, size_min=1, segment=False): + """ + Return couple of segment or network which are equal + + :param other: other atlas to compare + :param int size_min: number of observation in network/segment + :param bool segment: segment mode + """ + if segment: + counts = self.segment_size(), other.segment_size() + i_self_ref, i_other_ref = ( + self.ref_segment_track_index, + other.ref_segment_track_index, + ) + var_id = "segment" + else: + counts = self.network_size(), other.network_size() + i_self_ref, i_other_ref = self.ref_index, other.ref_index + var_id = "track" + # object to contain index of couple + in_self, in_other = list(), list() + # We iterate on item of same size + for i_self, i_other, i0, _ in self.align_on(other, counts, all_ref=True): + if i0 < size_min: + continue + if isinstance(i_other, slice): + i_other = arange(i_other.start, i_other.stop) + # All_ref will give all item of self, sometime there is no things to compare with other + if i_other.size == 0: + id_self = i_self + i_self_ref + in_self.append(id_self) + in_other.append(-ones(id_self.shape, dtype=id_self.dtype)) + continue + if isinstance(i_self, slice): + i_self = arange(i_self.start, i_self.stop) + # We get absolute id + id_self, id_other = i_self + i_self_ref, i_other + i_other_ref + # We compute mask to select data + m_self, m_other = self.networks_mask(id_self, segment), other.networks_mask( + id_other, segment + ) + + # We extract obs + obs_self, obs_other = self.obs[m_self], other.obs[m_other] + x1, y1, t1 = obs_self["lon"], obs_self["lat"], obs_self["time"] + x2, y2, t2 = obs_other["lon"], obs_other["lat"], obs_other["time"] + + if segment: + ids1 = build_unique_array(obs_self["segment"], obs_self["track"]) + ids2 = build_unique_array(obs_other["segment"], obs_other["track"]) + label1 = self.segment_track_array[m_self] + label2 = other.segment_track_array[m_other] + else: + label1, label2 = ids1, ids2 = obs_self[var_id], obs_other[var_id] + # For each item we get index to sort + i01, indexs1, id1 = list(), List(), list() + for sl_self, id_, _ in self.iter_on(ids1): + i01.append(sl_self.start) + indexs1.append(obs_self[sl_self].argsort(order=["time", "lon", "lat"])) + id1.append(label1[sl_self.start]) + i02, indexs2, id2 = list(), List(), list() + for sl_other, _, _ in other.iter_on(ids2): + i02.append(sl_other.start) + indexs2.append( + obs_other[sl_other].argsort(order=["time", "lon", "lat"]) + ) + id2.append(label2[sl_other.start]) + + id1, id2 = array(id1), array(id2) + # We search item from self in item of others + i_local_target = same_position( + x1, y1, t1, x2, y2, t2, array(i01), array(i02), indexs1, indexs2 + ) + + # -1 => no item found in other dataset + m = i_local_target != -1 + in_self.append(id1) + track2_ = -ones(id1.shape, dtype="i4") + track2_[m] = id2[i_local_target[m]] + in_other.append(track2_) + + return concatenate(in_self), concatenate(in_other) + + @classmethod + def __tag_segment(cls, seg, tag, groups, connexions): + """ + Will set same temporary ID for each connected segment. + + :param int seg: current ID of segment + :param ing tag: temporary ID to set for segment and its connexion + :param array[int] groups: array where tag is stored + :param dict connexions: gives for one ID of segment all connected segments + """ + # If segments are already used we stop recursivity + if groups[seg] != 0: + return + # We set tag for this segment + groups[seg] = tag + # Get all connexions of this segment + segs = connexions.get(seg, None) + if segs is not None: + for seg in segs: + # For each connexion we apply same function + cls.__tag_segment(seg, tag, groups, connexions) + + def tag_segment(self): + """For each segment, method give a new network id, and all segment are connected + + :return array: for each unique seg id, it return new network id + """ + nb = self.segment_track_array[-1] + 1 + sub_group = zeros(nb, dtype="u4") + c = self.connexions(multi_network=True) + j = 1 + # for each available id + for i in range(nb): + # No connexions, no need to explore + if i not in c: + sub_group[i] = j + j += 1 + continue + # Skip if already set + if sub_group[i] != 0: + continue + # we tag an unset segments and explore all connexions + self.__tag_segment(i, j, sub_group, c) + j += 1 + return sub_group + + def fully_connected(self): + """Suspicious""" + raise Exception("Must be check") + self.only_one_network() + return self.tag_segment().shape[0] == 1 + + def first_is_trash(self): + """Check if first network is Trash + + :return bool: True if first network is trash + """ + i_start, i_stop, _ = self.index_segment_track + sl = slice(i_start[0], i_stop[0]) + return (self.previous_obs[sl] == -1).all() and (self.next_obs[sl] == -1).all() + + def remove_trash(self): + """ + Remove the lonely eddies (only 1 obs in segment, associated network number is 0) + """ + if self.first_is_trash(): + return self.extract_with_mask(self.track != 0) + else: + return self + + def plot(self, ax, ref=None, color_cycle=None, **kwargs): + """ + This function draws the path of each trajectory + + :param matplotlib.axes.Axes ax: ax to draw + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary + :param dict kwargs: keyword arguments for Axes.plot + :return: a list of matplotlib mappables + """ + kwargs = kwargs.copy() + if color_cycle is None: + color_cycle = self.COLORS + nb_colors = len(color_cycle) + mappables = list() + if "label" in kwargs: + kwargs["label"] = self.format_label(kwargs["label"]) + j = 0 + for i, _, _ in self.iter_on(self.segment_track_array): + nb = i.stop - i.start + if nb == 0: + continue + if nb_colors: + kwargs["color"] = color_cycle[j % nb_colors] + x, y = self.lon[i], self.lat[i] + if ref is not None: + x, y = wrap_longitude(x, y, ref, cut=True) + mappables.append(ax.plot(x, y, **kwargs)[0]) + j += 1 + return mappables + + def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None, return_mask=False): + """ + Remove short segments that don't connect several segments + + :param int nobs: Minimal number of observation to keep a segment + :param int ndays: Minimal number of days to keep a segment + :param int recursive: Run method N times more + :param int mask: if one or more observation of the segment are selected by mask, the segment is kept + + .. warning:: + It will remove short segment that splits from then merges with the same segment + """ + connexions = self.connexions(multi_network=True) + i0, i1, _ = self.index_segment_track + dt = self.time[i1 - 1] - self.time[i0] + 1 + nb = i1 - i0 + m = (dt >= ndays) * (nb >= nobs) + nb_connexions = array([len(connexions.get(i, tuple())) for i in where(~m)[0]]) + m[~m] = nb_connexions >= 2 + segments_keep = where(m)[0] + if mask is not None: + segments_keep = unique( + concatenate((segments_keep, self.segment_track_array[mask])) + ) + # get mask for selected obs + m = ~self.segment_mask(segments_keep) + if return_mask: + return ~m + self.track[m] = 0 + self.segment[m] = 0 + self.previous_obs[m] = -1 + self.previous_cost[m] = 0 + self.next_obs[m] = -1 + self.next_cost[m] = 0 + + m_previous = m[self.previous_obs] + self.previous_obs[m_previous] = -1 + self.previous_cost[m_previous] = 0 + m_next = m[self.next_obs] + self.next_obs[m_next] = -1 + self.next_cost[m_next] = 0 + + self.sort() + if recursive > 0: + self.remove_dead_end(nobs, ndays, recursive - 1) + + + + def extract_segment(self, segments, absolute=False): + """Extract given segments + + :param array,tuple,list segments: list of segment to extract + :param bool absolute: keep for compatibility, defaults to False + :return NetworkObservations: Return observations from selected segment + """ + if not absolute: + raise Exception("Not implemented") + return self.extract_with_mask(self.segment_mask(segments)) + + def segment_mask(self, segments): + """Get mask from list of segment + + :param list,array segments: absolute id of segment + """ + return generate_mask_from_ids( + array(segments), len(self), *self.index_segment_track + ) + + def get_mask_with_period(self, period): + """ + obtain mask within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :return: mask where period is defined + :rtype: np.array(bool) + + """ + dataset_period = self.period + p_min, p_max = period + if p_min > 0: + mask = self.time >= p_min + elif p_min < 0: + mask = self.time >= (dataset_period[0] - p_min) + else: + mask = ones(self.time.shape, dtype=bool_) + if p_max > 0: + mask *= self.time <= p_max + elif p_max < 0: + mask *= self.time <= (dataset_period[1] + p_max) + return mask + + def extract_with_period(self, period): + """ + Extract within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :return: Return all eddy trajectories in period + :rtype: NetworkObservations + + .. minigallery:: py_eddy_tracker.NetworkObservations.extract_with_period + """ + + return self.extract_with_mask(self.get_mask_with_period(period)) + + def extract_light_with_mask(self, mask, track_extra_variables=[]): + """extract data with mask, but only with variables used for coherence, aka self.array_variables + + :param mask: mask used to extract + :type mask: np.array(bool) + :return: new EddiesObservation with data wanted + :rtype: self + """ + + if isinstance(mask, slice): + nb_obs = mask.stop - mask.start + else: + nb_obs = mask.sum() + + # only time & contour_lon/lat_e/s + variables = ["time"] + self.array_variables + new = self.__class__( + size=nb_obs, + track_extra_variables=track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=self.array_variables, + only_variables=variables, + raw_data=self.raw_data, + ) + new.sign_type = self.sign_type + if nb_obs == 0: + logger.info("Empty dataset will be created") + else: + logger.info( + f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" + ) + + for field in variables + track_extra_variables: + logger.debug("Copy of field %s ...", field) + new.obs[field] = self.obs[field][mask] + + if ( + "previous_obs" in track_extra_variables + and "next_obs" in track_extra_variables + ): + # n & p must be re-index + n, p = self.next_obs[mask], self.previous_obs[mask] + # we add 2 for -1 index return index -1 + translate = -ones(len(self) + 1, dtype="i4") + translate[:-1][mask] = arange(nb_obs) + new.next_obs[:] = translate[n] + new.previous_obs[:] = translate[p] + + return new + + def extract_with_mask(self, mask): + """ + Extract a subset of observations. + + :param array(bool) mask: mask to select observations + :return: same object with selected observations + :rtype: self + """ + if isinstance(mask, slice): + nb_obs = mask.stop - mask.start + else: + nb_obs = mask.sum() + new = self.__class__.new_like(self, nb_obs) + new.sign_type = self.sign_type + if nb_obs == 0: + logger.info("Empty dataset will be created") + else: + logger.debug( + f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" + ) + for field in self.fields: + if field in ("next_obs", "previous_obs"): + continue + logger.debug("Copy of field %s ...", field) + new.obs[field] = self.obs[field][mask] + # n & p must be re-index + n, p = self.next_obs[mask], self.previous_obs[mask] + # we add 2 for -1 index return index -1 + translate = -ones(len(self) + 1, dtype="i4") + translate[:-1][mask] = arange(nb_obs) + new.next_obs[:] = translate[n] + new.previous_obs[:] = translate[p] + return new + + def analysis_coherence( + self, + date_function, + uv_params, + advection_mode="both", + n_days=14, + step_mesh=1.0 / 50, + output_name=None, + dissociate_network=False, + correct_close_events=0, + remove_dead_end=0, + ): + """Global function to analyse segments coherence, with network preprocessing. + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: nuber of days for advection + :param float step_mesh: step for particule mesh in degrees + :param str output_name: path/name for the output (without extension) to store the clean + network in .nc and the coherence results in .zarr. Works only for advection_mode = "both" + :param bool dissociate_network: If True apply + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.dissociate_network` + :param int correct_close_events: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.correct_close_events` + :param int remove_dead_end: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.remove_dead_end` + :return target_forward, target_bakward: 2D numpy.array with the eddy observation the + particles ended in after advection + :return target_forward, target_bakward: percentage of ending particles within the + eddy observation with regards to the starting number + """ + + if dissociate_network: + self.dissociate_network() + + if correct_close_events > 0: + self.correct_close_events(nb_days_max=correct_close_events) + + if remove_dead_end > 0: + network_clean = self.remove_dead_end(nobs=0, ndays=remove_dead_end) + else: + network_clean = self + + network_clean.numbering_segment() + + res = [] + if (advection_mode == "both") | (advection_mode == "forward"): + target_forward, pct_forward = network_clean.segment_coherence_forward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_forward, pct_forward] + + if (advection_mode == "both") | (advection_mode == "backward"): + target_backward, pct_backward = network_clean.segment_coherence_backward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_backward, pct_backward] + + if (output_name is not None) & (advection_mode == "both"): + # TODO : put some path verification? + # Save the clean network in netcdf + with netCDF4.Dataset(output_name + ".nc", "w") as fh: + network_clean.to_netcdf(fh) + # Save the results of particles advection in zarr + # zarr compression parameters + # TODO : check size? compression? + params_seg = dict() + params_pct = dict() + zg = zarr.open(output_name + ".zarr", mode="w") + zg.array("target_forward", target_forward, **params_seg) + zg.array("pct_forward", pct_forward, **params_pct) + zg.array("target_backward", target_backward, **params_seg) + zg.array("pct_backward", pct_backward, **params_pct) + + return network_clean, res + + def segment_coherence_backward( + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", + ): + """ + Percentage of particules and their targets after backward advection from a specific eddy. + + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date (see note) + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: days for advection + :param float step_mesh: step for particule mesh in degrees + :return: observations matchs, and percents + + .. note:: the param `date_function` should be something like : + + .. code-block:: python + + def date2file(julian_day): + date = datetime.timedelta(days=julian_day) + datetime.datetime( + 1950, 1, 1 + ) + + return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" + """ + shape = len(self), 2 + itb_final = -ones(shape, dtype="i4") + ptb_final = zeros(shape, dtype="i1") + + t_start, t_end = int(self.period[0]), int(self.period[1]) + + # dates = arange(t_start, t_start + n_days + 1) + dates = arange(t_start, min(t_start + n_days + 1, t_end + 1)) + first_files = [date_function(x) for x in dates] + + c = GridCollection.from_netcdf_list(first_files, dates, **uv_params) + first = True + range_start = t_start + n_days + range_end = t_end + 1 + + for _t in range(t_start + n_days, t_end + 1): + _timestamp = time.time() + t_shift = _t + + # skip first shift, because already included + if first: + first = False + else: + # add next date to GridCollection and delete last date + c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) + particle_candidate( + c, + self, + step_mesh, + _t, + itb_final, + ptb_final, + n_days=-n_days, + contour_start=contour_start, + contour_end=contour_end, + ) + logger.info( + ( + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" + ) + ) + + return itb_final, ptb_final + + def segment_coherence_forward( + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", + **kwargs, + ): + """ + Percentage of particules and their targets after forward advection from a specific eddy. + + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date (see note) + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: days for advection + :param float step_mesh: step for particule mesh in degrees + :return: observations matchs, and percents + + .. note:: the param `date_function` should be something like : + + .. code-block:: python + + def date2file(julian_day): + date = datetime.timedelta(days=julian_day) + datetime.datetime( + 1950, 1, 1 + ) + + return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" + """ + shape = len(self), 2 + itf_final = -ones(shape, dtype="i4") + ptf_final = zeros(shape, dtype="i1") + + t_start, t_end = int(self.period[0]), int(self.period[1]) + + dates = arange(t_start, min(t_start + n_days + 1, t_end + 1)) + first_files = [date_function(x) for x in dates] + + c = GridCollection.from_netcdf_list(first_files, dates, **uv_params) + first = True + range_start = t_start + range_end = t_end - n_days + 1 + + for _t in range(range_start, range_end): + _timestamp = time.time() + t_shift = _t + n_days + + # skip first shift, because already included + if first: + first = False + else: + # add next date to GridCollection and delete last date + c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) + particle_candidate( + c, + self, + step_mesh, + _t, + itf_final, + ptf_final, + n_days=n_days, + contour_start=contour_start, + contour_end=contour_end, + **kwargs, + ) + logger.info( + ( + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" + ) + ) + return itf_final, ptf_final + + def mask_obs_close_event(self, merging=True, spliting=True, dt=3): + """Build a mask of close observation from event + + :param n: Network + :param bool merging: select merging event, defaults to True + :param bool spliting: select splitting event, defaults to True + :param int dt: delta of time max , defaults to 3 + :return array: mask + """ + m = zeros(len(self), dtype="bool") + if merging: + i_target, ip1, ip2 = self.merging_event(triplet=True, only_index=True) + mask_follow_obs(m, self.previous_obs, self.time, ip1, dt) + mask_follow_obs(m, self.previous_obs, self.time, ip2, dt) + mask_follow_obs(m, self.next_obs, self.time, i_target, dt) + if spliting: + i_target, in1, in2 = self.splitting_event(triplet=True, only_index=True) + mask_follow_obs(m, self.next_obs, self.time, in1, dt) + mask_follow_obs(m, self.next_obs, self.time, in2, dt) + mask_follow_obs(m, self.previous_obs, self.time, i_target, dt) + return m + + def swap_track( + self, + length_main_max_after_event=2, + length_secondary_min_after_event=10, + delta_pct_max=-0.2, + ): + events = self.splitting_event(triplet=True, only_index=True) + count = 0 + for i_main, i1, i2 in zip(*events): + seg_main, _, seg2 = ( + self.segment_track_array[i_main], + self.segment_track_array[i1], + self.segment_track_array[i2], + ) + i_start, i_end, i0 = self.index_segment_track + # For splitting + last_index_main = i_end[seg_main - i0] - 1 + last_index_secondary = i_end[seg2 - i0] - 1 + last_main_next_obs = self.next_obs[last_index_main] + t_event, t_main_end, t_secondary_start, t_secondary_end = ( + self.time[i_main], + self.time[last_index_main], + self.time[i2], + self.time[last_index_secondary], + ) + dt_main, dt_secondary = ( + t_main_end - t_event, + t_secondary_end - t_secondary_start, + ) + delta_cost = self.previous_cost[i2] - self.previous_cost[i1] + if ( + dt_main <= length_main_max_after_event + and dt_secondary >= length_secondary_min_after_event + and last_main_next_obs == -1 + and delta_cost > delta_pct_max + ): + self.segment[i1 : last_index_main + 1] = self.segment[i2] + self.segment[i2 : last_index_secondary + 1] = self.segment[i_main] + count += 1 + logger.info("%d segmnent swap on %d", count, len(events[0])) + return self.sort() + + +class Network: + __slots__ = ( + "window", + "filenames", + "nb_input", + "buffer", + "memory", + ) + + NOGROUP = TrackEddiesObservations.NOGROUP + + def __init__(self, input_regex, window=5, intern=False, memory=False): + """ + Class to group observations by network + """ + self.window = window + self.buffer = Buffer(window, intern, memory) + self.memory = memory + + self.filenames = glob(input_regex) + self.filenames.sort() + self.nb_input = len(self.filenames) + + @classmethod + def from_eddiesobservations(cls, observations, *args, **kwargs): + new = cls("", *args, **kwargs) + new.filenames = observations + new.nb_input = len(new.filenames) + return new + def get_group_array(self, results, nb_obs): """With a loop on all pair of index, we will label each obs with a group number """ - nb_obs = array(nb_obs) + nb_obs = array(nb_obs, dtype="u4") day_start = nb_obs.cumsum() - nb_obs gr = empty(nb_obs.sum(), dtype="u4") gr[:] = self.NOGROUP + merge_id = list() id_free = 1 for i, j, ii, ij in results: gr_i = gr[slice(day_start[i], day_start[i] + nb_obs[i])] @@ -89,41 +2026,87 @@ def get_group_array(self, results, nb_obs): if m.any(): # Merge of group, ref over etu for i_, j_ in zip(ii[m], ij[m]): - gr_i_, gr_j_ = gr_i[i_], gr_j[j_] - gr[gr == gr_i_] = gr_j_ - return gr + g0, g1 = gr_i[i_], gr_j[j_] + if g0 > g1: + g0, g1 = g1, g0 + merge_id.append((g0, g1)) + gr_transfer = self.group_translator(id_free, set(merge_id)) + return gr_transfer[gr] + + @staticmethod + def group_translator(nb, duos): + """ + Create a translator with all duos + + :param int nb: size of translator + :param set((int, int)) duos: set of all groups that must be joined + + :Example: + + >>> NetworkObservations.group_translator(5, ((0, 1), (0, 2), (1, 3))) + [3, 3, 3, 3, 5] + """ + translate = arange(nb, dtype="u4") + for i, j in sorted(duos): + gr_i, gr_j = translate[i], translate[j] + if gr_i != gr_j: + apply_replace(translate, gr_i, gr_j) + return translate + + def group_observations(self, min_overlap=0.2, minimal_area=False, **kwargs): + """Store every interaction between identifications + + :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union, by default False + :param float min_overlap: minimum overlap area to associate observations, by default 0.2 + + :return: + :rtype: TrackEddiesObservations + """ - def group_observations(self, **kwargs): results, nb_obs = list(), list() # To display print only in INFO display_iteration = logger.getEffectiveLevel() == logging.INFO + for i, filename in enumerate(self.filenames): if display_iteration: print(f"{filename} compared to {self.window} next", end="\r") - # Load observations with function to buffered observations - xi, yi = self.load_contour(filename) + # Load observations with function to buffer observations + xi, yi = self.buffer.load_contour(filename) # Append number of observations by filename nb_obs.append(xi.shape[0]) for j in range(i + 1, min(self.window + i + 1, self.nb_input)): - xj, yj = self.load_contour(self.filenames[j]) + xj, yj = self.buffer.load_contour(self.filenames[j]) ii, ij = bbox_intersection(xi, yi, xj, yj) - m = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) > 0.2 + m = ( + vertice_overlap( + xi[ii], + yi[ii], + xj[ij], + yj[ij], + minimal_area=minimal_area, + min_overlap=min_overlap, + **kwargs, + ) + != 0 + ) results.append((i, j, ii[m], ij[m])) if display_iteration: print() gr = self.get_group_array(results, nb_obs) + nb_alone, nb_obs, nb_gr = (gr == self.NOGROUP).sum(), len(gr), len(unique(gr)) logger.info( - f"{(gr == self.NOGROUP).sum()} alone / {len(gr)} obs, {len(unique(gr))} groups" + f"{nb_alone} alone / {nb_obs} obs, {nb_gr} groups, " + f"{nb_alone * 100. / nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group" ) return gr - def build_dataset(self, group): + def build_dataset(self, group, raw_data=True): nb_obs = group.shape[0] - model = EddiesObservations.load_file(self.filenames[-1], raw_data=True) + model = TrackEddiesObservations.load_file(self.filenames[-1], raw_data=raw_data) eddies = TrackEddiesObservations.new_like(model, nb_obs) eddies.sign_type = model.sign_type - # Get new index to re-order observation by group + # Get new index to re-order observations by groups new_i = get_next_index(group) display_iteration = logger.getEffectiveLevel() == logging.INFO elements = eddies.elements @@ -132,7 +2115,12 @@ def build_dataset(self, group): for filename in self.filenames: if display_iteration: print(f"Load {filename} to copy", end="\r") - e = EddiesObservations.load_file(filename, raw_data=True) + if self.memory: + # Only if netcdf + with open(filename, "rb") as h: + e = TrackEddiesObservations.load_file(h, raw_data=raw_data) + else: + e = TrackEddiesObservations.load_file(filename, raw_data=raw_data) stop = i + len(e) sl = slice(i, stop) for element in elements: @@ -140,15 +2128,101 @@ def build_dataset(self, group): i = stop if display_iteration: print() - eddies = eddies.add_fields(('track',)) - eddies.obs['track'][new_i] = group + eddies.track[new_i] = group return eddies @njit(cache=True) -def get_next_index(gr): - """Return for each obs index the new position to join all group +def get_percentile_on_following_obs( + i, indexs, percents, follow_obs, t, segment, i_target, window, q=50, nb_min=1 +): + """Get stat on a part of segment close of an event + + :param int i: index to follow + :param array indexs: indexs from coherence + :param array percents: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int i_target: index of target + :param int window: time window of search + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return float : return statistic """ + last_t, segment_follow = t[i], segment[i] + segment_target = segment[i_target] + percent_target = empty(window, dtype=percents.dtype) + j = 0 + while abs(last_t - t[i]) < window and i != -1 and segment_follow == segment[i]: + # Iter on primary & secondary + for index, percent in zip(indexs[i], percents[i]): + if index != -1 and segment[index] == segment_target: + percent_target[j] = percent + j += 1 + i = follow_obs[i] + if j < nb_min: + return nan + return percentile(percent_target[:j], q) + + +@njit(cache=True) +def get_percentile_around_event( + i, + i1, + i2, + ind, + pct, + follow_obs, + t, + segment, + window=10, + follow_parent=False, + q=50, + nb_min=1, +): + """Get stat around event + + :param array[int] i: Indexs of target + :param array[int] i1: Indexs of primary origin + :param array[int] i2: Indexs of secondary origin + :param array ind: indexs from coherence + :param array pct: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int window: time window of search, defaults to 10 + :param bool follow_parent: Follow parent instead of child, defaults to False + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return (array,array) : statistic for each event + """ + stat1 = empty(i.size, dtype=nb_types.float32) + stat2 = empty(i.size, dtype=nb_types.float32) + # iter on event + for j, (i_, i1_, i2_) in enumerate(zip(i, i1, i2)): + if follow_parent: + # We follow parent + stat1[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i1_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i2_, window, q, nb_min + ) + else: + # We follow child + stat1[j] = get_percentile_on_following_obs( + i1_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i2_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + return stat1, stat2 + + +@njit(cache=True) +def get_next_index(gr): + """Return for each obs index the new position to join all groups""" nb_obs_gr = bincount(gr) i_gr = nb_obs_gr.cumsum() - nb_obs_gr new_index = empty(gr.shape, dtype=uint32) @@ -156,3 +2230,139 @@ def get_next_index(gr): new_index[i] = i_gr[g] i_gr[g] += 1 return new_index + + +@njit(cache=True) +def apply_replace(x, x0, x1): + nb = x.shape[0] + for i in range(nb): + if x[i] == x0: + x[i] = x1 + + +@njit(cache=True) +def build_unique_array(id1, id2): + """Give a unique id for each (id1, id2) with id1 and id2 increasing monotonically""" + k = 0 + new_id = empty(id1.shape, dtype=id1.dtype) + id1_previous = id1[0] + id2_previous = id2[0] + for i in range(id1.shape[0]): + id1_, id2_ = id1[i], id2[i] + if id1_ != id1_previous or id2_ != id2_previous: + k += 1 + new_id[i] = k + id1_previous, id2_previous = id1_, id2_ + return new_id + + +@njit(cache=True) +def new_numbering(segs, start=0): + nb = len(segs) + s0 = segs[0] + j = start + for i in range(nb): + if segs[i] != s0: + s0 = segs[i] + j += 1 + segs[i] = j + + +@njit(cache=True) +def ptp(values): + return values.max() - values.min() + + +@njit(cache=True) +def generate_mask_from_ids(id_networks, nb, istart, iend, i0): + """From list of id, we generate a mask + + :param array id_networks: list of ids + :param int nb: size of mask + :param array istart: first index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param array iend: last index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param int i0: ref index from :py:meth:`~py_eddy_tracker.generic.build_index` + :return array: return a mask + """ + m = zeros(nb, dtype="bool") + for i in id_networks: + for j in range(istart[i - i0], iend[i - i0]): + m[j] = True + return m + + +@njit(cache=True) +def same_position(x0, y0, t0, x1, y1, t1, i00, i01, i0, i1): + """Return index of track/segment found in other dataset + + :param array x0: + :param array y0: + :param array t0: + :param array x1: + :param array y1: + :param array t1: + :param array i00: First index of track/segment/network in dataset0 + :param array i01: First index of track/segment/network in dataset1 + :param List(array) i0: list of array which contain index to order dataset0 + :param List(array) i1: list of array which contain index to order dataset1 + :return array: index of dataset1 which match with dataset0, -1 => no match + """ + nb0, nb1 = i00.size, i01.size + i_target = -ones(nb0, dtype="i4") + # To avoid to compare multiple time, if already match + used1 = zeros(nb1, dtype="bool") + for j0 in range(nb0): + for j1 in range(nb1): + if used1[j1]: + continue + test = True + for i0_, i1_ in zip(i0[j0], i1[j1]): + i0_ += i00[j0] + i1_ += i01[j1] + if t0[i0_] != t1[i1_] or x0[i0_] != x1[i1_] or y0[i0_] != y1[i1_]: + test = False + break + if test: + i_target[j0] = j1 + used1[j1] = True + break + return i_target + + +@njit(cache=True) +def mask_follow_obs(m, next_obs, time, indexs, dt=3): + """Generate a mask to select close obs in time from index + + :param array m: mask to fill with True + :param array next_obs: index of the next observation + :param array time: time of each obs + :param array indexs: index to start follow + :param int dt: delta of time max from index, defaults to 3 + """ + for i in indexs: + t0 = time[i] + m[i] = True + i_next = next_obs[i] + dt_ = abs(time[i_next] - t0) + while dt_ < dt and i_next != -1: + m[i_next] = True + i_next = next_obs[i_next] + dt_ = abs(time[i_next] - t0) + + +@njit(cache=True) +def get_period_with_index(t, i0, i1): + """Return peek to peek cover by each slice define by i0 and i1 + + :param array t: array which contain values to estimate spread + :param array i0: index which determine start of slice + :param array i1: index which determine end of slice + :return array: Peek to peek of t + """ + periods = np.empty(i0.size, t.dtype) + for i in range(i0.size): + if i1[i] == i0[i]: + periods[i] = 0 + continue + periods[i] = t[i0[i] : i1[i]].ptp() + return periods diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index f764ca87..b39f7f83 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -1,83 +1,111 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2017 by Evan Mason and Antoine Delepoulle -Email: emason@imedea.uib-csic.es -=========================================================================== - -observation.py - -Version 3.0.0 - -=========================================================================== - +Base class to manage eddy observation """ +from datetime import datetime +from io import BufferedReader, BytesIO import logging -import zarr +from tarfile import ExFileObject +from tokenize import TokenError + +from Polygon import Polygon +from matplotlib.cm import get_cmap +from matplotlib.collections import LineCollection, PolyCollection +from matplotlib.colors import Normalize +from netCDF4 import Dataset +from numba import njit, types as numba_types from numpy import ( - zeros, - where, - unique, - ma, - cos, - radians, - isnan, - ones, - ndarray, - floor, - array, - empty, absolute, - concatenate, - float64, - ceil, arange, + array, + array_equal, + ceil, + concatenate, + cos, + datetime64, + digitize, + empty, + errstate, + floor, + histogram, histogram2d, + in1d, + isnan, linspace, + ma, + nan, + ndarray, + ones, + percentile, + radians, sin, + unique, + where, + zeros, ) -from netCDF4 import Dataset -from datetime import datetime -from numba import njit, types as numba_types -from Polygon import Polygon +import packaging.version from pint import UnitRegistry from pint.errors import UndefinedUnitError -from tokenize import TokenError -from tarfile import ExFileObject -from matplotlib.path import Path as BasePath -from .. import VAR_DESCR, VAR_DESCR_inv +import zarr + +from .. import VAR_DESCR, VAR_DESCR_inv, __version__ from ..generic import ( - distance_grid, + bbox_indice_regular, + build_index, distance, + distance_grid, flatten_line_matrix, - wrap_longitude, + hist_numba, local_to_coordinates, reverse_index, + window_index, + wrap_longitude, +) +from ..poly import ( + bbox_intersection, + close_center, + convexs, + create_meshed_particles, + create_vertice, + get_pixel_in_regular, + insidepoly, + poly_indexs, + reduce_size, + vertice_overlap, ) -from ..poly import bbox_intersection, vertice_overlap, create_vertice logger = logging.getLogger("pet") +# keep only major and minor version number +_software_version_reduced = packaging.version.Version( + "{v.major}.{v.minor}".format(v=packaging.version.parse(__version__)) +) +_display_check_warning = True + + +def _check_versions(version): + """Check if version of py_eddy_tracker used to create the file is compatible with software version + + if not, warn user with both versions + + :param version: string version of software used to create the file. If None, version was not provided + :type version: str, None + """ + if not _display_check_warning: + return + file_version = packaging.version.parse(version) if version is not None else None + if file_version is None or file_version < _software_version_reduced: + logger.warning( + "File was created with py-eddy-tracker version '%s' but software version is '%s'", + file_version, + _software_version_reduced, + ) + @njit(cache=True, fastmath=True) def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5): """ - work only if major is an array but faster * 6 + Work only if major is an array but faster * 6 """ # c = (major ** 2 - minor ** 2) ** .5 + major c = major @@ -87,7 +115,7 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 # Focal f_right = lon0 f_left = f_right - (c - minor) - # Ellips center + # Ellipse center x_c = (f_left + f_right) * 0.5 nb_0, nb_1 = lat0.shape[0], lat1.shape[0] @@ -105,16 +133,36 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 if dx > major[j]: m[j, i] = False continue - d_normalize = dx ** 2 / major[j] ** 2 + dy ** 2 / minor ** 2 + d_normalize = dx**2 / major[j] ** 2 + dy**2 / minor**2 m[j, i] = d_normalize < 1.0 return m +class Table(object): + def __init__(self, values): + self.values = values + + def _repr_html_(self): + rows = list() + if isinstance(self.values, ndarray): + row = "\n".join([f"{v}" for v in self.values.dtype.names]) + rows.append(f"{row}") + for row in self.values: + row = "\n".join([f"{v}" for v in row]) + rows.append(f"{row}") + rows = "\n".join(rows) + return ( + f'' + f'' + f"{rows}" + f"
" + f"
" + ) + + class EddiesObservations(object): """ - Class to hold eddy properties *amplitude* and counts of - *local maxima/minima* within a closed region of a sea level anomaly field. - + Class to store eddy observations. """ __slots__ = ( @@ -125,6 +173,7 @@ class EddiesObservations(object): "observations", "sign_type", "raw_data", + "period_", ) ELEMENTS = [ @@ -137,6 +186,8 @@ class EddiesObservations(object): "time", "shape_error_e", "shape_error_s", + "speed_area", + "effective_area", "nb_contour_selected", "num_point_e", "num_point_s", @@ -145,6 +196,27 @@ class EddiesObservations(object): "height_inner_contour", ] + COLORS = [ + "sienna", + "red", + "darkorange", + "gold", + "palegreen", + "limegreen", + "forestgreen", + "mediumblue", + "dodgerblue", + "lightskyblue", + "violet", + "blueviolet", + "darkmagenta", + "darkgrey", + "dimgrey", + "steelblue", + ] + + NB_COLORS = len(COLORS) + def __init__( self, size=0, @@ -154,6 +226,7 @@ def __init__( only_variables=None, raw_data=False, ): + self.period_ = None self.only_variables = only_variables self.raw_data = raw_data self.track_extra_variables = ( @@ -167,25 +240,20 @@ def __init__( self.observations = zeros(size, dtype=self.dtype) self.sign_type = None - @property - def longitude(self): - return self.observations["lon"] - - @property - def latitude(self): - return self.observations["lat"] - - @property - def time(self): - return self.observations["time"] - @property def tracks(self): - return self.observations["track"] + return self.track - @property - def observation_number(self): - return self.observations["n"] + def __eq__(self, other): + if self.sign_type != other.sign_type: + return False + if self.dtype != other.dtype: + return False + return array_equal(self.obs, other.obs) + + def get_color(self, i): + """Return colors as a cyclic list""" + return self.COLORS[i % self.NB_COLORS] @property def sign_legend(self): @@ -195,68 +263,236 @@ def sign_legend(self): def shape(self): return self.observations.shape + def get_infos(self): + infos = dict( + bins_lat=(-90, -60, -15, 15, 60, 90), + bins_amplitude=array((0, 1, 2, 3, 4, 5, 10, 500)), + bins_radius=array((0, 15, 30, 45, 60, 75, 100, 200, 2000)), + nb_obs=len(self), + ) + t0, t1 = self.period + infos["t0"], infos["t1"] = t0, t1 + infos["period"] = t1 - t0 + 1 + return infos + + def _repr_html_(self): + infos = self.get_infos() + return f"""{infos['nb_obs']} observations from {infos['t0']} to {infos['t1']} """ + + def parse_varname(self, name): + return self[name] if isinstance(name, str) else name + + def hist(self, varname, x, bins, percent=False, mean=False, nb=False): + """Build histograms. + + :param str,array varname: variable to use to compute stat + :param str,array x: variable to use to know in which bins + :param array bins: + :param bool percent: normalized by sum of all bins + :param bool mean: compute mean by bins + :param bool nb: only count by bins + :return: value by bins + :rtype: array + """ + x = self.parse_varname(x) + if nb: + v = hist_numba(x, bins=bins)[0] + else: + v = histogram(x, bins=bins, weights=self.parse_varname(varname))[0] + if percent: + v = v.astype("f4") / v.sum() * 100 + elif mean: + v /= hist_numba(x, bins=bins)[0] + return v + + @staticmethod + def box_display(value): + """Return values evenly spaced with few numbers""" + return "".join([f"{v_:10.2f}" for v_ in value]) + + @property + def fields(self): + return list(self.obs.dtype.names) + + def field_table(self): + """ + Produce description table of the fields available in this object + """ + rows = [("Name (Unit)", "Long name", "Scale factor", "Offset")] + names = self.fields + names.sort() + for field in names: + infos = VAR_DESCR[field] + rows.append( + ( + f"{infos.get('nc_name', field)} ({infos['nc_attr'].get('units', '')})", + infos["nc_attr"].get("long_name", "").capitalize(), + infos.get("scale_factor", ""), + infos.get("add_offset", ""), + ) + ) + return Table(rows) + def __repr__(self): - return "%d observations" % len(self.observations) + """ + Return general informations on dataset as strings. - def __getitem__(self, attr): + :return: informations on datasets + :rtype: str + """ + t0, t1 = self.period + period = t1 - t0 + 1 + bins_lat = (-90, -60, -15, 15, 60, 90) + bins_amplitude = array((0, 1, 2, 3, 4, 5, 10, 500)) + bins_radius = array((0, 15, 30, 45, 60, 75, 100, 200, 2000)) + nb_obs = len(self) + + return f""" | {nb_obs} observations from {t0} to {t1} ({period} days, ~{nb_obs / period:.0f} obs/day) + | Speed area : {self.speed_area.sum() / period / 1e12:.2f} Mkm²/day + | Effective area : {self.effective_area.sum() / period / 1e12:.2f} Mkm²/day + ----Distribution in Amplitude: + | Amplitude bounds (cm) {self.box_display(bins_amplitude)} + | Percent of eddies : { + self.box_display(self.hist('time', 'amplitude', bins_amplitude / 100., percent=True, nb=True))} + ----Distribution in Radius: + | Speed radius (km) {self.box_display(bins_radius)} + | Percent of eddies : { + self.box_display(self.hist('time', 'radius_s', bins_radius * 1000., percent=True, nb=True))} + | Effective radius (km) {self.box_display(bins_radius)} + | Percent of eddies : { + self.box_display(self.hist('time', 'radius_e', bins_radius * 1000., percent=True, nb=True))} + ----Distribution in Latitude + Latitude bounds {self.box_display(bins_lat)} + Percent of eddies : {self.box_display(self.hist('time', 'lat', bins_lat, percent=True, nb=True))} + Percent of speed area : {self.box_display(self.hist('speed_area', 'lat', bins_lat, percent=True))} + Percent of effective area : {self.box_display(self.hist('effective_area', 'lat', bins_lat, percent=True))} + Mean speed radius (km) : {self.box_display(self.hist('radius_s', 'lat', bins_lat, mean=True) / 1000.)} + Mean effective radius (km): {self.box_display(self.hist('radius_e', 'lat', bins_lat, mean=True) / 1000.)} + Mean amplitude (cm) : {self.box_display(self.hist('amplitude', 'lat', bins_lat, mean=True) * 100.)}""" + + def __dir__(self): + """Provide method name lookup and completion.""" + base = set(dir(type(self))) + intern_name = set(self.elements) + extern_name = set([VAR_DESCR[k]["nc_name"] for k in intern_name]) + # Must be check in init not here + if base & intern_name: + logger.warning( + "Some variable name have a common name with class attrs: %s", + base & intern_name, + ) + if base & extern_name: + logger.warning( + "Some variable name have a common name with class attrs: %s", + base & extern_name, + ) + return sorted(base.union(intern_name).union(extern_name)) + + def __getitem__(self, attr: str): if attr in self.elements: - return self.observations[attr] + return self.obs[attr] + elif attr in VAR_DESCR_inv: + return self.obs[VAR_DESCR_inv[attr]] + elif attr in ("lifetime", "age"): + return getattr(self, attr) raise KeyError("%s unknown" % attr) + def __getattr__(self, attr): + if attr in self.elements: + return self.obs[attr] + elif attr in VAR_DESCR_inv: + return self.obs[VAR_DESCR_inv[attr]] + raise AttributeError( + "{!r} object has no attribute {!r}".format(type(self).__name__, attr) + ) + + @classmethod + def needed_variable(cls): + return None + @classmethod def obs_dimension(cls, handler): for candidate in ("obs", "Nobs", "observation", "i"): if candidate in handler.dimensions.keys(): return candidate - def add_fields(self, fields): + def remove_fields(self, *fields): + """ + Copy with fields listed remove + """ + nb_obs = len(self) + fields = set(fields) + only_variables = set(self.fields) - fields + track_extra_variables = set(self.track_extra_variables) - fields + array_variables = set(self.array_variables) - fields + new = self.__class__( + size=nb_obs, + track_extra_variables=track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=array_variables, + only_variables=only_variables, + raw_data=self.raw_data, + ) + new.sign_type = self.sign_type + for name in new.fields: + logger.debug("Copy of field %s ...", name) + new.obs[name] = self.obs[name] + return new + + def add_fields(self, fields=list(), array_fields=list()): """ - Add a new field + Add a new field. """ - nb_obs = self.obs.shape[0] + nb_obs = len(self) new = self.__class__( size=nb_obs, track_extra_variables=list( concatenate((self.track_extra_variables, fields)) ), track_array_variables=self.track_array_variables, - array_variables=self.array_variables, - only_variables=list(concatenate((self.obs.dtype.names, fields))), + array_variables=list(concatenate((self.array_variables, array_fields))), + only_variables=list(concatenate((self.fields, fields, array_fields))), raw_data=self.raw_data, ) new.sign_type = self.sign_type - for field in self.obs.dtype.descr: - logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var] + for name in self.fields: + logger.debug("Copy of field %s ...", name) + new.obs[name] = self.obs[name] return new def add_rotation_type(self): new = self.add_fields(("type_cyc",)) - new.observations["type_cyc"] = self.sign_type + new.type_cyc[:] = self.sign_type return new - def circle_contour(self): + def circle_contour(self, only_virtual=False, factor=1): + """ + Set contours as circles from radius and center data. + + .. minigallery:: py_eddy_tracker.EddiesObservations.circle_contour + """ angle = radians(linspace(0, 360, self.track_array_variables)) x_norm, y_norm = cos(angle), sin(angle) + radius_s = "contour_lon_s" in self.fields + radius_e = "contour_lon_e" in self.fields for i, obs in enumerate(self): - r_s, r_e, x, y = ( - obs["radius_s"], - obs["radius_e"], - obs["lon"], - obs["lat"], - ) - obs["contour_lon_s"], obs["contour_lat_s"] = local_to_coordinates( - x_norm * r_s, y_norm * r_s, x, y - ) - obs["contour_lon_e"], obs["contour_lat_e"] = local_to_coordinates( - x_norm * r_e, y_norm * r_e, x, y - ) + if only_virtual and not obs["virtual"]: + continue + x, y = obs["lon"], obs["lat"] + if radius_s: + r_s = obs["radius_s"] * factor + obs["contour_lon_s"], obs["contour_lat_s"] = local_to_coordinates( + x_norm * r_s, y_norm * r_s, x, y + ) + if radius_e: + r_e = obs["radius_e"] * factor + obs["contour_lon_e"], obs["contour_lat_e"] = local_to_coordinates( + x_norm * r_e, y_norm * r_e, x, y + ) @property def dtype(self): - """Return dtype to build numpy array - """ + """Return dtype to build numpy array.""" dtype = list() for elt in self.elements: data_type = ( @@ -272,8 +508,7 @@ def dtype(self): @property def elements(self): - """Return all variable name - """ + """Return all the names of the variables.""" elements = [i for i in self.ELEMENTS] if self.track_array_variables > 0: elements += self.array_variables @@ -285,8 +520,7 @@ def elements(self): return list(set(elements)) def coherence(self, other): - """Check coherence between two dataset - """ + """Check coherence between two datasets.""" test = self.track_extra_variables == other.track_extra_variables test *= self.track_array_variables == other.track_array_variables test *= self.array_variables == other.array_variables @@ -311,20 +545,19 @@ def concatenate(cls, observations): return eddies def merge(self, other): - """Merge two dataset - """ + """Merge two datasets.""" nb_obs_self = len(self) nb_obs = nb_obs_self + len(other) eddies = self.new_like(self, nb_obs) - other_keys = other.obs.dtype.fields.keys() - self_keys = self.obs.dtype.fields.keys() - for key in eddies.obs.dtype.fields.keys(): + other_keys = other.fields + self_keys = self.fields + for key in eddies.fields: eddies.obs[key][:nb_obs_self] = self.obs[key][:] if key in other_keys: eddies.obs[key][nb_obs_self:] = other.obs[key][:] if "track" in other_keys and "track" in self_keys: - last_track = eddies.obs["track"][nb_obs_self - 1] + 1 - eddies.obs["track"][nb_obs_self:] += last_track + last_track = eddies.track[nb_obs_self - 1] + 1 + eddies.track[nb_obs_self:] += last_track eddies.sign_type = self.sign_type return eddies @@ -333,8 +566,7 @@ def reset(self): @property def obs(self): - """return an array observations - """ + """Return observations.""" return self.observations def __len__(self): @@ -344,13 +576,98 @@ def __iter__(self): for obs in self.obs: yield obs - def insert_observations(self, other, index): - """Insert other obs in self at the index + def iter_on(self, xname, window=None, bins=None): + """ + Yield observation group for each bin. + + :param str,array xname: + :param float,None window: if defined we use a moving window with value like half window + :param array bins: bounds of each bin + :yield array,float,float: index in self, lower bound, upper bound + + .. minigallery:: py_eddy_tracker.EddiesObservations.iter_on + """ + x = self.parse_varname(xname) + if window is not None: + x0 = arange(x.min(), x.max()) if bins is None else array(bins) + i_ordered, first_index, last_index = window_index(x, x0, window) + for x_, i0, i1 in zip(x0, first_index, last_index): + yield i_ordered[i0:i1], x_ - window, x_ + window + else: + d = x[1:] - x[:-1] + if bins is None: + bins = arange(x.min(), x.max() + 2) + elif not isinstance(bins, ndarray): + bins = array(bins) + nb_bins = len(bins) - 1 + + # Not monotonous + if (d < 0).any(): + # If bins cover a small part of value + test, translate, x = iter_mode_reduce(x, bins) + # convert value in bins number + i = numba_digitize(x, bins) - 1 + # Order by bins + i_sort = i.argsort() + # If in reduced mode we will translate i_sort in full array index + i_sort_ = translate[i_sort] if test else i_sort + # Bound for each bins in sorting view + i0, i1, _ = build_index(i[i_sort]) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i_sort[i0_]] + if i_bins == -1 or i_bins == nb_bins: + continue + yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1] + else: + i = numba_digitize(x, bins) - 1 + i0, i1, _ = build_index(i) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i0_] + yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] + + def align_on(self, other, var_name="time", all_ref=False, **kwargs): + """ + Align the variable indices of two datasets. + + :param other: other compare with self + :param str,tuple var_name: variable name to align or two array, defaults to "time" + :param bool all_ref: yield all value of ref, if false only common value, defaults to False + :yield array,array,float,float: index in self, index in other, lower bound, upper bound + + .. minigallery:: py_eddy_tracker.EddiesObservations.align_on """ + if isinstance(var_name, str): + iter_self = self.iter_on(var_name, **kwargs) + iter_other = other.iter_on(var_name, **kwargs) + else: + iter_self = self.iter_on(var_name[0], **kwargs) + iter_other = other.iter_on(var_name[1], **kwargs) + indexs_other, b0_other, b1_other = iter_other.__next__() + for indexs_self, b0_self, b1_self in iter_self: + if b0_self > b0_other: + try: + while b0_other < b0_self: + indexs_other, b0_other, b1_other = iter_other.__next__() + except StopIteration: + break + if b0_self < b0_other: + if all_ref: + yield indexs_self, empty( + 0, dtype=indexs_self.dtype + ), b0_self, b1_self + continue + yield indexs_self, indexs_other, b0_self, b1_self + + def insert_observations(self, other, index): + """Insert other obs in self at the given index.""" if not self.coherence(other): raise Exception("Observations with no coherence") - insert_size = len(other.obs) - self_size = len(self.obs) + insert_size = len(other) + self_size = len(self) new_size = self_size + insert_size if self_size == 0: self.observations = other.obs @@ -367,23 +684,30 @@ def insert_observations(self, other, index): return self def append(self, other): - """Merge - """ + """Merge.""" return self + other def __add__(self, other): return self.insert_observations(other, -1) def distance(self, other): - """ Use haversine distance for distance matrix between every self and - other eddies""" - return distance_grid( - self.obs["lon"], self.obs["lat"], other.obs["lon"], other.obs["lat"] - ) + """Use haversine distance for distance matrix between every self and + other eddies.""" + return distance_grid(self.lon, self.lat, other.lon, other.lat) + + def __copy__(self): + eddies = self.new_like(self, len(self)) + for k in self.fields: + eddies[k][:] = self[k][:] + eddies.sign_type = self.sign_type + return eddies - @staticmethod - def new_like(eddies, new_size): - return eddies.__class__( + def copy(self): + return self.__copy__() + + @classmethod + def new_like(cls, eddies, new_size: int): + return cls( new_size, track_extra_variables=eddies.track_extra_variables, track_array_variables=eddies.track_array_variables, @@ -393,153 +717,266 @@ def new_like(eddies, new_size): ) def index(self, index, reverse=False): - """Return obs from self at the index - """ + """Return obs from self at the index.""" if reverse: index = reverse_index(index, len(self)) size = 1 if hasattr(index, "__iter__"): size = len(index) + elif isinstance(index, slice): + size = index.stop - index.start eddies = self.new_like(self, size) eddies.obs[:] = self.obs[index] + eddies.sign_type = self.sign_type return eddies @staticmethod def zarr_dimension(filename): - h = zarr.open(filename) + if isinstance(filename, zarr.storage.MutableMapping): + h = filename + else: + h = zarr.open(filename) + dims = list() for varname in h: - dims.extend(list(getattr(h, varname).shape)) - return set(dims) + shape = getattr(h, varname).shape + if len(shape) > len(dims): + dims = shape + return dims @classmethod def load_file(cls, filename, **kwargs): + """ + Load the netcdf or the zarr file. + + Load only latitude and longitude on the first 300 obs : + + .. code-block:: python + + kwargs_latlon_300 = dict( + include_vars=[ + "longitude", + "latitude", + ], + indexs=dict(obs=slice(0, 300)), + ) + small_dataset = TrackEddiesObservations.load_file( + filename, **kwargs_latlon_300 + ) + + For `**kwargs` look at :py:meth:`load_from_zarr` or :py:meth:`load_from_netcdf` + """ filename_ = ( filename.filename if isinstance(filename, ExFileObject) else filename ) - end = b".zarr" if isinstance(filename_, bytes) else ".zarr" - if filename_.endswith(end): + if isinstance(filename, zarr.storage.MutableMapping): + return cls.load_from_zarr(filename, **kwargs) + if isinstance(filename, (bytes, str)): + end = b".zarr" if isinstance(filename_, bytes) else ".zarr" + zarr_file = filename_.endswith(end) + else: + zarr_file = False + logger.info(f"loading file '{filename_}'") + if zarr_file: return cls.load_from_zarr(filename, **kwargs) else: return cls.load_from_netcdf(filename, **kwargs) @classmethod def load_from_zarr( - cls, filename, raw_data=False, remove_vars=None, include_vars=None + cls, + filename, + raw_data=False, + remove_vars=None, + include_vars=None, + indexs=None, + buffer_size=5000000, + **class_kwargs, ): - # FIXME must be investigate, in zarr no dimensions name (or could be add in attr) - array_dim = 50 - BLOC = 5000000 - if not isinstance(filename, str): - filename = filename.astype(str) - h_zarr = zarr.open(filename) - var_list = list(h_zarr.keys()) - if include_vars is not None: - var_list = [i for i in var_list if i in include_vars] - elif remove_vars is not None: - var_list = [i for i in var_list if i not in remove_vars] + """Load data from zarr. + + :param str,store filename: path or store to load data + :param bool raw_data: If true load data without scale_factor and add_offset + :param None,list(str) remove_vars: List of variable name that will be not loaded + :param None,list(str) include_vars: If defined only this variable will be loaded + :param None,dict indexs: Indexes to load only a slice of data + :param int buffer_size: Size of buffer used to load zarr data + :param class_kwargs: argument to set up observations class + :return: Obsevations selected + :return type: class + """ + # FIXME + if isinstance(filename, zarr.storage.MutableMapping): + h_zarr = filename + else: + if not isinstance(filename, str): + filename = filename.astype(str) + h_zarr = zarr.open(filename) + + _check_versions(h_zarr.attrs.get("framework_version", None)) + var_list = cls.build_var_list(list(h_zarr.keys()), remove_vars, include_vars) nb_obs = getattr(h_zarr, var_list[0]).shape[0] + track_array_variables = h_zarr.attrs["track_array_variables"] + + if indexs is not None and "obs" in indexs: + sl = indexs["obs"] + sl = slice(sl.start, min(sl.stop, nb_obs)) + if sl.stop is not None: + nb_obs = sl.stop + if sl.start is not None: + nb_obs -= sl.start + if sl.step is not None: + indexs["obs"] = slice(sl.start, sl.stop) + logger.warning("step of slice won't be use") logger.debug("%d observations will be load", nb_obs) kwargs = dict() - dims = cls.zarr_dimension(filename) - if array_dim in dims: - kwargs["track_array_variables"] = array_dim - kwargs["array_variables"] = list() - for variable in var_list: - if array_dim in h_zarr[variable].shape: - var_inv = VAR_DESCR_inv[variable] - kwargs["array_variables"].append(var_inv) - array_variables = kwargs.get("array_variables", list()) - kwargs["track_extra_variables"] = [] + + kwargs["track_array_variables"] = h_zarr.attrs.get( + "track_array_variables", track_array_variables + ) + + array_variables = list() + for variable in var_list: + if len(h_zarr[variable].shape) > 1: + var_inv = VAR_DESCR_inv[variable] + array_variables.append(var_inv) + kwargs["array_variables"] = array_variables + track_extra_variables = [] + for variable in var_list: var_inv = VAR_DESCR_inv[variable] if var_inv not in cls.ELEMENTS and var_inv not in array_variables: - kwargs["track_extra_variables"].append(var_inv) + track_extra_variables.append(var_inv) + kwargs["track_extra_variables"] = track_extra_variables kwargs["raw_data"] = raw_data kwargs["only_variables"] = ( None if include_vars is None else [VAR_DESCR_inv[i] for i in include_vars] ) + kwargs.update(class_kwargs) eddies = cls(size=nb_obs, **kwargs) - for variable in var_list: + + for i_var, variable in enumerate(var_list): var_inv = VAR_DESCR_inv[variable] - logger.debug("%s will be loaded", variable) + logger.debug("%s will be loaded (%d/%d)", variable, i_var, len(var_list)) # find unit factor - factor = 1 input_unit = h_zarr[variable].attrs.get("unit", None) if input_unit is None: input_unit = h_zarr[variable].attrs.get("units", None) output_unit = VAR_DESCR[var_inv]["nc_attr"].get("units", None) - if ( - output_unit is not None - and input_unit is not None - and output_unit != input_unit - ): - units = UnitRegistry() - try: - input_unit = units.parse_expression( - input_unit, case_sensitive=False - ) - output_unit = units.parse_expression( - output_unit, case_sensitive=False - ) - except UndefinedUnitError: - input_unit = None - except TokenError: - input_unit = None - if input_unit is not None: - factor = input_unit.to(output_unit).to_tuple()[0] - # If we are able to find a conversion - if factor != 1: - logger.info( - "%s will be multiply by %f to take care of units(%s->%s)", - variable, - factor, - input_unit, - output_unit, - ) - nb = h_zarr[variable].shape[0] - + factor = cls.compare_units(input_unit, output_unit, variable) + sl_obs = slice(None) if indexs is None else indexs.get("obs", slice(None)) scale_factor = VAR_DESCR[var_inv].get("scale_factor", None) add_offset = VAR_DESCR[var_inv].get("add_offset", None) - for i in range(0, nb, BLOC): - sl = slice(i, i + BLOC) - data = h_zarr[variable][sl] - if factor != 1: - data *= factor - if raw_data: - if add_offset is not None: - data -= add_offset - if scale_factor is not None: - data /= scale_factor - eddies.obs[var_inv][sl] = data - - eddies.sign_type = h_zarr.attrs.get("rotation_type", 0) + cls.copy_data_to_zarr( + h_zarr[variable], + eddies.obs[var_inv], + sl_obs, + buffer_size, + factor, + raw_data, + scale_factor, + add_offset, + ) + + eddies.sign_type = int(h_zarr.attrs.get("rotation_type", 0)) if eddies.sign_type == 0: logger.debug("File come from another algorithm of identification") eddies.sign_type = -1 return eddies + @staticmethod + def copy_data_to_zarr( + handler_zarr, + handler_eddies, + sl_obs, + buffer_size, + factor, + raw_data, + scale_factor, + add_offset, + ): + """ + Copy with buffer for zarr. + + Zarr need to get real value, and size could be huge, so we use a buffer to manage memory + :param zarr_dataset handler_zarr: + :param array handler_eddies: + :param slice zarr_dataset sl_obs: + :param int buffer_size: + :param float factor: + :param bool raw_data: + :param None,float scale_factor: + :param None,float add_offset: + """ + i_start, i_stop = sl_obs.start, sl_obs.stop + if i_start is None: + i_start = 0 + if i_stop is None: + i_stop = handler_zarr.shape[0] + + for i in range(i_start, i_stop, buffer_size): + sl_in = slice(i, min(i + buffer_size, i_stop)) + data = handler_zarr[sl_in] + if factor != 1: + data *= factor + if raw_data: + if add_offset is not None: + data -= add_offset + if scale_factor is not None: + data /= scale_factor + + sl_out = slice(i - i_start, i - i_start + buffer_size) + handler_eddies[sl_out] = data + @classmethod def load_from_netcdf( - cls, filename, raw_data=False, remove_vars=None, include_vars=None + cls, + filename, + raw_data=False, + remove_vars=None, + include_vars=None, + indexs=None, + **class_kwargs, ): + """Load data from netcdf. + + :param str,ExFileObject filename: path or handler to load data + :param bool raw_data: If true load data without apply scale_factor and add_offset + :param None,list(str) remove_vars: List of variable name which will be not loaded + :param None,list(str) include_vars: If defined only this variable will be loaded + :param None,dict indexs: Indexes to load only a slice of data + :param class_kwargs: argument to set up observations class + :return: Obsevations selected + :return type: class + """ array_dim = "NbSample" if isinstance(filename, bytes): filename = filename.astype(str) - if isinstance(filename, ExFileObject): + if isinstance(filename, (ExFileObject, BufferedReader, BytesIO)): filename.seek(0) args, kwargs = ("in-mem-file",), dict(memory=filename.read()) else: args, kwargs = (filename,), dict() with Dataset(*args, **kwargs) as h_nc: - var_list = list(h_nc.variables.keys()) - if include_vars is not None: - var_list = [i for i in var_list if i in include_vars] - elif remove_vars is not None: - var_list = [i for i in var_list if i not in remove_vars] + _check_versions(getattr(h_nc, "framework_version", None)) + + var_list = cls.build_var_list( + list(h_nc.variables.keys()), remove_vars, include_vars + ) - nb_obs = len(h_nc.dimensions[cls.obs_dimension(h_nc)]) + obs_dim = cls.obs_dimension(h_nc) + nb_obs = len(h_nc.dimensions[obs_dim]) + if indexs is not None and obs_dim in indexs: + sl = indexs[obs_dim] + sl = slice(sl.start, min(sl.stop, nb_obs)) + if sl.stop is not None: + nb_obs = sl.stop + if sl.start is not None: + nb_obs -= sl.start + if sl.step is not None: + indexs[obs_dim] = slice(sl.start, sl.stop) + logger.warning("step of slice won't be use") logger.debug("%d observations will be load", nb_obs) kwargs = dict() if array_dim in h_nc.dimensions: @@ -561,6 +998,7 @@ def load_from_netcdf( if include_vars is None else [VAR_DESCR_inv[i] for i in include_vars] ) + kwargs.update(class_kwargs) eddies = cls(size=nb_obs, **kwargs) for variable in var_list: var_inv = VAR_DESCR_inv[variable] @@ -578,38 +1016,17 @@ def load_from_netcdf( if input_unit is None: input_unit = getattr(h_nc.variables[variable], "units", None) output_unit = VAR_DESCR[var_inv]["nc_attr"].get("units", None) - if ( - output_unit is not None - and input_unit is not None - and output_unit != input_unit - ): - units = UnitRegistry() - try: - input_unit = units.parse_expression( - input_unit, case_sensitive=False - ) - output_unit = units.parse_expression( - output_unit, case_sensitive=False - ) - except UndefinedUnitError: - input_unit = None - except TokenError: - input_unit = None - if input_unit is not None: - factor = input_unit.to(output_unit).to_tuple()[0] - # If we are able to find a conversion - if factor != 1: - logger.info( - "%s will be multiply by %f to take care of units(%s->%s)", - variable, - factor, - input_unit, - output_unit, - ) + factor = cls.compare_units(input_unit, output_unit, variable) + if indexs is None: + indexs = dict() + var_sl = [ + indexs.get(dim, slice(None)) + for dim in h_nc.variables[variable].dimensions + ] if factor != 1: - eddies.obs[var_inv] = h_nc.variables[variable][:] * factor + eddies.obs[var_inv] = h_nc.variables[variable][var_sl] * factor else: - eddies.obs[var_inv] = h_nc.variables[variable][:] + eddies.obs[var_inv] = h_nc.variables[variable][var_sl] for variable in var_list: var_inv = VAR_DESCR_inv[variable] @@ -627,6 +1044,49 @@ def load_from_netcdf( return eddies + @staticmethod + def build_var_list(var_list, remove_vars, include_vars): + if include_vars is not None: + var_list = [i for i in var_list if i in include_vars] + elif remove_vars is not None: + var_list = [i for i in var_list if i not in remove_vars] + return var_list + + @staticmethod + def compare_units(input_unit, output_unit, name): + if output_unit is None or input_unit is None or output_unit == input_unit: + return 1 + units = UnitRegistry() + try: + input_unit = units.parse_expression(input_unit, case_sensitive=False) + output_unit = units.parse_expression(output_unit, case_sensitive=False) + except UndefinedUnitError: + input_unit = None + except TokenError: + input_unit = None + if input_unit is not None: + factor = input_unit.to(output_unit).to_tuple()[0] + # If we are able to find a conversion + if factor != 1: + logger.info( + "%s will be multiply by %f to take care of units(%s->%s)", + name, + factor, + input_unit, + output_unit, + ) + return factor + else: + return 1 + + @classmethod + def from_array(cls, arrays, **kwargs): + nb = arrays["time"].size + eddies = cls(size=nb, **kwargs) + for k, v in arrays.items(): + eddies.obs[k] = v + return eddies + @classmethod def from_zarr(cls, handler): nb_obs = len(handler.dimensions[cls.obs_dimension(handler)]) @@ -643,6 +1103,7 @@ def from_zarr(cls, handler): eddies.obs[variable] = handler.variables[variable][:] else: eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:] + eddies.sign_type = handler.rotation_type return eddies @classmethod @@ -661,23 +1122,23 @@ def from_netcdf(cls, handler): eddies.obs[variable] = handler.variables[variable][:] else: eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:] + eddies.sign_type = handler.rotation_type return eddies def propagate( self, previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model ): """ - Filled virtual obs (C) - Args: - previous_obs: previous obs from current (A) - current_obs: previous obs from virtual (B) - obs_to_extend: - dead_track: - nb_next: - model: + Fill virtual obs (C). - Returns: - New position C = B + AB + :param previous_obs: previous obs from current (A) + :param current_obs: previous obs from virtual (B) + :param obs_to_extend: + :param dead_track: + :param nb_next: + :param model: + + :return: New position C = B + AB """ next_obs = VirtualEddiesObservations( size=nb_next, @@ -726,28 +1187,75 @@ def intern(flag, public_label=False): labels = [VAR_DESCR[label]["nc_name"] for label in labels] return labels - def match(self, other, intern=False, cmin=0): - """return index and score compute with area + def match( + self, + other, + i_self=None, + i_other=None, + method="overlap", + intern=False, + cmin=0, + **kwargs, + ): + """Return index and score computed on the effective contour. + + :param EddiesObservations other: Observations to compare + :param array[bool,int],None i_self: + Index or mask to subset observations, it could avoid to build a specific dataset. + :param array[bool,int],None i_other: + Index or mask to subset observations, it could avoid to build a specific dataset. + :param str method: + - "overlap": the score is computed with contours; + - "circle": circles are computed and used for score (TODO) + :param bool intern: if True, speed contour is used (default = effective contour) + :param float cmin: 0 < cmin < 1, return only couples with score >= cmin + :param dict kwargs: look at :py:meth:`vertice_overlap` + :return: return the indices of the eddies in self coupled with eddies in + other and their associated score + :rtype: (array(int), array(int), array(float)) + + .. minigallery:: py_eddy_tracker.EddiesObservations.match """ x_name, y_name = self.intern(intern) - i, j = bbox_intersection( - self[x_name], self[y_name], other[x_name], other[y_name] - ) - c = vertice_overlap( - self[x_name][i], self[y_name][i], other[x_name][j], other[y_name][j] - ) - m = c > cmin + if i_self is None: + i_self = slice(None) + if i_other is None: + i_other = slice(None) + if method == "overlap": + x0, y0 = self[x_name][i_self], self[y_name][i_self] + x1, y1 = other[x_name][i_other], other[y_name][i_other] + i, j = bbox_intersection(x0, y0, x1, y1) + c = vertice_overlap(x0[i], y0[i], x1[j], y1[j], **kwargs) + elif method == "close_center": + x0, y0 = self.longitude[i_self], self.latitude[i_self] + x1, y1 = other.longitude[i_other], other.latitude[i_other] + i, j, c = close_center(x0, y0, x1, y1, **kwargs) + m = c >= cmin # ajout >= pour garder la cmin dans la sélection return i[m], j[m], c[m] + @staticmethod + def re_reference_index(index, ref): + """ + Shift index with ref + + :param array,int index: local index to re ref + :param slice,array ref: + reference could be a slice in this case we juste add start to index + or could be indices and in this case we need to translate + """ + if isinstance(ref, slice): + return index + ref.start + else: + return ref[index] + @classmethod def cost_function_common_area(cls, xy_in, xy_out, distance, intern=False): - """ How does it work on x bound ? - Args: - xy_in: - xy_out: - distance: - intern: - Returns: + """How does it work on x bound ? + + :param xy_in: + :param xy_out: + :param distance: + :param bool intern: """ x_name, y_name = cls.intern(intern) @@ -782,6 +1290,20 @@ def mask_function(self, other, distance): @staticmethod def cost_function(records_in, records_out, distance): + r"""Return the cost function between two obs. + + .. math:: + + cost = \sqrt{({Amp_{_{in}} - Amp_{_{out}} \over Amp_{_{in}}}) ^2 + + ({Rspeed_{_{in}} - Rspeed_{_{out}} \over Rspeed_{_{in}}}) ^2 + + ({distance \over 125}) ^2 + } + + :param records_in: starting observations + :param records_out: observations to associate + :param distance: computed between in and out + + """ cost = ( (records_in["amplitude"] - records_out["amplitude"]) / records_in["amplitude"] @@ -797,16 +1319,11 @@ def cost_function(records_in, records_out, distance): def shifted_ellipsoid_degrees_mask(self, other, minor=1.5, major=1.5): return shifted_ellipsoid_degrees_mask2( - self.obs["lon"], - self.obs["lat"], - other.obs["lon"], - other.obs["lat"], - minor, - major, + self.lon, self.lat, other.lon, other.lat, minor, major ) def fixed_ellipsoid_mask( - self, other, minor=50, major=100, only_east=False, shifted_ellips=False + self, other, minor=50, major=100, only_east=False, shifted_ellipse=False ): dist = self.distance(other).T accepted = dist < minor @@ -825,46 +1342,45 @@ def fixed_ellipsoid_mask( if isinstance(minor, ndarray): minor = minor[index_self] # focal distance - f_degree = ((major ** 2 - minor ** 2) ** 0.5) / ( - 111.2 * cos(radians(self.obs["lat"][index_self])) + f_degree = ((major**2 - minor**2) ** 0.5) / ( + 111.2 * cos(radians(self.lat[index_self])) ) - lon_self = self.obs["lon"][index_self] - if shifted_ellips: - x_center_ellips = lon_self - (major - minor) / 2 + lon_self = self.lon[index_self] + if shifted_ellipse: + x_center_ellipse = lon_self - (major - minor) / 2 else: - x_center_ellips = lon_self + x_center_ellipse = lon_self - lon_left_f = x_center_ellips - f_degree - lon_right_f = x_center_ellips + f_degree + lon_left_f = x_center_ellipse - f_degree + lon_right_f = x_center_ellipse + f_degree dist_left_f = distance( lon_left_f, - self.obs["lat"][index_self], - other.obs["lon"][index_other], - other.obs["lat"][index_other], + self.lat[index_self], + other.lon[index_other], + other.lat[index_other], ) dist_right_f = distance( lon_right_f, - self.obs["lat"][index_self], - other.obs["lon"][index_other], - other.obs["lat"][index_other], + self.lat[index_self], + other.lon[index_other], + other.lat[index_other], ) dist_2a = (dist_left_f + dist_right_f) / 1000 accepted[index_other, index_self] = dist_2a < (2 * major) if only_east: - d_lon = (other.obs["lon"][index_other] - lon_self + 180) % 360 - 180 + d_lon = (other.lon[index_other] - lon_self + 180) % 360 - 180 mask = d_lon < 0 accepted[index_other[mask], index_self[mask]] = False return accepted.T @staticmethod - def basic_formula_ellips_major_axis( + def basic_formula_ellipse_major_axis( lats, cmin=1.5, cmax=10.0, c0=1.5, lat1=13.5, lat2=5.0, degrees=False ): - """Give major axis in km with a given latitude - """ + """Give major axis in km with a given latitude""" # Straight line between lat1 and lat2: # y = a * x + b a = (cmin - cmax) / (lat1 - lat2) @@ -887,20 +1403,27 @@ def solve_conflict(cost): @staticmethod def solve_simultaneous(cost): + """Deduce link from cost matrix. + + :param array(float) cost: Cost for each available link + :return: return a boolean mask array, True for each valid couple + :rtype: array(bool) + """ mask = ~cost.mask - # Count number of link by self obs and other obs - self_links = mask.sum(axis=1) - other_links = mask.sum(axis=0) + if mask.size == 0: + return mask + # Count number of links by self obs and other obs + self_links, other_links = sum_row_column(mask) max_links = max(self_links.max(), other_links.max()) if max_links > 5: logger.warning("One observation have %d links", max_links) - # If some obs have multiple link, we keep only one link by eddy + # If some obs have multiple links, we keep only one link by eddy eddies_separation = 1 < self_links eddies_merge = 1 < other_links test = eddies_separation.any() or eddies_merge.any() if test: - # We extract matrix which contains concflict + # We extract matrix that contains conflict obs_linking_to_self = mask[eddies_separation].any(axis=0) obs_linking_to_other = mask[:, eddies_merge].any(axis=1) i_self_keep = where(obs_linking_to_other + eddies_separation)[0] @@ -923,13 +1446,13 @@ def solve_simultaneous(cost): security_increment = 0 while False in cost_reduce.mask: if security_increment > max_iteration: - # Maybe check if the size decrease if not rise an exception + # Maybe check if the size decreases if not rise an exception # x_i, y_i = where(-cost_reduce.mask) raise Exception("To many iteration: %d" % security_increment) security_increment += 1 i_min_value = cost_reduce.argmin() i, j = floor(i_min_value / shape[1]).astype(int), i_min_value % shape[1] - # Set to False all link + # Set to False all links mask[i_self_keep[i]] = False mask[:, i_other_keep[j]] = False cost_reduce.mask[i] = True @@ -943,19 +1466,19 @@ def solve_simultaneous(cost): @staticmethod def solve_first(cost, multiple_link=False): mask = ~cost.mask - # Count number of link by self obs and other obs + # Count number of links by self obs and other obs self_links = mask.sum(axis=1) other_links = mask.sum(axis=0) max_links = max(self_links.max(), other_links.max()) if max_links > 5: logger.warning("One observation have %d links", max_links) - # If some obs have multiple link, we keep only one link by eddy + # If some obs have multiple links, we keep only one link by eddy eddies_separation = 1 < self_links eddies_merge = 1 < other_links test = eddies_separation.any() or eddies_merge.any() if test: - # We extract matrix which contains concflict + # We extract matrix that contains conflict obs_linking_to_self = mask[eddies_separation].any(axis=0) obs_linking_to_other = mask[:, eddies_merge].any(axis=1) i_self_keep = where(obs_linking_to_other + eddies_separation)[0] @@ -991,7 +1514,7 @@ def solve_first(cost, multiple_link=False): return mask def solve_function(self, cost_matrix): - return where(self.solve_simultaneous(cost_matrix)) + return numba_where(self.solve_simultaneous(cost_matrix)) def post_process_link(self, other, i_self, i_other): if unique(i_other).shape[0] != i_other.shape[0]: @@ -999,8 +1522,7 @@ def post_process_link(self, other, i_self, i_other): return i_self, i_other def tracking(self, other): - """Track obs between self and other - """ + """Track obs between self and other""" dist = self.distance(other) mask_accept_dist = self.mask_function(other, dist) indexs_closest = where(mask_accept_dist) @@ -1029,8 +1551,7 @@ def to_zarr(self, handler, **kwargs): handler.attrs["track_array_variables"] = self.track_array_variables handler.attrs["array_variables"] = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - for ori_name in fields: + for ori_name in self.fields: # Patch for a transition name = ori_name # @@ -1044,7 +1565,7 @@ def to_zarr(self, handler, **kwargs): dimensions=VAR_DESCR[name]["nc_dims"], ), VAR_DESCR[name]["nc_attr"], - self.observations[ori_name], + self.obs[ori_name], scale_factor=VAR_DESCR[name].get("scale_factor", None), add_offset=VAR_DESCR[name].get("add_offset", None), filters=VAR_DESCR[name].get("filters", None), @@ -1075,12 +1596,9 @@ def to_netcdf(self, handler, **kwargs): handler.track_array_variables = self.track_array_variables handler.array_variables = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - fields_ = array( - [VAR_DESCR[field[0]]["nc_name"] for field in self.observations.dtype.descr] - ) + fields_ = array([VAR_DESCR[field]["nc_name"] for field in self.fields]) i = fields_.argsort() - for ori_name in array(fields)[i]: + for ori_name in array(self.fields)[i]: # Patch for a transition name = ori_name # @@ -1093,7 +1611,7 @@ def to_netcdf(self, handler, **kwargs): dimensions=VAR_DESCR[name]["nc_dims"], ), VAR_DESCR[name]["nc_attr"], - self.observations[ori_name], + self.obs[ori_name], scale_factor=VAR_DESCR[name].get("scale_factor", None), add_offset=VAR_DESCR[name].get("add_offset", None), **kwargs, @@ -1147,6 +1665,33 @@ def create_variable( except ValueError: logger.warning("Data is empty") + @staticmethod + def get_filters_zarr(name): + """Get filters to store in zarr for known variable + + :param str name: private variable name + :return list: filters list + """ + content = VAR_DESCR.get(name) + filters = list() + store_dtype = content["output_type"] + scale_factor, add_offset = content.get("scale_factor", None), content.get( + "add_offset", None + ) + if scale_factor is not None or add_offset is not None: + if add_offset is None: + add_offset = 0 + filters.append( + zarr.FixedScaleOffset( + offset=add_offset, + scale=1 / scale_factor, + dtype=content["nc_type"], + astype=store_dtype, + ) + ) + filters.extend(content.get("filters", [])) + return filters + def create_variable_zarr( self, handler_zarr, @@ -1157,6 +1702,7 @@ def create_variable_zarr( add_offset=None, filters=None, compressor=None, + chunck_size=2500000, ): kwargs_variable["shape"] = data.shape kwargs_variable["compressor"] = ( @@ -1169,8 +1715,8 @@ def create_variable_zarr( add_offset = 0 kwargs_variable["filters"].append( zarr.FixedScaleOffset( - offset=float64(add_offset), - scale=1 / float64(scale_factor), + offset=add_offset, + scale=1 / scale_factor, dtype=kwargs_variable["dtype"], astype=store_dtype, ) @@ -1180,10 +1726,10 @@ def create_variable_zarr( dims = kwargs_variable.get("dimensions", None) # Manage chunk in 2d case if len(dims) == 1: - kwargs_variable["chunks"] = (2500000,) + kwargs_variable["chunks"] = (chunck_size,) if len(dims) == 2: second_dim = data.shape[1] - kwargs_variable["chunks"] = (200000, second_dim) + kwargs_variable["chunks"] = (chunck_size // second_dim, second_dim) kwargs_variable.pop("dimensions") v = handler_zarr.create_dataset(**kwargs_variable) @@ -1211,9 +1757,15 @@ def create_variable_zarr( logger.warning("Data is empty") def write_file( - self, path="./", filename="%(path)s/%(sign_type)s.nc", zarr_flag=False + self, path="./", filename="%(path)s/%(sign_type)s.nc", zarr_flag=False, **kwargs ): - """Write a netcdf with eddy obs + """Write a netcdf or zarr with eddy obs. + Zarr is usefull for large dataset > 10M observations + + :param str path: set path variable + :param str filename: model to store file + :param bool zarr_flag: If True, method will use zarr format instead of netcdf + :param dict kwargs: look at :py:meth:`to_zarr` or :py:meth:`to_netcdf` """ filename = filename % dict( path=path, @@ -1224,13 +1776,14 @@ def write_file( filename = filename.replace(".nc", ".zarr") if filename.endswith(".zarr"): zarr_flag = True - logger.info("Store in %s", filename) + logger.info("Store in %s (%d observations)", filename, len(self)) if zarr_flag: handler = zarr.open(filename, "w") - self.to_zarr(handler) + self.to_zarr(handler, **kwargs) else: - with Dataset(filename, "w", format="NETCDF4") as handler: - self.to_netcdf(handler) + nc_format = kwargs.pop("format", "NETCDF4") + with Dataset(filename, "w", format=nc_format) as handler: + self.to_netcdf(handler, **kwargs) @property def global_attr(self): @@ -1238,6 +1791,7 @@ def global_attr(self): Metadata_Conventions="Unidata Dataset Discovery v1.0", comment="Surface product; mesoscale eddies", framework_used="https://github.com/AntSimi/py-eddy-tracker", + framework_version=__version__, standard_name_vocabulary="NetCDF Climate and Forecast (CF) Metadata Convention Standard Name Table", rotation_type=self.sign_type, ) @@ -1250,37 +1804,432 @@ def set_global_attr_netcdf(self, h_nc): for key, item in self.global_attr.items(): h_nc.setncattr(key, item) - def scatter(self, ax, name, ref=None, factor=1, **kwargs): + def mask_from_polygons(self, polygons): + """ + Return mask for all observations in one of polygons list + + :param list((array,array)) polygons: list of x/y array which be used to identify observations + """ + x, y = polygons[0] + m = insidepoly( + self.longitude, self.latitude, x.reshape((1, -1)), y.reshape((1, -1)) + ) + for x, y in polygons[1:]: + m_ = ~m + m[m_] = insidepoly( + self.longitude[m_], + self.latitude[m_], + x.reshape((1, -1)), + y.reshape((1, -1)), + ) + return m + + def extract_with_area(self, area, **kwargs): + """ + Extract geographically with a bounding box. + + :param dict area: 4 coordinates in a dictionary to specify bounding box (lower left corner and upper right corner) + :param dict kwargs: look at :py:meth:`extract_with_mask` + :return: Return all eddy trajetories in bounds + :rtype: EddiesObservations + + .. code-block:: python + + area = dict(llcrnrlon=x0, llcrnrlat=y0, urcrnrlon=x1, urcrnrlat=y1) + + .. minigallery:: py_eddy_tracker.EddiesObservations.extract_with_area + """ + lat0 = area.get("llcrnrlat", -90) + lat1 = area.get("urcrnrlat", 90) + mask = (self.latitude > lat0) * (self.latitude < lat1) + lon0 = area["llcrnrlon"] + lon = (self.longitude - lon0) % 360 + lon0 + mask *= (lon > lon0) * (lon < area["urcrnrlon"]) + return self.extract_with_mask(mask, **kwargs) + + @property + def time_datetime64(self): + dt = (datetime64("1970-01-01") - datetime64("1950-01-01")).astype("i8") + return (self.time - dt).astype("datetime64[D]") + + def time_sub_sample(self, t0, time_step): + """ + Time sub sampling + + :param int,float t0: reference time that will be keep + :param int,float time_step: keep every observation spaced by time_step + """ + mask = (self.time - t0) % time_step == 0 + return self.extract_with_mask(mask) + + def extract_with_mask(self, mask): + """ + Extract a subset of observations. + + :param array(bool) mask: mask to select observations + :return: same object with selected observations + :rtype: self + """ + + nb_obs = mask.sum() + new = self.__class__.new_like(self, nb_obs) + new.sign_type = self.sign_type + if nb_obs == 0: + logger.warning("Empty dataset will be created") + else: + for field in self.fields: + logger.debug("Copy of field %s ...", field) + new.obs[field] = self.obs[field][mask] + return new + + def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): + """ + Scatter data. + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array,None name: + variable used to fill the contour, if None all elements have the same color + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param float factor: multiply value by + :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.scatter` + :return: scatter mappable + + .. minigallery:: py_eddy_tracker.EddiesObservations.scatter + """ x = self.longitude if ref is not None: x = (x - ref) % 360 + ref - return ax.scatter(x, self.latitude, c=self[name] * factor, **kwargs) + kwargs = kwargs.copy() + if name is not None and "c" not in kwargs: + v = self.parse_varname(name) + kwargs["c"] = v * factor + return ax.scatter(x, self.latitude, **kwargs) - def display( - self, ax, ref=None, extern_only=False, intern_only=False, nobs=True, **kwargs + def filled( + self, + ax, + varname=None, + ref=None, + intern=False, + cmap="magma_r", + lut=10, + vmin=None, + vmax=None, + factor=1, + **kwargs, ): + """ + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array,None varname: variable used to fill the contours, or an array of same size than obs + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool intern: if True draw speed contours instead of effective contours + :param str cmap: matplotlib colormap name + :param int,None lut: Number of colors in the colormap + :param float,None vmin: Min value of the colorbar + :param float,None vmax: Max value of the colorbar + :param float factor: multiply value by + :return: Collection drawed + :rtype: matplotlib.collections.PolyCollection + + .. minigallery:: py_eddy_tracker.EddiesObservations.filled + """ + x_name, y_name = self.intern(intern) + x, y = self[x_name], self[y_name] + if ref is not None: + # TODO : maybe buggy with global display + shape_out = x.shape + x, y = wrap_longitude(x.reshape(-1), y.reshape(-1), ref) + x, y = x.reshape(shape_out), y.reshape(shape_out) + verts = list() + for x_, y_ in zip(x, y): + verts.append(create_vertice(x_, y_)) + if "facecolors" not in kwargs: + kwargs = kwargs.copy() + cmap = get_cmap(cmap, lut) + v = self.parse_varname(varname) * factor + if vmin is None: + vmin = v.min() + if vmax is None: + vmax = v.max() + v = (v - vmin) / (vmax - vmin) + colors = [cmap(v_) for v_ in v] + kwargs["facecolors"] = colors + if "label" in kwargs: + kwargs["label"] = self.format_label(kwargs["label"]) + c = PolyCollection(verts, **kwargs) + ax.add_collection(c) + c.cmap = cmap + c.norm = Normalize(vmin=vmin, vmax=vmax) + return c + + def __merge_filters__(self, *filters): + """ + Compute an intersection between all filters after to evaluate each of them + + :param list(slice,array[int],array[bool]) filters: + + :return: Return applicable object to numpy.array + :rtype: slice, index, mask + """ + filter1 = filters[0] + if len(filters) > 2: + filter2 = self.__merge_filters__(*filters[1:]) + elif len(filters) == 2: + filter2 = filters[1] + # Merge indexs and filter + if isinstance(filter1, slice): + reject = ones(len(self), dtype="bool") + reject[filter1] = False + if isinstance(filter2, slice): + reject[filter2] = False + return ~reject + # Mask case + elif filter2.dtype == bool: + return ~reject * filter2 + # index case + else: + return filter2[~reject[filter2]] + # mask case + elif filter1.dtype == bool: + if isinstance(filter2, slice): + select = zeros(len(self), dtype="bool") + select[filter2] = True + return select * filter1 + # Mask case + elif filter2.dtype == bool: + return filter2 * filter1 + # index case + else: + return filter2[filter1[filter2]] + # index case + else: + if isinstance(filter2, slice): + select = zeros(len(self), dtype="bool") + select[filter2] = True + return filter1[select[filter1]] + # Mask case + elif filter2.dtype == bool: + return filter1[filter2[filter1]] + # index case + else: + return filter1[in1d(filter1, filter2)] + + def merge_filters(self, *filters): + """ + Compute an intersection between all filters after to evaluate each of them + + :param list(callable,None,slice,array[int],array[bool]) filters: + + :return: Return applicable object to numpy.array + :rtype: slice, index, mask + """ + if len(filters) == 1 and isinstance(filters[0], list): + filters = filters[0] + filters_ = list() + # Remove all filter which select all obs + for filter in filters: + if callable(filter): + filter = filter(self) + if filter is None: + continue + if isinstance(filter, slice): + if filter == slice(None): + continue + elif filter.dtype == "bool": + if filter.all(): + continue + if not filter.any(): + return empty(0, dtype=int) + filters_.append(filter) + if len(filters_) == 1: + return filters_[0] + elif len(filters_) == 0: + return slice(None) + else: + return self.__merge_filters__(*filters_) + + def bins_stat(self, xname, bins=None, yname=None, method=None, mask=None): + """ + :param str,array xname: variable to compute stats on + :param array, None bins: bins to perform statistics, if None bins = arange(variable.min(), variable.max() + 2) + :param None,str,array yname: variable used to apply method + :param None,str method: If None method counts the number of observations in each bin, can be "mean", "std" + :param None,array(bool) mask: If defined use only True position + :return: x array and y array + :rtype: array,array + + .. minigallery:: py_eddy_tracker.EddiesObservations.bins_stat + """ + v = self.parse_varname(xname) + mask = self.merge_filters(mask) + v = v[mask] + if bins is None: + bins = arange(v.min(), v.max() + 2) + y, x = hist_numba(v, bins=bins) + x = (x[1:] + x[:-1]) / 2 + if method == "mean": + y_v = self.parse_varname(yname) + y_v = y_v[mask] + y_, _ = histogram(v, bins=bins, weights=y_v) + with errstate(divide="ignore", invalid="ignore"): + y = y_ / y + return x, y + + def format_label(self, label): + t0, t1 = self.period + return label.format( + t0=t0, + t1=t1, + nb_obs=len(self), + ) + + def display_color(self, ax, field, ref=None, intern=False, **kwargs): + """Plot colored contour of eddies + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array field: color field + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool intern: if True, draw the speed contour + :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` + + .. minigallery:: py_eddy_tracker.EddiesObservations.display_color + """ + xname, yname = self.intern(intern) + x, y = self[xname], self[yname] + + if ref is not None: + # TODO : maybe buggy with global display + shape_out = x.shape + x, y = wrap_longitude(x.reshape(-1), y.reshape(-1), ref) + x, y = x.reshape(shape_out), y.reshape(shape_out) + + c = self.parse_varname(field) + cmap = get_cmap(kwargs.pop("cmap", "Spectral_r")) + cmin, cmax = kwargs.pop("vmin", c.min()), kwargs.pop("vmax", c.max()) + colors = cmap((c - cmin) / (cmax - cmin)) + lines = LineCollection( + [create_vertice(i, j) for i, j in zip(x, y)], colors=colors, **kwargs + ) + ax.add_collection(lines) + lines.cmap = cmap + lines.norm = Normalize(vmin=cmin, vmax=cmax) + return lines + + def display(self, ax, ref=None, extern_only=False, intern_only=False, **kwargs): + """Plot the speed and effective (dashed) contour of the eddies + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool extern_only: if True, draw only the effective contour + :param bool intern_only: if True, draw only the speed contour + :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.plot` + + .. minigallery:: py_eddy_tracker.EddiesObservations.display + """ if not extern_only: - lon_s = flatten_line_matrix(self.obs["contour_lon_s"]) - lat_s = flatten_line_matrix(self.obs["contour_lat_s"]) + lon_s = flatten_line_matrix(self.contour_lon_s) + lat_s = flatten_line_matrix(self.contour_lat_s) if not intern_only: - lon_e = flatten_line_matrix(self.obs["contour_lon_e"]) - lat_e = flatten_line_matrix(self.obs["contour_lat_e"]) - if nobs and "label" in kwargs: - kwargs["label"] += " (%s observations)" % len(self) + lon_e = flatten_line_matrix(self.contour_lon_e) + lat_e = flatten_line_matrix(self.contour_lat_e) + if "label" in kwargs: + kwargs["label"] = self.format_label(kwargs["label"]) kwargs_e = kwargs.copy() + if "ls" not in kwargs_e and "linestyle" not in kwargs_e: + kwargs_e["linestyle"] = "-." if not extern_only: kwargs_e.pop("label", None) + mappables = list() if not extern_only: if ref is not None: lon_s, lat_s = wrap_longitude(lon_s, lat_s, ref, cut=True) - ax.plot(lon_s, lat_s, **kwargs) + mappables.append(ax.plot(lon_s, lat_s, **kwargs)[0]) if not intern_only: if ref is not None: lon_e, lat_e = wrap_longitude(lon_e, lat_e, ref, cut=True) - ax.plot(lon_e, lat_e, linestyle="-.", **kwargs_e) + mappables.append(ax.plot(lon_e, lat_e, **kwargs_e)[0]) + return mappables + + def first_obs(self): + """ + Get first obs of each trajectory. + + :rtype: __class__ + + .. minigallery:: py_eddy_tracker.EddiesObservations.first_obs + """ + return self.extract_with_mask(self.n == 0) + + def last_obs(self): + """ + Get Last obs of each trajectory. - def grid_count(self, bins, intern=False, center=False): + :rtype: __class__ + + .. minigallery:: py_eddy_tracker.EddiesObservations.last_obs + """ + m = zeros(len(self), dtype="bool") + m[-1] = True + m[:-1][self.n[1:] == 0] = True + return self.extract_with_mask(m) + + def is_convex(self, intern=False): + """ + Get flag of the eddy's convexity + + :param bool intern: If True use speed contour instead of effective contour + :return: True if the contour is convex + :rtype: array[bool] + """ + xname, yname = self.intern(intern) + return convexs(self[xname], self[yname]) + + def contains(self, x, y, intern=False): + """ + Return index of contour containing (x,y) + + :param array x: longitude + :param array y: latitude + :param bool intern: If true use speed contour instead of effective contour + :return: indexs, -1 if no index + :rtype: array[int32] + """ + xname, yname = self.intern(intern) + m = ~(isnan(x) + isnan(y)) + i = -ones(x.shape, dtype="i4") + + if x.size != 0 and m.any(): + i[m] = poly_indexs(x[m], y[m], self[xname], self[yname]) + return i + + def inside(self, x, y, intern=False): + """ + True for each point inside the effective contour of an eddy + + :param array x: longitude + :param array y: latitude + :param bool intern: If true use speed contour instead of effective contour + :return: flag + :rtype: array[bool] + """ + xname, yname = self.intern(intern) + return insidepoly(x, y, self[xname], self[yname]) + + def grid_count(self, bins, intern=False, center=False, filter=slice(None)): + """ + Count the eddies in each bin (use all pixels in each contour) + + :param (numpy.array,numpy.array) bins: bins (grid) to count + :param bool intern: if True use speed contour only + :param bool center: if True use of center to count + :param array,mask,slice filter: keep the data selected with the filter + :return: return the grid of counts + :rtype: py_eddy_tracker.dataset.grid.RegularGridDataset + + .. minigallery:: py_eddy_tracker.EddiesObservations.grid_count + """ + filter = self.merge_filters(filter) x_name, y_name = self.intern(intern) x_bins, y_bins = arange(*bins[0]), arange(*bins[1]) x0 = bins[0][0] @@ -1295,51 +2244,378 @@ def grid_count(self, bins, intern=False, center=False): lat=(y_bins[1:] + y_bins[:-1]) / 2, ), variables_description=dict( - count=dict(long_name="Number of times pixel is in eddies") + count=dict(long_name="Number of times the pixel is within an eddy") ), centered=True, ) if center: - x, y = (self.longitude - x0) % 360 + x0, self.latitude + x, y = (self.longitude[filter] - x0) % 360 + x0, self.latitude[filter] grid[:] = histogram2d(x, y, (x_bins, y_bins))[0] grid.mask = grid.data == 0 else: - x_ref = ((self.longitude - x0) % 360 + x0 - 180).reshape(-1, 1) - x, y = (self[x_name] - x_ref) % 360 + x_ref, self[y_name] - for x_, y_ in zip(x, y): - i, j = BasePath(create_vertice(x_, y_)).pixels_in(regular_grid) - grid_count_(grid, i, j) + x_ref = ((self.longitude[filter] - x0) % 360 + x0 - 180).reshape(-1, 1) + x_contour, y_contour = self[x_name][filter], self[y_name][filter] + grid_count_pixel_in( + grid.data, + x_contour, + y_contour, + x_ref, + regular_grid.x_bounds, + regular_grid.y_bounds, + regular_grid.xstep, + regular_grid.ystep, + regular_grid.N, + regular_grid.is_circular(), + regular_grid.x_size, + regular_grid.x_c, + regular_grid.y_c, + ) grid.mask = grid == 0 return regular_grid - def grid_stat(self, bins, varname): + def grid_box_stat(self, bins, varname, method=50, data=None, filter=slice(None)): + """ + Get percentile of eddies in each bin + + :param (numpy.array,numpy.array) bins: bins (grid) to count + :param str varname: variable to apply the method if data is None and will be output name + :param str,float method: method to apply. If float, use ? + :param array data: Array used to compute stat if defined + :param array,mask,slice filter: keep the data selected with the filter + :return: return grid of method + :rtype: py_eddy_tracker.dataset.grid.RegularGridDataset + + .. minigallery:: py_eddy_tracker.EddiesObservations.grid_box_stat + """ x_bins, y_bins = arange(*bins[0]), arange(*bins[1]) x0 = bins[0][0] x, y = (self.longitude - x0) % 360 + x0, self.latitude - sum_obs = histogram2d(x, y, (x_bins, y_bins), weights=self[varname])[0] - nb_obs = histogram2d(x, y, (x_bins, y_bins))[0] + data = self[varname] if data is None else data + if hasattr(data, "mask"): + filter = self.merge_filters(~data.mask, self.merge_filters(filter)) + else: + filter = self.merge_filters(filter) + x, y, data = x[filter], y[filter], data[filter] + from ..dataset.grid import RegularGridDataset + shape = (x_bins.shape[0] - 1, y_bins.shape[0] - 1) + grid = ma.empty(shape, dtype=data.dtype) + grid.mask = ones(shape, dtype="bool") regular_grid = RegularGridDataset.with_array( coordinates=("x", "y"), - datas={ - varname: ma.array(sum_obs / nb_obs, mask=nb_obs == 0), - "x": x_bins[:-1], - "y": y_bins[:-1], - }, + datas={varname: grid, "x": x_bins[:-1], "y": y_bins[:-1]}, + centered=False, ) + grid_box_stat( + regular_grid.x_c, + regular_grid.y_c, + grid.data, + grid.mask, + x, + y, + data, + regular_grid.is_circular(), + method, + ) + return regular_grid + def grid_stat(self, bins, varname, data=None): + """ + Return the mean of the eddies' variable in each bin + + :param (numpy.array,numpy.array) bins: bins (grid) to compute the mean on + :param str varname: name of variable to compute the mean on and output grid_name + :param array data: Array used to compute stat if defined + :return: return the gridde mean variable + :rtype: py_eddy_tracker.dataset.grid.RegularGridDataset + + .. minigallery:: py_eddy_tracker.EddiesObservations.grid_stat + """ + x_bins, y_bins = arange(*bins[0]), arange(*bins[1]) + x0 = bins[0][0] + x, y = (self.longitude - x0) % 360 + x0, self.latitude + data = self[varname] if data is None else data + if hasattr(data, "mask"): + m = ~data.mask + sum_obs = histogram2d(x[m], y[m], (x_bins, y_bins), weights=data[m])[0] + nb_obs = histogram2d(x[m], y[m], (x_bins, y_bins))[0] + else: + sum_obs = histogram2d(x, y, (x_bins, y_bins), weights=data)[0] + nb_obs = histogram2d(x, y, (x_bins, y_bins))[0] + from ..dataset.grid import RegularGridDataset + + with errstate(divide="ignore", invalid="ignore"): + regular_grid = RegularGridDataset.with_array( + coordinates=("x", "y"), + datas={ + varname: ma.array(sum_obs / nb_obs, mask=nb_obs == 0), + "x": x_bins[:-1], + "y": y_bins[:-1], + }, + centered=False, + ) + return regular_grid + + def interp_grid( + self, grid_object, varname, i=None, method="center", dtype=None, intern=None + ): + """ + Interpolate a grid on a center or contour with mean, min or max method + + :param grid_object: Handler of grid to interp + :type grid_object: py_eddy_tracker.dataset.grid.RegularGridDataset + :param str varname: Name of variable to use + :param array[bool,int],None i: + Index or mask to subset observations, it could avoid to build a specific dataset. + :param str method: 'center', 'mean', 'max', 'min', 'nearest' + :param str dtype: if None we use var dtype + :param bool intern: Use extern or intern contour + + .. minigallery:: py_eddy_tracker.EddiesObservations.interp_grid + """ + if method in ("center", "nearest"): + x, y = self.longitude, self.latitude + if i is not None: + x, y = x[i], y[i] + return grid_object.interp(varname, x, y, method) + elif method in ("min", "max", "mean", "count"): + x0 = grid_object.x_bounds[0] + x_name, y_name = self.intern(False if intern is None else intern) + x_ref = ((self.longitude - x0) % 360 + x0 - 180).reshape(-1, 1) + x, y = (self[x_name] - x_ref) % 360 + x_ref, self[y_name] + if i is not None: + x, y = x[i], y[i] + grid = grid_object.grid(varname) + result = empty(x.shape[0], dtype=grid.dtype if dtype is None else dtype) + min_method = method == "min" + grid_stat( + grid_object.x_c, + grid_object.y_c, + -grid.data if min_method else grid.data, + grid.mask, + x, + y, + result, + grid_object.is_circular(), + method="max" if min_method else method, + ) + return -result if min_method else result + else: + raise Exception(f'method "{method}" unknown') + + @property + def period(self): + """ + Give the time coverage. If collection is empty, return nan,nan + + :return: first and last date + :rtype: (int,int) + """ + if self.period_ is None: + if self.time.size < 1: + self.period_ = nan, nan + else: + self.period_ = self.time.min(), self.time.max() + return self.period_ + + @property + def nb_days(self): + """Return period in days covered by the dataset + + :return: Number of days + :rtype: int + """ + return self.period[1] - self.period[0] + 1 + + def create_particles(self, step, intern=True): + """Create particles inside contour (Default : speed contour). Avoid creating too large numpy arrays, only to be masked + + :param step: step for particles + :type step: float + :param bool intern: If true use speed contour instead of effective contour + :return: lon, lat and indices of particles + :rtype: tuple(np.array) + """ + + xname, yname = self.intern(intern) + return create_meshed_particles(self[xname], self[yname], step) + + def empty_dataset(self): + return self.new_like(self, 0) + @njit(cache=True) def grid_count_(grid, i, j): + """ + Add 1 to each index + """ for i_, j_ in zip(i, j): grid[i_, j_] += 1 -class VirtualEddiesObservations(EddiesObservations): - """Class to work with virtual obs +@njit(cache=True) +def grid_count_pixel_in( + grid, + x, + y, + x_ref, + x_bounds, + y_bounds, + xstep, + ystep, + N, + is_circular, + x_size, + x_c, + y_c, +): + """ + Count how many times a pixel is used. + + :param array grid: + :param array x: x for all contour + :param array y: y for all contour + :param array x_ref: x reference for wrapping + :param array x_bounds: grid longitude + :param array y_bounds: grid latitude + :param float xstep: step between two longitude + :param float ystep: step between two latitude + :param int N: shift of index to enlarge window + :param bool is_circular: To know if grid is wrappable + :param int x_size: Number of longitude + :param array x_c: longitude coordinate of grid + :param array y_c: latitude coordinate of grid + """ + nb = x_ref.shape[0] + for i_ in range(nb): + x_, y_, x_ref_ = x[i_], y[i_], x_ref[i_] + x_ = (x_ - x_ref_) % 360 + x_ref_ + x_, y_ = reduce_size(x_, y_) + v = create_vertice(x_, y_) + (x_start, x_stop), (y_start, y_stop) = bbox_indice_regular( + v, + x_bounds, + y_bounds, + xstep, + ystep, + N, + is_circular, + x_size, + ) + i, j = get_pixel_in_regular(v, x_c, y_c, x_start, x_stop, y_start, y_stop) + grid_count_(grid, i, j) + + +@njit(cache=True) +def grid_box_stat(x_c, y_c, grid, mask, x, y, value, circular=False, method=50): + """ + Compute method on each set (one set by box) + + :param array_like x_c: grid longitude coordinates + :param array_like y_c: grid latitude coordinates + :param array_like grid: grid to store the result + :param array[bool] mask: grid to store unused boxes + :param array_like x: longitude of observations + :param array_like y: latitude of observations + :param array_like value: value to group to apply method + :param bool circular: True if grid is wrappable + :param float method: percentile + """ + xstep, ystep = x_c[1] - x_c[0], y_c[1] - y_c[0] + x0, y0 = x_c[0] - xstep / 2.0, y_c[0] - ystep / 2.0 + nb_x = x_c.shape[0] + nb_y = y_c.shape[0] + i, j = ( + ((x - x0) // xstep).astype(numba_types.int32), + ((y - y0) // ystep).astype(numba_types.int32), + ) + if circular: + i %= nb_x + else: + if (i < 0).any(): + raise Exception("x indices underflow") + if (i >= nb_x).any(): + raise Exception("x indices overflow") + if (j < 0).any(): + raise Exception("y indices underflow") + if (j >= nb_y).any(): + raise Exception("y indices overflow") + abs_i = j * nb_x + i + k_sort = abs_i.argsort() + i0, j0 = i[k_sort[0]], j[k_sort[0]] + values = list() + for k in k_sort: + i_, j_ = i[k], j[k] + # group change + if i_ != i0 or j_ != j0: + # apply method and store result + grid[i_, j_] = percentile(values, method) + mask[i_, j_] = False + # start new group + i0, j0 = i_, j_ + # reset list + values.clear() + values.append(value[k]) + + +@njit(cache=True) +def grid_stat(x_c, y_c, grid, mask, x, y, result, circular=False, method="mean"): """ + Compute the mean or the max of the grid for each contour + + :param array_like x_c: the grid longitude coordinates + :param array_like y_c: the grid latitude coordinates + :param array_like grid: grid value + :param array[bool] mask: mask for invalid value + :param array_like x: longitude of contours + :param array_like y: latitude of contours + :param array_like result: return values + :param bool circular: True if grid is wrappable + :param str method: 'mean', 'max' + """ + # FIXME : how does it work on grid bound + nb = result.shape[0] + xstep, ystep = x_c[1] - x_c[0], y_c[1] - y_c[0] + x0, y0 = x_c - xstep / 2.0, y_c - ystep / 2.0 + nb_x = x_c.shape[0] + max_method = "max" == method + mean_method = "mean" == method + count_method = "count" == method + for elt in range(nb): + v = create_vertice(x[elt], y[elt]) + (x_start, x_stop), (y_start, y_stop) = bbox_indice_regular( + v, x0, y0, xstep, ystep, 1, circular, nb_x + ) + i, j = get_pixel_in_regular(v, x_c, y_c, x_start, x_stop, y_start, y_stop) + + if count_method: + result[elt] = i.shape[0] + elif mean_method: + v_sum = 0 + nb_ = 0 + for i_, j_ in zip(i, j): + if mask[i_, j_]: + continue + v_sum += grid[i_, j_] + nb_ += 1 + # FIXME : how does it work on grid bound, + if nb_ == 0: + result[elt] = nan + else: + result[elt] = v_sum / nb_ + elif max_method: + v_max = -1e40 + for i_, j_ in zip(i, j): + values = grid[i_, j_] + # FIXME must use mask + v_max = max(v_max, values) + result[elt] = v_max + + +class VirtualEddiesObservations(EddiesObservations): + """Class to work with virtual obs""" __slots__ = () @@ -1348,3 +2624,90 @@ def elements(self): elements = super().elements elements.extend(["track", "segment_size", "dlon", "dlat"]) return list(set(elements)) + + +@njit(cache=True) +def numba_where(mask): + """Usefull when mask is close to be empty""" + return where(mask) + + +@njit(cache=True) +def sum_row_column(mask): + """ + Compute sum on row and column at same time + """ + nb_x, nb_y = mask.shape + row_sum = zeros(nb_x, dtype=numba_types.int32) + column_sum = zeros(nb_y, dtype=numba_types.int32) + for i in range(nb_x): + for j in range(nb_y): + if mask[i, j]: + row_sum[i] += 1 + column_sum[j] += 1 + return row_sum, column_sum + + +@njit(cache=True) +def numba_digitize(values, bins): + # Check if bins are regular + nb_bins = bins.shape[0] + step = bins[1] - bins[0] + bin_previous = bins[1] + for i in range(2, nb_bins): + bin_current = bins[i] + if step != (bin_current - bin_previous): + # If bins are not regular + return digitize(values, bins) + bin_previous = bin_current + nb_values = values.shape[0] + out = empty(nb_values, dtype=numba_types.int64) + up, down = bins[0], bins[-1] + for i in range(nb_values): + v_ = values[i] + if v_ >= down: + out[i] = nb_bins + continue + if v_ < up: + out[i] = 0 + continue + out[i] = (v_ - bins[0]) / step + 1 + return out + + +@njit(cache=True) +def iter_mode_reduce(x, bins): + """ + Test if we could use a reduce mode + + :param array x: array to divide in group + :param array bins: array which defined bounds between each group + :return: If reduce mode, translator, and reduce x + """ + nb = x.shape[0] + # If we use less than half value + limit = nb // 2 + # low and up + x0, x1 = bins[0], bins[-1] + m = empty(nb, dtype=numba_types.bool_) + # To count number of value cover by bins + c = 0 + for i in range(nb): + x_ = x[i] + test = (x_ >= x0) * (x_ <= x1) + m[i] = test + if test: + c += 1 + # If number value exceed limit + if c > limit: + return False, empty(0, dtype=numba_types.int_), x + # Indices to be able to translate in full index array + indices = empty(c, dtype=numba_types.int_) + x_ = empty(c, dtype=x.dtype) + j = 0 + for i in range(nb): + if m[i]: + indices[j] = i + x_[j] = x[i] + j += 1 + return True, indices, x_ diff --git a/src/py_eddy_tracker/observations/tracking.py b/src/py_eddy_tracker/observations/tracking.py index 05098cd0..fa1c1f93 100644 --- a/src/py_eddy_tracker/observations/tracking.py +++ b/src/py_eddy_tracker/observations/tracking.py @@ -1,68 +1,51 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2017 by Evan Mason and Antoine Delepoulle -Email: emason@imedea.uib-csic.es -=========================================================================== - -tracking.py - -Version 3.0.0 - -=========================================================================== - +Class to manage observations gathered in trajectories """ +from datetime import datetime, timedelta import logging + +from numba import njit from numpy import ( - empty, arange, - where, - unique, - interp, - ones, - bool_, - zeros, + arctan2, array, + bool_, + concatenate, + cos, + degrees, + empty, + histogram, + int_, median, + nan, + ones, + radians, + sin, + unique, + zeros, ) -from datetime import datetime, timedelta -from numba import njit -from Polygon import Polygon -from .observation import EddiesObservations -from .. import VAR_DESCR_inv -from ..generic import split_line, wrap_longitude, build_index -from ..poly import polygon_overlap, create_vertice_from_2darray +from .. import VAR_DESCR_inv, __version__ +from ..generic import build_index, cumsum_by_track, distance, split_line, wrap_longitude +from ..poly import bbox_intersection, merge, vertice_overlap +from .groups import GroupEddiesObservations, get_missing_indices logger = logging.getLogger("pet") -class TrackEddiesObservations(EddiesObservations): - """Class to practice Tracking on observations - """ +class TrackEddiesObservations(GroupEddiesObservations): + """Class to practice Tracking on observations""" - __slots__ = ("__obs_by_track", "__first_index_of_track") + __slots__ = ("__obs_by_track", "__first_index_of_track", "__nb_track") ELEMENTS = [ "lon", "lat", "radius_s", "radius_e", + "speed_area", + "effective_area", "amplitude", "speed_average", "time", @@ -83,54 +66,154 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__first_index_of_track = None self.__obs_by_track = None + self.__nb_track = None + + def track_slice(self, track): + i0 = self.index_from_track[track] + return slice(i0, i0 + self.nb_obs_by_track[track]) - def filled_by_interpolation(self, mask): - """Filled selected values by interpolation + def iter_track(self): """ - nb_filled = mask.sum() - logger.info("%d obs will be filled (unobserved)", nb_filled) + Yield track + """ + for i0, nb in zip(self.index_from_track, self.nb_obs_by_track): + if nb == 0: + continue + yield self.index(slice(i0, i0 + nb)) + def get_missing_indices(self, dt): + """Find indices where observations are missing. + + :param int,float dt: theorical delta time between 2 observations + """ + return get_missing_indices( + self.time, + self.track, + dt=dt, + flag_untrack=False, + indice_untrack=self.NOGROUP, + ) + + def fix_next_previous_obs(self): + """Function used after 'insert_virtual', to correct next_obs and + previous obs. + """ + + pass + + @property + def nb_tracks(self): + """ + Count and return number of track + """ + if self.__nb_track is None: + if len(self) == 0: + self.__nb_track = 0 + else: + self.__nb_track = (self.nb_obs_by_track != 0).sum() + return self.__nb_track + + def __repr__(self): + content = super().__repr__() + t0, t1 = self.period + period = t1 - t0 + 1 + nb = self.nb_obs_by_track nb_obs = len(self) - index = arange(nb_obs) + m = self.virtual.astype("bool") + nb_m = m.sum() + bins_t = (1, 30, 90, 180, 270, 365, 1000, 10000) + nb_tracks_by_t = histogram(nb, bins=bins_t)[0] + nb_obs_by_t = histogram(nb, bins=bins_t, weights=nb)[0] + pct_tracks_by_t = nb_tracks_by_t / nb_tracks_by_t.sum() * 100.0 + pct_obs_by_t = nb_obs_by_t / nb_obs_by_t.sum() * 100.0 + d = self.distance_to_next() / 1000.0 + cum_d = cumsum_by_track(d, self.tracks) + m_last = ones(d.shape, dtype="bool") + m_last[-1] = False + m_last[self.index_from_track[1:] - 1] = False + content += f""" + | {self.nb_tracks} tracks ({ + nb_obs / self.nb_tracks:.2f} obs/tracks, shorter {nb[nb!=0].min()} obs, longer {nb.max()} obs) + | {nb_m} filled observations ({nb_m / self.nb_tracks:.2f} obs/tracks, {nb_m / nb_obs * 100:.2f} % of total) + | Intepolated speed area : {self.speed_area[m].sum() / period / 1e12:.2f} Mkm²/day + | Intepolated effective area : {self.effective_area[m].sum() / period / 1e12:.2f} Mkm²/day + | Distance by day : Mean {d[m_last].mean():.2f} , Median {median(d[m_last]):.2f} km/day + | Distance by track : Mean {cum_d[~m_last].mean():.2f} , Median {median(cum_d[~m_last]):.2f} km/track + ----Distribution in lifetime: + | Lifetime (days ) {self.box_display(bins_t)} + | Percent of tracks : {self.box_display(pct_tracks_by_t)} + | Percent of eddies : {self.box_display(pct_obs_by_t)}""" + return content + + def add_distance(self): + """Add a field of distance (m) between two consecutive observations, 0 for the last observation of each track""" + if "distance_next" in self.fields: + return self + new = self.add_fields(("distance_next",)) + new["distance_next"][:1] = self.distance_to_next() + return new - for field in self.obs.dtype.descr: - var = field[0] - if ( - var in ["n", "virtual", "track", "cost_association"] - or var in self.array_variables - ): - continue - # to normalize longitude before interpolation - if var == "lon": - lon = self.obs[var] - first = where(self.obs["n"] == 0)[0] - nb_obs = empty(first.shape, dtype="u4") - nb_obs[:-1] = first[1:] - first[:-1] - nb_obs[-1] = lon.shape[0] - first[-1] - lon0 = (lon[first] - 180).repeat(nb_obs) - self.obs[var] = (lon - lon0) % 360 + lon0 - self.obs[var][mask] = interp( - index[mask], index[~mask], self.obs[var][~mask] - ) + def distance_to_next(self): + """ + :return: array of distance in m, 0 when next obs is from another track + :rtype: array + """ + d = distance( + self.longitude[:-1], + self.latitude[:-1], + self.longitude[1:], + self.latitude[1:], + ) + d[self.index_from_track[1:] - 1] = 0 + d_ = empty(d.shape[0] + 1, dtype=d.dtype) + d_[:-1] = d + d_[-1] = 0 + return d_ + + def normalize_longitude(self): + """Normalize all longitudes + + Normalize longitude field and in the same range : + - longitude_max + - contour_lon_e (how to do if in raw) + - contour_lon_s (how to do if in raw) + """ + if self.lon.size == 0: + return + lon0 = (self.lon[self.index_from_track] - 180).repeat(self.nb_obs_by_track) + logger.debug("Normalize longitude") + self.lon[:] = (self.lon - lon0) % 360 + lon0 + if "lon_max" in self.fields: + logger.debug("Normalize longitude_max") + self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 + if not self.raw_data: + if "contour_lon_e" in self.fields: + logger.debug("Normalize effective contour longitude") + self.contour_lon_e[:] = ( + (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 + ).T + if "contour_lon_s" in self.fields: + logger.debug("Normalize speed contour longitude") + self.contour_lon_s[:] = ( + (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 + ).T def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True): - """Select eddies which are longer than nb_min - """ + """Select the trajectories longer than nb_min""" mask = nb_obs >= nb_min nb_obs_select = mask.sum() logger.info("Selection of %d observations", nb_obs_select) eddies = self.__class__.new_like(self, nb_obs_select) eddies.sign_type = self.sign_type - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - eddies.obs[var] = self.obs[var][mask] + eddies.obs[field] = self.obs[field][mask] if compress_id: - list_id = unique(eddies.obs["track"]) + list_id = unique(eddies.obs.track) list_id.sort() id_translate = arange(list_id.max() + 1) id_translate[list_id] = arange(len(list_id)) + 1 - eddies.obs["track"] = id_translate[eddies.obs["track"]] + eddies.track = id_translate[eddies.track] return eddies @property @@ -140,49 +223,35 @@ def elements(self): return list(set(elements)) def set_global_attr_netcdf(self, h_nc): - """Set global attr - """ + """Set global attributes""" h_nc.title = "Cyclonic" if self.sign_type == -1 else "Anticyclonic" h_nc.Metadata_Conventions = "Unidata Dataset Discovery v1.0" h_nc.comment = "Surface product; mesoscale eddies" h_nc.framework_used = "https://github.com/AntSimi/py-eddy-tracker" + h_nc.framework_version = __version__ h_nc.standard_name_vocabulary = ( "NetCDF Climate and Forecast (CF) Metadata Convention Standard Name Table" ) h_nc.date_created = datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ") t = h_nc.variables[VAR_DESCR_inv["j1"]] - delta = t.max - t.min + 1 - h_nc.time_coverage_duration = "P%dD" % delta - d_start = datetime(1950, 1, 1) + timedelta(int(t.min)) - d_end = datetime(1950, 1, 1) + timedelta(int(t.max)) - h_nc.time_coverage_start = d_start.strftime("%Y-%m-%dT00:00:00Z") - h_nc.time_coverage_end = d_end.strftime("%Y-%m-%dT00:00:00Z") - - def extract_with_area(self, area, **kwargs): - """ - Extract with a bounding box - Args: - area: 4 coordinates in a dictionary to specify bounding box (lower left corner and upper right corner) - **kwargs: - - Returns: - - """ - mask = (self.latitude > area["llcrnrlat"]) * (self.latitude < area["urcrnrlat"]) - lon0 = area["llcrnrlon"] - lon = (self.longitude - lon0) % 360 + lon0 - mask *= (lon > lon0) * (lon < area["urcrnrlon"]) - return self.__extract_with_mask(mask, **kwargs) + if t.size: + delta = t.max - t.min + 1 + h_nc.time_coverage_duration = "P%dD" % delta + d_start = datetime(1950, 1, 1) + timedelta(int(t.min)) + d_end = datetime(1950, 1, 1) + timedelta(int(t.max)) + h_nc.time_coverage_start = d_start.strftime("%Y-%m-%dT00:00:00Z") + h_nc.time_coverage_end = d_end.strftime("%Y-%m-%dT00:00:00Z") def extract_with_period(self, period, **kwargs): """ - Extract with a period - Args: - period: two date to define period, must be specify from 1/1/1950 - **kwargs: directly give to __extract_with_mask + Extract within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :param dict kwargs: look at :py:meth:`extract_with_mask` + :return: Return all eddy tracks in period + :rtype: TrackEddiesObservations - Returns: - same object with selected data + .. minigallery:: py_eddy_tracker.TrackEddiesObservations.extract_with_period """ dataset_period = self.period p_min, p_max = period @@ -196,15 +265,32 @@ def extract_with_period(self, period, **kwargs): mask *= self.time <= p_max elif p_max < 0: mask *= self.time <= (dataset_period[1] + p_max) - return self.__extract_with_mask(mask, **kwargs) + return self.extract_with_mask(mask, **kwargs) - @property - def period(self): + def get_azimuth(self, equatorward=False): """ - Give time coverage - Returns: 2 date + Return azimuth for each track. + + Azimuth is computed with first and last observations + + :param bool equatorward: If True, Poleward is positive and Equatorward negative + :rtype: array """ - return self.time.min(), self.time.max() + i0, nb = self.index_from_track, self.nb_obs_by_track + i0 = i0[nb != 0] + i1 = i0 - 1 + nb[nb != 0] + lat0, lon0 = self.latitude[i0], self.longitude[i0] + lat1, lon1 = self.latitude[i1], self.longitude[i1] + lat0, lon0 = radians(lat0), radians(lon0) + lat1, lon1 = radians(lat1), radians(lon1) + dlon = lon1 - lon0 + x = cos(lat0) * sin(lat1) - sin(lat0) * cos(lat1) * cos(dlon) + y = sin(dlon) * cos(lat1) + azimuth = degrees(arctan2(y, x)) + 90 + if equatorward: + south = lat0 < 0 + azimuth[south] *= -1 + return azimuth def get_mask_from_id(self, tracks): mask = zeros(self.tracks.shape, dtype=bool_) @@ -212,9 +298,12 @@ def get_mask_from_id(self, tracks): return mask def compute_index(self): + """ + If obs are not sorted by track, __first_index_of_track will be unusable + """ if self.__first_index_of_track is None: s = self.tracks.max() + 1 - # Doesn't work => core dump with numba, maybe he wait i8 instead of u4 + # Doesn't work => core dump with numba, maybe he wants i8 instead of u4 # self.__first_index_of_track = -ones(s, self.tracks.dtype) # self.__obs_by_track = zeros(s, self.observation_number.dtype) self.__first_index_of_track = -ones(s, "i8") @@ -223,6 +312,33 @@ def compute_index(self): compute_index(self.tracks, self.__first_index_of_track, self.__obs_by_track) logger.debug("... OK") + @classmethod + def concatenate(cls, observations): + eddies = super().concatenate(observations) + last_track = 0 + i_start = 0 + for obs in observations: + nb_obs = len(obs) + sl = slice(i_start, i_start + nb_obs) + new_track = obs.track + last_track + eddies.track[sl] = new_track + last_track = new_track.max() + 1 + i_start += nb_obs + return eddies + + def count_by_track(self, mask): + """ + Count by track + + :param array[bool] mask: Mask of boolean count +1 if true + :return: Return count by track + :rtype: array + """ + s = self.tracks.max() + 1 + obs_by_track = zeros(s, "i4") + count_by_track(self.tracks, mask, obs_by_track) + return obs_by_track + @property def index_from_track(self): self.compute_index() @@ -233,21 +349,50 @@ def nb_obs_by_track(self): self.compute_index() return self.__obs_by_track + @property + def lifetime(self): + """Return lifetime for each observation""" + return self.nb_obs_by_track.repeat(self.nb_obs_by_track) + + @property + def age(self): + """Return age in % for each observation, will be [0:100]""" + return self.n.astype("f4") / (self.lifetime - 1) * 100.0 + def extract_ids(self, tracks): mask = self.get_mask_from_id(array(tracks)) - return self.__extract_with_mask(mask) + return self.extract_with_mask(mask) + + def extract_toward_direction(self, west=True, delta_lon=None): + """ + Get trajectories going in the same direction + + :param bool west: Only eastward eddies if True return westward + :param None,float delta_lon: Only eddies with more than delta_lon span in longitude + :return: Only eastern eddy + :rtype: __class__ + + .. minigallery:: py_eddy_tracker.TrackEddiesObservations.extract_toward_direction + """ + lon = self.longitude + i0, nb = self.index_from_track, self.nb_obs_by_track + i1 = i0 - 1 + nb + d_lon = lon[i1] - lon[i0] + m = d_lon < 0 if west else d_lon > 0 + if delta_lon is not None: + m *= delta_lon < abs(d_lon) + m = m.repeat(nb) + return self.extract_with_mask(m) def extract_first_obs_in_box(self, res): - data = empty( - self.obs.shape, dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")] - ) + data = empty(len(self), dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]) data["lon"] = self.longitude - self.longitude % res data["lat"] = self.latitude - self.latitude % res - data["track"] = self.obs["track"] + data["track"] = self.track _, indexs = unique(data, return_index=True) - mask = zeros(self.obs.shape, dtype="bool") + mask = zeros(len(self), dtype="bool") mask[indexs] = True - return self.__extract_with_mask(mask) + return self.extract_with_mask(mask) def extract_in_direction(self, direction, value=0): nb_obs = self.nb_obs_by_track @@ -264,38 +409,50 @@ def extract_in_direction(self, direction, value=0): mask = d_lon < 0 if "W" == direction else d_lon > 0 mask &= abs(d_lon) > value mask = mask.repeat(nb_obs) - return self.__extract_with_mask(mask) + return self.extract_with_mask(mask) def extract_with_length(self, bounds): + """ + Return the observations within trajectories lasting between [b0:b1] + + :param (int,int) bounds: length min and max of the desired trajectories, if -1 this bound is not used + :return: Return all trajectories having length between bounds + :rtype: TrackEddiesObservations + + .. minigallery:: py_eddy_tracker.TrackEddiesObservations.extract_with_length + """ + if len(self) == 0: + return self.empty_dataset() b0, b1 = bounds - if b0 >= 0 and b1 >= 0: + if b0 >= 0 and b1 != -1: track_mask = (self.nb_obs_by_track >= b0) * (self.nb_obs_by_track <= b1) - elif b0 < 0 and b1 >= 0: + elif b0 == -1 and b1 >= 0: track_mask = self.nb_obs_by_track <= b1 - elif b0 >= 0 and b1 < 0: - track_mask = self.nb_obs_by_track > b0 + elif b0 >= 0 and b1 == -1: + track_mask = self.nb_obs_by_track >= b0 else: logger.warning("No valid value for bounds") - raise Exception("One bounds must be positiv") - return self.__extract_with_mask(track_mask.repeat(self.nb_obs_by_track)) + raise Exception("One bound must be positive") + return self.extract_with_mask(track_mask.repeat(self.nb_obs_by_track)) def loess_filter(self, half_window, xfield, yfield, inplace=True): - track = self.obs["track"] + track = self.track x = self.obs[xfield] y = self.obs[yfield] result = track_loess_filter(half_window, x, y, track) if inplace: self.obs[yfield] = result return self + return result def median_filter(self, half_window, xfield, yfield, inplace=True): - track = self.obs["track"] - x = self.obs[xfield] - y = self.obs[yfield] - result = track_median_filter(half_window, x, y, track) + result = track_median_filter( + half_window, self[xfield], self[yfield], self.track + ) if inplace: - self.obs[yfield] = result + self[yfield][:] = result return self + return result def position_filter(self, median_half_window, loess_half_window): self.median_filter(median_half_window, "time", "lon").loess_filter( @@ -305,7 +462,7 @@ def position_filter(self, median_half_window, loess_half_window): loess_half_window, "time", "lat" ) - def __extract_with_mask( + def extract_with_mask( self, mask, full_path=False, @@ -315,15 +472,14 @@ def __extract_with_mask( ): """ Extract a subset of observations - Args: - mask: mask to select observations - full_path: extract full path if only one part is selected - remove_incomplete: delete path which are not fully selected - compress_id: resample track number to use a little range - reject_virtual: if track are only virtual in selection we remove track - Returns: - same object with selected observations + :param array(bool) mask: mask to select observations + :param bool full_path: extract the full trajectory if only one part is selected + :param bool remove_incomplete: delete trajectory if not fully selected + :param bool compress_id: resample trajectory number to use a smaller range + :param bool reject_virtual: if only virtuals are selected, the trajectory is removed + :return: same object with the selected observations + :rtype: self.__class__ """ if full_path and remove_incomplete: logger.warning( @@ -333,7 +489,7 @@ def __extract_with_mask( if full_path: if reject_virtual: - mask *= ~self.obs["virtual"].astype("bool") + mask *= ~self.virtual.astype("bool") tracks = unique(self.tracks[mask]) mask = self.get_mask_from_id(tracks) elif remove_incomplete: @@ -344,107 +500,196 @@ def __extract_with_mask( new = self.__class__.new_like(self, nb_obs) new.sign_type = self.sign_type if nb_obs == 0: - logger.warning("Empty dataset will be created") + logger.info("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] if compress_id: - list_id = unique(new.obs["track"]) + list_id = unique(new.track) list_id.sort() id_translate = arange(list_id.max() + 1) id_translate[list_id] = arange(len(list_id)) + 1 - new.obs["track"] = id_translate[new.obs["track"]] + new.track = id_translate[new.track] return new + def shape_polygon(self, intern=False): + """ + Get the polygon enclosing each trajectory. + + The polygon merges the non-overlapping bounds of the specified contours + + :param bool intern: If True use speed contour instead of effective contour + :rtype: list(array, array) + """ + xname, yname = self.intern(intern) + return [merge(track[xname], track[yname]) for track in self.iter_track()] + + def display_shape(self, ax, ref=None, intern=False, **kwargs): + """ + This function draws the shape of each trajectory + + :param matplotlib.axes.Axes ax: ax to draw + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool intern: If True use speed contour instead of effective contour + :param dict kwargs: keyword arguments for Axes.plot + :return: matplotlib mappable + """ + if "label" in kwargs: + kwargs["label"] = self.format_label(kwargs["label"]) + if len(self) == 0: + x, y = [], [] + else: + polygons = self.shape_polygon(intern) + x, y = list(), list() + for p_ in polygons: + x.append((nan,)) + y.append((nan,)) + x.append(p_[0]) + y.append(p_[1]) + x, y = concatenate(x), concatenate(y) + if ref is not None: + x, y = wrap_longitude(x, y, ref, cut=True) + return ax.plot(x, y, **kwargs) + + def close_tracks(self, other, nb_obs_min=10, **kwargs): + """ + Get close trajectories from another atlas. + + :param self other: Atlas to compare + :param int nb_obs_min: Minimal number of overlap for one trajectory + :param dict kwargs: keyword arguments for match function + :return: return other atlas reduced to common trajectories with self + + .. warning:: + It could be a costly operation for huge dataset + """ + p0, p1 = self.period + p0_other, p1_other = other.period + if p1_other < p0 or p1 < p0_other: + return other.__class__.new_like(other, 0) + indexs = list() + for i_self, i_other, t0, t1 in self.align_on(other, bins=arange(p0, p1 + 2)): + i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs) + indexs.append(other.re_reference_index(j, i_other)) + indexs = concatenate(indexs) + tr, nb = unique(other.track[indexs], return_counts=True) + return other.extract_ids(tr[nb >= nb_obs_min]) + + def format_label(self, label): + t0, t1 = self.period + return label.format( + t0=t0, + t1=t1, + nb_obs=len(self), + nb_tracks=(self.nb_obs_by_track != 0).sum(), + ) + def plot(self, ax, ref=None, **kwargs): + """ + This function will draw path of each trajectory + + :param matplotlib.axes.Axes ax: ax to draw + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary + :param dict kwargs: keyword arguments for Axes.plot + :return: matplotlib mappable + """ if "label" in kwargs: - kwargs["label"] += " (%s eddies)" % (self.nb_obs_by_track != 0).sum() - x, y = split_line(self.longitude, self.latitude, self.tracks) - if ref is not None: - x, y = wrap_longitude(x, y, ref, cut=True) + kwargs["label"] = self.format_label(kwargs["label"]) + if len(self) == 0: + x, y = [], [] + else: + x, y = split_line(self.longitude, self.latitude, self.tracks) + if ref is not None: + x, y = wrap_longitude(x, y, ref, cut=True) return ax.plot(x, y, **kwargs) def split_network(self, intern=True, **kwargs): - """Divide each group in track - """ + """Return each group (network) divided in segments""" + # Find timestep of dataset + # FIXME : how to know exact time sampling + t = unique(self.time) + dts = t[1:] - t[:-1] + timestep = median(dts) + track_s, track_e, track_ref = build_index(self.tracks) ids = empty( len(self), dtype=[ ("group", self.tracks.dtype), ("time", self.time.dtype), - ("track", "u2"), + ("track", "u4"), ("previous_cost", "f4"), ("next_cost", "f4"), ("previous_obs", "i4"), ("next_obs", "i4"), ], ) - ids["group"], ids["time"] = self.tracks, self.time - # To store id track + ids["group"], ids["time"] = self.tracks, int_(self.time / timestep) + # Initialisation + # To store the id of the segments, the backward and forward cost associations ids["track"], ids["previous_cost"], ids["next_cost"] = 0, 0, 0 + # To store the indices of the backward and forward observations associated ids["previous_obs"], ids["next_obs"] = -1, -1 + # At the end, ids["previous_obs"] == -1 means the start of a non-split segment + # and ids["next_obs"] == -1 means the end of a non-merged segment xname, yname = self.intern(intern) + display_iteration = logger.getEffectiveLevel() == logging.INFO for i_s, i_e in zip(track_s, track_e): if i_s == i_e or self.tracks[i_s] == self.NOGROUP: continue + if display_iteration: + print(f"Network obs from {i_s} to {i_e} on {track_e[-1]}", end="\r") sl = slice(i_s, i_e) local_ids = ids[sl] + # built segments with local indices self.set_tracks(self[xname][sl], self[yname][sl], local_ids, **kwargs) - m = local_ids["previous_obs"] == -1 + # shift the local indices to the total indexation for the used observations + m = local_ids["previous_obs"] != -1 local_ids["previous_obs"][m] += i_s - m = local_ids["next_obs"] == -1 + m = local_ids["next_obs"] != -1 local_ids["next_obs"][m] += i_s + if display_iteration: + print() + ids["time"] *= timestep return ids - # ids_sort = ids[new_i] - # # To be able to follow indices sorting - # reverse_sort = empty(new_i.shape[0], dtype="u4") - # reverse_sort[new_i] = arange(new_i.shape[0]) - # # Redirect indices - # m = ids_sort["next_obs"] != -1 - # ids_sort["next_obs"][m] = reverse_sort[ - # ids_sort["next_obs"][m] - # ] - # m = ids_sort["previous_obs"] != -1 - # ids_sort["previous_obs"][m] = reverse_sort[ - # ids_sort["previous_obs"][m] - # ] - # # print(ids_sort) - # display_network( - # x[new_i], - # y[new_i], - # ids_sort["track"], - # ids_sort["time"], - # ids_sort["next_cost"], - # ) - - def set_tracks(self, x, y, ids, window): - # Will split one group in tracks - time_index = build_index(ids["time"]) + + def set_tracks(self, x, y, ids, window, **kwargs): + """ + Split one group (network) in segments + + :param array x: coordinates of group + :param array y: coordinates of group + :param ndarray ids: several fields like time, group, ... + :param int window: number of days where observations could missed + """ + time_index = build_index((ids["time"]).astype("i4")) nb = x.shape[0] used = zeros(nb, dtype="bool") track_id = 1 - # build all polygon (need to check if wrap is needed) - polygons = [Polygon(create_vertice_from_2darray(x, y, i)) for i in range(nb)] + # build all polygons (need to check if wrap is needed) for i in range(nb): - # If observation already in one track, we go to the next one + # If the observation is already in one track, we go to the next one if used[i]: continue - self.follow_obs(i, track_id, used, ids, polygons, *time_index, window) + # Search a possible continuation (forward) + self.follow_obs(i, track_id, used, ids, x, y, *time_index, window, **kwargs) track_id += 1 + # Search a possible ancestor (backward) + self.get_previous_obs(i, ids, x, y, *time_index, window, **kwargs) @classmethod - def follow_obs(cls, i_next, track_id, used, ids, *args): + def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs): + """Associate the observations to the segments""" + while i_next != -1: # Flag used[i_next] = True # Assign id ids["track"][i_next] = track_id # Search next - i_next_ = cls.next_obs(i_next, ids, *args) + i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs) if i_next_ == -1: break ids["next_obs"][i_next] = i_next_ @@ -460,9 +705,61 @@ def follow_obs(cls, i_next, track_id, used, ids, *args): i_next = i_next_ @staticmethod - def next_obs(i_current, ids, polygons, time_s, time_e, time_ref, window): + def get_previous_obs( + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.2, + **kwargs, + ): + """Backward association of observations to the segments""" + time_cur = int_(ids["time"][i_current]) + t0, t1 = time_cur - 1 - time_ref, max(time_cur - window - time_ref, 0) + for t_step in range(t0, t1 - 1, -1): + i0, i1 = time_s[t_step], time_e[t_step] + # No observation at the time step + if i0 == i1: + continue + # Search for overlaps + xi, yi, xj, yj = x[[i_current]], y[[i_current]], x[i0:i1], y[i0:i1] + ii, ij = bbox_intersection(xi, yi, xj, yj) + if len(ii) == 0: + continue + c = zeros(len(xj)) + c[ij] = vertice_overlap( + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs + ) + # We get index of maximal overlap + i = c.argmax() + c_i = c[i] + # No overlap found + if c_i == 0: + continue + ids["previous_cost"][i_current] = c_i + ids["previous_obs"][i_current] = i0 + i + break + + @staticmethod + def get_next_obs( + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.2, + **kwargs, + ): + """Forward association of observations to the segments""" time_max = time_e.shape[0] - 1 - time_cur = ids["time"][i_current] + time_cur = int_(ids["time"][i_current]) t0, t1 = time_cur + 1 - time_ref, min(time_cur + window - time_ref, time_max) if t0 > time_max: return -1 @@ -471,10 +768,15 @@ def next_obs(i_current, ids, polygons, time_s, time_e, time_ref, window): # No observation at the time step if i0 == i1: continue - # Intersection / union, to be able to separte in case of multiple inside - c = polygon_overlap(polygons[i_current], polygons[i0:i1]) - # We remove low overlap - c[c < 0.1] = 0 + # Search for overlaps + xi, yi, xj, yj = x[[i_current]], y[[i_current]], x[i0:i1], y[i0:i1] + ii, ij = bbox_intersection(xi, yi, xj, yj) + if len(ii) == 0: + continue + c = zeros(len(xj)) + c[ij] = vertice_overlap( + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs + ) # We get index of maximal overlap i = c.argmax() c_i = c[i] @@ -501,6 +803,13 @@ def compute_index(tracks, index, number): previous_track = track +@njit(cache=True) +def count_by_track(tracks, mask, number): + for track, test in zip(tracks, mask): + if test: + number[track] += 1 + + @njit(cache=True) def compute_mask_from_id(tracks, first_index, number_of_obs, mask): for track in tracks: @@ -511,13 +820,14 @@ def compute_mask_from_id(tracks, first_index, number_of_obs, mask): def track_loess_filter(half_window, x, y, track): """ Apply a loess filter on y field - Args: - window: parameter of smoother - x: must be growing for each track but could be irregular - y: field to smooth - track: field which allow to separate path - Returns: + :param int,float half_window: parameter of smoother + :param array_like x: must be growing for each track but could be irregular + :param array_like y: field to smooth + :param array_like track: field that allows to separate path + + :return: Array smoothed + :rtype: array_like """ nb = y.shape[0] @@ -554,14 +864,15 @@ def track_loess_filter(half_window, x, y, track): @njit(cache=True) def track_median_filter(half_window, x, y, track): """ - Apply a loess filter on y field - Args: - window: parameter of smoother - x: must be growing for each track but could be irregular - y: field to smooth - track: field which allow to separate path + Apply a median filter on y field + + :param int,float half_window: parameter of smoother + :param array_like x: must be growing for each track but could be irregular + :param array_like y: field to smooth + :param array_like track: field which allow to separate path - Returns: + :return: Array smoothed + :rtype: array_like """ nb = y.shape[0] diff --git a/src/py_eddy_tracker/poly.py b/src/py_eddy_tracker/poly.py index 3752f201..b5849610 100644 --- a/src/py_eddy_tracker/poly.py +++ b/src/py_eddy_tracker/poly.py @@ -1,41 +1,44 @@ # -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Method for polygon """ -from numpy import empty, where, array -from numba import njit, prange, types as numba_types +import heapq + from Polygon import Polygon +from numba import njit, prange, types as numba_types +from numpy import arctan, array, concatenate, empty, nan, ones, pi, where, zeros +from numpy.linalg import lstsq + +from .generic import build_index @njit(cache=True) -def is_left(x_line_0, y_line_0, x_line_1, y_line_1, x_test, y_test): +def is_left( + x_line_0: float, + y_line_0: float, + x_line_1: float, + y_line_1: float, + x_test: float, + y_test: float, +) -> bool: """ + Test if point is left of an infinit line. + http://geomalgorithms.com/a03-_inclusion.html - isLeft(): tests if a point is Left|On|Right of an infinite line. - Input: three points P0, P1, and P2 - Return: >0 for P2 left of the line through P0 and P1 - =0 for P2 on the line - <0 for P2 right of the line See: Algorithm 1 "Area of Triangles and Polygons" + + :param float x_line_0: + :param float y_line_0: + :param float x_line_1: + :param float y_line_1: + :param float x_test: + :param float y_test: + :return: > 0 for P2 left of the line through P0 and P1 + = 0 for P2 on the line + < 0 for P2 right of the line + :rtype: bool + """ # Vector product product = (x_line_1 - x_line_0) * (y_test - y_line_0) - (x_test - x_line_0) * ( @@ -46,6 +49,14 @@ def is_left(x_line_0, y_line_0, x_line_1, y_line_1, x_test, y_test): @njit(cache=True) def poly_contain_poly(xy_poly_out, xy_poly_in): + """ + Check if poly_in is include in poly_out. + + :param vertice xy_poly_out: + :param vertice xy_poly_in: + :return: True if poly_in is in poly_out + :rtype: bool + """ nb_elt = xy_poly_in.shape[0] x = xy_poly_in[:, 0] x_ref = xy_poly_out[0, 0] @@ -60,21 +71,146 @@ def poly_contain_poly(xy_poly_out, xy_poly_in): @njit(cache=True) -def poly_area(vertice): - p_area = 0 - nb_elt = vertice.shape[0] - for i_elt in range(1, nb_elt - 1): - p_area += vertice[i_elt, 0] * (vertice[1 + i_elt, 1] - vertice[i_elt - 1, 1]) +def poly_area_vertice(v): + """ + Compute area from vertice. + + :param vertice v: polygon vertice + :return: area of polygon in coordinates unit + :rtype: float + """ + return poly_area(v[:, 0], v[:, 1]) + + +@njit(cache=True) +def poly_area(x, y): + """ + Must be called with local coordinates (in m, to get an area in m²). + + :param array x: + :param array y: + :return: area of polygon in coordinates unit + :rtype: float + """ + p_area = x[0] * (y[1] - y[-2]) + nb = x.shape[0] + for i in range(1, nb - 1): + p_area += x[i] * (y[1 + i] - y[i - 1]) return abs(p_area) * 0.5 +@njit(cache=True) +def convexs(x, y): + """ + Check if polygons are convex + + :param array[float] x: + :param array[float] y: + :return: True if convex + :rtype: array[bool] + """ + nb_poly = x.shape[0] + flag = empty(nb_poly, dtype=numba_types.bool_) + for i in range(nb_poly): + flag[i] = convex(x[i], y[i]) + return flag + + +@njit(cache=True) +def convex(x, y): + """ + Check if polygon is convex + + :param array[float] x: + :param array[float] y: + :return: True if convex + :rtype: bool + """ + nb = x.shape[0] + x0, y0, x1, y1, x2, y2 = x[-2], y[-2], x[-1], y[-1], x[1], y[1] + # if first is left it must be always left if it's right it must be always right + ref = is_left(x0, y0, x1, y1, x2, y2) + # We skip 0 because it's same than -1 + # We skip 1 because we tested previously + for i in range(2, nb): + # shift position + x0, y0, x1, y1 = x1, y1, x2, y2 + x2, y2 = x[i], y[i] + # test + if ref != is_left(x0, y0, x1, y1, x2, y2): + return False + return True + + +@njit(cache=True) +def get_convex_hull(x, y): + """ + Get convex polygon which enclosed current polygon + + Work only if contour is describe anti-clockwise + + :param array[float] x: + :param array[float] y: + :return: a convex polygon + :rtype: array,array + """ + nb = x.shape[0] - 1 + indices = list() + # leftmost point + i_first = x[:-1].argmin() + indices.append(i_first) + i_next = (i_first + 1) % nb + # Will define bounds line + x0, y0, x1, y1 = x[i_first], y[i_first], x[i_next], y[i_next] + xf, yf = x0, y0 + # we will check if no point are right + while True: + i_test = (i_next + 1) % nb + # value to test + xt, yt = x[i_test], y[i_test] + # We will test all the position until we touch first one, + # If all next position are on the left we keep x1, y1 + # if not we will replace by xt,yt which are more outter + while is_left(x0, y0, x1, y1, xt, yt): + i_test += 1 + i_test %= nb + if i_test == i_first: + x0, y0 = x1, y1 + indices.append(i_next) + i_next += 1 + i_next %= nb + x1, y1 = x[i_next], y[i_next] + break + xt, yt = x[i_test], y[i_test] + if i_test != i_first: + i_next = i_test + x1, y1 = x[i_next], y[i_next] + if i_next == (i_first - 1) % nb: + if is_left(x0, y0, x1, y1, xf, yf): + indices.append(i_next) + break + indices.append(i_first) + indices = array(indices) + return x[indices], y[indices] + + @njit(cache=True) def winding_number_poly(x, y, xy_poly): + """ + Check if x,y is in poly. + + :param float x: x to test + :param float y: y to test + :param vertice xy_poly: vertice of polygon + :return: wn == 0 if x,y is not in poly + :retype: int + """ nb_elt = xy_poly.shape[0] wn = 0 # loop through all edges of the polygon for i_elt in range(nb_elt): if i_elt + 1 == nb_elt: + # We close polygon with first value (no need to duplicate first value) x_next = xy_poly[0, 0] y_next = xy_poly[0, 1] else: @@ -96,13 +232,20 @@ def winding_number_poly(x, y, xy_poly): @njit(cache=True) def winding_number_grid_in_poly(x_1d, y_1d, i_x0, i_x1, x_size, i_y0, xy_poly): """ + Return index for each grid coordinates within contour. + http://geomalgorithms.com/a03-_inclusion.html - wn_PnPoly(): winding number test for a point in a polygon - Input: P = a point, - V[] = vertex points of a polygon V[n+1] with V[n]=V[0] - Return: wn = the winding number (=0 only when P is outside) + + :param array x_1d: x of local grid + :param array y_1d: y of local grid + :param int i_x0: int to add at x index to have index in global grid + :param int i_x1: last index in global grid + :param int x_size: number of x in global grid + :param int i_y0: int to add at y index to have index in global grid + :param vertice xy_poly: vertices of polygon which must contain pixel + :return: Return index in xy_poly + :rtype: (int,int) """ - # the winding number counter nb_x, nb_y = len(x_1d), len(y_1d) wn = empty((nb_x, nb_y), dtype=numba_types.bool_) for i in prange(nb_x): @@ -118,9 +261,66 @@ def winding_number_grid_in_poly(x_1d, y_1d, i_x0, i_x1, x_size, i_y0, xy_poly): return i_x, i_y +@njit(cache=True, fastmath=True) +def close_center(x0, y0, x1, y1, delta=0.1): + """ + Compute an overlap with circle parameter and return a percentage + + :param array x0: x centers of dataset 0 + :param array y0: y centers of dataset 0 + :param array x1: x centers of dataset 1 + :param array y1: y centers of dataset 1 + :return: Result of cost function + :rtype: array + """ + nb0, nb1 = x0.shape[0], x1.shape[0] + i, j, c = list(), list(), list() + for i0 in range(nb0): + xi0, yi0 = x0[i0], y0[i0] + for i1 in range(nb1): + d_x = x1[i1] - xi0 + if abs(d_x) > 180: + d_x = (d_x + 180) % 360 - 180 + if abs(d_x) > delta: + continue + if abs(y1[i1] - yi0) > delta: + continue + i.append(i0), j.append(i1), c.append(1) + return array(i), array(j), array(c) + + +@njit(cache=True) +def create_meshed_particles(lons, lats, step): + x_out, y_out, i_out = list(), list(), list() + nb = lons.shape[0] + for i in range(nb): + lon, lat = lons[i], lats[i] + vertice = create_vertice(*reduce_size(lon, lat)) + lon_min, lon_max = lon.min(), lon.max() + lat_min, lat_max = lat.min(), lat.max() + y0 = lat_min - lat_min % step + x = lon_min - lon_min % step + while x <= lon_max: + y = y0 + while y <= lat_max: + if winding_number_poly(x, y, vertice): + x_out.append(x), y_out.append(y), i_out.append(i) + y += step + x += step + return array(x_out), array(y_out), array(i_out) + + @njit(cache=True, fastmath=True) def bbox_intersection(x0, y0, x1, y1): - """compute bbox to check if there are a bbox intersection + """ + Compute bbox to check if there are a bbox intersection. + + :param array x0: x for polygon list 0 + :param array y0: y for polygon list 0 + :param array x1: x for polygon list 1 + :param array y1: y for polygon list 1 + :return: index of each polygon bbox which have an intersection + :rtype: (int, int) """ nb0 = x0.shape[0] nb1 = x1.shape[0] @@ -147,11 +347,19 @@ def bbox_intersection(x0, y0, x1, y1): continue i.append(i0) j.append(i1) - return array(i), array(j) + return array(i, dtype=numba_types.int32), array(j, dtype=numba_types.int32) @njit(cache=True) def create_vertice(x, y): + """ + Return polygon vertice. + + :param array x: + :param array y: + :return: Return polygon vertice + :rtype: vertice + """ nb = x.shape[0] v = empty((nb, 2), dtype=x.dtype) for i in range(nb): @@ -162,6 +370,15 @@ def create_vertice(x, y): @njit(cache=True) def create_vertice_from_2darray(x, y, index): + """ + Choose a polygon in x,y list and return vertice. + + :param array x: + :param array y: + :param int index: + :return: Return the vertice of polygon + :rtype: vertice + """ _, nb = x.shape v = empty((nb, 2), dtype=x.dtype) for i in range(nb): @@ -172,6 +389,17 @@ def create_vertice_from_2darray(x, y, index): @njit(cache=True) def get_wrap_vertice(x0, y0, x1, y1, i): + """ + Return a vertice for each polygon and check that use same reference coordinates. + + :param array x0: x for polygon list 0 + :param array y0: y for polygon list 0 + :param array x1: x for polygon list 1 + :param array y1: y for polygon list 1 + :param int i: index to use fot the 2 list + :return: return two compatible vertice + :rtype: (vertice, vertice) + """ x0_, x1_ = x0[i], x1[i] if abs(x0_[0] - x1_[0]) > 180: ref = x0_[0] - x0.dtype.type(180) @@ -179,7 +407,63 @@ def get_wrap_vertice(x0, y0, x1, y1, i): return create_vertice(x0_, y0[i]), create_vertice(x1_, y1[i]) -def vertice_overlap(x0, y0, x1, y1, minimal_area=False): +def merge(x, y): + """ + Merge all polygon of the list + + :param array x: 2D array for a list of polygon + :param array y: 2D array for a list of polygon + :return: Polygons which enclosed all + :rtype: array, array + """ + nb = x.shape[0] + p = None + for i in range(nb): + p_ = Polygon(create_vertice(x[i], y[i])) + if p is None: + p = p_ + else: + p += p_ + x, y = list(), list() + for p_ in p: + p_ = array(p_).T + x.append((nan,)) + y.append((nan,)) + x.append(p_[0]) + y.append(p_[1]) + return concatenate(x), concatenate(y) + + +def vertice_overlap( + x0, y0, x1, y1, minimal_area=False, p1_area=False, hybrid_area=False, min_overlap=0 +): + r""" + Return percent of overlap for each item. + + :param array x0: x for polygon list 0 + :param array y0: y for polygon list 0 + :param array x1: x for polygon list 1 + :param array y1: y for polygon list 1 + :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union + :param bool p1_area: If True, function will compute intersection/p1 polygon, else intersection/union + :param bool hybrid_area: If True, function will compute like union, + but if cost is under min_overlap, obs is kept in case of fully included + :param float min_overlap: under this value cost is set to zero + :return: Result of cost function + :rtype: array + + By default + + .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{Union(P_0,P_1)_{area}} + + If minimal area: + + .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{min(P_{0 area},P_{1 area})} + + If P1 area: + + .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{P_{1 area}} + """ nb = x0.shape[0] cost = empty(nb) for i in range(nb): @@ -190,25 +474,499 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): # Area of intersection intersection = (p0 & p1).area() # we divide intersection with the little one result from 0 to 1 + if intersection == 0: + cost[i] = 0 + continue + p0_area_, p1_area_ = p0.area(), p1.area() if minimal_area: - cost[i] = intersection / min(p0.area(), p1.area()) + cost_ = intersection / min(p0_area_, p1_area_) + # we divide intersection with p1 + elif p1_area: + cost_ = intersection / p1_area_ # we divide intersection with polygon merging result from 0 to 1 else: - cost[i] = intersection / (p0 + p1).area() + cost_ = intersection / (p0_area_ + p1_area_ - intersection) + if cost_ >= min_overlap: + cost[i] = cost_ + else: + if ( + hybrid_area + and cost_ != 0 + and (intersection / min(p0_area_, p1_area_)) > 0.99 + ): + cost[i] = cost_ + else: + cost[i] = 0 return cost def polygon_overlap(p0, p1, minimal_area=False): + """ + Return percent of overlap for each item. + + :param list(Polygon) p0: List of polygon to compare with p1 list + :param list(Polygon) p1: List of polygon to compare with p0 list + :param bool minimal_area: If True, function will compute intersection/smaller polygon, else intersection/union + :return: Result of cost function + :rtype: array + """ nb = len(p1) cost = empty(nb) for i in range(nb): p_ = p1[i] # Area of intersection intersection = (p0 & p_).area() - # we divide intersection with the little one result from 0 to 1 + # we divide the intersection by the smaller area, result from 0 to 1 if minimal_area: cost[i] = intersection / min(p0.area(), p_.area()) - # we divide intersection with polygon merging result from 0 to 1 + # we divide the intersection by the merged polygons area, result from 0 to 1 else: cost[i] = intersection / (p0 + p_).area() return cost + + +# FIXME: only one function is needed +@njit(cache=True) +def fit_circle(x, y): + """ + From a polygon, function will fit a circle. + + Must be called with local coordinates (in m, to get a radius in m). + + :param array x: x of polygon + :param array y: y of polygon + :return: x0, y0, radius, shape_error + :rtype: (float,float,float,float) + """ + nb_elt = x.shape[0] + + # last coordinates == first + x_mean = x[1:].mean() + y_mean = y[1:].mean() + + norme = (x[1:] - x_mean) ** 2 + (y[1:] - y_mean) ** 2 + norme_max = norme.max() + scale = norme_max**0.5 + + # Form matrix equation and solve it + # Maybe put f4 + datas = ones((nb_elt - 1, 3)) + datas[:, 0] = 2.0 * (x[1:] - x_mean) / scale + datas[:, 1] = 2.0 * (y[1:] - y_mean) / scale + + (x0, y0, radius), _, _, _ = lstsq(datas, norme / norme_max) + + # Unscale data and get circle variables + radius += x0**2 + y0**2 + radius **= 0.5 + x0 *= scale + y0 *= scale + # radius of fit circle + radius *= scale + # center X-position of fit circle + x0 += x_mean + # center Y-position of fit circle + y0 += y_mean + + err = shape_error(x, y, x0, y0, radius) + return x0, y0, radius, err + + +@njit(cache=True) +def fit_ellipse(x, y): + r""" + From a polygon, function will fit an ellipse. + + Must be call with local coordinates (in m, to get a radius in m). + + .. math:: (\frac{x - x_0}{a})^2 + (\frac{y - y_0}{b})^2 = 1 + + .. math:: (\frac{x^2 - 2 * x * x_0 + x_0 ^2}{a^2}) + \frac{y^2 - 2 * y * y_0 + y_0 ^2}{b^2}) = 1 + + In case of angle + https://en.wikipedia.org/wiki/Ellipse + + """ + nb = x.shape[0] + datas = ones((nb, 5), dtype=x.dtype) + datas[:, 0] = x**2 + datas[:, 1] = x * y + datas[:, 2] = y**2 + datas[:, 3] = x + datas[:, 4] = y + (a, b, c, d, e), _, _, _ = lstsq(datas, ones(nb, dtype=x.dtype)) + det = b**2 - 4 * a * c + if det > 0: + print(det) + x0 = (2 * c * d - b * e) / det + y0 = (2 * a * e - b * d) / det + + AB1 = 2 * (a * e**2 + c * d**2 - b * d * e - det) + AB2 = a + c + AB3 = ((a - c) ** 2 + b**2) ** 0.5 + A = -((AB1 * (AB2 + AB3)) ** 0.5) / det + B = -((AB1 * (AB2 - AB3)) ** 0.5) / det + theta = arctan((c - a - AB3) / b) + return x0, y0, A, B, theta + + +@njit(cache=True) +def fit_circle_(x, y): + r""" + From a polygon, function will fit a circle. + + Must be call with local coordinates (in m, to get a radius in m). + + .. math:: (x_i - x_0)^2 + (y_i - y_0)^2 = r^2 + .. math:: x_i^2 - 2 x_i x_0 + x_0^2 + y_i^2 - 2 y_i y_0 + y_0^2 = r^2 + .. math:: 2 x_0 x_i + 2 y_0 y_i + r^2 - x_0^2 - y_0^2 = x_i^2 + y_i^2 + + we get this linear equation + + .. math:: a X + b Y + c = Z + + where : + + .. math:: a = 2 x_0 , b = 2 y_0 , c = r^2 - x_0^2 - y_0^2 + .. math:: X = x_i , Y = y_i , Z = x_i^2 + y_i^2 + + Solutions: + + .. math:: x_0 = a / 2 , y_0 = b / 2 , r = \sqrt{c + x_0^2 + y_0^2} + + + :param array x: x of polygon + :param array y: y of polygon + :return: x0, y0, radius, shape_error + :rtype: (float,float,float,float) + + .. plot:: + + import matplotlib.pyplot as plt + import numpy as np + from py_eddy_tracker.poly import fit_circle_ + from py_eddy_tracker.generic import build_circle + + V = np.array(((2, 2, 3, 3, 2), (-10, -9, -9, -10, -10)), dtype="f4") + x0, y0, radius, err = fit_circle_(V[0], V[1]) + ax = plt.subplot(111) + ax.set_aspect("equal") + ax.grid(True) + ax.plot(*build_circle(x0, y0, radius), "r") + ax.plot(x0, y0, "r+") + ax.plot(*V, "b.") + plt.show() + """ + datas = ones((x.shape[0] - 1, 3), dtype=x.dtype) + # we skip first position which are the same than the last + datas[:, 0] = x[1:] + datas[:, 1] = y[1:] + # Linear regression + (a, b, c), _, _, _ = lstsq(datas, x[1:] ** 2 + y[1:] ** 2) + x0, y0 = a / 2.0, b / 2.0 + radius = (c + x0**2 + y0**2) ** 0.5 + err = shape_error(x, y, x0, y0, radius) + return x0, y0, radius, err + + +@njit(cache=True, fastmath=True) +def shape_error(x, y, x0, y0, r): + r""" + With a polygon(x,y) in local coordinates. + + and circle properties(x0, y0, r), function compute a shape error: + + .. math:: ShapeError = \frac{Polygon_{area} + Circle_{area} - 2 * Intersection_{area}}{Circle_{area}} * 100 + + When error > 100, area of difference is bigger than circle area + + :param array x: x of polygon + :param array y: y of polygon + :param float x0: x center of circle + :param float y0: y center of circle + :param float r: radius of circle + :return: shape error + :rtype: float + """ + # circle area + c_area = (r**2) * pi + p_area = poly_area(x, y) + nb = x.shape[0] + x, y = x.copy(), y.copy() + # Find distance between circle center and polygon + for i in range(nb): + dx, dy = x[i] - x0, y[i] - y0 + rd = r / (dx**2 + dy**2) ** 0.5 + if rd < 1: + x[i] = x0 + dx * rd + y[i] = y0 + dy * rd + return 100 + (p_area - 2 * poly_area(x, y)) / c_area * 100 + + +@njit(cache=True, fastmath=True) +def get_pixel_in_regular(vertices, x_c, y_c, x_start, x_stop, y_start, y_stop): + """ + Get a pixel list of a regular grid contain in a contour. + + :param array_like vertices: contour vertice (N,2) + :param array_like x_c: longitude coordinate of grid + :param array_like y_c: latitude coordinate of grid + :param int x_start: west index of contour + :param int y_start: east index of contour + :param int x_stop: south index of contour + :param int y_stop: north index of contour + """ + if x_stop < x_start: + x_ref = vertices[0, 0] + x_array = ( + (concatenate((x_c[x_start:], x_c[:x_stop])) - x_ref + 180) % 360 + + x_ref + - 180 + ) + return winding_number_grid_in_poly( + x_array, + y_c[y_start:y_stop], + x_start, + x_stop, + x_c.shape[0], + y_start, + vertices, + ) + else: + return winding_number_grid_in_poly( + x_c[x_start:x_stop], + y_c[y_start:y_stop], + x_start, + x_stop, + x_c.shape[0], + y_start, + vertices, + ) + + +@njit(cache=True) +def tri_area2(x, y, i0, i1, i2): + """Double area of triangle + + :param array x: + :param array y: + :param int i0: indice of first point + :param int i1: indice of second point + :param int i2: indice of third point + :return: area + :rtype: float + """ + x0, y0 = x[i0], y[i0] + x1, y1 = x[i1], y[i1] + x2, y2 = x[i2], y[i2] + p_area2 = (x0 - x2) * (y1 - y0) - (x0 - x1) * (y2 - y0) + return abs(p_area2) + + +@njit(cache=True) +def visvalingam(x, y, fixed_size=18): + """Polygon simplification with visvalingam algorithm + + X, Y are considered like a polygon, the next point after the last one is the first one + + :param array x: + :param array y: + :param int fixed_size: array size of out + :return: + New (x, y) array, last position will be equal to first one, if array size is 6, + there is only 5 point. + :rtype: array,array + + .. plot:: + + import matplotlib.pyplot as plt + import numpy as np + from py_eddy_tracker.poly import visvalingam + + x = np.array([1, 2, 3, 4, 5, 6.75, 6, 1]) + y = np.array([-0.5, -1.5, -1, -1.75, -1, -1, -0.5, -0.5]) + ax = plt.subplot(111) + ax.set_aspect("equal") + ax.grid(True), ax.set_ylim(-2, -.2) + ax.plot(x, y, "r", lw=5) + ax.plot(*visvalingam(x,y,6), "b", lw=2) + plt.show() + """ + # TODO : in case of original size lesser than fixed size, jump at the end + nb = x.shape[0] + nb_ori = nb + # Get indice of first triangle + i0, i1 = nb - 2, nb - 1 + # Init heap with first area and tiangle + h = [(tri_area2(x, y, i0, i1, 0), (i0, i1, 0))] + # Roll index for next one + i0 = i1 + i1 = 0 + # Index of previous valid point + i_previous = empty(nb, dtype=numba_types.int64) + # Index of next valid point + i_next = empty(nb, dtype=numba_types.int64) + # Mask of removed + removed = zeros(nb, dtype=numba_types.bool_) + i_previous[0] = -1 + i_next[0] = -1 + for i in range(1, nb): + i_previous[i] = -1 + i_next[i] = -1 + # We add triangle area for all triangle + heapq.heappush(h, (tri_area2(x, y, i0, i1, i), (i0, i1, i))) + i0 = i1 + i1 = i + # we continue until we are equal to nb_pt + while nb >= fixed_size: + # We pop lower area + _, (i0, i1, i2) = heapq.heappop(h) + # We check if triangle is valid(i0 or i2 not removed) + if removed[i0] or removed[i2]: + # In this cas nothing to do + continue + # Flag obs like removed + removed[i1] = True + # We count point still valid + nb -= 1 + # Modify index for the next and previous, we jump over i1 + i_previous[i2] = i0 + i_next[i0] = i2 + # We insert 2 triangles which are modified by the deleted point + # Previous triangle + i_1 = i_previous[i0] + if i_1 == -1: + i_1 = (i0 - 1) % nb_ori + heapq.heappush(h, (tri_area2(x, y, i_1, i0, i2), (i_1, i0, i2))) + # Previous triangle + i3 = i_next[i2] + if i3 == -1: + i3 = (i2 + 1) % nb_ori + heapq.heappush(h, (tri_area2(x, y, i0, i2, i3), (i0, i2, i3))) + x_new, y_new = empty(fixed_size, dtype=x.dtype), empty(fixed_size, dtype=y.dtype) + j = 0 + for i, flag in enumerate(removed): + if not flag: + x_new[j] = x[i] + y_new[j] = y[i] + j += 1 + # we copy first value to fill array end + x_new[j:] = x_new[0] + y_new[j:] = y_new[0] + return x_new, y_new + + +@njit(cache=True) +def reduce_size(x, y): + """ + Reduce array size if last position is repeated, in order to save compute time + + :param array x: longitude + :param array y: latitude + + :return: reduce arrays x,y + :rtype: ndarray,ndarray + """ + i = x.shape[0] + x0, y0 = x[0], y[0] + while True: + i -= 1 + if x[i] != x0 or y[i] != y0: + i += 1 + # In case of virtual obs all value could be fill with same value, to avoid empty array + i = max(3, i) + return x[:i], y[:i] + + +@njit(cache=True) +def group_obs(x, y, step, nb_x): + """Get index k_box for each box, and indexes to sort""" + nb = x.size + i = empty(nb, dtype=numba_types.uint32) + for k in range(nb): + i[k] = box_index(x[k], y[k], step, nb_x) + return i, i.argsort(kind="mergesort") + + +@njit(cache=True) +def box_index(x, y, step, nb_x): + """Return k_box index for each value""" + return numba_types.uint32((x % 360) // step + nb_x * ((y + 90) // step)) + + +@njit(cache=True) +def box_indexes(x, y, step): + """Return i_box,j_box index for each value""" + return numba_types.uint32((x % 360) // step), numba_types.uint32((y + 90) // step) + + +@njit(cache=True) +def poly_indexs(x_p, y_p, x_c, y_c): + """ + Index of contour for each postion inside a contour, -1 in case of no contour + + :param array x_p: longitude to test (must be defined, no nan) + :param array y_p: latitude to test (must be defined, no nan) + :param array x_c: longitude of contours + :param array y_c: latitude of contours + """ + nb_x = 360 + step = 1.0 + i, i_order = group_obs(x_p, y_p, step, nb_x) + nb_p = x_p.shape[0] + nb_c = x_c.shape[0] + indexs = -ones(nb_p, dtype=numba_types.int32) + # Adress table to get test bloc + start_index, end_index, i_first = build_index(i[i_order]) + nb_bloc = end_index.size + for i_contour in range(nb_c): + # Build vertice and box included contour + x_, y_ = reduce_size(x_c[i_contour], y_c[i_contour]) + x_c_min, y_c_min = x_.min(), y_.min() + x_c_max, y_c_max = x_.max(), y_.max() + v = create_vertice(x_, y_) + i0, j0 = box_indexes(x_c_min, y_c_min, step) + i1, j1 = box_indexes(x_c_max, y_c_max, step) + # i0 could be greater than i1, (x_c is always continious) so you could have a contour over bound + if i0 > i1: + i1 += nb_x + for i_x in range(i0, i1 + 1): + # we force i_x in 0 360 range + i_x %= nb_x + for i_y in range(j0, j1 + 1): + # Get box indices + i_box = i_x + nb_x * i_y - i_first + # Indice must be in table range + if i_box < 0 or i_box >= nb_bloc: + continue + for i_p_ordered in range(start_index[i_box], end_index[i_box]): + i_p = i_order[i_p_ordered] + if indexs[i_p] != -1: + continue + y = y_p[i_p] + if y > y_c_max: + continue + if y < y_c_min: + continue + # Normalize longitude at +-180° around x_c_min + x = (x_p[i_p] - x_c_min + 180) % 360 + x_c_min - 180 + if x > x_c_max: + continue + if x < x_c_min: + continue + if winding_number_poly(x, y, v) != 0: + indexs[i_p] = i_contour + return indexs + + +@njit(cache=True) +def insidepoly(x_p, y_p, x_c, y_c): + """ + True for each postion inside a contour + + :param array x_p: longitude to test + :param array y_p: latitude to test + :param array x_c: longitude of contours + :param array y_c: latitude of contours + """ + return poly_indexs(x_p, y_p, x_c, y_c) != -1 diff --git a/src/py_eddy_tracker/tracking.py b/src/py_eddy_tracker/tracking.py index b0d78191..b64b6fcc 100644 --- a/src/py_eddy_tracker/tracking.py +++ b/src/py_eddy_tracker/tracking.py @@ -1,36 +1,36 @@ # -*- coding: utf-8 -*- -# -*- coding: utf-8 -*- """ -=========================================================================== -This file is part of py-eddy-tracker. - - py-eddy-tracker is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - py-eddy-tracker is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with py-eddy-tracker. If not, see . - -Copyright (c) 2014-2020 by Evan Mason -Email: evanmason@gmail.com -=========================================================================== +Class to store link between observations """ - -from datetime import timedelta, datetime -from py_eddy_tracker.observations.observation import EddiesObservations, VirtualEddiesObservations -from py_eddy_tracker.observations.tracking import TrackEddiesObservations -from numpy import bool_, array, arange, ones, setdiff1d, zeros, uint16, where, empty, isin, unique, concatenate, \ - ma -from netCDF4 import Dataset, default_fillvals +from datetime import datetime, timedelta +import json import logging import platform +from tarfile import ExFileObject + +from netCDF4 import Dataset, default_fillvals from numba import njit, types as numba_types +from numpy import ( + arange, + array, + bool_, + concatenate, + empty, + isin, + ma, + ones, + setdiff1d, + uint16, + unique, + where, + zeros, +) + +from py_eddy_tracker.observations.observation import ( + EddiesObservations, + VirtualEddiesObservations, +) +from py_eddy_tracker.observations.tracking import TrackEddiesObservations logger = logging.getLogger("pet") @@ -50,28 +50,46 @@ class Correspondances(list): """Object to store correspondances And run tracking """ + UINT32_MAX = 4294967295 # Prolongation limit to 255 - VIRTUAL_DTYPE = 'u1' + VIRTUAL_DTYPE = "u1" # ID limit to 4294967295 - ID_DTYPE = 'u4' + ID_DTYPE = "u4" # Track limit to 65535 - N_DTYPE = 'u2' - - def __init__(self, datasets, virtual=0, class_method=None, previous_correspondance=None): + N_DTYPE = "u2" + + def __init__( + self, + datasets, + virtual=0, + class_method=None, + class_kw=None, + previous_correspondance=None, + memory=False, + ): """Initiate tracking + + :param list(str) datasets: A sorted list of filename which contains eddy observations to track + :param class class_method: A class which tell how to track + :param dict class_kw: keyword argument to setup class + :param Correspondances previous_correspondance: A previous correspondance object if you want continue tracking + :param bool memory: identification file are load in memory before to be open with netcdf """ super().__init__() # Correspondance dtype - self.correspondance_dtype = [('in', 'u2'), - ('out', 'u2'), - ('id', self.ID_DTYPE), - ('cost_value', 'f4') - ] + self.correspondance_dtype = [ + ("in", "u2"), + ("out", "u2"), + ("id", self.ID_DTYPE), + ("cost_value", "f4"), + ] if class_method is None: self.class_method = EddiesObservations else: self.class_method = class_method + self.class_kw = dict() if class_kw is None else class_kw + self.memory = memory # To count ID self.current_id = 0 @@ -93,15 +111,18 @@ def __init__(self, datasets, virtual=0, class_method=None, previous_correspondan # Correspondance to prolongate self.filename_previous_correspondance = previous_correspondance - self.previous_correspondance = self.load_compatible(self.filename_previous_correspondance) + self.previous_correspondance = self.load_compatible( + self.filename_previous_correspondance + ) if self.virtual: # Add field to dtype to follow virtual observations self.correspondance_dtype += [ # True if it isn't a real obs - ('virtual', bool_), + ("virtual", bool_), # Length of virtual segment - ('virtual_length', self.VIRTUAL_DTYPE)] + ("virtual_length", self.VIRTUAL_DTYPE), + ] # Array to simply merged self.nb_obs_by_tracks = None @@ -114,14 +135,16 @@ def _copy(self): datasets=self.datasets, virtual=self.nb_virtual, class_method=self.class_method, - previous_correspondance=self.filename_previous_correspondance) + class_kw=self.class_kw, + previous_correspondance=self.filename_previous_correspondance, + ) for i in self: new.append(i) new.current_id = self.current_id new.nb_link_max = self.nb_link_max new.nb_obs = self.nb_obs new.prepare_merging() - logger.debug('Copy done') + logger.debug("Copy done") return new def reset_dataset_cache(self): @@ -137,35 +160,42 @@ def period(self): """ date_start = datetime(1950, 1, 1) + timedelta( - int(self.class_method.load_file(self.datasets[0]).obs['time'][0])) + self.class_method.load_file(self.datasets[0]).time[0] + ) date_stop = datetime(1950, 1, 1) + timedelta( - int(self.class_method.load_file(self.datasets[-1]).obs['time'][0])) + self.class_method.load_file(self.datasets[-1]).time[0] + ) return date_start, date_stop - def swap_dataset(self, dataset, raw_data=False): - """ Swap to next dataset - """ + def swap_dataset(self, dataset, *args, **kwargs): + """Swap to next dataset""" self.previous2_obs = self.previous_obs self.previous_obs = self.current_obs - self.current_obs = self.class_method.load_file(dataset, raw_data=raw_data) + kwargs = kwargs.copy() + kwargs.update(self.class_kw) + if self.memory: + with open(dataset, "rb") as h: + self.current_obs = self.class_method.load_file(h, *args, **kwargs) + else: + self.current_obs = self.class_method.load_file(dataset, *args, **kwargs) def merge_correspondance(self, other): # Verify compliance of file if self.nb_virtual != other.nb_virtual: - raise Exception('Different method of tracking') + raise Exception("Different method of tracking") # Determine junction i = where(other.datasets == array(self.datasets[-1]))[0] if len(i) != 1: - raise Exception('More than one intersection') + raise Exception("More than one intersection") # Merge # Create a hash table - translate = empty(other.current_id, dtype='u4') + translate = empty(other.current_id, dtype="u4") translate[:] = self.UINT32_MAX - translate[other[i[0] - 1]['id']] = self[-1]['id'] + translate[other[i[0] - 1]["id"]] = self[-1]["id"] - nb_max = other[i[0] - 1]['id'].max() + nb_max = other[i[0] - 1]["id"].max() mask = translate == self.UINT32_MAX # We won't translate previous id mask[:nb_max] = False @@ -173,71 +203,75 @@ def merge_correspondance(self, other): translate[mask] = arange(mask.sum()) + self.current_id # Translate - for items in other[i[0]:]: - items['id'] = translate[items['id']] + for items in other[i[0] :]: + items["id"] = translate[items["id"]] # Extend with other obs - self.extend(other[i[0]:]) + self.extend(other[i[0] :]) # Extend datasets list, which are bounds so we add one - self.datasets.extend(other.datasets[i[0] + 1:]) + self.datasets.extend(other.datasets[i[0] + 1 :]) # We set new id available self.current_id = translate[-1] + 1 - def store_correspondance(self, i_previous, i_current, nb_real_obs, association_cost): - """Storing correspondance in an array - """ + def store_correspondance( + self, i_previous, i_current, nb_real_obs, association_cost + ): + """Storing correspondance in an array""" # Create array to store correspondance data correspondance = array(i_previous, dtype=self.correspondance_dtype) if self.virtual: - correspondance['virtual_length'][:] = 255 + correspondance["virtual_length"][:] = 255 # index from current_obs - correspondance['out'] = i_current - correspondance['cost_value'] = association_cost + correspondance["out"] = i_current + correspondance["cost_value"] = association_cost if self.virtual: # if index in previous dataset is bigger than real obs number # it's a virtual data - correspondance['virtual'] = i_previous >= nb_real_obs + correspondance["virtual"] = i_previous >= nb_real_obs if self.previous2_obs is None: # First time we set ID (Program starting) nb_match = i_previous.shape[0] # Set an id for each match - correspondance['id'] = self.id_generator(nb_match) + correspondance["id"] = self.id_generator(nb_match) self.append(correspondance) return True # We set all id to UINT32_MAX - id_previous = ones(len(self.previous_obs), - dtype=self.ID_DTYPE) * self.UINT32_MAX + id_previous = ( + ones(len(self.previous_obs), dtype=self.ID_DTYPE) * self.UINT32_MAX + ) # We get old id for previously eddies tracked - id_previous[self[-1]['out']] = self[-1]['id'] + id_previous[self[-1]["out"]] = self[-1]["id"] # We store ID in correspondance if the ID is UINT32_MAX, we never # track it before - correspondance['id'] = id_previous[correspondance['in']] + correspondance["id"] = id_previous[correspondance["in"]] # We set correspondance data for virtual obs : ID/LENGTH if self.previous2_obs is not None and self.virtual: - nb_rebirth = correspondance['virtual'].sum() + nb_rebirth = correspondance["virtual"].sum() if nb_rebirth != 0: - logger.debug('%d re-birth due to prolongation with' - ' virtual observations', nb_rebirth) + logger.debug( + "%d re-birth due to prolongation with" " virtual observations", + nb_rebirth, + ) # Set id for virtual # get correspondance mask using virtual obs - m_virtual = correspondance['virtual'] + m_virtual = correspondance["virtual"] # index of virtual in virtual obs - i_virtual = correspondance['in'][m_virtual] - nb_real_obs - correspondance['id'][m_virtual] = \ - self.virtual_obs['track'][i_virtual] - correspondance['virtual_length'][m_virtual] = \ - self.virtual_obs['segment_size'][i_virtual] + i_virtual = correspondance["in"][m_virtual] - nb_real_obs + correspondance["id"][m_virtual] = self.virtual_obs["track"][i_virtual] + correspondance["virtual_length"][m_virtual] = self.virtual_obs[ + "segment_size" + ][i_virtual] # new_id is equal to UINT32_MAX we must add a new ones # we count the number of new - mask_new_id = correspondance['id'] == self.UINT32_MAX + mask_new_id = correspondance["id"] == self.UINT32_MAX nb_new_tracks = mask_new_id.sum() - logger.debug('%d birth in this step', nb_new_tracks) + logger.debug("%d birth in this step", nb_new_tracks) # Set new id - correspondance['id'][mask_new_id] = self.id_generator(nb_new_tracks) + correspondance["id"][mask_new_id] = self.id_generator(nb_new_tracks) self.append(correspondance) @@ -248,49 +282,55 @@ def append(self, *args, **kwargs): super().append(*args, **kwargs) def id_generator(self, nb_id): - """Generation id and incrementation - """ + """Generation id and incrementation""" values = arange(self.current_id, self.current_id + nb_id) self.current_id += nb_id return values def recense_dead_id_to_extend(self): - """Recense dead id to extend in virtual observation - """ + """Recense dead id to extend in virtual observation""" # List previous id which are not use in the next step - dead_id = setdiff1d(self[-2]['id'], self[-1]['id']) + dead_id = setdiff1d(self[-2]["id"], self[-1]["id"]) nb_dead = dead_id.shape[0] - logger.debug('%d death of real obs in this step', nb_dead) + logger.debug("%d death of real obs in this step", nb_dead) if not self.virtual: return # get id already dead from few time nb_virtual_extend = 0 if self.virtual_obs is not None: - virtual_dead_id = setdiff1d(self.virtual_obs['track'], self[-1]['id']) - i_virtual_dead_id = index(self.virtual_obs['track'], virtual_dead_id) + virtual_dead_id = setdiff1d(self.virtual_obs["track"], self[-1]["id"]) + i_virtual_dead_id = index(self.virtual_obs["track"], virtual_dead_id) # Virtual obs which can be prolongate - alive_virtual_obs = self.virtual_obs['segment_size'][i_virtual_dead_id] < self.nb_virtual + alive_virtual_obs = ( + self.virtual_obs["segment_size"][i_virtual_dead_id] < self.nb_virtual + ) nb_virtual_extend = alive_virtual_obs.sum() - logger.debug('%d virtual obs will be prolongate on the next step', nb_virtual_extend) + logger.debug( + "%d virtual obs will be prolongate on the next step", nb_virtual_extend + ) # Save previous state to count virtual obs self.previous_virtual_obs = self.virtual_obs # Find mask/index on previous correspondance to extrapolate # position - i_dead_id = index(self[-2]['id'], dead_id) + i_dead_id = index(self[-2]["id"], dead_id) # Selection of observations on N-2 and N-1 - obs_a = self.previous2_obs.obs[self[-2][i_dead_id]['in']] - obs_b = self.previous_obs.obs[self[-2][i_dead_id]['out']] + obs_a = self.previous2_obs.obs[self[-2][i_dead_id]["in"]] + obs_b = self.previous_obs.obs[self[-2][i_dead_id]["out"]] self.virtual_obs = self.previous_obs.propagate( - obs_a, obs_b, - self.previous_virtual_obs.obs[i_virtual_dead_id][alive_virtual_obs] if nb_virtual_extend > 0 else None, + obs_a, + obs_b, + self.previous_virtual_obs.obs[i_virtual_dead_id][alive_virtual_obs] + if nb_virtual_extend > 0 + else None, dead_track=dead_id, nb_next=nb_dead + nb_virtual_extend, - model=self.previous_obs) + model=self.previous_obs, + ) def load_state(self): # If we have a previous file of correspondance, we will replay only recent part @@ -298,41 +338,63 @@ def load_state(self): first_dataset = len(self.previous_correspondance.datasets) for correspondance in self.previous_correspondance[:first_dataset]: self.append(correspondance) - self.current_obs = self.class_method.load_file(self.datasets[first_dataset - 2]) + self.current_obs = self.class_method.load_file( + self.datasets[first_dataset - 2], **self.class_kw + ) flg_virtual = self.previous_correspondance.virtual with Dataset(self.filename_previous_correspondance) as general_handler: self.current_id = general_handler.last_current_id if flg_virtual: # Load last virtual obs - self.virtual_obs = VirtualEddiesObservations.from_netcdf(general_handler.groups['LastVirtualObs']) + self.virtual_obs = VirtualEddiesObservations.from_netcdf( + general_handler.groups["LastVirtualObs"] + ) + self.previous_virtual_obs = VirtualEddiesObservations.from_netcdf( + general_handler.groups["LastPreviousVirtualObs"] + ) # Load and last previous virtual obs to be merge with current => will be previous2_obs # TODO : Need to rethink this line ?? self.current_obs = self.current_obs.merge( - VirtualEddiesObservations.from_netcdf(general_handler.groups['LastPreviousVirtualObs'])) + VirtualEddiesObservations.from_netcdf( + general_handler.groups["LastPreviousVirtualObs"] + ) + ) return first_dataset, flg_virtual return 1, False def track(self): - """Run tracking - """ + """Run tracking""" self.reset_dataset_cache() first_dataset, flg_virtual = self.load_state() - self.swap_dataset(self.datasets[first_dataset - 1]) + kwargs = dict() + needed_variable = self.class_method.needed_variable() + if needed_variable is not None: + kwargs["include_vars"] = needed_variable + self.swap_dataset(self.datasets[first_dataset - 1], **kwargs) # We begin with second file, first one is in previous for file_name in self.datasets[first_dataset:]: - self.swap_dataset(file_name) - logger.info('%s match with previous state', file_name) - logger.debug('%d obs to match', len(self.current_obs)) + self.swap_dataset(file_name, **kwargs) + filename_ = ( + file_name.filename if isinstance(file_name, ExFileObject) else file_name + ) + logger.info("%s match with previous state", filename_) + logger.debug("%d obs to match", len(self.current_obs)) nb_real_obs = len(self.previous_obs) if flg_virtual: - logger.debug('%d virtual obs will be add to previous', len(self.virtual_obs)) + logger.debug( + "%d virtual obs will be add to previous", len(self.virtual_obs) + ) self.previous_obs = self.previous_obs.merge(self.virtual_obs) - i_previous, i_current, association_cost = self.previous_obs.tracking(self.current_obs) + i_previous, i_current, association_cost = self.previous_obs.tracking( + self.current_obs + ) # return true if the first time (previous2obs is none) - if self.store_correspondance(i_previous, i_current, nb_real_obs, association_cost): + if self.store_correspondance( + i_previous, i_current, nb_real_obs, association_cost + ): continue self.recense_dead_id_to_extend() @@ -340,122 +402,166 @@ def track(self): if self.virtual: flg_virtual = True + def to_netcdf(self, handler): + nb_step = len(self.datasets) - 1 + logger.info("Create correspondance file") + # Create dimensions + logger.debug('Create Dimensions "Nlink" : %d', self.nb_link_max) + handler.createDimension("Nlink", self.nb_link_max) + + logger.debug('Create Dimensions "Nstep" : %d', nb_step) + handler.createDimension("Nstep", nb_step) + var_file_in = handler.createVariable( + zlib=False, + complevel=1, + varname="FileIn", + datatype="S1024", + dimensions="Nstep", + ) + var_file_out = handler.createVariable( + zlib=False, + complevel=1, + varname="FileOut", + datatype="S1024", + dimensions="Nstep", + ) + + def get_filename(dataset): + if not isinstance(dataset, str) or not isinstance(dataset, bytes): + return "In memory file" + return dataset + + for i, dataset in enumerate(self.datasets[:-1]): + var_file_in[i] = get_filename(dataset) + var_file_out[i] = get_filename(self.datasets[i + 1]) + + var_nb_link = handler.createVariable( + zlib=True, + complevel=1, + varname="nb_link", + datatype="u2", + dimensions="Nstep", + ) + + datas = dict() + for name, dtype in self.correspondance_dtype: + if dtype is bool_: + dtype = "u1" + kwargs_cv = dict() + if "u1" in dtype: + kwargs_cv["fill_value"] = (255,) + handler.createVariable( + zlib=True, + complevel=1, + varname=name, + datatype=dtype, + dimensions=("Nstep", "Nlink"), + **kwargs_cv + ) + datas[name] = ma.empty((nb_step, self.nb_link_max), dtype=dtype) + datas[name].mask = datas[name] == datas[name] + + for i, correspondance in enumerate(self): + logger.debug("correspondance %d", i) + nb_elt = correspondance.shape[0] + var_nb_link[i] = nb_elt + for name, _ in self.correspondance_dtype: + datas[name][i, :nb_elt] = correspondance[name] + for name, data in datas.items(): + h_v = handler.variables[name] + h_v[:] = data + if "File" not in name: + h_v.min = h_v[:].min() + h_v.max = h_v[:].max() + + handler.virtual_use = str(self.virtual) + handler.virtual_max_segment = self.nb_virtual + handler.last_current_id = self.current_id + if self.virtual_obs is not None: + group = handler.createGroup("LastVirtualObs") + self.virtual_obs.to_netcdf(group) + group = handler.createGroup("LastPreviousVirtualObs") + self.previous_virtual_obs.to_netcdf(group) + handler.module = self.class_method.__module__ + handler.classname = self.class_method.__qualname__ + handler.class_kw = json.dumps(self.class_kw) + handler.node = platform.node() + logger.info("Create correspondance file done") + def save(self, filename, dict_completion=None): self.prepare_merging() - nb_step = len(self.datasets) - 1 if isinstance(dict_completion, dict): filename = filename.format(**dict_completion) - logger.info('Create correspondance file %s', filename) - with Dataset(filename, 'w', format='NETCDF4') as h_nc: - # Create dimensions - logger.debug('Create Dimensions "Nlink" : %d', self.nb_link_max) - h_nc.createDimension('Nlink', self.nb_link_max) - - logger.debug('Create Dimensions "Nstep" : %d', nb_step) - h_nc.createDimension('Nstep', nb_step) - var_file_in = h_nc.createVariable( - zlib=True, complevel=1, - varname='FileIn', datatype='S1024', dimensions='Nstep') - var_file_out = h_nc.createVariable( - zlib=True, complevel=1, - varname='FileOut', datatype='S1024', dimensions='Nstep') - for i, dataset in enumerate(self.datasets[:-1]): - var_file_in[i] = dataset - var_file_out[i] = self.datasets[i + 1] - - var_nb_link = h_nc.createVariable( - zlib=True, complevel=1, - varname='nb_link', datatype='u2', dimensions='Nstep') - - datas = dict() - for name, dtype in self.correspondance_dtype: - if dtype is bool_: - dtype = 'u1' - kwargs_cv = dict() - if 'u1' in dtype: - kwargs_cv['fill_value'] = 255, - h_nc.createVariable(zlib=True, - complevel=1, - varname=name, - datatype=dtype, - dimensions=('Nstep', 'Nlink'), - **kwargs_cv - ) - datas[name] = ma.empty((nb_step, self.nb_link_max), dtype=dtype) - datas[name].mask = datas[name] == datas[name] - - for i, correspondance in enumerate(self): - logger.debug('correspondance %d', i) - nb_elt = correspondance.shape[0] - var_nb_link[i] = nb_elt - for name, _ in self.correspondance_dtype: - datas[name][i, :nb_elt] = correspondance[name] - for name, data in datas.items(): - h_v = h_nc.variables[name] - h_v[:] = data - if 'File' not in name: - h_v.min = h_v[:].min() - h_v.max = h_v[:].max() - - h_nc.virtual_use = str(self.virtual) - h_nc.virtual_max_segment = self.nb_virtual - h_nc.last_current_id = self.current_id - if self.virtual_obs is not None: - group = h_nc.createGroup('LastVirtualObs') - self.virtual_obs.to_netcdf(group) - group = h_nc.createGroup('LastPreviousVirtualObs') - self.previous_virtual_obs.to_netcdf(group) - h_nc.module = self.class_method.__module__ - h_nc.classname = self.class_method.__qualname__ - h_nc.node = platform.node() - logger.info('Create correspondance file done') + with Dataset(filename, "w", format="NETCDF4") as h_nc: + self.to_netcdf(h_nc) def load_compatible(self, filename): if filename is None: return None previous_correspondance = Correspondances.load(filename) if self.nb_virtual != previous_correspondance.nb_virtual: - raise Exception('File of correspondance IN contains a different virtual segment size : file(%d), yaml(%d)' % - (previous_correspondance.nb_virtual, self.nb_virtual)) + raise Exception( + "File of correspondance IN contains a different virtual segment size : file(%d), yaml(%d)" + % (previous_correspondance.nb_virtual, self.nb_virtual) + ) if self.class_method != previous_correspondance.class_method: - raise Exception('File of correspondance IN contains a different class method: file(%s), yaml(%s)' % - (previous_correspondance.class_method, self.class_method)) + raise Exception( + "File of correspondance IN contains a different class method: file(%s), yaml(%s)" + % (previous_correspondance.class_method, self.class_method) + ) return previous_correspondance @classmethod - def load(cls, filename): - logger.info('Try load %s', filename) - with Dataset(filename, 'r', format='NETCDF4') as h_nc: - datas = {varname: data[:] for varname, data in h_nc.variables.items()} - - datasets = list(datas['FileIn']) - datasets.append(datas['FileOut'][-1]) - - if hasattr(h_nc, 'module'): - class_method = getattr(__import__(h_nc.module, globals(), locals(), h_nc.classname), h_nc.classname) - else: - class_method = None - logger.info('File %s load with class %s', filename, class_method) - obj = cls(datasets, h_nc.virtual_max_segment, class_method=class_method) + def from_netcdf(cls, handler): + datas = {varname: data[:] for varname, data in handler.variables.items()} + + datasets = list(datas["FileIn"]) + datasets.append(datas["FileOut"][-1]) + + if hasattr(handler, "module"): + class_method = getattr( + __import__(handler.module, globals(), locals(), handler.classname), + handler.classname, + ) + class_kw = getattr(handler, "class_kw", dict()) + if isinstance(class_kw, str): + class_kw = json.loads(class_kw) + else: + class_method = None + class_kw = dict() + logger.info("File load with class %s(%s)", class_method, class_kw) + obj = cls( + datasets, + handler.virtual_max_segment, + class_method=class_method, + class_kw=class_kw, + ) + + id_max = 0 + for i, nb_elt in enumerate(datas["nb_link"][:]): + logger.debug( + "Link between %s and %s", datas["FileIn"][i], datas["FileOut"][i] + ) + correspondance = array( + datas["in"][i, :nb_elt], dtype=obj.correspondance_dtype + ) + for name, _ in obj.correspondance_dtype: + if name == "in": + continue + if name == "virtual_length": + correspondance[name] = 255 + correspondance[name] = datas[name][i, :nb_elt] + id_max = max(id_max, correspondance["id"].max()) + obj.append(correspondance) + obj.current_id = id_max + 1 + return obj - id_max = 0 - for i, nb_elt in enumerate(datas['nb_link'][:]): - logger.debug( - 'Link between %s and %s', - datas['FileIn'][i], - datas['FileOut'][i]) - correspondance = array(datas['in'][i, :nb_elt], - dtype=obj.correspondance_dtype) - for name, _ in obj.correspondance_dtype: - if name == 'in': - continue - if name == 'virtual_length': - correspondance[name] = 255 - correspondance[name] = datas[name][i, :nb_elt] - id_max = max(id_max, correspondance['id'].max()) - obj.append(correspondance) - obj.current_id = id_max + 1 + @classmethod + def load(cls, filename): + logger.info("Loading %s", filename) + with Dataset(filename, "r", format="NETCDF4") as h_nc: + obj = cls.from_netcdf(h_nc) return obj def prepare_merging(self): @@ -463,88 +569,99 @@ def prepare_merging(self): # is an interval) self.nb_obs_by_tracks = ones(self.current_id, dtype=self.N_DTYPE) for correspondance in self: - self.nb_obs_by_tracks[correspondance['id']] += 1 + self.nb_obs_by_tracks[correspondance["id"]] += 1 if self.virtual: # When start is virtual, we don't have a previous # correspondance - self.nb_obs_by_tracks[correspondance['id'][correspondance['virtual']]] += \ - correspondance['virtual_length'][correspondance['virtual']] + self.nb_obs_by_tracks[ + correspondance["id"][correspondance["virtual"]] + ] += correspondance["virtual_length"][correspondance["virtual"]] # Compute index of each tracks - self.i_current_by_tracks = self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks + self.i_current_by_tracks = ( + self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks + ) # Number of global obs self.nb_obs = self.nb_obs_by_tracks.sum() - logger.info('%d tracks identified', self.current_id) - logger.info('%d observations will be join', self.nb_obs) + logger.info("%d tracks identified", self.current_id) + logger.info("%d observations will be join", self.nb_obs) def longer_than(self, size_min): - """Remove from correspondance table all association for shorter eddies than size_min - """ + """Remove from correspondance table all association for shorter eddies than size_min""" # Identify eddies longer than - i_keep_track = where(self.nb_obs_by_tracks >= size_min)[0] + mask = self.nb_obs_by_tracks >= size_min + if not mask.any(): + return False + i_keep_track = where(mask)[0] # Reduce array self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track] - self.i_current_by_tracks = self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks + self.i_current_by_tracks = ( + self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks + ) self.nb_obs = self.nb_obs_by_tracks.sum() # Give the last id used self.current_id = self.nb_obs_by_tracks.shape[0] - translate = empty(i_keep_track.max() + 1, dtype='u4') + translate = empty(i_keep_track.max() + 1, dtype="u4") translate[i_keep_track] = arange(self.current_id) for i, correspondance in enumerate(self): - m_keep = isin(correspondance['id'], i_keep_track) + m_keep = isin(correspondance["id"], i_keep_track) self[i] = correspondance[m_keep] - self[i]['id'] = translate[self[i]['id']] - logger.debug('Select longer than %d done', size_min) + self[i]["id"] = translate[self[i]["id"]] + logger.debug("Select longer than %d done", size_min) def shorter_than(self, size_max): - """Remove from correspondance table all association for longer eddies than size_max - """ + """Remove from correspondance table all association for longer eddies than size_max""" # Identify eddies longer than i_keep_track = where(self.nb_obs_by_tracks < size_max)[0] # Reduce array self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track] - self.i_current_by_tracks = self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks + self.i_current_by_tracks = ( + self.nb_obs_by_tracks.cumsum() - self.nb_obs_by_tracks + ) self.nb_obs = self.nb_obs_by_tracks.sum() # Give the last id used self.current_id = self.nb_obs_by_tracks.shape[0] - translate = empty(i_keep_track.max() + 1, dtype='u4') + translate = empty(i_keep_track.max() + 1, dtype="u4") translate[i_keep_track] = arange(self.current_id) for i, correspondance in enumerate(self): - m_keep = isin(correspondance['id'], i_keep_track) + m_keep = isin(correspondance["id"], i_keep_track) self[i] = correspondance[m_keep] - self[i]['id'] = translate[self[i]['id']] - logger.debug('Select shorter than %d done', size_max) + self[i]["id"] = translate[self[i]["id"]] + logger.debug("Select shorter than %d done", size_max) def merge(self, until=-1, raw_data=True): - """Merge all the correspondance in one array with all fields - """ + """Merge all the correspondance in one array with all fields""" # Start loading identification again to save in the finals tracks # Load first file self.reset_dataset_cache() self.swap_dataset(self.datasets[0], raw_data=raw_data) # Start create netcdf to agglomerate all eddy - logger.debug('We will create an array (size %d)', self.nb_obs) + logger.debug("We will create an array (size %d)", self.nb_obs) eddies = TrackEddiesObservations( size=self.nb_obs, track_extra_variables=self.current_obs.track_extra_variables, track_array_variables=self.current_obs.track_array_variables, - array_variables=self.current_obs.array_variables, raw_data=raw_data) + array_variables=self.current_obs.array_variables, + raw_data=raw_data, + ) # All the value put at nan, necessary only for all end of track - eddies['cost_association'][:] = default_fillvals['f4'] + eddies["cost_association"][:] = default_fillvals["f4"] # Calculate the index in each tracks, we compute in u4 and translate # in u2 (which are limited to 65535) - logger.debug('Compute global index array (N)') - eddies['n'][:] = uint16( - arange(self.nb_obs, dtype='u4') - self.i_current_by_tracks.repeat(self.nb_obs_by_tracks)) - logger.debug('Compute global track array') - eddies['track'][:] = arange(self.current_id).repeat(self.nb_obs_by_tracks) + logger.debug("Compute global index array (N)") + eddies["n"][:] = uint16( + arange(self.nb_obs, dtype="u4") + - self.i_current_by_tracks.repeat(self.nb_obs_by_tracks) + ) + logger.debug("Compute global track array") + eddies["track"][:] = arange(self.current_id).repeat(self.nb_obs_by_tracks) # Set type of eddy with first file eddies.sign_type = self.current_obs.sign_type # Fields to copy - fields = self.current_obs.obs.dtype.descr + fields = self.current_obs.fields # To know if the track start first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_) @@ -552,11 +669,11 @@ def merge(self, until=-1, raw_data=True): for i, file_name in enumerate(self.datasets[1:]): if until != -1 and i >= until: break - logger.debug('Merge data from %s', file_name) + logger.debug("Merge data from %s", file_name) # Load current file (we begin with second one) self.swap_dataset(file_name, raw_data=raw_data) # We select the list of id which are involve in the correspondance - i_id = self[i]['id'] + i_id = self[i]["id"] # Index where we will write in the final object index_final = self.i_current_by_tracks[i_id] @@ -564,13 +681,14 @@ def merge(self, until=-1, raw_data=True): m_first_obs = ~first_obs_save_in_tracks[i_id] if m_first_obs.any(): # Index in the previous file - index_in = self[i]['in'][m_first_obs] + index_in = self[i]["in"][m_first_obs] # Copy all variable for field in fields: - var = field[0] - if var == 'cost_association': + if field == "cost_association": continue - eddies[var][index_final[m_first_obs]] = self.previous_obs[var][index_in] + eddies[field][index_final[m_first_obs]] = self.previous_obs[field][ + index_in + ] # Increment self.i_current_by_tracks[i_id[m_first_obs]] += 1 # Active this flag, we have only one first by tracks @@ -580,23 +698,23 @@ def merge(self, until=-1, raw_data=True): if self.virtual: # If the flag virtual in correspondance is active, # the previous is virtual - m_virtual = self[i]['virtual'] + m_virtual = self[i]["virtual"] if m_virtual.any(): # Incrementing index - self.i_current_by_tracks[i_id[m_virtual]] += self[i]['virtual_length'][m_virtual] + self.i_current_by_tracks[i_id[m_virtual]] += self[i][ + "virtual_length" + ][m_virtual] # Get new index index_final = self.i_current_by_tracks[i_id] # Index in the current file - index_current = self[i]['out'] + index_current = self[i]["out"] + if "cost_association" in eddies.fields: + eddies["cost_association"][index_final - 1] = self[i]["cost_value"] # Copy all variable for field in fields: - var = field[0] - if var == 'cost_association': - eddies[var][index_final - 1] = self[i]['cost_value'] - else: - eddies[var][index_final] = self.current_obs[var][index_current] + eddies[field][index_final] = self.current_obs[field][index_current] # Add increment for each index used self.i_current_by_tracks[i_id] += 1 @@ -609,49 +727,29 @@ def get_unused_data(self, raw_data=False): Returns: Unused Eddies """ - self.reset_dataset_cache() - self.swap_dataset(self.datasets[0], raw_data=raw_data) - nb_dataset = len(self.datasets) - # Get the number of obs unused - nb_obs = 0 - list_mask = list() - has_virtual = 'virtual' in self[0].dtype.names - logger.debug('Count unused data ...') - for i, filename in enumerate(self.datasets): + has_virtual = "virtual" in self[0].dtype.names + eddies = list() + for i, dataset in enumerate(self.datasets): last_dataset = i == (nb_dataset - 1) if has_virtual and not last_dataset: - m_in = ~self[i]['virtual'] + m_in = ~self[i]["virtual"] else: m_in = slice(None) if i == 0: - eddies_used = self[i]['in'] + index_used = self[i]["in"] elif last_dataset: - eddies_used = self[i - 1]['out'] + index_used = self[i - 1]["out"] else: - eddies_used = unique(concatenate((self[i - 1]['out'], self[i]['in'][m_in]))) - if not isinstance(filename, str): - filename = filename.astype(str) - with Dataset(filename) as h: - nb_obs_day = len(h.dimensions['obs']) - m = ones(nb_obs_day, dtype='bool') - m[eddies_used] = False - list_mask.append(m) - nb_obs += m.sum() - logger.debug('Count unused data OK') - eddies = EddiesObservations( - size=nb_obs, - track_extra_variables=self.current_obs.track_extra_variables, - track_array_variables=self.current_obs.track_array_variables, - array_variables=self.current_obs.array_variables, raw_data=raw_data) - j = 0 - for i, dataset in enumerate(self.datasets): - logger.debug('Loaf file : (%d) %s', i, dataset) - current_obs = self.class_method.load_file(dataset, raw_data=raw_data) - if i == 0: - eddies.sign_type = current_obs.sign_type - unused_obs = current_obs.observations[list_mask[i]] - nb = unused_obs.shape[0] - eddies.observations[j:j + nb] = unused_obs - j += nb - return eddies + index_used = unique( + concatenate((self[i - 1]["out"], self[i]["in"][m_in])) + ) + + logger.debug("Load file : %s", dataset) + if self.memory: + with open(dataset, "rb") as h: + current_obs = self.class_method.load_file(h, raw_data=raw_data) + else: + current_obs = self.class_method.load_file(dataset, raw_data=raw_data) + eddies.append(current_obs.index(index_used, reverse=True)) + return EddiesObservations.concatenate(eddies) diff --git a/src/scripts/EddyFinalTracking b/src/scripts/EddyFinalTracking index f147828d..fad4b02c 100644 --- a/src/scripts/EddyFinalTracking +++ b/src/scripts/EddyFinalTracking @@ -3,12 +3,13 @@ """ Track eddy with Identification file produce with EddyIdentification """ +import datetime as dt +import logging +from os import mkdir +from os.path import exists + from py_eddy_tracker import EddyParser from py_eddy_tracker.tracking import Correspondances -from os.path import exists -from os import mkdir -import logging -import datetime as dt logger = logging.getLogger("pet") @@ -17,27 +18,24 @@ def usage(): """Usage """ # Run using: - parser = EddyParser( - "Tool to use identification step to compute tracking") - parser.add_argument('nc_file', - help='File of correspondances to reload link ' - 'without tracking computation') - parser.add_argument('--path_out', - default='./', - help='Path, where to write file') - - parser.add_argument('--eddies_long_model', default=None) - parser.add_argument('--eddies_short_model', default=None) - parser.add_argument('--eddies_untracked_model', default=None) - - parser.add_argument('--nb_obs_min', - type=int, - default=28, - help='Minimal length of tracks') + parser = EddyParser("Tool to use identification step to compute tracking") + parser.add_argument( + "nc_file", + help="File of correspondances to reload link " "without tracking computation", + ) + parser.add_argument("--path_out", default="./", help="Path, where to write file") + + parser.add_argument("--eddies_long_model", default=None) + parser.add_argument("--eddies_short_model", default=None) + parser.add_argument("--eddies_untracked_model", default=None) + + parser.add_argument( + "--nb_obs_min", type=int, default=28, help="Minimal length of tracks" + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": CONFIG = usage() # Create output directory @@ -49,16 +47,26 @@ if __name__ == '__main__': CORRESPONDANCES = Correspondances.load(CONFIG.nc_file) - logger.info('Start merging') + logger.info("Start merging") CORRESPONDANCES.prepare_merging() - logger.info('The longest tracks have %d observations', CORRESPONDANCES.nb_obs_by_tracks.max()) - logger.info('The mean length is %d observations before filtering', CORRESPONDANCES.nb_obs_by_tracks.mean()) + logger.info( + "The longest tracks have %d observations", + CORRESPONDANCES.nb_obs_by_tracks.max(), + ) + logger.info( + "The mean length is %d observations before filtering", + CORRESPONDANCES.nb_obs_by_tracks.mean(), + ) if CONFIG.eddies_untracked_model is None: - CONFIG.eddies_untracked_model = '%(path)s/%(sign_type)s_%(prod_time)s_untracked.nc' - CORRESPONDANCES.get_unused_data(raw_data=True).write_file(path=SAVE_DIR, filename=CONFIG.eddies_untracked_model) + CONFIG.eddies_untracked_model = ( + "%(path)s/%(sign_type)s_%(prod_time)s_untracked.nc" + ) + CORRESPONDANCES.get_unused_data(raw_data=True).write_file( + path=SAVE_DIR, filename=CONFIG.eddies_untracked_model + ) SHORT_CORRESPONDANCES = CORRESPONDANCES._copy() SHORT_CORRESPONDANCES.shorter_than(size_max=NB_OBS_MIN) @@ -70,21 +78,28 @@ if __name__ == '__main__': # We flag obs if CORRESPONDANCES.virtual: - FINAL_EDDIES['virtual'][:] = FINAL_EDDIES['time'] == 0 - FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES['virtual'] == 1) - SHORT_TRACK['virtual'][:] = SHORT_TRACK['time'] == 0 - SHORT_TRACK.filled_by_interpolation(SHORT_TRACK['virtual'] == 1) + FINAL_EDDIES["virtual"][:] = FINAL_EDDIES["time"] == 0 + FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES["virtual"] == 1) + SHORT_TRACK["virtual"][:] = SHORT_TRACK["time"] == 0 + SHORT_TRACK.filled_by_interpolation(SHORT_TRACK["virtual"] == 1) # Total running time FULL_TIME = dt.datetime.now() - START_TIME - logger.info('Duration : %s', FULL_TIME) + logger.info("Duration : %s", FULL_TIME) - logger.info('Longer track saved have %d obs', CORRESPONDANCES.nb_obs_by_tracks.max()) - logger.info('The mean length is %d observations after filtering', CORRESPONDANCES.nb_obs_by_tracks.mean()) + logger.info( + "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max() + ) + logger.info( + "The mean length is %d observations after filtering", + CORRESPONDANCES.nb_obs_by_tracks.mean(), + ) if CONFIG.eddies_long_model is None: - CONFIG.eddies_long_model = '%(path)s/%(sign_type)s_%(prod_time)s.nc' + CONFIG.eddies_long_model = "%(path)s/%(sign_type)s_%(prod_time)s.nc" if CONFIG.eddies_short_model is None: - CONFIG.eddies_short_model = '%(path)s/%(sign_type)s_%(prod_time)s_track_too_short.nc' + CONFIG.eddies_short_model = ( + "%(path)s/%(sign_type)s_%(prod_time)s_track_too_short.nc" + ) FINAL_EDDIES.write_file(filename=CONFIG.eddies_long_model, path=SAVE_DIR) SHORT_TRACK.write_file(filename=CONFIG.eddies_short_model, path=SAVE_DIR) diff --git a/src/scripts/EddyMergeCorrespondances b/src/scripts/EddyMergeCorrespondances index aa6321eb..7598b47a 100644 --- a/src/scripts/EddyMergeCorrespondances +++ b/src/scripts/EddyMergeCorrespondances @@ -3,12 +3,13 @@ """ Track eddy with Identification file produce with EddyIdentification """ +import datetime as dt +import logging +from os import mkdir +from os.path import dirname, exists + from py_eddy_tracker import EddyParser from py_eddy_tracker.tracking import Correspondances -from os.path import exists, dirname -from os import mkdir -import logging -import datetime as dt logger = logging.getLogger("pet") @@ -17,17 +18,17 @@ def usage(): """Usage """ # Run using: - parser = EddyParser( - "Tool to use identification step to compute tracking") - parser.add_argument('nc_file', - nargs='+', - help='File of correspondances to reload link ' - 'without tracking computation') - parser.add_argument('path_out', help='Path, where to write file') + parser = EddyParser("Tool to use identification step to compute tracking") + parser.add_argument( + "nc_file", + nargs="+", + help="File of correspondances to reload link " "without tracking computation", + ) + parser.add_argument("path_out", help="Path, where to write file") return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": CONFIG = usage() # Create output directory @@ -36,7 +37,7 @@ if __name__ == '__main__': START_TIME = dt.datetime.now() CORRESPONDANCES = Correspondances.load(CONFIG.nc_file[0]) - logger.info('Start merging') + logger.info("Start merging") for i in CONFIG.nc_file[1:]: CORRESPONDANCES.merge_correspondance(Correspondances.load(i)) diff --git a/src/scripts/EddySubSetter b/src/scripts/EddySubSetter index 3a4b0f12..6cace388 100644 --- a/src/scripts/EddySubSetter +++ b/src/scripts/EddySubSetter @@ -3,52 +3,102 @@ """ Subset eddy Dataset """ +import logging + from py_eddy_tracker import EddyParser from py_eddy_tracker.observations.tracking import TrackEddiesObservations -import logging logger = logging.getLogger("pet") def id_parser(): - parser = EddyParser('Eddy Subsetter') - parser.add_argument('filename') - parser.add_argument('filename_out') - - group = parser.add_argument_group('Extraction options') - group.add_argument('-p', '--period', nargs=2, type=int, - help='Start day and end day, if it s negative value we will add to day min and add to day max, if 0 it s not use') - group.add_argument('-l', '--length', nargs=2, type=int, - help='Minimal and maximal quantity of observation for one track, ones bounds could be negative, it will be not use') - group.add_argument('-f', '--full_path', action='store_true', - help='Extract path, if one obs or more are selected') - group.add_argument('-d', '--remove_incomplete', action='store_true', - help='Extract path only if all obs are selected') - group.add_argument('--reject_virtual', action='store_true', - help="If there are only virtual observation in selection, we don't select track") - group.add_argument('-a', '--area', nargs=4, type=float, - metavar=('llcrnrlon', 'llcrnrlat', 'urcrnrlon', 'urcrnrlat'), - help='Coordinates of bounding to extract' - ) - group.add_argument('--direction', choices=['E', 'W', 'S', 'N'], - help='Select only track which have an end point which go in this direction') - group.add_argument('--minimal_degrees_displacement_in_direction', type=float, - help='Minimal displacement in direction specified in --directio options') - group.add_argument('--select_first_observation_in_box', type=float, - help='Select only the first obs in each box for each tracks, value specified must be resolution') - group.add_argument('--remove_var', nargs='+', type=str, help='remove all listed variable') - group.add_argument('--include_var', nargs='+', type=str, help='use only listed variable, remove_var will be ignored') - group.add_argument('-i', '--ids', nargs='+', type=int, help='List of tracks which will be extract') - - group = parser.add_argument_group('General options') - group.add_argument('--sort_time', action='store_true', help='sort all observation with time') - - parser.add_argument('-n', '--no_raw_mode', action='store_true', - help='Uncompress all data, could be create a memory error for huge file, but is safer for extern file of py eddy tracker') + parser = EddyParser("Eddy Subsetter") + parser.add_argument("filename") + parser.add_argument("filename_out") + + group = parser.add_argument_group("Extraction options") + group.add_argument( + "-p", + "--period", + nargs=2, + type=int, + help="Start day and end day, if it's negative value we will add to day min and add to day max, if 0 it s not use", + ) + group.add_argument( + "-l", + "--length", + nargs=2, + type=int, + help="Minimal and maximal quantity of observation for one track, ones bounds could be negative, it will be not use", + ) + group.add_argument( + "-f", + "--full_path", + action="store_true", + help="Extract path, if one obs or more are selected", + ) + group.add_argument( + "-d", + "--remove_incomplete", + action="store_true", + help="Extract path only if all obs are selected", + ) + group.add_argument( + "--reject_virtual", + action="store_true", + help="If there are only virtual observation in selection, we don't select track", + ) + group.add_argument( + "-a", + "--area", + nargs=4, + type=float, + metavar=("llcrnrlon", "llcrnrlat", "urcrnrlon", "urcrnrlat"), + help="Coordinates of bounding to extract", + ) + group.add_argument( + "--direction", + choices=["E", "W", "S", "N"], + help="Select only track which have an end point which go in this direction", + ) + group.add_argument( + "--minimal_degrees_displacement_in_direction", + type=float, + help="Minimal displacement in direction specified in --directio options", + ) + group.add_argument( + "--select_first_observation_in_box", + type=float, + help="Select only the first obs in each box for each tracks, value specified must be resolution", + ) + group.add_argument( + "--remove_var", nargs="+", type=str, help="remove all listed variable" + ) + group.add_argument( + "--include_var", + nargs="+", + type=str, + help="use only listed variable, remove_var will be ignored", + ) + group.add_argument( + "-i", "--ids", nargs="+", type=int, help="List of tracks which will be extract" + ) + + group = parser.add_argument_group("General options") + group.add_argument( + "--sort_time", action="store_true", help="sort all observation with time" + ) + + parser.add_argument( + "-n", + "--no_raw_mode", + action="store_true", + help="Uncompress all data, could be create a memory error for huge file, but is safer for extern file of py eddy tracker", + ) return parser -if __name__ == '__main__': +if __name__ == "__main__": args = id_parser().parse_args() # Original dataset @@ -69,37 +119,46 @@ if __name__ == '__main__': # Select with a start date and end date if args.period is not None: - dataset = dataset.extract_with_period(args.period, full_path=args.full_path, - remove_incomplete=args.remove_incomplete, - reject_virtual=args.reject_virtual) + dataset = dataset.extract_with_period( + args.period, + full_path=args.full_path, + remove_incomplete=args.remove_incomplete, + reject_virtual=args.reject_virtual, + ) # Select track which go through an area if args.area is not None: - area = dict(llcrnrlon=args.area[0], - llcrnrlat=args.area[1], - urcrnrlon=args.area[2], - urcrnrlat=args.area[3], - ) - dataset = dataset.extract_with_area(area, full_path=args.full_path, - remove_incomplete=args.remove_incomplete, - reject_virtual=args.reject_virtual) + area = dict( + llcrnrlon=args.area[0], + llcrnrlat=args.area[1], + urcrnrlon=args.area[2], + urcrnrlat=args.area[3], + ) + dataset = dataset.extract_with_area( + area, + full_path=args.full_path, + remove_incomplete=args.remove_incomplete, + reject_virtual=args.reject_virtual, + ) # Select only track which go in the direction specified if args.direction: if args.minimal_degrees_displacement_in_direction: dataset = dataset.extract_in_direction( - args.direction, - value=args.minimal_degrees_displacement_in_direction) + args.direction, value=args.minimal_degrees_displacement_in_direction + ) else: dataset = dataset.extract_in_direction(args.direction) if args.select_first_observation_in_box: - dataset = dataset.extract_first_obs_in_box(res=args.select_first_observation_in_box) + dataset = dataset.extract_first_obs_in_box( + res=args.select_first_observation_in_box + ) if args.sort_time: - logger.debug('start sorting ...') - dataset.obs.sort(order=['time', 'lon', 'lat']) - logger.debug('end sorting') + logger.debug("start sorting ...") + dataset.obs.sort(order=["time", "lon", "lat"]) + logger.debug("end sorting") # if no data, no output will be written if len(dataset) == 0: diff --git a/src/scripts/EddyTracking b/src/scripts/EddyTracking deleted file mode 100644 index e8c23f12..00000000 --- a/src/scripts/EddyTracking +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Track eddy with Identification file produce with EddyIdentification -""" -from py_eddy_tracker import EddyParser -from yaml import load as yaml_load -from py_eddy_tracker.tracking import Correspondances -from os.path import exists, dirname, basename -from os import mkdir -from re import compile as re_compile -from os.path import join as join_path -from numpy import bytes_, empty, unique -from netCDF4 import Dataset -from datetime import datetime -from glob import glob -import logging - - -logger = logging.getLogger("pet") - - -def browse_dataset_in( - data_dir, - files_model, - date_regexp, - date_model, - start_date=None, - end_date=None, - sub_sampling_step=1, - files=None, -): - pattern_regexp = re_compile(".*/" + date_regexp) - if files is not None: - filenames = bytes_(files) - else: - full_path = join_path(data_dir, files_model) - logger.info("Search files : %s", full_path) - filenames = bytes_(glob(full_path)) - - dataset_list = empty( - len(filenames), dtype=[("filename", "S500"), ("date", "datetime64[D]"),] - ) - dataset_list["filename"] = filenames - - logger.info("%s grids available", dataset_list.shape[0]) - mode_attrs = False - if "(" not in date_regexp: - logger.debug("Attrs date : %s", date_regexp) - mode_attrs = date_regexp.strip().split(":") - else: - logger.debug("Pattern date : %s", date_regexp) - - for item in dataset_list: - str_date = None - if mode_attrs: - with Dataset(item["filename"].decode("utf-8")) as h: - if len(mode_attrs) == 1: - str_date = getattr(h, mode_attrs[0]) - else: - str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1]) - else: - result = pattern_regexp.match(str(item["filename"])) - if result: - str_date = result.groups()[0] - - if str_date is not None: - item["date"] = datetime.strptime(str_date, date_model).date() - - dataset_list.sort(order=["date", "filename"]) - - steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1]) - if len(steps) > 1: - raise Exception("Several days steps in grid dataset %s" % steps) - - if sub_sampling_step != 1: - logger.info("Grid subsampling %d", sub_sampling_step) - dataset_list = dataset_list[::sub_sampling_step] - - if start_date is not None or end_date is not None: - logger.info( - "Available grid from %s to %s", - dataset_list[0]["date"], - dataset_list[-1]["date"], - ) - logger.info("Filtering grid by time %s, %s", start_date, end_date) - mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date) - - dataset_list = dataset_list[mask] - return dataset_list - - -def usage(): - """Usage - """ - # Run using: - parser = EddyParser("Tool to use identification step to compute tracking") - parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker") - parser.add_argument("--correspondance_in", help="Filename of saved correspondance") - parser.add_argument("--correspondance_out", help="Filename to save correspondance") - parser.add_argument( - "--save_correspondance_and_stop", - action="store_true", - help="Stop tracking after correspondance computation," - " merging can be done with EddyFinalTracking", - ) - parser.add_argument( - "--zarr", action="store_true", help="Output will be wrote in zarr" - ) - parser.add_argument("--unraw", action="store_true", help="Load unraw data") - parser.add_argument( - "--blank_period", - type=int, - default=0, - help="Nb of detection which will not use at the end of the period", - ) - args = parser.parse_args() - - # Read yaml configuration file - with open(args.yaml_file, "r") as stream: - config = yaml_load(stream) - if args.correspondance_in is not None and not exists(args.correspondance_in): - args.correspondance_in = None - return ( - config, - args.save_correspondance_and_stop, - args.correspondance_in, - args.correspondance_out, - args.blank_period, - args.zarr, - not args.unraw, - ) - - -if __name__ == "__main__": - ( - CONFIG, - SAVE_STOP, - CORRESPONDANCES_IN, - CORRESPONDANCES_OUT, - BLANK_PERIOD, - ZARR, - RAW, - ) = usage() - # Create output directory - SAVE_DIR = CONFIG["PATHS"].get("SAVE_DIR", None) - if SAVE_DIR is not None and not exists(SAVE_DIR): - mkdir(SAVE_DIR) - - YAML_CORRESPONDANCES_OUT = CONFIG["PATHS"].get("CORRESPONDANCES_OUT", None) - if CORRESPONDANCES_IN is None: - CORRESPONDANCES_IN = CONFIG["PATHS"].get("CORRESPONDANCES_IN", None) - if CORRESPONDANCES_OUT is None: - CORRESPONDANCES_OUT = YAML_CORRESPONDANCES_OUT - if YAML_CORRESPONDANCES_OUT is None and CORRESPONDANCES_OUT is None: - CORRESPONDANCES_OUT = "{path}/{sign_type}_correspondances.nc" - - CLASS = None - if "CLASS" in CONFIG: - CLASSNAME = CONFIG["CLASS"]["CLASS"] - CLASS = getattr( - __import__(CONFIG["CLASS"]["MODULE"], globals(), locals(), CLASSNAME), - CLASSNAME, - ) - - NB_VIRTUAL_OBS_MAX_BY_SEGMENT = int(CONFIG.get("VIRTUAL_LENGTH_MAX", 0)) - - if isinstance(CONFIG["PATHS"]["FILES_PATTERN"], list): - DATASET_LIST = browse_dataset_in( - data_dir=None, - files_model=None, - files=CONFIG["PATHS"]["FILES_PATTERN"], - date_regexp=".*c_([0-9]*?).[nz].*", - date_model="%Y%m%d", - ) - else: - DATASET_LIST = browse_dataset_in( - data_dir=dirname(CONFIG["PATHS"]["FILES_PATTERN"]), - files_model=basename(CONFIG["PATHS"]["FILES_PATTERN"]), - date_regexp=".*c_([0-9]*?).[nz].*", - date_model="%Y%m%d", - ) - - if BLANK_PERIOD > 0: - DATASET_LIST = DATASET_LIST[:-BLANK_PERIOD] - logger.info("Last %d files will be pop", BLANK_PERIOD) - - START_TIME = datetime.now() - logger.info("Start tracking on %d files", len(DATASET_LIST)) - - NB_OBS_MIN = int(CONFIG.get("TRACK_DURATION_MIN", 14)) - if NB_OBS_MIN > len(DATASET_LIST): - raise Exception( - "Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)." - % (len(DATASET_LIST), NB_OBS_MIN) - ) - - CORRESPONDANCES = Correspondances( - datasets=DATASET_LIST["filename"], - virtual=NB_VIRTUAL_OBS_MAX_BY_SEGMENT, - class_method=CLASS, - previous_correspondance=CORRESPONDANCES_IN, - ) - CORRESPONDANCES.track() - logger.info("Track finish") - - logger.info("Start merging") - DATE_START, DATE_STOP = CORRESPONDANCES.period - DICT_COMPLETION = dict( - date_start=DATE_START, - date_stop=DATE_STOP, - date_prod=START_TIME, - path=SAVE_DIR, - sign_type=CORRESPONDANCES.current_obs.sign_legend, - ) - - CORRESPONDANCES.save(CORRESPONDANCES_OUT, DICT_COMPLETION) - if SAVE_STOP: - exit() - - # Merge correspondance, only do if we stop and store just after compute of correspondance - CORRESPONDANCES.prepare_merging() - - logger.info( - "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max() - ) - logger.info( - "The mean length is %d observations before filtering", - CORRESPONDANCES.nb_obs_by_tracks.mean(), - ) - - CORRESPONDANCES.get_unused_data(raw_data=RAW).write_file( - path=SAVE_DIR, filename="%(path)s/%(sign_type)s_untracked.nc", zarr_flag=ZARR - ) - - SHORT_CORRESPONDANCES = CORRESPONDANCES._copy() - SHORT_CORRESPONDANCES.shorter_than(size_max=NB_OBS_MIN) - - CORRESPONDANCES.longer_than(size_min=NB_OBS_MIN) - - FINAL_EDDIES = CORRESPONDANCES.merge(raw_data=RAW) - SHORT_TRACK = SHORT_CORRESPONDANCES.merge(raw_data=RAW) - - # We flag obs - if CORRESPONDANCES.virtual: - FINAL_EDDIES["virtual"][:] = FINAL_EDDIES["time"] == 0 - FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES["virtual"] == 1) - SHORT_TRACK["virtual"][:] = SHORT_TRACK["time"] == 0 - SHORT_TRACK.filled_by_interpolation(SHORT_TRACK["virtual"] == 1) - - # Total running time - FULL_TIME = datetime.now() - START_TIME - logger.info("Mean duration by loop : %s", FULL_TIME / (len(DATASET_LIST) - 1)) - logger.info("Duration : %s", FULL_TIME) - - logger.info( - "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max() - ) - logger.info( - "The mean length is %d observations after filtering", - CORRESPONDANCES.nb_obs_by_tracks.mean(), - ) - - FINAL_EDDIES.write_file(path=SAVE_DIR, zarr_flag=ZARR) - SHORT_TRACK.write_file( - filename="%(path)s/%(sign_type)s_track_too_short.nc", - path=SAVE_DIR, - zarr_flag=ZARR, - ) - diff --git a/src/scripts/EddyTranslate b/src/scripts/EddyTranslate index 8d9fea3e..a0060e9b 100644 --- a/src/scripts/EddyTranslate +++ b/src/scripts/EddyTranslate @@ -3,21 +3,26 @@ """ Translate eddy Dataset """ -from py_eddy_tracker import EddyParser -from py_eddy_tracker.observations.observation import EddiesObservations from netCDF4 import Dataset import zarr +from py_eddy_tracker import EddyParser +from py_eddy_tracker.observations.observation import EddiesObservations + def id_parser(): - parser = EddyParser('Eddy Translate, Translate eddies from netcdf to zarr or from zarr to netcdf') - parser.add_argument('filename_in') - parser.add_argument('filename_out') + parser = EddyParser( + "Eddy Translate, Translate eddies from netcdf to zarr or from zarr to netcdf" + ) + parser.add_argument("filename_in") + parser.add_argument("filename_out") + parser.add_argument("--unraw", action="store_true", help="Load unraw data, use only for netcdf." + "If unraw is active, netcdf is loaded without apply scalefactor and add_offset.") return parser def is_nc(filename): - return filename.endswith('.nc') + return filename.endswith(".nc") def get_variable_name(filename): @@ -29,23 +34,25 @@ def get_variable_name(filename): return list(h.keys()) -def get_variable(filename, varname): +def get_variable(filename, varname, raw=True): if is_nc(filename): - dataset = EddiesObservations.load_from_netcdf(filename, raw_data=True, include_vars=(varname,)) + dataset = EddiesObservations.load_from_netcdf( + filename, raw_data=raw, include_vars=(varname,) + ) else: dataset = EddiesObservations.load_from_zarr(filename, include_vars=(varname,)) return dataset -if __name__ == '__main__': +if __name__ == "__main__": args = id_parser().parse_args() variables = get_variable_name(args.filename_in) if not is_nc(args.filename_out): - h = zarr.open(args.filename_out, 'w') + h = zarr.open(args.filename_out, "w") for varname in variables: - get_variable(args.filename_in, varname).to_zarr(h) + get_variable(args.filename_in, varname, raw=not args.unraw).to_zarr(h) else: - with Dataset(args.filename_out, 'w') as h: + with Dataset(args.filename_out, "w") as h: for varname in variables: - get_variable(args.filename_in, varname).to_netcdf(h) \ No newline at end of file + get_variable(args.filename_in, varname, raw=not args.unraw).to_netcdf(h) diff --git a/src/scripts/GUIEddy b/src/scripts/GUIEddy deleted file mode 100644 index 6f9673f4..00000000 --- a/src/scripts/GUIEddy +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Gui for eddy atlas -""" -from argparse import ArgumentParser -from py_eddy_tracker.observations.tracking import TrackEddiesObservations -from py_eddy_tracker.gui import GUI - - -def parser(): - parser = ArgumentParser("Eddy atlas GUI") - parser.add_argument("atlas", nargs="+") - parser.add_argument("--med", action='store_true') - return parser.parse_args() - - -if __name__ == "__main__": - args = parser() - atlas = { - dataset: TrackEddiesObservations.load_file(dataset) for dataset in args.atlas - } - g = GUI(**atlas) - if args.med: - g.med() - g.show() diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 00000000..ee2d7881 --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,51 @@ +from numpy import arange, array, nan, ones, zeros + +from py_eddy_tracker.generic import cumsum_by_track, simplify, wrap_longitude + + +def test_simplify(): + x = arange(10, dtype="f4") + y = zeros(10, dtype="f4") + # Will jump one value on two + x_, y_ = simplify(x, y, precision=1) + assert x_.shape[0] == 5 + x_, y_ = simplify(x, y, precision=0.99) + assert x_.shape[0] == 10 + # check nan management + x[4] = nan + x_, y_ = simplify(x, y, precision=1) + assert x_.shape[0] == 6 + x[3] = nan + x_, y_ = simplify(x, y, precision=1) + assert x_.shape[0] == 6 + x[:4] = nan + x_, y_ = simplify(x, y, precision=1) + assert x_.shape[0] == 3 + x[:] = nan + x_, y_ = simplify(x, y, precision=1) + assert x_.shape[0] == 0 + + +def test_cumsum_by_track(): + a = ones(10, dtype="i4") * 2 + track = array([1, 1, 2, 2, 2, 2, 44, 44, 44, 48]) + assert (cumsum_by_track(a, track) == [2, 4, 2, 4, 6, 8, 2, 4, 6, 2]).all() + + +def test_wrapping(): + y = x = arange(-5, 5, dtype="f4") + x_, _ = wrap_longitude(x, y, ref=-10) + assert (x_ == x).all() + x_, _ = wrap_longitude(x, y, ref=1) + assert x.size == x_.size + assert (x_[6:] == x[6:]).all() + assert (x_[:6] == x[:6] + 360).all() + x_, _ = wrap_longitude(x, y, ref=1, cut=True) + assert x.size + 3 == x_.size + assert (x_[6 + 3 :] == x[6:]).all() + assert (x_[:7] == x[:7] + 360).all() + + # FIXME Need evolution in wrap_longitude + # x %= 360 + # x_, _ = wrap_longitude(x, y, ref=-10, cut=True) + # assert x.size == x_.size diff --git a/tests/test_grid.py b/tests/test_grid.py index 32ae2721..0e6dd586 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,11 +1,21 @@ -from py_eddy_tracker.dataset.grid import RegularGridDataset -from py_eddy_tracker.data import get_path from matplotlib.path import Path +from numpy import arange, array, isnan, ma, nan, ones, zeros from pytest import approx -G = RegularGridDataset(get_path("mask_1_60.nc"), "lon", "lat") +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import RegularGridDataset + +G = RegularGridDataset(get_demo_path("mask_1_60.nc"), "lon", "lat") X = 0.025 -contour = Path(((-X, 0), (X, 0), (X, X), (-X, X), (-X, 0),)) +contour = Path( + ( + (-X, 0), + (X, 0), + (X, X), + (-X, X), + (-X, 0), + ) +) # contour @@ -44,3 +54,71 @@ def test_bounds(): x0, x1, y0, y1 = G.bounds assert x0 == -1 / 120.0 and x1 == 360 - 1 / 120 assert y0 == approx(-90 - 1 / 120.0) and y1 == approx(90 - 1 / 120) + + +def test_interp(): + # Fake grid + g = RegularGridDataset.with_array( + coordinates=("x", "y"), + datas=dict( + z=ma.array(((0, 1), (2, 3)), dtype="f4"), + x=array((0, 20)), + y=array((0, 10)), + ), + centered=True, + ) + x0, y0 = array((10,)), array((5,)) + x1, y1 = array((15,)), array((5,)) + # outside but usable with nearest + x2, y2 = array((25,)), array((5,)) + # Outside for any interpolation + x3, y3 = array((25,)), array((16,)) + x4, y4 = array((55,)), array((25,)) + # Interp nearest + assert g.interp("z", x0, y0, method="nearest") == 0 + assert g.interp("z", x1, y1, method="nearest") == 2 + assert isnan(g.interp("z", x4, y4, method="nearest")) + assert g.interp("z", x2, y2, method="nearest") == 2 + assert isnan(g.interp("z", x3, y3, method="nearest")) + + # Interp bilinear + assert g.interp("z", x0, y0) == 1.5 + assert g.interp("z", x1, y1) == 2 + assert isnan(g.interp("z", x2, y2)) + + +def test_convolution(): + """ + Add some dummy check on convolution filter + """ + # Fake grid + z = ma.array( + arange(12).reshape((-1, 1)) * arange(10).reshape((1, -1)), + mask=zeros((12, 10), dtype="bool"), + dtype="f4", + ) + g = RegularGridDataset.with_array( + coordinates=("x", "y"), + datas=dict( + z=z, + x=arange(0, 6, 0.5), + y=arange(0, 5, 0.5), + ), + centered=True, + ) + + def kernel_func(lat): + return ones((3, 3)) + + # After transpose we must get same result + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert (d.T[:9, :9] == d[:9, :9]).all() + # We mask one value and check convolution result + z.mask[2, 2] = True + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert d[1, 1] == z[:3, :3].sum() / 8 + # Add nan and check only nearest value is contaminate + z[2, 2] = nan + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert not isnan(d[0, 0]) + assert isnan(d[1:4, 1:4]).all() diff --git a/tests/test_id.py b/tests/test_id.py index 1469cad0..c69a5a26 100644 --- a/tests/test_id.py +++ b/tests/test_id.py @@ -1,14 +1,15 @@ -from py_eddy_tracker.dataset.grid import RegularGridDataset -from py_eddy_tracker.data import get_path from datetime import datetime +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import RegularGridDataset + g = RegularGridDataset( - get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" ) def test_id(): g.add_uv("adt") a, c = g.eddy_identification("adt", "u", "v", datetime(2019, 2, 23)) - assert len(a) == 35 + assert len(a) == 36 assert len(c) == 36 diff --git a/tests/test_network.py b/tests/test_network.py new file mode 100644 index 00000000..5cd9b4cd --- /dev/null +++ b/tests/test_network.py @@ -0,0 +1,15 @@ +from py_eddy_tracker.observations.network import Network + + +def test_group_translate(): + translate = Network.group_translator(5, ((0, 1), (0, 2), (1, 3))) + assert (translate == [3, 3, 3, 3, 4]).all() + + translate = Network.group_translator(5, ((1, 3), (0, 1), (0, 2))) + assert (translate == [3, 3, 3, 3, 4]).all() + + translate = Network.group_translator(8, ((1, 3), (2, 3), (2, 4), (5, 6), (4, 5))) + assert (translate == [0, 6, 6, 6, 6, 6, 6, 7]).all() + + translate = Network.group_translator(6, ((0, 1), (0, 2), (1, 3), (4, 5))) + assert (translate == [3, 3, 3, 3, 5, 5]).all() diff --git a/tests/test_obs.py b/tests/test_obs.py index 8a38a96a..a912e06b 100644 --- a/tests/test_obs.py +++ b/tests/test_obs.py @@ -1,8 +1,20 @@ +import zarr + +from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.observations.observation import EddiesObservations -from py_eddy_tracker.data import get_path -a = EddiesObservations.load_file(get_path("Anticyclonic_20190223.nc")) -c = EddiesObservations.load_file(get_path("Cyclonic_20190223.nc")) +a_filename, c_filename = ( + get_demo_path("Anticyclonic_20190223.nc"), + get_demo_path("Cyclonic_20190223.nc"), +) +a = EddiesObservations.load_file(a_filename) +a_raw = EddiesObservations.load_file(a_filename, raw_data=True) +memory_store = zarr.group() +# Dataset was raw loaded from netcdf and save in zarr +a_raw.to_zarr(memory_store, chunck_size=100000) +# We load zarr data without raw option +a_zarr = EddiesObservations.load_from_zarr(memory_store) +c = EddiesObservations.load_file(c_filename) def test_merge(): @@ -10,5 +22,15 @@ def test_merge(): assert len(new) == len(a) + len(c) -# def test_write(): -# with Dataset +def test_zarr_raw(): + assert a == a_zarr + + +def test_index(): + a_nc_subset = EddiesObservations.load_file( + a_filename, indexs=dict(obs=slice(500, 1000)) + ) + a_zarr_subset = EddiesObservations.load_from_zarr( + memory_store, indexs=dict(obs=slice(500, 1000)), buffer_size=50 + ) + assert a_nc_subset == a_zarr_subset diff --git a/tests/test_poly.py b/tests/test_poly.py new file mode 100644 index 00000000..a780f64d --- /dev/null +++ b/tests/test_poly.py @@ -0,0 +1,53 @@ +from numpy import array, pi, roll +from pytest import approx + +from py_eddy_tracker.poly import ( + convex, + fit_circle, + get_convex_hull, + poly_area_vertice, + visvalingam, +) + +# Vertices for next test +V = array(((2, 2, 3, 3, 2), (-10, -9, -9, -10, -10))) +V_concave = array(((2, 2, 2.5, 3, 3, 2), (-10, -9, -9.5, -9, -10, -10))) + + +def test_poly_area(): + assert 1 == poly_area_vertice(V.T) + + +def test_fit_circle(): + x0, y0, r, err = fit_circle(*V) + assert x0 == approx(2.5, rel=1e-10) + assert y0 == approx(-9.5, rel=1e-10) + assert r == approx(2**0.5 / 2, rel=1e-10) + assert err == approx((1 - 2 / pi) * 100, rel=1e-10) + + +def test_convex(): + assert convex(*V) is True + assert convex(*V[::-1]) is True + assert convex(*V_concave) is False + assert convex(*V_concave[::-1]) is False + + +def test_convex_hull(): + assert convex(*get_convex_hull(*V_concave)) is True + + +def test_visvalingam(): + x = array([1, 2, 3, 4, 5, 6.75, 6, 1]) + y = array([-0.5, -1.5, -1, -1.75, -1, -1, -0.5, -0.5]) + x_target = [1, 2, 3, 4, 6, 1] + y_target = [-0.5, -1.5, -1, -1.75, -0.5, -0.5] + x_, y_ = visvalingam(x, y, 6) + assert (x_target == x_).all() + assert (y_target == y_).all() + x_, y_ = visvalingam(x[:-1], y[:-1], 6) + assert (x_target == x_).all() + assert (y_target == y_).all() + x_, y_ = visvalingam(roll(x, 2), roll(y, 2), 6) + assert (x_target[:-1] == x_[1:]).all() + assert (y_target[:-1] == y_[1:]).all() diff --git a/tests/test_track.py b/tests/test_track.py new file mode 100644 index 00000000..f7e83786 --- /dev/null +++ b/tests/test_track.py @@ -0,0 +1,50 @@ +from netCDF4 import Dataset +import zarr + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker +from py_eddy_tracker.observations.observation import EddiesObservations +from py_eddy_tracker.tracking import Correspondances + +filename = get_demo_path("Anticyclonic_20190223.nc") +a0 = EddiesObservations.load_file(filename) +a1 = a0.copy() + + +def test_area_tracking_parameter(): + delta = 0.2 + # All eddies will be shift of delta in longitude and latitude + for k in ( + "lon", + "lon_max", + "contour_lon_s", + "contour_lon_e", + "lat", + "lat_max", + "contour_lat_s", + "contour_lat_e", + ): + a1[k][:] -= delta + a1.time[:] += 1 + # wrote in memory a0 and a1 + h0, h1 = zarr.group(), zarr.group() + a0.to_zarr(h0), a1.to_zarr(h1) + cmin = 0.5 + class_kw = dict(cmin=cmin) + c = Correspondances(datasets=(h0, h1), class_method=AreaTracker, class_kw=class_kw) + c.track() + c.prepare_merging() + # We have now an eddy object + eddies_tracked = c.merge(raw_data=False) + cost = eddies_tracked.cost_association + m = cost < 1 + assert cost[m].max() <= (1 - cmin) + + # Try to save netcdf + with Dataset("tata", mode="w", diskless=True) as h: + c.to_netcdf(h) + c_reloaded = Correspondances.from_netcdf(h) + assert class_kw == c_reloaded.class_kw + + # test access to the lifetime (item) + eddies_tracked["lifetime"] diff --git a/versioneer.py b/versioneer.py index 2b545405..1e3753e6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,5 @@ -# Version: 0.18 + +# Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -6,18 +7,14 @@ ============== * like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer +* https://github.com/python-versioneer/python-versioneer * Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based +* License: Public Domain (Unlicense) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control @@ -26,9 +23,38 @@ ## Quick Install -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results +Versioneer provides two installation modes. The "classic" vendored mode installs +a copy of versioneer into your repository. The experimental build-time dependency mode +is intended to allow you to skip this step and simplify the process of upgrading. + +### Vendored mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) + * Note that you will need to add `tomli; python_version < "3.11"` to your + build-time dependencies if you use `pyproject.toml` +* run `versioneer install --vendor` in your source tree, commit the results +* verify version information with `python setup.py version` + +### Build-time dependency mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) +* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) + to the `requires` key of the `build-system` table in `pyproject.toml`: + ```toml + [build-system] + requires = ["setuptools", "versioneer[toml]"] + build-backend = "setuptools.build_meta" + ``` +* run `versioneer install --no-vendor` in your source tree, commit the results +* verify version information with `python setup.py version` ## Version Identifiers @@ -60,7 +86,7 @@ for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. +uncommitted changes). The version identifier is used for multiple purposes: @@ -165,7 +191,7 @@ Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). +[issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects @@ -179,7 +205,7 @@ `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. + provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs @@ -193,9 +219,9 @@ Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve @@ -223,31 +249,20 @@ cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace +* edit `setup.cfg` and `pyproject.toml`, if necessary, + to include any new configuration settings indicated by the release notes. + See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files @@ -264,36 +279,70 @@ direction and include code from all supported VCS systems, reducing the number of intermediate scripts. +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . +Specifically, both are released under the "Unlicense", as described in +https://unlicense.org/. -""" +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer -from __future__ import print_function +""" +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments -try: - import configparser -except ImportError: - import ConfigParser as configparser +import configparser import errno import json import os import re import subprocess import sys +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import NoReturn +import functools + +have_tomllib = True +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + versionfile_source: str + versionfile_build: Optional[str] + parentdir_prefix: Optional[str] + verbose: Optional[bool] -def get_root(): + +def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the @@ -301,20 +350,28 @@ def get_root(): """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -323,46 +380,62 @@ def get_root(): # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py) - ) + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root -def get_config_from_root(root): +def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or + # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None + root_pth = Path(root) + pyproject_toml = root_pth / "pyproject.toml" + setup_cfg = root_pth / "setup.cfg" + section: Union[Dict[str, Any], configparser.SectionProxy, None] = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, 'rb') as fobj: + pp = tomllib.load(fobj) + section = pp['tool']['versioneer'] + except (tomllib.TOMLDecodeError, KeyError) as e: + print(f"Failed to load config from {pyproject_toml}: {e}") + print("Try to load it from setup.cfg") + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + # `cast`` really shouldn't be used, but its simplest for the + # common VersioneerConfig users at the moment. We verify against + # `None` values elsewhere where it matters cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): + cfg.VCS = section['VCS'] + cfg.style = section.get("style", "") + cfg.versionfile_source = cast(str, section.get("versionfile_source")) + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = cast(str, section.get("tag_prefix")) + if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") + cfg.parentdir_prefix = section.get("parentdir_prefix") + if isinstance(section, configparser.SectionProxy): + # Make sure configparser translates to bool + cfg.verbose = section.getboolean("verbose") + else: + cfg.verbose = section.get("verbose") + return cfg @@ -371,41 +444,48 @@ class NotThisMethod(Exception): # these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f + HANDLERS.setdefault(vcs, {})[method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -416,28 +496,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -446,9 +523,11 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -464,8 +543,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -483,13 +569,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -498,22 +584,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -524,18 +623,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -544,15 +645,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% @@ -561,41 +661,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -608,11 +715,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d @@ -621,7 +728,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: @@ -630,6 +737,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %%s" %% r) return {"version": r, @@ -645,7 +757,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -656,8 +773,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) @@ -665,24 +789,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -699,7 +856,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces @@ -724,26 +881,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -768,23 +926,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] + rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -811,12 +1017,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -833,7 +1068,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -853,7 +1088,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -873,7 +1108,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -887,10 +1122,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -905,7 +1144,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -926,7 +1165,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, @@ -953,41 +1192,48 @@ def get_versions(): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -1000,11 +1246,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1013,7 +1259,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1021,30 +1267,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -1055,7 +1304,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1063,33 +1320,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -1098,16 +1379,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -1116,12 +1398,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1132,19 +1412,20 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def do_vcs_install(manifest_in, versionfile_source, ipy): +def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py @@ -1153,36 +1434,40 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] + files = [versionfile_source] if ipy: files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) present = False try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: pass if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -1191,30 +1476,23 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from +# This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. @@ -1231,43 +1509,41 @@ def get_versions(): """ -def versions_from_file(filename): +def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search( - r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: - mo = re.search( - r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) -def write_to_version_file(filename, versions): +def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) print("set %s to '%s'" % (filename, versions["version"])) -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -1285,29 +1561,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -1334,12 +1659,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -1356,7 +1710,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -1376,7 +1730,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -1396,26 +1750,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -1425,20 +1781,16 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" -def get_versions(verbose=False): +def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. @@ -1453,10 +1805,9 @@ def get_versions(verbose=False): assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1510,22 +1861,22 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} -def get_version(): +def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" +def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and @@ -1539,25 +1890,25 @@ def get_cmdclass(): # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - cmds = {} + cmds = {} if cmdclass is None else cmdclass.copy() - # we add "version" to both distutils and setuptools - from distutils.core import Command + # we add "version" to setuptools + from setuptools import Command class cmd_version(Command): description = "report generated version string" - user_options = [] - boolean_options = [] + user_options: List[Tuple[str, str, str]] = [] + boolean_options: List[str] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) @@ -1565,10 +1916,9 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - # we override "build_py" in both distutils and setuptools + # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py @@ -1583,30 +1933,68 @@ def run(self): # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py + if 'build_py' in cmds: + _build_py: Any = cmds['build_py'] else: - from distutils.command.build_py import build_py as _build_py + from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe + if 'build_ext' in cmds: + _build_ext: Any = cmds['build_ext'] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self) -> None: + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + if not os.path.exists(target_versionfile): + print(f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py.") + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1615,7 +2003,7 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1627,28 +2015,24 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if "py2exe" in sys.modules: # py2exe enabled? + if 'py2exe' in sys.modules: # py2exe enabled? try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 + from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1660,27 +2044,60 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["py2exe"] = cmd_py2exe + # sdist farms its file list building out to egg_info + if 'egg_info' in cmds: + _egg_info: Any = cmds['egg_info'] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self) -> None: + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append('versioneer.py') + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') + for f in self.filelist.files] + + manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') + with open(manifest_filename, 'w') as fobj: + fobj.write('\n'.join(normalized)) + + cmds['egg_info'] = cmd_egg_info + # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist + if 'sdist' in cmds: + _sdist: Any = cmds['sdist'] else: - from distutils.command.sdist import sdist as _sdist + from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): + def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old @@ -1688,7 +2105,7 @@ def run(self): self.distribution.metadata.version = versions["version"] return _sdist.run(self) - def make_release_tree(self, base_dir, files): + def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) @@ -1697,10 +2114,8 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file( - target_versionfile, self._versioneer_generated_versions - ) - + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) cmds["sdist"] = cmd_sdist return cmds @@ -1743,25 +2158,28 @@ def make_release_tree(self, base_dir, files): """ -INIT_PY_SNIPPET = """ +OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" +def do_setup() -> int: + """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) - except ( - EnvironmentError, - configparser.NoSectionError, - configparser.NoOptionError, - ) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) + except (OSError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (OSError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1770,76 +2188,46 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" - if INIT_PY_SNIPPET not in old: + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) + f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print( - " appending versionfile_source ('%s') to MANIFEST.in" - % cfg.versionfile_source - ) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") + maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 -def scan_setup_py(): +def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False @@ -1876,10 +2264,14 @@ def scan_setup_py(): return errors +def setup_command() -> NoReturn: + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) + setup_command()