diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..d9437d16 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,70 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + schedule: + - cron: '41 16 * * 4' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 7c800b33..f2f4753e 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,29 +1,33 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Python application +name: Pytest & Flake8 -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] +on: [push, pull_request] jobs: build: - - runs-on: ubuntu-latest + strategy: + matrix: + # os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] + python_version: ['3.10', '3.11', '3.12'] + name: Run py eddy tracker build tests + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -l {0} steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python ${{ matrix.python_version }} uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: ${{ matrix.python_version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest pytest-cov if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Install package run: | @@ -34,6 +38,3 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 00000000..ba36f8ea --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,13 @@ +version: 2 +conda: + environment: doc/environment.yml +build: + os: ubuntu-lts-latest + tools: + python: "mambaforge-latest" +python: + install: + - method: pip + path: . +sphinx: + configuration: doc/conf.py \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 72b204c1..6d6d6a30 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,7 +7,115 @@ The format is based on `Keep a Changelog `_ and this project adheres to `Semantic Versioning `_. [Unreleased] ------------- +------------- +Changed +^^^^^^^ + +Fixed +^^^^^ + +Added +^^^^^ + +[3.6.2] - 2025-06-06 +-------------------- +Changed +^^^^^^^ + +- Remove dead end method for network will move dead end to the trash and not remove observations + +Fixed +^^^^^ + +- Fix matplotlib version + +[3.6.1] - 2022-10-14 +-------------------- +Changed +^^^^^^^ + +- Rewrite particle candidate to be easily parallelize + +Fixed +^^^^^ + +- Check strictly increasing coordinates for RegularGridDataset. +- Grid mask is check to replace mask monovalue by 2D mask with fixed value + +Added +^^^^^ + +- Add method to colorize contour with a field +- Add option to force align on to return all step for reference dataset +- Add method and property to network to easily select segment and network +- Add method to found same track/segment/network in dataset + +[3.6.0] - 2022-01-12 +-------------------- +Changed +^^^^^^^ + +- Now time allows second precision (instead of daily precision) in storage on uint32 from 01/01/1950 to 01/01/2086 + New identifications are produced with this type, old files could still be loaded. + If you use old identifications for tracking use the `--unraw` option to unpack old times and store data with the new format. +- Now amplitude is stored with .1 mm of precision (instead of 1 mm), same advice as for time. +- Expose more parameters to users for bash tools build_network & divide_network +- Add warning when loading a file created from a previous version of py-eddy-tracker. + + + +Fixed +^^^^^ + +- Fix bug in convolution(filter), lowest rows was replace by zeros in convolution computation. + Important impact for tiny kernel +- Fix method of sampling before contour fitting +- Fix bug when loading dataset in zarr format, not all variables were correctly loaded +- Fix bug when zarr dataset has same size for number of observations and contour size +- Fix bug when tracking, previous_virtual_obs was not always loaded + +Added +^^^^^ + +- Allow to replace mask by isnan method to manage nan data instead of masked data +- Add drifter colocation example + +[3.5.0] - 2021-06-22 +-------------------- + +Fixed +^^^^^ +- GridCollection get_next_time_step & get_previous_time_step needed more files to work in the dataset list. + The loop needed explicitly self.dataset[i+-1] even when i==0, therefore indice went out of range + +[3.4.0] - 2021-03-29 +-------------------- +Changed +^^^^^^^ +- `TrackEddiesObservations.filled_by_interpolation` method stop to normalize longitude, to continue to have same + beahviour you must call before `TrackEddiesObservations.normalize_longitude` + +Fixed +^^^^^ +- Use `safe_load` for yaml load +- repr of EddiesObservation when the collection is empty (time attribute empty array) +- display_timeline and event_timeline can now use colors according to 'y' values. +- event_timeline now plot all merging event in one plot, instead of one plot per merging. Same for splitting. (avoid bad legend) + +Added +^^^^^ +- Identification file could be load in memory before to be read with netcdf library to get speed up in case of slow disk +- Add a filter option in EddyId to be able to remove fine scale (like noise) with same filter order than high scale + filter +- Add **EddyQuickCompare** to have few figures about several datasets in comparison based on match function +- Color and text field for contour in **EddyAnim** could be choose +- Save EddyAnim in mp4 +- Add method to get eddy contour which enclosed obs defined with (x,y) coordinates +- Add **EddyNetworkSubSetter** to subset network which need special tool and operation after subset +- Network: + - Add method to find relatives segments + - Add method to get cloase network in an other atlas +- Management of time cube data for advection [3.3.0] - 2020-12-03 -------------------- diff --git a/README.md b/README.md index 4f466b7e..0cc34894 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,40 @@ [![PyPI version](https://badge.fury.io/py/pyEddyTracker.svg)](https://badge.fury.io/py/pyEddyTracker) +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6333988.svg)](https://doi.org/10.5281/zenodo.6333988) [![Documentation Status](https://readthedocs.org/projects/py-eddy-tracker/badge/?version=stable)](https://py-eddy-tracker.readthedocs.io/en/stable/?badge=stable) [![Gitter](https://badges.gitter.im/py-eddy-tracker/community.svg)](https://gitter.im/py-eddy-tracker/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/AntSimi/py-eddy-tracker/master?urlpath=lab/tree/notebooks/python_module/) +[![pytest](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml/badge.svg)](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml) # README # +### How to cite code? ### + +Zenodo provide DOI for each tagged version, [all DOI are available here](https://doi.org/10.5281/zenodo.6333988) + ### Method ### Method was described in : +[Pegliasco, C., Delepoulle, A., Morrow, R., Faugère, Y., and Dibarboure, G.: META3.1exp : A new Global Mesoscale Eddy Trajectories Atlas derived from altimetry, Earth Syst. Sci. Data Discuss.](https://doi.org/10.5194/essd-14-1087-2022) + [Mason, E., A. Pascual, and J. C. McWilliams, 2014: A new sea surface height–based code for oceanic mesoscale eddy tracking.](https://doi.org/10.1175/JTECH-D-14-00019.1) ### Use case ### Method is used in : - + [Mason, E., A. Pascual, P. Gaube, S.Ruiz, J. Pelegrí, A. Delepoulle, 2017: Subregional characterization of mesoscale eddies across the Brazil-Malvinas Confluence](https://doi.org/10.1002/2016JC012611) ### How do I get set up? ### +#### Short story #### + +```bash +pip install pyeddytracker +``` + +#### Long story #### + To avoid problems with installation, use of the virtualenv Python virtual environment is recommended. Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, ...), e.g.: @@ -27,12 +43,20 @@ Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, ... pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pint polygon3 ``` -Then run the following to install the eddy tracker: +Clone : + +```bash +git clone https://github.com/AntSimi/py-eddy-tracker +``` + +Then run the following to install the eddy tracker : ```bash python setup.py install ``` + ### Tools gallery ### + Several examples based on py eddy tracker module are [here](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html). [![](https://py-eddy-tracker.readthedocs.io/en/latest/_static/logo.png)](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html) diff --git a/apt.txt b/apt.txt new file mode 100644 index 00000000..a72c3b87 --- /dev/null +++ b/apt.txt @@ -0,0 +1 @@ +libgl1-mesa-glx \ No newline at end of file diff --git a/check.sh b/check.sh index ddafab69..a402bf52 100644 --- a/check.sh +++ b/check.sh @@ -1,7 +1,5 @@ -isort src tests examples -black src tests examples -blackdoc src tests examples -flake8 tests examples src --count --select=E9,F63,F7,F82 --show-source --statistics -# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -flake8 tests examples src --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -pytest -vv --cov py_eddy_tracker --cov-report html +isort . +black . +blackdoc . +flake8 . +python -m pytest -vv --cov py_eddy_tracker --cov-report html diff --git a/doc/api.rst b/doc/api.rst index c463c7d0..866704f8 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -14,6 +14,7 @@ API reference py_eddy_tracker.observations.network py_eddy_tracker.observations.observation py_eddy_tracker.observations.tracking + py_eddy_tracker.observations.groups py_eddy_tracker.eddy_feature py_eddy_tracker.generic py_eddy_tracker.gui diff --git a/doc/conf.py b/doc/conf.py index add04862..0844d585 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -54,7 +54,7 @@ sphinx_gallery_conf = { "examples_dirs": "../examples", # path to your example scripts "gallery_dirs": "python_module", - "capture_repr": ("_repr_html_",), + "capture_repr": ("_repr_html_", "__repr__"), "backreferences_dir": "gen_modules/backreferences", "doc_module": ("py_eddy_tracker",), "reference_url": { @@ -69,7 +69,7 @@ "repo": "py-eddy-tracker", "branch": "master", "binderhub_url": "https://mybinder.org", - "dependencies": ["../requirements.txt"], + "dependencies": ["environment.yml"], # Optional keys "use_jupyter_lab": True, }, @@ -96,9 +96,9 @@ master_doc = "index" # General information about the project. -project = u"py-eddy-tracker" -copyright = u"2019, A. Delepoulle & E. Mason" -author = u"A. Delepoulle & E. Mason" +project = "py-eddy-tracker" +copyright = "2019, A. Delepoulle & E. Mason" +author = "A. Delepoulle & E. Mason" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -272,8 +272,8 @@ ( master_doc, "py-eddy-tracker.tex", - u"py-eddy-tracker Documentation", - u"A. Delepoulle \\& E. Mason", + "py-eddy-tracker Documentation", + "A. Delepoulle \\& E. Mason", "manual", ), ] @@ -304,7 +304,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "py-eddy-tracker", u"py-eddy-tracker Documentation", [author], 1) + (master_doc, "py-eddy-tracker", "py-eddy-tracker Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -320,7 +320,7 @@ ( master_doc, "py-eddy-tracker", - u"py-eddy-tracker Documentation", + "py-eddy-tracker Documentation", author, "py-eddy-tracker", "One line description of project.", diff --git a/doc/custom_tracking.rst b/doc/custom_tracking.rst index f75ad2c5..f72f3e72 100644 --- a/doc/custom_tracking.rst +++ b/doc/custom_tracking.rst @@ -6,7 +6,6 @@ Customize tracking Code my own tracking ******************** -To use your own tracking method, you just need to create a class which inherit +To use your own tracking method, you just need to create a class which inherits from :meth:`py_eddy_tracker.observations.observation.EddiesObservations` and set this class in yaml file like we see in the previous topic. - diff --git a/doc/environment.yml b/doc/environment.yml new file mode 100644 index 00000000..063a60de --- /dev/null +++ b/doc/environment.yml @@ -0,0 +1,13 @@ +channels: + - conda-forge +dependencies: + - python>=3.10, <3.13 + - ffmpeg + - pip + - pip: + - -r ../requirements.txt + - git+https://github.com/AntSimi/py-eddy-tracker-sample-id.git + - sphinx-gallery + - sphinx_rtd_theme + - sphinx>=3.1 + - pyeddytrackersample diff --git a/doc/grid_identification.rst b/doc/grid_identification.rst index 4d681252..2cc3fb52 100644 --- a/doc/grid_identification.rst +++ b/doc/grid_identification.rst @@ -8,7 +8,7 @@ Run the identification process for a single day Shell/bash command ****************** -Bash command will allow to process one grid, it will apply a filter and an identification. +Bash command will allow you to process one grid, it will apply a filter and an identification. .. code-block:: bash @@ -18,14 +18,14 @@ Bash command will allow to process one grid, it will apply a filter and an ident out_directory -v DEBUG -Filter could be modify with options *--cut_wavelength* and *--filter_order*. You could also defined height between two isolines with *--isoline_step*, which could +Filter could be modified with options *--cut_wavelength* and *--filter_order*. You could also define height between two isolines with *--isoline_step*, which could improve speed profile quality and detect accurately tiny eddies. You could also use *--fit_errmax* to manage acceptable shape of eddies. An eddy identification will produce two files in the output directory, one for anticyclonic eddies and the other one for cyclonic. -In regional area which are away from the equator, current could be deduce from height, juste write *None None* inplace of *ugos vgos* +In regional areas which are away from the Equator, current could be deduced from height, just write *None None* in place of *ugos vgos* -In case of **datacube**, you need to specify index for each layer (time, depth, ...) wiht *--indexs* option like: +In case of **datacube**, you need to specify index for each layer (time, depth, ...) with *--indexs* option like: .. code-block:: bash @@ -40,45 +40,49 @@ In case of **datacube**, you need to specify index for each layer (time, depth, Python code *********** -If we want customize eddies identification, python module is here. +If we want to customize eddies identification, the Python module is here. Activate verbose .. code-block:: python from py_eddy_tracker import start_logger - start_logger().setLevel('DEBUG') # Available options: ERROR, WARNING, INFO, DEBUG + + start_logger().setLevel("DEBUG") # Available options: ERROR, WARNING, INFO, DEBUG Run identification .. code-block:: python from datetime import datetime + h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) date = datetime(2019, 2, 23) a, c = h.eddy_identification( - 'adt', 'ugos', 'vgos', # Variables used for identification - date, # Date of identification - 0.002, # step between two isolines of detection (m) - pixel_limit=(5, 2000), # Min and max pixel count for valid contour - shape_error=55, # Error max (%) between ratio of circle fit and contour - ) + "adt", + "ugos", + "vgos", # Variables used for identification + date, # Date of identification + 0.002, # step between two isolines of detection (m) + pixel_limit=(5, 2000), # Min and max pixel count for valid contour + shape_error=55, # Error max (%) between ratio of circle fit and contour + ) Plot the resulting identification .. code-block:: python - fig = plt.figure(figsize=(15,7)) - ax = fig.add_axes([.03,.03,.94,.94]) - ax.set_title('Eddies detected -- Cyclonic(red) and Anticyclonic(blue)') - ax.set_ylim(-75,75) - ax.set_xlim(0,360) - ax.set_aspect('equal') - a.display(ax, color='b', linewidth=.5) - c.display(ax, color='r', linewidth=.5) + fig = plt.figure(figsize=(15, 7)) + ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) + ax.set_title("Eddies detected -- Cyclonic(red) and Anticyclonic(blue)") + ax.set_ylim(-75, 75) + ax.set_xlim(0, 360) + ax.set_aspect("equal") + a.display(ax, color="b", linewidth=0.5) + c.display(ax, color="r", linewidth=0.5) ax.grid() - fig.savefig('share/png/eddies.png') + fig.savefig("share/png/eddies.png") .. image:: ../share/png/eddies.png @@ -87,7 +91,8 @@ Save identification data .. code-block:: python from netCDF import Dataset - with Dataset(date.strftime('share/Anticyclonic_%Y%m%d.nc'), 'w') as h: + + with Dataset(date.strftime("share/Anticyclonic_%Y%m%d.nc"), "w") as h: a.to_netcdf(h) - with Dataset(date.strftime('share/Cyclonic_%Y%m%d.nc'), 'w') as h: + with Dataset(date.strftime("share/Cyclonic_%Y%m%d.nc"), "w") as h: c.to_netcdf(h) diff --git a/doc/grid_load_display.rst b/doc/grid_load_display.rst index 2e570274..2f0e3765 100644 --- a/doc/grid_load_display.rst +++ b/doc/grid_load_display.rst @@ -7,7 +7,12 @@ Loading grid .. code-block:: python from py_eddy_tracker.dataset.grid import RegularGridDataset - grid_name, lon_name, lat_name = 'share/nrt_global_allsat_phy_l4_20190223_20190226.nc', 'longitude', 'latitude' + + grid_name, lon_name, lat_name = ( + "share/nrt_global_allsat_phy_l4_20190223_20190226.nc", + "longitude", + "latitude", + ) h = RegularGridDataset(grid_name, lon_name, lat_name) Plotting grid @@ -15,14 +20,15 @@ Plotting grid .. code-block:: python from matplotlib import pyplot as plt + fig = plt.figure(figsize=(14, 12)) - ax = fig.add_axes([.02, .51, .9, .45]) - ax.set_title('ADT (m)') + ax = fig.add_axes([0.02, 0.51, 0.9, 0.45]) + ax.set_title("ADT (m)") ax.set_ylim(-75, 75) - ax.set_aspect('equal') - m = h.display(ax, name='adt', vmin=-1, vmax=1) + ax.set_aspect("equal") + m = h.display(ax, name="adt", vmin=-1, vmax=1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .51, .01, .45])) + plt.colorbar(m, cax=fig.add_axes([0.94, 0.51, 0.01, 0.45])) Filtering @@ -30,27 +36,27 @@ Filtering .. code-block:: python h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) Save grid .. code-block:: python - h.write('/tmp/grid.nc') + h.write("/tmp/grid.nc") Add second plot .. code-block:: python - ax = fig.add_axes([.02, .02, .9, .45]) - ax.set_title('ADT Filtered (m)') - ax.set_aspect('equal') + ax = fig.add_axes([0.02, 0.02, 0.9, 0.45]) + ax.set_title("ADT Filtered (m)") + ax.set_aspect("equal") ax.set_ylim(-75, 75) - m = h.display(ax, name='adt', vmin=-.1, vmax=.1) + m = h.display(ax, name="adt", vmin=-0.1, vmax=0.1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .02, .01, .45])) - fig.savefig('share/png/filter.png') + plt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.45])) + fig.savefig("share/png/filter.png") .. image:: ../share/png/filter.png \ No newline at end of file diff --git a/doc/installation.rst b/doc/installation.rst index 40ce9ad5..b2bcb45c 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -2,7 +2,13 @@ How do I get set up ? ===================== -Source are available on github https://github.com/AntSimi/py-eddy-tracker +You could install stable version with pip + +.. code-block:: bash + + pip install pyEddyTracker + +Or with source which are available on github https://github.com/AntSimi/py-eddy-tracker Use python3. To avoid problems with installation, use of the virtualenv Python virtual environment is recommended or conda. diff --git a/doc/run_tracking.rst b/doc/run_tracking.rst index 95fa45f4..36290339 100644 --- a/doc/run_tracking.rst +++ b/doc/run_tracking.rst @@ -5,9 +5,9 @@ Tracking Requirements ************ -Before to run tracking, you will need to run identification on every time step of the period(Period of your study). +Before tracking, you will need to run identification on every time step of the period (period of your study). -**Advice** : Before to run tracking, display some identification file allow to learn a lot +**Advice** : Before tracking, displaying some identification files. You will learn a lot Default method ************** @@ -24,9 +24,9 @@ Example of conf.yaml FILES_PATTERN: MY_IDENTIFICATION_PATH/Anticyclonic*.nc SAVE_DIR: MY_OUTPUT_PATH - # Number of timestep for missing detection + # Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 3 - # Minimal time to consider as a full track + # Minimal number of timesteps to considered as a long trajectory TRACK_DURATION_MIN: 10 To run: @@ -35,24 +35,24 @@ To run: EddyTracking conf.yaml -v DEBUG -It will use default tracker: +It will use the default tracker: -- No travel longer than 125 km between two observation -- Amplitude and speed radius must be close to previous observation -- In case of several candidate only closest is kept +- No travel longer than 125 km between two observations +- Amplitude and speed radius must be close to the previous observation +- In case of several candidates only the closest is kept It will produce 4 files by run: -- A file of correspondances which will contains all the information to merge all identifications file -- A file which will contains all the observations which are alone -- A file which will contains all the short track which are shorter than **TRACK_DURATION_MIN** -- A file which will contains all the long track which are longer than **TRACK_DURATION_MIN** +- A file of correspondences which will contain all the information to merge all identifications file +- A file which will contain all the observations which are alone +- A file which will contain all the short tracks which are shorter than **TRACK_DURATION_MIN** +- A file which will contain all the long tracks which are longer than **TRACK_DURATION_MIN** -Use python module +Use Python module ***************** -An example of tracking with python module is available in the gallery: +An example of tracking with the Python module is available in the gallery: :ref:`sphx_glr_python_module_08_tracking_manipulation_pet_run_a_tracking.py` Choose a tracker @@ -63,13 +63,13 @@ With yaml you could also select another tracker: .. code-block:: yaml PATHS: - # Files produces with EddyIdentification + # Files produced with EddyIdentification FILES_PATTERN: MY/IDENTIFICATION_PATH/Anticyclonic*.nc SAVE_DIR: MY_OUTPUT_PATH - # Number of timestep for missing detection + # Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 3 - # Minimal time to consider as a full track + # Minimal number of timesteps to considered as a long trajectory TRACK_DURATION_MIN: 10 CLASS: @@ -80,5 +80,5 @@ With yaml you could also select another tracker: # py_eddy_tracker.observations.observation.EddiesObservations CLASS: CheltonTracker -This tracker is like described in CHELTON11[https://doi.org/10.1016/j.pocean.2011.01.002]. +This tracker is like the one described in CHELTON11[https://doi.org/10.1016/j.pocean.2011.01.002]. Code is here :meth:`py_eddy_tracker.featured_tracking.old_tracker_reference` diff --git a/doc/spectrum.rst b/doc/spectrum.rst index d751b909..f96e30a0 100644 --- a/doc/spectrum.rst +++ b/doc/spectrum.rst @@ -11,7 +11,7 @@ Load data raw = RegularGridDataset(grid_name, lon_name, lat_name) filtered = RegularGridDataset(grid_name, lon_name, lat_name) - filtered.bessel_low_filter('adt', 150, order=3) + filtered.bessel_low_filter("adt", 150, order=3) areas = dict( sud_pacific=dict(llcrnrlon=188, urcrnrlon=280, llcrnrlat=-64, urcrnrlat=-7), @@ -23,24 +23,33 @@ Compute and display spectrum .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum') - ax.set_xlabel('km') + ax.set_title("Spectrum") + ax.set_xlabel("km") for name_area, area in areas.items(): - - lon_spec, lat_spec = raw.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s raw' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s raw' % name_area, color=mappable.get_color(), linestyle='--') - - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s high' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = raw.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s raw" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s raw" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s high" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum.png') + fig.savefig("share/png/spectrum.png") .. image:: ../share/png/spectrum.png @@ -49,18 +58,23 @@ Compute and display spectrum ratio .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum ratio') - ax.set_xlabel('km') + ax.set_title("Spectrum ratio") + ax.set_xlabel("km") for name_area, area in areas.items(): - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area, ref=raw) - mappable = ax.plot(*lat_spec, label='lat %s high' % name_area)[0] - ax.plot(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area, ref=raw) + mappable = ax.plot(*lat_spec, label="lat %s high" % name_area)[0] + ax.plot( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum_ratio.png') + fig.savefig("share/png/spectrum_ratio.png") .. image:: ../share/png/spectrum_ratio.png diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..e94c7bc1 --- /dev/null +++ b/environment.yml @@ -0,0 +1,11 @@ +name: binder-pyeddytracker +channels: + - conda-forge +dependencies: + - python>=3.10, <3.13 + - pip + - ffmpeg + - pip: + - -r requirements.txt + - pyeddytrackersample + - . diff --git a/examples/01_general_things/README.rst b/examples/01_general_things/README.rst new file mode 100644 index 00000000..5876c1b6 --- /dev/null +++ b/examples/01_general_things/README.rst @@ -0,0 +1,3 @@ +General features +================ + diff --git a/examples/01_general_things/pet_storage.py b/examples/01_general_things/pet_storage.py new file mode 100644 index 00000000..918ebbee --- /dev/null +++ b/examples/01_general_things/pet_storage.py @@ -0,0 +1,158 @@ +""" +How data is stored +================== + +General information about eddies storage. + +All files have the same structure, with more or less fields and possible different order. + +There are 3 class of files: + +- **Eddies collections** : contain a list of eddies without link between them +- **Track eddies collections** : + manage eddies associated in trajectories, the ```track``` field allows to separate each trajectory +- **Network eddies collections** : + manage eddies associated in networks, the ```track``` and ```segment``` fields allow to separate observations +""" + +from matplotlib import pyplot as plt +from numpy import arange, outer +import py_eddy_tracker_sample + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.observations.observation import EddiesObservations, Table +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + +# %% +# Eddies can be stored in 2 formats with the same structure: +# +# - zarr (https://zarr.readthedocs.io/en/stable/), which allow efficiency in IO,... +# - NetCDF4 (https://unidata.github.io/netcdf4-python/), well-known format +# +# Each field are stored in column, each row corresponds at 1 observation, +# array field like contour/profile are 2D column. + +# %% +# Eddies files (zarr or netcdf) can be loaded with ```load_file``` method: +eddies_collections = EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")) +eddies_collections.field_table() +# offset and scale_factor are used only when data is stored in zarr or netCDF4 + +# %% +# Field access +# ------------ +# To access the total field, here ```amplitude``` +eddies_collections.amplitude + +# To access only a specific part of the field +eddies_collections.amplitude[4:15] + +# %% +# Data matrix is a numpy ndarray +eddies_collections.obs +# %% +eddies_collections.obs.dtype + + +# %% +# Contour storage +# --------------- +# All contours are stored on the same number of points, and are resampled if needed with an algorithm to be stored as objects + +# %% +# Speed profile storage +# --------------------- +# Speed profile is an interpolation of speed mean along each contour. +# For each contour included in eddy, we compute mean of speed along the contour, +# and after we interpolate speed mean array on a fixed size array. +# +# Several field are available to understand "uavg_profile" : +# 0. - num_contours : Number of contour in eddies, must be equal to amplitude divide by isoline step +# 1. - height_inner_contour : height of inner contour used +# 2. - height_max_speed_contour : height of max speed contour used +# 3. - height_external_contour : height of outter contour used +# +# Last value of "uavg_profile" is for inner contour and first value for outter contour. + +# Observations selection of "uavg_profile" with high number of contour(Eddy with high amplitude) +e = eddies_collections.extract_with_mask(eddies_collections.num_contours > 15) + +# %% + +# Raw display of profiles with more than 15 contours +ax = plt.subplot(111) +_ = ax.plot(e.uavg_profile.T, lw=0.5) + +# %% + +# Profile from inner to outter +ax = plt.subplot(111) +ax.plot(e.uavg_profile[:, ::-1].T, lw=0.5) +_ = ax.set_xlabel("From inner to outter contour"), ax.set_ylabel("Speed (m/s)") + +# %% + +# If we normalize indice of contour to set speed contour to 1 and inner contour to 0 +ax = plt.subplot(111) +h_in = e.height_inner_contour +h_s = e.height_max_speed_contour +h_e = e.height_external_contour +r = (h_e - h_in) / (h_s - h_in) +nb_pt = e.uavg_profile.shape[1] +# Create an x array for each profile +x = outer(arange(nb_pt) / nb_pt, r) + +ax.plot(x, e.uavg_profile[:, ::-1].T, lw=0.5) +_ = ax.set_xlabel("From inner to outter contour"), ax.set_ylabel("Speed (m/s)") + + +# %% +# Trajectories +# ------------ +# Tracks eddies collections add several fields : +# +# - **track** : Trajectory number +# - **observation_flag** : Flag indicating if the value is interpolated between two observations or not +# (0: observed eddy, 1: interpolated eddy)" +# - **observation_number** : Eddy temporal index in a trajectory, days starting at the eddy first detection +# - **cost_association** : result of the cost function to associate the eddy with the next observation +eddies_tracks = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") +) +# In this example some fields are removed (effective_contour_longitude,...) in order to save time for doc building +eddies_tracks.field_table() + +# %% +# Networks +# -------- +# Network files use some specific fields : +# +# - track : ID of network (ID 0 correspond to lonely eddies) +# - segment : ID of a segment within a network (from 1 to N) +# - previous_obs : Index of the previous observation in the full dataset, +# if -1 there are no previous observation (the segment starts) +# - next_obs : Index of the next observation in the full dataset, if -1 there are no next observation (the segment ends) +# - previous_cost : Result of the cost function (1 is a good association, 0 is bad) with previous observation +# - next_cost : Result of the cost function (1 is a good association, 0 is bad) with next observation +eddies_network = NetworkObservations.load_file(get_demo_path("network_med.nc")) +eddies_network.field_table() + +# %% +sl = slice(70, 100) +Table( + eddies_network.network(651).obs[sl][ + [ + "time", + "track", + "segment", + "previous_obs", + "previous_cost", + "next_obs", + "next_cost", + ] + ] +) + +# %% +# Networks are ordered by increasing network number (`track`), then increasing segment number, then increasing time diff --git a/examples/02_eddy_identification/README.rst b/examples/02_eddy_identification/README.rst index 30d23e5b..07ef8f44 100644 --- a/examples/02_eddy_identification/README.rst +++ b/examples/02_eddy_identification/README.rst @@ -1,2 +1,4 @@ Eddy detection ============== + +Method to detect eddies from grid and display, with various parameters. diff --git a/examples/02_eddy_identification/pet_contour_circle.py b/examples/02_eddy_identification/pet_contour_circle.py index 197ae357..03332285 100644 --- a/examples/02_eddy_identification/pet_contour_circle.py +++ b/examples/02_eddy_identification/pet_contour_circle.py @@ -11,7 +11,7 @@ # %% # Load detection files -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) # %% # Plot the speed and effective (dashed) contours @@ -25,4 +25,4 @@ # Replace contours by circles using center and radius (effective is dashed) a.circle_contour() a.display(ax, label="Anticyclonic circle", color="g", lw=1) -ax.legend(loc="upper right") +_ = ax.legend(loc="upper right") diff --git a/examples/02_eddy_identification/pet_display_id.py b/examples/02_eddy_identification/pet_display_id.py index 37c2b863..57c59bc2 100644 --- a/examples/02_eddy_identification/pet_display_id.py +++ b/examples/02_eddy_identification/pet_display_id.py @@ -11,8 +11,8 @@ # %% # Load detection files -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) -c = EddiesObservations.load_file(data.get_path("Cyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +c = EddiesObservations.load_file(data.get_demo_path("Cyclonic_20190223.nc")) # %% # Fill effective contour with amplitude diff --git a/examples/02_eddy_identification/pet_eddy_detection.py b/examples/02_eddy_identification/pet_eddy_detection.py index d158e870..b1b2c1af 100644 --- a/examples/02_eddy_identification/pet_eddy_detection.py +++ b/examples/02_eddy_identification/pet_eddy_detection.py @@ -35,7 +35,9 @@ def update_axes(ax, mappable=None): # %% # Load Input grid, ADT is used to detect eddies g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) ax = start_axes("ADT (m)") @@ -91,12 +93,12 @@ def update_axes(ax, mappable=None): update_axes(ax) # %% -# Creteria for rejecting a contour -# 0. - Accepted (green) -# 1. - Rejection for shape error (red) -# 2. - Masked value within contour (blue) -# 3. - Under or over the pixel limit bounds (black) -# 4. - Amplitude criterion (yellow) +# Criteria for rejecting a contour: +# 0. - Accepted (green) +# 1. - Rejection for shape error (red) +# 2. - Masked value within contour (blue) +# 3. - Under or over the pixel limit bounds (black) +# 4. - Amplitude criterion (yellow) ax = start_axes("Contours' rejection criteria") g.contours.display(ax, only_unused=True, lw=0.5, display_criterion=True) update_axes(ax) @@ -143,10 +145,9 @@ def update_axes(ax, mappable=None): # %% # Display the speed radius of the detected eddies ax = start_axes("Speed Radius (km)") -a.scatter(ax, "radius_s", vmin=10, vmax=50, s=80, ref=-10, cmap="magma_r", factor=0.001) -m = c.scatter( - ax, "radius_s", vmin=10, vmax=50, s=80, ref=-10, cmap="magma_r", factor=0.001 -) +kwargs = dict(vmin=10, vmax=50, s=80, ref=-10, cmap="magma_r", factor=0.001) +a.scatter(ax, "radius_s", **kwargs) +m = c.scatter(ax, "radius_s", **kwargs) update_axes(ax, m) # %% @@ -154,7 +155,5 @@ def update_axes(ax, mappable=None): ax = start_axes("Effective Radius (km)") kwargs = dict(vmin=10, vmax=80, cmap="magma_r", factor=0.001, lut=14, ref=-10) a.filled(ax, "effective_radius", **kwargs) -m = c.filled( - ax, "radius_e", vmin=10, vmax=80, cmap="magma_r", factor=0.001, lut=14, ref=-10 -) +m = c.filled(ax, "radius_e", **kwargs) update_axes(ax, m) diff --git a/examples/02_eddy_identification/pet_eddy_detection_ACC.py b/examples/02_eddy_identification/pet_eddy_detection_ACC.py new file mode 100644 index 00000000..d12c62f3 --- /dev/null +++ b/examples/02_eddy_identification/pet_eddy_detection_ACC.py @@ -0,0 +1,207 @@ +""" +Eddy detection : Antartic Circumpolar Current +============================================= + +This script detect eddies on the ADT field, and compute u,v with the method add_uv (use it only if the Equator is avoided) + +Two detections are provided : with a filtered ADT and without filtering + +""" + +from datetime import datetime + +from matplotlib import pyplot as plt, style + +from py_eddy_tracker import data +from py_eddy_tracker.dataset.grid import RegularGridDataset + +pos_cb = [0.1, 0.52, 0.83, 0.015] +pos_cb2 = [0.1, 0.07, 0.4, 0.015] + + +def quad_axes(title): + style.use("default") + fig = plt.figure(figsize=(13, 10)) + fig.suptitle(title, weight="bold", fontsize=14) + axes = list() + + ax_pos = dict( + topleft=[0.1, 0.54, 0.4, 0.38], + topright=[0.53, 0.54, 0.4, 0.38], + botleft=[0.1, 0.09, 0.4, 0.38], + botright=[0.53, 0.09, 0.4, 0.38], + ) + + for key, position in ax_pos.items(): + ax = fig.add_axes(position) + ax.set_xlim(5, 45), ax.set_ylim(-60, -37) + ax.set_aspect("equal"), ax.grid(True) + axes.append(ax) + if "right" in key: + ax.set_yticklabels("") + return fig, axes + + +def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold"): + for ax in fig.get_axes(): + ax.grid() + ax.grid(which="major", linestyle="-", linewidth="0.5", color="black") + if ax.get_ylabel() != "": + ax.set_ylabel(ax.get_ylabel(), fontsize=labelsize, fontweight=labelweight) + if ax.get_xlabel() != "": + ax.set_xlabel(ax.get_xlabel(), fontsize=labelsize, fontweight=labelweight) + if ax.get_title() != "": + ax.set_title(ax.get_title(), fontsize=labelsize, fontweight=labelweight) + ax.tick_params(labelsize=ticklabelsize) + + +# %% +# Load Input grid, ADT is used to detect eddies +margin = 30 + +kw_data = dict( + filename=data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + x_name="longitude", + y_name="latitude", + # Manual area subset + indexs=dict( + latitude=slice(100 - margin, 220 + margin), + longitude=slice(0, 230 + margin), + ), +) +g_raw = RegularGridDataset(**kw_data) +g_raw.add_uv("adt") +g = RegularGridDataset(**kw_data) +g.copy("adt", "adt_low") +g.bessel_high_filter("adt", 700) +g.bessel_low_filter("adt_low", 700) +g.add_uv("adt") + +# %% +# Identification +# ^^^^^^^^^^^^^^ +# Run the identification step with slices of 2 mm +date = datetime(2019, 2, 23) +kw_ident = dict( + date=date, step=0.002, shape_error=70, sampling=30, uname="u", vname="v" +) +a, c = g.eddy_identification("adt", **kw_ident) +a_, c_ = g_raw.eddy_identification("adt", **kw_ident) + + +# %% +# Figures +# ------- +kw_adt = dict(vmin=-1.5, vmax=1.5, cmap=plt.get_cmap("RdBu_r", 30)) +fig, axs = quad_axes("General properties field") +g_raw.display(axs[0], "adt", **kw_adt) +axs[0].set_title("Total ADT (m)") +m = g.display(axs[1], "adt_low", **kw_adt) +axs[1].set_title("ADT (m) large scale, cutoff at 700 km") +m2 = g.display(axs[2], "adt", cmap=plt.get_cmap("RdBu_r", 20), vmin=-0.5, vmax=0.5) +axs[2].set_title("ADT (m) high-pass filtered, a cutoff at 700 km") +cb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation="horizontal") +cb.set_label("ADT (m)", labelpad=0) +cb2 = plt.colorbar(m2, cax=axs[2].figure.add_axes(pos_cb2), orientation="horizontal") +cb2.set_label("ADT (m)", labelpad=0) +set_fancy_labels(fig) + +# %% +# The large-scale North-South gradient is removed by the filtering step. + +# %% +fig, axs = quad_axes("") +axs[0].set_title("Without filter") +axs[0].set_ylabel("Contours used in eddies") +axs[1].set_title("With filter") +axs[2].set_ylabel("Closed contours but not used") +g_raw.contours.display(axs[0], lw=0.5, only_used=True) +g.contours.display(axs[1], lw=0.5, only_used=True) +g_raw.contours.display(axs[2], lw=0.5, only_unused=True) +g.contours.display(axs[3], lw=0.5, only_unused=True) +set_fancy_labels(fig) + +# %% +# Removing the large-scale North-South gradient reveals closed contours in the +# South-Western corner of the ewample region. + +# %% +kw = dict(ref=-10, linewidth=0.75) +kw_a = dict(color="r", label="Anticyclonic ({nb_obs} eddies)") +kw_c = dict(color="b", label="Cyclonic ({nb_obs} eddies)") +kw_filled = dict(vmin=0, vmax=100, cmap="Spectral_r", lut=20, intern=True, factor=100) +fig, axs = quad_axes("Comparison between two detections") +# Match with intern/inner contour +i_a, j_a, s_a = a_.match(a, intern=True, cmin=0.15) +i_c, j_c, s_c = c_.match(c, intern=True, cmin=0.15) + +a_.index(i_a).filled(axs[0], s_a, **kw_filled) +a.index(j_a).filled(axs[1], s_a, **kw_filled) +c_.index(i_c).filled(axs[0], s_c, **kw_filled) +m = c.index(j_c).filled(axs[1], s_c, **kw_filled) + +cb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation="horizontal") +cb.set_label("Similarity index (%)", labelpad=-5) +a_.display(axs[0], **kw, **kw_a), c_.display(axs[0], **kw, **kw_c) +a.display(axs[1], **kw, **kw_a), c.display(axs[1], **kw, **kw_c) + +axs[0].set_title("Without filter") +axs[0].set_ylabel("Detection") +axs[1].set_title("With filter") +axs[2].set_ylabel("Contours' rejection criteria") + +g_raw.contours.display(axs[2], lw=0.5, only_unused=True, display_criterion=True) +g.contours.display(axs[3], lw=0.5, only_unused=True, display_criterion=True) + +for ax in axs: + ax.legend() + +set_fancy_labels(fig) + +# %% +# Very similar eddies have Similarity Indexes >= 40% + +# %% +# Criteria for rejecting a contour : +# 0. Accepted (green) +# 1. Rejection for shape error (red) +# 2. Masked value within contour (blue) +# 3. Under or over the pixel limit bounds (black) +# 4. Amplitude criterion (yellow) + +# %% +i_a, j_a = i_a[s_a >= 0.4], j_a[s_a >= 0.4] +i_c, j_c = i_c[s_c >= 0.4], j_c[s_c >= 0.4] +fig = plt.figure(figsize=(12, 12)) +fig.suptitle(f"Scatter plot (A : {i_a.shape[0]}, C : {i_c.shape[0]} matches)") + +for i, (label, field, factor, stop) in enumerate( + ( + ("Speed radius (km)", "radius_s", 0.001, 120), + ("Effective radius (km)", "radius_e", 0.001, 120), + ("Amplitude (cm)", "amplitude", 100, 25), + ("Speed max (cm/s)", "speed_average", 100, 25), + ) +): + ax = fig.add_subplot(2, 2, i + 1, title=label) + ax.set_xlabel("Without filter") + ax.set_ylabel("With filter") + + ax.plot( + a_[field][i_a] * factor, + a[field][j_a] * factor, + "r.", + label="Anticyclonic", + ) + ax.plot( + c_[field][i_c] * factor, + c[field][j_c] * factor, + "b.", + label="Cyclonic", + ) + ax.set_aspect("equal"), ax.grid() + ax.plot((0, 1000), (0, 1000), "g") + ax.set_xlim(0, stop), ax.set_ylim(0, stop) + ax.legend() + +set_fancy_labels(fig) diff --git a/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py b/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py index 27ea77ac..55267b76 100644 --- a/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py +++ b/examples/02_eddy_identification/pet_eddy_detection_gulf_stream.py @@ -37,7 +37,7 @@ def update_axes(ax, mappable=None): # Load Input grid, ADT is used to detect eddies margin = 30 g = RegularGridDataset( - data.get_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), "longitude", "latitude", # Manual area subset diff --git a/examples/02_eddy_identification/pet_filter_and_detection.py b/examples/02_eddy_identification/pet_filter_and_detection.py index 2feffc3e..ec02a28c 100644 --- a/examples/02_eddy_identification/pet_filter_and_detection.py +++ b/examples/02_eddy_identification/pet_filter_and_detection.py @@ -33,7 +33,9 @@ def update_axes(ax, mappable=None): # Add a new filed to store the high-pass filtered ADT g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) g.add_uv("adt") g.copy("adt", "adt_high") diff --git a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py index 4d9a67aa..fa27a3d1 100644 --- a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py +++ b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py @@ -30,18 +30,20 @@ def update_axes(ax, mappable=None): # %% # Load detection files and data to interp -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20160515.nc")) -c = EddiesObservations.load_file(data.get_path("Cyclonic_20160515.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20160515.nc")) +c = EddiesObservations.load_file(data.get_demo_path("Cyclonic_20160515.nc")) aviso_map = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) aviso_map.add_uv("adt") # %% # Compute and store eke in cm²/s² aviso_map.add_grid( - "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100 ** 2) + "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100**2) ) eke_kwargs = dict(vmin=1, vmax=1000, cmap="magma_r") diff --git a/examples/02_eddy_identification/pet_radius_vs_area.py b/examples/02_eddy_identification/pet_radius_vs_area.py index 0239ebe7..e34ad725 100644 --- a/examples/02_eddy_identification/pet_radius_vs_area.py +++ b/examples/02_eddy_identification/pet_radius_vs_area.py @@ -13,7 +13,7 @@ # %% # Load detection files -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) areas = list() # For each contour area will be compute in local reference for i in a: diff --git a/examples/02_eddy_identification/pet_shape_gallery.py b/examples/02_eddy_identification/pet_shape_gallery.py index 4e37d727..ed8df83d 100644 --- a/examples/02_eddy_identification/pet_shape_gallery.py +++ b/examples/02_eddy_identification/pet_shape_gallery.py @@ -25,7 +25,9 @@ def build_circle(x0, y0, r): # %% # We iterate over closed contours and sort with regards of shape error g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) c = Contours(g.x_c, g.y_c, g.grid("adt") * 100, arange(-50, 50, 0.2)) contours = dict() diff --git a/examples/02_eddy_identification/pet_sla_and_adt.py b/examples/02_eddy_identification/pet_sla_and_adt.py index 307d8cca..29dcc0a7 100644 --- a/examples/02_eddy_identification/pet_sla_and_adt.py +++ b/examples/02_eddy_identification/pet_sla_and_adt.py @@ -31,7 +31,9 @@ def update_axes(ax, mappable=None): # Load Input grid, ADT will be used to detect eddies g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) g.add_uv("adt", "ugos", "vgos") g.add_uv("sla", "ugosa", "vgosa") diff --git a/examples/02_eddy_identification/pet_statistics_on_identification.py b/examples/02_eddy_identification/pet_statistics_on_identification.py new file mode 100644 index 00000000..dbd73c61 --- /dev/null +++ b/examples/02_eddy_identification/pet_statistics_on_identification.py @@ -0,0 +1,105 @@ +""" +Stastics on identification files +================================ + +Some statistics on raw identification without any tracking +""" +from matplotlib import pyplot as plt +from matplotlib.dates import date2num +import numpy as np + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.observations.observation import EddiesObservations + +start_logger().setLevel("ERROR") + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title) + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + + +# %% +# We load demo sample and take only first year. +# +# Replace by a list of filename to apply on your own dataset. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:365] + +# %% +# Merge all identification datasets in one object +all_a = EddiesObservations.concatenate( + [EddiesObservations.load_file(i) for i in file_objects] +) + +# %% +# We define polygon bound +x0, x1, y0, y1 = 15, 20, 33, 38 +xs = np.array([[x0, x1, x1, x0, x0]], dtype="f8") +ys = np.array([[y0, y0, y1, y1, y0]], dtype="f8") +# Polygon object created for the match function use. +polygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys) + +# %% +# Geographic frequency of eddies +step = 0.125 +ax = start_axes("") +# Count pixel encompassed in each contour +g_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True) +m = g_a.display( + ax, cmap="terrain_r", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name="count" +) +ax.plot(polygon["contour_lon_e"][0], polygon["contour_lat_e"][0], "r") +update_axes(ax, m) + +# %% +# We use the match function to count the number of eddies that intersect the polygon defined previously +# `p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon. +i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False) +i_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True) + +# %% +dt = np.datetime64("1970-01-01") - np.datetime64("1950-01-01") +kw_hist = dict( + bins=date2num(np.arange(21900, 22300).astype("datetime64") - dt), histtype="step" +) +# translate julian day in datetime64 +t = all_a.time.astype("datetime64") - dt +# %% +# Number of eddies within a polygon +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Different ways to count eddies within a polygon") +ax.set_ylabel("Count") +m = all_a.mask_from_polygons(((xs, ys),)) +ax.hist(t[m], label="Eddy Center in polygon", **kw_hist) +ax.hist(t[i_s[c_s > 0]], label="Intersection Speed contour and polygon", **kw_hist) +ax.hist(t[i_e[c_e > 0]], label="Intersection Effective contour and polygon", **kw_hist) +ax.legend() +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() + +# %% +# Percent of the area of interest occupied by eddies. +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Percent of polygon occupied by an anticyclonic eddy") +ax.set_ylabel("Percent of polygon") +ax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label="speed contour", **kw_hist) +ax.hist( + t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label="effective contour", **kw_hist +) +ax.legend(), ax.set_ylim(0, 25) +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() diff --git a/examples/06_grid_manipulation/README.rst b/examples/06_grid_manipulation/README.rst index 7a300f82..a885d867 100644 --- a/examples/06_grid_manipulation/README.rst +++ b/examples/06_grid_manipulation/README.rst @@ -1,2 +1,2 @@ Grid Manipulation -================= \ No newline at end of file +================= diff --git a/examples/06_grid_manipulation/pet_advect.py b/examples/06_grid_manipulation/pet_advect.py new file mode 100644 index 00000000..d7cc67e9 --- /dev/null +++ b/examples/06_grid_manipulation/pet_advect.py @@ -0,0 +1,178 @@ +""" +Grid advection +============== + +Dummy advection which use only static geostrophic current, which didn't solve the complex circulation of the ocean. +""" +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, isnan, meshgrid, ones + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.observation import EddiesObservations + +# %% +# Load Input grid ADT +g = RegularGridDataset( + get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" +) +# Compute u/v from height +g.add_uv("adt") + +# %% +# Load detection files +a = EddiesObservations.load_file(get_demo_path("Anticyclonic_20160515.nc")) +c = EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")) + + +# %% +# Quiver from u/v with eddies +fig = plt.figure(figsize=(10, 5)) +ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) +ax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid() +x, y = meshgrid(g.x_c, g.y_c) +a.filled(ax, facecolors="r", alpha=0.1), c.filled(ax, facecolors="b", alpha=0.1) +_ = ax.quiver(x.T, y.T, g.grid("u"), g.grid("v"), scale=20) + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Anim +# ---- +# Particles setup +step_p = 1 / 8 +x, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p)) +x, y = x.reshape(-1), y.reshape(-1) +# Remove all original position that we can't advect at first place +m = ~isnan(g.interp("u", x, y)) +x0, y0 = x[m], y[m] +x, y = x0.copy(), y0.copy() + +# %% +# Movie properties +kwargs = dict(frames=arange(51), interval=100) +kw_p = dict(u_name="u", v_name="v", nb_step=2, time_step=21600) +frame_t = kw_p["nb_step"] * kw_p["time_step"] / 86400.0 + + +# %% +# Function +def anim_ax(**kw): + t = 0 + fig = plt.figure(figsize=(10, 5), dpi=55) + axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) + axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid() + a.filled(axes, facecolors="r", alpha=0.1), c.filled(axes, facecolors="b", alpha=0.1) + line = axes.plot([], [], "k", **kw)[0] + return fig, axes.text(21, 32.1, ""), line, t + + +def update(i_frame, t_step): + global t + x, y = p.__next__() + t += t_step + l.set_data(x, y) + txt.set_text(f"T0 + {t:.1f} days") + + +# %% +# Filament forward +# ^^^^^^^^^^^^^^^^ +# Draw 3 last position in one path for each particles., +# it could be run backward with `backward=True` option in filament method +p = g.filament(x, y, **kw_p, filament_size=3) +fig, txt, l, t = anim_ax(lw=0.5) +_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) + +# %% +# Particle forward +# ^^^^^^^^^^^^^^^^^ +# Forward advection of particles +p = g.advect(x, y, **kw_p) +fig, txt, l, t = anim_ax(ls="", marker=".", markersize=1) +_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) + +# %% +# We get last position and run backward until original position +p = g.advect(x, y, **kw_p, backward=True) +fig, txt, l, _ = anim_ax(ls="", marker=".", markersize=1) +_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,)) + +# %% +# Particles stat +# -------------- + +# %% +# Time_step settings +# ^^^^^^^^^^^^^^^^^^ +# Dummy experiment to test advection precision, we run particles 50 days forward and backward with different time step +# and we measure distance between new positions and original positions. +fig = plt.figure() +ax = fig.add_subplot(111) +kw = dict( + bins=arange(0, 50, 0.001), + cumulative=True, + weights=ones(x0.shape) / x0.shape[0] * 100.0, + histtype="step", +) +for time_step in (10800, 21600, 43200, 86400): + x, y = x0.copy(), y0.copy() + kw_advect = dict( + nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name="u", v_name="v" + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"{86400. / time_step:.0f} time step by day") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title("Distance after 50 days forward and 50 days backward") +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than") + +# %% +# Time duration +# ^^^^^^^^^^^^^ +# We keep same time_step but change time duration +fig = plt.figure() +ax = fig.add_subplot(111) +time_step = 10800 +for duration in (5, 50, 100): + x, y = x0.copy(), y0.copy() + kw_advect = dict( + nb_step=int(duration * 86400 / time_step), + time_step=time_step, + u_name="u", + v_name="v", + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"Time duration {duration} days") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title( + "Distance after N days forward and N days backward\nwith a time step of 1/8 days" +) +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than ") diff --git a/examples/06_grid_manipulation/pet_filter.py b/examples/06_grid_manipulation/pet_filter.py index 2975325d..ae4356d7 100644 --- a/examples/06_grid_manipulation/pet_filter.py +++ b/examples/06_grid_manipulation/pet_filter.py @@ -32,7 +32,9 @@ def update_axes(ax, mappable=None): # %% # All information will be for regular grid g = RegularGridDataset( - data.get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + data.get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), + "longitude", + "latitude", ) # %% # Kernel diff --git a/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py b/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py index 58a31374..388c9c7f 100644 --- a/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py +++ b/examples/06_grid_manipulation/pet_hide_pixel_out_eddies.py @@ -15,12 +15,12 @@ # %% # Load an eddy file which contains contours -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) # %% # Load a grid where we want found pixels in eddies or out g = RegularGridDataset( - data.get_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), "longitude", "latitude", ) diff --git a/examples/06_grid_manipulation/pet_lavd.py b/examples/06_grid_manipulation/pet_lavd.py new file mode 100644 index 00000000..a3ea846e --- /dev/null +++ b/examples/06_grid_manipulation/pet_lavd.py @@ -0,0 +1,177 @@ +""" +LAVD experiment +=============== + +Naive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation) method with a static velocity field. +In the current example we didn't remove a mean vorticity. + +Method are described here: + + - Abernathey, Ryan, and George Haller. "Transport by Lagrangian Vortices in the Eastern Pacific", + Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021, + https://doi.org/10.1175/JPO-D-17-0102.1 + - `Transport by Coherent Lagrangian Vortices`_, + R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019, + Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop + +.. _Transport by Coherent Lagrangian Vortices: + https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf + +""" +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, meshgrid, zeros + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.observation import EddiesObservations + + +# %% +def start_ax(title="", dpi=90): + fig = plt.figure(figsize=(16, 9), dpi=dpi) + ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) + ax.set_xlim(0, 32), ax.set_ylim(28, 46) + ax.set_title(title) + return fig, ax, ax.text(3, 32, "", fontsize=20) + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + cb = plt.colorbar( + mappable, + cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]), + orientation="horizontal", + ) + cb.set_label("Vorticity integration along trajectory at initial position") + return cb + + +kw_vorticity = dict(vmin=0, vmax=2e-5, cmap="viridis") + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Data +# ---- +# To compute vorticity (:math:`\omega`) we compute u/v field with a stencil and apply the following equation with stencil +# method : +# +# .. math:: +# \omega = \frac{\partial v}{\partial x} - \frac{\partial u}{\partial y} +g = RegularGridDataset( + get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" +) +g.add_uv("adt") +u_y = g.compute_stencil(g.grid("u"), vertical=True) +v_x = g.compute_stencil(g.grid("v")) +g.vars["vort"] = v_x - u_y + +# %% +# Display vorticity field +fig, ax, _ = start_ax() +mappable = g.display(ax, abs(g.grid("vort")), **kw_vorticity) +cb = update_axes(ax, mappable) +cb.set_label("Vorticity") + +# %% +# Particles +# --------- +# Particles specification +step = 1 / 32 +x_g, y_g = arange(0, 36, step), arange(28, 46, step) +x, y = meshgrid(x_g, y_g) +original_shape = x.shape +x, y = x.reshape(-1), y.reshape(-1) +print(f"{len(x)} particles advected") +# A frame every 8h +step_by_day = 3 +# Compute step of advection every 4h +nb_step = 2 +kw_p = dict( + nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name="u", v_name="v" +) +# Start a generator which at each iteration return new position at next time step +particule = g.advect(x, y, **kw_p, rk4=True) + +# %% +# LAVD +# ---- +lavd = zeros(original_shape) +# Advection time +nb_days = 8 +# Nb frame +nb_time = step_by_day * nb_days +i = 0.0 + + +# %% +# Anim +# ^^^^ +# Movie of LAVD integration at each integration time step. +def update(i_frame): + global lavd, i + i += 1 + x, y = particule.__next__() + # Interp vorticity on new_position + lavd += abs(g.interp("vort", x, y).reshape(original_shape) * 1 / nb_time) + txt.set_text(f"T0 + {i / step_by_day:.2f} days of advection") + pcolormesh.set_array(lavd / i * nb_time) + return pcolormesh, txt + + +kw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True) +fig, ax, txt = start_ax(dpi=60) +x_g_, y_g_ = ( + arange(0 - step / 2, 36 + step / 2, step), + arange(28 - step / 2, 46 + step / 2, step), +) +# pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular +pcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity) +update_axes(ax, pcolormesh) +_ = VideoAnimation(ax.figure, update, **kw_video) + +# %% +# Final LAVD +# ^^^^^^^^^^ + +# %% +# Format LAVD data +lavd = RegularGridDataset.with_array( + coordinates=("lon", "lat"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True +) + +# %% +# Display final LAVD with py eddy tracker detection. +# Period used for LAVD integration (8 days) is too short for a real use, but choose for example efficiency. +fig, ax, _ = start_ax() +mappable = lavd.display(ax, "lavd", **kw_vorticity) +EddiesObservations.load_file(get_demo_path("Anticyclonic_20160515.nc")).display( + ax, color="k" +) +EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")).display( + ax, color="k" +) +_ = update_axes(ax, mappable) diff --git a/examples/06_grid_manipulation/pet_okubo_weiss.py b/examples/06_grid_manipulation/pet_okubo_weiss.py index 577178c5..aa8a063e 100644 --- a/examples/06_grid_manipulation/pet_okubo_weiss.py +++ b/examples/06_grid_manipulation/pet_okubo_weiss.py @@ -1,8 +1,8 @@ r""" Get Okubo Weis -===================== +============== -.. math:: OW = S_n^2 + S_s^2 + \omega^2 +.. math:: OW = S_n^2 + S_s^2 - \omega^2 with normal strain (:math:`S_n`), shear strain (:math:`S_s`) and vorticity (:math:`\omega`) @@ -40,13 +40,13 @@ def update_axes(axes, mappable=None): # %% # Load detection files -a = EddiesObservations.load_file(data.get_path("Anticyclonic_20190223.nc")) -c = EddiesObservations.load_file(data.get_path("Cyclonic_20190223.nc")) +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +c = EddiesObservations.load_file(data.get_demo_path("Cyclonic_20190223.nc")) # %% # Load Input grid, ADT will be used to detect eddies g = RegularGridDataset( - data.get_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), + data.get_demo_path("nrt_global_allsat_phy_l4_20190223_20190226.nc"), "longitude", "latitude", ) diff --git a/examples/07_cube_manipulation/README.rst b/examples/07_cube_manipulation/README.rst new file mode 100644 index 00000000..7cecfbd4 --- /dev/null +++ b/examples/07_cube_manipulation/README.rst @@ -0,0 +1,2 @@ +Time grid computation +===================== diff --git a/examples/07_cube_manipulation/pet_cube.py b/examples/07_cube_manipulation/pet_cube.py new file mode 100644 index 00000000..cba6c85b --- /dev/null +++ b/examples/07_cube_manipulation/pet_cube.py @@ -0,0 +1,152 @@ +""" +Time advection +============== + +Example which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles. +""" +from datetime import datetime, timedelta + +# sphinx_gallery_thumbnail_number = 2 +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, isnan, meshgrid, ones + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection +from py_eddy_tracker.gui import GUI_AXES + +start_logger().setLevel("ERROR") + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +# Data +# ---- +# Load Input time grid ADT +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + # To create U/V variable + heigth="adt", +) + +# %% +# Anim +# ---- +# Particles setup +step_p = 1 / 8 +x, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p)) +x, y = x.reshape(-1), y.reshape(-1) +# Remove all original position that we can't advect at first place +t0 = 20181 +m = ~isnan(c[t0].interp("u", x, y)) +x0, y0 = x[m], y[m] +x, y = x0.copy(), y0.copy() + + +# %% +# Function +def anim_ax(**kw): + fig = plt.figure(figsize=(10, 5), dpi=55) + axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) + axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid() + line = axes.plot([], [], "k", **kw)[0] + return fig, axes.text(21, 32.1, ""), line + + +def update(_): + tt, xt, yt = f.__next__() + mappable.set_data(xt, yt) + d = timedelta(tt / 86400.0) + datetime(1950, 1, 1) + txt.set_text(f"{d:%Y/%m/%d-%H}") + + +# %% +f = c.filament(x, y, "u", "v", t_init=t0, nb_step=2, time_step=21600, filament_size=3) +fig, txt, mappable = anim_ax(lw=0.5) +ani = VideoAnimation(fig, update, frames=arange(160), interval=100) + + +# %% +# Particules stat +# --------------- +# Time_step settings +# ^^^^^^^^^^^^^^^^^^ +# Dummy experiment to test advection precision, we run particles 50 days forward and backward with different time step +# and we measure distance between new positions and original positions. +fig = plt.figure() +ax = fig.add_subplot(111) +kw = dict( + bins=arange(0, 50, 0.002), + cumulative=True, + weights=ones(x0.shape) / x0.shape[0] * 100.0, + histtype="step", +) +kw_p = dict(u_name="u", v_name="v", nb_step=1) +for time_step in (10800, 21600, 43200, 86400): + x, y = x0.copy(), y0.copy() + nb = int(30 * 86400 / time_step) + # Go forward + p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + # Go backward + p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"{86400. / time_step:.0f} time step by day") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title("Distance after 50 days forward and 50 days backward") +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than") + +# %% +# Time duration +# ^^^^^^^^^^^^^ +# We keep same time_step but change time duration +fig = plt.figure() +ax = fig.add_subplot(111) +time_step = 10800 +for duration in (10, 40, 80): + x, y = x0.copy(), y0.copy() + nb = int(duration * 86400 / time_step) + # Go forward + p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + # Go backward + p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p) + for i in range(nb): + t_, _, _ = p.__next__() + d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 + ax.hist(d, **kw, label=f"Time duration {duration} days") +ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() +ax.set_title( + "Distance after N days forward and N days backward\nwith a time step of 1/8 days" +) +ax.set_xlabel("Distance between original position and final position (in degrees)") +_ = ax.set_ylabel("Percent of particles with distance lesser than ") diff --git a/examples/07_cube_manipulation/pet_fsle_med.py b/examples/07_cube_manipulation/pet_fsle_med.py new file mode 100644 index 00000000..9d78ea02 --- /dev/null +++ b/examples/07_cube_manipulation/pet_fsle_med.py @@ -0,0 +1,200 @@ +""" +FSLE experiment in med +====================== + +Example to build Finite Size Lyapunov Exponents, parameter values must be adapted for your case. + +Example use a method similar to `AVISO flse`_ + +.. _AVISO flse: + https://www.aviso.altimetry.fr/en/data/products/value-added-products/ + fsle-finite-size-lyapunov-exponents/fsle-description.html + +""" + +from matplotlib import pyplot as plt +from numba import njit +from numpy import arange, arctan2, empty, isnan, log, ma, meshgrid, ones, pi, zeros + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset + +start_logger().setLevel("ERROR") + + +# %% +# ADT in med +# ---------- +# :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is +# made for data stores in time cube, you could use also +# :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to +# load data-cube from multiple file. +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + # To create U/V variable + heigth="adt", +) + + +# %% +# Methods to compute FSLE +# ----------------------- +@njit(cache=True, fastmath=True) +def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): + """ + Check if distance between eastern or northern particle to center particle is bigger than `dist_max` + """ + nb_p = x.shape[0] // 3 + delta = dist_max**2 + for i in range(nb_p): + i0 = i * 3 + i_n = i0 + 1 + i_e = i0 + 2 + # If particle already set, we skip + if m[i0] or m[i_n] or m[i_e]: + continue + # Distance with north + dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n] + dn = dxn**2 + dyn**2 + # Distance with east + dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e] + de = dxe**2 + dye**2 + + if dn >= delta or de >= delta: + s1 = dn + de + at1 = 2 * (dxe * dxn + dye * dyn) + at2 = de - dn + s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * ( + (dxn - dye) ** 2 + (dxe + dyn) ** 2 + ) + flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5)) + theta[i] = arctan2(at1, at2 + s2) * 180 / pi + # To know where value are set + m_set[i] = False + # To stop particle advection + m[i0], m[i_n], m[i_e] = True, True, True + + +@njit(cache=True) +def build_triplet(x, y, step=0.02): + """ + Triplet building for each position we add east and north point with defined step + """ + nb_x = x.shape[0] + x_ = empty(nb_x * 3, dtype=x.dtype) + y_ = empty(nb_x * 3, dtype=y.dtype) + for i in range(nb_x): + i0 = i * 3 + i_n, i_e = i0 + 1, i0 + 2 + x__, y__ = x[i], y[i] + x_[i0], y_[i0] = x__, y__ + x_[i_n], y_[i_n] = x__, y__ + step + x_[i_e], y_[i_e] = x__ + step, y__ + return x_, y_ + + +# %% +# Settings +# -------- + +# Step in degrees for ouput +step_grid_out = 1 / 25.0 +# Initial separation in degrees +dist_init = 1 / 50.0 +# Final separation in degrees +dist_max = 1 / 5.0 +# Time of start +t0 = 20268 +# Number of time step by days +time_step_by_days = 5 +# Maximal time of advection +# Here we limit because our data cube cover only 3 month +nb_days = 85 +# Backward or forward +backward = True + +# %% +# Particles +# --------- +x0_, y0_ = -5, 30 +lon_p = arange(x0_, x0_ + 43, step_grid_out) +lat_p = arange(y0_, y0_ + 16, step_grid_out) +y0, x0 = meshgrid(lat_p, lon_p) +grid_shape = x0.shape +x0, y0 = x0.reshape(-1), y0.reshape(-1) +# Identify all particle not on land +m = ~isnan(c[t0].interp("adt", x0, y0)) +x0, y0 = x0[m], y0[m] + +# %% +# FSLE +# ---- + +# Array to compute fsle +fsle = zeros(x0.shape[0], dtype="f4") +theta = zeros(x0.shape[0], dtype="f4") +mask = ones(x0.shape[0], dtype="f4") +x, y = build_triplet(x0, y0, dist_init) +used = zeros(x.shape[0], dtype="bool") + +# advection generator +kw = dict( + t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name="u", v_name="v" +) +p = c.advect(x, y, time_step=86400 / time_step_by_days, **kw) + +# We check at each step of advection if particle distance is over `dist_max` +for i in range(time_step_by_days * nb_days): + t, xt, yt = p.__next__() + dt = t / 86400.0 - t0 + check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init) + +# Get index with original_position +i = ((x0 - x0_) / step_grid_out).astype("i4") +j = ((y0 - y0_) / step_grid_out).astype("i4") +fsle_ = empty(grid_shape, dtype="f4") +theta_ = empty(grid_shape, dtype="f4") +mask_ = ones(grid_shape, dtype="bool") +fsle_[i, j] = fsle +theta_[i, j] = theta +mask_[i, j] = mask +# Create a grid object +fsle_custom = RegularGridDataset.with_array( + coordinates=("lon", "lat"), + datas=dict( + fsle=ma.array(fsle_, mask=mask_), + theta=ma.array(theta_, mask=mask_), + lon=lon_p, + lat=lat_p, + ), + centered=True, +) + +# %% +# Display FSLE +# ------------ +fig = plt.figure(figsize=(13, 5), dpi=150) +ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) +ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) +ax.set_aspect("equal") +ax.set_title("Finite size lyapunov exponent", weight="bold") +kw = dict(cmap="viridis_r", vmin=-20, vmax=0) +m = fsle_custom.display(ax, 1 / fsle_custom.grid("fsle"), **kw) +ax.grid() +_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9])) +# %% +# Display Theta +# ------------- +fig = plt.figure(figsize=(13, 5), dpi=150) +ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) +ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) +ax.set_aspect("equal") +ax.set_title("Theta from finite size lyapunov exponent", weight="bold") +kw = dict(cmap="Spectral_r", vmin=-180, vmax=180) +m = fsle_custom.display(ax, fsle_custom.grid("theta"), **kw) +ax.grid() +_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9])) diff --git a/examples/07_cube_manipulation/pet_lavd_detection.py b/examples/07_cube_manipulation/pet_lavd_detection.py new file mode 100644 index 00000000..4dace120 --- /dev/null +++ b/examples/07_cube_manipulation/pet_lavd_detection.py @@ -0,0 +1,218 @@ +""" +LAVD detection and geometric detection +====================================== + +Naive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation). +In the current example we didn't remove a mean vorticity. + +Method are described here: + + - Abernathey, Ryan, and George Haller. "Transport by Lagrangian Vortices in the Eastern Pacific", + Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021, + https://doi.org/10.1175/JPO-D-17-0102.1 + - `Transport by Coherent Lagrangian Vortices`_, + R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019, + Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop + +.. _Transport by Coherent Lagrangian Vortices: + https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf + +""" +from datetime import datetime + +from matplotlib import pyplot as plt +from numpy import arange, isnan, ma, meshgrid, zeros + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES + +start_logger().setLevel("ERROR") + + +# %% +class LAVDGrid(RegularGridDataset): + def init_speed_coef(self, uname="u", vname="v"): + """Hack to be able to identify eddy with LAVD field""" + self._speed_ev = self.grid("lavd") + + @classmethod + def from_(cls, x, y, z): + z.mask += isnan(z.data) + datas = dict(lavd=z, lon=x, lat=y) + return cls.with_array(coordinates=("lon", "lat"), datas=datas, centered=True) + + +# %% +def start_ax(title="", dpi=90): + fig = plt.figure(figsize=(12, 5), dpi=dpi) + ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=GUI_AXES) + ax.set_xlim(-6, 36), ax.set_ylim(31, 45) + ax.set_title(title) + return fig, ax, ax.text(3, 32, "", fontsize=20) + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + cb = plt.colorbar( + mappable, + cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]), + orientation="horizontal", + ) + cb.set_label("LAVD at initial position") + return cb + + +kw_lavd = dict(vmin=0, vmax=2e-5, cmap="viridis") + +# %% +# Data +# ---- + +# Load data cube of 3 month +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + heigth="adt", +) + +# Add vorticity at each time step +for g in c: + u_y = g.compute_stencil(g.grid("u"), vertical=True) + v_x = g.compute_stencil(g.grid("v")) + g.vars["vort"] = v_x - u_y + +# %% +# Particles +# --------- + +# Time properties, for example with advection only 25 days +nb_days, step_by_day = 25, 6 +nb_time = step_by_day * nb_days +kw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name="u", v_name="v") +t0 = 20236 +t0_grid = c[t0] +# Geographic properties, we use a coarser resolution for time consuming reasons +step = 1 / 32.0 +x_g, y_g = arange(-6, 36, step), arange(30, 46, step) +x0, y0 = meshgrid(x_g, y_g) +original_shape = x0.shape +x0, y0 = x0.reshape(-1), y0.reshape(-1) +# Get all particles in defined area +m = ~isnan(t0_grid.interp("vort", x0, y0)) +x0, y0 = x0[m], y0[m] +print(f"{x0.size} particles advected") +# Gridded mask +m = m.reshape(original_shape) + +# %% +# LAVD forward (dynamic field) +# ---------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p) +for _ in range(nb_time): + t, x, y = p.__next__() + lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a forward advection") +mappable = lavd_forward.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# LAVD backward (dynamic field) +# ----------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p) +for i in range(nb_time): + t, x, y = p.__next__() + lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a backward advection") +mappable = lavd_backward.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# LAVD forward (static field) +# --------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = t0_grid.advect(x0.copy(), y0.copy(), **kw_p) +for _ in range(nb_time): + x, y = p.__next__() + lavd_ += abs(t0_grid.interp("vort", x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a forward advection on a static velocity field") +mappable = lavd_forward_static.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# LAVD backward (static field) +# ---------------------------- +lavd = zeros(original_shape) +lavd_ = lavd[m] +p = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p) +for i in range(nb_time): + x, y = p.__next__() + lavd_ += abs(t0_grid.interp("vort", x, y)) +lavd[m] = lavd_ / nb_time +# Put LAVD result in a standard py eddy tracker grid +lavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T) +# Display +fig, ax, _ = start_ax("LAVD with a backward advection on a static velocity field") +mappable = lavd_backward_static.display(ax, "lavd", **kw_lavd) +_ = update_axes(ax, mappable) + +# %% +# Contour detection +# ----------------- +# To extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options. +# It will produce false amplitude and speed. +kw_ident = dict( + force_speed_unit="m/s", + force_height_unit="m", + pixel_limit=(40, 200000), + date=datetime(2005, 5, 18), + uname=None, + vname=None, + grid_height="lavd", + shape_error=70, + step=1e-6, +) +fig, ax, _ = start_ax("Detection of eddies with several method") +t0_grid.bessel_high_filter("adt", 700) +a, c = t0_grid.eddy_identification( + "adt", "u", "v", kw_ident["date"], step=0.002, shape_error=70 +) +kw_ed = dict(ax=ax, intern=True, ref=-10) +a.filled( + facecolors="#FFEFCD", label="Anticyclonic SSH detection {nb_obs} eddies", **kw_ed +) +c.filled(facecolors="#DEDEDE", label="Cyclonic SSH detection {nb_obs} eddies", **kw_ed) +kw_cont = dict(ax=ax, extern_only=True, ls="-", ref=-10) +forward, _ = lavd_forward.eddy_identification(**kw_ident) +forward.display(label="LAVD forward {nb_obs} eddies", color="g", **kw_cont) +backward, _ = lavd_backward.eddy_identification(**kw_ident) +backward.display(label="LAVD backward {nb_obs} eddies", color="r", **kw_cont) +forward, _ = lavd_forward_static.eddy_identification(**kw_ident) +forward.display(label="LAVD forward static {nb_obs} eddies", color="cyan", **kw_cont) +backward, _ = lavd_backward_static.eddy_identification(**kw_ident) +backward.display( + label="LAVD backward static {nb_obs} eddies", color="orange", **kw_cont +) +ax.legend() +update_axes(ax) diff --git a/examples/07_cube_manipulation/pet_particles_drift.py b/examples/07_cube_manipulation/pet_particles_drift.py new file mode 100644 index 00000000..3d7aa1a4 --- /dev/null +++ b/examples/07_cube_manipulation/pet_particles_drift.py @@ -0,0 +1,46 @@ +""" +Build path of particle drifting +=============================== + +""" + +from matplotlib import pyplot as plt +from numpy import arange, meshgrid + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection + +start_logger().setLevel("ERROR") + +# %% +# Load data cube +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + unset=True, +) + +# %% +# Advection properties +nb_days, step_by_day = 10, 6 +nb_time = step_by_day * nb_days +kw_p = dict(nb_step=1, time_step=86400 / step_by_day) +t0 = 20210 + +# %% +# Get paths +x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5)) +x0, y0 = x0.reshape(-1), y0.reshape(-1) +t, x, y = c.path(x0, y0, h_name="adt", t_init=t0, **kw_p, nb_time=nb_time) + +# %% +# Plot paths +ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect="equal") +ax.plot(x0, y0, "k.", ms=20) +ax.plot(x, y, lw=3) +ax.set_title("10 days particle paths") +ax.set_xlim(31, 35), ax.set_ylim(32, 34.5) +ax.grid() diff --git a/examples/08_tracking_manipulation/README.rst b/examples/08_tracking_manipulation/README.rst index 2626f19d..a971049f 100644 --- a/examples/08_tracking_manipulation/README.rst +++ b/examples/08_tracking_manipulation/README.rst @@ -1,2 +1,4 @@ Tracking Manipulation ===================== + +Method to subset and display atlas. \ No newline at end of file diff --git a/examples/08_tracking_manipulation/pet_display_field.py b/examples/08_tracking_manipulation/pet_display_field.py index b1add536..b943a2ba 100644 --- a/examples/08_tracking_manipulation/pet_display_field.py +++ b/examples/08_tracking_manipulation/pet_display_field.py @@ -4,15 +4,15 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental cyclonic atlas, we keep only eddies which are follow more than 180 days c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) c = c.extract_with_length((180, -1)) diff --git a/examples/08_tracking_manipulation/pet_display_track.py b/examples/08_tracking_manipulation/pet_display_track.py index 624395ac..b15d51d7 100644 --- a/examples/08_tracking_manipulation/pet_display_track.py +++ b/examples/08_tracking_manipulation/pet_display_track.py @@ -4,18 +4,20 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load experimental atlas a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) print(a) diff --git a/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py new file mode 100644 index 00000000..8161ad81 --- /dev/null +++ b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py @@ -0,0 +1,94 @@ +""" +Correspondances +=============== + +Correspondances is a mechanism to intend to continue tracking with new detection + +""" + +import logging + +# %% +from matplotlib import pyplot as plt +from netCDF4 import Dataset + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker + +# In order to hide some warning +import py_eddy_tracker.observations.observation +from py_eddy_tracker.tracking import Correspondances + +py_eddy_tracker.observations.observation._display_check_warning = False + + +# %% +def plot_eddy(ed): + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_axes([0.05, 0.03, 0.90, 0.94]) + ed.plot(ax, ref=-10, marker="x") + lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True) + plt.colorbar(lc).set_label("Time in Julian days (from 1950/01/01)") + ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3) + ax.set_aspect("equal") + ax.grid() + + +# %% +# Get remote data, we will keep only 20 first days, +# `get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename +# and don't mix cyclonic and anticyclonic files. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:20] + +# %% +# We run a traking with a tracker which use contour overlap, on 10 first time step +c_first_run = Correspondances( + datasets=file_objects[:10], class_method=AreaTracker, virtual=4 +) +start_logger().setLevel("INFO") +c_first_run.track() +start_logger().setLevel("WARNING") +with Dataset("correspondances.nc", "w") as h: + c_first_run.to_netcdf(h) +# Next step are done only to build atlas and display it +c_first_run.prepare_merging() + +# We have now an eddy object +eddies_area_tracker = c_first_run.merge(raw_data=False) +eddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0 +eddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1) + +# %% +# Plot from first ten days +plot_eddy(eddies_area_tracker) + +# %% +# Restart from previous run +# ------------------------- +# We give all filenames, the new one and filename from previous run +c_second_run = Correspondances( + datasets=file_objects[:20], + # This parameter must be identical in each run + class_method=AreaTracker, + virtual=4, + # Previous saved correspondancs + previous_correspondance="correspondances.nc", +) +start_logger().setLevel("INFO") +c_second_run.track() +start_logger().setLevel("WARNING") +c_second_run.prepare_merging() +# We have now another eddy object +eddies_area_tracker_extend = c_second_run.merge(raw_data=False) +eddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0 +eddies_area_tracker_extend.filled_by_interpolation( + eddies_area_tracker_extend.virtual == 1 +) + + +# %% +# Plot with time extension +plot_eddy(eddies_area_tracker_extend) diff --git a/examples/08_tracking_manipulation/pet_one_track.py b/examples/08_tracking_manipulation/pet_one_track.py index 710ab6cf..a2536c34 100644 --- a/examples/08_tracking_manipulation/pet_one_track.py +++ b/examples/08_tracking_manipulation/pet_one_track.py @@ -2,15 +2,17 @@ One Track =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load experimental atlas, and we select one eddy a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) eddy = a.extract_ids([9672]) eddy_f = a.extract_ids([9672]) diff --git a/examples/08_tracking_manipulation/pet_run_a_tracking.py b/examples/08_tracking_manipulation/pet_run_a_tracking.py index f7ac5ca9..15d8b18b 100644 --- a/examples/08_tracking_manipulation/pet_run_a_tracking.py +++ b/examples/08_tracking_manipulation/pet_run_a_tracking.py @@ -7,16 +7,16 @@ # %% -from py_eddy_tracker.data import get_remote_sample +from py_eddy_tracker.data import get_remote_demo_sample from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker from py_eddy_tracker.gui import GUI from py_eddy_tracker.tracking import Correspondances # %% # Get remote data, we will keep only 180 first days, -# `get_remote_sample` function is only to get demo dataset, in your own case give a list of identification filename +# `get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename # and don't mix cyclonic and anticyclonic files. -file_objects = get_remote_sample( +file_objects = get_remote_demo_sample( "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" )[:180] diff --git a/examples/08_tracking_manipulation/pet_select_track_across_area.py b/examples/08_tracking_manipulation/pet_select_track_across_area.py index bd18585d..58184e1f 100644 --- a/examples/08_tracking_manipulation/pet_select_track_across_area.py +++ b/examples/08_tracking_manipulation/pet_select_track_across_area.py @@ -3,15 +3,15 @@ ============================ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load experimental atlas, we filter position to have nice display c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) c.position_filter(median_half_window=1, loess_half_window=5) diff --git a/examples/08_tracking_manipulation/pet_track_anim.py b/examples/08_tracking_manipulation/pet_track_anim.py index f65ad157..94e09ad3 100644 --- a/examples/08_tracking_manipulation/pet_track_anim.py +++ b/examples/08_tracking_manipulation/pet_track_anim.py @@ -2,7 +2,9 @@ Track animation =============== -Run in a terminal this script, which allow to watch eddy evolution +Run in a terminal this script, which allow to watch eddy evolution. + +You could use also *EddyAnim* script to display/save animation. """ import py_eddy_tracker_sample @@ -13,7 +15,9 @@ # %% # Load experimental atlas, and we select one eddy a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) # We get only 300 first step to save time of documentation builder eddy = a.extract_ids([9672]).index(slice(0, 300)) diff --git a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py index 3fb04450..b686fd67 100644 --- a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py +++ b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py @@ -2,33 +2,59 @@ Track animation with standard matplotlib ======================================== -Run in a terminal this script, which allow to watch eddy evolution +Run in a terminal this script, which allow to watch eddy evolution. + +You could use also *EddyAnim* script to display/save animation. """ -import py_eddy_tracker_sample +import re + from matplotlib.animation import FuncAnimation from numpy import arange +import py_eddy_tracker_sample from py_eddy_tracker.appli.gui import Anim from py_eddy_tracker.observations.tracking import TrackEddiesObservations +# sphinx_gallery_thumbnail_path = '_static/no_image.png' + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + # %% # Load experimental atlas, and we select one eddy a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) eddy = a.extract_ids([9672]) # %% # Run animation -a = Anim(eddy, intern=True, figsize=(8, 3.5), cmap="magma_r", nb_step=6) +a = Anim(eddy, intern=True, figsize=(8, 3.5), cmap="magma_r", nb_step=5, dpi=50) a.txt.set_position((17, 34.6)) a.ax.set_xlim(16.5, 23) a.ax.set_ylim(34.5, 37) # arguments to get full animation -# kwargs = dict(frames=arange(*a.period), interval=50) -# arguments to reduce compute cost for doucmentation, we display only every 10 days -kwargs = dict(frames=arange(*a.period)[200:800:10], save_count=60, interval=200) +kwargs = dict(frames=arange(*a.period)[300:800], interval=90) -ani = FuncAnimation(a.fig, a.func_animation, **kwargs) +ani = VideoAnimation(a.fig, a.func_animation, **kwargs) diff --git a/examples/10_tracking_diagnostics/README.rst b/examples/10_tracking_diagnostics/README.rst index 1c4e690a..2030c0cc 100644 --- a/examples/10_tracking_diagnostics/README.rst +++ b/examples/10_tracking_diagnostics/README.rst @@ -1,2 +1,4 @@ Tracking diagnostics ==================== + +Method to produce statistics with eddies atlas. \ No newline at end of file diff --git a/examples/10_tracking_diagnostics/pet_birth_and_death.py b/examples/10_tracking_diagnostics/pet_birth_and_death.py index 612c32f7..b67993a2 100644 --- a/examples/10_tracking_diagnostics/pet_birth_and_death.py +++ b/examples/10_tracking_diagnostics/pet_birth_and_death.py @@ -5,8 +5,8 @@ Following figures are based on https://doi.org/10.1016/j.pocean.2011.01.002 """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -41,10 +41,12 @@ def update_axes(ax, mappable=None): ) ) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) # %% diff --git a/examples/10_tracking_diagnostics/pet_center_count.py b/examples/10_tracking_diagnostics/pet_center_count.py index 295299cd..77a4dcda 100644 --- a/examples/10_tracking_diagnostics/pet_center_count.py +++ b/examples/10_tracking_diagnostics/pet_center_count.py @@ -2,21 +2,24 @@ Count center ============ -Do Geo stat with center and compare with frequency method show: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py` +Do Geo stat with center and compare with frequency method +show: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py` """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.colors import LogNorm +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) # %% @@ -24,7 +27,7 @@ step = 0.125 bins = ((-10, 37, step), (30, 46, step)) kwargs_pcolormesh = dict( - cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step ** 2), name="count" + cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step**2), name="count" ) @@ -61,7 +64,7 @@ g_c.vars["count"] = ratio m = g_c.display( - ax_ratio, name="count", vmin=0.1, vmax=10, norm=LogNorm(), cmap="coolwarm_r" + ax_ratio, name="count", norm=LogNorm(vmin=0.1, vmax=10), cmap="coolwarm_r" ) plt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.2])) diff --git a/examples/10_tracking_diagnostics/pet_geographic_stats.py b/examples/10_tracking_diagnostics/pet_geographic_stats.py index 137133b3..a2e3f6b5 100644 --- a/examples/10_tracking_diagnostics/pet_geographic_stats.py +++ b/examples/10_tracking_diagnostics/pet_geographic_stats.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -22,10 +22,12 @@ def start_axes(title): # %% # Load an experimental med atlas over a period of 26 years (1993-2019), we merge the 2 datasets a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) a = a.merge(c) diff --git a/examples/10_tracking_diagnostics/pet_groups.py b/examples/10_tracking_diagnostics/pet_groups.py index e080310b..deedcc3f 100644 --- a/examples/10_tracking_diagnostics/pet_groups.py +++ b/examples/10_tracking_diagnostics/pet_groups.py @@ -3,16 +3,18 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones, percentile +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) # %% diff --git a/examples/10_tracking_diagnostics/pet_histo.py b/examples/10_tracking_diagnostics/pet_histo.py index cd7abf49..abf97c38 100644 --- a/examples/10_tracking_diagnostics/pet_histo.py +++ b/examples/10_tracking_diagnostics/pet_histo.py @@ -3,19 +3,21 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) kwargs_a = dict(label="Anticyclonic", color="r", histtype="step", density=True) kwargs_c = dict(label="Cyclonic", color="b", histtype="step", density=True) diff --git a/examples/10_tracking_diagnostics/pet_lifetime.py b/examples/10_tracking_diagnostics/pet_lifetime.py index 95f75718..4e2500fd 100644 --- a/examples/10_tracking_diagnostics/pet_lifetime.py +++ b/examples/10_tracking_diagnostics/pet_lifetime.py @@ -3,19 +3,21 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) nb_year = (a.period[1] - a.period[0] + 1) / 365.25 diff --git a/examples/10_tracking_diagnostics/pet_normalised_lifetime.py b/examples/10_tracking_diagnostics/pet_normalised_lifetime.py new file mode 100644 index 00000000..1c84a8cc --- /dev/null +++ b/examples/10_tracking_diagnostics/pet_normalised_lifetime.py @@ -0,0 +1,78 @@ +""" +Normalised Eddy Lifetimes +========================= + +Example from Evan Mason +""" +from matplotlib import pyplot as plt +from numba import njit +from numpy import interp, linspace, zeros +from py_eddy_tracker_sample import get_demo_path + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +@njit(cache=True) +def sum_profile(x_new, y, out): + """Will sum all interpolated given array""" + out += interp(x_new, linspace(0, 1, y.size), y) + + +class MyObs(TrackEddiesObservations): + def eddy_norm_lifetime(self, name, nb, factor=1): + """ + :param str,array name: Array or field name + :param int nb: size of output array + """ + y = self.parse_varname(name) + x = linspace(0, 1, nb) + out = zeros(nb, dtype=y.dtype) + nb_track = 0 + for i, b0, b1 in self.iter_on("track"): + y_ = y[i] + size_ = y_.size + if size_ == 0: + continue + sum_profile(x, y_, out) + nb_track += 1 + return x, out / nb_track * factor + + +# %% +# Load atlas +# ---------- +kw = dict(include_vars=("speed_radius", "amplitude", "track")) +a = MyObs.load_file( + get_demo_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr"), **kw +) +c = MyObs.load_file(get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr"), **kw) + +nb_max_a = a.nb_obs_by_track.max() +nb_max_c = c.nb_obs_by_track.max() + +# %% +# Compute normalised lifetime +# --------------------------- + +# Radius +AC_radius = a.eddy_norm_lifetime("speed_radius", nb=nb_max_a, factor=1e-3) +CC_radius = c.eddy_norm_lifetime("speed_radius", nb=nb_max_c, factor=1e-3) +# Amplitude +AC_amplitude = a.eddy_norm_lifetime("amplitude", nb=nb_max_a, factor=1e2) +CC_amplitude = c.eddy_norm_lifetime("amplitude", nb=nb_max_c, factor=1e2) + +# %% +# Figure +# ------ +fig, (ax0, ax1) = plt.subplots(nrows=2, figsize=(8, 6)) + +ax0.set_title("Normalised Mean Radius") +ax0.plot(*AC_radius), ax0.plot(*CC_radius) +ax0.set_ylabel("Radius (km)"), ax0.grid() +ax0.set_xlim(0, 1), ax0.set_ylim(0, None) + +ax1.set_title("Normalised Mean Amplitude") +ax1.plot(*AC_amplitude, label="AC"), ax1.plot(*CC_amplitude, label="CC") +ax1.set_ylabel("Amplitude (cm)"), ax1.grid(), ax1.legend() +_ = ax1.set_xlim(0, 1), ax1.set_ylim(0, None) diff --git a/examples/10_tracking_diagnostics/pet_pixel_used.py b/examples/10_tracking_diagnostics/pet_pixel_used.py index 43e59d68..75a826d6 100644 --- a/examples/10_tracking_diagnostics/pet_pixel_used.py +++ b/examples/10_tracking_diagnostics/pet_pixel_used.py @@ -2,21 +2,24 @@ Count pixel used ================ -Do Geo stat with frequency and compare with center count method: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py` +Do Geo stat with frequency and compare with center count +method: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py` """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.colors import LogNorm +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) # %% @@ -60,7 +63,7 @@ g_c.vars["count"] = ratio m = g_c.display( - ax_ratio, name="count", vmin=0.1, vmax=10, norm=LogNorm(), cmap="coolwarm_r" + ax_ratio, name="count", norm=LogNorm(vmin=0.1, vmax=10), cmap="coolwarm_r" ) plt.colorbar(m, cax=fig.add_axes([0.95, 0.02, 0.01, 0.2])) diff --git a/examples/10_tracking_diagnostics/pet_propagation.py b/examples/10_tracking_diagnostics/pet_propagation.py index f13b03bb..e6bc6c1b 100644 --- a/examples/10_tracking_diagnostics/pet_propagation.py +++ b/examples/10_tracking_diagnostics/pet_propagation.py @@ -3,9 +3,9 @@ ===================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones +import py_eddy_tracker_sample from py_eddy_tracker.generic import cumsum_by_track from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -13,10 +13,12 @@ # %% # Load an experimental med atlas over a period of 26 years (1993-2019) a = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr") + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) ) c = TrackEddiesObservations.load_file( - py_eddy_tracker_sample.get_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") ) nb_year = (a.period[1] - a.period[0] + 1) / 365.25 diff --git a/examples/12_external_data/README.rst b/examples/12_external_data/README.rst index cc9d4e2f..7ecbe30b 100644 --- a/examples/12_external_data/README.rst +++ b/examples/12_external_data/README.rst @@ -1,2 +1,4 @@ External data ============= + +Eddies comparison with external data diff --git a/examples/12_external_data/pet_SST_collocation.py b/examples/12_external_data/pet_SST_collocation.py index e6337754..defe00df 100644 --- a/examples/12_external_data/pet_SST_collocation.py +++ b/examples/12_external_data/pet_SST_collocation.py @@ -17,8 +17,10 @@ date = datetime(2016, 7, 7) -filename_alt = data.get_path(f"dt_blacksea_allsat_phy_l4_{date:%Y%m%d}_20200801.nc") -filename_sst = data.get_path( +filename_alt = data.get_demo_path( + f"dt_blacksea_allsat_phy_l4_{date:%Y%m%d}_20200801.nc" +) +filename_sst = data.get_demo_path( f"{date:%Y%m%d}000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc" ) var_name_sst = "analysed_sst" @@ -27,10 +29,10 @@ # %% # Loading data -# ----------------------------- +# ------------ sst = RegularGridDataset(filename=filename_sst, x_name="lon", y_name="lat") alti = RegularGridDataset( - data.get_path(filename_alt), x_name="longitude", y_name="latitude" + data.get_demo_path(filename_alt), x_name="longitude", y_name="latitude" ) # We can use `Grid` tools to interpolate ADT on the sst grid sst.regrid(alti, "sla") @@ -58,14 +60,14 @@ def update_axes(ax, mappable=None, unit=""): # %% # ADT first display -# ----------------------------- +# ----------------- ax = start_axes("SLA", extent=extent) m = sst.display(ax, "sla", vmin=0.05, vmax=0.35) update_axes(ax, m, unit="[m]") # %% # SST first display -# ----------------------------- +# ----------------- # %% # We can now plot SST from `sst` diff --git a/examples/12_external_data/pet_drifter_loopers.py b/examples/12_external_data/pet_drifter_loopers.py new file mode 100644 index 00000000..5266db7b --- /dev/null +++ b/examples/12_external_data/pet_drifter_loopers.py @@ -0,0 +1,153 @@ +""" +Colocate looper with eddy from altimetry +======================================== + +All loopers data used in this example are a subset from the dataset described in this article +[Lumpkin, R. : Global characteristics of coherent vortices from surface drifter trajectories](https://doi.org/10.1002/2015JC011435) +""" + +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +import numpy as np +import py_eddy_tracker_sample + +from py_eddy_tracker import data +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], aspect="equal") + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_title(title, weight="bold") + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Load eddies dataset +cyclonic_eddies = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") +) +anticyclonic_eddies = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) +) + +# %% +# Load loopers dataset +loopers_med = TrackEddiesObservations.load_file( + data.get_demo_path("loopers_lumpkin_med.nc") +) + +# %% +# Global view +# =========== +ax = start_axes("All drifters available in Med from Lumpkin dataset") +loopers_med.plot(ax, lw=0.5, color="r", ref=-10) +update_axes(ax) + +# %% +# One segment of drifter +# ====================== +# +# Get a drifter segment (the indexes used have no correspondance with the original dataset). +looper = loopers_med.extract_ids((3588,)) +fig = plt.figure(figsize=(16, 6)) +ax = fig.add_subplot(111, aspect="equal") +looper.plot(ax, lw=0.5, label="Original position of drifter") +looper_filtered = looper.copy() +looper_filtered.position_filter(1, 13) +s = looper_filtered.scatter( + ax, + "time", + cmap=plt.get_cmap("Spectral_r", 20), + label="Filtered position of drifter", +) +plt.colorbar(s).set_label("time (days from 1/1/1950)") +ax.legend() +ax.grid() + +# %% +# Try to find a detected eddies with adt at same place. We used filtered track to simulate an eddy center +match = looper_filtered.close_tracks( + anticyclonic_eddies, method="close_center", delta=0.1, nb_obs_min=50 +) +fig = plt.figure(figsize=(16, 6)) +ax = fig.add_subplot(111, aspect="equal") +looper.plot(ax, lw=0.5, label="Original position of drifter") +looper_filtered.plot(ax, lw=1.5, label="Filtered position of drifter") +match.plot(ax, lw=1.5, label="Matched eddy") +ax.legend() +ax.grid() + +# %% +# Display radius of this 2 datasets. +fig = plt.figure(figsize=(20, 8)) +ax = fig.add_subplot(111) +ax.plot(looper.time, looper.radius_s / 1e3, label="loopers") +looper_radius = looper.copy() +looper_radius.median_filter(1, "time", "radius_s", inplace=True) +looper_radius.loess_filter(13, "time", "radius_s", inplace=True) +ax.plot( + looper_radius.time, + looper_radius.radius_s / 1e3, + label="loopers (filtered half window 13 days)", +) +ax.plot(match.time, match.radius_s / 1e3, label="altimetry") +match_radius = match.copy() +match_radius.median_filter(1, "time", "radius_s", inplace=True) +match_radius.loess_filter(13, "time", "radius_s", inplace=True) +ax.plot( + match_radius.time, + match_radius.radius_s / 1e3, + label="altimetry (filtered half window 13 days)", +) +ax.set_ylabel("radius(km)"), ax.set_ylim(0, 100) +ax.legend() +ax.set_title("Radius from loopers and altimeter") +ax.grid() + + +# %% +# Animation of a drifter and its colocated eddy +def update(frame): + # We display last 5 days of loopers trajectory + m = (looper.time < frame) * (looper.time > (frame - 5)) + anim.func_animation(frame) + line.set_data(looper.lon[m], looper.lat[m]) + + +anim = Anim(match, intern=True, figsize=(8, 8), cmap="magma_r", nb_step=10, dpi=75) +# mappable to show drifter in red +line = anim.ax.plot([], [], "r", lw=4, zorder=100)[0] +anim.fig.suptitle("") +_ = VideoAnimation(anim.fig, update, frames=np.arange(*anim.period, 1), interval=125) diff --git a/examples/14_generic_tools/README.rst b/examples/14_generic_tools/README.rst new file mode 100644 index 00000000..295d55fe --- /dev/null +++ b/examples/14_generic_tools/README.rst @@ -0,0 +1,4 @@ +Polygon tools +============= + +Method to work with contour \ No newline at end of file diff --git a/examples/14_generic_tools/pet_fit_contour.py b/examples/14_generic_tools/pet_fit_contour.py new file mode 100644 index 00000000..2d3b6dc9 --- /dev/null +++ b/examples/14_generic_tools/pet_fit_contour.py @@ -0,0 +1,62 @@ +""" +Contour fit +=========== + +Two type of fit : + - Ellipse + - Circle + +In the two case we use a least square algorithm +""" + +from matplotlib import pyplot as plt +from numpy import cos, linspace, radians, sin + +from py_eddy_tracker import data +from py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates +from py_eddy_tracker.observations.observation import EddiesObservations +from py_eddy_tracker.poly import fit_circle_, fit_ellipse + +# %% +# Load example identification file +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) + + +# %% +# Function to draw circle or ellipse from parameter +def build_circle(x0, y0, r): + angle = radians(linspace(0, 360, 50)) + x_norm, y_norm = cos(angle), sin(angle) + return local_to_coordinates(x_norm * r, y_norm * r, x0, y0) + + +def build_ellipse(x0, y0, a, b, theta): + angle = radians(linspace(0, 360, 50)) + x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle) + y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle) + return local_to_coordinates(x, y, x0, y0) + + +# %% +# Plot fitted circle or ellipse on stored contour +xs, ys = a.contour_lon_s, a.contour_lat_s + +fig = plt.figure(figsize=(15, 15)) + +j = 1 +for i in range(0, 800, 30): + x, y = xs[i], ys[i] + x0_, y0_ = x.mean(), y.mean() + x_, y_ = coordinates_to_local(x, y, x0_, y0_) + ax = fig.add_subplot(4, 4, j) + ax.grid(), ax.set_aspect("equal") + ax.plot(x, y, label="store", color="black") + x0, y0, a, b, theta = fit_ellipse(x_, y_) + x0, y0 = local_to_coordinates(x0, y0, x0_, y0_) + ax.plot(*build_ellipse(x0, y0, a, b, theta), label="ellipse", color="green") + x0, y0, radius, shape_error = fit_circle_(x_, y_) + x0, y0 = local_to_coordinates(x0, y0, x0_, y0_) + ax.plot(*build_circle(x0, y0, radius), label="circle", color="red", lw=0.5) + if j == 16: + break + j += 1 diff --git a/examples/14_generic_tools/pet_visvalingam.py b/examples/14_generic_tools/pet_visvalingam.py new file mode 100644 index 00000000..736e8852 --- /dev/null +++ b/examples/14_generic_tools/pet_visvalingam.py @@ -0,0 +1,96 @@ +""" +Visvalingam algorithm +===================== +""" +from matplotlib import pyplot as plt +import matplotlib.animation as animation +from numba import njit +from numpy import array, empty + +from py_eddy_tracker import data +from py_eddy_tracker.generic import uniform_resample +from py_eddy_tracker.observations.observation import EddiesObservations +from py_eddy_tracker.poly import vertice_overlap, visvalingam + + +@njit(cache=True) +def visvalingam_polys(x, y, nb_pt): + nb = x.shape[0] + x_new = empty((nb, nb_pt), dtype=x.dtype) + y_new = empty((nb, nb_pt), dtype=y.dtype) + for i in range(nb): + x_new[i], y_new[i] = visvalingam(x[i], y[i], nb_pt) + return x_new, y_new + + +@njit(cache=True) +def uniform_resample_polys(x, y, nb_pt): + nb = x.shape[0] + x_new = empty((nb, nb_pt), dtype=x.dtype) + y_new = empty((nb, nb_pt), dtype=y.dtype) + for i in range(nb): + x_new[i], y_new[i] = uniform_resample(x[i], y[i], fixed_size=nb_pt) + return x_new, y_new + + +def update_line(num): + nb = 50 - num - 20 + x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb) + for i, (x_, y_) in enumerate(zip(x_v, y_v)): + lines_v[i].set_data(x_, y_) + x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb) + for i, (x_, y_) in enumerate(zip(x_u, y_u)): + lines_u[i].set_data(x_, y_) + scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0 + scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0 + for i, (s_v, s_u) in enumerate(zip(scores_v, scores_u)): + texts[i].set_text(f"Score uniform {s_u:.1f} %\nScore visvalingam {s_v:.1f} %") + title.set_text(f"{nb} points by contour in place of 50") + return (title, *lines_u, *lines_v, *texts) + + +# %% +# Load detection files +a = EddiesObservations.load_file(data.get_demo_path("Anticyclonic_20190223.nc")) +a = a.extract_with_mask((abs(a.lat) < 66) * (abs(a.radius_e) > 80e3)) + +nb_pt = 10 +x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0 +scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0 +d_6 = scores_v - scores_u +nb_pt = 18 +x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt) +scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0 +scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0 +d_12 = scores_v - scores_u +a = a.index(array((d_6.argmin(), d_6.argmax(), d_12.argmin(), d_12.argmax()))) + + +# %% +fig = plt.figure() +axs = [ + fig.add_subplot(221), + fig.add_subplot(222), + fig.add_subplot(223), + fig.add_subplot(224), +] +lines_u, lines_v, texts, score_text = list(), list(), list(), list() +for i, obs in enumerate(a): + axs[i].set_aspect("equal") + axs[i].grid() + axs[i].set_xticklabels([]), axs[i].set_yticklabels([]) + axs[i].plot( + obs["contour_lon_e"], obs["contour_lat_e"], "r", lw=6, label="Original contour" + ) + lines_v.append(axs[i].plot([], [], color="limegreen", lw=4, label="visvalingam")[0]) + lines_u.append( + axs[i].plot([], [], color="black", lw=2, label="uniform resampling")[0] + ) + texts.append(axs[i].set_title("", fontsize=8)) +axs[0].legend(fontsize=8) +title = fig.suptitle("") +anim = animation.FuncAnimation(fig, update_line, 27) +anim diff --git a/examples/16_network/README.rst b/examples/16_network/README.rst new file mode 100644 index 00000000..49bdc3ab --- /dev/null +++ b/examples/16_network/README.rst @@ -0,0 +1,6 @@ +Network +======= + +.. warning:: + + Network is under development, API could move quickly! diff --git a/examples/16_network/pet_atlas.py b/examples/16_network/pet_atlas.py new file mode 100644 index 00000000..48b374e2 --- /dev/null +++ b/examples/16_network/pet_atlas.py @@ -0,0 +1,187 @@ +""" +Network Analysis +================ +""" +from matplotlib import pyplot as plt +from numpy import ma + +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations + +n = NetworkObservations.load_file( + get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc" + ) +) +# %% +# Parameters +step = 1 / 10.0 +bins = ((-10, 37, step), (30, 46, step)) +kw_time = dict(cmap="terrain_r", factor=100.0 / n.nb_days, name="count") +kw_ratio = dict(cmap=plt.get_cmap("YlGnBu_r", 10)) + + +# %% +# Functions +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# All +# --- +# Display the % of time each pixel (1/10°) is within an anticyclonic network +ax = start_axes("") +g_all = n.grid_count(bins) +m = g_all.display(ax, **kw_time, vmin=0, vmax=75) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Network longer than 10 days +# --------------------------- +# Display the % of time each pixel (1/10°) is within an anticyclonic network +# which total lifetime in longer than 10 days +ax = start_axes("") +n10 = n.longer_than(10) +g_10 = n10.grid_count(bins) +m = g_10.display(ax, **kw_time, vmin=0, vmax=75) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio +# ^^^^^ +# Ratio between the longer and total presence +ax = start_axes("") +g_ = g_10.vars["count"] * 100.0 / g_all.vars["count"] +m = g_10.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Blue = mostly short networks +# +# Network longer than 20 days +# --------------------------- +# Display the % of time each pixel (1/10°) is within an anticyclonic network +# which total lifetime is longer than 20 days +ax = start_axes("") +n20 = n.longer_than(20) +g_20 = n20.grid_count(bins) +m = g_20.display(ax, **kw_time, vmin=0, vmax=75) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio +# ^^^^^ +# Ratio between the longer and total presence +ax = start_axes("") +g_ = g_20.vars["count"] * 100.0 / g_all.vars["count"] +m = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Now we will hide pixel which are used less than 365 times +g_ = ma.array( + g_20.vars["count"] * 100.0 / g_all.vars["count"], mask=g_all.vars["count"] < 365 +) +ax = start_axes("") +m = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") +# %% +# Now we will hide pixel which are used more than 365 times +ax = start_axes("") +g_ = ma.array( + g_20.vars["count"] * 100.0 / g_all.vars["count"], mask=g_all.vars["count"] >= 365 +) +m = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Coastal areas are mostly populated by short networks +# +# All merging +# ----------- +# Display the occurence of merging events +ax = start_axes("") +g_all_merging = n.merging_event().grid_count(bins) +m = g_all_merging.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio merging events / eddy presence +ax = start_axes("") +g_ = g_all_merging.vars["count"] * 100.0 / g_all.vars["count"] +m = g_all_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations) +# -------------------------------------------------------------------------------------------- +ax = start_axes("") +n10_ = n10.copy() +n10_.remove_dead_end(nobs=10) +merger = n10_.merging_event() +g_10_merging = merger.grid_count(bins) +m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Merging in networks longer than 10 days +# --------------------------------------- +ax = start_axes("") +merger = n10.merging_event() +g_10_merging = merger.grid_count(bins) +m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") +# %% +# Ratio merging events / eddy presence +ax = start_axes("") +g_ = ma.array( + g_10_merging.vars["count"] * 100.0 / g_10.vars["count"], + mask=g_10.vars["count"] < 365, +) +m = g_10_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# All splitting +# ------------- +# Display the occurence of splitting events +ax = start_axes("") +g_all_splitting = n.splitting_event().grid_count(bins) +m = g_all_splitting.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") + +# %% +# Ratio splitting events / eddy presence +ax = start_axes("") +g_ = g_all_splitting.vars["count"] * 100.0 / g_all.vars["count"] +m = g_all_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") + +# %% +# splitting in networks longer than 10 days +# ----------------------------------------- +ax = start_axes("") +g_10_splitting = n10.splitting_event().grid_count(bins) +m = g_10_splitting.display(ax, **kw_time, vmin=0, vmax=1) +update_axes(ax, m).set_label("Pixel used in % of time") +# %% +ax = start_axes("") +g_ = ma.array( + g_10_splitting.vars["count"] * 100.0 / g_10.vars["count"], + mask=g_10.vars["count"] < 365, +) +m = g_10_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +update_axes(ax, m).set_label("Pixel used in % all atlas") diff --git a/examples/16_network/pet_follow_particle.py b/examples/16_network/pet_follow_particle.py new file mode 100644 index 00000000..6815fb6e --- /dev/null +++ b/examples/16_network/pet_follow_particle.py @@ -0,0 +1,186 @@ +""" +Follow particle +=============== + +""" +import re + +from matplotlib import colors, pyplot as plt +from matplotlib.animation import FuncAnimation +from numpy import arange, meshgrid, ones, unique, zeros + +from py_eddy_tracker import start_logger +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection +from py_eddy_tracker.observations.groups import particle_candidate +from py_eddy_tracker.observations.network import NetworkObservations + +start_logger().setLevel("ERROR") + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) +n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269)) +n.remove_dead_end(nobs=0, ndays=10) +n = n.remove_trash() +n.numbering_segment() +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + heigth="adt", +) + +# %% +# Schema +# ------ +fig = plt.figure(figsize=(12, 6)) +ax = fig.add_axes([0.05, 0.05, 0.9, 0.9]) +_ = n.display_timeline(ax, field="longitude", marker="+", lw=2, markersize=5) + +# %% +# Animation +# --------- +# Particle settings +t_snapshot = 20200 +step = 1 / 50.0 +x, y = meshgrid(arange(20, 36, step), arange(30, 46, step)) +N = 6 +x_f, y_f = x[::N, ::N].copy(), y[::N, ::N].copy() +x, y = x.reshape(-1), y.reshape(-1) +x_f, y_f = x_f.reshape(-1), y_f.reshape(-1) +n_ = n.extract_with_mask(n.time == t_snapshot) +index = n_.contains(x, y, intern=True) +m = index != -1 +index = n_.segment[index[m]] +index_ = unique(index) +x, y = x[m], y[m] +m = ~n_.inside(x_f, y_f, intern=True) +x_f, y_f = x_f[m], y_f[m] + +# %% +# Animation +cmap = colors.ListedColormap(list(n.COLORS), name="from_list", N=n.segment.max() + 1) +a = Anim( + n, + intern=False, + figsize=(12, 6), + nb_step=1, + dpi=60, + field_color="segment", + field_txt="segment", + cmap=cmap, +) +a.fig.suptitle(""), a.ax.set_xlim(24, 36), a.ax.set_ylim(30, 36) +a.txt.set_position((25, 31)) + +step = 0.25 +kw_p = dict( + nb_step=2, + time_step=86400 * step * 0.5, + t_init=t_snapshot - 2 * step, + u_name="u", + v_name="v", +) + +mappables = dict() +particules = c.advect(x, y, **kw_p) +filament = c.filament(x_f, y_f, **kw_p, filament_size=3) +kw = dict(ls="", marker=".", markersize=0.25) +for k in index_: + m = k == index + mappables[k] = a.ax.plot([], [], color=cmap(k), **kw)[0] +m_filament = a.ax.plot([], [], lw=0.25, color="gray")[0] + + +def update(frame): + tt, xt, yt = particules.__next__() + for k, mappable in mappables.items(): + m = index == k + mappable.set_data(xt[m], yt[m]) + tt, xt, yt = filament.__next__() + m_filament.set_data(xt, yt) + if frame % 1 == 0: + a.func_animation(frame) + + +ani = VideoAnimation(a.fig, update, frames=arange(20200, 20269, step), interval=200) + + +# %% +# Particle advection +# ^^^^^^^^^^^^^^^^^^ +# Advection from speed contour to speed contour (default) + +step = 1 / 60.0 + +t_start, t_end = int(n.period[0]), int(n.period[1]) +dt = 14 + +shape = (n.obs.size, 2) +# Forward run +i_target_f, pct_target_f = -ones(shape, dtype="i4"), zeros(shape, dtype="i1") +for t in arange(t_start, t_end - dt): + particle_candidate(c, n, step, t, i_target_f, pct_target_f, n_days=dt) + +# Backward run +i_target_b, pct_target_b = -ones(shape, dtype="i4"), zeros(shape, dtype="i1") +for t in arange(t_start + dt, t_end): + particle_candidate(c, n, step, t, i_target_b, pct_target_b, n_days=-dt) + +# %% +fig = plt.figure(figsize=(10, 10)) +ax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45]) +ax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45]) +ax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45]) +ax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45]) +ax_1st_b.set_title("Backward advection for each time step") +ax_1st_f.set_title("Forward advection for each time step") +ax_1st_b.set_ylabel("Color -> First target\nLatitude") +ax_2nd_b.set_ylabel("Color -> Secondary target\nLatitude") +ax_2nd_b.set_xlabel("Julian days"), ax_2nd_f.set_xlabel("Julian days") +ax_1st_f.set_yticks([]), ax_2nd_f.set_yticks([]) +ax_1st_f.set_xticks([]), ax_1st_b.set_xticks([]) + + +def color_alpha(target, pct, vmin=5, vmax=80): + color = cmap(n.segment[target]) + # We will hide under 5 % and from 80% to 100 % it will be 1 + alpha = (pct - vmin) / (vmax - vmin) + alpha[alpha < 0] = 0 + alpha[alpha > 1] = 1 + color[:, 3] = alpha + return color + + +kw = dict( + name=None, yfield="longitude", event=False, zorder=-100, s=(n.speed_area / 20e6) +) +n.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw) +n.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw) +n.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw) +n.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw) +for ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f): + n.display_timeline(ax, field="longitude", marker="+", lw=2, markersize=5) + ax.grid() diff --git a/examples/16_network/pet_group_anim.py b/examples/16_network/pet_group_anim.py new file mode 100644 index 00000000..f2d439ed --- /dev/null +++ b/examples/16_network/pet_group_anim.py @@ -0,0 +1,154 @@ +""" +Network group process +===================== +""" +from datetime import datetime + +# sphinx_gallery_thumbnail_number = 2 +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib.colors import ListedColormap +from numba import njit +from numpy import arange, array, empty, ones + +from py_eddy_tracker import data +from py_eddy_tracker.generic import flatten_line_matrix +from py_eddy_tracker.observations.network import Network +from py_eddy_tracker.observations.observation import EddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +# %% +NETWORK_GROUPS = list() + + +@njit(cache=True) +def apply_replace(x, x0, x1): + nb = x.shape[0] + for i in range(nb): + if x[i] == x0: + x[i] = x1 + + +# %% +# Modified class to catch group process at each step in order to illustrate processing +class MyNetwork(Network): + def get_group_array(self, results, nb_obs): + """With a loop on all pair of index, we will label each obs with a group + number + """ + nb_obs = array(nb_obs, dtype="u4") + day_start = nb_obs.cumsum() - nb_obs + gr = empty(nb_obs.sum(), dtype="u4") + gr[:] = self.NOGROUP + + id_free = 1 + for i, j, ii, ij in results: + gr_i = gr[slice(day_start[i], day_start[i] + nb_obs[i])] + gr_j = gr[slice(day_start[j], day_start[j] + nb_obs[j])] + # obs with no groups + m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] == self.NOGROUP) + nb_new = m.sum() + gr_i[ii[m]] = gr_j[ij[m]] = arange(id_free, id_free + nb_new) + id_free += nb_new + # associate obs with no group with obs with group + m = (gr_i[ii] != self.NOGROUP) * (gr_j[ij] == self.NOGROUP) + gr_j[ij[m]] = gr_i[ii[m]] + m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] != self.NOGROUP) + gr_i[ii[m]] = gr_j[ij[m]] + # case where 2 obs have a different group + m = gr_i[ii] != gr_j[ij] + if m.any(): + # Merge of group, ref over etu + for i_, j_ in zip(ii[m], ij[m]): + g0, g1 = gr_i[i_], gr_j[j_] + apply_replace(gr, g0, g1) + NETWORK_GROUPS.append((i, j, gr.copy())) + return gr + + +# %% +# Movie period +t0 = (datetime(2005, 5, 1) - datetime(1950, 1, 1)).days +t1 = (datetime(2005, 6, 1) - datetime(1950, 1, 1)).days + +# %% +# Get data from period and area +e = EddiesObservations.load_file(data.get_demo_path("network_med.nc")) +e = e.extract_with_mask((e.time >= t0) * (e.time < t1)).extract_with_area( + dict(llcrnrlon=25, urcrnrlon=35, llcrnrlat=31, urcrnrlat=37.5) +) +# %% +# Reproduce individual daily identification(for demonstration) +EDDIES_BY_DAYS = list() +for i, b0, b1 in e.iter_on("time"): + EDDIES_BY_DAYS.append(e.index(i)) +# need for display +e = EddiesObservations.concatenate(EDDIES_BY_DAYS) + +# %% +# Run network building group to intercept every step +n = MyNetwork.from_eddiesobservations(EDDIES_BY_DAYS, window=7) +_ = n.group_observations(minimal_area=True) + + +# %% +def update(frame): + i_current, i_match, gr = NETWORK_GROUPS[frame] + current = EDDIES_BY_DAYS[i_current] + x = flatten_line_matrix(current.contour_lon_e) + y = flatten_line_matrix(current.contour_lat_e) + current_contour.set_data(x, y) + match = EDDIES_BY_DAYS[i_match] + x = flatten_line_matrix(match.contour_lon_e) + y = flatten_line_matrix(match.contour_lat_e) + matched_contour.set_data(x, y) + groups.set_array(gr) + txt.set_text(f"Day {i_current} match with day {i_match}") + s = 80 * ones(gr.shape) + s[gr == 0] = 4 + groups.set_sizes(s) + + +# %% +# Anim +# ---- +fig = plt.figure(figsize=(16, 9), dpi=50) +ax = fig.add_axes([0, 0, 1, 1]) +ax.set_aspect("equal"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5) +cmap = ListedColormap(["gray", *e.COLORS[:-1]], name="from_list", N=30) +kw_s = dict(cmap=cmap, vmin=0, vmax=30) +groups = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[0][2], **kw_s) +current_contour = ax.plot([], [], "k", lw=2, label="Current contour")[0] +matched_contour = ax.plot([], [], "r", lw=1, ls="--", label="Candidate contour")[0] +txt = ax.text(29, 35, "", fontsize=25) +ax.legend(fontsize=25) +ani = VideoAnimation(fig, update, frames=len(NETWORK_GROUPS), interval=220) + +# %% +# Final Result +# ------------ +fig = plt.figure(figsize=(16, 9)) +ax = fig.add_axes([0, 0, 1, 1]) +ax.set_aspect("equal"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5) +_ = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[-1][2], **kw_s) diff --git a/examples/16_network/pet_ioannou_2017_case.py b/examples/16_network/pet_ioannou_2017_case.py new file mode 100644 index 00000000..56bec82e --- /dev/null +++ b/examples/16_network/pet_ioannou_2017_case.py @@ -0,0 +1,237 @@ +""" +Ioannou case +============ +Figure 10 from https://doi.org/10.1002/2017JC013158 + +We want to find the Ierapetra Eddy described above in a network demonstration run. +""" + +from datetime import datetime, timedelta + +# %% +import re + +from matplotlib import colors, pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib.ticker import FuncFormatter +from numpy import arange, array, pi, where + +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.generic import coordinates_to_local +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.poly import fit_ellipse + +# %% + + +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +@FuncFormatter +def formatter(x, pos): + return (timedelta(x) + datetime(1950, 1, 1)).strftime("%d/%m/%Y") + + +def start_axes(title=""): + fig = plt.figure(figsize=(13, 6)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) + ax.set_xlim(19, 29), ax.set_ylim(31, 35.5) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def timeline_axes(title=""): + fig = plt.figure(figsize=(15, 5)) + ax = fig.add_axes([0.03, 0.06, 0.90, 0.88]) + ax.set_title(title, weight="bold") + ax.xaxis.set_major_formatter(formatter), ax.grid() + return ax + + +def update_axes(ax, mappable=None): + ax.grid(True) + if mappable: + return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# We know the network ID, we will get directly +ioannou_case = NetworkObservations.load_file(get_demo_path("network_med.nc")).network( + 651 +) +print(ioannou_case.infos()) + +# %% +# It seems that this network is huge! Our case is visible at 22E 33.5N +ax = start_axes() +ioannou_case.plot(ax, color_cycle=ioannou_case.COLORS) +update_axes(ax) + +# %% +# Full Timeline +# ------------- +# The network span for many years... How to cut the interesting part? +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.xaxis.set_major_formatter(formatter), ax.grid() +_ = ioannou_case.display_timeline(ax) + + +# %% +# Sub network and new numbering +# ----------------------------- +# Here we chose to keep only the order 3 segments relatives to our chosen eddy +i = where( + (ioannou_case.lat > 33) + * (ioannou_case.lat < 34) + * (ioannou_case.lon > 22) + * (ioannou_case.lon < 23) + * (ioannou_case.time > 20630) + * (ioannou_case.time < 20650) +)[0][0] +close_to_i3 = ioannou_case.relative(i, order=3) +close_to_i3.numbering_segment() + +# %% +# Anim +# ---- +# Quick movie to see better! +a = Anim( + close_to_i3, + figsize=(12, 4), + cmap=colors.ListedColormap( + list(close_to_i3.COLORS), name="from_list", N=close_to_i3.segment.max() + 1 + ), + nb_step=7, + dpi=70, + field_color="segment", + field_txt="segment", +) +a.ax.set_xlim(19, 30), a.ax.set_ylim(32, 35.25) +a.txt.set_position((21.5, 32.7)) +# We display in video only from the 100th day to the 500th +kwargs = dict(frames=arange(*a.period)[100:501], interval=100) +ani = VideoAnimation(a.fig, a.func_animation, **kwargs) + +# %% +# Classic display +# --------------- +ax = timeline_axes() +_ = close_to_i3.display_timeline(ax) + +# %% +ax = start_axes("") +n_copy = close_to_i3.copy() +n_copy.position_filter(2, 4) +n_copy.plot(ax, color_cycle=n_copy.COLORS) +update_axes(ax) + +# %% +# Latitude Timeline +# ----------------- +ax = timeline_axes(f"Close segments ({close_to_i3.infos()})") +n_copy = close_to_i3.copy() +n_copy.median_filter(15, "time", "latitude") +_ = n_copy.display_timeline(ax, field="lat", method="all") + +# %% +# Local radius timeline +# --------------------- +# Effective (bold) and Speed (thin) Radius together +n_copy.median_filter(2, "time", "radius_e") +n_copy.median_filter(2, "time", "radius_s") +for b0, b1 in [ + (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007) +]: + ref, delta = datetime(1950, 1, 1), 20 + b0_, b1_ = (b0 - ref).days, (b1 - ref).days + ax = timeline_axes() + ax.set_xlim(b0_ - delta, b1_ + delta) + ax.set_ylim(10, 115) + ax.axvline(b0_, color="k", lw=1.5, ls="--"), ax.axvline( + b1_, color="k", lw=1.5, ls="--" + ) + n_copy.display_timeline( + ax, field="radius_e", method="all", lw=4, markersize=8, factor=1e-3 + ) + n_copy.display_timeline( + ax, field="radius_s", method="all", lw=1, markersize=3, factor=1e-3 + ) + +# %% +# Parameters timeline +# ------------------- +# Effective Radius +kw = dict(s=35, cmap=plt.get_cmap("Spectral_r", 8), zorder=10) +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, "radius_e", factor=1e-3, vmin=20, vmax=100, **kw) +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective radius (km)") +# %% +# Shape error +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, "shape_error_e", vmin=14, vmax=70, **kw) +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") + +# %% +# Rotation angle +# -------------- +# For each obs, fit an ellipse to the contour, with theta the angle from the x-axis, +# a the semi ax in x direction and b the semi ax in y dimension + +theta_ = list() +a_ = list() +b_ = list() +for obs in close_to_i3: + x, y = obs["contour_lon_s"], obs["contour_lat_s"] + x0_, y0_ = x.mean(), y.mean() + x_, y_ = coordinates_to_local(x, y, x0_, y0_) + x0, y0, a, b, theta = fit_ellipse(x_, y_) + theta_.append(theta) + a_.append(a) + b_.append(b) +a_ = array(a_) +b_ = array(b_) + +# %% +# Theta +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, theta_, vmin=-pi / 2, vmax=pi / 2, cmap="hsv") +_ = update_axes(ax, m["scatter"]) + +# %% +# a +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, a_ * 1e-3, vmin=0, vmax=80, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) + +# %% +# b +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, b_ * 1e-3, vmin=0, vmax=80, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) + +# %% +# a/b +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, a_ / b_, vmin=1, vmax=2, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py new file mode 100644 index 00000000..dd97b538 --- /dev/null +++ b/examples/16_network/pet_relative.py @@ -0,0 +1,336 @@ +""" +Network basic manipulation +========================== +""" +from matplotlib import pyplot as plt +from numpy import where + +from py_eddy_tracker import data +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations + +# %% +# Load data +# --------- +# Load data where observations are put in same network but no segmentation +n = NetworkObservations.load_file(data.get_demo_path("network_med.nc")).network(651) +i = where( + (n.lat > 33) + * (n.lat < 34) + * (n.lon > 22) + * (n.lon < 23) + * (n.time > 20630) + * (n.time < 20650) +)[0][0] +# For event use +n2 = n.relative(i, order=2) +n = n.relative(i, order=4) +n.numbering_segment() + +# %% +# Timeline +# -------- + +# %% +# Display timeline with events +# A segment generated by a splitting is marked with a star +# +# A segment merging in another is marked with an exagon +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +_ = n.display_timeline(ax) + +# %% +# Display timeline without event +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +_ = n.display_timeline(ax, event=False) + +# %% +# Timeline by mean latitude +# ------------------------- +# Display timeline with the mean latitude of the segments in yaxis +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +ax.set_ylabel("Latitude") +_ = n.display_timeline(ax, field="latitude") + +# %% +# Timeline by mean Effective Radius +# --------------------------------- +# The factor argument is applied on the chosen field +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.04, 0.92, 0.92]) +ax.set_ylabel("Effective Radius (km)") +_ = n.display_timeline(ax, field="radius_e", factor=1e-3) + +# %% +# Timeline by latitude +# -------------------- +# Use `method="all"` to display the consecutive values of the field +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.set_ylabel("Latitude") +_ = n.display_timeline(ax, field="lat", method="all") + +# %% +# You can filter the data, here with a time window of 15 days +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +n_copy = n.copy() +n_copy.median_filter(15, "time", "latitude") +_ = n_copy.display_timeline(ax, field="lat", method="all") + +# %% +# Parameters timeline +# ------------------- +# Scatter is usefull to display the parameters' temporal evolution +# +# Effective Radius and Amplitude +kw = dict(s=25, cmap="Spectral_r", zorder=10) +fig = plt.figure(figsize=(15, 12)) +ax = fig.add_axes([0.04, 0.54, 0.90, 0.44]) +m = n.scatter_timeline(ax, "radius_e", factor=1e-3, vmin=50, vmax=150, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.54, 0.01, 0.44]), orientation="vertical" +) +cb.set_label("Effective radius (km)") + +ax = fig.add_axes([0.04, 0.04, 0.90, 0.44]) +m = n.scatter_timeline(ax, "amplitude", factor=100, vmin=0, vmax=15, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.44]), orientation="vertical" +) +cb.set_label("Amplitude (cm)") + +# %% +# Speed +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +m = n.scatter_timeline(ax, "speed_average", factor=100, vmin=0, vmax=40, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation="vertical" +) +cb.set_label("Maximum speed (cm/s)") + +# %% +# Speed Radius +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +m = n.scatter_timeline(ax, "radius_s", factor=1e-3, vmin=20, vmax=100, **kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation="vertical" +) +cb.set_label("Speed radius (km)") + +# %% +# Remove dead branch +# ------------------ +# Remove all tiny segments with less than N obs which didn't join two segments +n_clean = n.copy() +n_clean.remove_dead_end(nobs=5, ndays=10) +n_clean = n_clean.remove_trash() +fig = plt.figure(figsize=(15, 12)) +ax = fig.add_axes([0.04, 0.54, 0.90, 0.40]) +ax.set_title(f"Original network ({n.infos()})") +n.display_timeline(ax) +ax = fig.add_axes([0.04, 0.04, 0.90, 0.40]) +ax.set_title(f"Clean network ({n_clean.infos()})") +_ = n_clean.display_timeline(ax) + +# %% +# For further figure we will use clean path +n = n_clean + +# %% +# Change splitting-merging events +# ------------------------------- +# change event where seg A split to B, then A merge into B, to A split to B then B merge into A +fig = plt.figure(figsize=(15, 12)) +ax = fig.add_axes([0.04, 0.54, 0.90, 0.40]) +ax.set_title(f"Clean network ({n.infos()})") +n.display_timeline(ax) + +clean_modified = n.copy() +# If it's happen in less than 40 days +clean_modified.correct_close_events(40) + +ax = fig.add_axes([0.04, 0.04, 0.90, 0.40]) +ax.set_title(f"resplitted network ({clean_modified.infos()})") +_ = clean_modified.display_timeline(ax) + +# %% +# Keep only observations where water could propagate from an observation +# ---------------------------------------------------------------------- +i_observation = 600 +only_linked = n.find_link(i_observation) + +fig = plt.figure(figsize=(15, 12)) +ax1 = fig.add_axes([0.04, 0.54, 0.90, 0.40]) +ax2 = fig.add_axes([0.04, 0.04, 0.90, 0.40]) + +kw = dict(marker="s", s=300, color="black", zorder=200, label="observation start") +for ax, dataset in zip([ax1, ax2], [n, only_linked]): + dataset.display_timeline(ax, field="segment", lw=2, markersize=5, colors_mode="y") + ax.scatter(n.time[i_observation], n.segment[i_observation], **kw) + ax.legend() + +ax1.set_title(f"full example ({n.infos()})") +ax2.set_title(f"only linked observations ({only_linked.infos()})") +_ = ax2.set_xlim(ax1.get_xlim()), ax2.set_ylim(ax1.get_ylim()) + +# %% +# Keep close relative +# ------------------- +# When you want to investigate one particular observation and select only the closest segments + +# First choose an observation in the network +i = 1100 + +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +n.display_timeline(ax) +obs_args = n.time[i], n.segment[i] +obs_kw = dict(color="black", markersize=30, marker=".") +_ = ax.plot(*obs_args, **obs_kw) + +# %% +# Colors show the relative order of the segment with regards to the chosen one +fig = plt.figure(figsize=(15, 6)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +m = n.scatter_timeline( + ax, n.obs_relative_order(i), vmin=-1.5, vmax=6.5, cmap=plt.get_cmap("jet", 8), s=10 +) +ax.plot(*obs_args, **obs_kw) +cb = plt.colorbar( + m["scatter"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation="vertical" +) +cb.set_label("Relative order") +# %% +# You want to keep only the segments at the order 1 +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +close_to_i1 = n.relative(i, order=1) +ax.set_title(f"Close segments ({close_to_i1.infos()})") +_ = close_to_i1.display_timeline(ax) +# %% +# You want to keep the segments until order 2 +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +close_to_i2 = n.relative(i, order=2) +ax.set_title(f"Close segments ({close_to_i2.infos()})") +_ = close_to_i2.display_timeline(ax) +# %% +# You want to keep the segments until order 3 +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88]) +close_to_i3 = n.relative(i, order=3) +ax.set_title(f"Close segments ({close_to_i3.infos()})") +_ = close_to_i3.display_timeline(ax) + +# %% +# Keep relatives to an event +# -------------------------- +# When you want to investigate one particular event and select only the closest segments +# +# First choose a merging event in the network +after, before, stopped = n.merging_event(triplet=True, only_index=True) +i_event = 7 +# %% +# then see some order of relatives + +max_order = 1 +fig, axs = plt.subplots( + max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2)) +) +# Original network +ax = axs[0] +ax.set_title("Full network", weight="bold") +n.display_timeline(axs[0], colors_mode="y") +ax.grid(), ax.legend() + +for k in range(0, max_order + 1): + ax = axs[k + 1] + ax.set_title(f"Relatives order={k}", weight="bold") + # Extract neighbours of event + sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k) + sub_network.display_timeline(ax, colors_mode="y") + ax.legend(), ax.grid() + _ = ax.set_ylim(axs[0].get_ylim()) + +# %% +# Display track on map +# -------------------- + +# Get a simplified network +n = n2.copy() +n.remove_dead_end(nobs=50, recursive=1) +n = n.remove_trash() +n.numbering_segment() +# %% +# Only a map can be tricky to understand, with a timeline it's easier! +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) +n.plot(ax, color_cycle=n.COLORS) +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +ax = fig.add_axes([0.08, 0.7, 0.7, 0.3]) +_ = n.display_timeline(ax) + + +# %% +# Get merging event +# ----------------- +# Display the position of the eddies after a merging +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +n.plot(ax, color_cycle=n.COLORS) +m1, m0, m0_stop = n.merging_event(triplet=True) +m1.display(ax, color="violet", lw=2, label="Eddies after merging") +m0.display(ax, color="blueviolet", lw=2, label="Eddies before merging") +m0_stop.display(ax, color="black", lw=2, label="Eddies stopped by merging") +ax.plot(m1.lon, m1.lat, marker=".", color="purple", ls="") +ax.plot(m0.lon, m0.lat, marker=".", color="blueviolet", ls="") +ax.plot(m0_stop.lon, m0_stop.lat, marker=".", color="black", ls="") +ax.legend() +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +m1 + +# %% +# Get splitting event +# ------------------- +# Display the position of the eddies before a splitting +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +n.plot(ax, color_cycle=n.COLORS) +s0, s1, s1_start = n.splitting_event(triplet=True) +s0.display(ax, color="violet", lw=2, label="Eddies before splitting") +s1.display(ax, color="blueviolet", lw=2, label="Eddies after splitting") +s1_start.display(ax, color="black", lw=2, label="Eddies starting by splitting") +ax.plot(s0.lon, s0.lat, marker=".", color="purple", ls="") +ax.plot(s1.lon, s1.lat, marker=".", color="blueviolet", ls="") +ax.plot(s1_start.lon, s1_start.lat, marker=".", color="black", ls="") +ax.legend() +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +s1 + +# %% +# Get birth event +# --------------- +# Display the starting position of non-splitted eddies +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +birth = n.birth_event() +birth.display(ax) +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +birth + +# %% +# Get death event +# --------------- +# Display the last position of non-merged eddies +fig = plt.figure(figsize=(15, 8)) +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) +death = n.death_event() +death.display(ax) +ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() +death diff --git a/examples/16_network/pet_replay_segmentation.py b/examples/16_network/pet_replay_segmentation.py new file mode 100644 index 00000000..d909af7f --- /dev/null +++ b/examples/16_network/pet_replay_segmentation.py @@ -0,0 +1,176 @@ +""" +Replay segmentation +=================== +Case from figure 10 from https://doi.org/10.1002/2017JC013158 + +Again with the Ierapetra Eddy +""" +from datetime import datetime, timedelta + +from matplotlib import pyplot as plt +from matplotlib.ticker import FuncFormatter +from numpy import where + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +@FuncFormatter +def formatter(x, pos): + return (timedelta(x) + datetime(1950, 1, 1)).strftime("%d/%m/%Y") + + +def start_axes(title=""): + fig = plt.figure(figsize=(13, 6)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) + ax.set_xlim(19, 29), ax.set_ylim(31, 35.5) + ax.set_aspect("equal") + ax.set_title(title, weight="bold") + return ax + + +def timeline_axes(title=""): + fig = plt.figure(figsize=(15, 5)) + ax = fig.add_axes([0.04, 0.06, 0.89, 0.88]) + ax.set_title(title, weight="bold") + ax.xaxis.set_major_formatter(formatter), ax.grid() + return ax + + +def update_axes(ax, mappable=None): + ax.grid(True) + if mappable: + return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Class for new_segmentation +# -------------------------- +# The oldest win +class MyTrackEddiesObservations(TrackEddiesObservations): + __slots__ = tuple() + + @classmethod + def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs): + """ + Method to overwrite behaviour in merging. + + We will give the point to the older one instead of the maximum overlap ratio + """ + while i_next != -1: + # Flag + used[i_next] = True + # Assign id + ids["track"][i_next] = track_id + # Search next + i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs) + if i_next_ == -1: + break + ids["next_obs"][i_next] = i_next_ + # Target was previously used + if used[i_next_]: + i_next_ = -1 + else: + ids["previous_obs"][i_next_] = i_next + i_next = i_next_ + + +def get_obs(dataset): + "Function to isolate a specific obs" + return where( + (dataset.lat > 33) + * (dataset.lat < 34) + * (dataset.lon > 22) + * (dataset.lon < 23) + * (dataset.time > 20630) + * (dataset.time < 20650) + )[0][0] + + +# %% +# Get original network, we will isolate only relative at order *2* +n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) +n_ = n.relative(get_obs(n), order=2) + +# %% +# Display the default segmentation +ax = start_axes(n_.infos()) +n_.plot(ax, color_cycle=n.COLORS) +update_axes(ax) +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.xaxis.set_major_formatter(formatter), ax.grid() +_ = n_.display_timeline(ax) + +# %% +# Run a new segmentation +# ---------------------- +e = n.astype(MyTrackEddiesObservations) +e.obs.sort(order=("track", "time"), kind="stable") +split_matrix = e.split_network(intern=False, window=7) +n_ = NetworkObservations.from_split_network(e, split_matrix) +n_ = n_.relative(get_obs(n_), order=2) +n_.numbering_segment() + +# %% +# New segmentation +# ---------------- +# "The oldest wins" method produce a very long segment +ax = start_axes(n_.infos()) +n_.plot(ax, color_cycle=n_.COLORS) +update_axes(ax) +fig = plt.figure(figsize=(15, 5)) +ax = fig.add_axes([0.04, 0.05, 0.92, 0.92]) +ax.xaxis.set_major_formatter(formatter), ax.grid() +_ = n_.display_timeline(ax) + +# %% +# Parameters timeline +# ------------------- +kw = dict(s=35, cmap=plt.get_cmap("Spectral_r", 8), zorder=10) +ax = timeline_axes() +n_.median_filter(15, "time", "latitude") +m = n_.scatter_timeline(ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lat") +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") + +ax = timeline_axes() +n_.median_filter(15, "time", "latitude") +m = n_.scatter_timeline( + ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lat", method="all" +) +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") +ax.set_ylabel("Latitude") + +ax = timeline_axes() +n_.median_filter(15, "time", "latitude") +kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30**2 * 20 +m = n_.scatter_timeline( + ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lon", method="all" +) +ax.set_ylabel("Longitude") +cb = update_axes(ax, m["scatter"]) +cb.set_label("Effective shape error") + +# %% +# Cost association plot +# --------------------- +n_copy = n_.copy() +n_copy.median_filter(2, "time", "next_cost") +for b0, b1 in [ + (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008) +]: + ref, delta = datetime(1950, 1, 1), 20 + b0_, b1_ = (b0 - ref).days, (b1 - ref).days + ax = timeline_axes() + ax.set_xlim(b0_ - delta, b1_ + delta) + ax.set_ylim(0, 1) + ax.axvline(b0_, color="k", lw=1.5, ls="--"), ax.axvline( + b1_, color="k", lw=1.5, ls="--" + ) + n_copy.display_timeline(ax, field="next_cost", method="all", lw=4, markersize=8) + + n_.display_timeline(ax, field="next_cost", method="all", lw=0.5, markersize=0) diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py new file mode 100644 index 00000000..1fcb9ae1 --- /dev/null +++ b/examples/16_network/pet_segmentation_anim.py @@ -0,0 +1,125 @@ +""" +Network segmentation process +============================ +""" +# sphinx_gallery_thumbnail_number = 2 +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +from matplotlib.colors import ListedColormap +from numpy import ones, where + +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.gui import GUI_AXES +from py_eddy_tracker.observations.network import NetworkObservations +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +def get_obs(dataset): + "Function to isolate a specific obs" + return where( + (dataset.lat > 33) + * (dataset.lat < 34) + * (dataset.lon > 22) + * (dataset.lon < 23) + * (dataset.time > 20630) + * (dataset.time < 20650) + )[0][0] + + +# %% +# Hack to pick up each step of segmentation +TRACKS = list() +INDICES = list() + + +class MyTrack(TrackEddiesObservations): + @staticmethod + def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs): + TRACKS.append(ids["track"].copy()) + INDICES.append(i_current) + return TrackEddiesObservations.get_next_obs( + i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs + ) + + +# %% +# Load data +# --------- +# Load data where observations are put in same network but no segmentation + +# Get a known network for the demonstration +n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) +# We keep only some segment +n = n.relative(get_obs(n), order=2) +print(len(n)) +# We convert and order object like segmentation was never happen on observations +e = n.astype(MyTrack) +e.obs.sort(order=("track", "time"), kind="stable") + +# %% +# Do segmentation +# --------------- +# Segmentation based on maximum overlap, temporal window for candidates = 5 days +matrix = e.split_network(intern=False, window=5) + + +# %% +# Anim +# ---- +def update(i_frame): + tr = TRACKS[i_frame] + mappable_tracks.set_array(tr) + s = 40 * ones(tr.shape) + s[tr == 0] = 4 + mappable_tracks.set_sizes(s) + + indices_frames = INDICES[i_frame] + mappable_CONTOUR.set_data( + e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames] + ) + mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) + return (mappable_tracks,) + + +fig = plt.figure(figsize=(16, 9), dpi=60) +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) +ax.set_title(f"{len(e)} observations to segment") +ax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid() +vmax = TRACKS[-1].max() +cmap = ListedColormap(["gray", *e.COLORS[:-1]], name="from_list", N=vmax) +mappable_tracks = ax.scatter( + e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20 +) +mappable_CONTOUR = ax.plot( + e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0] +)[0] +ani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125) + +# %% +# Final Result +# ------------ +fig = plt.figure(figsize=(16, 9)) +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) +ax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid() +_ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20) diff --git a/notebooks/README.md b/notebooks/README.md index 27bee160..fd8971aa 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -1,3 +1,3 @@ -# rm build/sphinx/ doc/python_module/ doc/gen_modules/ -rf +# rm build/sphinx/ doc/python_module/ doc/gen_modules/ doc/_autosummary/ -rf python setup.py build_sphinx rsync -vrltp doc/python_module notebooks/. --include '*/' --include '*.ipynb' --exclude '*' --prune-empty-dirs diff --git a/notebooks/python_module/01_general_things/pet_storage.ipynb b/notebooks/python_module/01_general_things/pet_storage.ipynb new file mode 100644 index 00000000..a56e4def --- /dev/null +++ b/notebooks/python_module/01_general_things/pet_storage.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# How data is stored\n\nGeneral information about eddies storage.\n\nAll files have the same structure, with more or less fields and possible different order.\n\nThere are 3 class of files:\n\n- **Eddies collections** : contain a list of eddies without link between them\n- **Track eddies collections** :\n manage eddies associated in trajectories, the ```track``` field allows to separate each trajectory\n- **Network eddies collections** :\n manage eddies associated in networks, the ```track``` and ```segment``` fields allow to separate observations\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, outer\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.observation import EddiesObservations, Table\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Eddies can be stored in 2 formats with the same structure:\n\n- zarr (https://zarr.readthedocs.io/en/stable/), which allow efficiency in IO,...\n- NetCDF4 (https://unidata.github.io/netcdf4-python/), well-known format\n\nEach field are stored in column, each row corresponds at 1 observation,\narray field like contour/profile are 2D column.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Eddies files (zarr or netcdf) can be loaded with ```load_file``` method:\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections = EddiesObservations.load_file(get_demo_path(\"Cyclonic_20160515.nc\"))\neddies_collections.field_table()\n# offset and scale_factor are used only when data is stored in zarr or netCDF4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Field access\nTo access the total field, here ```amplitude```\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections.amplitude\n\n# To access only a specific part of the field\neddies_collections.amplitude[4:15]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data matrix is a numpy ndarray\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections.obs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_collections.obs.dtype" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Contour storage\nAll contours are stored on the same number of points, and are resampled if needed with an algorithm to be stored as objects\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Speed profile storage\nSpeed profile is an interpolation of speed mean along each contour.\nFor each contour included in eddy, we compute mean of speed along the contour,\nand after we interpolate speed mean array on a fixed size array.\n\nSeveral field are available to understand \"uavg_profile\" :\n 0. - num_contours : Number of contour in eddies, must be equal to amplitude divide by isoline step\n 1. - height_inner_contour : height of inner contour used\n 2. - height_max_speed_contour : height of max speed contour used\n 3. - height_external_contour : height of outter contour used\n\nLast value of \"uavg_profile\" is for inner contour and first value for outter contour.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Observations selection of \"uavg_profile\" with high number of contour(Eddy with high amplitude)\ne = eddies_collections.extract_with_mask(eddies_collections.num_contours > 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Raw display of profiles with more than 15 contours\nax = plt.subplot(111)\n_ = ax.plot(e.uavg_profile.T, lw=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Profile from inner to outter\nax = plt.subplot(111)\nax.plot(e.uavg_profile[:, ::-1].T, lw=0.5)\n_ = ax.set_xlabel(\"From inner to outter contour\"), ax.set_ylabel(\"Speed (m/s)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# If we normalize indice of contour to set speed contour to 1 and inner contour to 0\nax = plt.subplot(111)\nh_in = e.height_inner_contour\nh_s = e.height_max_speed_contour\nh_e = e.height_external_contour\nr = (h_e - h_in) / (h_s - h_in)\nnb_pt = e.uavg_profile.shape[1]\n# Create an x array for each profile\nx = outer(arange(nb_pt) / nb_pt, r)\n\nax.plot(x, e.uavg_profile[:, ::-1].T, lw=0.5)\n_ = ax.set_xlabel(\"From inner to outter contour\"), ax.set_ylabel(\"Speed (m/s)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Trajectories\nTracks eddies collections add several fields :\n\n- **track** : Trajectory number\n- **observation_flag** : Flag indicating if the value is interpolated between two observations or not\n (0: observed eddy, 1: interpolated eddy)\"\n- **observation_number** : Eddy temporal index in a trajectory, days starting at the eddy first detection\n- **cost_association** : result of the cost function to associate the eddy with the next observation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_tracks = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\n# In this example some fields are removed (effective_contour_longitude,...) in order to save time for doc building\neddies_tracks.field_table()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Networks\nNetwork files use some specific fields :\n\n- track : ID of network (ID 0 correspond to lonely eddies)\n- segment : ID of a segment within a network (from 1 to N)\n- previous_obs : Index of the previous observation in the full dataset,\n if -1 there are no previous observation (the segment starts)\n- next_obs : Index of the next observation in the full dataset, if -1 there are no next observation (the segment ends)\n- previous_cost : Result of the cost function (1 is a good association, 0 is bad) with previous observation\n- next_cost : Result of the cost function (1 is a good association, 0 is bad) with next observation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "eddies_network = NetworkObservations.load_file(get_demo_path(\"network_med.nc\"))\neddies_network.field_table()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "sl = slice(70, 100)\nTable(\n eddies_network.network(651).obs[sl][\n [\n \"time\",\n \"track\",\n \"segment\",\n \"previous_obs\",\n \"previous_cost\",\n \"next_obs\",\n \"next_cost\",\n ]\n ]\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Networks are ordered by increasing network number (`track`), then increasing segment number, then increasing time\n\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb b/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb index aab4e0f6..2d924387 100644 --- a/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(10, 70)\nax.set_ylim(-50, -25)\na.display(ax, label=\"Anticyclonic contour\", color=\"r\", lw=1)\n\n# Replace contours by circles using center and radius (effective is dashed)\na.circle_contour()\na.display(ax, label=\"Anticyclonic circle\", color=\"g\", lw=1)\nax.legend(loc=\"upper right\")" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_aspect(\"equal\")\nax.set_xlim(10, 70)\nax.set_ylim(-50, -25)\na.display(ax, label=\"Anticyclonic contour\", color=\"r\", lw=1)\n\n# Replace contours by circles using center and radius (effective is dashed)\na.circle_contour()\na.display(ax, label=\"Anticyclonic circle\", color=\"g\", lw=1)\n_ = ax.legend(loc=\"upper right\")" ] } ], @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb b/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb index 31abe783..d59f9e15 100644 --- a/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_path(\"Cyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_demo_path(\"Cyclonic_20190223.nc\"))" ] }, { @@ -129,7 +129,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb index 72604a63..7469b034 100644 --- a/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15, cmap=\"RdBu_r\")\nupdate_axes(ax, m)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15, cmap=\"RdBu_r\")\nupdate_axes(ax, m)" ] }, { @@ -170,7 +170,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Creteria for rejecting a contour\n0. - Accepted (green)\n1. - Rejection for shape error (red)\n2. - Masked value within contour (blue)\n3. - Under or over the pixel limit bounds (black)\n4. - Amplitude criterion (yellow)\n\n" + "Criteria for rejecting a contour:\n 0. - Accepted (green)\n 1. - Rejection for shape error (red)\n 2. - Masked value within contour (blue)\n 3. - Under or over the pixel limit bounds (black)\n 4. - Amplitude criterion (yellow)\n\n" ] }, { @@ -235,7 +235,7 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Detected Eddies\")\na.display(ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\nupdate_axes(ax)" + "ax = start_axes(\"Detected Eddies\")\na.display(\n ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10\n)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\nupdate_axes(ax)" ] }, { @@ -253,7 +253,7 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Speed Radius (km)\")\na.scatter(ax, \"radius_s\", vmin=10, vmax=50, s=80, ref=-10, cmap=\"magma_r\", factor=0.001)\nm = c.scatter(\n ax, \"radius_s\", vmin=10, vmax=50, s=80, ref=-10, cmap=\"magma_r\", factor=0.001\n)\nupdate_axes(ax, m)" + "ax = start_axes(\"Speed Radius (km)\")\nkwargs = dict(vmin=10, vmax=50, s=80, ref=-10, cmap=\"magma_r\", factor=0.001)\na.scatter(ax, \"radius_s\", **kwargs)\nm = c.scatter(ax, \"radius_s\", **kwargs)\nupdate_axes(ax, m)" ] }, { @@ -271,7 +271,7 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Effective Radius (km)\")\nkwargs = dict(vmin=10, vmax=80, cmap=\"magma_r\", factor=0.001, lut=14, ref=-10)\na.filled(ax, \"effective_radius\", **kwargs)\nm = c.filled(\n ax, \"radius_e\", vmin=10, vmax=80, cmap=\"magma_r\", factor=0.001, lut=14, ref=-10\n)\nupdate_axes(ax, m)" + "ax = start_axes(\"Effective Radius (km)\")\nkwargs = dict(vmin=10, vmax=80, cmap=\"magma_r\", factor=0.001, lut=14, ref=-10)\na.filled(ax, \"effective_radius\", **kwargs)\nm = c.filled(ax, \"radius_e\", **kwargs)\nupdate_axes(ax, m)" ] } ], @@ -291,7 +291,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb new file mode 100644 index 00000000..6ac75cee --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Eddy detection : Antartic Circumpolar Current\n\nThis script detect eddies on the ADT field, and compute u,v with the method add_uv (use it only if the Equator is avoided)\n\nTwo detections are provided : with a filtered ADT and without filtering\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib import style\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\n\npos_cb = [0.1, 0.52, 0.83, 0.015]\npos_cb2 = [0.1, 0.07, 0.4, 0.015]\n\n\ndef quad_axes(title):\n style.use(\"default\")\n fig = plt.figure(figsize=(13, 10))\n fig.suptitle(title, weight=\"bold\", fontsize=14)\n axes = list()\n\n ax_pos = dict(\n topleft=[0.1, 0.54, 0.4, 0.38],\n topright=[0.53, 0.54, 0.4, 0.38],\n botleft=[0.1, 0.09, 0.4, 0.38],\n botright=[0.53, 0.09, 0.4, 0.38],\n )\n\n for key, position in ax_pos.items():\n ax = fig.add_axes(position)\n ax.set_xlim(5, 45), ax.set_ylim(-60, -37)\n ax.set_aspect(\"equal\"), ax.grid(True)\n axes.append(ax)\n if \"right\" in key:\n ax.set_yticklabels(\"\")\n return fig, axes\n\n\ndef set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight=\"semibold\"):\n for ax in fig.get_axes():\n ax.grid()\n ax.grid(which=\"major\", linestyle=\"-\", linewidth=\"0.5\", color=\"black\")\n if ax.get_ylabel() != \"\":\n ax.set_ylabel(ax.get_ylabel(), fontsize=labelsize, fontweight=labelweight)\n if ax.get_xlabel() != \"\":\n ax.set_xlabel(ax.get_xlabel(), fontsize=labelsize, fontweight=labelweight)\n if ax.get_title() != \"\":\n ax.set_title(ax.get_title(), fontsize=labelsize, fontweight=labelweight)\n ax.tick_params(labelsize=ticklabelsize)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load Input grid, ADT is used to detect eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "margin = 30\n\nkw_data = dict(\n filename=data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n x_name=\"longitude\",\n y_name=\"latitude\",\n # Manual area subset\n indexs=dict(\n latitude=slice(100 - margin, 220 + margin),\n longitude=slice(0, 230 + margin),\n ),\n)\ng_raw = RegularGridDataset(**kw_data)\ng_raw.add_uv(\"adt\")\ng = RegularGridDataset(**kw_data)\ng.copy(\"adt\", \"adt_low\")\ng.bessel_high_filter(\"adt\", 700)\ng.bessel_low_filter(\"adt_low\", 700)\ng.add_uv(\"adt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Identification\nRun the identification step with slices of 2 mm\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "date = datetime(2016, 5, 15)\nkw_ident = dict(\n date=date, step=0.002, shape_error=70, sampling=30, uname=\"u\", vname=\"v\"\n)\na, c = g.eddy_identification(\"adt\", **kw_ident)\na_, c_ = g_raw.eddy_identification(\"adt\", **kw_ident)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Figures\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw_adt = dict(vmin=-1.5, vmax=1.5, cmap=plt.get_cmap(\"RdBu_r\", 30))\nfig, axs = quad_axes(\"General properties field\")\ng_raw.display(axs[0], \"adt\", **kw_adt)\naxs[0].set_title(\"Total ADT (m)\")\nm = g.display(axs[1], \"adt_low\", **kw_adt)\naxs[1].set_title(\"ADT (m) large scale, cutoff at 700 km\")\nm2 = g.display(axs[2], \"adt\", cmap=plt.get_cmap(\"RdBu_r\", 20), vmin=-0.5, vmax=0.5)\naxs[2].set_title(\"ADT (m) high-pass filtered, a cutoff at 700 km\")\ncb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation=\"horizontal\")\ncb.set_label(\"ADT (m)\", labelpad=0)\ncb2 = plt.colorbar(m2, cax=axs[2].figure.add_axes(pos_cb2), orientation=\"horizontal\")\ncb2.set_label(\"ADT (m)\", labelpad=0)\nset_fancy_labels(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The large-scale North-South gradient is removed by the filtering step.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, axs = quad_axes(\"\")\naxs[0].set_title(\"Without filter\")\naxs[0].set_ylabel(\"Contours used in eddies\")\naxs[1].set_title(\"With filter\")\naxs[2].set_ylabel(\"Closed contours but not used\")\ng_raw.contours.display(axs[0], lw=0.5, only_used=True)\ng.contours.display(axs[1], lw=0.5, only_used=True)\ng_raw.contours.display(axs[2], lw=0.5, only_unused=True)\ng.contours.display(axs[3], lw=0.5, only_unused=True)\nset_fancy_labels(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Removing the large-scale North-South gradient reveals closed contours in the\nSouth-Western corner of the ewample region.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(ref=-10, linewidth=0.75)\nkw_a = dict(color=\"r\", label=\"Anticyclonic ({nb_obs} eddies)\")\nkw_c = dict(color=\"b\", label=\"Cyclonic ({nb_obs} eddies)\")\nkw_filled = dict(vmin=0, vmax=100, cmap=\"Spectral_r\", lut=20, intern=True, factor=100)\nfig, axs = quad_axes(\"Comparison between two detections\")\n# Match with intern/inner contour\ni_a, j_a, s_a = a_.match(a, intern=True, cmin=0.15)\ni_c, j_c, s_c = c_.match(c, intern=True, cmin=0.15)\n\na_.index(i_a).filled(axs[0], s_a, **kw_filled)\na.index(j_a).filled(axs[1], s_a, **kw_filled)\nc_.index(i_c).filled(axs[0], s_c, **kw_filled)\nm = c.index(j_c).filled(axs[1], s_c, **kw_filled)\n\ncb = plt.colorbar(m, cax=axs[0].figure.add_axes(pos_cb), orientation=\"horizontal\")\ncb.set_label(\"Similarity index (%)\", labelpad=-5)\na_.display(axs[0], **kw, **kw_a), c_.display(axs[0], **kw, **kw_c)\na.display(axs[1], **kw, **kw_a), c.display(axs[1], **kw, **kw_c)\n\naxs[0].set_title(\"Without filter\")\naxs[0].set_ylabel(\"Detection\")\naxs[1].set_title(\"With filter\")\naxs[2].set_ylabel(\"Contours' rejection criteria\")\n\ng_raw.contours.display(axs[2], lw=0.5, only_unused=True, display_criterion=True)\ng.contours.display(axs[3], lw=0.5, only_unused=True, display_criterion=True)\n\nfor ax in axs:\n ax.legend()\n\nset_fancy_labels(fig)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Very similar eddies have Similarity Indexes >= 40%\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Criteria for rejecting a contour :\n 0. Accepted (green)\n 1. Rejection for shape error (red)\n 2. Masked value within contour (blue)\n 3. Under or over the pixel limit bounds (black)\n 4. Amplitude criterion (yellow)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_a, j_a = i_a[s_a >= 0.4], j_a[s_a >= 0.4]\ni_c, j_c = i_c[s_c >= 0.4], j_c[s_c >= 0.4]\nfig = plt.figure(figsize=(12, 12))\nfig.suptitle(f\"Scatter plot (A : {i_a.shape[0]}, C : {i_c.shape[0]} matches)\")\n\nfor i, (label, field, factor, stop) in enumerate(\n (\n (\"Speed radius (km)\", \"radius_s\", 0.001, 120),\n (\"Effective radius (km)\", \"radius_e\", 0.001, 120),\n (\"Amplitude (cm)\", \"amplitude\", 100, 25),\n (\"Speed max (cm/s)\", \"speed_average\", 100, 25),\n )\n):\n ax = fig.add_subplot(2, 2, i + 1, title=label)\n ax.set_xlabel(\"Without filter\")\n ax.set_ylabel(\"With filter\")\n\n ax.plot(\n a_[field][i_a] * factor,\n a[field][j_a] * factor,\n \"r.\",\n label=\"Anticyclonic\",\n )\n ax.plot(\n c_[field][i_c] * factor,\n c[field][j_c] * factor,\n \"b.\",\n label=\"Cyclonic\",\n )\n ax.set_aspect(\"equal\"), ax.grid()\n ax.plot((0, 1000), (0, 1000), \"g\")\n ax.set_xlim(0, stop), ax.set_ylim(0, stop)\n ax.legend()\n\nset_fancy_labels(fig)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb index 3e7567d8..49024327 100644 --- a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "margin = 30\ng = RegularGridDataset(\n data.get_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n # Manual area subset\n indexs=dict(\n longitude=slice(1116 - margin, 1216 + margin),\n latitude=slice(476 - margin, 536 + margin),\n ),\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-1, vmax=1, cmap=\"RdBu_r\")\n# Draw line on the gulf stream front\ngreat_current = Contours(g.x_c, g.y_c, g.grid(\"adt\"), levels=(0.35,), keep_unclose=True)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax, m)" + "margin = 30\ng = RegularGridDataset(\n data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n # Manual area subset\n indexs=dict(\n longitude=slice(1116 - margin, 1216 + margin),\n latitude=slice(476 - margin, 536 + margin),\n ),\n)\n\nax = start_axes(\"ADT (m)\")\nm = g.display(ax, \"adt\", vmin=-1, vmax=1, cmap=\"RdBu_r\")\n# Draw line on the gulf stream front\ngreat_current = Contours(g.x_c, g.y_c, g.grid(\"adt\"), levels=(0.35,), keep_unclose=True)\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax, m)" ] }, { @@ -235,7 +235,7 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Eddies detected\")\na.display(ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax)" + "ax = start_axes(\"Eddies detected\")\na.display(\n ax, color=\"r\", linewidth=0.75, label=\"Anticyclonic ({nb_obs} eddies)\", ref=-10\n)\nc.display(ax, color=\"b\", linewidth=0.75, label=\"Cyclonic ({nb_obs} eddies)\", ref=-10)\nax.legend()\ngreat_current.display(ax, color=\"k\")\nupdate_axes(ax)" ] }, { @@ -273,7 +273,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb b/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb index 06d65865..381aa8f6 100644 --- a/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\ng.add_uv(\"adt\")\ng.copy(\"adt\", \"adt_high\")\nwavelength = 800\ng.bessel_high_filter(\"adt_high\", wavelength)\ndate = datetime(2016, 5, 15)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\ng.add_uv(\"adt\")\ng.copy(\"adt\", \"adt_high\")\nwavelength = 800\ng.bessel_high_filter(\"adt_high\", wavelength)\ndate = datetime(2016, 5, 15)" ] }, { @@ -91,7 +91,7 @@ }, "outputs": [], "source": [ - "ax = start_axes(\"Eddies detected over ADT\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nmerge_f.display(ax, lw=0.75, label=\"Eddies in the filtered grid ({nb_obs} eddies)\", ref=-10, color=\"k\")\nmerge_t.display(ax, lw=0.75, label=\"Eddies without filter ({nb_obs} eddies)\", ref=-10, color=\"r\")\nax.legend()\nupdate_axes(ax, m)" + "ax = start_axes(\"Eddies detected over ADT\")\nm = g.display(ax, \"adt\", vmin=-0.15, vmax=0.15)\nmerge_f.display(\n ax,\n lw=0.75,\n label=\"Eddies in the filtered grid ({nb_obs} eddies)\",\n ref=-10,\n color=\"k\",\n)\nmerge_t.display(\n ax, lw=0.75, label=\"Eddies without filter ({nb_obs} eddies)\", ref=-10, color=\"r\"\n)\nax.legend()\nupdate_axes(ax, m)" ] }, { @@ -176,7 +176,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb b/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb index 8207f8d1..0cfdc9a8 100644 --- a/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20160515.nc\"))\nc = EddiesObservations.load_file(data.get_path(\"Cyclonic_20160515.nc\"))\n\naviso_map = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\naviso_map.add_uv(\"adt\")" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20160515.nc\"))\nc = EddiesObservations.load_file(data.get_demo_path(\"Cyclonic_20160515.nc\"))\n\naviso_map = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\naviso_map.add_uv(\"adt\")" ] }, { @@ -111,7 +111,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb b/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb index 8ab39ebc..03eba8bf 100644 --- a/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))\nareas = list()\n# For each contour area will be compute in local reference\nfor i in a:\n x, y = coordinates_to_local(\n i[\"contour_lon_s\"], i[\"contour_lat_s\"], i[\"lon\"], i[\"lat\"]\n )\n areas.append(poly_area(x, y))\nareas = array(areas)" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\nareas = list()\n# For each contour area will be compute in local reference\nfor i in a:\n x, y = coordinates_to_local(\n i[\"contour_lon_s\"], i[\"contour_lat_s\"], i[\"lon\"], i[\"lat\"]\n )\n areas.append(poly_area(x, y))\nareas = array(areas)" ] }, { @@ -107,7 +107,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb b/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb index f367d7e2..0ef03f6f 100644 --- a/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\nc = Contours(g.x_c, g.y_c, g.grid(\"adt\") * 100, arange(-50, 50, 0.2))\ncontours = dict()\nfor coll in c.iter():\n for current_contour in coll.get_paths():\n _, _, _, aerr = current_contour.fit_circle()\n i = int(aerr // 4) + 1\n if i not in contours:\n contours[i] = list()\n contours[i].append(current_contour)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\nc = Contours(g.x_c, g.y_c, g.grid(\"adt\") * 100, arange(-50, 50, 0.2))\ncontours = dict()\nfor coll in c.iter():\n for current_contour in coll.get_paths():\n _, _, _, aerr = current_contour.fit_circle()\n i = int(aerr // 4) + 1\n if i not in contours:\n contours[i] = list()\n contours[i].append(current_contour)" ] }, { @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb b/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb index a7ff8efe..9b8b3951 100644 --- a/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\ng.add_uv(\"adt\", \"ugos\", \"vgos\")\ng.add_uv(\"sla\", \"ugosa\", \"vgosa\")\nwavelength = 400\ng.copy(\"adt\", \"adt_raw\")\ng.copy(\"sla\", \"sla_raw\")\ng.bessel_high_filter(\"adt\", wavelength)\ng.bessel_high_filter(\"sla\", wavelength)\ndate = datetime(2016, 5, 15)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)\ng.add_uv(\"adt\", \"ugos\", \"vgos\")\ng.add_uv(\"sla\", \"ugosa\", \"vgosa\")\nwavelength = 400\ng.copy(\"adt\", \"adt_raw\")\ng.copy(\"sla\", \"sla_raw\")\ng.bessel_high_filter(\"adt\", wavelength)\ng.bessel_high_filter(\"sla\", wavelength)\ndate = datetime(2016, 5, 15)" ] }, { @@ -66,7 +66,7 @@ }, "outputs": [], "source": [ - "kwargs_a_adt = dict(lw=0.5, label=\"Anticyclonic ADT ({nb_obs} eddies)\", ref=-10, color=\"k\")\nkwargs_c_adt = dict(lw=0.5, label=\"Cyclonic ADT ({nb_obs} eddies)\", ref=-10, color=\"r\")\nkwargs_a_sla = dict(lw=0.5, label=\"Anticyclonic SLA ({nb_obs} eddies)\", ref=-10, color=\"g\")\nkwargs_c_sla = dict(lw=0.5, label=\"Cyclonic SLA ({nb_obs} eddies)\", ref=-10, color=\"b\")" + "kwargs_a_adt = dict(\n lw=0.5, label=\"Anticyclonic ADT ({nb_obs} eddies)\", ref=-10, color=\"k\"\n)\nkwargs_c_adt = dict(lw=0.5, label=\"Cyclonic ADT ({nb_obs} eddies)\", ref=-10, color=\"r\")\nkwargs_a_sla = dict(\n lw=0.5, label=\"Anticyclonic SLA ({nb_obs} eddies)\", ref=-10, color=\"g\"\n)\nkwargs_c_sla = dict(lw=0.5, label=\"Cyclonic SLA ({nb_obs} eddies)\", ref=-10, color=\"b\")" ] }, { @@ -223,7 +223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb new file mode 100644 index 00000000..7fa04435 --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Stastics on identification files\n\nSome statistics on raw identification without any tracking\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom matplotlib.dates import date2num\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.observations.observation import EddiesObservations\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load demo sample and take only first year.\n\nReplace by a list of filename to apply on your own dataset.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:365]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Merge all identification dataset in one object\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "all_a = EddiesObservations.concatenate(\n [EddiesObservations.load_file(i) for i in file_objects]\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define polygon bound\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, x1, y0, y1 = 15, 20, 33, 38\nxs = np.array([[x0, x1, x1, x0, x0]], dtype=\"f8\")\nys = np.array([[y0, y0, y1, y1, y0]], dtype=\"f8\")\n# Polygon object is create to be usable by match function.\npolygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Geographic frequency of eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 0.125\nax = start_axes(\"\")\n# Count pixel used for each contour\ng_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True)\nm = g_a.display(\n ax, cmap=\"terrain_r\", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name=\"count\"\n)\nax.plot(polygon[\"contour_lon_e\"][0], polygon[\"contour_lat_e\"][0], \"r\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use match function to count number of eddies which intersect the polygon defined previously.\n`p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False)\ni_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "dt = np.datetime64(\"1970-01-01\") - np.datetime64(\"1950-01-01\")\nkw_hist = dict(\n bins=date2num(np.arange(21900, 22300).astype(\"datetime64\") - dt), histtype=\"step\"\n)\n# translate julian day in datetime64\nt = all_a.time.astype(\"datetime64\") - dt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count how many are in polygon\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Different way to count eddies presence in a polygon\")\nax.set_ylabel(\"Count\")\nm = all_a.mask_from_polygons(((xs, ys),))\nax.hist(t[m], label=\"center in polygon\", **kw_hist)\nax.hist(t[i_s[c_s > 0]], label=\"intersect speed contour with polygon\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], label=\"intersect extern contour with polygon\", **kw_hist)\nax.legend()\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Percent of are of interest occupy by eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Percent of polygon occupy by an anticyclonic eddy\")\nax.set_ylabel(\"Percent of polygon\")\nax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label=\"speed contour\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label=\"effective contour\", **kw_hist)\nax.legend(), ax.set_ylim(0, 25)\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb new file mode 100644 index 00000000..90ee1722 --- /dev/null +++ b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb @@ -0,0 +1,270 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Grid advection\n\nDummy advection which use only static geostrophic current, which didn't solve the complex circulation of the ocean.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load Input grid ADT\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g = RegularGridDataset(\n get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\n# Compute u/v from height\ng.add_uv(\"adt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(get_demo_path(\"Anticyclonic_20160515.nc\"))\nc = EddiesObservations.load_file(get_demo_path(\"Cyclonic_20160515.nc\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Quiver from u/v with eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(10, 5))\nax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\nax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid()\nx, y = meshgrid(g.x_c, g.y_c)\na.filled(ax, facecolors=\"r\", alpha=0.1), c.filled(ax, facecolors=\"b\", alpha=0.1)\n_ = ax.quiver(x.T, y.T, g.grid(\"u\"), g.grid(\"v\"), scale=20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Anim\nParticles setup\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step_p = 1 / 8\nx, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p))\nx, y = x.reshape(-1), y.reshape(-1)\n# Remove all original position that we can't advect at first place\nm = ~isnan(g.interp(\"u\", x, y))\nx0, y0 = x[m], y[m]\nx, y = x0.copy(), y0.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Movie properties\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kwargs = dict(frames=arange(51), interval=100)\nkw_p = dict(u_name=\"u\", v_name=\"v\", nb_step=2, time_step=21600)\nframe_t = kw_p[\"nb_step\"] * kw_p[\"time_step\"] / 86400.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def anim_ax(**kw):\n t = 0\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n a.filled(axes, facecolors=\"r\", alpha=0.1), c.filled(axes, facecolors=\"b\", alpha=0.1)\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line, t\n\n\ndef update(i_frame, t_step):\n global t\n x, y = p.__next__()\n t += t_step\n l.set_data(x, y)\n txt.set_text(f\"T0 + {t:.1f} days\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Filament forward\nDraw 3 last position in one path for each particles.,\nit could be run backward with `backward=True` option in filament method\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "p = g.filament(x, y, **kw_p, filament_size=3)\nfig, txt, l, t = anim_ax(lw=0.5)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Particle forward\nForward advection of particles\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "p = g.advect(x, y, **kw_p)\nfig, txt, l, t = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We get last position and run backward until original position\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "p = g.advect(x, y, **kw_p, backward=True)\nfig, txt, l, _ = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles stat\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Time_step settings\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.001),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Time duration\nWe keep same time_step but change time duration\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (5, 50, 100):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb b/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb index dcc9c0d3..2d6a7d3a 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)" + "g = RegularGridDataset(\n data.get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"),\n \"longitude\",\n \"latitude\",\n)" ] }, { @@ -215,7 +215,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb b/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb index 9ef65cb5..f30076fa 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)" + "g = RegularGridDataset(\n data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)" ] }, { @@ -111,7 +111,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb new file mode 100644 index 00000000..cbe6de64 --- /dev/null +++ b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb @@ -0,0 +1,209 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# LAVD experiment\n\nNaive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation) method with a static velocity field.\nIn the current example we didn't remove a mean vorticity.\n\nMethod are described here:\n\n - Abernathey, Ryan, and George Haller. \"Transport by Lagrangian Vortices in the Eastern Pacific\",\n Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021,\n https://doi.org/10.1175/JPO-D-17-0102.1\n - `Transport by Coherent Lagrangian Vortices`_,\n R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019,\n Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop\n\n https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, zeros\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(16, 9), dpi=dpi)\n ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n ax.set_xlim(0, 32), ax.set_ylim(28, 46)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"Vorticity integration along trajectory at initial position\")\n return cb\n\n\nkw_vorticity = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\nTo compute vorticity ($\\omega$) we compute u/v field with a stencil and apply the following equation with stencil\nmethod :\n\n\\begin{align}\\omega = \\frac{\\partial v}{\\partial x} - \\frac{\\partial u}{\\partial y}\\end{align}\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g = RegularGridDataset(\n get_demo_path(\"dt_med_allsat_phy_l4_20160515_20190101.nc\"), \"longitude\", \"latitude\"\n)\ng.add_uv(\"adt\")\nu_y = g.compute_stencil(g.grid(\"u\"), vertical=True)\nv_x = g.compute_stencil(g.grid(\"v\"))\ng.vars[\"vort\"] = v_x - u_y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display vorticity field\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, ax, _ = start_ax()\nmappable = g.display(ax, abs(g.grid(\"vort\")), **kw_vorticity)\ncb = update_axes(ax, mappable)\ncb.set_label(\"Vorticity\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles\nParticles specification\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 1 / 32\nx_g, y_g = arange(0, 36, step), arange(28, 46, step)\nx, y = meshgrid(x_g, y_g)\noriginal_shape = x.shape\nx, y = x.reshape(-1), y.reshape(-1)\nprint(f\"{len(x)} particles advected\")\n# A frame every 8h\nstep_by_day = 3\n# Compute step of advection every 4h\nnb_step = 2\nkw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name=\"u\", v_name=\"v\")\n# Start a generator which at each iteration return new position at next time step\nparticule = g.advect(x, y, **kw_p, rk4=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\n# Advection time\nnb_days = 8\n# Nb frame\nnb_time = step_by_day * nb_days\ni = 0.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Anim\nMovie of LAVD integration at each integration time step.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(i_frame):\n global lavd, i\n i += 1\n x, y = particule.__next__()\n # Interp vorticity on new_position\n lavd += abs(g.interp(\"vort\", x, y).reshape(original_shape) * 1 / nb_time)\n txt.set_text(f\"T0 + {i / step_by_day:.2f} days of advection\")\n pcolormesh.set_array(lavd / i * nb_time)\n return pcolormesh, txt\n\n\nkw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True)\nfig, ax, txt = start_ax(dpi=60)\nx_g_, y_g_ = (\n arange(0 - step / 2, 36 + step / 2, step),\n arange(28 - step / 2, 46 + step / 2, step),\n)\n# pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular\npcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity)\nupdate_axes(ax, pcolormesh)\n_ = VideoAnimation(ax.figure, update, **kw_video)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Final LAVD\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Format LAVD data\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display final LAVD with py eddy tracker detection.\nPeriod used for LAVD integration (8 days) is too short for a real use, but choose for example efficiency.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, ax, _ = start_ax()\nmappable = lavd.display(ax, \"lavd\", **kw_vorticity)\nEddiesObservations.load_file(get_demo_path(\"Anticyclonic_20160515.nc\")).display(\n ax, color=\"k\"\n)\nEddiesObservations.load_file(get_demo_path(\"Cyclonic_20160515.nc\")).display(\n ax, color=\"k\"\n)\n_ = update_axes(ax, mappable)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb b/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb index 73abbcda..ca4998ee 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "a = EddiesObservations.load_file(data.get_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_path(\"Cyclonic_20190223.nc\"))" + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\nc = EddiesObservations.load_file(data.get_demo_path(\"Cyclonic_20190223.nc\"))" ] }, { @@ -73,7 +73,7 @@ }, "outputs": [], "source": [ - "g = RegularGridDataset(\n data.get_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)\n\nax = start_axes(\"ADT (cm)\")\nm = g.display(ax, \"adt\", vmin=-120, vmax=120, factor=100)\nupdate_axes(ax, m)" + "g = RegularGridDataset(\n data.get_demo_path(\"nrt_global_allsat_phy_l4_20190223_20190226.nc\"),\n \"longitude\",\n \"latitude\",\n)\n\nax = start_axes(\"ADT (cm)\")\nm = g.display(ax, \"adt\", vmin=-120, vmax=120, factor=100)\nupdate_axes(ax, m)" ] }, { @@ -201,7 +201,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb b/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb new file mode 100644 index 00000000..d4cdb187 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nTime advection\n==============\n\nExample which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# sphinx_gallery_thumbnail_number = 2\nimport re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.gui import GUI_AXES\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Data\n----\nLoad Input time grid ADT\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n # To create U/V variable\n heigth=\"adt\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\nParticles setup\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step_p = 1 / 8\nx, y = meshgrid(arange(13, 36, step_p), arange(28, 40, step_p))\nx, y = x.reshape(-1), y.reshape(-1)\n# Remove all original position that we can't advect at first place\nt0 = 20181\nm = ~isnan(c[t0].interp(\"u\", x, y))\nx0, y0 = x[m], y[m]\nx, y = x0.copy(), y0.copy()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def anim_ax(**kw):\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line\n\n\ndef update(_):\n tt, xt, yt = f.__next__()\n mappable.set_data(xt, yt)\n d = timedelta(tt / 86400.0) + datetime(1950, 1, 1)\n txt.set_text(f\"{d:%Y/%m/%d-%H}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "f = c.filament(x, y, \"u\", \"v\", t_init=t0, nb_step=2, time_step=21600, filament_size=3)\nfig, txt, mappable = anim_ax(lw=0.5)\nani = VideoAnimation(fig, update, frames=arange(160), interval=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Particules stat\n---------------\nTime_step settings\n^^^^^^^^^^^^^^^^^^\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.002),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nkw_p = dict(u_name=\"u\", v_name=\"v\", nb_step=1)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n nb = int(30 * 86400 / time_step)\n # Go forward\n p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n # Go backward\n p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Time duration\n^^^^^^^^^^^^^\nWe keep same time_step but change time duration\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (10, 40, 80):\n x, y = x0.copy(), y0.copy()\n nb = int(duration * 86400 / time_step)\n # Go forward\n p = c.advect(x, y, time_step=time_step, t_init=20181.5, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n # Go backward\n p = c.advect(x, y, time_step=time_step, backward=True, t_init=t_ / 86400.0, **kw_p)\n for i in range(nb):\n t_, _, _ = p.__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb new file mode 100644 index 00000000..6f52e750 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# FSLE experiment in med\n\nExample to build Finite Size Lyapunov Exponents, parameter values must be adapted for your case.\n\nExample use a method similar to `AVISO flse`_\n\n https://www.aviso.altimetry.fr/en/data/products/value-added-products/\n fsle-finite-size-lyapunov-exponents/fsle-description.html\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import arange, arctan2, empty, isnan, log, ma, meshgrid, ones, pi, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ADT in med\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is\nmade for data stores in time cube, you could use also\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to\nload data-cube from multiple file.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n # To create U/V variable\n heigth=\"adt\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Methods to compute FSLE\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "@njit(cache=True, fastmath=True)\ndef check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6):\n \"\"\"\n Check if distance between eastern or northern particle to center particle is bigger than `dist_max`\n \"\"\"\n nb_p = x.shape[0] // 3\n delta = dist_max**2\n for i in range(nb_p):\n i0 = i * 3\n i_n = i0 + 1\n i_e = i0 + 2\n # If particle already set, we skip\n if m[i0] or m[i_n] or m[i_e]:\n continue\n # Distance with north\n dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n]\n dn = dxn**2 + dyn**2\n # Distance with east\n dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e]\n de = dxe**2 + dye**2\n\n if dn >= delta or de >= delta:\n s1 = dn + de\n at1 = 2 * (dxe * dxn + dye * dyn)\n at2 = de - dn\n s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * (\n (dxn - dye) ** 2 + (dxe + dyn) ** 2\n )\n flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5))\n theta[i] = arctan2(at1, at2 + s2) * 180 / pi\n # To know where value are set\n m_set[i] = False\n # To stop particle advection\n m[i0], m[i_n], m[i_e] = True, True, True\n\n\n@njit(cache=True)\ndef build_triplet(x, y, step=0.02):\n \"\"\"\n Triplet building for each position we add east and north point with defined step\n \"\"\"\n nb_x = x.shape[0]\n x_ = empty(nb_x * 3, dtype=x.dtype)\n y_ = empty(nb_x * 3, dtype=y.dtype)\n for i in range(nb_x):\n i0 = i * 3\n i_n, i_e = i0 + 1, i0 + 2\n x__, y__ = x[i], y[i]\n x_[i0], y_[i0] = x__, y__\n x_[i_n], y_[i_n] = x__, y__ + step\n x_[i_e], y_[i_e] = x__ + step, y__\n return x_, y_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Settings\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Step in degrees for ouput\nstep_grid_out = 1 / 25.0\n# Initial separation in degrees\ndist_init = 1 / 50.0\n# Final separation in degrees\ndist_max = 1 / 5.0\n# Time of start\nt0 = 20268\n# Number of time step by days\ntime_step_by_days = 5\n# Maximal time of advection\n# Here we limit because our data cube cover only 3 month\nnb_days = 85\n# Backward or forward\nbackward = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0_, y0_ = -5, 30\nlon_p = arange(x0_, x0_ + 43, step_grid_out)\nlat_p = arange(y0_, y0_ + 16, step_grid_out)\ny0, x0 = meshgrid(lat_p, lon_p)\ngrid_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Identify all particle not on land\nm = ~isnan(c[t0].interp(\"adt\", x0, y0))\nx0, y0 = x0[m], y0[m]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## FSLE\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Array to compute fsle\nfsle = zeros(x0.shape[0], dtype=\"f4\")\ntheta = zeros(x0.shape[0], dtype=\"f4\")\nmask = ones(x0.shape[0], dtype=\"f4\")\nx, y = build_triplet(x0, y0, dist_init)\nused = zeros(x.shape[0], dtype=\"bool\")\n\n# advection generator\nkw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name=\"u\", v_name=\"v\")\np = c.advect(x, y, time_step=86400 / time_step_by_days, **kw)\n\n# We check at each step of advection if particle distance is over `dist_max`\nfor i in range(time_step_by_days * nb_days):\n t, xt, yt = p.__next__()\n dt = t / 86400.0 - t0\n check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init)\n\n# Get index with original_position\ni = ((x0 - x0_) / step_grid_out).astype(\"i4\")\nj = ((y0 - y0_) / step_grid_out).astype(\"i4\")\nfsle_ = empty(grid_shape, dtype=\"f4\")\ntheta_ = empty(grid_shape, dtype=\"f4\")\nmask_ = ones(grid_shape, dtype=\"bool\")\nfsle_[i, j] = fsle\ntheta_[i, j] = theta\nmask_[i, j] = mask\n# Create a grid object\nfsle_custom = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n fsle=ma.array(fsle_, mask=mask_),\n theta=ma.array(theta_, mask=mask_),\n lon=lon_p,\n lat=lat_p,\n ),\n centered=True,\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display FSLE\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(13, 5), dpi=150)\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\nax.set_aspect(\"equal\")\nax.set_title(\"Finite size lyapunov exponent\", weight=\"bold\")\nkw = dict(cmap=\"viridis_r\", vmin=-20, vmax=0)\nm = fsle_custom.display(ax, 1 / fsle_custom.grid(\"fsle\"), **kw)\nax.grid()\n_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display Theta\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(13, 5), dpi=150)\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\nax.set_aspect(\"equal\")\nax.set_title(\"Theta from finite size lyapunov exponent\", weight=\"bold\")\nkw = dict(cmap=\"Spectral_r\", vmin=-180, vmax=180)\nm = fsle_custom.display(ax, fsle_custom.grid(\"theta\"), **kw)\nax.grid()\n_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb new file mode 100644 index 00000000..708d7024 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# LAVD detection and geometric detection\n\nNaive method to reproduce LAVD(Lagrangian-Averaged Vorticity deviation).\nIn the current example we didn't remove a mean vorticity.\n\nMethod are described here:\n\n - Abernathey, Ryan, and George Haller. \"Transport by Lagrangian Vortices in the Eastern Pacific\",\n Journal of Physical Oceanography 48, 3 (2018): 667-685, accessed Feb 16, 2021,\n https://doi.org/10.1175/JPO-D-17-0102.1\n - `Transport by Coherent Lagrangian Vortices`_,\n R. Abernathey, Sinha A., Tarshish N., Liu T., Zhang C., Haller G., 2019,\n Talk a t the Sources and Sinks of Ocean Mesoscale Eddy Energy CLIVAR Workshop\n\n https://usclivar.org/sites/default/files/meetings/2019/presentations/Aberernathey_CLIVAR.pdf\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, isnan, ma, meshgrid, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class LAVDGrid(RegularGridDataset):\n def init_speed_coef(self, uname=\"u\", vname=\"v\"):\n \"\"\"Hack to be able to identify eddy with LAVD field\"\"\"\n self._speed_ev = self.grid(\"lavd\")\n\n @classmethod\n def from_(cls, x, y, z):\n z.mask += isnan(z.data)\n datas = dict(lavd=z, lon=x, lat=y)\n return cls.with_array(coordinates=(\"lon\", \"lat\"), datas=datas, centered=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(12, 5), dpi=dpi)\n ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=GUI_AXES)\n ax.set_xlim(-6, 36), ax.set_ylim(31, 45)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"LAVD at initial position\")\n return cb\n\n\nkw_lavd = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Load data cube of 3 month\nc = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n heigth=\"adt\",\n)\n\n# Add vorticity at each time step\nfor g in c:\n u_y = g.compute_stencil(g.grid(\"u\"), vertical=True)\n v_x = g.compute_stencil(g.grid(\"v\"))\n g.vars[\"vort\"] = v_x - u_y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Particles\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Time properties, for example with advection only 25 days\nnb_days, step_by_day = 25, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name=\"u\", v_name=\"v\")\nt0 = 20236\nt0_grid = c[t0]\n# Geographic properties, we use a coarser resolution for time consuming reasons\nstep = 1 / 32.0\nx_g, y_g = arange(-6, 36, step), arange(30, 46, step)\nx0, y0 = meshgrid(x_g, y_g)\noriginal_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Get all particles in defined area\nm = ~isnan(t0_grid.interp(\"vort\", x0, y0))\nx0, y0 = x0[m], y0[m]\nprint(f\"{x0.size} particles advected\")\n# Gridded mask\nm = m.reshape(original_shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD forward (dynamic field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p)\nfor _ in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection\")\nmappable = lavd_forward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD backward (dynamic field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p)\nfor i in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection\")\nmappable = lavd_backward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD forward (static field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), **kw_p)\nfor _ in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection on a static velocity field\")\nmappable = lavd_forward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LAVD backward (static field)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p)\nfor i in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection on a static velocity field\")\nmappable = lavd_backward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Contour detection\nTo extract contour from LAVD grid, we will used method design for SSH, with some hacks and adapted options.\nIt will produce false amplitude and speed.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw_ident = dict(\n force_speed_unit=\"m/s\",\n force_height_unit=\"m\",\n pixel_limit=(40, 200000),\n date=datetime(2005, 5, 18),\n uname=None,\n vname=None,\n grid_height=\"lavd\",\n shape_error=70,\n step=1e-6,\n)\nfig, ax, _ = start_ax(\"Detection of eddies with several method\")\nt0_grid.bessel_high_filter(\"adt\", 700)\na, c = t0_grid.eddy_identification(\n \"adt\", \"u\", \"v\", kw_ident[\"date\"], step=0.002, shape_error=70\n)\nkw_ed = dict(ax=ax, intern=True, ref=-10)\na.filled(\n facecolors=\"#FFEFCD\", label=\"Anticyclonic SSH detection {nb_obs} eddies\", **kw_ed\n)\nc.filled(facecolors=\"#DEDEDE\", label=\"Cyclonic SSH detection {nb_obs} eddies\", **kw_ed)\nkw_cont = dict(ax=ax, extern_only=True, ls=\"-\", ref=-10)\nforward, _ = lavd_forward.eddy_identification(**kw_ident)\nforward.display(label=\"LAVD forward {nb_obs} eddies\", color=\"g\", **kw_cont)\nbackward, _ = lavd_backward.eddy_identification(**kw_ident)\nbackward.display(label=\"LAVD backward {nb_obs} eddies\", color=\"r\", **kw_cont)\nforward, _ = lavd_forward_static.eddy_identification(**kw_ident)\nforward.display(label=\"LAVD forward static {nb_obs} eddies\", color=\"cyan\", **kw_cont)\nbackward, _ = lavd_backward_static.eddy_identification(**kw_ident)\nbackward.display(\n label=\"LAVD backward static {nb_obs} eddies\", color=\"orange\", **kw_cont\n)\nax.legend()\nupdate_axes(ax)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb new file mode 100644 index 00000000..b92c4d21 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Build path of particle drifting\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import arange, meshgrid\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data cube\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n unset=True\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Advection properties\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "nb_days, step_by_day = 10, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day)\nt0 = 20210" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5))\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\nt, x, y = c.path(x0, y0, h_name=\"adt\", t_init=t0, **kw_p, nb_time=nb_time)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect=\"equal\")\nax.plot(x0, y0, \"k.\", ms=20)\nax.plot(x, y, lw=3)\nax.set_title(\"10 days particle paths\")\nax.set_xlim(31, 35), ax.set_ylim(32, 34.5)\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb index 58e4ff06..6e43e9a4 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc = c.extract_with_length((180, -1))" + "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc = c.extract_with_length((180, -1))" ] }, { @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb index edb20deb..c98e53f0 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nprint(a)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nprint(a)" ] }, { @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb new file mode 100644 index 00000000..0681c0fc --- /dev/null +++ b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Correspondances\n\nCorrespondances is a mechanism to intend to continue tracking with new detection\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom netCDF4 import Dataset\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\n\n# In order to hide some warning\nimport py_eddy_tracker.observations.observation\nfrom py_eddy_tracker.tracking import Correspondances\n\npy_eddy_tracker.observations.observation._display_check_warning = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_eddy(ed):\n fig = plt.figure(figsize=(10, 5))\n ax = fig.add_axes([0.05, 0.03, 0.90, 0.94])\n ed.plot(ax, ref=-10, marker=\"x\")\n lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True)\n plt.colorbar(lc).set_label(\"Time in Julian days (from 1950/01/01)\")\n ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3)\n ax.set_aspect(\"equal\")\n ax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get remote data, we will keep only 20 first days,\n`get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:20]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We run a traking with a tracker which use contour overlap, on 10 first time step\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_first_run = Correspondances(\n datasets=file_objects[:10], class_method=AreaTracker, virtual=4\n)\nstart_logger().setLevel(\"INFO\")\nc_first_run.track()\nstart_logger().setLevel(\"WARNING\")\nwith Dataset(\"correspondances.nc\", \"w\") as h:\n c_first_run.to_netcdf(h)\n# Next step are done only to build atlas and display it\nc_first_run.prepare_merging()\n\n# We have now an eddy object\neddies_area_tracker = c_first_run.merge(raw_data=False)\neddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0\neddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot from first ten days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restart from previous run\nWe give all filenames, the new one and filename from previous run\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_second_run = Correspondances(\n datasets=file_objects[:20],\n # This parameter must be identical in each run\n class_method=AreaTracker,\n virtual=4,\n # Previous saved correspondancs\n previous_correspondance=\"correspondances.nc\",\n)\nstart_logger().setLevel(\"INFO\")\nc_second_run.track()\nstart_logger().setLevel(\"WARNING\")\nc_second_run.prepare_merging()\n# We have now another eddy object\neddies_area_tracker_extend = c_second_run.merge(raw_data=False)\neddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0\neddies_area_tracker_extend.filled_by_interpolation(\n eddies_area_tracker_extend.virtual == 1\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot with time extension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker_extend)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb index 13d5a9a3..95595a7a 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\neddy = a.extract_ids([9672])\neddy_f = a.extract_ids([9672])\neddy_f.position_filter(median_half_window=1, loess_half_window=5)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\neddy = a.extract_ids([9672])\neddy_f = a.extract_ids([9672])\neddy_f.position_filter(median_half_window=1, loess_half_window=5)" ] }, { @@ -93,7 +93,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb index 3b5cc8e5..d0a2e5b0 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb @@ -26,14 +26,14 @@ }, "outputs": [], "source": [ - "from py_eddy_tracker.data import get_remote_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\nfrom py_eddy_tracker.gui import GUI\nfrom py_eddy_tracker.tracking import Correspondances" + "from py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\nfrom py_eddy_tracker.gui import GUI\nfrom py_eddy_tracker.tracking import Correspondances" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Get remote data, we will keep only 180 first days,\n`get_remote_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" + "Get remote data, we will keep only 180 first days,\n`get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "file_objects = get_remote_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:180]" + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:180]" ] }, { @@ -154,7 +154,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb index 36b2adc1..8e64b680 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc.position_filter(median_half_window=1, loess_half_window=5)" + "c = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nc.position_filter(median_half_window=1, loess_half_window=5)" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_xlim(-1, 9)\nax.set_ylim(36, 40)\nax.set_aspect(\"equal\")\nax.grid()\nc.plot(ax, color=\"gray\", lw=0.1, ref=-10, label=\"All tracks ({nb_tracks} tracks)\")\nc_subset.plot(ax, color=\"red\", lw=0.2, ref=-10, label=\"selected tracks ({nb_tracks} tracks)\")\nax.plot(\n (x0, x0, x1, x1, x0),\n (y0, y1, y1, y0, y0),\n color=\"green\",\n lw=1.5,\n label=\"Box of selection\",\n)\nax.legend()" + "fig = plt.figure(figsize=(12, 5))\nax = fig.add_axes((0.05, 0.05, 0.9, 0.9))\nax.set_xlim(-1, 9)\nax.set_ylim(36, 40)\nax.set_aspect(\"equal\")\nax.grid()\nc.plot(ax, color=\"gray\", lw=0.1, ref=-10, label=\"All tracks ({nb_tracks} tracks)\")\nc_subset.plot(\n ax, color=\"red\", lw=0.2, ref=-10, label=\"selected tracks ({nb_tracks} tracks)\"\n)\nax.plot(\n (x0, x0, x1, x1, x0),\n (y0, y1, y1, y0, y0),\n color=\"green\",\n lw=1.5,\n label=\"Box of selection\",\n)\nax.legend()" ] } ], @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb index 9a2510b2..08364d16 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Track animation\n\nRun in a terminal this script, which allow to watch eddy evolution\n" + "\nTrack animation\n===============\n\nRun in a terminal this script, which allow to watch eddy evolution.\n\nYou could use also *EddyAnim* script to display/save animation.\n" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\n# We get only 300 first step to save time of documentation builder\neddy = a.extract_ids([9672]).index(slice(0, 300))" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\n# We get only 300 first step to save time of documentation builder\neddy = a.extract_ids([9672]).index(slice(0, 300))" ] }, { diff --git a/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb index 259980a1..1fc4d082 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Track animation with standard matplotlib\n\nRun in a terminal this script, which allow to watch eddy evolution\n" + "\nTrack animation with standard matplotlib\n========================================\n\nRun in a terminal this script, which allow to watch eddy evolution.\n\nYou could use also *EddyAnim* script to display/save animation.\n" ] }, { @@ -26,7 +26,18 @@ }, "outputs": [], "source": [ - "import py_eddy_tracker_sample\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + "import re\n\nimport py_eddy_tracker_sample\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n# sphinx_gallery_thumbnail_path = '_static/no_image.png'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { @@ -44,7 +55,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\neddy = a.extract_ids([9672])" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\neddy = a.extract_ids([9672])" ] }, { @@ -62,7 +73,7 @@ }, "outputs": [], "source": [ - "a = Anim(eddy, intern=True, figsize=(8, 3.5), cmap=\"magma_r\", nb_step=6)\na.txt.set_position((17, 34.6))\na.ax.set_xlim(16.5, 23)\na.ax.set_ylim(34.5, 37)\n\n# arguments to get full animation\n# kwargs = dict(frames=arange(*a.period), interval=50)\n# arguments to reduce compute cost for doucmentation, we display only every 10 days\nkwargs = dict(frames=arange(*a.period)[200:800:10], save_count=60, interval=200)\n\nani = FuncAnimation(a.fig, a.func_animation, **kwargs)" + "a = Anim(eddy, intern=True, figsize=(8, 3.5), cmap=\"magma_r\", nb_step=5, dpi=50)\na.txt.set_position((17, 34.6))\na.ax.set_xlim(16.5, 23)\na.ax.set_ylim(34.5, 37)\n\n# arguments to get full animation\nkwargs = dict(frames=arange(*a.period)[300:800], interval=90)\n\nani = VideoAnimation(a.fig, a.func_animation, **kwargs)" ] } ], @@ -87,4 +98,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb index 739b907a..635c6b5a 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb @@ -55,7 +55,7 @@ }, "outputs": [], "source": [ - "kwargs_load = dict(\n include_vars=(\n \"longitude\",\n \"latitude\",\n \"observation_number\",\n \"track\",\n \"time\",\n \"speed_contour_longitude\",\n \"speed_contour_latitude\",\n )\n)\na = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + "kwargs_load = dict(\n include_vars=(\n \"longitude\",\n \"latitude\",\n \"observation_number\",\n \"track\",\n \"time\",\n \"speed_contour_longitude\",\n \"speed_contour_latitude\",\n )\n)\na = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" ] }, { @@ -144,7 +144,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb index 77e739bf..753cd625 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Count center\n\nDo Geo stat with center and compare with frequency method show: `sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py`\n" + "\n# Count center\n\nDo Geo stat with center and compare with frequency method\nshow: `sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py`\n" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic center frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic center frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies center frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each center\ng_a = a.grid_count(bins, intern=True, center=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True, center=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\ncb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.27, 0.01, 0.7]))\ncb.set_label(\"Eddies by 1\u00b0^2 by day\")\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(\n ax_ratio, name=\"count\", vmin=0.1, vmax=10, norm=LogNorm(), cmap=\"coolwarm_r\"\n)\nplt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" + "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic center frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic center frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies center frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each center\ng_a = a.grid_count(bins, intern=True, center=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True, center=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\ncb = plt.colorbar(m, cax=fig.add_axes([0.94, 0.27, 0.01, 0.7]))\ncb.set_label(\"Eddies by 1\u00b0^2 by day\")\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(\n ax_ratio, name=\"count\", norm=LogNorm(vmin=0.1, vmax=10), cmap=\"coolwarm_r\"\n)\nplt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" ] }, { @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb index 28bf4579..df495703 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\na = a.merge(c)\n\nstep = 0.1" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\na = a.merge(c)\n\nstep = 0.1" ] }, { @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb index 944167fd..9f06e010 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)" ] }, { @@ -136,7 +136,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb index c28ff02a..81809d8b 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nkwargs_a = dict(label=\"Anticyclonic\", color=\"r\", histtype=\"step\", density=True)\nkwargs_c = dict(label=\"Cyclonic\", color=\"b\", histtype=\"step\", density=True)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nkwargs_a = dict(label=\"Anticyclonic\", color=\"r\", histtype=\"step\", density=True)\nkwargs_c = dict(label=\"Cyclonic\", color=\"b\", histtype=\"step\", density=True)" ] }, { @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb index 40cc8e73..ed8c0295 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nnb_year = (a.period[1] - a.period[0] + 1) / 365.25" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nnb_year = (a.period[1] - a.period[0] + 1) / 365.25" ] }, { @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb new file mode 100644 index 00000000..f9fb474f --- /dev/null +++ b/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNormalised Eddy Lifetimes\n=========================\n\nExample from Evan Mason\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import interp, linspace, zeros\nfrom py_eddy_tracker_sample import get_demo_path\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "@njit(cache=True)\ndef sum_profile(x_new, y, out):\n \"\"\"Will sum all interpolated given array\"\"\"\n out += interp(x_new, linspace(0, 1, y.size), y)\n\n\nclass MyObs(TrackEddiesObservations):\n def eddy_norm_lifetime(self, name, nb, factor=1):\n \"\"\"\n :param str,array name: Array or field name\n :param int nb: size of output array\n \"\"\"\n y = self.parse_varname(name)\n x = linspace(0, 1, nb)\n out = zeros(nb, dtype=y.dtype)\n nb_track = 0\n for i, b0, b1 in self.iter_on(\"track\"):\n y_ = y[i]\n size_ = y_.size\n if size_ == 0:\n continue\n sum_profile(x, y_, out)\n nb_track += 1\n return x, out / nb_track * factor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load atlas\n----------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(include_vars=(\"speed_radius\", \"amplitude\", \"track\"))\na = MyObs.load_file(\n get_demo_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"), **kw\n)\nc = MyObs.load_file(get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\"), **kw)\n\nnb_max_a = a.nb_obs_by_track.max()\nnb_max_c = c.nb_obs_by_track.max()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute normalised lifetime\n---------------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Radius\nAC_radius = a.eddy_norm_lifetime(\"speed_radius\", nb=nb_max_a, factor=1e-3)\nCC_radius = c.eddy_norm_lifetime(\"speed_radius\", nb=nb_max_c, factor=1e-3)\n# Amplitude\nAC_amplitude = a.eddy_norm_lifetime(\"amplitude\", nb=nb_max_a, factor=1e2)\nCC_amplitude = c.eddy_norm_lifetime(\"amplitude\", nb=nb_max_c, factor=1e2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Figure\n------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, (ax0, ax1) = plt.subplots(nrows=2, figsize=(8, 6))\n\nax0.set_title(\"Normalised Mean Radius\")\nax0.plot(*AC_radius), ax0.plot(*CC_radius)\nax0.set_ylabel(\"Radius (km)\"), ax0.grid()\nax0.set_xlim(0, 1), ax0.set_ylim(0, None)\n\nax1.set_title(\"Normalised Mean Amplitude\")\nax1.plot(*AC_amplitude, label=\"AC\"), ax1.plot(*CC_amplitude, label=\"CC\")\nax1.set_ylabel(\"Amplitude (cm)\"), ax1.grid(), ax1.legend()\n_ = ax1.set_xlim(0, 1), ax1.set_ylim(0, None)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb index e15daf26..23f830d6 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Count pixel used\n\nDo Geo stat with frequency and compare with center count method: `sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py`\n" + "\n# Count pixel used\n\nDo Geo stat with frequency and compare with center count\nmethod: `sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py`\n" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each contour\ng_a = a.grid_count(bins, intern=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.27, 0.01, 0.7]))\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(\n ax_ratio, name=\"count\", vmin=0.1, vmax=10, norm=LogNorm(), cmap=\"coolwarm_r\"\n)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" + "fig = plt.figure(figsize=(12, 18.5))\nax_a = fig.add_axes([0.03, 0.75, 0.90, 0.25])\nax_a.set_title(\"Anticyclonic frequency\")\nax_c = fig.add_axes([0.03, 0.5, 0.90, 0.25])\nax_c.set_title(\"Cyclonic frequency\")\nax_all = fig.add_axes([0.03, 0.25, 0.90, 0.25])\nax_all.set_title(\"All eddies frequency\")\nax_ratio = fig.add_axes([0.03, 0.0, 0.90, 0.25])\nax_ratio.set_title(\"Ratio cyclonic / Anticyclonic\")\n\n# Count pixel used for each contour\ng_a = a.grid_count(bins, intern=True)\ng_a.display(ax_a, **kwargs_pcolormesh)\ng_c = c.grid_count(bins, intern=True)\ng_c.display(ax_c, **kwargs_pcolormesh)\n# Compute a ratio Cyclonic / Anticyclonic\nratio = g_c.vars[\"count\"] / g_a.vars[\"count\"]\n\n# Mask manipulation to be able to sum the 2 grids\nm_c = g_c.vars[\"count\"].mask\nm = m_c & g_a.vars[\"count\"].mask\ng_c.vars[\"count\"][m_c] = 0\ng_c.vars[\"count\"] += g_a.vars[\"count\"]\ng_c.vars[\"count\"].mask = m\n\nm = g_c.display(ax_all, **kwargs_pcolormesh)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.27, 0.01, 0.7]))\n\ng_c.vars[\"count\"] = ratio\nm = g_c.display(\n ax_ratio, name=\"count\", norm=LogNorm(vmin=0.1, vmax=10), cmap=\"coolwarm_r\"\n)\nplt.colorbar(m, cax=fig.add_axes([0.95, 0.02, 0.01, 0.2]))\n\nfor ax in (ax_a, ax_c, ax_all, ax_ratio):\n ax.set_aspect(\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.grid()" ] }, { @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb index 3cfb6140..9792f8f4 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\")\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nnb_year = (a.period[1] - a.period[0] + 1) / 365.25" + "a = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)\nc = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nnb_year = (a.period[1] - a.period[0] + 1) / 365.25" ] }, { @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb b/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb index 025b62d0..b30682a1 100644 --- a/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb +++ b/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\n\ndate = datetime(2016, 7, 7)\n\nfilename_alt = data.get_path(f\"dt_blacksea_allsat_phy_l4_{date:%Y%m%d}_20200801.nc\")\nfilename_sst = data.get_path(\n f\"{date:%Y%m%d}000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc\"\n)\nvar_name_sst = \"analysed_sst\"\n\nextent = [27, 42, 40.5, 47]" + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\n\ndate = datetime(2016, 7, 7)\n\nfilename_alt = data.get_demo_path(\n f\"dt_blacksea_allsat_phy_l4_{date:%Y%m%d}_20200801.nc\"\n)\nfilename_sst = data.get_demo_path(\n f\"{date:%Y%m%d}000000-GOS-L4_GHRSST-SSTfnd-OISST_HR_REP-BLK-v02.0-fv01.0.nc\"\n)\nvar_name_sst = \"analysed_sst\"\n\nextent = [27, 42, 40.5, 47]" ] }, { @@ -44,7 +44,7 @@ }, "outputs": [], "source": [ - "sst = RegularGridDataset(filename=filename_sst, x_name=\"lon\", y_name=\"lat\")\nalti = RegularGridDataset(\n data.get_path(filename_alt), x_name=\"longitude\", y_name=\"latitude\"\n)\n# We can use `Grid` tools to interpolate ADT on the sst grid\nsst.regrid(alti, \"sla\")\nsst.add_uv(\"sla\")" + "sst = RegularGridDataset(filename=filename_sst, x_name=\"lon\", y_name=\"lat\")\nalti = RegularGridDataset(\n data.get_demo_path(filename_alt), x_name=\"longitude\", y_name=\"latitude\"\n)\n# We can use `Grid` tools to interpolate ADT on the sst grid\nsst.regrid(alti, \"sla\")\nsst.add_uv(\"sla\")" ] }, { @@ -226,7 +226,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb b/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb new file mode 100644 index 00000000..7ba30914 --- /dev/null +++ b/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nColocate looper with eddy from altimetry\n========================================\n\nAll loopers data used in this example are a subset from the dataset described in this article\n[Lumpkin, R. : Global characteristics of coherent vortices from surface drifter trajectories](https://doi.org/10.1002/2015JC011435)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nimport numpy as np\nimport py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], aspect=\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load eddies dataset\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cyclonic_eddies = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nanticyclonic_eddies = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load loopers dataset\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "loopers_med = TrackEddiesObservations.load_file(\n data.get_demo_path(\"loopers_lumpkin_med.nc\")\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Global view\n===========\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"All drifters available in Med from Lumpkin dataset\")\nloopers_med.plot(ax, lw=0.5, color=\"r\", ref=-10)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One segment of drifter\n======================\n\nGet a drifter segment (the indexes used have no correspondance with the original dataset).\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "looper = loopers_med.extract_ids((3588,))\nfig = plt.figure(figsize=(16, 6))\nax = fig.add_subplot(111, aspect=\"equal\")\nlooper.plot(ax, lw=0.5, label=\"Original position of drifter\")\nlooper_filtered = looper.copy()\nlooper_filtered.position_filter(1, 13)\ns = looper_filtered.scatter(\n ax,\n \"time\",\n cmap=plt.get_cmap(\"Spectral_r\", 20),\n label=\"Filtered position of drifter\",\n)\nplt.colorbar(s).set_label(\"time (days from 1/1/1950)\")\nax.legend()\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Try to find a detected eddies with adt at same place. We used filtered track to simulate an eddy center\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "match = looper_filtered.close_tracks(\n anticyclonic_eddies, method=\"close_center\", delta=0.1, nb_obs_min=50\n)\nfig = plt.figure(figsize=(16, 6))\nax = fig.add_subplot(111, aspect=\"equal\")\nlooper.plot(ax, lw=0.5, label=\"Original position of drifter\")\nlooper_filtered.plot(ax, lw=1.5, label=\"Filtered position of drifter\")\nmatch.plot(ax, lw=1.5, label=\"Matched eddy\")\nax.legend()\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display radius of this 2 datasets.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(20, 8))\nax = fig.add_subplot(111)\nax.plot(looper.time, looper.radius_s / 1e3, label=\"loopers\")\nlooper_radius = looper.copy()\nlooper_radius.median_filter(1, \"time\", \"radius_s\", inplace=True)\nlooper_radius.loess_filter(13, \"time\", \"radius_s\", inplace=True)\nax.plot(\n looper_radius.time,\n looper_radius.radius_s / 1e3,\n label=\"loopers (filtered half window 13 days)\",\n)\nax.plot(match.time, match.radius_s / 1e3, label=\"altimetry\")\nmatch_radius = match.copy()\nmatch_radius.median_filter(1, \"time\", \"radius_s\", inplace=True)\nmatch_radius.loess_filter(13, \"time\", \"radius_s\", inplace=True)\nax.plot(\n match_radius.time,\n match_radius.radius_s / 1e3,\n label=\"altimetry (filtered half window 13 days)\",\n)\nax.set_ylabel(\"radius(km)\"), ax.set_ylim(0, 100)\nax.legend()\nax.set_title(\"Radius from loopers and altimeter\")\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation of a drifter and its colocated eddy\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(frame):\n # We display last 5 days of loopers trajectory\n m = (looper.time < frame) * (looper.time > (frame - 5))\n anim.func_animation(frame)\n line.set_data(looper.lon[m], looper.lat[m])\n\n\nanim = Anim(match, intern=True, figsize=(8, 8), cmap=\"magma_r\", nb_step=10, dpi=75)\n# mappable to show drifter in red\nline = anim.ax.plot([], [], \"r\", lw=4, zorder=100)[0]\nanim.fig.suptitle(\"\")\n_ = VideoAnimation(anim.fig, update, frames=np.arange(*anim.period, 1), interval=125)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb b/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb new file mode 100644 index 00000000..a46a7e22 --- /dev/null +++ b/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Contour fit\n\nTwo type of fit :\n - Ellipse\n - Circle\n\nIn the two case we use a least square algorithm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import cos, linspace, radians, sin\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import fit_circle_, fit_ellipse" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load example identification file\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function to draw circle or ellipse from parameter\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def build_circle(x0, y0, r):\n angle = radians(linspace(0, 360, 50))\n x_norm, y_norm = cos(angle), sin(angle)\n return local_to_coordinates(x_norm * r, y_norm * r, x0, y0)\n\n\ndef build_ellipse(x0, y0, a, b, theta):\n angle = radians(linspace(0, 360, 50))\n x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle)\n y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle)\n return local_to_coordinates(x, y, x0, y0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot fitted circle or ellipse on stored contour\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "xs, ys = a.contour_lon_s, a.contour_lat_s\n\nfig = plt.figure(figsize=(15, 15))\n\nj = 1\nfor i in range(0, 800, 30):\n x, y = xs[i], ys[i]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n ax = fig.add_subplot(4, 4, j)\n ax.grid(), ax.set_aspect(\"equal\")\n ax.plot(x, y, label=\"store\", color=\"black\")\n x0, y0, a, b, theta = fit_ellipse(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_ellipse(x0, y0, a, b, theta), label=\"ellipse\", color=\"green\")\n x0, y0, radius, shape_error = fit_circle_(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_circle(x0, y0, radius), label=\"circle\", color=\"red\", lw=0.5)\n if j == 16:\n break\n j += 1" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb b/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb new file mode 100644 index 00000000..69e49b57 --- /dev/null +++ b/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Visvalingam algorithm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import matplotlib.animation as animation\nfrom matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import array, empty\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import uniform_resample\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import vertice_overlap, visvalingam\n\n\n@njit(cache=True)\ndef visvalingam_polys(x, y, nb_pt):\n nb = x.shape[0]\n x_new = empty((nb, nb_pt), dtype=x.dtype)\n y_new = empty((nb, nb_pt), dtype=y.dtype)\n for i in range(nb):\n x_new[i], y_new[i] = visvalingam(x[i], y[i], nb_pt)\n return x_new, y_new\n\n\n@njit(cache=True)\ndef uniform_resample_polys(x, y, nb_pt):\n nb = x.shape[0]\n x_new = empty((nb, nb_pt), dtype=x.dtype)\n y_new = empty((nb, nb_pt), dtype=y.dtype)\n for i in range(nb):\n x_new[i], y_new[i] = uniform_resample(x[i], y[i], fixed_size=nb_pt)\n return x_new, y_new\n\n\ndef update_line(num):\n nb = 50 - num - 20\n x_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb)\n for i, (x_, y_) in enumerate(zip(x_v, y_v)):\n lines_v[i].set_data(x_, y_)\n x_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb)\n for i, (x_, y_) in enumerate(zip(x_u, y_u)):\n lines_u[i].set_data(x_, y_)\n scores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0\n scores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0\n for i, (s_v, s_u) in enumerate(zip(scores_v, scores_u)):\n texts[i].set_text(f\"Score uniform {s_u:.1f} %\\nScore visvalingam {s_v:.1f} %\")\n title.set_text(f\"{nb} points by contour in place of 50\")\n return (title, *lines_u, *lines_v, *texts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load detection files\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = EddiesObservations.load_file(data.get_demo_path(\"Anticyclonic_20190223.nc\"))\na = a.extract_with_mask((abs(a.lat) < 66) * (abs(a.radius_e) > 80e3))\n\nnb_pt = 10\nx_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nx_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nscores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0\nscores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0\nd_6 = scores_v - scores_u\nnb_pt = 18\nx_v, y_v = visvalingam_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nx_u, y_u = uniform_resample_polys(a.contour_lon_e, a.contour_lat_e, nb_pt)\nscores_v = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_v, y_v) * 100.0\nscores_u = vertice_overlap(a.contour_lon_e, a.contour_lat_e, x_u, y_u) * 100.0\nd_12 = scores_v - scores_u\na = a.index(array((d_6.argmin(), d_6.argmax(), d_12.argmin(), d_12.argmax())))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure()\naxs = [\n fig.add_subplot(221),\n fig.add_subplot(222),\n fig.add_subplot(223),\n fig.add_subplot(224),\n]\nlines_u, lines_v, texts, score_text = list(), list(), list(), list()\nfor i, obs in enumerate(a):\n axs[i].set_aspect(\"equal\")\n axs[i].grid()\n axs[i].set_xticklabels([]), axs[i].set_yticklabels([])\n axs[i].plot(\n obs[\"contour_lon_e\"], obs[\"contour_lat_e\"], \"r\", lw=6, label=\"Original contour\"\n )\n lines_v.append(axs[i].plot([], [], color=\"limegreen\", lw=4, label=\"visvalingam\")[0])\n lines_u.append(\n axs[i].plot([], [], color=\"black\", lw=2, label=\"uniform resampling\")[0]\n )\n texts.append(axs[i].set_title(\"\", fontsize=8))\naxs[0].legend(fontsize=8)\ntitle = fig.suptitle(\"\")\nanim = animation.FuncAnimation(fig, update_line, 27)\nanim" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_atlas.ipynb b/notebooks/python_module/16_network/pet_atlas.ipynb new file mode 100644 index 00000000..31e3580f --- /dev/null +++ b/notebooks/python_module/16_network/pet_atlas.ipynb @@ -0,0 +1,371 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Network Analysis\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import ma\n\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nn = NetworkObservations.load_file(\n get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc\"\n )\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 1 / 10.0\nbins = ((-10, 37, step), (30, 46, step))\nkw_time = dict(cmap=\"terrain_r\", factor=100.0 / n.nb_days, name=\"count\")\nkw_ratio = dict(cmap=plt.get_cmap(\"YlGnBu_r\", 10))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## All\nDisplay the % of time each pixel (1/10\u00b0) is within an anticyclonic network\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_all = n.grid_count(bins)\nm = g_all.display(ax, **kw_time, vmin=0, vmax=75)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Network longer than 10 days\nDisplay the % of time each pixel (1/10\u00b0) is within an anticyclonic network\nwhich total lifetime in longer than 10 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nn10 = n.longer_than(10)\ng_10 = n10.grid_count(bins)\nm = g_10.display(ax, **kw_time, vmin=0, vmax=75)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ratio\nRatio between the longer and total presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_10.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_10.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Blue = mostly short networks\n\n## Network longer than 20 days\nDisplay the % of time each pixel (1/10\u00b0) is within an anticyclonic network\nwhich total lifetime is longer than 20 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nn20 = n.longer_than(20)\ng_20 = n20.grid_count(bins)\nm = g_20.display(ax, **kw_time, vmin=0, vmax=75)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ratio\nRatio between the longer and total presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_20.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will hide pixel which are used less than 365 times\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "g_ = ma.array(\n g_20.vars[\"count\"] * 100.0 / g_all.vars[\"count\"], mask=g_all.vars[\"count\"] < 365\n)\nax = start_axes(\"\")\nm = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we will hide pixel which are used more than 365 times\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = ma.array(\n g_20.vars[\"count\"] * 100.0 / g_all.vars[\"count\"], mask=g_all.vars[\"count\"] >= 365\n)\nm = g_20.display(ax, **kw_ratio, vmin=50, vmax=100, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Coastal areas are mostly populated by short networks\n\n## All merging\nDisplay the occurence of merging events\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_all_merging = n.merging_event().grid_count(bins)\nm = g_all_merging.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ratio merging events / eddy presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_all_merging.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_all_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nmerger = n10.remove_dead_end(nobs=10).merging_event()\ng_10_merging = merger.grid_count(bins)\nm = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Merging in networks longer than 10 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nmerger = n10.merging_event()\ng_10_merging = merger.grid_count(bins)\nm = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ratio merging events / eddy presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = ma.array(\n g_10_merging.vars[\"count\"] * 100.0 / g_10.vars[\"count\"],\n mask=g_10.vars[\"count\"] < 365,\n)\nm = g_10_merging.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## All Spliting\nDisplay the occurence of spliting events\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_all_spliting = n.spliting_event().grid_count(bins)\nm = g_all_spliting.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ratio spliting events / eddy presence\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = g_all_spliting.vars[\"count\"] * 100.0 / g_all.vars[\"count\"]\nm = g_all_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Spliting in networks longer than 10 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_10_spliting = n10.spliting_event().grid_count(bins)\nm = g_10_spliting.display(ax, **kw_time, vmin=0, vmax=1)\nupdate_axes(ax, m).set_label(\"Pixel used in % of time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\ng_ = ma.array(\n g_10_spliting.vars[\"count\"] * 100.0 / g_10.vars[\"count\"],\n mask=g_10.vars[\"count\"] < 365,\n)\nm = g_10_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_)\nupdate_axes(ax, m).set_label(\"Pixel used in % all atlas\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_follow_particle.ipynb b/notebooks/python_module/16_network/pet_follow_particle.ipynb new file mode 100644 index 00000000..a2a97944 --- /dev/null +++ b/notebooks/python_module/16_network/pet_follow_particle.ipynb @@ -0,0 +1,159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nFollow particle\n===============\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, ones, unique, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.observations.groups import particle_candidate\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(651)\nn = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269))\nn = n.remove_dead_end(nobs=0, ndays=10)\nn.numbering_segment()\nc = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n heigth=\"adt\",\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Schema\n------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(12, 6))\nax = fig.add_axes([0.05, 0.05, 0.9, 0.9])\n_ = n.display_timeline(ax, field=\"longitude\", marker=\"+\", lw=2, markersize=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation\n---------\nParticle settings\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t_snapshot = 20200\nstep = 1 / 50.0\nx, y = meshgrid(arange(20, 36, step), arange(30, 46, step))\nN = 6\nx_f, y_f = x[::N, ::N].copy(), y[::N, ::N].copy()\nx, y = x.reshape(-1), y.reshape(-1)\nx_f, y_f = x_f.reshape(-1), y_f.reshape(-1)\nn_ = n.extract_with_mask(n.time == t_snapshot)\nindex = n_.contains(x, y, intern=True)\nm = index != -1\nindex = n_.segment[index[m]]\nindex_ = unique(index)\nx, y = x[m], y[m]\nm = ~n_.inside(x_f, y_f, intern=True)\nx_f, y_f = x_f[m], y_f[m]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cmap = colors.ListedColormap(list(n.COLORS), name=\"from_list\", N=n.segment.max() + 1)\na = Anim(\n n,\n intern=False,\n figsize=(12, 6),\n nb_step=1,\n dpi=60,\n field_color=\"segment\",\n field_txt=\"segment\",\n cmap=cmap,\n)\na.fig.suptitle(\"\"), a.ax.set_xlim(24, 36), a.ax.set_ylim(30, 36)\na.txt.set_position((25, 31))\n\nstep = 0.25\nkw_p = dict(nb_step=2, time_step=86400 * step * 0.5, t_init=t_snapshot - 2 * step)\n\nmappables = dict()\nparticules = c.advect(x, y, \"u\", \"v\", **kw_p)\nfilament = c.filament(x_f, y_f, \"u\", \"v\", **kw_p, filament_size=3)\nkw = dict(ls=\"\", marker=\".\", markersize=0.25)\nfor k in index_:\n m = k == index\n mappables[k] = a.ax.plot([], [], color=cmap(k), **kw)[0]\nm_filament = a.ax.plot([], [], lw=0.25, color=\"gray\")[0]\n\n\ndef update(frame):\n tt, xt, yt = particules.__next__()\n for k, mappable in mappables.items():\n m = index == k\n mappable.set_data(xt[m], yt[m])\n tt, xt, yt = filament.__next__()\n m_filament.set_data(xt, yt)\n if frame % 1 == 0:\n a.func_animation(frame)\n\n\nani = VideoAnimation(a.fig, update, frames=arange(20200, 20269, step), interval=200)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Particle advection\n^^^^^^^^^^^^^^^^^^\nAdvection from speed contour to speed contour (default)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 1 / 60.0\n\nt_start, t_end = int(n.period[0]), int(n.period[1])\ndt = 14\n\nshape = (n.obs.size, 2)\n# Forward run\ni_target_f, pct_target_f = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in arange(t_start, t_end - dt):\n particle_candidate(c, n, step, t, i_target_f, pct_target_f, n_days=dt)\n\n# Backward run\ni_target_b, pct_target_b = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in arange(t_start + dt, t_end):\n particle_candidate(c, n, step, t, i_target_b, pct_target_b, n_days=-dt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(10, 10))\nax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45])\nax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45])\nax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45])\nax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45])\nax_1st_b.set_title(\"Backward advection for each time step\")\nax_1st_f.set_title(\"Forward advection for each time step\")\nax_1st_b.set_ylabel(\"Color -> First target\\nLatitude\")\nax_2nd_b.set_ylabel(\"Color -> Secondary target\\nLatitude\")\nax_2nd_b.set_xlabel(\"Julian days\"), ax_2nd_f.set_xlabel(\"Julian days\")\nax_1st_f.set_yticks([]), ax_2nd_f.set_yticks([])\nax_1st_f.set_xticks([]), ax_1st_b.set_xticks([])\n\n\ndef color_alpha(target, pct, vmin=5, vmax=80):\n color = cmap(n.segment[target])\n # We will hide under 5 % and from 80% to 100 % it will be 1\n alpha = (pct - vmin) / (vmax - vmin)\n alpha[alpha < 0] = 0\n alpha[alpha > 1] = 1\n color[:, 3] = alpha\n return color\n\n\nkw = dict(\n name=None, yfield=\"longitude\", event=False, zorder=-100, s=(n.speed_area / 20e6)\n)\nn.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw)\nn.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw)\nn.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw)\nn.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw)\nfor ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f):\n n.display_timeline(ax, field=\"longitude\", marker=\"+\", lw=2, markersize=5)\n ax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_group_anim.ipynb b/notebooks/python_module/16_network/pet_group_anim.ipynb new file mode 100644 index 00000000..090170ff --- /dev/null +++ b/notebooks/python_module/16_network/pet_group_anim.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNetwork group process\n=====================\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# sphinx_gallery_thumbnail_number = 2\nimport re\nfrom datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.colors import ListedColormap\nfrom numba import njit\nfrom numpy import arange, array, empty, ones\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import flatten_line_matrix\nfrom py_eddy_tracker.observations.network import Network\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "NETWORK_GROUPS = list()\n\n\n@njit(cache=True)\ndef apply_replace(x, x0, x1):\n nb = x.shape[0]\n for i in range(nb):\n if x[i] == x0:\n x[i] = x1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Modified class to catch group process at each step in order to illustrate processing\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MyNetwork(Network):\n def get_group_array(self, results, nb_obs):\n \"\"\"With a loop on all pair of index, we will label each obs with a group\n number\n \"\"\"\n nb_obs = array(nb_obs, dtype=\"u4\")\n day_start = nb_obs.cumsum() - nb_obs\n gr = empty(nb_obs.sum(), dtype=\"u4\")\n gr[:] = self.NOGROUP\n\n id_free = 1\n for i, j, ii, ij in results:\n gr_i = gr[slice(day_start[i], day_start[i] + nb_obs[i])]\n gr_j = gr[slice(day_start[j], day_start[j] + nb_obs[j])]\n # obs with no groups\n m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] == self.NOGROUP)\n nb_new = m.sum()\n gr_i[ii[m]] = gr_j[ij[m]] = arange(id_free, id_free + nb_new)\n id_free += nb_new\n # associate obs with no group with obs with group\n m = (gr_i[ii] != self.NOGROUP) * (gr_j[ij] == self.NOGROUP)\n gr_j[ij[m]] = gr_i[ii[m]]\n m = (gr_i[ii] == self.NOGROUP) * (gr_j[ij] != self.NOGROUP)\n gr_i[ii[m]] = gr_j[ij[m]]\n # case where 2 obs have a different group\n m = gr_i[ii] != gr_j[ij]\n if m.any():\n # Merge of group, ref over etu\n for i_, j_ in zip(ii[m], ij[m]):\n g0, g1 = gr_i[i_], gr_j[j_]\n apply_replace(gr, g0, g1)\n NETWORK_GROUPS.append((i, j, gr.copy()))\n return gr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Movie period\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "t0 = (datetime(2005, 5, 1) - datetime(1950, 1, 1)).days\nt1 = (datetime(2005, 6, 1) - datetime(1950, 1, 1)).days" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get data from period and area\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "e = EddiesObservations.load_file(data.get_demo_path(\"network_med.nc\"))\ne = e.extract_with_mask((e.time >= t0) * (e.time < t1)).extract_with_area(\n dict(llcrnrlon=25, urcrnrlon=35, llcrnrlat=31, urcrnrlat=37.5)\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reproduce individual daily identification(for demonstration)\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "EDDIES_BY_DAYS = list()\nfor i, b0, b1 in e.iter_on(\"time\"):\n EDDIES_BY_DAYS.append(e.index(i))\n# need for display\ne = EddiesObservations.concatenate(EDDIES_BY_DAYS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run network building group to intercept every step\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = MyNetwork.from_eddiesobservations(EDDIES_BY_DAYS, window=7)\n_ = n.group_observations(minimal_area=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(frame):\n i_current, i_match, gr = NETWORK_GROUPS[frame]\n current = EDDIES_BY_DAYS[i_current]\n x = flatten_line_matrix(current.contour_lon_e)\n y = flatten_line_matrix(current.contour_lat_e)\n current_contour.set_data(x, y)\n match = EDDIES_BY_DAYS[i_match]\n x = flatten_line_matrix(match.contour_lon_e)\n y = flatten_line_matrix(match.contour_lat_e)\n matched_contour.set_data(x, y)\n groups.set_array(gr)\n txt.set_text(f\"Day {i_current} match with day {i_match}\")\n s = 80 * ones(gr.shape)\n s[gr == 0] = 4\n groups.set_sizes(s)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(16, 9), dpi=50)\nax = fig.add_axes([0, 0, 1, 1])\nax.set_aspect(\"equal\"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5)\ncmap = ListedColormap([\"gray\", *e.COLORS[:-1]], name=\"from_list\", N=30)\nkw_s = dict(cmap=cmap, vmin=0, vmax=30)\ngroups = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[0][2], **kw_s)\ncurrent_contour = ax.plot([], [], \"k\", lw=2, label=\"Current contour\")[0]\nmatched_contour = ax.plot([], [], \"r\", lw=1, ls=\"--\", label=\"Candidate contour\")[0]\ntxt = ax.text(29, 35, \"\", fontsize=25)\nax.legend(fontsize=25)\nani = VideoAnimation(fig, update, frames=len(NETWORK_GROUPS), interval=220)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Final Result\n------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(16, 9))\nax = fig.add_axes([0, 0, 1, 1])\nax.set_aspect(\"equal\"), ax.grid(), ax.set_xlim(26, 34), ax.set_ylim(31, 35.5)\n_ = ax.scatter(e.lon, e.lat, c=NETWORK_GROUPS[-1][2], **kw_s)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb b/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb new file mode 100644 index 00000000..9d659597 --- /dev/null +++ b/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb @@ -0,0 +1,346 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nIoannou case\n============\nFigure 10 from https://doi.org/10.1002/2017JC013158\n\nWe want to find the Ierapetra Eddy described above in a network demonstration run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import arange, array, pi, where\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.generic import coordinates_to_local\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.poly import fit_ellipse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.03, 0.06, 0.90, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We know the network ID, we will get directly\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ioannou_case = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(\n 651\n)\nprint(ioannou_case.infos())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It seems that this network is huge! Our case is visible at 22E 33.5N\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes()\nioannou_case.plot(ax, color_cycle=ioannou_case.COLORS)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Full Timeline\n-------------\nThe network span for many years... How to cut the interesting part?\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.xaxis.set_major_formatter(formatter), ax.grid()\n_ = ioannou_case.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Sub network and new numbering\n-----------------------------\nHere we chose to keep only the order 3 segments relatives to our chosen eddy\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i = where(\n (ioannou_case.lat > 33)\n * (ioannou_case.lat < 34)\n * (ioannou_case.lon > 22)\n * (ioannou_case.lon < 23)\n * (ioannou_case.time > 20630)\n * (ioannou_case.time < 20650)\n)[0][0]\nclose_to_i3 = ioannou_case.relative(i, order=3)\nclose_to_i3.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\nQuick movie to see better!\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "a = Anim(\n close_to_i3,\n figsize=(12, 4),\n cmap=colors.ListedColormap(\n list(close_to_i3.COLORS), name=\"from_list\", N=close_to_i3.segment.max() + 1\n ),\n nb_step=7,\n dpi=70,\n field_color=\"segment\",\n field_txt=\"segment\",\n)\na.ax.set_xlim(19, 30), a.ax.set_ylim(32, 35.25)\na.txt.set_position((21.5, 32.7))\n# We display in video only from the 100th day to the 500th\nkwargs = dict(frames=arange(*a.period)[100:501], interval=100)\nani = VideoAnimation(a.fig, a.func_animation, **kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Classic display\n---------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\n_ = close_to_i3.display_timeline(ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"\")\nn_copy = close_to_i3.copy()\nn_copy.position_filter(2, 4)\nn_copy.plot(ax, color_cycle=n_copy.COLORS)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Latitude Timeline\n-----------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes(f\"Close segments ({close_to_i3.infos()})\")\nn_copy = close_to_i3.copy()\nn_copy.median_filter(15, \"time\", \"latitude\")\n_ = n_copy.display_timeline(ax, field=\"lat\", method=\"all\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Local radius timeline\n---------------------\nEffective (bold) and Speed (thin) Radius together\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n_copy.median_filter(2, \"time\", \"radius_e\")\nn_copy.median_filter(2, \"time\", \"radius_s\")\nfor b0, b1 in [\n (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007)\n]:\n ref, delta = datetime(1950, 1, 1), 20\n b0_, b1_ = (b0 - ref).days, (b1 - ref).days\n ax = timeline_axes()\n ax.set_xlim(b0_ - delta, b1_ + delta)\n ax.set_ylim(10, 115)\n ax.axvline(b0_, color=\"k\", lw=1.5, ls=\"--\"), ax.axvline(\n b1_, color=\"k\", lw=1.5, ls=\"--\"\n )\n n_copy.display_timeline(\n ax, field=\"radius_e\", method=\"all\", lw=4, markersize=8, factor=1e-3\n )\n n_copy.display_timeline(\n ax, field=\"radius_s\", method=\"all\", lw=1, markersize=3, factor=1e-3\n )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Parameters timeline\n-------------------\nEffective Radius\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(s=35, cmap=plt.get_cmap(\"Spectral_r\", 8), zorder=10)\nax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, \"radius_e\", factor=1e-3, vmin=20, vmax=100, **kw)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective radius (km)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Shape error\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, \"shape_error_e\", vmin=14, vmax=70, **kw)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Rotation angle\n--------------\nFor each obs, fit an ellipse to the contour, with theta the angle from the x-axis,\na the semi ax in x direction and b the semi ax in y dimension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "theta_ = list()\na_ = list()\nb_ = list()\nfor obs in close_to_i3:\n x, y = obs[\"contour_lon_s\"], obs[\"contour_lat_s\"]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n x0, y0, a, b, theta = fit_ellipse(x_, y_)\n theta_.append(theta)\n a_.append(a)\n b_.append(b)\na_ = array(a_)\nb_ = array(b_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Theta\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, theta_, vmin=-pi / 2, vmax=pi / 2, cmap=\"hsv\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "a\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, a_ * 1e-3, vmin=0, vmax=80, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "b\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, b_ * 1e-3, vmin=0, vmax=80, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "a/b\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, a_ / b_, vmin=1, vmax=2, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_relative.ipynb b/notebooks/python_module/16_network/pet_relative.ipynb new file mode 100644 index 00000000..9f3fd3d9 --- /dev/null +++ b/notebooks/python_module/16_network/pet_relative.ipynb @@ -0,0 +1,547 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Network basic manipulation\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import where\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data\nLoad data where observations are put in same network but no segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = NetworkObservations.load_file(data.get_demo_path(\"network_med.nc\")).network(651)\ni = where(\n (n.lat > 33)\n * (n.lat < 34)\n * (n.lon > 22)\n * (n.lon < 23)\n * (n.time > 20630)\n * (n.time < 20650)\n)[0][0]\n# For event use\nn2 = n.relative(i, order=2)\nn = n.relative(i, order=4)\nn.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display timeline with events\nA segment generated by a splitting is marked with a star\n\nA segment merging in another is marked with an exagon\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\n_ = n.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display timeline without event\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\n_ = n.display_timeline(ax, event=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline by mean latitude\nDisplay timeline with the mean latitude of the segments in yaxis\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\nax.set_ylabel(\"Latitude\")\n_ = n.display_timeline(ax, field=\"latitude\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline by mean Effective Radius\nThe factor argument is applied on the chosen field\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.04, 0.92, 0.92])\nax.set_ylabel(\"Effective Radius (km)\")\n_ = n.display_timeline(ax, field=\"radius_e\", factor=1e-3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Timeline by latitude\nUse `method=\"all\"` to display the consecutive values of the field\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.set_ylabel(\"Latitude\")\n_ = n.display_timeline(ax, field=\"lat\", method=\"all\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can filter the data, here with a time window of 15 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nn_copy = n.copy()\nn_copy.median_filter(15, \"time\", \"latitude\")\n_ = n_copy.display_timeline(ax, field=\"lat\", method=\"all\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameters timeline\nScatter is usefull to display the parameters' temporal evolution\n\nEffective Radius and Amplitude\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(s=25, cmap=\"Spectral_r\", zorder=10)\nfig = plt.figure(figsize=(15, 12))\nax = fig.add_axes([0.04, 0.54, 0.90, 0.44])\nm = n.scatter_timeline(ax, \"radius_e\", factor=1e-3, vmin=50, vmax=150, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.54, 0.01, 0.44]), orientation=\"vertical\"\n)\ncb.set_label(\"Effective radius (km)\")\n\nax = fig.add_axes([0.04, 0.04, 0.90, 0.44])\nm = n.scatter_timeline(ax, \"amplitude\", factor=100, vmin=0, vmax=15, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.44]), orientation=\"vertical\"\n)\ncb.set_label(\"Amplitude (cm)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Speed\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nm = n.scatter_timeline(ax, \"speed_average\", factor=100, vmin=0, vmax=40, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation=\"vertical\"\n)\ncb.set_label(\"Maximum speed (cm/s)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Speed Radius\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nm = n.scatter_timeline(ax, \"radius_s\", factor=1e-3, vmin=20, vmax=100, **kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation=\"vertical\"\n)\ncb.set_label(\"Speed radius (km)\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Remove dead branch\nRemove all tiny segments with less than N obs which didn't join two segments\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n_clean = n.remove_dead_end(nobs=5, ndays=10)\nfig = plt.figure(figsize=(15, 12))\nax = fig.add_axes([0.04, 0.54, 0.90, 0.40])\nax.set_title(f\"Original network ({n.infos()})\")\nn.display_timeline(ax)\nax = fig.add_axes([0.04, 0.04, 0.90, 0.40])\nax.set_title(f\"Clean network ({n_clean.infos()})\")\n_ = n_clean.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For further figure we will use clean path\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = n_clean" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Change splitting-merging events\nchange event where seg A split to B, then A merge into B, to A split to B then B merge into A\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 12))\nax = fig.add_axes([0.04, 0.54, 0.90, 0.40])\nax.set_title(f\"Clean network ({n.infos()})\")\nn.display_timeline(ax)\n\nclean_modified = n.copy()\n# If it's happen in less than 40 days\nclean_modified.correct_close_events(40)\n\nax = fig.add_axes([0.04, 0.04, 0.90, 0.40])\nax.set_title(f\"resplitted network ({clean_modified.infos()})\")\n_ = clean_modified.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Keep only observations where water could propagate from an observation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_observation = 600\nonly_linked = n.find_link(i_observation)\n\nfig = plt.figure(figsize=(15, 12))\nax1 = fig.add_axes([0.04, 0.54, 0.90, 0.40])\nax2 = fig.add_axes([0.04, 0.04, 0.90, 0.40])\n\nkw = dict(marker=\"s\", s=300, color=\"black\", zorder=200, label=\"observation start\")\nfor ax, dataset in zip([ax1, ax2], [n, only_linked]):\n dataset.display_timeline(ax, field=\"segment\", lw=2, markersize=5, colors_mode=\"y\")\n ax.scatter(n.time[i_observation], n.segment[i_observation], **kw)\n ax.legend()\n\nax1.set_title(f\"full example ({n.infos()})\")\nax2.set_title(f\"only linked observations ({only_linked.infos()})\")\n_ = ax2.set_xlim(ax1.get_xlim()), ax2.set_ylim(ax1.get_ylim())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Keep close relative\nWhen you want to investigate one particular observation and select only the closest segments\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# First choose an observation in the network\ni = 1100\n\nfig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nn.display_timeline(ax)\nobs_args = n.time[i], n.segment[i]\nobs_kw = dict(color=\"black\", markersize=30, marker=\".\")\n_ = ax.plot(*obs_args, **obs_kw)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Colors show the relative order of the segment with regards to the chosen one\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 6))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nm = n.scatter_timeline(\n ax, n.obs_relative_order(i), vmin=-1.5, vmax=6.5, cmap=plt.get_cmap(\"jet\", 8), s=10\n)\nax.plot(*obs_args, **obs_kw)\ncb = plt.colorbar(\n m[\"scatter\"], cax=fig.add_axes([0.95, 0.04, 0.01, 0.92]), orientation=\"vertical\"\n)\ncb.set_label(\"Relative order\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You want to keep only the segments at the order 1\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nclose_to_i1 = n.relative(i, order=1)\nax.set_title(f\"Close segments ({close_to_i1.infos()})\")\n_ = close_to_i1.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You want to keep the segments until order 2\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nclose_to_i2 = n.relative(i, order=2)\nax.set_title(f\"Close segments ({close_to_i2.infos()})\")\n_ = close_to_i2.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You want to keep the segments until order 3\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88])\nclose_to_i3 = n.relative(i, order=3)\nax.set_title(f\"Close segments ({close_to_i3.infos()})\")\n_ = close_to_i3.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Keep relatives to an event\nWhen you want to investigate one particular event and select only the closest segments\n\nFirst choose a merging event in the network\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "after, before, stopped = n.merging_event(triplet=True, only_index=True)\ni_event = 7" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "then see some order of relatives\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "max_order = 1\nfig, axs = plt.subplots(\n max_order + 2, 1, sharex=True, figsize=(15, 5 * (max_order + 2))\n)\n# Original network\nax = axs[0]\nax.set_title(\"Full network\", weight=\"bold\")\nn.display_timeline(axs[0], colors_mode=\"y\")\nax.grid(), ax.legend()\n\nfor k in range(0, max_order + 1):\n ax = axs[k + 1]\n ax.set_title(f\"Relatives order={k}\", weight=\"bold\")\n # Extract neighbours of event\n sub_network = n.find_segments_relative(after[i_event], stopped[i_event], order=k)\n sub_network.display_timeline(ax, colors_mode=\"y\")\n ax.legend(), ax.grid()\n _ = ax.set_ylim(axs[0].get_ylim())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display track on map\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Get a simplified network\nn = n2.remove_dead_end(nobs=50, recursive=1)\nn.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Only a map can be tricky to understand, with a timeline it's easier!\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nax = fig.add_axes([0.08, 0.7, 0.7, 0.3])\n_ = n.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get merging event\nDisplay the position of the eddies after a merging\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\nm1, m0, m0_stop = n.merging_event(triplet=True)\nm1.display(ax, color=\"violet\", lw=2, label=\"Eddies after merging\")\nm0.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies before merging\")\nm0_stop.display(ax, color=\"black\", lw=2, label=\"Eddies stopped by merging\")\nax.plot(m1.lon, m1.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(m0.lon, m0.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(m0_stop.lon, m0_stop.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nm1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get spliting event\nDisplay the position of the eddies before a splitting\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\ns0, s1, s1_start = n.spliting_event(triplet=True)\ns0.display(ax, color=\"violet\", lw=2, label=\"Eddies before splitting\")\ns1.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies after splitting\")\ns1_start.display(ax, color=\"black\", lw=2, label=\"Eddies starting by splitting\")\nax.plot(s0.lon, s0.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(s1.lon, s1.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(s1_start.lon, s1_start.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ns1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get birth event\nDisplay the starting position of non-splitted eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nbirth = n.birth_event()\nbirth.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nbirth" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get death event\nDisplay the last position of non-merged eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\ndeath = n.death_event()\ndeath.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ndeath" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_replay_segmentation.ipynb b/notebooks/python_module/16_network/pet_replay_segmentation.ipynb new file mode 100644 index 00000000..7c632138 --- /dev/null +++ b/notebooks/python_module/16_network/pet_replay_segmentation.ipynb @@ -0,0 +1,180 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Replay segmentation\nCase from figure 10 from https://doi.org/10.1002/2017JC013158\n\nAgain with the Ierapetra Eddy\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import where\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.04, 0.06, 0.89, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Class for new_segmentation\nThe oldest win\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class MyTrackEddiesObservations(TrackEddiesObservations):\n __slots__ = tuple()\n\n @classmethod\n def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs):\n \"\"\"\n Method to overwrite behaviour in merging.\n\n We will give the point to the older one instead of the maximum overlap ratio\n \"\"\"\n while i_next != -1:\n # Flag\n used[i_next] = True\n # Assign id\n ids[\"track\"][i_next] = track_id\n # Search next\n i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs)\n if i_next_ == -1:\n break\n ids[\"next_obs\"][i_next] = i_next_\n # Target was previously used\n if used[i_next_]:\n i_next_ = -1\n else:\n ids[\"previous_obs\"][i_next_] = i_next\n i_next = i_next_\n\n\ndef get_obs(dataset):\n \"Function to isolate a specific obs\"\n return where(\n (dataset.lat > 33)\n * (dataset.lat < 34)\n * (dataset.lon > 22)\n * (dataset.lon < 23)\n * (dataset.time > 20630)\n * (dataset.time < 20650)\n )[0][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get original network, we will isolate only relative at order *2*\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(651)\nn_ = n.relative(get_obs(n), order=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display the default segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(n_.infos())\nn_.plot(ax, color_cycle=n.COLORS)\nupdate_axes(ax)\nfig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.xaxis.set_major_formatter(formatter), ax.grid()\n_ = n_.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run a new segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "e = n.astype(MyTrackEddiesObservations)\ne.obs.sort(order=(\"track\", \"time\"), kind=\"stable\")\nsplit_matrix = e.split_network(intern=False, window=7)\nn_ = NetworkObservations.from_split_network(e, split_matrix)\nn_ = n_.relative(get_obs(n_), order=2)\nn_.numbering_segment()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## New segmentation\n\"The oldest wins\" method produce a very long segment\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(n_.infos())\nn_.plot(ax, color_cycle=n_.COLORS)\nupdate_axes(ax)\nfig = plt.figure(figsize=(15, 5))\nax = fig.add_axes([0.04, 0.05, 0.92, 0.92])\nax.xaxis.set_major_formatter(formatter), ax.grid()\n_ = n_.display_timeline(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameters timeline\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(s=35, cmap=plt.get_cmap(\"Spectral_r\", 8), zorder=10)\nax = timeline_axes()\nn_.median_filter(15, \"time\", \"latitude\")\nm = n_.scatter_timeline(ax, \"shape_error_e\", vmin=14, vmax=70, **kw, yfield=\"lat\")\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")\n\nax = timeline_axes()\nn_.median_filter(15, \"time\", \"latitude\")\nm = n_.scatter_timeline(\n ax, \"shape_error_e\", vmin=14, vmax=70, **kw, yfield=\"lat\", method=\"all\"\n)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")\nax.set_ylabel(\"Latitude\")\n\nax = timeline_axes()\nn_.median_filter(15, \"time\", \"latitude\")\nkw[\"s\"] = (n_.radius_e * 1e-3) ** 2 / 30 ** 2 * 20\nm = n_.scatter_timeline(\n ax,\n \"shape_error_e\",\n vmin=14,\n vmax=70,\n **kw,\n yfield=\"lon\",\n method=\"all\",\n)\nax.set_ylabel(\"Longitude\")\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cost association plot\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "n_copy = n_.copy()\nn_copy.median_filter(2, \"time\", \"next_cost\")\nfor b0, b1 in [\n (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008)\n]:\n\n ref, delta = datetime(1950, 1, 1), 20\n b0_, b1_ = (b0 - ref).days, (b1 - ref).days\n ax = timeline_axes()\n ax.set_xlim(b0_ - delta, b1_ + delta)\n ax.set_ylim(0, 1)\n ax.axvline(b0_, color=\"k\", lw=1.5, ls=\"--\"), ax.axvline(\n b1_, color=\"k\", lw=1.5, ls=\"--\"\n )\n n_copy.display_timeline(ax, field=\"next_cost\", method=\"all\", lw=4, markersize=8)\n\n n_.display_timeline(ax, field=\"next_cost\", method=\"all\", lw=0.5, markersize=0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_segmentation_anim.ipynb b/notebooks/python_module/16_network/pet_segmentation_anim.ipynb new file mode 100644 index 00000000..0a546832 --- /dev/null +++ b/notebooks/python_module/16_network/pet_segmentation_anim.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNetwork segmentation process\n============================\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# sphinx_gallery_thumbnail_number = 2\nimport re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.colors import ListedColormap\nfrom numpy import ones, where\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef get_obs(dataset):\n \"Function to isolate a specific obs\"\n return where(\n (dataset.lat > 33)\n * (dataset.lat < 34)\n * (dataset.lon > 22)\n * (dataset.lon < 23)\n * (dataset.time > 20630)\n * (dataset.time < 20650)\n )[0][0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Hack to pick up each step of segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "TRACKS = list()\nINDICES = list()\n\n\nclass MyTrack(TrackEddiesObservations):\n @staticmethod\n def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs):\n TRACKS.append(ids[\"track\"].copy())\n INDICES.append(i_current)\n return TrackEddiesObservations.get_next_obs(\n i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs\n )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data\n---------\nLoad data where observations are put in same network but no segmentation\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Get a known network for the demonstration\nn = NetworkObservations.load_file(get_demo_path(\"network_med.nc\")).network(651)\n# We keep only some segment\nn = n.relative(get_obs(n), order=2)\nprint(len(n))\n# We convert and order object like segmentation was never happen on observations\ne = n.astype(MyTrack)\ne.obs.sort(order=(\"track\", \"time\"), kind=\"stable\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Do segmentation\n---------------\nSegmentation based on maximum overlap, temporal window for candidates = 5 days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "matrix = e.split_network(intern=False, window=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Anim\n----\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(i_frame):\n tr = TRACKS[i_frame]\n mappable_tracks.set_array(tr)\n s = 40 * ones(tr.shape)\n s[tr == 0] = 4\n mappable_tracks.set_sizes(s)\n\n indices_frames = INDICES[i_frame]\n mappable_CONTOUR.set_data(\n e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames],\n )\n mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])\n return (mappable_tracks,)\n\n\nfig = plt.figure(figsize=(16, 9), dpi=60)\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nax.set_title(f\"{len(e)} observations to segment\")\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\nvmax = TRACKS[-1].max()\ncmap = ListedColormap([\"gray\", *e.COLORS[:-1]], name=\"from_list\", N=vmax)\nmappable_tracks = ax.scatter(\n e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20\n)\nmappable_CONTOUR = ax.plot(\n e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]\n)[0]\nani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Final Result\n------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(16, 9))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\n_ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/16_network/pet_something_cool.ipynb b/notebooks/python_module/16_network/pet_something_cool.ipynb new file mode 100644 index 00000000..158852f9 --- /dev/null +++ b/notebooks/python_module/16_network/pet_something_cool.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# essai\n\non tente des trucs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import cartopy.crs as ccrs\nimport cartopy.feature as cfeature\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\n\ndef rect_from_extent(extent):\n rect_lon = [extent[0], extent[1], extent[1], extent[0], extent[0]]\n rect_lat = [extent[2], extent[2], extent[3], extent[3], extent[2]]\n return rect_lon, rect_lat\n\n\ndef indice_from_extent(lon, lat, extent):\n mask = (lon > extent[0]) * (lon < extent[1]) * (lat > extent[2]) * (lat < extent[3])\n return np.where(mask)[0]\n\n\nfichier = \"/data/adelepoulle/work/Eddies/20201217_network_build/big_network.nc\"\nnetwork = NetworkObservations.load_file(fichier)\nsub_network = network.network(1078566)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# extent_begin = [0, 2, -50, -15]\n# extent_end = [-42, -35, -40, -10]\n\nextent_begin = [2, 22, -50, -30]\ni_obs_begin = indice_from_extent(\n sub_network.longitude, sub_network.latitude, extent_begin\n)\nnetwork_begin = sub_network.find_link(i_obs_begin)\ntime_mini = network_begin.time.min()\ntime_maxi = network_begin.time.max()\n\nextent_end = [-52, -45, -37, -33]\ni_obs_end = indice_from_extent(\n (network_begin.longitude + 180) % 360 - 180, network_begin.latitude, extent_end\n)\nnetwork_end = network_begin.find_link(i_obs_end, forward=False, backward=True)\n\n\ndatasets = [network_begin, network_end]\nextents = [extent_begin, extent_end]\nfig, (ax1, ax2) = plt.subplots(\n 2, 1, figsize=(10, 9), dpi=140, subplot_kw={\"projection\": ccrs.PlateCarree()}\n)\n\nfor ax, dataset, extent in zip([ax1, ax2], datasets, extents):\n sca = dataset.scatter(\n ax,\n name=\"time\",\n cmap=\"Spectral_r\",\n label=\"observation dans le temps\",\n vmin=time_mini,\n vmax=time_maxi,\n )\n\n x, y = rect_from_extent(extent)\n ax.fill(x, y, color=\"grey\", alpha=0.3, label=\"observations choisies\")\n # ax.plot(x, y, marker='o')\n\n ax.legend()\n\n gridlines = ax.gridlines(\n alpha=0.2, color=\"black\", linestyle=\"dotted\", draw_labels=True, dms=True\n )\n\n gridlines.left_labels = False\n gridlines.top_labels = False\n\n ax.coastlines()\n ax.add_feature(cfeature.LAND)\n ax.add_feature(cfeature.LAKES, zorder=10)\n ax.add_feature(cfeature.BORDERS, lw=0.25)\n ax.add_feature(cfeature.OCEAN, alpha=0.2)\n\n\nax1.set_title(\n \"Recherche du d\u00e9placement de l'eau dans les eddies \u00e0 travers les observations choisies\"\n)\nax2.set_title(\"Recherche de la provenance de l'eau \u00e0 travers les observations choisies\")\nax2.set_extent(ax1.get_extent(), ccrs.PlateCarree())\n\nfig.subplots_adjust(right=0.87, left=0.02)\ncbar_ax = fig.add_axes([0.90, 0.1, 0.02, 0.8])\ncbar = fig.colorbar(sca[\"scatter\"], cax=cbar_ax, orientation=\"vertical\")\n_ = cbar.set_label(\"time (jj)\", rotation=270, labelpad=-65)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eae54426..556cabbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,11 @@ -matplotlib -netCDF4 -numba -numpy +matplotlib < 3.8 # need an update of contour management opencv-python pint polygon3 pyyaml requests scipy -zarr -# for binder -pyeddytrackersample \ No newline at end of file +zarr < 3.0 +netCDF4 +numpy +numba \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 00000000..a005c37d --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,7 @@ +-r requirements.txt +isort +black +blackdoc +flake8 +pytest +pytest-cov \ No newline at end of file diff --git a/requirements_doc.txt b/requirements_doc.txt deleted file mode 100644 index 6a4a2937..00000000 --- a/requirements_doc.txt +++ /dev/null @@ -1,15 +0,0 @@ -matplotlib -netCDF4 -numba -numpy -opencv-python -pint -polygon3 -pyyaml -scipy -zarr -# doc -sphinx-gallery -pyeddytrackersample -sphinx_rtd_theme -sphinx>=3.1 diff --git a/setup.cfg b/setup.cfg index dfed5c3b..7e773ae8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,29 @@ + +[yapf] +column_limit = 100 + [flake8] +max-line-length = 140 ignore = - E203, # whitespace before ':' - W503, # line break before binary operator + E203, + W503, +exclude= + build + doc + versioneer.py + +[isort] +combine_as_imports=True +force_grid_wrap=0 +force_sort_within_sections=True +force_to_top=typing +include_trailing_comma=True +line_length=140 +multi_line_output=3 +skip= + build + doc/conf.py + [versioneer] VCS = git @@ -13,4 +35,5 @@ parentdir_prefix = [tool:pytest] filterwarnings= - ignore:tostring.*is deprecated \ No newline at end of file + ignore:tostring.*is deprecated + diff --git a/setup.py b/setup.py index d3f5d1bd..7b836763 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- +from setuptools import find_packages, setup + import versioneer -from setuptools import setup, find_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -9,7 +10,7 @@ setup( name="pyEddyTracker", - python_requires=">=3.7", + python_requires=">=3.10", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), description="Py-Eddy-Tracker libraries", @@ -28,7 +29,6 @@ scripts=[ "src/scripts/EddySubSetter", "src/scripts/EddyTranslate", - "src/scripts/EddyTracking", "src/scripts/EddyFinalTracking", "src/scripts/EddyMergeCorrespondances", ], @@ -43,9 +43,13 @@ "EddyFrequency = py_eddy_tracker.appli.eddies:get_frequency_grid", "EddyInfos = py_eddy_tracker.appli.eddies:display_infos", "EddyCircle = py_eddy_tracker.appli.eddies:eddies_add_circle", + "EddyTracking = py_eddy_tracker.appli.eddies:eddies_tracking", + "EddyQuickCompare = py_eddy_tracker.appli.eddies:quick_compare", # network "EddyNetworkGroup = py_eddy_tracker.appli.network:build_network", "EddyNetworkBuildPath = py_eddy_tracker.appli.network:divide_network", + "EddyNetworkSubSetter = py_eddy_tracker.appli.network:subset_network", + "EddyNetworkQuickCompare = py_eddy_tracker.appli.network:quick_compare", # anim/gui "EddyAnim = py_eddy_tracker.appli.gui:anim", "GUIEddy = py_eddy_tracker.appli.gui:guieddy", diff --git a/share/fig.py b/share/fig.py index 8640abcb..80c7f12b 100644 --- a/share/fig.py +++ b/share/fig.py @@ -1,8 +1,10 @@ -from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset from datetime import datetime import logging +from matplotlib import pyplot as plt + +from py_eddy_tracker.dataset.grid import RegularGridDataset + grid_name, lon_name, lat_name = ( "nrt_global_allsat_phy_l4_20190223_20190226.nc", "longitude", diff --git a/share/nrt_global_allsat_phy_l4_20190223_20190226.nc b/share/nrt_global_allsat_phy_l4_20190223_20190226.nc new file mode 120000 index 00000000..077ce7e6 --- /dev/null +++ b/share/nrt_global_allsat_phy_l4_20190223_20190226.nc @@ -0,0 +1 @@ +../src/py_eddy_tracker/data/nrt_global_allsat_phy_l4_20190223_20190226.nc \ No newline at end of file diff --git a/share/tracking.yaml b/share/tracking.yaml index 0f8766b8..d6264104 100644 --- a/share/tracking.yaml +++ b/share/tracking.yaml @@ -1,13 +1,14 @@ PATHS: - # Files produces with EddyIdentification + # Files produced with EddyIdentification FILES_PATTERN: /home/emason/toto/Anticyclonic_*.nc - # Path for saving of outputs + # Path to save outputs SAVE_DIR: '/home/emason/toto/' -# Minimum number of observations to store eddy -TRACK_DURATION_MIN: 4 +# Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 0 +# Minimal number of timesteps to considered as a long trajectory +TRACK_DURATION_MIN: 4 -#CLASS: -# MODULE: py_eddy_tracker.featured_tracking.old_tracker_reference -# CLASS: CheltonTracker +CLASS: + MODULE: py_eddy_tracker.featured_tracking.area_tracker + CLASS: AreaTracker diff --git a/src/py_eddy_tracker/__init__.py b/src/py_eddy_tracker/__init__.py index 9be2d280..7115bf67 100644 --- a/src/py_eddy_tracker/__init__.py +++ b/src/py_eddy_tracker/__init__.py @@ -20,8 +20,9 @@ """ -import logging from argparse import ArgumentParser +from datetime import datetime +import logging import zarr @@ -31,15 +32,13 @@ del get_versions -def start_logger(): - FORMAT_LOG = ( - "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t\t\t\t\t%(message)s" - ) +def start_logger(color=True): + FORMAT_LOG = "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t%(message)s" logger = logging.getLogger("pet") if len(logger.handlers) == 0: # set up logging to CONSOLE console = logging.StreamHandler() - console.setFormatter(ColoredFormatter(FORMAT_LOG)) + console.setFormatter(ColoredFormatter(FORMAT_LOG, color=color)) # add the handler to the root logger logger.addHandler(console) return logger @@ -54,13 +53,14 @@ class ColoredFormatter(logging.Formatter): DEBUG="\033[34m\t", ) - def __init__(self, message): + def __init__(self, message, color=True): super().__init__(message) + self.with_color = color def format(self, record): color = self.COLOR_LEVEL.get(record.levelname, "") color_reset = "\033[0m" - model = color + "%s" + color_reset + model = (color + "%s" + color_reset) if self.with_color else "%s" record.msg = model % record.msg record.funcName = model % record.funcName record.module = model % record.module @@ -85,6 +85,20 @@ def add_base_argument(self): help="Levels : DEBUG, INFO, WARNING," " ERROR, CRITICAL", ) + def memory_arg(self): + self.add_argument( + "--memory", + action="store_true", + help="Load file in memory before to read with netCDF library", + ) + + def contour_intern_arg(self): + self.add_argument( + "--intern", + action="store_true", + help="Use intern contour instead of outter contour", + ) + def parse_args(self, *args, **kwargs): logger = start_logger() # Parsing @@ -94,12 +108,26 @@ def parse_args(self, *args, **kwargs): return opts +TIME_MODELS = ["%Y%m%d", "%Y%m%d%H%M%S", "%Y%m%dT%H%M%S"] + + +def identify_time(str_date): + for model in TIME_MODELS: + try: + return datetime.strptime(str_date, model) + except ValueError: + pass + raise Exception("No time model found") + + VAR_DESCR = dict( time=dict( attr_name="time", nc_name="time", old_nc_name=["j1"], - nc_type="int32", + nc_type="float64", + output_type="uint32", + scale_factor=1 / 86400.0, nc_dims=("obs",), nc_attr=dict( standard_name="time", @@ -118,7 +146,7 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( long_name="Rotating sense of the eddy", - comment="Cyclonic -1; Anti-cyclonic +1", + comment="Cyclonic -1; Anticyclonic +1", ), ), segment_size=dict( @@ -183,7 +211,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( units="degrees_east", axis="X", - comment="Longitude center of the fitted circle", + comment="Longitude center of the best fit circle", long_name="Eddy Center Longitude", standard_name="longitude", ), @@ -200,7 +228,7 @@ def parse_args(self, *args, **kwargs): axis="Y", long_name="Eddy Center Latitude", standard_name="latitude", - comment="Latitude center of the fitted circle", + comment="Latitude center of the best fit circle", ), ), lon_max=dict( @@ -215,6 +243,7 @@ def parse_args(self, *args, **kwargs): axis="X", long_name="Longitude of the SSH maximum", standard_name="longitude", + comment="Longitude of the inner contour", ), ), lat_max=dict( @@ -229,6 +258,7 @@ def parse_args(self, *args, **kwargs): axis="Y", long_name="Latitude of the SSH maximum", standard_name="latitude", + comment="Latitude of the inner contour", ), ), amplitude=dict( @@ -237,7 +267,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["A"], nc_type="float32", output_type="uint16", - scale_factor=0.001, + scale_factor=0.0001, nc_dims=("obs",), nc_attr=dict( long_name="Amplitude", @@ -254,7 +284,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Speed area", units="m^2", - comment="Area enclosed by speed contour in m^2", + comment="Area enclosed by the speed contour in m^2", ), ), effective_area=dict( @@ -265,7 +295,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Effective area", units="m^2", - comment="Area enclosed by effective contour in m^2", + comment="Area enclosed by the effective contour in m^2", ), ), speed_average=dict( @@ -293,7 +323,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Radial Speed Profile", units="m/s", - comment="Speed average values from effective contour inwards to smallest contour, evenly spaced points", + comment="Speed averaged values from the effective contour inwards to the smallest contour, evenly spaced points", ), ), i=dict( @@ -332,7 +362,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Effective Radius", units="m", - comment="Radius of a circle whose area is equal to that enclosed by the effective contour", + comment="Radius of the best fit circle corresponding to the effective contour", ), ), radius_s=dict( @@ -346,8 +376,7 @@ def parse_args(self, *args, **kwargs): nc_attr=dict( long_name="Speed Radius", units="m", - comment="Radius of a circle whose area is equal to that " - "enclosed by the contour of maximum circum-average speed", + comment="Radius of the best fit circle corresponding to the contour of maximum circum-average speed", ), ), track=dict( @@ -360,15 +389,55 @@ def parse_args(self, *args, **kwargs): long_name="Trajectory number", comment="Trajectory identification number" ), ), - sub_track=dict( + segment=dict( attr_name=None, - nc_name="sub_track", + nc_name="segment", nc_type="uint32", nc_dims=("obs",), nc_attr=dict( long_name="Segment Number", comment="Segment number inside a group" ), ), + previous_obs=dict( + attr_name=None, + nc_name="previous_obs", + nc_type="int32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Previous observation index", + comment="Index of previous observation in a splitting case", + ), + ), + next_obs=dict( + attr_name=None, + nc_name="next_obs", + nc_type="int32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Next observation index", + comment="Index of next observation in a merging case", + ), + ), + previous_cost=dict( + attr_name=None, + nc_name="previous_cost", + nc_type="float32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Previous cost for previous observation", + comment="", + ), + ), + next_cost=dict( + attr_name=None, + nc_name="next_cost", + nc_type="float32", + nc_dims=("obs",), + nc_attr=dict( + long_name="Next cost for next observation", + comment="", + ), + ), n=dict( attr_name=None, nc_name="observation_number", @@ -419,9 +488,9 @@ def parse_args(self, *args, **kwargs): nc_type="u2", nc_dims=("obs",), nc_attr=dict( - longname="number of point for effective contour", + long_name="number of points for effective contour", units="ordinal", - description="Number of point for effective contour, if greater than NbSample, there is a resampling", + description="Number of points for effective contour before resampling", ), ), contour_lon_s=dict( @@ -463,9 +532,9 @@ def parse_args(self, *args, **kwargs): nc_type="u2", nc_dims=("obs",), nc_attr=dict( - longname="number of point for speed contour", + long_name="number of points for speed contour", units="ordinal", - description="Number of point for speed contour, if greater than NbSample, there is a resampling", + description="Number of points for speed contour before resampling", ), ), shape_error_e=dict( @@ -478,8 +547,8 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( units="%", - comment="Error criterion between the effective contour and its fit with the circle of same effective radius", - long_name="Effective Contour Error", + comment="Error criterion between the effective contour and its best fit circle", + long_name="Effective Contour Shape Error", ), ), score=dict( @@ -512,8 +581,8 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( units="%", - comment="Error criterion between the speed contour and its fit with the circle of same speed radius", - long_name="Speed Contour Error", + comment="Error criterion between the speed contour and its best fit circle", + long_name="Speed Contour Shape Error", ), ), height_max_speed_contour=dict( @@ -628,3 +697,6 @@ def parse_args(self, *args, **kwargs): VAR_DESCR_inv[VAR_DESCR[key]["nc_name"]] = key for key_old in VAR_DESCR[key].get("old_nc_name", list()): VAR_DESCR_inv[key_old] = key + +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/src/py_eddy_tracker/_version.py b/src/py_eddy_tracker/_version.py index 44367e3a..589e706f 100644 --- a/src/py_eddy_tracker/_version.py +++ b/src/py_eddy_tracker/_version.py @@ -1,11 +1,13 @@ + # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -14,9 +16,11 @@ import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -32,8 +36,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -51,41 +62,50 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} - +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -96,18 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -116,64 +138,64 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -186,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -199,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -207,30 +229,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -241,7 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -249,33 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -284,16 +341,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -302,12 +360,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -318,26 +374,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -355,29 +412,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -404,12 +510,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -426,7 +561,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -446,7 +581,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -466,26 +601,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -495,16 +632,12 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -515,7 +648,8 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) except NotThisMethod: pass @@ -524,16 +658,13 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -547,10 +678,6 @@ def get_versions(): except NotThisMethod: pass - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/src/py_eddy_tracker/appli/eddies.py b/src/py_eddy_tracker/appli/eddies.py index f59eab90..c1c7a90d 100644 --- a/src/py_eddy_tracker/appli/eddies.py +++ b/src/py_eddy_tracker/appli/eddies.py @@ -3,12 +3,23 @@ Applications on detection and tracking files """ import argparse +from datetime import datetime +from glob import glob +import logging +from os import mkdir +from os.path import basename, dirname, exists, join as join_path +from re import compile as re_compile from netCDF4 import Dataset +from numpy import bincount, bytes_, empty, in1d, unique +from yaml import safe_load -from .. import EddyParser -from ..observations.observation import EddiesObservations +from .. import EddyParser, identify_time +from ..observations.observation import EddiesObservations, reverse_index from ..observations.tracking import TrackEddiesObservations +from ..tracking import Correspondances + +logger = logging.getLogger("pet") def eddies_add_circle(): @@ -41,24 +52,22 @@ def merge_eddies(): parser.add_argument( "--include_var", nargs="+", type=str, help="use only listed variable" ) + parser.memory_arg() args = parser.parse_args() if args.include_var is None: with Dataset(args.filename[0]) as h: args.include_var = h.variables.keys() - obs = TrackEddiesObservations.load_file( - args.filename[0], raw_data=True, include_vars=args.include_var - ) - if args.add_rotation_variable: - obs = obs.add_rotation_type() - for filename in args.filename[1:]: - other = TrackEddiesObservations.load_file( + obs = list() + for filename in args.filename: + e = TrackEddiesObservations.load_file( filename, raw_data=True, include_vars=args.include_var ) if args.add_rotation_variable: - other = other.add_rotation_type() - obs = obs.merge(other) + e = e.add_rotation_type() + obs.append(e) + obs = TrackEddiesObservations.concatenate(obs) obs.write_file(filename=args.out) @@ -66,11 +75,7 @@ def get_frequency_grid(): parser = EddyParser("Compute eddy frequency") parser.add_argument("observations", help="Input observations to compute frequency") parser.add_argument("out", help="Grid output file") - parser.add_argument( - "--intern", - help="Use speed contour instead of effective contour", - action="store_true", - ) + parser.contour_intern_arg() parser.add_argument( "--xrange", nargs="+", type=float, help="Horizontal range : START,STOP,STEP" ) @@ -141,3 +146,437 @@ def display_infos(): ) e = e.extract_with_area(area) print(e) + + +def eddies_tracking(): + parser = EddyParser("Tool to use identification step to compute tracking") + parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker") + parser.add_argument("--correspondance_in", help="Filename of saved correspondance") + parser.add_argument("--correspondance_out", help="Filename to save correspondance") + parser.add_argument( + "--save_correspondance_and_stop", + action="store_true", + help="Stop tracking after correspondance computation," + " merging can be done with EddyFinalTracking", + ) + parser.add_argument( + "--zarr", action="store_true", help="Output will be wrote in zarr" + ) + parser.add_argument( + "--unraw", + action="store_true", + help="Load unraw data, use only for netcdf." + "If unraw is active, netcdf is loaded without apply scalefactor and add_offset.", + ) + parser.add_argument( + "--blank_period", + type=int, + default=0, + help="Nb of detection which will not use at the end of the period", + ) + parser.memory_arg() + args = parser.parse_args() + + # Read yaml configuration file + with open(args.yaml_file, "r") as stream: + config = safe_load(stream) + + if "CLASS" in config: + classname = config["CLASS"]["CLASS"] + obs_class = dict( + class_method=getattr( + __import__(config["CLASS"]["MODULE"], globals(), locals(), classname), + classname, + ), + class_kw=config["CLASS"].get("OPTIONS", dict()), + ) + else: + obs_class = dict() + + c_in, c_out = args.correspondance_in, args.correspondance_out + if c_in is None: + c_in = config["PATHS"].get("CORRESPONDANCES_IN", None) + y_c_out = config["PATHS"].get( + "CORRESPONDANCES_OUT", "{path}/{sign_type}_correspondances.nc" + ) + if c_out is None: + c_out = y_c_out + + # Create ouput folder if necessary + save_dir = config["PATHS"].get("SAVE_DIR", None) + if save_dir is not None and not exists(save_dir): + mkdir(save_dir) + + track( + pattern=config["PATHS"]["FILES_PATTERN"], + output_dir=save_dir, + c_out=c_out, + **obs_class, + virtual=int(config.get("VIRTUAL_LENGTH_MAX", 0)), + previous_correspondance=c_in, + memory=args.memory, + correspondances_only=args.save_correspondance_and_stop, + raw=not args.unraw, + zarr=args.zarr, + nb_obs_min=int(config.get("TRACK_DURATION_MIN", 10)), + blank_period=args.blank_period, + ) + + +def browse_dataset_in( + data_dir, + files_model, + date_regexp, + date_model=None, + start_date=None, + end_date=None, + sub_sampling_step=1, + files=None, +): + pattern_regexp = re_compile(".*/" + date_regexp) + if files is not None: + filenames = bytes_(files) + else: + full_path = join_path(data_dir, files_model) + logger.info("Search files : %s", full_path) + filenames = bytes_(glob(full_path)) + + dataset_list = empty( + len(filenames), + dtype=[("filename", "S500"), ("date", "datetime64[s]")], + ) + dataset_list["filename"] = filenames + + logger.info("%s grids available", dataset_list.shape[0]) + mode_attrs = False + if "(" not in date_regexp: + logger.debug("Attrs date : %s", date_regexp) + mode_attrs = date_regexp.strip().split(":") + else: + logger.debug("Pattern date : %s", date_regexp) + + for item in dataset_list: + str_date = None + if mode_attrs: + with Dataset(item["filename"].decode("utf-8")) as h: + if len(mode_attrs) == 1: + str_date = getattr(h, mode_attrs[0]) + else: + str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1]) + else: + result = pattern_regexp.match(str(item["filename"])) + if result: + str_date = result.groups()[0] + + if str_date is not None: + if date_model is None: + item["date"] = identify_time(str_date) + else: + item["date"] = datetime.strptime(str_date, date_model) + + dataset_list.sort(order=["date", "filename"]) + steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1]) + if len(steps) > 1: + raise Exception("Several timesteps in grid dataset %s" % steps) + + if sub_sampling_step != 1: + logger.info("Grid subsampling %d", sub_sampling_step) + dataset_list = dataset_list[::sub_sampling_step] + + if start_date is not None or end_date is not None: + logger.info( + "Available grid from %s to %s", + dataset_list[0]["date"], + dataset_list[-1]["date"], + ) + logger.info("Filtering grid by time %s, %s", start_date, end_date) + mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date) + + dataset_list = dataset_list[mask] + return dataset_list + + +def track( + pattern, + output_dir, + c_out, + nb_obs_min=10, + raw=True, + zarr=False, + blank_period=0, + correspondances_only=False, + **kw_c, +): + kw = dict(date_regexp=".*_([0-9]*?).[nz].*") + if isinstance(pattern, list): + kw.update(dict(data_dir=None, files_model=None, files=pattern)) + else: + kw.update(dict(data_dir=dirname(pattern), files_model=basename(pattern))) + datasets = browse_dataset_in(**kw) + if blank_period > 0: + datasets = datasets[:-blank_period] + logger.info("Last %d files will be pop", blank_period) + + if nb_obs_min > len(datasets): + raise Exception( + "Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)." + % (len(datasets), nb_obs_min) + ) + + c = Correspondances(datasets=datasets["filename"], **kw_c) + c.track() + logger.info("Track finish") + kw_save = dict( + date_start=datasets["date"][0], + date_stop=datasets["date"][-1], + date_prod=datetime.now(), + path=output_dir, + sign_type=c.current_obs.sign_legend, + ) + + c.save(c_out, kw_save) + if correspondances_only: + return + + logger.info("Start merging") + c.prepare_merging() + logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max()) + logger.info( + "The mean length is %d observations for all tracks", c.nb_obs_by_tracks.mean() + ) + + kw_write = dict(path=output_dir, zarr_flag=zarr) + + c.get_unused_data(raw_data=raw).write_file( + filename="%(path)s/%(sign_type)s_untracked.nc", **kw_write + ) + + short_c = c._copy() + short_c.shorter_than(size_max=nb_obs_min) + short_track = short_c.merge(raw_data=raw) + + if c.longer_than(size_min=nb_obs_min) is False: + long_track = short_track.empty_dataset() + else: + long_track = c.merge(raw_data=raw) + + # We flag obs + if c.virtual: + long_track["virtual"][:] = long_track["time"] == 0 + long_track.normalize_longitude() + long_track.filled_by_interpolation(long_track["virtual"] == 1) + short_track["virtual"][:] = short_track["time"] == 0 + short_track.normalize_longitude() + short_track.filled_by_interpolation(short_track["virtual"] == 1) + + logger.info("Longer track saved have %d obs", c.nb_obs_by_tracks.max()) + logger.info( + "The mean length is %d observations for long track", + c.nb_obs_by_tracks.mean(), + ) + + long_track.write_file(**kw_write) + short_track.write_file( + filename="%(path)s/%(sign_type)s_track_too_short.nc", **kw_write + ) + + +def get_group( + dataset1, + dataset2, + index1, + index2, + score, + invalid=2, + low=10, + high=60, +): + group1, group2 = dict(), dict() + m_valid = (score * 100) >= invalid + i1, i2, score = index1[m_valid], index2[m_valid], score[m_valid] * 100 + # Eddies with no association & scores < invalid + group1["nomatch"] = reverse_index(i1, len(dataset1)) + group2["nomatch"] = reverse_index(i2, len(dataset2)) + # Select all eddies involved in multiple associations + i1_, nb1 = unique(i1, return_counts=True) + i2_, nb2 = unique(i2, return_counts=True) + i1_multi = i1_[nb1 >= 2] + i2_multi = i2_[nb2 >= 2] + m_multi = in1d(i1, i1_multi) + in1d(i2, i2_multi) + + # Low scores + m_low = score < low + m_low *= ~m_multi + group1["low"] = i1[m_low] + group2["low"] = i2[m_low] + # Intermediate scores + m_i = (score >= low) * (score < high) + m_i *= ~m_multi + group1["intermediate"] = i1[m_i] + group2["intermediate"] = i2[m_i] + # High scores + m_high = score >= high + m_high *= ~m_multi + group1["high"] = i1[m_high] + group2["high"] = i2[m_high] + + # Here for a nice display order + group1["multi_match"] = unique(i1[m_multi]) + group2["multi_match"] = unique(i2[m_multi]) + + def get_twin(j2, j1): + # True only if j1 is used only one + m = bincount(j1)[j1] == 1 + # We keep only link of this mask j1 have exactly one parent + j2_ = j2[m] + # We count parent times + m_ = (bincount(j2_)[j2_] == 2) * (bincount(j2)[j2_] == 2) + # we fill first mask with second one + m[m] = m_ + return m + + m1 = get_twin(i1, i2) + m2 = get_twin(i2, i1) + group1["parent"] = unique(i1[m1]) + group2["parent"] = unique(i2[m2]) + group1["twin"] = i1[m2] + group2["twin"] = i2[m1] + + m = ~m1 * ~m2 * m_multi + group1["complex"] = unique(i1[m]) + group2["complex"] = unique(i2[m]) + + return group1, group2 + + +def run_compare(ref, others, invalid=1, low=20, high=80, intern=False, **kwargs): + groups_ref, groups_other = dict(), dict() + for i, (k, other) in enumerate(others.items()): + print(f"[{i}] {k} -> {len(other)} obs") + gr1, gr2 = get_group( + ref, + other, + *ref.match(other, intern=intern, **kwargs), + invalid=invalid, + low=low, + high=high, + ) + groups_ref[k] = gr1 + groups_other[k] = gr2 + return groups_ref, groups_other + + +def display_compare( + ref, others, invalid=1, low=20, high=80, area=False, intern=False, **kwargs +): + gr_ref, gr_others = run_compare( + ref, others, invalid=invalid, low=low, high=high, intern=intern, **kwargs + ) + + def display(value, ref=None): + outs = list() + for v in value: + if ref: + if area: + outs.append(f"{v / ref * 100:.1f}% ({v:.1f}Mkm²)") + else: + outs.append(f"{v/ref * 100:.1f}% ({v})") + else: + outs.append(v) + if area: + return "".join([f"{v:^16}" for v in outs]) + else: + return "".join([f"{v:^15}" for v in outs]) + + def get_values(v, dataset): + if area: + area_ = dataset["speed_area" if intern else "effective_area"] + return [area_[v_].sum() / 1e12 for v_ in v.values()] + else: + return [ + v_.sum() if v_.dtype == "bool" else v_.shape[0] for v_ in v.values() + ] + + labels = dict( + high=f"{high:0.0f} <= high", + low=f"{invalid:0.0f} <= low < {low:0.0f}", + ) + + keys = [labels.get(key, key) for key in list(gr_ref.values())[0].keys()] + print(" ", display(keys)) + if area: + ref_ = ref["speed_area" if intern else "effective_area"].sum() / 1e12 + else: + ref_ = len(ref) + for i, v in enumerate(gr_ref.values()): + print(f"[{i:2}] ", display(get_values(v, ref), ref=ref_)) + + print(" Point of view of study dataset") + print(" ", display(keys)) + for i, (k, v) in enumerate(gr_others.items()): + other = others[k] + if area: + ref_ = other["speed_area" if intern else "effective_area"].sum() / 1e12 + else: + ref_ = len(other) + print(f"[{i:2}] ", display(get_values(v, other), ref=ref_)) + + +def quick_compare(): + parser = EddyParser( + "Tool to have a quick comparison between several identification" + ) + parser.add_argument("ref", help="Identification file of reference") + parser.add_argument("others", nargs="+", help="Identifications files to compare") + help = "Display in percent of area instead percent of observation" + parser.add_argument("--area", action="store_true", help=help) + help = "Use minimal cost function" + parser.add_argument("--minimal_area", action="store_true", help=help) + parser.add_argument("--high", default=40, type=float) + parser.add_argument("--low", default=20, type=float) + parser.add_argument("--invalid", default=5, type=float) + parser.add_argument( + "--path_out", default=None, help="Save each group in separate file" + ) + parser.contour_intern_arg() + args = parser.parse_args() + + kw = dict( + include_vars=[ + "longitude", + *EddiesObservations.intern(args.intern, public_label=True), + ] + ) + if args.area: + kw["include_vars"].append("speed_area" if args.intern else "effective_area") + + if args.path_out is not None: + kw = dict() + + ref = EddiesObservations.load_file(args.ref, **kw) + print(f"[ref] {args.ref} -> {len(ref)} obs") + others = {other: EddiesObservations.load_file(other, **kw) for other in args.others} + + kwargs = dict( + invalid=args.invalid, + low=args.low, + high=args.high, + intern=args.intern, + minimal_area=args.minimal_area, + ) + if args.path_out is not None: + groups_ref, groups_other = run_compare(ref, others, **kwargs) + if not exists(args.path_out): + mkdir(args.path_out) + for i, other_ in enumerate(args.others): + dirname_ = f"{args.path_out}/{other_.replace('/', '_')}/" + if not exists(dirname_): + mkdir(dirname_) + for k, v in groups_other[other_].items(): + basename_ = f"other_{k}.nc" + others[other_].index(v).write_file(filename=f"{dirname_}/{basename_}") + for k, v in groups_ref[other_].items(): + basename_ = f"ref_{k}.nc" + ref.index(v).write_file(filename=f"{dirname_}/{basename_}") + return + display_compare(ref, others, **kwargs, area=args.area) diff --git a/src/py_eddy_tracker/appli/grid.py b/src/py_eddy_tracker/appli/grid.py index 39456643..099465ee 100644 --- a/src/py_eddy_tracker/appli/grid.py +++ b/src/py_eddy_tracker/appli/grid.py @@ -3,9 +3,8 @@ All entry point to manipulate grid """ from argparse import Action -from datetime import datetime -from .. import EddyParser +from .. import EddyParser, identify_time from ..dataset.grid import RegularGridDataset, UnRegularGridDataset @@ -69,34 +68,69 @@ def eddy_id(args=None): parser.add_argument("longitude") parser.add_argument("latitude") parser.add_argument("path_out") - help = "Wavelength for mesoscale filter in km" - parser.add_argument("--cut_wavelength", default=500, type=float, help=help) + help = ( + "Wavelength for mesoscale filter in km to remove low scale if 2 args is given first one will be" + "used to remove high scale and second value to remove low scale" + ) + parser.add_argument( + "--cut_wavelength", default=[500], type=float, help=help, nargs="+" + ) parser.add_argument("--filter_order", default=3, type=int) help = "Step between 2 isoline in m" parser.add_argument("--isoline_step", default=0.002, type=float, help=help) help = "Error max accepted to fit circle in percent" parser.add_argument("--fit_errmax", default=55, type=float, help=help) + parser.add_argument( + "--lat_max", default=85, type=float, help="Maximal latitude filtered" + ) parser.add_argument("--height_unit", default=None, help="Force height unit") parser.add_argument("--speed_unit", default=None, help="Force speed unit") parser.add_argument("--unregular", action="store_true", help="if grid is unregular") + help = "Array size used to build contour, first and last point will be the same" + parser.add_argument("--sampling", default=50, type=int, help=help) + parser.add_argument( + "--sampling_method", + default="visvalingam", + type=str, + choices=("visvalingam", "uniform"), + help="Method to resample contour", + ) help = "Output will be wrote in zarr" parser.add_argument("--zarr", action="store_true", help=help) help = "Indexs to select grid : --indexs time=2, will select third step along time dimensions" + parser.add_argument("--indexs", nargs="*", help=help, action=DictAction) + help = "Number of pixel of grid detection which could be in an eddies, you must specify MIN and MAX." parser.add_argument( - "--indexs", - nargs="*", - help=help, - action=DictAction, + "--pixel_limit", nargs="+", default=(5, 2000), type=int, help=help ) + help = "Minimal number of amplitude in number of step" + parser.add_argument("--nb_step_min", default=2, type=int, help=help) args = parser.parse_args(args) if args else parser.parse_args() - date = datetime.strptime(args.datetime, "%Y%m%d") + + if len(args.pixel_limit) != 2: + raise Exception( + "You must define two value minimal number of pixel and maximal number of pixel" + ) + + cut_wavelength = args.cut_wavelength + nb_cw = len(cut_wavelength) + if nb_cw > 2 or nb_cw == 0: + raise Exception("You must specify 1 or 2 values for cut wavelength.") + elif nb_cw == 1: + cut_wavelength = [0, *cut_wavelength] + inf_bnds, upper_bnds = cut_wavelength + + date = identify_time(args.datetime) kwargs = dict( step=args.isoline_step, shape_error=args.fit_errmax, - pixel_limit=(5, 2000), + pixel_limit=args.pixel_limit, force_height_unit=args.height_unit, force_speed_unit=args.speed_unit, + nb_step_to_be_mle=0, + nb_step_min=args.nb_step_min, ) + a, c = identification( args.filename, args.longitude, @@ -106,12 +140,16 @@ def eddy_id(args=None): args.u, args.v, unregular=args.unregular, - cut_wavelength=args.cut_wavelength, + cut_wavelength=upper_bnds, + cut_highwavelength=inf_bnds, + lat_max=args.lat_max, filter_order=args.filter_order, indexs=args.indexs, + sampling=args.sampling, + sampling_method=args.sampling_method, **kwargs, ) - out_name = date.strftime("%(path)s/%(sign_type)s_%Y%m%d.nc") + out_name = date.strftime("%(path)s/%(sign_type)s_%Y%m%dT%H%M%S.nc") a.write_file(path=args.path_out, filename=out_name, zarr_flag=args.zarr) c.write_file(path=args.path_out, filename=out_name, zarr_flag=args.zarr) @@ -126,6 +164,8 @@ def identification( v="None", unregular=False, cut_wavelength=500, + cut_highwavelength=0, + lat_max=85, filter_order=1, indexs=None, **kwargs @@ -135,6 +175,9 @@ def identification( if u == "None" and v == "None": grid.add_uv(h) u, v = "u", "v" + kw_filter = dict(order=filter_order, lat_max=lat_max) + if cut_highwavelength != 0: + grid.bessel_low_filter(h, cut_highwavelength, **kw_filter) if cut_wavelength != 0: - grid.bessel_high_filter(h, cut_wavelength, order=filter_order) + grid.bessel_high_filter(h, cut_wavelength, **kw_filter) return grid.eddy_identification(h, u, v, date, **kwargs) diff --git a/src/py_eddy_tracker/appli/gui.py b/src/py_eddy_tracker/appli/gui.py index ec843400..c3d7619b 100644 --- a/src/py_eddy_tracker/appli/gui.py +++ b/src/py_eddy_tracker/appli/gui.py @@ -3,18 +3,23 @@ Entry point of graphic user interface """ -from datetime import datetime +from datetime import datetime, timedelta +from itertools import chain +import logging from matplotlib import pyplot +from matplotlib.animation import FuncAnimation +from matplotlib.axes import Axes from matplotlib.collections import LineCollection -from numpy import arange, empty, where +from numpy import arange, where, nan from .. import EddyParser -from ..generic import flatten_line_matrix from ..gui import GUI from ..observations.tracking import TrackEddiesObservations from ..poly import create_vertice +logger = logging.getLogger("pet") + class Anim: def __init__( @@ -29,29 +34,77 @@ def __init__( self.period = self.eddy.period self.sleep_event = sleep_event self.mappables = list() + self.field_color = None + self.field_txt = None + self.time_field = False + self.txt = None + self.ax = None + self.kw_label = dict() self.setup(**kwargs) - def setup(self, cmap="jet", nb_step=25, figsize=(8, 6), **kwargs): - cmap = pyplot.get_cmap(cmap) - self.colors = cmap(arange(nb_step + 1) / nb_step) + def setup( + self, + cmap="jet", + lut=None, + field_color="time", + field_txt="track", + range_color=(None, None), + nb_step=25, + figsize=(8, 6), + position=(0.05, 0.05, 0.9, 0.9), + **kwargs, + ): + self.kw_label["fontsize"] = kwargs.pop("fontsize", 12) + self.kw_label["fontweight"] = kwargs.pop("fontweight", "demibold") + # To text each visible eddy + if field_txt: + if isinstance(field_txt,str): + self.field_txt = self.eddy[field_txt] + else : + self.field_txt=field_txt + if field_color: + # To color each visible eddy + self.field_color = self.eddy[field_color].astype("f4") + rg = range_color + if rg[0] is None and rg[1] is None and field_color == "time": + self.time_field = True + else: + rg = ( + self.field_color.min() if rg[0] is None else rg[0], + self.field_color.max() if rg[1] is None else rg[1], + ) + self.field_color = (self.field_color - rg[0]) / (rg[1] - rg[0]) + self.colors = pyplot.get_cmap(cmap, lut=lut) self.nb_step = nb_step - x_min, x_max = self.x_core.min() - 2, self.x_core.max() + 2 - d_x = x_max - x_min - y_min, y_max = self.y_core.min() - 2, self.y_core.max() + 2 - d_y = y_max - y_min # plot - self.fig = pyplot.figure(figsize=figsize, **kwargs) + if "figure" in kwargs: + self.fig = kwargs.pop("figure") + else: + self.fig = pyplot.figure(figsize=figsize, **kwargs) t0, t1 = self.period self.fig.suptitle(f"{t0} -> {t1}") - self.ax = self.fig.add_axes((0.05, 0.05, 0.9, 0.9)) - self.ax.set_xlim(x_min, x_max), self.ax.set_ylim(y_min, y_max) - self.ax.set_aspect("equal") - self.ax.grid() - # init mappable - self.txt = self.ax.text(x_min + 0.05 * d_x, y_min + 0.05 * d_y, "", zorder=10) + if isinstance(position, Axes): + self.ax = position + else: + x_min, x_max = self.x_core.min() - 2, self.x_core.max() + 2 + d_x = x_max - x_min + y_min, y_max = self.y_core.min() - 2, self.y_core.max() + 2 + d_y = y_max - y_min + self.ax = self.fig.add_axes(position, projection="full_axes") + self.ax.set_xlim(x_min, x_max), self.ax.set_ylim(y_min, y_max) + self.ax.set_aspect("equal") + self.ax.grid() + self.txt = self.ax.text( + x_min + 0.05 * d_x, y_min + 0.05 * d_y, "", zorder=10 + ) self.segs = list() - self.contour = LineCollection([], zorder=1) + self.t_segs = list() + self.c_segs = list() + if field_color is None: + self.contour = LineCollection([], zorder=1, color="gray") + else: + self.contour = LineCollection([], zorder=1) self.ax.add_collection(self.contour) self.fig.canvas.draw() @@ -89,8 +142,9 @@ def show(self, infinity_loop=False): if dt < 0: # self.sleep_event = dt_draw * 1.01 dt = 1e-10 + if dt == 0: + dt = 1e-10 self.fig.canvas.start_event_loop(dt) - if self.now > t1: break if infinity_loop: @@ -117,32 +171,63 @@ def func_animation(self, frame): def update(self): m = self.t == self.now + color = self.field_color is not None if m.sum(): - self.segs.append( - create_vertice( - flatten_line_matrix(self.x[m]), flatten_line_matrix(self.y[m]) + segs = list() + t = list() + c = list() + for i in where(m)[0]: + segs.append(create_vertice(self.x[i], self.y[i])) + if color: + c.append(self.field_color[i]) + t.append(self.now) + self.segs.append(segs) + if color: + self.c_segs.append(c) + self.t_segs.append(t) + self.contour.set_paths(chain(*self.segs)) + if color: + if self.time_field: + self.contour.set_color( + self.colors( + [ + (self.nb_step - self.now + i) / self.nb_step + for i in chain(*self.c_segs) + ] + ) ) - ) - else: - self.segs.append(empty((0, 2))) - self.contour.set_paths(self.segs) - self.contour.set_color(self.colors[-len(self.segs) :]) - self.contour.set_lw(arange(len(self.segs)) / len(self.segs) * 2.5) - txt = f"{self.now}" - if self.graphic_informations: - txt += f"- {1/self.sleep_event:.0f} frame/s" - self.txt.set_text(txt) - for i in where(m)[0]: - mappable = self.ax.text( - self.x_core[i], self.y_core[i], self.track[i], fontsize=8 - ) - self.mappables.append(mappable) - self.ax.draw_artist(mappable) + else: + self.contour.set_color(self.colors(list(chain(*self.c_segs)))) + # linewidth will be link to time delay + self.contour.set_lw( + [ + (1 - (self.now - i) / self.nb_step) * 2.5 if i <= self.now else 0 + for i in chain(*self.t_segs) + ] + ) + # Update date txt and info + if self.txt is not None: + txt = f"{(timedelta(int(self.now)) + datetime(1950,1,1)).strftime('%Y/%m/%d')}" + if self.graphic_informations: + txt += f"- {1/self.sleep_event:.0f} frame/s" + self.txt.set_text(txt) + self.ax.draw_artist(self.txt) + # Update id txt + if self.field_txt is not None: + for i in where(m)[0]: + mappable = self.ax.text( + self.x_core[i], self.y_core[i], self.field_txt[i], **self.kw_label + ) + self.mappables.append(mappable) + self.ax.draw_artist(mappable) self.ax.draw_artist(self.contour) - self.ax.draw_artist(self.txt) + # Remove first segment to keep only T contour if len(self.segs) > self.nb_step: self.segs.pop(0) + self.t_segs.pop(0) + if color: + self.c_segs.pop(0) def draw_contour(self): # select contour for this time step @@ -165,8 +250,13 @@ def keyboard(self, event): elif event.key == "right" and self.pause: self.next() elif event.key == "left" and self.pause: + # we remove 2 step to add 1 so we rewind of only one self.segs.pop(-1) self.segs.pop(-1) + self.t_segs.pop(-1) + self.t_segs.pop(-1) + self.c_segs.pop(-1) + self.c_segs.pop(-1) self.prev() @@ -177,11 +267,7 @@ def anim(): ) parser.add_argument("filename", help="eddy atlas") parser.add_argument("id", help="Track id to anim", type=int, nargs="*") - parser.add_argument( - "--intern", - action="store_true", - help="display intern contour inplace of outter contour", - ) + parser.contour_intern_arg() parser.add_argument( "--keep_step", default=25, help="number maximal of step displayed", type=int ) @@ -196,31 +282,74 @@ def anim(): parser.add_argument( "--infinity_loop", action="store_true", help="Press Escape key to stop loop" ) + parser.add_argument( + "--first_centered", + action="store_true", + help="Longitude will be centered on first obs.", + ) + parser.add_argument( + "--field", default="time", help="Field use to color contour instead of time" + ) + parser.add_argument("--txt_field", default="track", help="Field use to text eddy") + parser.add_argument( + "--vmin", default=None, type=float, help="Inferior bound to color contour" + ) + parser.add_argument( + "--vmax", default=None, type=float, help="Upper bound to color contour" + ) + parser.add_argument("--mp4", help="Filename to save animation (mp4)") args = parser.parse_args() - variables = ["time", "track", "longitude", "latitude"] + variables = list( + set(["time", "track", "longitude", "latitude", args.field, args.txt_field]) + ) variables.extend(TrackEddiesObservations.intern(args.intern, public_label=True)) - eddies = TrackEddiesObservations.load_file(args.filename, include_vars=variables) + eddies = TrackEddiesObservations.load_file( + args.filename, include_vars=set(variables) + ) if not args.all: if len(args.id) == 0: raise Exception( "You need to specify id to display or ask explicity all with --all option" ) eddies = eddies.extract_ids(args.id) + if args.first_centered: + # TODO: include to observation class + x0 = eddies.lon[0] + eddies.lon[:] = (eddies.lon - x0 + 180) % 360 + x0 - 180 + eddies.contour_lon_e[:] = ( + (eddies.contour_lon_e.T - eddies.lon + 180) % 360 + eddies.lon - 180 + ).T + + kw = dict() + if args.mp4: + kw["figsize"] = (16, 9) + kw["dpi"] = 120 a = Anim( eddies, intern=args.intern, sleep_event=args.time_sleep, cmap=args.cmap, nb_step=args.keep_step, + field_color=args.field, + field_txt=args.txt_field, + range_color=(args.vmin, args.vmax), + graphic_information=logger.getEffectiveLevel() == logging.DEBUG, + **kw, ) - a.show(infinity_loop=args.infinity_loop) + if args.mp4 is None: + a.show(infinity_loop=args.infinity_loop) + else: + kwargs = dict(frames=arange(*a.period), interval=50) + ani = FuncAnimation(a.fig, a.func_animation, **kwargs) + ani.save(args.mp4, fps=30, extra_args=["-vcodec", "libx264"]) def gui_parser(): parser = EddyParser("Eddy atlas GUI") parser.add_argument("atlas", nargs="+") parser.add_argument("--med", action="store_true") + parser.add_argument("--nopath", action="store_true", help="Don't draw path") return parser.parse_args() @@ -232,4 +361,5 @@ def guieddy(): g = GUI(**atlas) if args.med: g.med() + g.hide_path(not args.nopath) g.show() diff --git a/src/py_eddy_tracker/appli/network.py b/src/py_eddy_tracker/appli/network.py index c92f1d3d..0a3d06ca 100644 --- a/src/py_eddy_tracker/appli/network.py +++ b/src/py_eddy_tracker/appli/network.py @@ -5,15 +5,11 @@ import logging -from netCDF4 import Dataset -from numpy import arange, empty, zeros -from Polygon import Polygon +from numpy import in1d, zeros from .. import EddyParser -from ..generic import build_index -from ..observations.network import Network +from ..observations.network import Network, NetworkObservations from ..observations.tracking import TrackEddiesObservations -from ..poly import create_vertice_from_2darray, polygon_overlap logger = logging.getLogger("pet") @@ -27,277 +23,280 @@ def build_network(): parser.add_argument( "--window", "-w", type=int, help="Half time window to search eddy", default=1 ) + + parser.add_argument( + "--min-overlap", + "-p", + type=float, + help="minimum overlap area to associate observations", + default=0.2, + ) parser.add_argument( - "--intern", + "--minimal-area", action="store_true", - help="Use intern contour instead of outter contour", + help="If True, use intersection/little polygon, else intersection/union", ) + parser.add_argument( + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", + ) + + parser.contour_intern_arg() + + parser.memory_arg() args = parser.parse_args() - n = Network(args.identification_regex, window=args.window, intern=args.intern) - group = n.group_observations(minimal_area=True) + n = Network( + args.identification_regex, + window=args.window, + intern=args.intern, + memory=args.memory, + ) + group = n.group_observations( + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, + ) n.build_dataset(group).write_file(filename=args.out) def divide_network(): - parser = EddyParser("Separate path for a same group") + parser = EddyParser("Separate path for a same group (network)") parser.add_argument("input", help="input network file") parser.add_argument("out", help="output file") + parser.contour_intern_arg() + parser.add_argument( + "--window", "-w", type=int, help="Half time window to search eddy", default=1 + ) parser.add_argument( - "--intern", + "--min-overlap", + "-p", + type=float, + help="minimum overlap area to associate observations", + default=0.2, + ) + parser.add_argument( + "--minimal-area", action="store_true", - help="Use intern contour instead of outter contour", + help="If True, use intersection/little polygon, else intersection/union", ) parser.add_argument( - "--window", "-w", type=int, help="Half time window to search eddy", default=1 + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", ) args = parser.parse_args() contour_name = TrackEddiesObservations.intern(args.intern, public_label=True) e = TrackEddiesObservations.load_file( - args.input, include_vars=("time", "track", *contour_name) + args.input, + include_vars=("time", "track", "latitude", "longitude", *contour_name), ) - e.split_network(intern=args.intern, window=args.window) - # split_network(args.input, args.out) - - -def split_network(input, output): - """Divide each group in track""" - sl = slice(None) - with Dataset(input) as h: - group = h.variables["track"][sl] - track_s, track_e, track_ref = build_index(group) - # nb = track_e - track_s - # m = nb > 1500 - # print(group[track_s[m]]) - - track_id = 12003 - sls = [slice(track_s[track_id - track_ref], track_e[track_id - track_ref], None)] - for sl in sls: - - print(sl) - with Dataset(input) as h: - time = h.variables["time"][sl] - group = h.variables["track"][sl] - x = h.variables["effective_contour_longitude"][sl] - y = h.variables["effective_contour_latitude"][sl] - print(group[0]) - ids = empty( - time.shape, - dtype=[ - ("group", group.dtype), - ("time", time.dtype), - ("track", "u2"), - ("previous_cost", "f4"), - ("next_cost", "f4"), - ("previous_observation", "i4"), - ("next_observation", "i4"), - ], - ) - ids["group"] = group - ids["time"] = time - # To store id track - ids["track"] = 0 - ids["previous_cost"] = 0 - ids["next_cost"] = 0 - ids["previous_observation"] = -1 - ids["next_observation"] = -1 - # Cost with previous - track_start, track_end, track_ref = build_index(group) - for i0, i1 in zip(track_start, track_end): - if (i1 - i0) == 0 or group[i0] == Network.NOGROUP: - continue - sl_group = slice(i0, i1) - set_tracks( - x[sl_group], - y[sl_group], - time[sl_group], - i0, - ids["track"][sl_group], - ids["previous_cost"][sl_group], - ids["next_cost"][sl_group], - ids["previous_observation"][sl_group], - ids["next_observation"][sl_group], - window=5, - ) + n = NetworkObservations.from_split_network( + TrackEddiesObservations.load_file(args.input, raw_data=True), + e.split_network( + intern=args.intern, + window=args.window, + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, + ), + ) + n.write_file(filename=args.out) + + +def subset_network(): + parser = EddyParser("Subset network") + parser.add_argument("input", help="input network file") + parser.add_argument("out", help="output file") + parser.add_argument( + "-l", + "--length", + nargs=2, + type=int, + help="Nb of days that must be covered by the network, first minimum number of day and last maximum number of day," + "if value is negative, this bound won't be used", + ) + parser.add_argument( + "--remove_dead_end", + nargs=2, + type=int, + help="Remove short dead end, first is for minimal obs number and second for minimal segment time to keep", + ) + parser.add_argument( + "--remove_trash", + action="store_true", + help="Remove trash (network id == 0)", + ) + parser.add_argument( + "-i", "--ids", nargs="+", type=int, help="List of network which will be extract" + ) + parser.add_argument( + "-p", + "--period", + nargs=2, + type=int, + help="Start day and end day, if it's a negative value we will add to day min and add to day max," + "if 0 it is not used", + ) + args = parser.parse_args() + n = NetworkObservations.load_file(args.input, raw_data=True) + if args.ids is not None: + n = n.networks(args.ids) + if args.length is not None: + n = n.longer_than(*args.length) + if args.remove_dead_end is not None: + n = n.remove_dead_end(*args.remove_dead_end) + if args.period is not None: + n = n.extract_with_period(args.period) + n.write_file(filename=args.out) + + +def quick_compare(): + parser = EddyParser( + """Tool to have a quick comparison between several network: + - N : network + - S : segment + - Obs : observations + """ + ) + parser.add_argument("ref", help="Identification file of reference") + parser.add_argument("others", nargs="+", help="Identifications files to compare") + parser.add_argument( + "--path_out", default=None, help="Save each group in separate file" + ) + args = parser.parse_args() - new_i = ids.argsort(order=("group", "track", "time")) - ids_sort = ids[new_i] - # To be able to follow indices sorting - reverse_sort = empty(new_i.shape[0], dtype="u4") - reverse_sort[new_i] = arange(new_i.shape[0]) - # Redirect indices - m = ids_sort["next_observation"] != -1 - ids_sort["next_observation"][m] = reverse_sort[ids_sort["next_observation"][m]] - m = ids_sort["previous_observation"] != -1 - ids_sort["previous_observation"][m] = reverse_sort[ - ids_sort["previous_observation"][m] + kw = dict( + include_vars=[ + "longitude", + "latitude", + "time", + "track", + "segment", + "next_obs", + "previous_obs", ] - # print(ids_sort) - display_network( - x[new_i], - y[new_i], - ids_sort["track"], - ids_sort["time"], - ids_sort["next_cost"], - ) + ) + if args.path_out is not None: + kw = dict() -def next_obs( - i_current, next_cost, previous_cost, polygons, t, t_start, t_end, t_ref, window -): - t_max = t_end.shape[0] - 1 - t_cur = t[i_current] - t0, t1 = t_cur + 1 - t_ref, t_cur + window - t_ref - if t0 > t_max: - return -1 - t1 = min(t1, t_max) - for t_step in range(t0, t1 + 1): - i0, i1 = t_start[t_step], t_end[t_step] - # No observation at the time step ! - if i0 == i1: - continue - sl = slice(i0, i1) - # Intersection / union, to be able to separte in case of multiple inside - c = polygon_overlap(polygons[i_current], polygons[sl]) - # We remove low overlap - if (c > 0.1).sum() > 1: - print(c) - c[c < 0.1] = 0 - # We get index of maximal overlap - i = c.argmax() - c_i = c[i] - # No overlap found - if c_i == 0: - continue - target = i0 + i - # Check if candidate is already used - c_target = previous_cost[target] - if (c_target != 0 and c_target < c_i) or c_target == 0: - previous_cost[target] = c_i - next_cost[i_current] = c_i - return target - return -1 - - -def set_tracks( - x, - y, - t, - ref_index, - track, - previous_cost, - next_cost, - previous_observation, - next_observation, - window, -): - # Will split one group in tracks - t_start, t_end, t_ref = build_index(t) - nb = x.shape[0] - used = zeros(nb, dtype="bool") - current_track = 1 - # build all polygon (need to check if wrap is needed) - polygons = list() - for i in range(nb): - polygons.append(Polygon(create_vertice_from_2darray(x, y, i))) - - for i in range(nb): - # If observation already in one track, we go to the next one - if used[i]: - continue - build_track( - i, - current_track, - used, - track, - previous_observation, - next_observation, - ref_index, - next_cost, - previous_cost, - polygons, - t, - t_start, - t_end, - t_ref, - window, - ) - current_track += 1 - - -def build_track( - first_index, - track_id, - used, - track, - previous_observation, - next_observation, - ref_index, - next_cost, - previous_cost, - *args, -): - i_next = first_index - while i_next != -1: - # Flag - used[i_next] = True - # Assign id - track[i_next] = track_id - # Search next - i_next_ = next_obs(i_next, next_cost, previous_cost, *args) - if i_next_ == -1: - break - next_observation[i_next] = i_next_ + ref_index - if not used[i_next_]: - previous_observation[i_next_] = i_next + ref_index - # Target was previously used - if used[i_next_]: - if next_cost[i_next] == previous_cost[i_next_]: - m = track[i_next_:] == track[i_next_] - track[i_next_:][m] = track_id - previous_observation[i_next_] = i_next + ref_index - i_next_ = -1 - i_next = i_next_ - - -def display_network(x, y, tr, t, c): - tr0, tr1, t_ref = build_index(tr) - import matplotlib.pyplot as plt - - cmap = plt.get_cmap("jet") - from ..generic import flatten_line_matrix - - fig = plt.figure(figsize=(20, 10)) - ax = fig.add_subplot(121, aspect="equal") - ax.grid() - ax_time = fig.add_subplot(122) - ax_time.grid() - i = 0 - for s, e in zip(tr0, tr1): - if s == e: - continue - sl = slice(s, e) - color = cmap((tr[s] - tr[tr0[0]]) / (tr[tr0[-1]] - tr[tr0[0]])) - ax.plot( - flatten_line_matrix(x[sl]), - flatten_line_matrix(y[sl]), - color=color, - label=f"{tr[s]} - {e-s} obs from {t[s]} to {t[e-1]}", - ) - i += 1 - ax_time.plot( - t[sl], - tr[s].repeat(e - s) + c[sl], - color=color, - label=f"{tr[s]} - {e-s} obs", - lw=0.5, - ) - ax_time.plot(t[sl], tr[s].repeat(e - s), color=color, lw=1, marker="+") - ax_time.text(t[s], tr[s] + 0.15, f"{x[s].mean():.2f}, {y[s].mean():.2f}") - ax_time.axvline(t[s], color=".75", lw=0.5, ls="--", zorder=-10) - ax_time.text( - t[e - 1], tr[e - 1] - 0.25, f"{x[e-1].mean():.2f}, {y[e-1].mean():.2f}" + ref = NetworkObservations.load_file(args.ref, **kw) + print( + f"[ref] {args.ref} -> {ref.nb_network} network / {ref.nb_segment} segment / {len(ref)} obs " + f"-> {ref.network_size(0)} trash obs, " + f"{len(ref.merging_event())} merging, {len(ref.splitting_event())} spliting" + ) + others = { + other: NetworkObservations.load_file(other, **kw) for other in args.others + } + + # if args.path_out is not None: + # groups_ref, groups_other = run_compare(ref, others, **kwargs) + # if not exists(args.path_out): + # mkdir(args.path_out) + # for i, other_ in enumerate(args.others): + # dirname_ = f"{args.path_out}/{other_.replace('/', '_')}/" + # if not exists(dirname_): + # mkdir(dirname_) + # for k, v in groups_other[other_].items(): + # basename_ = f"other_{k}.nc" + # others[other_].index(v).write_file(filename=f"{dirname_}/{basename_}") + # for k, v in groups_ref[other_].items(): + # basename_ = f"ref_{k}.nc" + # ref.index(v).write_file(filename=f"{dirname_}/{basename_}") + # return + display_compare(ref, others) + + +def run_compare(ref, others): + outs = dict() + for i, (k, other) in enumerate(others.items()): + out = dict() + print( + f"[{i}] {k} -> {other.nb_network} network / {other.nb_segment} segment / {len(other)} obs " + f"-> {other.network_size(0)} trash obs, " + f"{len(other.merging_event())} merging, {len(other.splitting_event())} spliting" ) - ax.legend() - ax_time.legend() - plt.show() + ref_id, other_id = ref.identify_in(other, size_min=2) + m = other_id != -1 + ref_id, other_id = ref_id[m], other_id[m] + out["same N(N)"] = m.sum() + out["same N(Obs)"] = ref.network_size(ref_id).sum() + + # For network which have same obs + ref_, other_ = ref.networks(ref_id), other.networks(other_id) + ref_segu, other_segu = ref_.identify_in(other_, segment=True) + m = other_segu == -1 + ref_track_no_match, _ = ref_.unique_segment_to_id(ref_segu[m]) + ref_segu, other_segu = ref_segu[~m], other_segu[~m] + m = ~in1d(ref_id, ref_track_no_match) + out["same NS(N)"] = m.sum() + out["same NS(Obs)"] = ref.network_size(ref_id[m]).sum() + + # Check merge/split + def follow_obs(d, i_follow): + m = i_follow != -1 + i_follow = i_follow[m] + t, x, y = ( + zeros(m.size, d.time.dtype), + zeros(m.size, d.longitude.dtype), + zeros(m.size, d.latitude.dtype), + ) + t[m], x[m], y[m] = ( + d.time[i_follow], + d.longitude[i_follow], + d.latitude[i_follow], + ) + return t, x, y + + def next_obs(d, i_seg): + last_i = d.index_segment_track[1][i_seg] - 1 + return follow_obs(d, d.next_obs[last_i]) + + def previous_obs(d, i_seg): + first_i = d.index_segment_track[0][i_seg] + return follow_obs(d, d.previous_obs[first_i]) + + tref, xref, yref = next_obs(ref_, ref_segu) + tother, xother, yother = next_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + tref, xref, yref = previous_obs(ref_, ref_segu) + tother, xother, yother = previous_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + ref_segu, other_segu = ref.identify_in(other, segment=True) + m = other_segu != -1 + out["same S(S)"] = m.sum() + out["same S(Obs)"] = ref.segment_size()[ref_segu[m]].sum() + + outs[k] = out + return outs + + +def display_compare(ref, others): + def display(value, ref=None): + if ref: + outs = [f"{v / ref[k] * 100:.1f}% ({v})" for k, v in value.items()] + else: + outs = value + return "".join([f"{v:^18}" for v in outs]) + + datas = run_compare(ref, others) + ref_ = { + "same N(N)": ref.nb_network, + "same N(Obs)": len(ref), + "same NS(N)": ref.nb_network, + "same NS(Obs)": len(ref), + "same S(S)": ref.nb_segment, + "same S(Obs)": len(ref), + } + print(" ", display(ref_.keys())) + for i, (_, v) in enumerate(datas.items()): + print(f"[{i:2}] ", display(v, ref=ref_)) diff --git a/src/py_eddy_tracker/data/Anticyclonic_20190223.nc b/src/py_eddy_tracker/data/Anticyclonic_20190223.nc index ce48c8d6..4ab8f226 100644 Binary files a/src/py_eddy_tracker/data/Anticyclonic_20190223.nc and b/src/py_eddy_tracker/data/Anticyclonic_20190223.nc differ diff --git a/src/py_eddy_tracker/data/__init__.py b/src/py_eddy_tracker/data/__init__.py index 644cf95d..bf062983 100644 --- a/src/py_eddy_tracker/data/__init__.py +++ b/src/py_eddy_tracker/data/__init__.py @@ -8,26 +8,41 @@ 20160515 adt None None longitude latitude . \ --cut 800 --fil 1 """ + import io import lzma -import tarfile from os import path +import tarfile import requests -def get_path(name): +def get_demo_path(name): return path.join(path.dirname(__file__), name) -def get_remote_sample(path): - url = ( - f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" - ) - - content = requests.get(url).content +def get_remote_demo_sample(path): + if path.startswith("/") or path.startswith("."): + content = open(path, "rb").read() + if path.endswith(".nc"): + return io.BytesIO(content) + else: + try: + import py_eddy_tracker_sample_id + if path.endswith(".nc"): + return py_eddy_tracker_sample_id.get_remote_demo_sample(path) + content = open(py_eddy_tracker_sample_id.get_remote_demo_sample(f"{path}.tar.xz"), "rb").read() + except: + if path.endswith(".nc"): + content = requests.get( + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}" + ).content + return io.BytesIO(content) + content = requests.get( + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" + ).content - # Tar module could manage lzma tar, but it will apply un compress for each extractfile + # Tar module could manage lzma tar, but it will apply uncompress for each extractfile tar = tarfile.open(mode="r", fileobj=io.BytesIO(lzma.decompress(content))) # tar = tarfile.open(mode="r:xz", fileobj=io.BytesIO(content)) files_content = list() diff --git a/src/py_eddy_tracker/data/dt_med_allsat_phy_l4_2005T2.nc b/src/py_eddy_tracker/data/dt_med_allsat_phy_l4_2005T2.nc new file mode 100644 index 00000000..cff2e2c7 Binary files /dev/null and b/src/py_eddy_tracker/data/dt_med_allsat_phy_l4_2005T2.nc differ diff --git a/src/py_eddy_tracker/data/loopers_lumpkin_med.nc b/src/py_eddy_tracker/data/loopers_lumpkin_med.nc new file mode 100644 index 00000000..cf817424 Binary files /dev/null and b/src/py_eddy_tracker/data/loopers_lumpkin_med.nc differ diff --git a/src/py_eddy_tracker/data/network_med.nc b/src/py_eddy_tracker/data/network_med.nc new file mode 100644 index 00000000..d695b09b Binary files /dev/null and b/src/py_eddy_tracker/data/network_med.nc differ diff --git a/src/py_eddy_tracker/dataset/grid.py b/src/py_eddy_tracker/dataset/grid.py index 3fd48f89..f15503b2 100644 --- a/src/py_eddy_tracker/dataset/grid.py +++ b/src/py_eddy_tracker/dataset/grid.py @@ -2,14 +2,14 @@ """ Class to load and manipulate RegularGrid and UnRegularGrid """ -import logging from datetime import datetime +import logging from cv2 import filter2D from matplotlib.path import Path as BasePath from netCDF4 import Dataset -from numba import njit -from numba import types as numba_types +from numba import njit, prange, types as numba_types +import numpy as np from numpy import ( arange, array, @@ -21,36 +21,36 @@ errstate, exp, float_, + floor, histogram2d, - int8, int_, interp, isnan, linspace, ma, -) -from numpy import mean as np_mean -from numpy import ( + mean as np_mean, meshgrid, nan, nanmean, ones, percentile, pi, - round_, + radians, sin, sinc, + sqrt, where, zeros, ) from pint import UnitRegistry from scipy.interpolate import RectBivariateSpline, interp1d -from scipy.ndimage import convolve, gaussian_filter +from scipy.ndimage import gaussian_filter from scipy.signal import welch from scipy.spatial import cKDTree from scipy.special import j1 from .. import VAR_DESCR +from ..data import get_demo_path from ..eddy_feature import Amplitude, Contours from ..generic import ( bbox_indice_regular, @@ -68,6 +68,7 @@ get_pixel_in_regular, poly_area, poly_contain_poly, + visvalingam, winding_number_poly, ) @@ -123,7 +124,7 @@ def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size @njit(cache=True) def mean_on_regular_contour( - x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None, nan_remove=False + x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=-1, nan_remove=False ): x_val, y_val = vertices[:, 0], vertices[:, 1] x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size) @@ -151,7 +152,7 @@ def _circle_from_equal_area(vertice): # last coordinates == first lon0, lat0 = lons[1:].mean(), lats[1:].mean() c_x, c_y = coordinates_to_local(lons, lats, lon0, lat0) - # Some time, edge is only a dot of few coordinates + # Sometimes, edge is only a dot of few coordinates d_lon = lons.max() - lons.min() d_lat = lats.max() - lats.min() if d_lon < 1e-7 and d_lat < 1e-7: @@ -237,19 +238,15 @@ def nb_pixel(self): class GridDataset(object): """ - Class to have basic tool on NetCDF Grid + Class for basic tools on NetCDF Grid """ __slots__ = ( - "_x_var", - "_y_var", "x_c", "y_c", "x_bounds", "y_bounds", "centered", - "xinterp", - "yinterp", "x_dim", "y_dim", "coordinates", @@ -259,9 +256,8 @@ class GridDataset(object): "variables_description", "global_attrs", "vars", - "interpolators", - "speed_coef", "contours", + "nan_mask", ) GRAVITY = 9.807 @@ -271,15 +267,23 @@ class GridDataset(object): N = 1 def __init__( - self, filename, x_name, y_name, centered=None, indexs=None, unset=False + self, + filename, + x_name, + y_name, + centered=None, + indexs=None, + unset=False, + nan_masking=False, ): """ :param str filename: Filename to load :param str x_name: Name of longitude coordinates :param str y_name: Name of latitude coordinates :param bool,None centered: Allow to know how coordinates could be used with pixel - :param dict indexs: A dictionary which set indexs to use for non-coordinate dimensions + :param dict indexs: A dictionary that sets indexes to use for non-coordinate dimensions :param bool unset: Set to True to create an empty grid object without file + :param bool nan_masking: Set to True to replace data.mask with isnan method result """ self.dimensions = None self.variables_description = None @@ -290,23 +294,38 @@ def __init__( self.y_bounds = None self.x_dim = None self.y_dim = None + self.nan_mask = nan_masking self.centered = centered self.contours = None - self.xinterp = None - self.yinterp = None self.filename = filename self.coordinates = x_name, y_name self.vars = dict() self.indexs = dict() if indexs is None else indexs - self.interpolators = dict() if centered is None: logger.warning( - "We assume pixel position of grid is center for %s", filename + "We assume pixel position of grid is centered for %s", filename ) if not unset: + self.populate() + + def populate(self): + if self.dimensions is None: self.load_general_features() self.load() + def clean(self): + self.dimensions = None + self.variables_description = None + self.global_attrs = None + self.x_c = None + self.y_c = None + self.x_bounds = None + self.y_bounds = None + self.x_dim = None + self.y_dim = None + self.contours = None + self.vars = dict() + @property def is_centered(self): """Give True if pixel is described with its center's position or @@ -321,7 +340,7 @@ def is_centered(self): return self.centered def load_general_features(self): - """Load attrs to be stored in object""" + """Load attrs to be stored in object""" logger.debug( "Load general feature from %(filename)s", dict(filename=self.filename) ) @@ -398,14 +417,21 @@ def load(self): self.vars[y_name] = h.variables[y_name][sl_y] self.setup_coordinates() - self.init_pos_interpolator() + + @staticmethod + def get_mask(a): + if len(a.mask.shape): + m = a.mask + else: + m = ones(a.shape, dtype="bool") if a.mask else zeros(a.shape, dtype="bool") + return m @staticmethod def c_to_bounds(c): """ - Centred coordinates to bounds coordinates + Centered coordinates to bounds coordinates - :param array c: centred coordinates to translate + :param array c: centered coordinates to translate :return: bounds coordinates """ bounds = concatenate((c, (2 * c[-1] - c[-2],))) @@ -417,9 +443,9 @@ def c_to_bounds(c): def setup_coordinates(self): x_name, y_name = self.coordinates if self.is_centered: - logger.info("Grid center") - self.x_c = self.vars[x_name].astype("float64") - self.y_c = self.vars[y_name].astype("float64") + # logger.info("Grid center") + self.x_c = array(self.vars[x_name].astype("float64")) + self.y_c = array(self.vars[y_name].astype("float64")) self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],))) self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],))) @@ -431,8 +457,8 @@ def setup_coordinates(self): self.y_bounds[-1] -= d_y[-1] / 2 else: - self.x_bounds = self.vars[x_name].astype("float64") - self.y_bounds = self.vars[y_name].astype("float64") + self.x_bounds = array(self.vars[x_name].astype("float64")) + self.y_bounds = array(self.vars[y_name].astype("float64")) if len(self.x_dim) == 1: self.x_c = self.x_bounds.copy() @@ -527,6 +553,11 @@ def grid(self, varname, indexs=None): if i_x > i_y: self.variables_description[varname]["infos"]["transpose"] = True self.vars[varname] = self.vars[varname].T + if self.nan_mask: + self.vars[varname] = ma.array( + self.vars[varname], + mask=isnan(self.vars[varname]), + ) if not hasattr(self.vars[varname], "mask"): self.vars[varname] = ma.array( self.vars[varname], @@ -566,7 +597,7 @@ def grid_tiles(self, varname, slice_x, slice_y): return data def high_filter(self, grid_name, w_cut, **kwargs): - """Return the grid high-pass filtered, by substracting to the grid the low-pass filter (default: order=1) + """Return the high-pass filtered grid, by substracting to the initial grid the low-pass filtered grid (default: order=1) :param grid_name: the name of the grid :param int, w_cut: the half-power wavelength cutoff (km) @@ -575,7 +606,7 @@ def high_filter(self, grid_name, w_cut, **kwargs): self.vars[grid_name] -= result def low_filter(self, grid_name, w_cut, **kwargs): - """Return the grid low-pass filtered (default: order=1) + """Return the low-pass filtered grid (default: order=1) :param grid_name: the name of the grid :param int, w_cut: the half-power wavelength cutoff (km) @@ -601,7 +632,9 @@ def eddy_identification( date, step=0.005, shape_error=55, + presampling_multiplier=10, sampling=50, + sampling_method="visvalingam", pixel_limit=None, precision=None, force_height_unit=None, @@ -609,30 +642,36 @@ def eddy_identification( **kwargs, ): """ - Compute eddy identification on the pecified grid + Compute eddy identification on the specified grid :param str grid_height: Grid name of Sea Surface Height :param str uname: Grid name of u speed component :param str vname: Grid name of v speed component - :param datetime.datetime date: Date which will be stored in object to date data + :param datetime.datetime date: Date to be stored in object to date data :param float,int step: Height between two layers in m - :param float,int shape_error: Maximal error allowed for outter contour in % + :param float,int shape_error: Maximal error allowed for outermost contour in % + :param int presampling_multiplier: + Evenly oversample the initial number of points in the contour by nb_pts x presampling_multiplier to fit circles :param int sampling: Number of points to store contours and speed profile + :param str sampling_method: Method to resample the stored contours, 'uniform' or 'visvalingam' :param (int,int),None pixel_limit: - Min and max number of pixels inside the inner and the outer contour to be considered as an eddy + Min and max number of pixels inside the inner and the outermost contour to be considered as an eddy :param float,None precision: Truncate values at the defined precision in m :param str force_height_unit: Unit used for height unit :param str force_speed_unit: Unit used for speed unit - :param dict kwargs: Argument given to amplitude + :param dict kwargs: Arguments given to amplitude (mle, nb_step_min, nb_step_to_be_mle). + Look at :py:meth:`py_eddy_tracker.eddy_feature.Amplitude` + The amplitude threshold is given by `step*nb_step_min` + - :return: Return a list of 2 elements: Anticyclone and Cyclone + :return: Return a list of 2 elements: Anticyclones and Cyclones :rtype: py_eddy_tracker.observations.observation.EddiesObservations .. minigallery:: py_eddy_tracker.GridDataset.eddy_identification """ if not isinstance(date, datetime): - raise Exception("Date argument be a datetime object") - # The inf limit must be in pixel and sup limit in surface + raise Exception("Date argument must be a datetime object") + # The inf limit must be in pixel and sup limit in surface if pixel_limit is None: pixel_limit = (4, 1000) @@ -640,7 +679,6 @@ def eddy_identification( self.init_speed_coef(uname, vname) # Get unit of h grid - h_units = ( self.units(grid_height) if force_height_unit is None else force_height_unit ) @@ -658,10 +696,10 @@ def eddy_identification( # Get ssh grid data = self.grid(grid_height).astype("f8") - # In case of a reduce mask + # In case of a reduced mask if len(data.mask.shape) == 0 and not data.mask: data.mask = zeros(data.shape, dtype="bool") - # we remove noisy information + # we remove noisy data if precision is not None: data = (data / precision).round() * precision # Compute levels for ssh @@ -684,6 +722,7 @@ def eddy_identification( ) z_min, z_max = z_min_p, z_max_p + logger.debug("Levels from %f to %f", z_min, z_max) levels = arange(z_min - z_min % step, z_max - z_max % step + 2 * step, step) # Get x and y values @@ -693,6 +732,7 @@ def eddy_identification( self.contours = Contours(x, y, data, levels, wrap_x=self.is_circular()) out_sampling = dict(fixed_size=sampling) + resample = visvalingam if sampling_method == "visvalingam" else uniform_resample track_extra_variables = [ "height_max_speed_contour", "height_external_contour", @@ -735,7 +775,8 @@ def eddy_identification( for contour in contour_paths: if contour.used: continue - # FIXME : center could be not in contour and fit on raw sampling + # FIXME : center could be outside the contour due to the fit + # FIXME : warning : the fit is made on raw sampling _, _, _, aerr = contour.fit_circle() # Filter for shape @@ -754,11 +795,12 @@ def eddy_identification( # Test of the rotating sense: cyclone or anticyclone if has_value( - data, i_x_in, i_y_in, cvalues, below=anticyclonic_search + data.data, i_x_in, i_y_in, cvalues, below=anticyclonic_search ): continue - # FIXME : Maybe limit max must be replace with a maximum of surface + # Test the number of pixels within the outermost contour + # FIXME : Maybe limit max must be replaced with a maximum of surface if ( contour.nb_pixel < pixel_limit[0] or contour.nb_pixel > pixel_limit[1] @@ -766,6 +808,9 @@ def eddy_identification( contour.reject = 3 continue + # Here the considered contour passed shape_error test, masked_pixels test, + # values strictly above (AEs) or below (CEs) the contour, number_pixels test) + # Compute amplitude reset_centroid, amp = self.get_amplitude( contour, @@ -781,13 +826,12 @@ def eddy_identification( contour.reject = 4 continue if reset_centroid: - if self.is_circular(): centi = self.normalize_x_indice(reset_centroid[0]) else: centi = reset_centroid[0] centj = reset_centroid[1] - # To move in regular and unregular grid + # FIXME : To move in regular and unregular grid if len(x.shape) == 1: centlon_e = x[centi] centlat_e = y[centj] @@ -795,7 +839,7 @@ def eddy_identification( centlon_e = x[centi, centj] centlat_e = y[centi, centj] - # centlat_e and centlon_e must be index of maximum, we will loose some inner contour if it's not + # centlat_e and centlon_e must be indexes of maximum, we will loose some inner contour if it's not ( max_average_speed, speed_contour, @@ -813,7 +857,7 @@ def eddy_identification( pixel_min=pixel_limit[0], ) - # FIXME : Instantiate new EddyObservation object (high cost need to be reviewed) + # FIXME : Instantiate new EddyObservation object (high cost, need to be reviewed) obs = EddiesObservations( size=1, track_extra_variables=track_extra_variables, @@ -832,42 +876,59 @@ def eddy_identification( obs.amplitude[:] = amp.amplitude obs.speed_average[:] = max_average_speed obs.num_point_e[:] = contour.lon.shape[0] - xy_e = uniform_resample(contour.lon, contour.lat, **out_sampling) - obs.contour_lon_e[:], obs.contour_lat_e[:] = xy_e obs.num_point_s[:] = speed_contour.lon.shape[0] + + # Evenly resample contours with nb_pts = nb_pts_original x presampling_multiplier + xy_i = uniform_resample( + inner_contour.lon, + inner_contour.lat, + num_fac=presampling_multiplier, + ) + xy_e = uniform_resample( + contour.lon, + contour.lat, + num_fac=presampling_multiplier, + ) xy_s = uniform_resample( - speed_contour.lon, speed_contour.lat, **out_sampling + speed_contour.lon, + speed_contour.lat, + num_fac=presampling_multiplier, ) - obs.contour_lon_s[:], obs.contour_lat_s[:] = xy_s - # FIXME : we use a contour without resampling - # First, get position based on innermost contour - centlon_i, centlat_i, _, _ = _fit_circle_path( - create_vertice(inner_contour.lon, inner_contour.lat) - ) - # Second, get speed-based radius based on contour of max uavg + # First, get position of max SSH based on best fit circle with resampled innermost contour + centlon_i, centlat_i, _, _ = _fit_circle_path(create_vertice(*xy_i)) + obs.lon_max[:] = centlon_i + obs.lat_max[:] = centlat_i + + # Second, get speed-based radius, shape error, eddy center, area based on resampled contour of max uavg centlon_s, centlat_s, eddy_radius_s, aerr_s = _fit_circle_path( create_vertice(*xy_s) ) - # Compute again to use resampled contour - _, _, eddy_radius_e, aerr_e = _fit_circle_path( - create_vertice(*xy_e) - ) - obs.radius_s[:] = eddy_radius_s - obs.radius_e[:] = eddy_radius_e - obs.shape_error_e[:] = aerr_e obs.shape_error_s[:] = aerr_s obs.speed_area[:] = poly_area( *coordinates_to_local(*xy_s, lon0=centlon_s, lat0=centlat_s) ) + obs.lon[:] = centlon_s + obs.lat[:] = centlat_s + + # Third, compute effective radius, shape error, area from resampled effective contour + _, _, eddy_radius_e, aerr_e = _fit_circle_path( + create_vertice(*xy_e) + ) + obs.radius_e[:] = eddy_radius_e + obs.shape_error_e[:] = aerr_e obs.effective_area[:] = poly_area( *coordinates_to_local(*xy_e, lon0=centlon_s, lat0=centlat_s) ) - obs.lon[:] = centlon_s - obs.lat[:] = centlat_s - obs.lon_max[:] = centlon_i - obs.lat_max[:] = centlat_i + + # Finally, resample contours with output parameters + xy_e_f = resample(*xy_e, **out_sampling) + xy_s_f = resample(*xy_s, **out_sampling) + + obs.contour_lon_s[:], obs.contour_lat_s[:] = xy_s_f + obs.contour_lon_e[:], obs.contour_lat_e[:] = xy_e_f + if aerr > 99.9 or aerr_s > 99.9: logger.warning( "Strange shape at this step! shape_error : %f, %f", @@ -929,7 +990,7 @@ def get_uavg( pixel_min=3, ): """ - Calculate geostrophic speed around successive contours + Compute geostrophic speed around successive contours Returns the average """ # Init max speed to search maximum @@ -1041,7 +1102,7 @@ def load(self): @property def bounds(self): - """Give bound""" + """Give bounds""" return self.x_c.min(), self.x_c.max(), self.y_c.min(), self.y_c.max() def bbox_indice(self, vertices): @@ -1073,7 +1134,7 @@ def compute_pixel_path(self, x0, y0, x1, y1): pass def init_pos_interpolator(self): - logger.debug("Create a KdTree could be long ...") + logger.debug("Create a KdTree, could be long ...") self.index_interp = cKDTree( create_vertice(self.x_c.reshape(-1), self.y_c.reshape(-1)) ) @@ -1091,7 +1152,7 @@ def _low_filter(self, grid_name, w_cut, factor=8.0): bins = (x_array, y_array) x_flat, y_flat, z_flat = x.reshape((-1,)), y.reshape((-1,)), data.reshape((-1,)) - m = ~z_flat.mask + m = ~self.get_mask(z_flat) x_flat, y_flat, z_flat = x_flat[m], y_flat[m], z_flat[m] nb_value, _, _ = histogram2d(x_flat, y_flat, bins=bins) @@ -1154,11 +1215,24 @@ def __init__(self, *args, **kwargs): def setup_coordinates(self): super().setup_coordinates() self.x_size = self.x_c.shape[0] + if len(self.x_c.shape) != 1: + raise Exception( + "Coordinates in RegularGridDataset must be 1D array, or think to use UnRegularGridDataset" + ) + dx = self.x_bounds[1:] - self.x_bounds[:-1] + dy = self.y_bounds[1:] - self.y_bounds[:-1] + if (dx < 0).any() or (dy < 0).any(): + raise Exception( + "Coordinates in RegularGridDataset must be strictly increasing" + ) self._x_step = (self.x_c[1:] - self.x_c[:-1]).mean() self._y_step = (self.y_c[1:] - self.y_c[:-1]).mean() @classmethod def with_array(cls, coordinates, datas, variables_description=None, **kwargs): + """ + Geo matrix data must be ordered like this (X,Y) and masked with numpy.ma.array + """ vd = dict() if variables_description is None else variables_description x_name, y_name = coordinates[0], coordinates[1] obj = cls("array", x_name, y_name, unset=True, **kwargs) @@ -1182,11 +1256,6 @@ def with_array(cls, coordinates, datas, variables_description=None, **kwargs): obj.setup_coordinates() return obj - def init_pos_interpolator(self): - """Create function to have a quick index interpolator""" - self.xinterp = arange(self.x_bounds.shape[0]) - self.yinterp = arange(self.y_bounds.shape[0]) - def bbox_indice(self, vertices): return bbox_indice_regular( vertices, @@ -1201,10 +1270,10 @@ def bbox_indice(self, vertices): def get_pixels_in(self, contour): """ - Get indices of pixels in contour. + Get indexes of pixels in contour. - :param vertice,Path contour: Contour which enclosed some pixels - :return: Indices of grid in contour + :param vertice,Path contour: Contour that encloses some pixels + :return: Indexes of grid in contour :rtype: array[int],array[int] """ if isinstance(contour, BasePath): @@ -1237,7 +1306,7 @@ def ystep(self): return self._y_step def compute_pixel_path(self, x0, y0, x1, y1): - """Give a series of indexes which describe the path between to position""" + """Give a series of indexes describing the path between two positions""" return compute_pixel_path( x0, y0, @@ -1250,9 +1319,13 @@ def compute_pixel_path(self, x0, y0, x1, y1): self.x_size, ) - def clean_land(self): + def clean_land(self, name): """Function to remove all land pixel""" - pass + mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat") + x, y = meshgrid(self.x_c, self.y_c) + m = mask_land.interp("mask", x.reshape(-1), y.reshape(-1), "nearest") + data = self.grid(name) + self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T) def is_circular(self): """Check if the grid is circular""" @@ -1325,6 +1398,24 @@ def kernel_lanczos(self, lat, wave_length, order=1): kernel[dist_norm > order] = 0 return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_loess(self, lat, wave_length, order=1): + """ + https://fr.wikipedia.org/wiki/R%C3%A9gression_locale + """ + order = self.check_order(order) + half_x_pt, half_y_pt, dist_norm = self.estimate_kernel_shape( + lat, wave_length, order + ) + + def inc_func(xdist): + f = zeros(xdist.size) + f[abs(xdist) < 1] = 1 + return f + + kernel = (1 - abs(dist_norm) ** 3) ** 3 + kernel[abs(dist_norm) > order] = 0 + return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_bessel(self, lat, wave_length, order=1): """wave_length in km order must be int @@ -1400,7 +1491,8 @@ def convolve_filter_with_dynamic_kernel( tmp_matrix = ma.zeros((2 * d_lon + data.shape[0], k_shape[1])) tmp_matrix.mask = ones(tmp_matrix.shape, dtype=bool) # Slice to apply on input data - sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat, data.shape[1])) + # +1 for upper bound, to take in acount this column + sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat + 1, data.shape[1])) # slice to apply on temporary matrix to store input data sl_lat_in = slice( d_lat - (i - sl_lat_data.start), d_lat + (sl_lat_data.stop - i) @@ -1480,13 +1572,13 @@ def bessel_high_filter(self, grid_name, wave_length, order=1, lat_max=85, **kwar :param str grid_name: grid to filter, data will replace original one :param float wave_length: in km :param int order: order to use, if > 1 negative values of the cardinal sinus are present in kernel - :param float lat_max: absolute latitude above no filtering apply + :param float lat_max: absolute latitude, no filtering above :param dict kwargs: look at :py:meth:`RegularGridDataset.convolve_filter_with_dynamic_kernel` .. minigallery:: py_eddy_tracker.RegularGridDataset.bessel_high_filter """ logger.debug( - "Run filtering with wave of %(wave_length)s km and order of %(order)s ...", + "Run filtering with wavelength of %(wave_length)s km and order of %(order)s ...", dict(wave_length=wave_length, order=order), ) data_out = self.convolve_filter_with_dynamic_kernel( @@ -1578,7 +1670,7 @@ def spectrum_lonlat(self, grid_name, area=None, ref=None, **kwargs): (lat_content[0], lat_content[1] / ref_lat_content[1]), ) - def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False): + def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False, second=False): if not isinstance(schema, int) and schema < 1: raise Exception("schema must be a positive int") @@ -1601,12 +1693,17 @@ def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=Fal data1[-schema:] = nan data2[:schema] = nan + # Distance for one degree d = self.EARTH_RADIUS * 2 * pi / 360 * 2 * schema + # Mulitply by 2 step if vertical: - d *= self.ystep + d *= self.ystep else: d *= self.xstep * cos(deg2rad(self.y_c)) - return (data1 - data2) / d + if second: + return (data1 + data2 - 2 * data) / (d ** 2 / 4) + else: + return (data1 - data2) / d def compute_stencil( self, data, stencil_halfwidth=4, mode="reflect", vertical=False @@ -1679,85 +1776,28 @@ def compute_stencil( ... - """ stencil_halfwidth = max(min(int(stencil_halfwidth), 4), 1) logger.debug("Stencil half width apply : %d", stencil_halfwidth) - # output - grad = None - - weights = [ - array((3, -32, 168, -672, 0, 672, -168, 32, -3)) / 840.0, - array((-1, 9, -45, 0, 45, -9, 1)) / 60.0, - array((1, -8, 0, 8, -1)) / 12.0, - array((-1, 0, 1)) / 2.0, - # uncentered kernel - # like array((0, -1, 1)) but left value could be default value - array((-1, 1)), - # like array((-1, 1, 0)) but right value could be default value - (1, array((-1, 1))), - ] - # reduce to stencil selected - weights = weights[4 - stencil_halfwidth :] - if vertical: - data = data.T - # Iteration from larger stencil to smaller (to fill matrix) - for weight in weights: - if isinstance(weight, tuple): - # In the case of unbalanced diff - shift, weight = weight - data_ = data.copy() - data_[shift:] = data[:-shift] - if not vertical: - data_[:shift] = data[-shift:] - else: - data_ = data - # Delta h - d_h = convolve(data_, weights=weight.reshape((-1, 1)), mode=mode) - mask = convolve( - int8(data_.mask), weights=ones(weight.shape).reshape((-1, 1)), mode=mode - ) - d_h = ma.array(d_h, mask=mask != 0) - - # Delta d - if vertical: - d_h = d_h.T - d = self.EARTH_RADIUS * 2 * pi / 360 * convolve(self.y_c, weight) - else: - if mode == "wrap": - # Along x axis, we need to close - # we will compute in two part - x = self.x_c % 360 - d_degrees = convolve(x, weight, mode=mode) - d_degrees_180 = convolve((x + 180) % 360 - 180, weight, mode=mode) - # Arbitrary, to be sure to be far far away of bound - m = (x < 90) + (x > 270) - d_degrees[m] = d_degrees_180[m] - else: - d_degrees = convolve(self.x_c, weight, mode=mode) - d = ( - self.EARTH_RADIUS - * 2 - * pi - / 360 - * d_degrees.reshape((-1, 1)) - * cos(deg2rad(self.y_c)) - ) - if grad is None: - # First Gradient - grad = d_h / d - else: - # Fill hole - grad[grad.mask] = (d_h / d)[grad.mask] - return grad + g, m = compute_stencil( + self.x_c, + self.y_c, + data.data, + self.get_mask(data), + self.EARTH_RADIUS, + vertical=vertical, + stencil_halfwidth=stencil_halfwidth, + ) + return ma.array(g, mask=m) - def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): - self.add_uv(grid_height, uname, vname) + def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15, **kwargs): + self.add_uv(grid_height, uname, vname, **kwargs) latmax = 5 - _, (i_start, i_end) = self.nearest_grd_indice((0, 0), (-latmax, latmax)) + _, i_start = self.nearest_grd_indice(0, -latmax) + _, i_end = self.nearest_grd_indice(0, latmax) sl = slice(i_start, i_end) # Divide by sideral day - lat = self.y_c[sl] + lat = self.y_c gob = ( cos(deg2rad(lat)) * ones((self.x_c.shape[0], 1)) @@ -1771,48 +1811,48 @@ def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): mode = "wrap" if self.is_circular() else "reflect" # fill data to compute a finite difference on all point - data = self.convolve_filter_with_dynamic_kernel( - grid_height, - self.kernel_bessel, - lat_max=10, - wave_length=500, - order=1, - extend=0.1, - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) + kw_filter = dict(kernel_func=self.kernel_bessel, order=1, extend=.1) + data = self.convolve_filter_with_dynamic_kernel(grid_height, wave_length=500, **kw_filter, lat_max=6+5+2+3) v_lagerloef = ( self.compute_finite_difference( - self.compute_finite_difference(data, mode=mode, schema=schema), - mode=mode, - schema=schema, - )[:, sl] - * gob - ) - u_lagerloef = ( - -self.compute_finite_difference( - self.compute_finite_difference(data, vertical=True, schema=schema), - vertical=True, - schema=schema, - )[:, sl] + self.compute_finite_difference(data, mode=mode, schema=1), + vertical=True, schema=1 + ) * gob ) - w = 1 - exp(-((lat / 2.2) ** 2)) - self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef * (1 - w) - self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef * (1 - w) + u_lagerloef = -self.compute_finite_difference(data, vertical=True, schema=schema, second=True) * gob + + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6) + w = 1 - exp(-((lat[sl] / 2.2) ** 2)) + self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef[:, sl] * (1 - w) + self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef[:, sl] * (1 - w) def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): - """Compute a u and v grid + r"""Compute a u and v grid :param str grid_height: grid name where the funtion will apply stencil method :param str uname: future name of u :param str vname: future name of v :param int stencil_halfwidth: largest stencil could be apply (max: 4) + .. math:: + u = \frac{g}{f} \frac{dh}{dy} + + v = -\frac{g}{f} \frac{dh}{dx} + + where + + .. math:: + g = gravity + + f = 2 \Omega sin(\phi) + + .. minigallery:: py_eddy_tracker.RegularGridDataset.add_uv """ logger.info("Add u/v variable with stencil method") @@ -1857,13 +1897,13 @@ def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): ) def speed_coef_mean(self, contour): - """Some nan can be computed over contour if we are near border, + """Some nan can be computed over contour if we are near borders, something to explore """ return mean_on_regular_contour( self.x_c, self.y_c, - self._speed_ev, + self._speed_ev.data, self._speed_ev.mask, contour.vertices, nan_remove=True, @@ -1871,14 +1911,15 @@ def speed_coef_mean(self, contour): def init_speed_coef(self, uname="u", vname="v"): """Draft""" - self._speed_ev = (self.grid(uname) ** 2 + self.grid(vname) ** 2) ** 0.5 + u, v = self.grid(uname), self.grid(vname) + self._speed_ev = sqrt(u * u + v * v) def display(self, ax, name, factor=1, ref=None, **kwargs): """ - :param matplotlib.axes.Axes ax: matplotlib axes use to draw + :param matplotlib.axes.Axes ax: matplotlib axes used to draw :param str,array name: variable to display, could be an array :param float factor: multiply grid by - :param float,None ref: if define use like west bound + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.pcolormesh` .. minigallery:: py_eddy_tracker.RegularGridDataset.display @@ -1897,10 +1938,10 @@ def display(self, ax, name, factor=1, ref=None, **kwargs): def contour(self, ax, name, factor=1, ref=None, **kwargs): """ - :param matplotlib.axes.Axes ax: matplotlib axes use to draw + :param matplotlib.axes.Axes ax: matplotlib axes used to draw :param str,array name: variable to display, could be an array :param float factor: multiply grid by - :param float,None ref: if define use like west bound + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.contour` .. minigallery:: py_eddy_tracker.RegularGridDataset.contour @@ -1951,53 +1992,293 @@ def interp(self, grid_name, lons, lats, method="bilinear"): :return: new z """ g = self.grid(grid_name) - if len(g.mask.shape): - m = g.mask - else: - m = ones(g.shape) if g.mask else zeros(g.shape) + m = self.get_mask(g) return interp2d_geo( - self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest" + self.x_c, self.y_c, g.data, m, lons, lats, nearest=method == "nearest" ) + def uv_for_advection( + self, + u_name=None, + v_name=None, + time_step=600, + h_name=None, + backward=False, + factor=1, + ): + """ + Get U,V to be used in degrees with precomputed time step + + :param None,str,array u_name: U field to advect obs, if h_name is None + :param None,str,array v_name: V field to advect obs, if h_name is None + :param None,str,array h_name: H field to compute UV to advect obs, if u_name and v_name are None + :param int time_step: Number of second for each advection + """ + if h_name is not None: + u_name, v_name = "u", "v" + if u_name not in self.vars: + self.add_uv(h_name) + self.vars.pop(h_name, None) + + u = (self.grid(u_name) if isinstance(u_name, str) else u_name).copy() + v = (self.grid(v_name) if isinstance(v_name, str) else v_name).copy() + # N seconds / 1 degrees in m + coef = time_step * 180 / pi / self.EARTH_RADIUS * factor + u *= coef / cos(radians(self.y_c)) + v *= coef + if backward: + u = -u + v = -v + m = u.mask + v.mask + return u.data, v.data, m + + def advect(self, x, y, u_name, v_name, nb_step=10, rk4=True, **kw): + """ + At each call it will update position in place with u & v field + + It's a dummy advection using only one layer of current + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param str,array u_name: U field to advect obs + :param str,array v_name: V field to advect obs + :param int nb_step: Number of iterations before releasing data + + .. minigallery:: py_eddy_tracker.GridDataset.advect + """ + u, v, m = self.uv_for_advection(u_name, v_name, **kw) + m_p = isnan(x) + isnan(y) + advect_ = advect_rk4 if rk4 else advect + while True: + advect_(self.x_c, self.y_c, u, v, m, x, y, m_p, nb_step) + yield x, y + + def filament( + self, x, y, u_name, v_name, nb_step=10, filament_size=6, rk4=True, **kw + ): + """ + Produce filament with concatenation of advection + + It's a dummy advection using only one layer of current + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param str,array u_name: U field to advect obs + :param str,array v_name: V field to advect obs + :param int nb_step: Number of iteration before releasing data + :param int filament_size: Number of point by filament + :return: x,y for a line + + .. minigallery:: py_eddy_tracker.GridDataset.filament + """ + u, v, m = self.uv_for_advection(u_name, v_name, **kw) + x, y = x.copy(), y.copy() + nb = x.shape[0] + filament_size_ = filament_size + 1 + f_x = empty(nb * filament_size_, dtype="f4") + f_y = empty(nb * filament_size_, dtype="f4") + f_x[:] = nan + f_y[:] = nan + f_x[::filament_size_] = x + f_y[::filament_size_] = y + mp = isnan(x) + isnan(y) + advect_ = advect_rk4 if rk4 else advect + while True: + # Shift position + f_x[1:] = f_x[:-1] + f_y[1:] = f_y[:-1] + # Remove last position + f_x[filament_size::filament_size_] = nan + f_y[filament_size::filament_size_] = nan + advect_(self.x_c, self.y_c, u, v, m, x, y, mp, nb_step) + f_x[::filament_size_] = x + f_y[::filament_size_] = y + yield f_x, f_y + + +@njit(cache=True) +def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # cache + i_cache, j_cache = -1000000, -1000000 + masked = False + u00, u01, u10, u11 = 0.0, 0.0, 0.0, 0.0 + v00, v01, v10, v11 = 0.0, 0.0, 0.0, 0.0 + # On each particle + for i in prange(x.size): + # If particle is not valid => continue + if m[i]: + continue + x_, y_ = x[i], y[i] + # Iterate on whole steps + for _ in range(nb_step): + # k1, slope at origin + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x_, y_, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + # The 3 following could be in cache operation but this one must be tested in any case + if masked: + x_, y_ = nan, nan + m[i] = True + break + u1, v1 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # k2, slope at middle with first guess position + x1, y1 = x_ + u1 * 0.5, y_ + v1 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x1, y1, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + if masked: + x_, y_ = nan, nan + m[i] = True + break + u2, v2 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # k3, slope at middle with updated guess position + x2, y2 = x_ + u2 * 0.5, y_ + v2 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x2, y2, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + if masked: + x_, y_ = nan, nan + m[i] = True + break + u3, v3 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # k4, slope at end with updated guess position + x3, y3 = x_ + u3, y_ + v3 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x3, y3, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + ii_, jj_, u_g, v_g, m_g, nb_x + ) + if masked: + x_, y_ = nan, nan + m[i] = True + break + u4, v4 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # RK4 compute + dx = (u1 + 2 * u2 + 2 * u3 + u4) / 6 + dy = (v1 + 2 * v2 + 2 * v3 + v4) / 6 + # Compute new x,y + x_ += dx + y_ += dy + x[i] = x_ + y[i] = y_ + + +@njit(cache=True) +def advect(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # Indexes which should be never exist + i0_old, j0_old = -100000, -100000 + masked = False + u00, u01, u10, u11 = 0.0, 0.0, 0.0, 0.0 + v00, v01, v10, v11 = 0.0, 0.0, 0.0, 0.0 + # On each particule + for i in prange(x.size): + # If particule is not valid => continue + if m[i]: + continue + # Iterate on whole steps + for _ in range(nb_step): + i0, j0, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x[i], y[i], nb_x + ) + # corners are the same, need only a new xd and yd + if i0 != i0_old or j0 != j0_old: + # Need to be stored only on change + i0_old, j0_old = i0, j0 + if not is_circular and (i0 < 0 or i0 > nb_x_): + masked = True + else: + masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( + i0, j0, u_g, v_g, m_g, nb_x + ) + if masked: + x[i], y[i] = nan, nan + m[i] = True + break + u, v = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) + # Compute new x,y + x[i] += u + y[i] += v + @njit(cache=True, fastmath=True) def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): - """Give a serie of indexes describing the path between two position""" + """Give a serie of indexes describing the path between two positions""" # index nx = x0.shape[0] i_x0 = empty(nx, dtype=numba_types.int_) i_x1 = empty(nx, dtype=numba_types.int_) i_y0 = empty(nx, dtype=numba_types.int_) i_y1 = empty(nx, dtype=numba_types.int_) - # Because round_ is not accepted with array in numba for i in range(nx): - i_x0[i] = round_(((x0[i] - x_ori) % 360) / x_step) - i_x1[i] = round_(((x1[i] - x_ori) % 360) / x_step) - i_y0[i] = round_((y0[i] - y_ori) / y_step) - i_y1[i] = round_((y1[i] - y_ori) / y_step) + i_x0[i] = np.round(((x0[i] - x_ori) % 360) / x_step) + i_x1[i] = np.round(((x1[i] - x_ori) % 360) / x_step) + i_y0[i] = np.round((y0[i] - y_ori) / y_step) + i_y1[i] = np.round((y1[i] - y_ori) / y_step) # Delta index of x d_x = i_x1 - i_x0 d_x = (d_x + nb_x // 2) % nb_x - (nb_x // 2) i_x1 = i_x0 + d_x # Delta index of y d_y = i_y1 - i_y0 - # max and abs sum doesn't work on array? + # max and abs sum do not work on array? d_max = empty(nx, dtype=numba_types.int32) nb_value = 0 for i in range(nx): d_max[i] = max(abs(d_x[i]), abs(d_y[i])) - # Compute number of pixel which we go trought + # Compute number of pixel we go trought nb_value += d_max[i] + 1 - # Create an empty array to store value of pixel across the travel + # Create an empty array to store value of pixel across the path i_g = empty(nb_value, dtype=numba_types.int32) j_g = empty(nb_value, dtype=numba_types.int32) # Index to determine the position in the global array ii = 0 - # Iteration on each travel + # Iteration on each path for i, delta in enumerate(d_max): - # If the travel don't cross multiple pixel + # If the path doesn't cross multiple pixels if delta == 0: i_g[ii : ii + delta + 1] = i_x0[i] j_g[ii : ii + delta + 1] = i_y0[i] @@ -2011,7 +2292,7 @@ def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): sup = -1 if d_x[i] < 0 else 1 i_g[ii : ii + delta + 1] = arange(i_x0[i], i_x1[i] + sup, sup) j_g[ii : ii + delta + 1] = i_y0[i] - # In case of multiple direction + # In case of multiple directions else: a = (i_x1[i] - i_x0[i]) / float(i_y1[i] - i_y0[i]) if abs(d_x[i]) >= abs(d_y[i]): @@ -2047,3 +2328,676 @@ def has_value(grid, i_x, i_y, value, below=False): if grid[i, j] > value: return True return False + + +class GridCollection: + def __init__(self): + self.datasets = list() + + @classmethod + def from_netcdf_cube(cls, filename, x_name, y_name, t_name, heigth=None, **kwargs): + new = cls() + with Dataset(filename) as h: + for i, t in enumerate(h.variables[t_name][:]): + d = RegularGridDataset( + filename, x_name, y_name, indexs={t_name: i}, **kwargs + ) + if heigth is not None: + d.add_uv(heigth) + new.datasets.append((t, d)) + return new + + @classmethod + def from_netcdf_list( + cls, filenames, t, x_name, y_name, indexs=None, heigth=None, **kwargs + ): + new = cls() + for i, _t in enumerate(t): + filename = filenames[i] + logger.debug(f"load file {i:02d}/{len(t)} t={_t} : {filename}") + d = RegularGridDataset(filename, x_name, y_name, indexs=indexs, **kwargs) + if heigth is not None: + d.add_uv(heigth) + new.datasets.append((_t, d)) + return new + + @property + def are_loaded(self): + return ~array([d.dimensions is None for _, d in self.datasets]) + + def __repr__(self): + nb_dataset = len(self.datasets) + return f"{self.are_loaded.sum()}/{nb_dataset} datasets loaded" + + def shift_files(self, t, filename, heigth=None, **rgd_kwargs): + """Add next file to the list and remove the oldest""" + + self.datasets = self.datasets[1:] + + d = RegularGridDataset(filename, **rgd_kwargs) + if heigth is not None: + d.add_uv(heigth) + self.datasets.append((t, d)) + logger.debug(f"shift and adding i={len(self.datasets)} t={t} : {filename}") + + def interp(self, grid_name, t, lons, lats, method="bilinear"): + """ + Compute z over lons, lats + + :param str grid_name: Grid to be interpolated + :param float, t: time for interpolation + :param lons: new x + :param lats: new y + :param str method: Could be 'bilinear' or 'nearest' + + :return: new z + """ + # FIXME: we do assumption on time step + t0 = int(t) + t1 = t0 + 1 + h0, h1 = self[t0], self[t1] + g0, g1 = h0.grid(grid_name), h1.grid(grid_name) + m0, m1 = h0.get_mask(g0), h0.get_mask(g1) + kw = dict(x=lons, y=lats, nearest=method == "nearest") + v0 = interp2d_geo(h0.x_c, h0.y_c, g0, m0, **kw) + v1 = interp2d_geo(h1.x_c, h1.y_c, g1, m1, **kw) + w = (t - t0) / (t1 - t0) + return v1 * w + v0 * (1 - w) + + def __iter__(self): + for _, d in self.datasets: + yield d + + @property + def time(self): + return array([t for t, _ in self.datasets]) + + @property + def period(self): + t = self.time + return t.min(), t.max() + + def __getitem__(self, item): + for t, d in self.datasets: + if t == item: + d.populate() + return d + raise KeyError(item) + + def filament( + self, + x, + y, + u_name, + v_name, + t_init, + nb_step=10, + time_step=600, + filament_size=6, + rk4=True, + **kw, + ): + """ + Produce filament with concatenation of advection + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param str,array u_name: U field to advect obs + :param str,array v_name: V field to advect obs + :param int nb_step: Number of iteration before to release data + :param int time_step: Number of second for each advection + :param int filament_size: Number of point by filament + :return: x,y for a line + + .. minigallery:: py_eddy_tracker.GridCollection.filament + """ + x, y = x.copy(), y.copy() + nb = x.shape[0] + filament_size_ = filament_size + 1 + f_x = empty(nb * filament_size_, dtype="f4") + f_y = empty(nb * filament_size_, dtype="f4") + f_x[:] = nan + f_y[:] = nan + f_x[::filament_size_] = x + f_y[::filament_size_] = y + + backward = kw.get("backward", False) + if backward: + generator = self.get_previous_time_step(t_init) + dt = -nb_step * time_step + t_step = -time_step + else: + generator = self.get_next_time_step(t_init) + dt = nb_step * time_step + t_step = time_step + t0, d0 = generator.__next__() + u0, v0, m0 = d0.uv_for_advection(u_name, v_name, time_step, **kw) + t1, d1 = generator.__next__() + u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + t0 = t0 * 86400 + t1 = t1 * 86400 + t = t_init * 86400 + mp = isnan(x) + isnan(y) + advect_ = advect_t_rk4 if rk4 else advect_t + while True: + # Shift position + f_x[1:] = f_x[:-1] + f_y[1:] = f_y[:-1] + # Remove last position + f_x[filament_size::filament_size_] = nan + f_y[filament_size::filament_size_] = nan + + if (backward and t <= t1) or (not backward and t >= t1): + t0, u0, v0, m0 = t1, u1, v1, m1 + t1, d1 = generator.__next__() + t1 = t1 * 86400 + u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0) + half_w = t_step / 2.0 / (t1 - t0) + advect_(d0.x_c, d0.y_c, u0, v0, m0, u1, v1, m1, x, y, mp, w, half_w=half_w) + f_x[::filament_size_] = x + f_y[::filament_size_] = y + t += dt + yield t, f_x, f_y + + def reset_grids(self, N=None): + if N is not None: + m = self.are_loaded + if m.sum() > N: + for i in where(m)[0]: + self.datasets[i][1].clean() + + def advect( + self, + x, + y, + t_init, + mask_particule=None, + nb_step=10, + time_step=600, + rk4=True, + reset_grid=None, + **kw, + ): + """ + At each call it will update position in place with u & v field + + :param array x: Longitude of obs to move + :param array y: Latitude of obs to move + :param float t_init: time to start advection + :param array,None mask_particule: advect only i mask is True + :param int nb_step: Number of iteration before to release data + :param int time_step: Number of second for each advection + :param bool rk4: Use rk4 algorithm instead of finite difference + :param int reset_grid: Delete all loaded data in cube if there are more than N grid loaded + + :return: t,x,y position + + .. minigallery:: py_eddy_tracker.GridCollection.advect + """ + self.reset_grids(reset_grid) + backward = kw.get("backward", False) + if backward: + generator = self.get_previous_time_step(t_init) + dt = -nb_step * time_step + t_step = -time_step + else: + generator = self.get_next_time_step(t_init) + dt = nb_step * time_step + t_step = time_step + t0, d0 = generator.__next__() + u0, v0, m0 = d0.uv_for_advection(time_step=time_step, **kw) + t1, d1 = generator.__next__() + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) + t0 = t0 * 86400 + t1 = t1 * 86400 + t = t_init * 86400 + advect_ = advect_t_rk4 if rk4 else advect_t + if mask_particule is None: + mask_particule = isnan(x) + isnan(y) + else: + mask_particule += isnan(x) + isnan(y) + while True: + logger.debug(f"advect : t={t/86400}") + if (backward and t <= t1) or (not backward and t >= t1): + t0, u0, v0, m0 = t1, u1, v1, m1 + t1, d1 = generator.__next__() + t1 = t1 * 86400 + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) + w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0) + half_w = t_step / 2.0 / (t1 - t0) + advect_( + d0.x_c, + d0.y_c, + u0, + v0, + m0, + u1, + v1, + m1, + x, + y, + mask_particule, + w, + half_w=half_w, + ) + t += dt + yield t, x, y + + def get_next_time_step(self, t_init): + for i, (t, dataset) in enumerate(self.datasets): + if t < t_init: + continue + dataset.populate() + logger.debug(f"i={i}, t={t}, dataset={dataset}") + yield t, dataset + + def get_previous_time_step(self, t_init): + i = len(self.datasets) + for t, dataset in reversed(self.datasets): + i -= 1 + if t > t_init: + continue + dataset.populate() + logger.debug(f"i={i}, t={t}, dataset={dataset}") + yield t, dataset + + def path(self, x0, y0, *args, nb_time=2, **kwargs): + """ + At each call it will update position in place with u & v field + + :param array x0: Longitude of obs to move + :param array y0: Latitude of obs to move + :param int nb_time: Number of iteration for particle + :param dict kwargs: look at :py:meth:`GridCollection.advect` + + :return: t,x,y + + .. minigallery:: py_eddy_tracker.GridCollection.path + """ + particles = self.advect(x0.copy(), y0.copy(), *args, **kwargs) + t = empty(nb_time + 1, dtype="f8") + x = empty((nb_time + 1, x0.size), dtype=x0.dtype) + y = empty(x.shape, dtype=y0.dtype) + t[0], x[0], y[0] = kwargs.get("t_init"), x0, y0 + for i in range(nb_time): + t[i + 1], x[i + 1], y[i + 1] = particles.__next__() + return t, x, y + + +@njit(cache=True) +def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, half_w=0): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # Indexes that should never exist + i0_old, j0_old = -100000, -100000 + m0, m1 = False, False + u000, u001, u010, u011 = 0.0, 0.0, 0.0, 0.0 + v000, v001, v010, v011 = 0.0, 0.0, 0.0, 0.0 + u100, u101, u110, u111 = 0.0, 0.0, 0.0, 0.0 + v100, v101, v110, v111 = 0.0, 0.0, 0.0, 0.0 + # On each particle + for i in prange(x.size): + # If particle is not valid => continue + if m[i]: + continue + # Iterate on whole steps + for w in weigths: + i0, j0, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x[i], y[i], nb_x + ) + if i0 != i0_old or j0 != j0_old: + # Need to be stored only on change + i0_old, j0_old = i0, j0 + if not is_circular and (i0 < 0 or i0 > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + i0, j0, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + i0, j0, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x[i], y[i] = nan, nan + m[i] = True + break + # Compute distance + xd_i, yd_i = 1 - xd, 1 - yd + # Compute new x,y + dx0 = (u000 * xd_i + u010 * xd) * yd_i + (u001 * xd_i + u011 * xd) * yd + dx1 = (u100 * xd_i + u110 * xd) * yd_i + (u101 * xd_i + u111 * xd) * yd + dy0 = (v000 * xd_i + v010 * xd) * yd_i + (v001 * xd_i + v011 * xd) * yd + dy1 = (v100 * xd_i + v110 * xd) * yd_i + (v101 * xd_i + v111 * xd) * yd + x[i] += dx0 * w + dx1 * (1 - w) + y[i] += dy0 * w + dy1 * (1 - w) + + +@njit(cache=True, fastmath=True) +def get_uv_quad(i0, j0, u, v, m, nb_x=0): + """ + Return u/v for (i0, j0), (i1, j0), (i0, j1), (i1, j1) + + :param int i0: indexes of longitude + :param int j0: indexes of latitude + :param array[float] u: current along x axis + :param array[float] v: current along y axis + :param array[bool] m: flag to know if position is valid + :param int nb_x: If different of 0 we check if wrapping is needed + + :return: if cell is valid 4 u, 4 v + :rtype: bool,float,float,float,float,float,float,float,float + """ + i1, j1 = i0 + 1, j0 + 1 + if nb_x != 0: + i1 %= nb_x + i_max, j_max = m.shape + + if i1 >= i_max or j1 >= j_max: + return True, nan, nan, nan, nan, nan, nan, nan, nan + + if m[i0, j0] or m[i0, j1] or m[i1, j0] or m[i1, j1]: + return True, nan, nan, nan, nan, nan, nan, nan, nan + # Extract value for u and v + u00, u01, u10, u11 = u[i0, j0], u[i0, j1], u[i1, j0], u[i1, j1] + v00, v01, v10, v11 = v[i0, j0], v[i0, j1], v[i1, j0], v[i1, j1] + return False, u00, u01, u10, u11, v00, v01, v10, v11 + + +@njit(cache=True, fastmath=True) +def get_grid_indices(x0, y0, x_step, y_step, x, y, nb_x=0): + """ + Return grid indexes and weight + + :param float x0: first longitude + :param float y0: first latitude + :param float x_step: longitude grid step + :param float y_step: latitude grid step + :param float x: longitude to interpolate + :param float y: latitude to interpolate + :param int nb_x: If different of 0 we check if wrapping is needed + + :return: indexes and weight + :rtype: int,int,float,float + """ + i, j = (x - x0) / x_step, (y - y0) / y_step + i0, j0 = int(floor(i)), int(floor(j)) + xd, yd = i - i0, j - j0 + if nb_x != 0: + i0 %= nb_x + return i0, j0, xd, yd + + +@njit(cache=True, fastmath=True) +def interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11): + """ + Return u/v interpolated in cell + + :param float xd: x weight + :param float yd: y weight + :param float u00: u lower left + :param float u01: u upper left + :param float u10: u lower right + :param float u11: u upper right + :param float v00: v lower left + :param float v01: v upper left + :param float v10: v lower right + :param float v11: v upper right + """ + xd_i, yd_i = 1 - xd, 1 - yd + u = (u00 * xd_i + u10 * xd) * yd_i + (u01 * xd_i + u11 * xd) * yd + v = (v00 * xd_i + v10 * xd) * yd_i + (v01 * xd_i + v11 * xd) * yd + return u, v + + +@njit(cache=True, fastmath=True) +def advect_t_rk4( + x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, half_w +): + # Grid coordinates + x_ref, y_ref = x_g[0], y_g[0] + x_step, y_step = x_g[1] - x_ref, y_g[1] - y_ref + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + nb_x_ = x_g.size + nb_x = nb_x_ if is_circular else 0 + # cache + i_cache, j_cache = -1000000, -1000000 + m0, m1 = False, False + u000, u001, u010, u011 = 0.0, 0.0, 0.0, 0.0 + v000, v001, v010, v011 = 0.0, 0.0, 0.0, 0.0 + u100, u101, u110, u111 = 0.0, 0.0, 0.0, 0.0 + v100, v101, v110, v111 = 0.0, 0.0, 0.0, 0.0 + # On each particle + for i in prange(x.size): + # If particle is not valid => continue + if m[i]: + continue + x_, y_ = x[i], y[i] + # Iterate on whole steps + for w in weigths: + # k1, slope at origin + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x_, y_, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + # The 3 following could be in cache operation but this one must be tested in any case + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + u1, v1 = u0_ * w + u1_ * (1 - w), v0_ * w + v1_ * (1 - w) + # k2, slope at middle with first guess position + x1, y1 = x_ + u1 * 0.5, y_ + v1 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x1, y1, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + w_ = w - half_w + u2, v2 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) + # k3, slope at middle with updated guess position + x2, y2 = x_ + u2 * 0.5, y_ + v2 * 0.5 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x2, y2, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + u3, v3 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) + # k4, slope at end with updated guess position + x3, y3 = x_ + u3, y_ + v3 + ii_, jj_, xd, yd = get_grid_indices( + x_ref, y_ref, x_step, y_step, x3, y3, nb_x + ) + if ii_ != i_cache or jj_ != j_cache: + i_cache, j_cache = ii_, jj_ + if not is_circular and (ii_ < 0 or ii_ > nb_x_): + m0, m1 = True, True + else: + (m0, u000, u001, u010, u011, v000, v001, v010, v011) = get_uv_quad( + ii_, jj_, u_g0, v_g0, m_g0, nb_x + ) + (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( + ii_, jj_, u_g1, v_g1, m_g1, nb_x + ) + if m0 or m1: + x_, y_ = nan, nan + m[i] = True + break + u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) + u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) + w_ -= half_w + u4, v4 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) + # RK4 compute + dx = (u1 + 2 * u2 + 2 * u3 + u4) / 6 + dy = (v1 + 2 * v2 + 2 * v3 + v4) / 6 + x_ += dx + y_ += dy + x[i], y[i] = x_, y_ + + +@njit( + [ + "Tuple((f8[:,:],b1[:,:]))(f8[:],f8[:],f8[:,:],b1[:,:],f8,b1,i1)", + "Tuple((f4[:,:],b1[:,:]))(f8[:],f8[:],f4[:,:],b1[:,:],f8,b1,i1)", + ], + cache=True, + fastmath=True, +) +def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth=4): + """ + Compute stencil on RegularGrid + + :param array x: longitude coordinates + :param array y: latitude coordinates + :param array h: 2D array to derivate + :param array m: mask associated to h to know where are invalid data + :param float earth_radius: Earth radius in m + :param bool vertical: if True stencil will be vertical (along y) + :param int stencil_halfwidth: from 1 to 4 to specify maximal kernel usable + + + stencil_halfwidth: + + - (1) : + + - (-1, 1, 0) + - (0, -1, 1) + - (-1, 0, 1) / 2 + + - (2) : (1, -8, 0, 8, 1) / 12 + - (3) : (-1, 9, -45, 0, 45, -9, 1) / 60 + - (4) : (3, -32, 168, -672, 0, 672, -168, 32, 3) / 840 + """ + if vertical: + # If vertical we transpose matrix and inverse coordinates + h = h.T + m = m.T + x, y = y, x + shape = h.shape + nb_x, nb_y = shape + # Out array + m_out = empty(shape, dtype=numba_types.bool_) + grad = empty(shape, dtype=h.dtype) + # Distance step in degrees + d_step = x[1] - x[0] + if vertical: + is_circular = False + else: + # Test if matrix is circular + is_circular = abs(x[-1] % 360 - (x[0] - d_step) % 360) < 1e-5 + + # Compute caracteristic distance, constant when vertical + d_ = 360 / (d_step * pi * 2 * earth_radius) + for j in range(nb_y): + # Buffer of maximal size of stencil (9) + if is_circular: + h_3, h_2, h_1, h0 = h[-4, j], h[-3, j], h[-2, j], h[-1, j] + m_3, m_2, m_1, m0 = m[-4, j], m[-3, j], m[-2, j], m[-1, j] + else: + m_3, m_2, m_1, m0 = True, True, True, True + h1, h2, h3, h4 = h[0, j], h[1, j], h[2, j], h[3, j] + m1, m2, m3, m4 = m[0, j], m[1, j], m[2, j], m[3, j] + for i in range(nb_x): + # Roll value and only last + h_4, h_3, h_2, h_1, h0, h1, h2, h3 = h_3, h_2, h_1, h0, h1, h2, h3, h4 + m_4, m_3, m_2, m_1, m0, m1, m2, m3 = m_3, m_2, m_1, m0, m1, m2, m3, m4 + i_ = i + 4 + if i_ >= nb_x: + if is_circular: + i_ = i_ % nb_x + m4 = m[i_, j] + h4 = h[i_, j] + else: + # When we are out + m4 = False + else: + m4 = m[i_, j] + h4 = h[i_, j] + + # Current value not defined + if m0: + m_out[i, j] = True + continue + if not vertical: + # For each row we compute distance + d_ = 360 / (d_step * cos(deg2rad(y[j])) * pi * 2 * earth_radius) + if m1 ^ m_1: + # unbalanced kernel + if m_1: + grad[i, j] = (h1 - h0) * d_ + m_out[i, j] = False + continue + if m1: + grad[i, j] = (h0 - h_1) * d_ + m_out[i, j] = False + continue + continue + if m2 or m_2 or stencil_halfwidth == 1: + grad[i, j] = (h1 - h_1) / 2 * d_ + m_out[i, j] = False + continue + if m3 or m_3 or stencil_halfwidth == 2: + grad[i, j] = (h_2 - h2 + 8 * (h1 - h_1)) / 12 * d_ + m_out[i, j] = False + continue + if m4 or m_4 or stencil_halfwidth == 3: + grad[i, j] = (h3 - h_3 + 9 * (h_2 - h2) + 45 * (h1 - h_1)) / 60 * d_ + m_out[i, j] = False + continue + # If all values of buffer are available + grad[i, j] = ( + (3 * (h_4 - h4) + 32 * (h3 - h_3) + 168 * (h_2 - h2) + 672 * (h1 - h_1)) + / 840 + * d_ + ) + m_out[i, j] = False + if vertical: + return grad.T, m_out.T + else: + return grad, m_out diff --git a/src/py_eddy_tracker/eddy_feature.py b/src/py_eddy_tracker/eddy_feature.py index 6d929ea0..8bc139ab 100644 --- a/src/py_eddy_tracker/eddy_feature.py +++ b/src/py_eddy_tracker/eddy_feature.py @@ -8,8 +8,7 @@ from matplotlib.cm import get_cmap from matplotlib.colors import Normalize from matplotlib.figure import Figure -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( array, concatenate, @@ -31,7 +30,7 @@ class Amplitude(object): """ Class to calculate *amplitude* and counts of *local maxima/minima* - within a closed region of a sea level anomaly field. + within a closed region of a sea surface height field. """ EPSILON = 1e-8 @@ -61,13 +60,13 @@ def __init__( """ Create amplitude object - :param Contours contour: - :param float contour_height: - :param array data: - :param float interval: - :param int mle: maximum number of local maxima in contour - :param int nb_step_min: number of interval to consider like an eddy - :param int nb_step_to_be_mle: number of interval to be consider like another maxima + :param Contours contour: usefull class defined below + :param float contour_height: field value of the contour + :param array data: grid + :param float interval: step between two contours + :param int mle: maximum number of local extrema in contour + :param int nb_step_min: minimum number of intervals to consider the contour as an eddy + :param int nb_step_to_be_mle: number of intervals to be considered as another extrema """ # Height of the contour @@ -102,8 +101,9 @@ def __init__( self.nb_pixel = i_x.shape[0] # Only pixel in contour + # FIXME : change sla by ssh as the grid can be adt? self.sla = data[contour.pixels_index] - # Amplitude which will be provide + # Amplitude which will be provided self.amplitude = 0 # Maximum local extrema accepted self.mle = mle @@ -117,7 +117,7 @@ def all_pixels_below_h0(self, level): Check CSS11 criterion 1: The SSH values of all of the pixels are below a given SSH threshold for cyclonic eddies. """ - # In some case pixel value must be very near of contour bounds + # In some cases pixel value may be very close to the contour bounds if self.sla.mask.any() or ((self.sla.data - self.h_0) > self.EPSILON).any(): return False else: @@ -173,6 +173,7 @@ def all_pixels_above_h0(self, level): self.mle, -1, ) + # After we use grid.data because index are in contour and we check before than no pixel are hide nb = len(lmi_i) if nb == 0: logger.warning( @@ -292,10 +293,10 @@ class Contours(object): Attributes: contour: - A matplotlib contour object of high-pass filtered SLA + A matplotlib contour object of high-pass filtered SSH eddy: - A tracklist object holding the SLA data + A tracklist object holding the SSH data grd: A grid object @@ -405,7 +406,7 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): fig = Figure() ax = fig.add_subplot(111) if wrap_x: - logger.debug("wrapping activate to compute contour") + logger.debug("wrapping activated to compute contour") x = concatenate((x, x[:1] + 360)) z = ma.concatenate((z, z[:1])) logger.debug("X shape : %s", x.shape) @@ -599,10 +600,10 @@ def display( 4. - Amplitude criterion (yellow) :param str field: Must be 'shape_error', 'x', 'y' or 'radius'. - If define display_criterion is not use. - bins argument must be define - :param array bins: bins use to colorize contour - :param str cmap: Name of cmap to use for field display + If defined display_criterion is not use. + bins argument must be defined + :param array bins: bins used to colorize contour + :param str cmap: Name of cmap for field display :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` .. minigallery:: py_eddy_tracker.Contours.display @@ -644,14 +645,20 @@ def display( paths.append(i.vertices) local_kwargs = kwargs.copy() if "color" not in kwargs: - local_kwargs["color"] = collection.get_color() + local_kwargs["color"] = collection.get_edgecolor() local_kwargs.pop("label", None) elif j != 0: local_kwargs.pop("label", None) if not overide_color: ax.add_collection(LineCollection(paths, **local_kwargs)) if display_criterion: - colors = {0: "g", 1: "r", 2: "b", 3: "k", 4: "y"} + colors = { + 0: "limegreen", + 1: "red", + 2: "mediumblue", + 3: "black", + 4: "gold", + } for k, v in paths.items(): local_kwargs = kwargs.copy() local_kwargs.pop("label", None) @@ -681,7 +688,7 @@ def display( ax.autoscale_view() def label_contour_unused_which_contain_eddies(self, eddies): - """Select contour which contain several eddies""" + """Select contour containing several eddies""" if eddies.sign_type == 1: # anticyclonic sl = slice(None, -1) @@ -776,7 +783,7 @@ def index_from_nearest_path_with_pt_in_bbox_( d_x = x_value[i_elt_pt] - xpt_ if abs(d_x) > 180: d_x = (d_x + 180) % 360 - 180 - dist = d_x ** 2 + (y_value[i_elt_pt] - ypt) ** 2 + dist = d_x**2 + (y_value[i_elt_pt] - ypt) ** 2 if dist < dist_ref: dist_ref = dist i_ref = i_elt_c diff --git a/src/py_eddy_tracker/featured_tracking/area_tracker.py b/src/py_eddy_tracker/featured_tracking/area_tracker.py index 5aa8e43c..9e676fc1 100644 --- a/src/py_eddy_tracker/featured_tracking/area_tracker.py +++ b/src/py_eddy_tracker/featured_tracking/area_tracker.py @@ -1,6 +1,7 @@ import logging -from numpy import ma +from numba import njit +from numpy import empty, ma, ones from ..observations.observation import EddiesObservations as Model @@ -8,6 +9,13 @@ class AreaTracker(Model): + """ + Area Tracker will used overlap to track eddy. + + This tracking will used :py:meth:`~py_eddy_tracker.observations.observation.EddiesObservations.match` method + to get a similarity index, which could be between [0:1]. + You could setup this class with `cmin` option to set a minimal value to valid an association. + """ __slots__ = ("cmin",) @@ -15,6 +23,11 @@ def __init__(self, *args, cmin=0.2, **kwargs): super().__init__(*args, **kwargs) self.cmin = cmin + def merge(self, *args, **kwargs): + eddies = super().merge(*args, **kwargs) + eddies.cmin = self.cmin + return eddies + @classmethod def needed_variable(cls): vars = ["longitude", "latitude"] @@ -22,13 +35,13 @@ def needed_variable(cls): return vars def tracking(self, other): + """ + Core method to track + """ shape = (self.shape[0], other.shape[0]) i, j, c = self.match(other, intern=False) - cost_mat = ma.empty(shape, dtype="f4") - cost_mat.mask = ma.ones(shape, dtype="bool") - m = c > self.cmin - i, j, c = i[m], j[m], c[m] - cost_mat[i, j] = 1 - c + cost_mat = ma.array(empty(shape, dtype="f4"), mask=ones(shape, dtype="bool")) + mask_cmin(i, j, c, self.cmin, cost_mat.data, cost_mat.mask) i_self, i_other = self.solve_function(cost_mat) i_self, i_other = self.post_process_link(other, i_self, i_other) @@ -50,3 +63,13 @@ def propagate( if nb_virtual_extend > 0: virtual[key][nb_dead:] = obs_to_extend[key] return virtual + + +@njit(cache=True) +def mask_cmin(i, j, c, cmin, cost_mat, mask): + for k in range(c.shape[0]): + c_ = c[k] + if c_ > cmin: + i_, j_ = i[k], j[k] + cost_mat[i_, j_] = 1 - c_ + mask[i_, j_] = False diff --git a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py index 7baaffd3..b0d4abfa 100644 --- a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py +++ b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py @@ -8,7 +8,6 @@ class CheltonTracker(Model): - __slots__ = tuple() GROUND = RegularGridDataset( @@ -21,13 +20,13 @@ def cost_function(records_in, records_out, distance): return distance def mask_function(self, other, distance): - """We mask link with ellips and ratio""" - # Compute Parameter of ellips + """We mask link with ellipse and ratio""" + # Compute Parameter of ellipse minor, major = 1.05, 1.5 - y = self.basic_formula_ellips_major_axis( + y = self.basic_formula_ellipse_major_axis( self.lat, degrees=True, c0=minor, cmin=minor, cmax=major, lat1=23, lat2=5 ) - # mask from ellips + # mask from ellipse mask = self.shifted_ellipsoid_degrees_mask( other, minor=minor, major=y # Minor can be bigger than major?? ) diff --git a/src/py_eddy_tracker/generic.py b/src/py_eddy_tracker/generic.py index a1da6b46..2fdb737a 100644 --- a/src/py_eddy_tracker/generic.py +++ b/src/py_eddy_tracker/generic.py @@ -3,8 +3,7 @@ Tool method which use mostly numba """ -from numba import njit, prange -from numba import types as numba_types +from numba import njit, prange, types as numba_types from numpy import ( absolute, arcsin, @@ -30,7 +29,7 @@ @njit(cache=True) def count_consecutive(mask): """ - Count consecutive event every False flag count restart + Count consecutive events every False flag count restart :param array[bool] mask: event to count :return: count when consecutive event @@ -50,7 +49,7 @@ def count_consecutive(mask): @njit(cache=True) def reverse_index(index, nb): """ - Compute a list of index, which are not in index. + Compute a list of indices, which are not in index. :param array index: index of group which will be set to False :param array nb: Count for each group @@ -65,12 +64,18 @@ def reverse_index(index, nb): @njit(cache=True) def build_index(groups): - """We expected that variable is monotonous, and return index for each step change. + """We expect that variable is monotonous, and return index for each step change. - :param array groups: array which contain group to be separated - :return: (first_index of each group, last_index of each group, value to shift group) + :param array groups: array that contains groups to be separated + :return: (first_index of each group, last_index of each group, value to shift groups) :rtype: (array, array, int) + + :Example: + + >>> build_index(array((1, 1, 3, 4, 4))) + (array([0, 2, 2, 3]), array([2, 2, 3, 5]), 1) """ + i0, i1 = groups.min(), groups.max() amplitude = i1 - i0 + 1 # Index of first observation for each group @@ -78,32 +83,32 @@ def build_index(groups): for i, group in enumerate(groups[:-1]): # Get next value to compare next_group = groups[i + 1] - # if different we need to set index + # if different we need to set index for all groups between the 2 values if group != next_group: first_index[group - i0 + 1 : next_group - i0 + 1] = i + 1 last_index = zeros(amplitude, dtype=numba_types.int_) last_index[:-1] = first_index[1:] - last_index[-1] = i + 2 + last_index[-1] = len(groups) return first_index, last_index, i0 @njit(cache=True) def hist_numba(x, bins): - """Call numba histogram to speed up.""" + """Call numba histogram to speed up.""" return histogram(x, bins) @njit(cache=True, fastmath=True, parallel=False) def distance_grid(lon0, lat0, lon1, lat1): """ - Get distance for every couple of point. + Get distance for every couple of points. :param array lon0: :param array lat0: :param array lon1: :param array lat1: - :return: nan value for far away point, and km for other + :return: nan value for far away points, and km for other :rtype: array """ nb_0 = lon0.shape[0] @@ -126,8 +131,8 @@ def distance_grid(lon0, lat0, lon1, lat1): sin_dlon = sin((dlon) * 0.5 * D2R) cos_lat1 = cos(lat0[i] * D2R) cos_lat2 = cos(lat1[j] * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - dist[i, j] = 6370.997 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + dist[i, j] = 6370.997 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) return dist @@ -148,8 +153,8 @@ def distance(lon0, lat0, lon1, lat1): sin_dlon = sin((lon1 - lon0) * 0.5 * D2R) cos_lat1 = cos(lat0 * D2R) cos_lat2 = cos(lat1 * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - return 6370997.0 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + return 6370997.0 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) @njit(cache=True) @@ -158,7 +163,7 @@ def cumsum_by_track(field, track): Cumsum by track. :param array field: data to sum - :pram array(int) track: id of track to separate data + :pram array(int) track: id of trajectories to separate data :return: cumsum with a reset at each start of track :rtype: array """ @@ -186,11 +191,29 @@ def interp2d_geo(x_g, y_g, z_g, m_g, x, y, nearest=False): :param array m_g: Boolean grid, True if value is masked :param array x: coordinate where interpolate z :param array y: coordinate where interpolate z - :param bool nearest: if true we will take nearest pixel + :param bool nearest: if True we will take nearest pixel + :return: z interpolated + :rtype: array + """ + if nearest: + return interp2d_nearest(x_g, y_g, z_g, x, y) + else: + return interp2d_bilinear(x_g, y_g, z_g, m_g, x, y) + + +@njit(cache=True, fastmath=True) +def interp2d_nearest(x_g, y_g, z_g, x, y): + """ + Nearest interpolation with wrapping if circular + + :param array x_g: coordinates of grid + :param array y_g: coordinates of grid + :param array z_g: Grid value + :param array x: coordinate where interpolate z + :param array y: coordinate where interpolate z :return: z interpolated :rtype: array """ - # TODO : Maybe test if we are out of bounds x_ref = x_g[0] y_ref = y_g[0] x_step = x_g[1] - x_ref @@ -199,57 +222,93 @@ def interp2d_geo(x_g, y_g, z_g, m_g, x, y, nearest=False): nb_y = y_g.shape[0] is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 z = empty(x.shape, dtype=z_g.dtype) + for i in prange(x.size): + i0 = int(round((x[i] - x_ref) / x_step)) + j0 = int(round((y[i] - y_ref) / y_step)) + if is_circular: + i0 %= nb_x + if i0 >= nb_x or i0 < 0 or j0 < 0 or j0 >= nb_y: + z[i] = nan + continue + z[i] = z_g[i0, j0] + return z + + +@njit(cache=True, fastmath=True) +def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): + """ + Bilinear interpolation with wrapping if circular + + :param array x_g: coordinates of grid + :param array y_g: coordinates of grid + :param array z_g: Grid value + :param array m_g: Boolean grid, True if value is masked + :param array x: coordinate where interpolate z + :param array y: coordinate where interpolate z + :return: z interpolated + :rtype: array + """ + x_ref = x_g[0] + y_ref = y_g[0] + x_step = x_g[1] - x_ref + y_step = y_g[1] - y_ref + nb_x = x_g.shape[0] + nb_y = y_g.shape[0] + is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 + # Indexes that should never exist + i0_old, j0_old, masked = -100000000, -10000000, False + z = empty(x.shape, dtype=z_g.dtype) for i in prange(x.size): x_ = (x[i] - x_ref) / x_step y_ = (y[i] - y_ref) / y_step i0 = int(floor(x_)) - i1 = i0 + 1 - xd = x_ - i0 + # To keep original values if wrapping applied to compute xd + i0_ = i0 j0 = int(floor(y_)) - j1 = j0 + 1 - if is_circular: - i0 %= nb_x - i1 %= nb_x - else: + # corners are the same need only a new xd and yd + if i0 != i0_old or j0 != j0_old: + i1 = i0 + 1 + j1 = j0 + 1 + if is_circular: + i0 %= nb_x + i1 %= nb_x if i1 >= nb_x or i0 < 0 or j0 < 0 or j1 >= nb_y: - z[i] = nan - continue - - yd = y_ - j0 - z00 = z_g[i0, j0] - z01 = z_g[i0, j1] - z10 = z_g[i1, j0] - z11 = z_g[i1, j1] - if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]: + masked = True + else: + masked = False + if not masked: + if m_g[i0, j0] or m_g[i0, j1] or m_g[i1, j0] or m_g[i1, j1]: + masked = True + else: + z00, z01, z10, z11 = ( + z_g[i0, j0], + z_g[i0, j1], + z_g[i1, j0], + z_g[i1, j1], + ) + masked = False + # Need to be stored only on change + i0_old, j0_old = i0, j0 + if masked: z[i] = nan else: - if nearest: - if xd <= 0.5: - if yd <= 0.5: - z[i] = z00 - else: - z[i] = z01 - else: - if yd <= 0.5: - z[i] = z10 - else: - z[i] = z11 - else: - z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + ( - z01 * (1 - xd) + z11 * xd - ) * yd + xd = x_ - i0_ + yd = y_ - j0 + z[i] = (z00 * (1 - xd) + (z10 * xd)) * (1 - yd) + ( + z01 * (1 - xd) + z11 * xd + ) * yd return z @njit(cache=True, fastmath=True) -def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): +def uniform_resample(x_val, y_val, num_fac=2, fixed_size=-1): """ Resample contours to have (nearly) equal spacing. :param array_like x_val: input x contour coordinates :param array_like y_val: input y contour coordinates :param int num_fac: factor to increase lengths of output coordinates - :param int,None fixed_size: if define, it will used to set sampling + :param int fixed_size: if > -1, will be used to set sampling """ nb = x_val.shape[0] # Get distances @@ -260,7 +319,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): dist[1:][dist[1:] < 1e-3] = 1e-3 dist = dist.cumsum() # Get uniform distances - if fixed_size is None: + if fixed_size == -1: fixed_size = dist.size * num_fac d_uniform = linspace(0, dist[-1], fixed_size) x_new = interp(d_uniform, dist, x_val) @@ -299,17 +358,17 @@ def flatten_line_matrix(l_matrix): @njit(cache=True) def simplify(x, y, precision=0.1): """ - Will remove all middle/end point which are closer than precision. + Will remove all middle/end points closer than precision. :param array x: :param array y: - :param float precision: if two points have distance inferior to precision with remove next point + :param float precision: if two points have distance inferior to precision we remove next point :return: (x,y) :rtype: (array,array) """ - precision2 = precision ** 2 + precision2 = precision**2 nb = x.shape[0] - # will be True for value keep + # will be True for kept values mask = ones(nb, dtype=bool_) for j in range(0, nb): x_previous, y_previous = x[j], y[j] @@ -339,7 +398,7 @@ def simplify(x, y, precision=0.1): if d_y > precision: x_previous, y_previous = x_, y_ continue - d2 = d_x ** 2 + d_y ** 2 + d2 = d_x**2 + d_y**2 if d2 > precision2: x_previous, y_previous = x_, y_ continue @@ -363,11 +422,11 @@ def split_line(x, y, i): :param y: array :param i: array of int at each i change, we cut x, y - :return: x and y separate by nan at each i jump + :return: x and y separated by nan at each i jump """ nb_jump = len(where(i[1:] - i[:-1] != 0)[0]) nb_value = x.shape[0] - final_size = (nb_jump - 1) + nb_value + final_size = nb_jump + nb_value new_x = empty(final_size, dtype=x.dtype) new_y = empty(final_size, dtype=y.dtype) new_j = 0 @@ -385,11 +444,11 @@ def split_line(x, y, i): @njit(cache=True) def wrap_longitude(x, y, ref, cut=False): """ - Will wrap contiguous longitude with reference as west bound. + Will wrap contiguous longitude with reference as western boundary. :param array x: :param array y: - :param float ref: longitude of reference, all the new value will be between ref and ref + 360 + :param float ref: longitude of reference, all the new values will be between ref and ref + 360 :param bool cut: if True line will be cut at the bounds :return: lon,lat :rtype: (array,array) @@ -397,17 +456,18 @@ def wrap_longitude(x, y, ref, cut=False): if cut: indexs = list() nb = x.shape[0] - new_previous = (x[0] - ref) % 360 + + new_x_previous = (x[0] - ref) % 360 + ref x_previous = x[0] for i in range(1, nb): x_ = x[i] - new_x = (x_ - ref) % 360 + new_x = (x_ - ref) % 360 + ref if not isnan(x_) and not isnan(x_previous): - d_new = new_x - new_previous + d_new = new_x - new_x_previous d = x_ - x_previous if abs(d - d_new) > 1e-5: indexs.append(i) - x_previous, new_previous = x_, new_x + x_previous, new_x_previous = x_, new_x nb_indexs = len(indexs) new_size = nb + nb_indexs * 3 @@ -418,6 +478,7 @@ def wrap_longitude(x, y, ref, cut=False): for i in range(nb): if j < nb_indexs and i == indexs[j]: j += 1 + # FIXME need check cor = 360 if x[i - 1] > x[i] else -360 out_x[i + i_] = (x[i] - ref) % 360 + ref - cor out_y[i + i_] = y[i] @@ -457,8 +518,8 @@ def coordinates_to_local(lon, lat, lon0, lat0): sin_dlon = sin(dlon * 0.5) cos_lat0 = cos(lat0 * D2R) cos_lat = cos(lat * D2R) - a_val = sin_dlon ** 2 * cos_lat0 * cos_lat + sin_dlat ** 2 - module = R * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat0 * cos_lat + sin_dlat**2 + module = R * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) azimuth = pi / 2 - arctan2( cos_lat * sin(dlon), @@ -481,7 +542,7 @@ def local_to_coordinates(x, y, lon0, lat0): """ D2R = pi / 180.0 R = 6370997 - d = (x ** 2 + y ** 2) ** 0.5 / R + d = (x**2 + y**2) ** 0.5 / R a = -(arctan2(y, x) - pi / 2) lat = arcsin(sin(lat0 * D2R) * cos(d) + cos(lat0 * D2R) * sin(d) * cos(a)) lon = ( @@ -497,7 +558,7 @@ def local_to_coordinates(x, y, lon0, lat0): @njit(cache=True, fastmath=True) def nearest_grd_indice(x, y, x0, y0, xstep, ystep): """ - Get nearest grid indice from a position. + Get nearest grid index from a position. :param x: longitude :param y: latitude @@ -515,7 +576,7 @@ def nearest_grd_indice(x, y, x0, y0, xstep, ystep): @njit(cache=True) def bbox_indice_regular(vertices, x0, y0, xstep, ystep, N, circular, x_size): """ - Get bbox indice of a contour in a regular grid. + Get bbox index of a contour in a regular grid. :param vertices: vertice of contour :param float x0: first grid longitude @@ -552,3 +613,48 @@ def build_circle(x0, y0, r): angle = radians(linspace(0, 360, 50)) x_norm, y_norm = cos(angle), sin(angle) return x_norm * r + x0, y_norm * r + y0 + + +def window_index(x, x0, half_window=1): + """ + Give for a fixed half_window each start and end index for each x0, in + an unsorted array. + + :param array x: array of value + :param array x0: array of window center + :param float half_window: half window + """ + # Sort array, bounds will be sort also + i_ordered = x.argsort(kind="mergesort") + return window_index_(x, i_ordered, x0, half_window) + + +@njit(cache=True) +def window_index_(x, i_ordered, x0, half_window=1): + nb_x, nb_pt = x.size, x0.size + first_index = empty(nb_pt, dtype=i_ordered.dtype) + last_index = empty(nb_pt, dtype=i_ordered.dtype) + # First bound to find + j_min, j_max = 0, 0 + x_min = x0[j_min] - half_window + x_max = x0[j_max] + half_window + # We iterate on ordered x + for i, i_x in enumerate(i_ordered): + x_ = x[i_x] + # if x bigger than x_min , we found bound and search next one + while x_ > x_min and j_min < nb_pt: + first_index[j_min] = i + j_min += 1 + x_min = x0[j_min] - half_window + # if x bigger than x_max , we found bound and search next one + while x_ > x_max and j_max < nb_pt: + last_index[j_max] = i + j_max += 1 + x_max = x0[j_max] + half_window + if j_max == nb_pt: + break + for i in range(j_min, nb_pt): + first_index[i] = nb_x + for i in range(j_max, nb_pt): + last_index[i] = nb_x + return i_ordered, first_index, last_index diff --git a/src/py_eddy_tracker/gui.py b/src/py_eddy_tracker/gui.py index 1485c08c..a85e9c18 100644 --- a/src/py_eddy_tracker/gui.py +++ b/src/py_eddy_tracker/gui.py @@ -4,13 +4,16 @@ """ from datetime import datetime, timedelta +import logging +from matplotlib.projections import register_projection import matplotlib.pyplot as plt import numpy as np -from matplotlib.projections import register_projection from .generic import flatten_line_matrix, split_line +logger = logging.getLogger("pet") + try: from pylook.axes import PlatCarreAxes except ImportError: @@ -26,12 +29,15 @@ def __init__(self, *args, **kwargs): self.set_aspect("equal") +GUI_AXES = "full_axes" + + class GUIAxes(PlatCarreAxes): """ - Axes which will use full space available + Axes that uses full space available """ - name = "full_axes" + name = GUI_AXES def end_pan(self, *args, **kwargs): (x0, x1), (y0, y1) = self.get_xlim(), self.get_ylim() @@ -88,7 +94,7 @@ def set_initial_values(self): for dataset in self.datasets.values(): t0_, t1_ = dataset.period t0, t1 = min(t0, t0_), max(t1, t1_) - + logger.debug("period detected %f -> %f", t0, t1) self.settings = dict(period=(t0, t1), now=t1) @property @@ -125,7 +131,7 @@ def med(self): def setup(self): self.figure = plt.figure() # map - self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection="full_axes") + self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection=GUI_AXES) self.map.grid() self.map.tick_params("both", pad=-22) # self.map.tick_params("y", pad=-22) @@ -146,6 +152,11 @@ def setup(self): # param self.param_ax = self.figure.add_axes((0, 0, 1, 0.15), facecolor="0.2") + def hide_path(self, state): + for name in self.datasets: + self.m[name]["path_previous"].set_visible(state) + self.m[name]["path_future"].set_visible(state) + def draw(self): self.m["mini_ax"] = self.figure.add_axes((0.3, 0.85, 0.4, 0.15), zorder=80) self.m["mini_ax"].grid() @@ -283,8 +294,8 @@ def get_infos(self, name, index): i_first = d.index_from_track[tr] track = d.obs[i_first : i_first + nb] nb -= 1 - t0 = timedelta(days=int(track[0]["time"])) + datetime(1950, 1, 1) - t1 = timedelta(days=int(track[-1]["time"])) + datetime(1950, 1, 1) + t0 = timedelta(days=track[0]["time"]) + datetime(1950, 1, 1) + t1 = timedelta(days=track[-1]["time"]) + datetime(1950, 1, 1) txt = f"--{name}--\n" txt += f" {t0} -> {t1}\n" txt += f" Tracks : {tr} {now['n']}/{nb} ({now['n'] / nb * 100:.2f} %)\n" diff --git a/src/py_eddy_tracker/misc.py b/src/py_eddy_tracker/misc.py new file mode 100644 index 00000000..647bfba3 --- /dev/null +++ b/src/py_eddy_tracker/misc.py @@ -0,0 +1,21 @@ +import re + +from matplotlib.animation import FuncAnimation + + +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py new file mode 100644 index 00000000..81929e1e --- /dev/null +++ b/src/py_eddy_tracker/observations/groups.py @@ -0,0 +1,476 @@ +from abc import ABC, abstractmethod +import logging + +from numba import njit, types as nb_types +from numpy import arange, full, int32, interp, isnan, median, where, zeros + +from ..generic import window_index +from ..poly import create_meshed_particles, poly_indexs +from .observation import EddiesObservations + +logger = logging.getLogger("pet") + + +@njit(cache=True) +def get_missing_indices( + array_time, array_track, dt=1, flag_untrack=True, indice_untrack=0 +): + """Return indexes where values are missing + + :param np.array(int) array_time : array of strictly increasing int representing time + :param np.array(int) array_track: N° track where observations belong + :param int,float dt: theorical timedelta between 2 observations + :param bool flag_untrack: if True, ignore observations where n°track equal `indice_untrack` + :param int indice_untrack: n° representing where observations are untracked + + + ex : array_time = np.array([67, 68, 70, 71, 74, 75]) + array_track= np.array([ 1, 1, 1, 1, 1, 1]) + return : np.array([2, 4, 4]) + """ + + t0 = array_time[0] + t1 = t0 + + tr0 = array_track[0] + tr1 = tr0 + + nbr_step = zeros(array_time.shape, dtype=int32) + + for i in range(array_time.size - 1): + t0 = t1 + tr0 = tr1 + + t1 = array_time[i + 1] + tr1 = array_track[i + 1] + + if flag_untrack & (tr1 == indice_untrack): + continue + + if tr1 != tr0: + continue + + diff = t1 - t0 + if diff > dt: + nbr_step[i] = int(diff / dt) - 1 + + indices = zeros(nbr_step.sum(), dtype=int32) + + j = 0 + for i in range(array_time.size - 1): + nbr_missing = nbr_step[i] + + if nbr_missing != 0: + for k in range(nbr_missing): + indices[j] = i + 1 + j += 1 + return indices + + +def advect(x, y, c, t0, n_days, u_name="u", v_name="v"): + """ + Advect particles from t0 to t0 + n_days, with data cube. + + :param np.array(float) x: longitude of particles + :param np.array(float) y: latitude of particles + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int t0: julian day of advection start + :param int n_days: number of days to advect + :param str u_name: variable name for u component + :param str v_name: variable name for v component + """ + + kw = dict(nb_step=6, time_step=86400 / 6) + if n_days < 0: + kw["backward"] = True + n_days = -n_days + p = c.advect(x, y, u_name=u_name, v_name=v_name, t_init=t0, **kw) + for _ in range(n_days): + t, x, y = p.__next__() + return t, x, y + + +def particle_candidate_step( + t_start, contours_start, contours_end, space_step, dt, c, day_fraction=6, **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages. + For one time step. + + :param int t_start: julian day of the advection + :param (np.array(float),np.array(float)) contours_start: origin contour + :param (np.array(float),np.array(float)) contours_end: destination contour + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int day_fraction: fraction of day + :params dict kwargs: dict of params given to advection + :return (np.array,np.array): return target index and percent associate + """ + # In case of zarr array + contours_start = [i[:] for i in contours_start] + contours_end = [i[:] for i in contours_end] + # Create particles in start contour + x, y, i_start = create_meshed_particles(*contours_start, space_step) + # Advect particles + kw = dict(nb_step=day_fraction, time_step=86400 / day_fraction) + p = c.advect(x, y, t_init=t_start, **kwargs, **kw) + for _ in range(abs(dt)): + _, x, y = p.__next__() + m = ~(isnan(x) + isnan(y)) + i_end = full(x.shape, -1, dtype="i4") + if m.any(): + # Id eddies for each alive particle in start contour + i_end[m] = poly_indexs(x[m], y[m], *contours_end) + shape = (contours_start[0].shape[0], 2) + # Get target for each contour + i_target, pct_target = full(shape, -1, dtype="i4"), zeros(shape, dtype="f8") + nb_end = contours_end[0].shape[0] + get_targets(i_start, i_end, i_target, pct_target, nb_end) + return i_target, pct_target.astype("i1") + + +def particle_candidate( + c, + eddies, + step_mesh, + t_start, + i_target, + pct, + contour_start="speed", + contour_end="effective", + **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param GroupEddiesObservations eddies: GroupEddiesObservations considered + :param int t_start: julian day of the advection + :param np.array(int) i_target: corresponding obs where particles are advected + :param np.array(int) pct: corresponding percentage of avected particles + :param str contour_start: contour where particles are injected + :param str contour_end: contour where particles are counted after advection + :params dict kwargs: dict of params given to `advect` + + """ + # Obs from initial time + m_start = eddies.time == t_start + e = eddies.extract_with_mask(m_start) + + # to be able to get global index + translate_start = where(m_start)[0] + + # Create particles in specified contour + intern = False if contour_start == "effective" else True + x, y, i_start = e.create_particles(step_mesh, intern=intern) + # Advection + t_end, x, y = advect(x, y, c, t_start, **kwargs) + + # eddies at last date + m_end = eddies.time == t_end / 86400 + e_end = eddies.extract_with_mask(m_end) + + # to be able to get global index + translate_end = where(m_end)[0] + + # Id eddies for each alive particle in specified contour + intern = False if contour_end == "effective" else True + i_end = e_end.contains(x, y, intern=intern) + + # compute matrix and fill target array + get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct) + + +@njit(cache=True) +def get_targets(i_start, i_end, i_target, pct, nb_end): + """Compute target observation and associated percentages + + :param array(int) i_start: indices in time 0 + :param array(int) i_end: indices in time N + :param array(int) i_target: corresponding obs where particles are advected + :param array(int) pct: corresponding percentage of avected particles + :param int nb_end: number of contour at time N + """ + nb_start = i_target.shape[0] + # Matrix which will store count for every couple + counts = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + ref[i_start_] += 1 + if i_end_ != -1: + counts[i_start_, i_end_] += 1 + # From i to j + for i in range(nb_start): + for j in range(nb_end): + count = counts[i, j] + if count == 0: + continue + pct_ = count / ref[i] * 100 + pct_0 = pct[i, 0] + # If percent is higher than previous stored in rank 0 + if pct_ > pct_0: + pct[i, 1] = pct_0 + pct[i, 0] = pct_ + i_target[i, 1] = i_target[i, 0] + i_target[i, 0] = j + # If percent is higher than previous stored in rank 1 + elif pct_ > pct[i, 1]: + pct[i, 1] = pct_ + i_target[i, 1] = j + + +@njit(cache=True) +def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct): + """Compute target observation and associated percentages + + :param np.array(int) i_start: indices of associated contours at starting advection day + :param np.array(int) i_end: indices of associated contours after advection + :param np.array(int) translate_start: corresponding global indices at starting advection day + :param np.array(int) translate_end: corresponding global indices after advection + :param np.array(int) i_target: corresponding obs where particles are advected + :param np.array(int) pct: corresponding percentage of avected particles + """ + + nb_start, nb_end = translate_start.size, translate_end.size + # Matrix which will store count for every couple + count = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + if i_end_ != -1: + count[i_start_, i_end_] += 1 + ref[i_start_] += 1 + for i in range(nb_start): + for j in range(nb_end): + pct_ = count[i, j] + # If there are particles from i to j + if pct_ != 0: + # Get percent + pct_ = pct_ / ref[i] * 100.0 + # Get indices in full dataset + i_, j_ = translate_start[i], translate_end[j] + pct_0 = pct[i_, 0] + if pct_ > pct_0: + pct[i_, 1] = pct_0 + pct[i_, 0] = pct_ + i_target[i_, 1] = i_target[i_, 0] + i_target[i_, 0] = j_ + elif pct_ > pct[i_, 1]: + pct[i_, 1] = pct_ + i_target[i_, 1] = j_ + return i_target, pct + + +class GroupEddiesObservations(EddiesObservations, ABC): + @abstractmethod + def fix_next_previous_obs(self): + pass + + @abstractmethod + def get_missing_indices(self, dt): + "Find indexes where observations are missing" + pass + + def filled_by_interpolation(self, mask): + """Fill selected values by interpolation + + :param array(bool) mask: True if must be filled by interpolation + + .. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation + """ + if self.track.size == 0: + return + nb_filled = mask.sum() + logger.info("%d obs will be filled (unobserved)", nb_filled) + + nb_obs = len(self) + index = arange(nb_obs) + + for field in self.fields: + if ( + field in ["n", "virtual", "track", "cost_association"] + or field in self.array_variables + ): + continue + self.obs[field][mask] = interp( + index[mask], index[~mask], self.obs[field][~mask] + ) + + def insert_virtual(self): + """Insert virtual observations on segments where observations are missing""" + + dt_theorical = median(self.time[1:] - self.time[:-1]) + indices = self.get_missing_indices(dt_theorical) + + logger.info("%d virtual observation will be added", indices.size) + + # new observations size + size_obs_corrected = self.time.size + indices.size + + # correction of indexes for new size + indices_corrected = indices + arange(indices.size) + + # creating mask with indexes + mask = zeros(size_obs_corrected, dtype=bool) + mask[indices_corrected] = 1 + + new_TEO = self.new_like(self, size_obs_corrected) + new_TEO.obs[~mask] = self.obs + new_TEO.filled_by_interpolation(mask) + new_TEO.virtual[:] = mask + new_TEO.fix_next_previous_obs() + return new_TEO + + def keep_tracks_by_date(self, date, nb_days): + """ + Find tracks that exist at date `date` and lasted at least `nb_days` after. + + :param int,float date: date where the tracks must exist + :param int,float nb_days: number of times the tracks must exist. Can be negative + + If nb_days is negative, it searches a track that exists at the date, + but existed at least `nb_days` before the date + """ + + time = self.time + + mask = zeros(time.shape, dtype=bool) + + for i, b0, b1 in self.iter_on(self.tracks): + _time = time[i] + + if date in _time and (date + nb_days) in _time: + mask[i] = True + + return self.extract_with_mask(mask) + + def particle_candidate_atlas( + self, + cube, + space_step, + dt, + start_intern=False, + end_intern=False, + callback_coherence=None, + finalize_coherence=None, + **kwargs + ): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` cube: GridCollection with speed for particles + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param bool start_intern: Use intern or extern contour at injection, defaults to False + :param bool end_intern: Use intern or extern contour at end of advection, defaults to False + :param dict kwargs: dict of params given to advection + :param func callback_coherence: if None we will use cls.fill_coherence + :param func finalize_coherence: to apply on results of callback_coherence + :return (np.array,np.array): return target index and percent associate + """ + t_start, t_end = int(self.period[0]), int(self.period[1]) + # Pre-compute to get time index + i_sort, i_start, i_end = window_index( + self.time, arange(t_start, t_end + 1), 0.5 + ) + # Out shape + shape = (len(self), 2) + i_target, pct = full(shape, -1, dtype="i4"), zeros(shape, dtype="i1") + # Backward or forward + times = arange(t_start, t_end - dt) if dt > 0 else arange(t_start - dt, t_end) + + if callback_coherence is None: + callback_coherence = self.fill_coherence + indexs = dict() + results = list() + kw_coherence = dict(space_step=space_step, dt=dt, c=cube) + kw_coherence.update(kwargs) + for t in times: + logger.info( + "Coherence for time step : %s in [%s:%s]", t, times[0], times[-1] + ) + # Get index for origin + i = t - t_start + indexs0 = i_sort[i_start[i] : i_end[i]] + # Get index for end + i = t + dt - t_start + indexs1 = i_sort[i_start[i] : i_end[i]] + if indexs0.size == 0 or indexs1.size == 0: + continue + + results.append( + callback_coherence( + self, + i_target, + pct, + indexs0, + indexs1, + start_intern, + end_intern, + t_start=t, + **kw_coherence + ) + ) + indexs[results[-1]] = indexs0, indexs1 + + if finalize_coherence is not None: + finalize_coherence(results, indexs, i_target, pct) + return i_target, pct + + @classmethod + def fill_coherence( + cls, + network, + i_targets, + percents, + i_origin, + i_end, + start_intern, + end_intern, + **kwargs + ): + """_summary_ + + :param array i_targets: global target + :param array percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + :param bool start_intern: Use intern or extern contour at injection + :param bool end_intern: Use intern or extern contour at end of advection + """ + # Get contour data + contours_start = [ + network[label][i_origin] for label in cls.intern(start_intern) + ] + contours_end = [network[label][i_end] for label in cls.intern(end_intern)] + # Compute local coherence + i_local_targets, local_percents = particle_candidate_step( + contours_start=contours_start, contours_end=contours_end, **kwargs + ) + # Store + cls.merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ) + + @staticmethod + def merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ): + """Copy local result in merged result with global indexation + + :param array i_targets: global target + :param array percents: + :param array i_local_targets: local index target + :param array local_percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + """ + m = i_local_targets != -1 + i_local_targets[m] = i_end[i_local_targets[m]] + i_targets[i_origin] = i_local_targets + percents[i_origin] = local_percents diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 16577bf4..f0b9d7cc 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -2,51 +2,2011 @@ """ Class to create network of observations """ -import logging from glob import glob +import logging +import time +from datetime import timedelta, datetime +import os +import netCDF4 +from numba import njit, types as nb_types +from numba.typed import List +import numpy as np +from numpy import ( + arange, + array, + bincount, + bool_, + concatenate, -from numba import njit -from numpy import arange, array, bincount, empty, uint32, unique + empty, + nan, + ones, + percentile, + uint16, + uint32, + unique, + where, + zeros, +) +import zarr +from ..dataset.grid import GridCollection +from ..generic import build_index, wrap_longitude from ..poly import bbox_intersection, vertice_overlap +from .groups import GroupEddiesObservations, get_missing_indices, particle_candidate from .observation import EddiesObservations -from .tracking import TrackEddiesObservations +from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter logger = logging.getLogger("pet") -class Network: - __slots__ = ("window", "filenames", "contour_name", "nb_input", "xname", "yname") - # To be used like a buffer +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +class Buffer(metaclass=Singleton): + __slots__ = ( + "buffersize", + "contour_name", + "xname", + "yname", + "memory", + ) DATA = dict() FLIST = list() - NOGROUP = TrackEddiesObservations.NOGROUP - def __init__(self, input_regex, window=5, intern=False): - self.window = window + def __init__(self, buffersize, intern=False, memory=False): + self.buffersize = buffersize self.contour_name = EddiesObservations.intern(intern, public_label=True) self.xname, self.yname = EddiesObservations.intern(intern) - self.filenames = glob(input_regex) - self.filenames.sort() - self.nb_input = len(self.filenames) + self.memory = memory def load_contour(self, filename): + if isinstance(filename, EddiesObservations): + return filename[self.xname], filename[self.yname] if filename not in self.DATA: - if len(self.FLIST) > self.window: + if len(self.FLIST) > self.buffersize: self.DATA.pop(self.FLIST.pop(0)) - e = EddiesObservations.load_file(filename, include_vars=self.contour_name) + if self.memory: + # Only if netcdf + with open(filename, "rb") as h: + e = EddiesObservations.load_file(h, include_vars=self.contour_name) + else: + e = EddiesObservations.load_file( + filename, include_vars=self.contour_name + ) + self.FLIST.append(filename) self.DATA[filename] = e[self.xname], e[self.yname] return self.DATA[filename] + +@njit(cache=True) +def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): + """When an observation is virtual, we have to fix the previous and next obs + + :param np.array(int) next_obs : index of next observation from network + :param np.array(int previous_obs: index of previous observation from network + :param np.array(bool) flag_virtual: if observation is virtual or not + """ + + for i_o in range(next_obs.size): + if not flag_virtual[i_o]: + continue + + # if there are several consecutive virtuals, some values are written multiple times. + # but it should not be slow + next_obs[i_o - 1] = i_o + next_obs[i_o] = i_o + 1 + previous_obs[i_o] = i_o - 1 + previous_obs[i_o + 1] = i_o + + +class NetworkObservations(GroupEddiesObservations): + __slots__ = ("_index_network", "_index_segment_track", "_segment_track_array") + NOGROUP = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.reset_index() + + def __repr__(self): + m_event, s_event = ( + self.merging_event(only_index=True, triplet=True)[0], + self.splitting_event(only_index=True, triplet=True)[0], + ) + period = (self.period[1] - self.period[0]) / 365.25 + nb_by_network = self.network_size() + nb_trash = 0 if self.ref_index != 0 else nb_by_network[0] + lifetime=self.lifetime + big = 50_000 + infos = [ + f"Atlas with {self.nb_network} networks ({self.nb_network / period:0.0f} networks/year)," + f" {self.nb_segment} segments ({self.nb_segment / period:0.0f} segments/year), {len(self)} observations ({len(self) / period:0.0f} observations/year)", + f" {m_event.size} merging ({m_event.size / period:0.0f} merging/year), {s_event.size} splitting ({s_event.size / period:0.0f} splitting/year)", + f" with {(nb_by_network > big).sum()} network with more than {big} obs and the biggest have {nb_by_network.max()} observations ({nb_by_network[nb_by_network > big].sum()} observations cumulate)", + f" {nb_trash} observations in trash", + f" {lifetime.max()} days max of lifetime", + ] + return "\n".join(infos) + + def reset_index(self): + self._index_network = None + self._index_segment_track = None + self._segment_track_array = None + + def find_segments_relative(self, obs, stopped=None, order=1): + """ + Find all relative segments from obs linked with merging/splitting events at a specific order. + + :param int obs: index of observation after the event + :param int stopped: index of observation before the event + :param int order: order of relatives accepted + :return: all relative segments + :rtype: EddiesObservations + """ + + # extraction of network where the event is + network_id = self.tracks[obs] + nw = self.network(network_id) + + # indice of observation in new subnetwork + i_obs = where(nw.segment == self.segment[obs])[0][0] + + if stopped is None: + return nw.relatives(i_obs, order=order) + + else: + i_stopped = where(nw.segment == self.segment[stopped])[0][0] + return nw.relatives([i_obs, i_stopped], order=order) + + def get_missing_indices(self, dt): + """Find indices where observations are missing. + + As network have all untracked observation in tracknumber `self.NOGROUP`, + we don't compute them + + :param int,float dt: theorical delta time between 2 observations + """ + return get_missing_indices( + self.time, self.track, dt=dt, flag_untrack=True, indice_untrack=self.NOGROUP + ) + + def fix_next_previous_obs(self): + """Function used after 'insert_virtual', to correct next_obs and + previous obs. + """ + + fix_next_previous_obs(self.next_obs, self.previous_obs, self.virtual) + + @property + def index_network(self): + if self._index_network is None: + self._index_network = build_index(self.track) + return self._index_network + + @property + def index_segment_track(self): + if self._index_segment_track is None: + self._index_segment_track = build_index(self.segment_track_array) + return self._index_segment_track + + def segment_size(self): + return self.index_segment_track[1] - self.index_segment_track[0] + + @property + def ref_segment_track_index(self): + return self.index_segment_track[2] + + @property + def ref_index(self): + return self.index_network[2] + + @property + def lifetime(self): + """Return lifetime for each observation""" + lt=self.networks_period.astype("int") + nb_by_network=self.network_size() + return lt.repeat(nb_by_network) + + def network_segment_size(self, id_networks=None): + """Get number of segment by network + + :return array: + """ + i0, i1, ref = build_index(self.track[self.index_segment_track[0]]) + if id_networks is None: + return i1 - i0 + else: + i = id_networks - ref + return i1[i] - i0[i] + + def network_size(self, id_networks=None): + """ + Return size for specified network + + :param list,array, None id_networks: ids to identify network + """ + if id_networks is None: + return self.index_network[1] - self.index_network[0] + else: + i = id_networks - self.index_network[2] + return self.index_network[1][i] - self.index_network[0][i] + + @property + def networks_period(self): + """ + Return period for each network + """ + return get_period_with_index(self.time, *self.index_network[:2]) + + + + def unique_segment_to_id(self, id_unique): + """Return id network and id segment for a unique id + + :param array id_unique: + """ + i = self.index_segment_track[0][id_unique] - self.ref_segment_track_index + return self.track[i], self.segment[i] + + def segment_slice(self, id_network, id_segment): + """ + Return slice for one segment + + :param int id_network: id to identify network + :param int id_segment: id to identify segment + """ + raise Exception("need to be implemented") + + def network_slice(self, id_network): + """ + Return slice for one network + + :param int id_network: id to identify network + """ + i = id_network - self.index_network[2] + i_start, i_stop = self.index_network[0][i], self.index_network[1][i] + return slice(i_start, i_stop) + + @property + def elements(self): + elements = super().elements + elements.extend( + [ + "track", + "segment", + "next_obs", + "previous_obs", + "next_cost", + "previous_cost", + ] + ) + return list(set(elements)) + + def astype(self, cls): + new = cls.new_like(self, self.shape) + for k in new.fields: + if k in self.fields: + new[k][:] = self[k][:] + new.sign_type = self.sign_type + return new + + def longer_than(self, nb_day_min=-1, nb_day_max=-1): + """ + Select network on time duration + + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used + """ + return self.extract_with_mask(self.mask_longer_than(nb_day_min, nb_day_max)) + + def mask_longer_than(self, nb_day_min=-1, nb_day_max=-1): + """ + Select network on time duration + + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used + """ + if nb_day_max < 0: + nb_day_max = 1000000000000 + mask = zeros(self.shape, dtype="bool") + t = self.time + for i, _, _ in self.iter_on(self.track): + nb = i.stop - i.start + if nb == 0: + continue + if nb_day_min <= (ptp(t[i]) + 1) <= nb_day_max: + mask[i] = True + return mask + + @classmethod + def from_split_network(cls, group_dataset, indexs, **kwargs): + """ + Build a NetworkObservations object with Group dataset and indices + + :param TrackEddiesObservations group_dataset: Group dataset + :param indexs: result from split_network + :return: NetworkObservations + """ + index_order = indexs.argsort(order=("group", "track", "time")) + network = cls.new_like(group_dataset, len(group_dataset), **kwargs) + network.sign_type = group_dataset.sign_type + for field in group_dataset.elements: + if field not in network.elements: + continue + network[field][:] = group_dataset[field][index_order] + network.segment[:] = indexs["track"][index_order] + # n & p must be re-indexed + n, p = indexs["next_obs"][index_order], indexs["previous_obs"][index_order] + # we add 2 for -1 index return index -1 + translate = -ones(index_order.max() + 2, dtype="i4") + translate[index_order] = arange(index_order.shape[0]) + network.next_obs[:] = translate[n] + network.previous_obs[:] = translate[p] + network.next_cost[:] = indexs["next_cost"][index_order] + network.previous_cost[:] = indexs["previous_cost"][index_order] + return network + + def infos(self, label=""): + return f"{len(self)} obs {unique(self.segment).shape[0]} segments" + + def correct_close_events(self, nb_days_max=20): + """ + Transform event where + segment A splits from segment B, then x days after segment B merges with A + to + segment A splits from segment B then x days after segment A merges with B (B will be longer) + These events have to last less than `nb_days_max` to be changed. + + + ------------------- A + / / + B -------------------- + to + --A-- + / \ + B ----------------------------------- + + :param float nb_days_max: maximum time to search for splitting-merging event + """ + + _time = self.time + # segment used to correct and track changes + segment = self.segment_track_array.copy() + # final segment used to copy into self.segment + segment_copy = self.segment + + segments_connexion = dict() + + previous_obs, next_obs = self.previous_obs, self.next_obs + + # record for every segment the slice, index of next obs & index of previous obs + for i, seg, _ in self.iter_on(segment): + if i.start == i.stop: + continue + + i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1] + segments_connexion[seg] = [i, i_p, i_n] + + for seg in sorted(segments_connexion.keys()): + seg_slice, _, i_seg_n = segments_connexion[seg] + + # the segment ID has to be corrected, because we may have changed it since + seg_corrected = segment[seg_slice.stop - 1] + + # we keep the real segment number + seg_corrected_copy = segment_copy[seg_slice.stop - 1] + if i_seg_n == -1: + continue + + # if segment is split + n_seg = segment[i_seg_n] + + seg2_slice, i2_seg_p, _ = segments_connexion[n_seg] + if i2_seg_p == -1: + continue + p2_seg = segment[i2_seg_p] + + # if it merges on the first in a certain time + if (p2_seg == seg_corrected) and ( + _time[i_seg_n] - _time[i2_seg_p] < nb_days_max + ): + my_slice = slice(i_seg_n, seg2_slice.stop) + # correct the factice segment + segment[my_slice] = seg_corrected + # correct the good segment + segment_copy[my_slice] = seg_corrected_copy + previous_obs[i_seg_n] = seg_slice.stop - 1 + + segments_connexion[seg_corrected][0] = my_slice + + return self.sort() + + def sort(self, order=("track", "segment", "time")): + """ + Sort observations + + :param tuple order: order or sorting. Given to :func:`numpy.argsort` + """ + index_order = self.obs.argsort(order=order, kind="mergesort") + self.reset_index() + for field in self.fields: + self[field][:] = self[field][index_order] + + nb_obs = len(self) + # we add 1 for -1 index return index -1 + translate = -ones(nb_obs + 1, dtype="i4") + translate[index_order] = arange(nb_obs) + # next & previous must be re-indexed + self.next_obs[:] = translate[self.next_obs] + self.previous_obs[:] = translate[self.previous_obs] + return index_order, translate + + def obs_relative_order(self, i_obs): + self.only_one_network() + return self.segment_relative_order(self.segment[i_obs]) + + def find_link(self, i_observations, forward=True, backward=False): + """ + Find all observations where obs `i_observation` could be + in future or past. + + If forward=True, search all observations where water + from obs "i_observation" could go + + If backward=True, search all observation + where water from obs `i_observation` could come from + + :param int,iterable(int) i_observation: + indices of observation. Can be + int, or iterable of int. + :param bool forward, backward: + if forward, search observations after obs. + else mode==backward search before obs + + """ + + i_obs = ( + [i_observations] + if not hasattr(i_observations, "__iter__") + else i_observations + ) + + segment = self.segment_track_array + previous_obs, next_obs = self.previous_obs, self.next_obs + + segments_connexion = dict() + + for i_slice, seg, _ in self.iter_on(segment): + if i_slice.start == i_slice.stop: + continue + + i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1] + p_seg, n_seg = segment[i_p], segment[i_n] + + # dumping slice into dict + if seg not in segments_connexion: + segments_connexion[seg] = [i_slice, [], []] + else: + segments_connexion[seg][0] = i_slice + + if i_p != -1: + if p_seg not in segments_connexion: + segments_connexion[p_seg] = [None, [], []] + + # backward + segments_connexion[seg][2].append((i_slice.start, i_p, p_seg)) + # forward + segments_connexion[p_seg][1].append((i_p, i_slice.start, seg)) + + if i_n != -1: + if n_seg not in segments_connexion: + segments_connexion[n_seg] = [None, [], []] + + # forward + segments_connexion[seg][1].append((i_slice.stop - 1, i_n, n_seg)) + # backward + segments_connexion[n_seg][2].append((i_n, i_slice.stop - 1, seg)) + + mask = zeros(segment.size, dtype=bool) + + def func_forward(seg, indice): + seg_slice, _forward, _ = segments_connexion[seg] + + mask[indice : seg_slice.stop] = True + for i_begin, i_end, seg2 in _forward: + if i_begin < indice: + continue + + if not mask[i_end]: + func_forward(seg2, i_end) + + def func_backward(seg, indice): + seg_slice, _, _backward = segments_connexion[seg] + + mask[seg_slice.start : indice + 1] = True + for i_begin, i_end, seg2 in _backward: + if i_begin > indice: + continue + + if not mask[i_end]: + func_backward(seg2, i_end) + + for indice in i_obs: + if forward: + func_forward(segment[indice], indice) + + if backward: + func_backward(segment[indice], indice) + + return self.extract_with_mask(mask) + + def connexions(self, multi_network=False): + """Create dictionnary for each segment, gives the segments in interaction with + + :param bool multi_network: use segment_track_array instead of segment, defaults to False + :return dict: Return dict of set, for each seg id we get set of segment which have event with him + """ + if multi_network: + segment = self.segment_track_array + else: + self.only_one_network() + segment = self.segment + segments_connexion = dict() + + def add_seg(s1, s2): + if s1 not in segments_connexion: + segments_connexion[s1] = set() + if s2 not in segments_connexion: + segments_connexion[s2] = set() + segments_connexion[s1].add(s2), segments_connexion[s2].add(s1) + + # Get index for each segment + i0, i1, _ = self.index_segment_track + i1 = i1 - 1 + # Check if segment merge + i_next = self.next_obs[i1] + m_n = i_next != -1 + # Check if segment come from splitting + i_previous = self.previous_obs[i0] + m_p = i_previous != -1 + # For each split + for s1, s2 in zip(segment[i_previous[m_p]], segment[i0[m_p]]): + add_seg(s1, s2) + # For each merge + for s1, s2 in zip(segment[i_next[m_n]], segment[i1[m_n]]): + add_seg(s1, s2) + return segments_connexion + + @classmethod + def __close_segment(cls, father, shift, connexions, distance): + i_father = father - shift + if distance[i_father] == -1: + distance[i_father] = 0 + d_target = distance[i_father] + 1 + for son in connexions.get(father, list()): + i_son = son - shift + d_son = distance[i_son] + if d_son == -1 or d_son > d_target: + distance[i_son] = d_target + else: + continue + cls.__close_segment(son, shift, connexions, distance) + + def segment_relative_order(self, seg_origine): + """ + Compute the relative order of each segment to the chosen segment + """ + self.only_one_network() + i_s, i_e, i_ref = build_index(self.segment) + segment_connexions = self.connexions() + relative_tr = -ones(i_s.shape, dtype="i4") + self.__close_segment(seg_origine, i_ref, segment_connexions, relative_tr) + d = -ones(self.shape) + for i0, i1, v in zip(i_s, i_e, relative_tr): + if i0 == i1: + continue + d[i0:i1] = v + return d + + def relatives(self, obs, order=2): + """ + Extract the segments at a certain order from multiple observations. + + :param iterable,int obs: + indices of observation for relatives computation. Can be one observation (int) + or collection of observations (iterable(int)) + :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ... + :return: all segments' relatives + :rtype: EddiesObservations + """ + segment = self.segment_track_array + previous_obs, next_obs = self.previous_obs, self.next_obs + + segments_connexion = dict() + + for i_slice, seg, _ in self.iter_on(segment): + if i_slice.start == i_slice.stop: + continue + + i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1] + p_seg, n_seg = segment[i_p], segment[i_n] + + # dumping slice into dict + if seg not in segments_connexion: + segments_connexion[seg] = [i_slice, []] + else: + segments_connexion[seg][0] = i_slice + + if i_p != -1: + if p_seg not in segments_connexion: + segments_connexion[p_seg] = [None, []] + + # backward + segments_connexion[seg][1].append(p_seg) + segments_connexion[p_seg][1].append(seg) + + if i_n != -1: + if n_seg not in segments_connexion: + segments_connexion[n_seg] = [None, []] + + # forward + segments_connexion[seg][1].append(n_seg) + segments_connexion[n_seg][1].append(seg) + + i_obs = [obs] if not hasattr(obs, "__iter__") else obs + distance = zeros(segment.size, dtype=uint16) - 1 + + def loop(seg, dist=1): + i_slice, links = segments_connexion[seg] + d = distance[i_slice.start] + + if dist < d and dist <= order: + distance[i_slice] = dist + for _seg in links: + loop(_seg, dist + 1) + + for indice in i_obs: + loop(segment[indice], 0) + + return self.extract_with_mask(distance <= order) + + # keep old names, for backward compatibility + relative = relatives + + def close_network(self, other, nb_obs_min=10, **kwargs): + """ + Get close network from another atlas. + + :param self other: Atlas to compare + :param int nb_obs_min: Minimal number of overlap for one trajectory + :param dict kwargs: keyword arguments for match function + :return: return other atlas reduced to common tracks with self + + .. warning:: + It could be a costly operation for huge dataset + """ + p0, p1 = self.period + indexs = list() + for i_self, i_other, t0, t1 in self.align_on(other, bins=range(p0, p1 + 2)): + i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs) + indexs.append(other.re_reference_index(j, i_other)) + indexs = concatenate(indexs) + tr, nb = unique(other.track[indexs], return_counts=True) + m = zeros(other.track.shape, dtype=bool) + for i in tr[nb >= nb_obs_min]: + m[other.network_slice(i)] = True + return other.extract_with_mask(m) + + def normalize_longitude(self): + """Normalize all longitudes + + Normalize longitude field and in the same range : + - longitude_max + - contour_lon_e (how to do if in raw) + - contour_lon_s (how to do if in raw) + """ + i_start, i_stop, _ = self.index_network + lon0 = (self.lon[i_start] - 180).repeat(i_stop - i_start) + logger.debug("Normalize longitude") + self.lon[:] = (self.lon - lon0) % 360 + lon0 + if "lon_max" in self.fields: + logger.debug("Normalize longitude_max") + self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 + if not self.raw_data: + if "contour_lon_e" in self.fields: + logger.debug("Normalize effective contour longitude") + self.contour_lon_e[:] = ( + (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 + ).T + if "contour_lon_s" in self.fields: + logger.debug("Normalize speed contour longitude") + self.contour_lon_s[:] = ( + (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 + ).T + + def numbering_segment(self, start=0): + """ + New numbering of segment + """ + for i, _, _ in self.iter_on("track"): + new_numbering(self.segment[i], start) + + def numbering_network(self, start=1): + """ + New numbering of network + """ + new_numbering(self.track, start) + + def only_one_network(self): + """ + Raise a warning or error? + if there are more than one network + """ + _, i_start, _ = self.index_network + if i_start.size > 1: + raise Exception("Several networks") + + def position_filter(self, median_half_window, loess_half_window): + self.median_filter(median_half_window, "time", "lon").loess_filter( + loess_half_window, "time", "lon" + ) + self.median_filter(median_half_window, "time", "lat").loess_filter( + loess_half_window, "time", "lat" + ) + + def loess_filter(self, half_window, xfield, yfield, inplace=True): + result = track_loess_filter( + half_window, self.obs[xfield], self.obs[yfield], self.segment_track_array + ) + if inplace: + self.obs[yfield] = result + return self + return result + + def median_filter(self, half_window, xfield, yfield, inplace=True): + result = track_median_filter( + half_window, self[xfield], self[yfield], self.segment_track_array + ) + if inplace: + self[yfield][:] = result + return self + return result + + def display_timeline( + self, + ax, + event=True, + field=None, + method=None, + factor=1, + colors_mode="roll", + **kwargs, + ): + """ + Plot the timeline of a network. + Must be called on only one network. + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param bool event: if True, draw the splitting and merging events + :param str,array field: yaxis values, if None, segments are used + :param str method: if None, mean values are used + :param float factor: to multiply field + :param str colors_mode: + color of lines. "roll" means looping through colors, + "y" means color adapt the y values (for matching color plots) + :return: plot mappable + """ + self.only_one_network() + j = 0 + line_kw = dict( + ls="-", + marker="+", + markersize=6, + zorder=1, + lw=3, + ) + line_kw.update(kwargs) + mappables = dict(lines=list()) + + if event: + mappables.update( + self.event_timeline( + ax, + field=field, + method=method, + factor=factor, + colors_mode=colors_mode, + ) + ) + if field is not None: + field = self.parse_varname(field) + for i, b0, b1 in self.iter_on("segment"): + x = self.time_datetime64[i] + if x.shape[0] == 0: + continue + if field is None: + y = b0 * ones(x.shape) + else: + if method == "all": + y = field[i] * factor + else: + y = field[i].mean() * ones(x.shape) * factor + + if colors_mode == "roll": + _color = self.get_color(j) + elif colors_mode == "y": + _color = self.get_color(b0 - 1) + else: + raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") + + line = ax.plot(x, y, **line_kw, color=_color)[0] + mappables["lines"].append(line) + j += 1 + + return mappables + + def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="roll"): + """Mark events in plot""" + j = 0 + events = dict(splitting=[], merging=[]) + + # TODO : fill mappables dict + y_seg = dict() + _time = self.time_datetime64 + + if field is not None and method != "all": + for i, b0, _ in self.iter_on("segment"): + y = self.parse_varname(field)[i] + if y.shape[0] != 0: + y_seg[b0] = y.mean() * factor + mappables = dict() + for i, b0, b1 in self.iter_on("segment"): + x = _time[i] + if x.shape[0] == 0: + continue + + if colors_mode == "roll": + _color = self.get_color(j) + elif colors_mode == "y": + _color = self.get_color(b0 - 1) + else: + raise NotImplementedError(f"colors_mode '{colors_mode}' not defined") + + event_kw = dict(color=_color, ls="-", zorder=1) + + i_n, i_p = ( + self.next_obs[i.stop - 1], + self.previous_obs[i.start], + ) + if field is None: + y0 = b0 + else: + if method == "all": + y0 = self.parse_varname(field)[i.stop - 1] * factor + else: + y0 = y_seg[b0] + if i_n != -1: + seg_next = self.segment[i_n] + y1 = ( + seg_next + if field is None + else ( + self.parse_varname(field)[i_n] * factor + if method == "all" + else y_seg[seg_next] + ) + ) + ax.plot((x[-1], _time[i_n]), (y0, y1), **event_kw)[0] + events["merging"].append((x[-1], y0)) + + if i_p != -1: + seg_previous = self.segment[i_p] + if field is not None and method == "all": + y0 = self[field][i.start] * factor + y1 = ( + seg_previous + if field is None + else ( + self.parse_varname(field)[i_p] * factor + if method == "all" + else y_seg[seg_previous] + ) + ) + ax.plot((x[0], _time[i_p]), (y0, y1), **event_kw)[0] + events["splitting"].append((x[0], y0)) + + j += 1 + + kwargs = dict(color="k", zorder=-1, linestyle=" ") + if len(events["splitting"]) > 0: + X, Y = list(zip(*events["splitting"])) + ref = ax.plot( + X, Y, marker="*", markersize=12, label="splitting events", **kwargs + )[0] + mappables.setdefault("events", []).append(ref) + + if len(events["merging"]) > 0: + X, Y = list(zip(*events["merging"])) + ref = ax.plot( + X, Y, marker="H", markersize=10, label="merging events", **kwargs + )[0] + mappables.setdefault("events", []).append(ref) + + return mappables + + def mean_by_segment(self, y, **kw): + kw["dtype"] = y.dtype + return self.map_segment(lambda x: x.mean(), y, **kw) + + def map_segment(self, method, y, same=True, **kw): + if same: + out = empty(y.shape, **kw) + else: + out = list() + for i, _, _ in self.iter_on(self.segment_track_array): + res = method(y[i]) + if same: + out[i] = res + else: + if isinstance(i, slice): + if i.start == i.stop: + continue + elif len(i) == 0: + continue + out.append(res) + if not same: + out = array(out) + return out + + def map_network(self, method, y, same=True, return_dict=False, **kw): + """ + Transform data `y` with method `method` for each track. + + :param Callable method: method to apply on each track + :param np.array y: data where to apply method + :param bool same: if True, return an array with the same size than y. Else, return a list with the edited tracks + :param bool return_dict: if None, mean values are used + :param float kw: to multiply field + :return: array or dict of result from method for each network + """ + + if same and return_dict: + raise NotImplementedError( + "both conditions 'same' and 'return_dict' should no be true" + ) + + if same: + out = empty(y.shape, **kw) + + elif return_dict: + out = dict() + + else: + out = list() + + for i, b0, b1 in self.iter_on(self.track): + res = method(y[i]) + if same: + out[i] = res + + elif return_dict: + out[b0] = res + + else: + if isinstance(i, slice): + if i.start == i.stop: + continue + elif len(i) == 0: + continue + out.append(res) + + if not same and not return_dict: + out = array(out) + return out + + def scatter_timeline( + self, + ax, + name, + factor=1, + event=True, + yfield=None, + yfactor=1, + method=None, + **kwargs, + ): + """ + Must be called on only one network + """ + self.only_one_network() + y = (self.segment if yfield is None else self.parse_varname(yfield)) * yfactor + if method == "all": + pass + else: + y = self.mean_by_segment(y) + mappables = dict() + if event: + mappables.update( + self.event_timeline(ax, field=yfield, method=method, factor=yfactor) + ) + if "c" not in kwargs: + v = self.parse_varname(name) + kwargs["c"] = v * factor + mappables["scatter"] = ax.scatter(self.time_datetime64, y, **kwargs) + return mappables + + def event_map(self, ax, **kwargs): + """Add the merging and splitting events to a map""" + j = 0 + mappables = dict() + symbol_kw = dict( + markersize=10, + color="k", + ) + symbol_kw.update(kwargs) + symbol_kw_split = symbol_kw.copy() + symbol_kw_split["markersize"] += 4 + for i, b0, b1 in self.iter_on("segment"): + nb = i.stop - i.start + if nb == 0: + continue + event_kw = dict(color=self.COLORS[j % self.NB_COLORS], ls="-", **kwargs) + i_n, i_p = ( + self.next_obs[i.stop - 1], + self.previous_obs[i.start], + ) + + if i_n != -1: + y0, y1 = self.lat[i.stop - 1], self.lat[i_n] + x0, x1 = self.lon[i.stop - 1], self.lon[i_n] + ax.plot((x0, x1), (y0, y1), **event_kw)[0] + ax.plot(x0, y0, marker="H", **symbol_kw)[0] + if i_p != -1: + y0, y1 = self.lat[i.start], self.lat[i_p] + x0, x1 = self.lon[i.start], self.lon[i_p] + ax.plot((x0, x1), (y0, y1), **event_kw)[0] + ax.plot(x0, y0, marker="*", **symbol_kw_split)[0] + + j += 1 + return mappables + + def scatter( + self, + ax, + name="time", + factor=1, + ref=None, + edgecolor_cycle=None, + **kwargs, + ): + """ + This function scatters the path of each network, with the merging and splitting events + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array,None name: + variable used to fill the contours, if None all elements have the same color + :param float,None ref: if defined, ref is used as western boundary + :param float factor: multiply value by + :param list edgecolor_cycle: list of colors + :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.scatter` + :return: a dict of scattered mappables + """ + mappables = dict() + nb_colors = len(edgecolor_cycle) if edgecolor_cycle else None + x = self.longitude + if ref is not None: + x = (x - ref) % 360 + ref + kwargs = kwargs.copy() + if nb_colors: + edgecolors = list() + seg_previous = self.segment[0] + j = 0 + for seg in self.segment: + if seg != seg_previous: + j += 1 + edgecolors.append(edgecolor_cycle[j % nb_colors]) + seg_previous = seg + mappables["edges"] = ax.scatter( + x, self.latitude, edgecolor=edgecolors, **kwargs + ) + kwargs.pop("linewidths", None) + kwargs["lw"] = 0 + if name is not None and "c" not in kwargs: + v = self.parse_varname(name) + kwargs["c"] = v * factor + mappables["scatter"] = ax.scatter(x, self.latitude, **kwargs) + return mappables + + def extract_event(self, indices): + nb = len(indices) + new = EddiesObservations( + nb, + track_extra_variables=self.track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=self.array_variables, + only_variables=self.only_variables, + raw_data=self.raw_data, + ) + + for k in new.fields: + new[k][:] = self[k][indices] + new.sign_type = self.sign_type + return new + + @property + def segment_track_array(self): + """Return a unique segment id when multiple networks are considered""" + if self._segment_track_array is None: + self._segment_track_array = build_unique_array(self.segment, self.track) + return self._segment_track_array + + def birth_event(self, only_index=False): + """Extract birth events.""" + i_start, _, _ = self.index_segment_track + indices = i_start[self.previous_obs[i_start] == -1] + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + generation_event = birth_event + + def death_event(self, only_index=False): + """Extract death events.""" + _, i_stop, _ = self.index_segment_track + indices = i_stop[self.next_obs[i_stop - 1] == -1] - 1 + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + dissipation_event = death_event + + def merging_event(self, triplet=False, only_index=False): + """Return observation after a merging event. + + If `triplet=True` return the eddy after a merging event, the eddy before the merging event, + and the eddy stopped due to merging. + """ + # Get start and stop for each segment, there is no empty segment + _, i1, _ = self.index_segment_track + # Get last index for each segment + i_stop = i1 - 1 + # Get target index + idx_m1 = self.next_obs[i_stop] + # Get mask and valid target + m = idx_m1 != -1 + idx_m1 = idx_m1[m] + # Sort by time event + i = self.time[idx_m1].argsort() + idx_m1 = idx_m1[i] + if triplet: + # Get obs before target + idx_m0_stop = i_stop[m][i] + idx_m0 = self.previous_obs[idx_m1].copy() + + if triplet: + if only_index: + return idx_m1, idx_m0, idx_m0_stop + else: + return ( + self.extract_event(idx_m1), + self.extract_event(idx_m0), + self.extract_event(idx_m0_stop), + ) + else: + idx_m1 = unique(idx_m1) + if only_index: + return idx_m1 + else: + return self.extract_event(idx_m1) + + def splitting_event(self, triplet=False, only_index=False): + """Return observation before a splitting event. + + If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event, + and the eddy starting due to splitting. + """ + # Get start and stop for each segment, there is no empty segment + i_start, _, _ = self.index_segment_track + # Get target index + idx_s0 = self.previous_obs[i_start] + # Get mask and valid target + m = idx_s0 != -1 + idx_s0 = idx_s0[m] + # Sort by time event + i = self.time[idx_s0].argsort() + idx_s0 = idx_s0[i] + if triplet: + # Get obs after target + idx_s1_start = i_start[m][i] + idx_s1 = self.next_obs[idx_s0].copy() + + if triplet: + if only_index: + return idx_s0, idx_s1, idx_s1_start + else: + return ( + self.extract_event(idx_s0), + self.extract_event(idx_s1), + self.extract_event(idx_s1_start), + ) + + else: + idx_s0 = unique(idx_s0) + if only_index: + return idx_s0 + else: + return self.extract_event(idx_s0) + + def dissociate_network(self): + """ + Dissociate networks with no known interaction (splitting/merging) + """ + tags = self.tag_segment() + if self.track[0] == 0: + tags -= 1 + self.track[:] = tags[self.segment_track_array] + return self.sort() + + def network_segment(self, id_network, id_segment): + return self.extract_with_mask(self.segment_slice(id_network, id_segment)) + + def network(self, id_network): + return self.extract_with_mask(self.network_slice(id_network)) + + def networks_mask(self, id_networks, segment=False): + if segment: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_segment_track + ) + else: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_network + ) + + def networks(self, id_networks): + return self.extract_with_mask( + generate_mask_from_ids( + array(id_networks), self.track.size, *self.index_network + ) + ) + + @property + def nb_network(self): + """ + Count and return number of network + """ + return (self.network_size() != 0).sum() + + @property + def nb_segment(self): + """ + Count and return number of segment in all network + """ + return self.index_segment_track[0].size + + def identify_in(self, other, size_min=1, segment=False): + """ + Return couple of segment or network which are equal + + :param other: other atlas to compare + :param int size_min: number of observation in network/segment + :param bool segment: segment mode + """ + if segment: + counts = self.segment_size(), other.segment_size() + i_self_ref, i_other_ref = ( + self.ref_segment_track_index, + other.ref_segment_track_index, + ) + var_id = "segment" + else: + counts = self.network_size(), other.network_size() + i_self_ref, i_other_ref = self.ref_index, other.ref_index + var_id = "track" + # object to contain index of couple + in_self, in_other = list(), list() + # We iterate on item of same size + for i_self, i_other, i0, _ in self.align_on(other, counts, all_ref=True): + if i0 < size_min: + continue + if isinstance(i_other, slice): + i_other = arange(i_other.start, i_other.stop) + # All_ref will give all item of self, sometime there is no things to compare with other + if i_other.size == 0: + id_self = i_self + i_self_ref + in_self.append(id_self) + in_other.append(-ones(id_self.shape, dtype=id_self.dtype)) + continue + if isinstance(i_self, slice): + i_self = arange(i_self.start, i_self.stop) + # We get absolute id + id_self, id_other = i_self + i_self_ref, i_other + i_other_ref + # We compute mask to select data + m_self, m_other = self.networks_mask(id_self, segment), other.networks_mask( + id_other, segment + ) + + # We extract obs + obs_self, obs_other = self.obs[m_self], other.obs[m_other] + x1, y1, t1 = obs_self["lon"], obs_self["lat"], obs_self["time"] + x2, y2, t2 = obs_other["lon"], obs_other["lat"], obs_other["time"] + + if segment: + ids1 = build_unique_array(obs_self["segment"], obs_self["track"]) + ids2 = build_unique_array(obs_other["segment"], obs_other["track"]) + label1 = self.segment_track_array[m_self] + label2 = other.segment_track_array[m_other] + else: + label1, label2 = ids1, ids2 = obs_self[var_id], obs_other[var_id] + # For each item we get index to sort + i01, indexs1, id1 = list(), List(), list() + for sl_self, id_, _ in self.iter_on(ids1): + i01.append(sl_self.start) + indexs1.append(obs_self[sl_self].argsort(order=["time", "lon", "lat"])) + id1.append(label1[sl_self.start]) + i02, indexs2, id2 = list(), List(), list() + for sl_other, _, _ in other.iter_on(ids2): + i02.append(sl_other.start) + indexs2.append( + obs_other[sl_other].argsort(order=["time", "lon", "lat"]) + ) + id2.append(label2[sl_other.start]) + + id1, id2 = array(id1), array(id2) + # We search item from self in item of others + i_local_target = same_position( + x1, y1, t1, x2, y2, t2, array(i01), array(i02), indexs1, indexs2 + ) + + # -1 => no item found in other dataset + m = i_local_target != -1 + in_self.append(id1) + track2_ = -ones(id1.shape, dtype="i4") + track2_[m] = id2[i_local_target[m]] + in_other.append(track2_) + + return concatenate(in_self), concatenate(in_other) + + @classmethod + def __tag_segment(cls, seg, tag, groups, connexions): + """ + Will set same temporary ID for each connected segment. + + :param int seg: current ID of segment + :param ing tag: temporary ID to set for segment and its connexion + :param array[int] groups: array where tag is stored + :param dict connexions: gives for one ID of segment all connected segments + """ + # If segments are already used we stop recursivity + if groups[seg] != 0: + return + # We set tag for this segment + groups[seg] = tag + # Get all connexions of this segment + segs = connexions.get(seg, None) + if segs is not None: + for seg in segs: + # For each connexion we apply same function + cls.__tag_segment(seg, tag, groups, connexions) + + def tag_segment(self): + """For each segment, method give a new network id, and all segment are connected + + :return array: for each unique seg id, it return new network id + """ + nb = self.segment_track_array[-1] + 1 + sub_group = zeros(nb, dtype="u4") + c = self.connexions(multi_network=True) + j = 1 + # for each available id + for i in range(nb): + # No connexions, no need to explore + if i not in c: + sub_group[i] = j + j += 1 + continue + # Skip if already set + if sub_group[i] != 0: + continue + # we tag an unset segments and explore all connexions + self.__tag_segment(i, j, sub_group, c) + j += 1 + return sub_group + + def fully_connected(self): + """Suspicious""" + raise Exception("Must be check") + self.only_one_network() + return self.tag_segment().shape[0] == 1 + + def first_is_trash(self): + """Check if first network is Trash + + :return bool: True if first network is trash + """ + i_start, i_stop, _ = self.index_segment_track + sl = slice(i_start[0], i_stop[0]) + return (self.previous_obs[sl] == -1).all() and (self.next_obs[sl] == -1).all() + + def remove_trash(self): + """ + Remove the lonely eddies (only 1 obs in segment, associated network number is 0) + """ + if self.first_is_trash(): + return self.extract_with_mask(self.track != 0) + else: + return self + + def plot(self, ax, ref=None, color_cycle=None, **kwargs): + """ + This function draws the path of each trajectory + + :param matplotlib.axes.Axes ax: ax to draw + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary + :param dict kwargs: keyword arguments for Axes.plot + :return: a list of matplotlib mappables + """ + kwargs = kwargs.copy() + if color_cycle is None: + color_cycle = self.COLORS + nb_colors = len(color_cycle) + mappables = list() + if "label" in kwargs: + kwargs["label"] = self.format_label(kwargs["label"]) + j = 0 + for i, _, _ in self.iter_on(self.segment_track_array): + nb = i.stop - i.start + if nb == 0: + continue + if nb_colors: + kwargs["color"] = color_cycle[j % nb_colors] + x, y = self.lon[i], self.lat[i] + if ref is not None: + x, y = wrap_longitude(x, y, ref, cut=True) + mappables.append(ax.plot(x, y, **kwargs)[0]) + j += 1 + return mappables + + def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None, return_mask=False): + """ + Remove short segments that don't connect several segments + + :param int nobs: Minimal number of observation to keep a segment + :param int ndays: Minimal number of days to keep a segment + :param int recursive: Run method N times more + :param int mask: if one or more observation of the segment are selected by mask, the segment is kept + + .. warning:: + It will remove short segment that splits from then merges with the same segment + """ + connexions = self.connexions(multi_network=True) + i0, i1, _ = self.index_segment_track + dt = self.time[i1 - 1] - self.time[i0] + 1 + nb = i1 - i0 + m = (dt >= ndays) * (nb >= nobs) + nb_connexions = array([len(connexions.get(i, tuple())) for i in where(~m)[0]]) + m[~m] = nb_connexions >= 2 + segments_keep = where(m)[0] + if mask is not None: + segments_keep = unique( + concatenate((segments_keep, self.segment_track_array[mask])) + ) + # get mask for selected obs + m = ~self.segment_mask(segments_keep) + if return_mask: + return ~m + self.track[m] = 0 + self.segment[m] = 0 + self.previous_obs[m] = -1 + self.previous_cost[m] = 0 + self.next_obs[m] = -1 + self.next_cost[m] = 0 + + m_previous = m[self.previous_obs] + self.previous_obs[m_previous] = -1 + self.previous_cost[m_previous] = 0 + m_next = m[self.next_obs] + self.next_obs[m_next] = -1 + self.next_cost[m_next] = 0 + + self.sort() + if recursive > 0: + self.remove_dead_end(nobs, ndays, recursive - 1) + + + + def extract_segment(self, segments, absolute=False): + """Extract given segments + + :param array,tuple,list segments: list of segment to extract + :param bool absolute: keep for compatibility, defaults to False + :return NetworkObservations: Return observations from selected segment + """ + if not absolute: + raise Exception("Not implemented") + return self.extract_with_mask(self.segment_mask(segments)) + + def segment_mask(self, segments): + """Get mask from list of segment + + :param list,array segments: absolute id of segment + """ + return generate_mask_from_ids( + array(segments), len(self), *self.index_segment_track + ) + + def get_mask_with_period(self, period): + """ + obtain mask within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :return: mask where period is defined + :rtype: np.array(bool) + + """ + dataset_period = self.period + p_min, p_max = period + if p_min > 0: + mask = self.time >= p_min + elif p_min < 0: + mask = self.time >= (dataset_period[0] - p_min) + else: + mask = ones(self.time.shape, dtype=bool_) + if p_max > 0: + mask *= self.time <= p_max + elif p_max < 0: + mask *= self.time <= (dataset_period[1] + p_max) + return mask + + def extract_with_period(self, period): + """ + Extract within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :return: Return all eddy trajectories in period + :rtype: NetworkObservations + + .. minigallery:: py_eddy_tracker.NetworkObservations.extract_with_period + """ + + return self.extract_with_mask(self.get_mask_with_period(period)) + + def extract_light_with_mask(self, mask, track_extra_variables=[]): + """extract data with mask, but only with variables used for coherence, aka self.array_variables + + :param mask: mask used to extract + :type mask: np.array(bool) + :return: new EddiesObservation with data wanted + :rtype: self + """ + + if isinstance(mask, slice): + nb_obs = mask.stop - mask.start + else: + nb_obs = mask.sum() + + # only time & contour_lon/lat_e/s + variables = ["time"] + self.array_variables + new = self.__class__( + size=nb_obs, + track_extra_variables=track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=self.array_variables, + only_variables=variables, + raw_data=self.raw_data, + ) + new.sign_type = self.sign_type + if nb_obs == 0: + logger.info("Empty dataset will be created") + else: + logger.info( + f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" + ) + + for field in variables + track_extra_variables: + logger.debug("Copy of field %s ...", field) + new.obs[field] = self.obs[field][mask] + + if ( + "previous_obs" in track_extra_variables + and "next_obs" in track_extra_variables + ): + # n & p must be re-index + n, p = self.next_obs[mask], self.previous_obs[mask] + # we add 2 for -1 index return index -1 + translate = -ones(len(self) + 1, dtype="i4") + translate[:-1][mask] = arange(nb_obs) + new.next_obs[:] = translate[n] + new.previous_obs[:] = translate[p] + + return new + + def extract_with_mask(self, mask): + """ + Extract a subset of observations. + + :param array(bool) mask: mask to select observations + :return: same object with selected observations + :rtype: self + """ + if isinstance(mask, slice): + nb_obs = mask.stop - mask.start + else: + nb_obs = mask.sum() + new = self.__class__.new_like(self, nb_obs) + new.sign_type = self.sign_type + if nb_obs == 0: + logger.info("Empty dataset will be created") + else: + logger.debug( + f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" + ) + for field in self.fields: + if field in ("next_obs", "previous_obs"): + continue + logger.debug("Copy of field %s ...", field) + new.obs[field] = self.obs[field][mask] + # n & p must be re-index + n, p = self.next_obs[mask], self.previous_obs[mask] + # we add 2 for -1 index return index -1 + translate = -ones(len(self) + 1, dtype="i4") + translate[:-1][mask] = arange(nb_obs) + new.next_obs[:] = translate[n] + new.previous_obs[:] = translate[p] + return new + + def analysis_coherence( + self, + date_function, + uv_params, + advection_mode="both", + n_days=14, + step_mesh=1.0 / 50, + output_name=None, + dissociate_network=False, + correct_close_events=0, + remove_dead_end=0, + ): + """Global function to analyse segments coherence, with network preprocessing. + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: nuber of days for advection + :param float step_mesh: step for particule mesh in degrees + :param str output_name: path/name for the output (without extension) to store the clean + network in .nc and the coherence results in .zarr. Works only for advection_mode = "both" + :param bool dissociate_network: If True apply + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.dissociate_network` + :param int correct_close_events: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.correct_close_events` + :param int remove_dead_end: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.remove_dead_end` + :return target_forward, target_bakward: 2D numpy.array with the eddy observation the + particles ended in after advection + :return target_forward, target_bakward: percentage of ending particles within the + eddy observation with regards to the starting number + """ + + if dissociate_network: + self.dissociate_network() + + if correct_close_events > 0: + self.correct_close_events(nb_days_max=correct_close_events) + + if remove_dead_end > 0: + network_clean = self.remove_dead_end(nobs=0, ndays=remove_dead_end) + else: + network_clean = self + + network_clean.numbering_segment() + + res = [] + if (advection_mode == "both") | (advection_mode == "forward"): + target_forward, pct_forward = network_clean.segment_coherence_forward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_forward, pct_forward] + + if (advection_mode == "both") | (advection_mode == "backward"): + target_backward, pct_backward = network_clean.segment_coherence_backward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_backward, pct_backward] + + if (output_name is not None) & (advection_mode == "both"): + # TODO : put some path verification? + # Save the clean network in netcdf + with netCDF4.Dataset(output_name + ".nc", "w") as fh: + network_clean.to_netcdf(fh) + # Save the results of particles advection in zarr + # zarr compression parameters + # TODO : check size? compression? + params_seg = dict() + params_pct = dict() + zg = zarr.open(output_name + ".zarr", mode="w") + zg.array("target_forward", target_forward, **params_seg) + zg.array("pct_forward", pct_forward, **params_pct) + zg.array("target_backward", target_backward, **params_seg) + zg.array("pct_backward", pct_backward, **params_pct) + + return network_clean, res + + def segment_coherence_backward( + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", + ): + """ + Percentage of particules and their targets after backward advection from a specific eddy. + + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date (see note) + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: days for advection + :param float step_mesh: step for particule mesh in degrees + :return: observations matchs, and percents + + .. note:: the param `date_function` should be something like : + + .. code-block:: python + + def date2file(julian_day): + date = datetime.timedelta(days=julian_day) + datetime.datetime( + 1950, 1, 1 + ) + + return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" + """ + shape = len(self), 2 + itb_final = -ones(shape, dtype="i4") + ptb_final = zeros(shape, dtype="i1") + + t_start, t_end = int(self.period[0]), int(self.period[1]) + + # dates = arange(t_start, t_start + n_days + 1) + dates = arange(t_start, min(t_start + n_days + 1, t_end + 1)) + first_files = [date_function(x) for x in dates] + + c = GridCollection.from_netcdf_list(first_files, dates, **uv_params) + first = True + range_start = t_start + n_days + range_end = t_end + 1 + + for _t in range(t_start + n_days, t_end + 1): + _timestamp = time.time() + t_shift = _t + + # skip first shift, because already included + if first: + first = False + else: + # add next date to GridCollection and delete last date + c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) + particle_candidate( + c, + self, + step_mesh, + _t, + itb_final, + ptb_final, + n_days=-n_days, + contour_start=contour_start, + contour_end=contour_end, + ) + logger.info( + ( + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" + ) + ) + + return itb_final, ptb_final + + def segment_coherence_forward( + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", + **kwargs, + ): + """ + Percentage of particules and their targets after forward advection from a specific eddy. + + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date (see note) + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: days for advection + :param float step_mesh: step for particule mesh in degrees + :return: observations matchs, and percents + + .. note:: the param `date_function` should be something like : + + .. code-block:: python + + def date2file(julian_day): + date = datetime.timedelta(days=julian_day) + datetime.datetime( + 1950, 1, 1 + ) + + return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" + """ + shape = len(self), 2 + itf_final = -ones(shape, dtype="i4") + ptf_final = zeros(shape, dtype="i1") + + t_start, t_end = int(self.period[0]), int(self.period[1]) + + dates = arange(t_start, min(t_start + n_days + 1, t_end + 1)) + first_files = [date_function(x) for x in dates] + + c = GridCollection.from_netcdf_list(first_files, dates, **uv_params) + first = True + range_start = t_start + range_end = t_end - n_days + 1 + + for _t in range(range_start, range_end): + _timestamp = time.time() + t_shift = _t + n_days + + # skip first shift, because already included + if first: + first = False + else: + # add next date to GridCollection and delete last date + c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) + particle_candidate( + c, + self, + step_mesh, + _t, + itf_final, + ptf_final, + n_days=n_days, + contour_start=contour_start, + contour_end=contour_end, + **kwargs, + ) + logger.info( + ( + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" + ) + ) + return itf_final, ptf_final + + def mask_obs_close_event(self, merging=True, spliting=True, dt=3): + """Build a mask of close observation from event + + :param n: Network + :param bool merging: select merging event, defaults to True + :param bool spliting: select splitting event, defaults to True + :param int dt: delta of time max , defaults to 3 + :return array: mask + """ + m = zeros(len(self), dtype="bool") + if merging: + i_target, ip1, ip2 = self.merging_event(triplet=True, only_index=True) + mask_follow_obs(m, self.previous_obs, self.time, ip1, dt) + mask_follow_obs(m, self.previous_obs, self.time, ip2, dt) + mask_follow_obs(m, self.next_obs, self.time, i_target, dt) + if spliting: + i_target, in1, in2 = self.splitting_event(triplet=True, only_index=True) + mask_follow_obs(m, self.next_obs, self.time, in1, dt) + mask_follow_obs(m, self.next_obs, self.time, in2, dt) + mask_follow_obs(m, self.previous_obs, self.time, i_target, dt) + return m + + def swap_track( + self, + length_main_max_after_event=2, + length_secondary_min_after_event=10, + delta_pct_max=-0.2, + ): + events = self.splitting_event(triplet=True, only_index=True) + count = 0 + for i_main, i1, i2 in zip(*events): + seg_main, _, seg2 = ( + self.segment_track_array[i_main], + self.segment_track_array[i1], + self.segment_track_array[i2], + ) + i_start, i_end, i0 = self.index_segment_track + # For splitting + last_index_main = i_end[seg_main - i0] - 1 + last_index_secondary = i_end[seg2 - i0] - 1 + last_main_next_obs = self.next_obs[last_index_main] + t_event, t_main_end, t_secondary_start, t_secondary_end = ( + self.time[i_main], + self.time[last_index_main], + self.time[i2], + self.time[last_index_secondary], + ) + dt_main, dt_secondary = ( + t_main_end - t_event, + t_secondary_end - t_secondary_start, + ) + delta_cost = self.previous_cost[i2] - self.previous_cost[i1] + if ( + dt_main <= length_main_max_after_event + and dt_secondary >= length_secondary_min_after_event + and last_main_next_obs == -1 + and delta_cost > delta_pct_max + ): + self.segment[i1 : last_index_main + 1] = self.segment[i2] + self.segment[i2 : last_index_secondary + 1] = self.segment[i_main] + count += 1 + logger.info("%d segmnent swap on %d", count, len(events[0])) + return self.sort() + + +class Network: + __slots__ = ( + "window", + "filenames", + "nb_input", + "buffer", + "memory", + ) + + NOGROUP = TrackEddiesObservations.NOGROUP + + def __init__(self, input_regex, window=5, intern=False, memory=False): + """ + Class to group observations by network + """ + self.window = window + self.buffer = Buffer(window, intern, memory) + self.memory = memory + + self.filenames = glob(input_regex) + self.filenames.sort() + self.nb_input = len(self.filenames) + + @classmethod + def from_eddiesobservations(cls, observations, *args, **kwargs): + new = cls("", *args, **kwargs) + new.filenames = observations + new.nb_input = len(new.filenames) + return new + def get_group_array(self, results, nb_obs): """With a loop on all pair of index, we will label each obs with a group number """ - nb_obs = array(nb_obs) + nb_obs = array(nb_obs, dtype="u4") day_start = nb_obs.cumsum() - nb_obs gr = empty(nb_obs.sum(), dtype="u4") gr[:] = self.NOGROUP + merge_id = list() id_free = 1 for i, j, ii, ij in results: gr_i = gr[slice(day_start[i], day_start[i] + nb_obs[i])] @@ -66,41 +2026,87 @@ def get_group_array(self, results, nb_obs): if m.any(): # Merge of group, ref over etu for i_, j_ in zip(ii[m], ij[m]): - gr_i_, gr_j_ = gr_i[i_], gr_j[j_] - gr[gr == gr_i_] = gr_j_ - return gr + g0, g1 = gr_i[i_], gr_j[j_] + if g0 > g1: + g0, g1 = g1, g0 + merge_id.append((g0, g1)) + gr_transfer = self.group_translator(id_free, set(merge_id)) + return gr_transfer[gr] + + @staticmethod + def group_translator(nb, duos): + """ + Create a translator with all duos + + :param int nb: size of translator + :param set((int, int)) duos: set of all groups that must be joined + + :Example: + + >>> NetworkObservations.group_translator(5, ((0, 1), (0, 2), (1, 3))) + [3, 3, 3, 3, 5] + """ + translate = arange(nb, dtype="u4") + for i, j in sorted(duos): + gr_i, gr_j = translate[i], translate[j] + if gr_i != gr_j: + apply_replace(translate, gr_i, gr_j) + return translate + + def group_observations(self, min_overlap=0.2, minimal_area=False, **kwargs): + """Store every interaction between identifications + + :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union, by default False + :param float min_overlap: minimum overlap area to associate observations, by default 0.2 + + :return: + :rtype: TrackEddiesObservations + """ - def group_observations(self, **kwargs): results, nb_obs = list(), list() # To display print only in INFO display_iteration = logger.getEffectiveLevel() == logging.INFO + for i, filename in enumerate(self.filenames): if display_iteration: print(f"{filename} compared to {self.window} next", end="\r") - # Load observations with function to buffered observations - xi, yi = self.load_contour(filename) + # Load observations with function to buffer observations + xi, yi = self.buffer.load_contour(filename) # Append number of observations by filename nb_obs.append(xi.shape[0]) for j in range(i + 1, min(self.window + i + 1, self.nb_input)): - xj, yj = self.load_contour(self.filenames[j]) + xj, yj = self.buffer.load_contour(self.filenames[j]) ii, ij = bbox_intersection(xi, yi, xj, yj) - m = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) > 0.2 + m = ( + vertice_overlap( + xi[ii], + yi[ii], + xj[ij], + yj[ij], + minimal_area=minimal_area, + min_overlap=min_overlap, + **kwargs, + ) + != 0 + ) results.append((i, j, ii[m], ij[m])) if display_iteration: print() gr = self.get_group_array(results, nb_obs) + nb_alone, nb_obs, nb_gr = (gr == self.NOGROUP).sum(), len(gr), len(unique(gr)) logger.info( - f"{(gr == self.NOGROUP).sum()} alone / {len(gr)} obs, {len(unique(gr))} groups" + f"{nb_alone} alone / {nb_obs} obs, {nb_gr} groups, " + f"{nb_alone * 100. / nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group" ) return gr - def build_dataset(self, group): + def build_dataset(self, group, raw_data=True): nb_obs = group.shape[0] - model = EddiesObservations.load_file(self.filenames[-1], raw_data=True) + model = TrackEddiesObservations.load_file(self.filenames[-1], raw_data=raw_data) eddies = TrackEddiesObservations.new_like(model, nb_obs) eddies.sign_type = model.sign_type - # Get new index to re-order observation by group + # Get new index to re-order observations by groups new_i = get_next_index(group) display_iteration = logger.getEffectiveLevel() == logging.INFO elements = eddies.elements @@ -109,7 +2115,12 @@ def build_dataset(self, group): for filename in self.filenames: if display_iteration: print(f"Load {filename} to copy", end="\r") - e = EddiesObservations.load_file(filename, raw_data=True) + if self.memory: + # Only if netcdf + with open(filename, "rb") as h: + e = TrackEddiesObservations.load_file(h, raw_data=raw_data) + else: + e = TrackEddiesObservations.load_file(filename, raw_data=raw_data) stop = i + len(e) sl = slice(i, stop) for element in elements: @@ -117,14 +2128,101 @@ def build_dataset(self, group): i = stop if display_iteration: print() - eddies = eddies.add_fields(("track",)) eddies.track[new_i] = group return eddies +@njit(cache=True) +def get_percentile_on_following_obs( + i, indexs, percents, follow_obs, t, segment, i_target, window, q=50, nb_min=1 +): + """Get stat on a part of segment close of an event + + :param int i: index to follow + :param array indexs: indexs from coherence + :param array percents: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int i_target: index of target + :param int window: time window of search + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return float : return statistic + """ + last_t, segment_follow = t[i], segment[i] + segment_target = segment[i_target] + percent_target = empty(window, dtype=percents.dtype) + j = 0 + while abs(last_t - t[i]) < window and i != -1 and segment_follow == segment[i]: + # Iter on primary & secondary + for index, percent in zip(indexs[i], percents[i]): + if index != -1 and segment[index] == segment_target: + percent_target[j] = percent + j += 1 + i = follow_obs[i] + if j < nb_min: + return nan + return percentile(percent_target[:j], q) + + +@njit(cache=True) +def get_percentile_around_event( + i, + i1, + i2, + ind, + pct, + follow_obs, + t, + segment, + window=10, + follow_parent=False, + q=50, + nb_min=1, +): + """Get stat around event + + :param array[int] i: Indexs of target + :param array[int] i1: Indexs of primary origin + :param array[int] i2: Indexs of secondary origin + :param array ind: indexs from coherence + :param array pct: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int window: time window of search, defaults to 10 + :param bool follow_parent: Follow parent instead of child, defaults to False + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return (array,array) : statistic for each event + """ + stat1 = empty(i.size, dtype=nb_types.float32) + stat2 = empty(i.size, dtype=nb_types.float32) + # iter on event + for j, (i_, i1_, i2_) in enumerate(zip(i, i1, i2)): + if follow_parent: + # We follow parent + stat1[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i1_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i2_, window, q, nb_min + ) + else: + # We follow child + stat1[j] = get_percentile_on_following_obs( + i1_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i2_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + return stat1, stat2 + + @njit(cache=True) def get_next_index(gr): - """Return for each obs index the new position to join all group""" + """Return for each obs index the new position to join all groups""" nb_obs_gr = bincount(gr) i_gr = nb_obs_gr.cumsum() - nb_obs_gr new_index = empty(gr.shape, dtype=uint32) @@ -132,3 +2230,139 @@ def get_next_index(gr): new_index[i] = i_gr[g] i_gr[g] += 1 return new_index + + +@njit(cache=True) +def apply_replace(x, x0, x1): + nb = x.shape[0] + for i in range(nb): + if x[i] == x0: + x[i] = x1 + + +@njit(cache=True) +def build_unique_array(id1, id2): + """Give a unique id for each (id1, id2) with id1 and id2 increasing monotonically""" + k = 0 + new_id = empty(id1.shape, dtype=id1.dtype) + id1_previous = id1[0] + id2_previous = id2[0] + for i in range(id1.shape[0]): + id1_, id2_ = id1[i], id2[i] + if id1_ != id1_previous or id2_ != id2_previous: + k += 1 + new_id[i] = k + id1_previous, id2_previous = id1_, id2_ + return new_id + + +@njit(cache=True) +def new_numbering(segs, start=0): + nb = len(segs) + s0 = segs[0] + j = start + for i in range(nb): + if segs[i] != s0: + s0 = segs[i] + j += 1 + segs[i] = j + + +@njit(cache=True) +def ptp(values): + return values.max() - values.min() + + +@njit(cache=True) +def generate_mask_from_ids(id_networks, nb, istart, iend, i0): + """From list of id, we generate a mask + + :param array id_networks: list of ids + :param int nb: size of mask + :param array istart: first index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param array iend: last index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param int i0: ref index from :py:meth:`~py_eddy_tracker.generic.build_index` + :return array: return a mask + """ + m = zeros(nb, dtype="bool") + for i in id_networks: + for j in range(istart[i - i0], iend[i - i0]): + m[j] = True + return m + + +@njit(cache=True) +def same_position(x0, y0, t0, x1, y1, t1, i00, i01, i0, i1): + """Return index of track/segment found in other dataset + + :param array x0: + :param array y0: + :param array t0: + :param array x1: + :param array y1: + :param array t1: + :param array i00: First index of track/segment/network in dataset0 + :param array i01: First index of track/segment/network in dataset1 + :param List(array) i0: list of array which contain index to order dataset0 + :param List(array) i1: list of array which contain index to order dataset1 + :return array: index of dataset1 which match with dataset0, -1 => no match + """ + nb0, nb1 = i00.size, i01.size + i_target = -ones(nb0, dtype="i4") + # To avoid to compare multiple time, if already match + used1 = zeros(nb1, dtype="bool") + for j0 in range(nb0): + for j1 in range(nb1): + if used1[j1]: + continue + test = True + for i0_, i1_ in zip(i0[j0], i1[j1]): + i0_ += i00[j0] + i1_ += i01[j1] + if t0[i0_] != t1[i1_] or x0[i0_] != x1[i1_] or y0[i0_] != y1[i1_]: + test = False + break + if test: + i_target[j0] = j1 + used1[j1] = True + break + return i_target + + +@njit(cache=True) +def mask_follow_obs(m, next_obs, time, indexs, dt=3): + """Generate a mask to select close obs in time from index + + :param array m: mask to fill with True + :param array next_obs: index of the next observation + :param array time: time of each obs + :param array indexs: index to start follow + :param int dt: delta of time max from index, defaults to 3 + """ + for i in indexs: + t0 = time[i] + m[i] = True + i_next = next_obs[i] + dt_ = abs(time[i_next] - t0) + while dt_ < dt and i_next != -1: + m[i_next] = True + i_next = next_obs[i_next] + dt_ = abs(time[i_next] - t0) + + +@njit(cache=True) +def get_period_with_index(t, i0, i1): + """Return peek to peek cover by each slice define by i0 and i1 + + :param array t: array which contain values to estimate spread + :param array i0: index which determine start of slice + :param array i1: index which determine end of slice + :return array: Peek to peek of t + """ + periods = np.empty(i0.size, t.dtype) + for i in range(i0.size): + if i1[i] == i0[i]: + periods[i] = 0 + continue + periods[i] = t[i0[i] : i1[i]].ptp() + return periods diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index 005de056..b39f7f83 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -2,18 +2,18 @@ """ Base class to manage eddy observation """ -import logging from datetime import datetime +from io import BufferedReader, BytesIO +import logging from tarfile import ExFileObject from tokenize import TokenError -import zarr +from Polygon import Polygon from matplotlib.cm import get_cmap -from matplotlib.collections import PolyCollection +from matplotlib.collections import LineCollection, PolyCollection from matplotlib.colors import Normalize from netCDF4 import Dataset -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( absolute, arange, @@ -22,6 +22,7 @@ ceil, concatenate, cos, + datetime64, digitize, empty, errstate, @@ -32,6 +33,7 @@ isnan, linspace, ma, + nan, ndarray, ones, percentile, @@ -41,9 +43,10 @@ where, zeros, ) +import packaging.version from pint import UnitRegistry from pint.errors import UndefinedUnitError -from Polygon import Polygon +import zarr from .. import VAR_DESCR, VAR_DESCR_inv, __version__ from ..generic import ( @@ -55,20 +58,49 @@ hist_numba, local_to_coordinates, reverse_index, + window_index, wrap_longitude, ) from ..poly import ( bbox_intersection, close_center, convexs, + create_meshed_particles, create_vertice, get_pixel_in_regular, + insidepoly, + poly_indexs, + reduce_size, vertice_overlap, - winding_number_poly, ) logger = logging.getLogger("pet") +# keep only major and minor version number +_software_version_reduced = packaging.version.Version( + "{v.major}.{v.minor}".format(v=packaging.version.parse(__version__)) +) +_display_check_warning = True + + +def _check_versions(version): + """Check if version of py_eddy_tracker used to create the file is compatible with software version + + if not, warn user with both versions + + :param version: string version of software used to create the file. If None, version was not provided + :type version: str, None + """ + if not _display_check_warning: + return + file_version = packaging.version.parse(version) if version is not None else None + if file_version is None or file_version < _software_version_reduced: + logger.warning( + "File was created with py-eddy-tracker version '%s' but software version is '%s'", + file_version, + _software_version_reduced, + ) + @njit(cache=True, fastmath=True) def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5): @@ -83,7 +115,7 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 # Focal f_right = lon0 f_left = f_right - (c - minor) - # Ellips center + # Ellipse center x_c = (f_left + f_right) * 0.5 nb_0, nb_1 = lat0.shape[0], lat1.shape[0] @@ -101,11 +133,33 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 if dx > major[j]: m[j, i] = False continue - d_normalize = dx ** 2 / major[j] ** 2 + dy ** 2 / minor ** 2 + d_normalize = dx**2 / major[j] ** 2 + dy**2 / minor**2 m[j, i] = d_normalize < 1.0 return m +class Table(object): + def __init__(self, values): + self.values = values + + def _repr_html_(self): + rows = list() + if isinstance(self.values, ndarray): + row = "\n".join([f"{v}" for v in self.values.dtype.names]) + rows.append(f"{row}") + for row in self.values: + row = "\n".join([f"{v}" for v in row]) + rows.append(f"{row}") + rows = "\n".join(rows) + return ( + f'' + f'' + f"{rows}" + f"
" + f"
" + ) + + class EddiesObservations(object): """ Class to store eddy observations. @@ -142,6 +196,27 @@ class EddiesObservations(object): "height_inner_contour", ] + COLORS = [ + "sienna", + "red", + "darkorange", + "gold", + "palegreen", + "limegreen", + "forestgreen", + "mediumblue", + "dodgerblue", + "lightskyblue", + "violet", + "blueviolet", + "darkmagenta", + "darkgrey", + "dimgrey", + "steelblue", + ] + + NB_COLORS = len(COLORS) + def __init__( self, size=0, @@ -176,6 +251,10 @@ def __eq__(self, other): return False return array_equal(self.obs, other.obs) + def get_color(self, i): + """Return colors as a cyclic list""" + return self.COLORS[i % self.NB_COLORS] + @property def sign_legend(self): return "Cyclonic" if self.sign_type != 1 else "Anticyclonic" @@ -189,7 +268,7 @@ def get_infos(self): bins_lat=(-90, -60, -15, 15, 60, 90), bins_amplitude=array((0, 1, 2, 3, 4, 5, 10, 500)), bins_radius=array((0, 15, 30, 45, 60, 75, 100, 200, 2000)), - nb_obs=self.observations.shape[0], + nb_obs=len(self), ) t0, t1 = self.period infos["t0"], infos["t1"] = t0, t1 @@ -200,33 +279,60 @@ def _repr_html_(self): infos = self.get_infos() return f"""{infos['nb_obs']} observations from {infos['t0']} to {infos['t1']} """ + def parse_varname(self, name): + return self[name] if isinstance(name, str) else name + def hist(self, varname, x, bins, percent=False, mean=False, nb=False): """Build histograms. - :param str varname: variable to use to compute stat - :param str x: variable to use to know in which bins + :param str,array varname: variable to use to compute stat + :param str,array x: variable to use to know in which bins :param array bins: - :param bool percent: normalize by sum of all bins + :param bool percent: normalized by sum of all bins :param bool mean: compute mean by bins :param bool nb: only count by bins :return: value by bins :rtype: array """ + x = self.parse_varname(x) if nb: - v = hist_numba(self[x], bins=bins)[0] + v = hist_numba(x, bins=bins)[0] else: - v = histogram(self[x], bins=bins, weights=self[varname])[0] + v = histogram(x, bins=bins, weights=self.parse_varname(varname))[0] if percent: v = v.astype("f4") / v.sum() * 100 elif mean: - v /= hist_numba(self[x], bins=bins)[0] + v /= hist_numba(x, bins=bins)[0] return v @staticmethod def box_display(value): - """Return value evenly spaced with few numbers""" + """Return values evenly spaced with few numbers""" return "".join([f"{v_:10.2f}" for v_ in value]) + @property + def fields(self): + return list(self.obs.dtype.names) + + def field_table(self): + """ + Produce description table of the fields available in this object + """ + rows = [("Name (Unit)", "Long name", "Scale factor", "Offset")] + names = self.fields + names.sort() + for field in names: + infos = VAR_DESCR[field] + rows.append( + ( + f"{infos.get('nc_name', field)} ({infos['nc_attr'].get('units', '')})", + infos["nc_attr"].get("long_name", "").capitalize(), + infos.get("scale_factor", ""), + infos.get("add_offset", ""), + ) + ) + return Table(rows) + def __repr__(self): """ Return general informations on dataset as strings. @@ -239,7 +345,7 @@ def __repr__(self): bins_lat = (-90, -60, -15, 15, 60, 90) bins_amplitude = array((0, 1, 2, 3, 4, 5, 10, 500)) bins_radius = array((0, 15, 30, 45, 60, 75, 100, 200, 2000)) - nb_obs = self.observations.shape[0] + nb_obs = len(self) return f""" | {nb_obs} observations from {t0} to {t1} ({period} days, ~{nb_obs / period:.0f} obs/day) | Speed area : {self.speed_area.sum() / period / 1e12:.2f} Mkm²/day @@ -310,11 +416,34 @@ def obs_dimension(cls, handler): if candidate in handler.dimensions.keys(): return candidate + def remove_fields(self, *fields): + """ + Copy with fields listed remove + """ + nb_obs = len(self) + fields = set(fields) + only_variables = set(self.fields) - fields + track_extra_variables = set(self.track_extra_variables) - fields + array_variables = set(self.array_variables) - fields + new = self.__class__( + size=nb_obs, + track_extra_variables=track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=array_variables, + only_variables=only_variables, + raw_data=self.raw_data, + ) + new.sign_type = self.sign_type + for name in new.fields: + logger.debug("Copy of field %s ...", name) + new.obs[name] = self.obs[name] + return new + def add_fields(self, fields=list(), array_fields=list()): """ Add a new field. """ - nb_obs = self.obs.shape[0] + nb_obs = len(self) new = self.__class__( size=nb_obs, track_extra_variables=list( @@ -322,44 +451,41 @@ def add_fields(self, fields=list(), array_fields=list()): ), track_array_variables=self.track_array_variables, array_variables=list(concatenate((self.array_variables, array_fields))), - only_variables=list( - concatenate((self.obs.dtype.names, fields, array_fields)) - ), + only_variables=list(concatenate((self.fields, fields, array_fields))), raw_data=self.raw_data, ) new.sign_type = self.sign_type - for field in self.obs.dtype.descr: - logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var] + for name in self.fields: + logger.debug("Copy of field %s ...", name) + new.obs[name] = self.obs[name] return new def add_rotation_type(self): new = self.add_fields(("type_cyc",)) - new.type_cyc = self.sign_type + new.type_cyc[:] = self.sign_type return new - def circle_contour(self, only_virtual=False): + def circle_contour(self, only_virtual=False, factor=1): """ - Set contours as a circles from radius and center data. + Set contours as circles from radius and center data. .. minigallery:: py_eddy_tracker.EddiesObservations.circle_contour """ angle = radians(linspace(0, 360, self.track_array_variables)) x_norm, y_norm = cos(angle), sin(angle) - radius_s = "contour_lon_s" in self.obs.dtype.names - radius_e = "contour_lon_e" in self.obs.dtype.names + radius_s = "contour_lon_s" in self.fields + radius_e = "contour_lon_e" in self.fields for i, obs in enumerate(self): if only_virtual and not obs["virtual"]: continue x, y = obs["lon"], obs["lat"] if radius_s: - r_s = obs["radius_s"] + r_s = obs["radius_s"] * factor obs["contour_lon_s"], obs["contour_lat_s"] = local_to_coordinates( x_norm * r_s, y_norm * r_s, x, y ) if radius_e: - r_e = obs["radius_e"] + r_e = obs["radius_e"] * factor obs["contour_lon_e"], obs["contour_lat_e"] = local_to_coordinates( x_norm * r_e, y_norm * r_e, x, y ) @@ -423,9 +549,9 @@ def merge(self, other): nb_obs_self = len(self) nb_obs = nb_obs_self + len(other) eddies = self.new_like(self, nb_obs) - other_keys = other.obs.dtype.fields.keys() - self_keys = self.obs.dtype.fields.keys() - for key in eddies.obs.dtype.fields.keys(): + other_keys = other.fields + self_keys = self.fields + for key in eddies.fields: eddies.obs[key][:nb_obs_self] = self.obs[key][:] if key in other_keys: eddies.obs[key][nb_obs_self:] = other.obs[key][:] @@ -450,48 +576,76 @@ def __iter__(self): for obs in self.obs: yield obs - def iter_on(self, xname: str, bins=None): + def iter_on(self, xname, window=None, bins=None): """ Yield observation group for each bin. - :param str xname: - :param array bins: bounds of each bin , - :return: Group observations - :rtype: self.__class__ + :param str,array xname: + :param float,None window: if defined we use a moving window with value like half window + :param array bins: bounds of each bin + :yield array,float,float: index in self, lower bound, upper bound + + .. minigallery:: py_eddy_tracker.EddiesObservations.iter_on """ - x = self[xname] - d = x[1:] - x[:-1] - if bins is None: - bins = arange(x.min(), x.max() + 2) - nb_bins = len(bins) - 1 - i = digitize(x, bins) - 1 - # Not monotonous - if (d < 0).any(): - i_sort = i.argsort() - i0, i1, _ = build_index(i[i_sort]) - m = ~(i0 == i1) - i0, i1 = i0[m], i1[m] - for i0_, i1_ in zip(i0, i1): - i_bins = i[i_sort[i0_]] - if i_bins == -1 or i_bins == nb_bins: - continue - yield i_sort[i0_:i1_], bins[i_bins], bins[i_bins + 1] + x = self.parse_varname(xname) + if window is not None: + x0 = arange(x.min(), x.max()) if bins is None else array(bins) + i_ordered, first_index, last_index = window_index(x, x0, window) + for x_, i0, i1 in zip(x0, first_index, last_index): + yield i_ordered[i0:i1], x_ - window, x_ + window else: - i0, i1, _ = build_index(i) - m = ~(i0 == i1) - i0, i1 = i0[m], i1[m] - for i0_, i1_ in zip(i0, i1): - i_bins = i[i0_] - yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] - - def align_on(self, other, var_name="time", **kwargs): + d = x[1:] - x[:-1] + if bins is None: + bins = arange(x.min(), x.max() + 2) + elif not isinstance(bins, ndarray): + bins = array(bins) + nb_bins = len(bins) - 1 + + # Not monotonous + if (d < 0).any(): + # If bins cover a small part of value + test, translate, x = iter_mode_reduce(x, bins) + # convert value in bins number + i = numba_digitize(x, bins) - 1 + # Order by bins + i_sort = i.argsort() + # If in reduced mode we will translate i_sort in full array index + i_sort_ = translate[i_sort] if test else i_sort + # Bound for each bins in sorting view + i0, i1, _ = build_index(i[i_sort]) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i_sort[i0_]] + if i_bins == -1 or i_bins == nb_bins: + continue + yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1] + else: + i = numba_digitize(x, bins) - 1 + i0, i1, _ = build_index(i) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i0_] + yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] + + def align_on(self, other, var_name="time", all_ref=False, **kwargs): """ - Align the time indexes of two datasets. + Align the variable indices of two datasets. + + :param other: other compare with self + :param str,tuple var_name: variable name to align or two array, defaults to "time" + :param bool all_ref: yield all value of ref, if false only common value, defaults to False + :yield array,array,float,float: index in self, index in other, lower bound, upper bound + + .. minigallery:: py_eddy_tracker.EddiesObservations.align_on """ - iter_self, iter_other = ( - self.iter_on(var_name, **kwargs), - other.iter_on(var_name, **kwargs), - ) + if isinstance(var_name, str): + iter_self = self.iter_on(var_name, **kwargs) + iter_other = other.iter_on(var_name, **kwargs) + else: + iter_self = self.iter_on(var_name[0], **kwargs) + iter_other = other.iter_on(var_name[1], **kwargs) indexs_other, b0_other, b1_other = iter_other.__next__() for indexs_self, b0_self, b1_self in iter_self: if b0_self > b0_other: @@ -501,15 +655,19 @@ def align_on(self, other, var_name="time", **kwargs): except StopIteration: break if b0_self < b0_other: + if all_ref: + yield indexs_self, empty( + 0, dtype=indexs_self.dtype + ), b0_self, b1_self continue yield indexs_self, indexs_other, b0_self, b1_self def insert_observations(self, other, index): - """Insert other obs in self at the index.""" + """Insert other obs in self at the given index.""" if not self.coherence(other): raise Exception("Observations with no coherence") - insert_size = len(other.obs) - self_size = len(self.obs) + insert_size = len(other) + self_size = len(self) new_size = self_size + insert_size if self_size == 0: self.observations = other.obs @@ -539,7 +697,7 @@ def distance(self, other): def __copy__(self): eddies = self.new_like(self, len(self)) - for k in self.obs.dtype.names: + for k in self.fields: eddies[k][:] = self[k][:] eddies.sign_type = self.sign_type return eddies @@ -547,9 +705,9 @@ def __copy__(self): def copy(self): return self.__copy__() - @staticmethod - def new_like(eddies, new_size: int): - return eddies.__class__( + @classmethod + def new_like(cls, eddies, new_size: int): + return cls( new_size, track_extra_variables=eddies.track_extra_variables, track_array_variables=eddies.track_array_variables, @@ -578,10 +736,13 @@ def zarr_dimension(filename): h = filename else: h = zarr.open(filename) + dims = list() for varname in h: - dims.extend(list(getattr(h, varname).shape)) - return set(dims) + shape = getattr(h, varname).shape + if len(shape) > len(dims): + dims = shape + return dims @classmethod def load_file(cls, filename, **kwargs): @@ -610,8 +771,13 @@ def load_file(cls, filename, **kwargs): ) if isinstance(filename, zarr.storage.MutableMapping): return cls.load_from_zarr(filename, **kwargs) - end = b".zarr" if isinstance(filename_, bytes) else ".zarr" - if filename_.endswith(end): + if isinstance(filename, (bytes, str)): + end = b".zarr" if isinstance(filename_, bytes) else ".zarr" + zarr_file = filename_.endswith(end) + else: + zarr_file = False + logger.info(f"loading file '{filename_}'") + if zarr_file: return cls.load_from_zarr(filename, **kwargs) else: return cls.load_from_netcdf(filename, **kwargs) @@ -630,26 +796,29 @@ def load_from_zarr( """Load data from zarr. :param str,store filename: path or store to load data - :param bool raw_data: If true load data without apply scale_factor and add_offset - :param None,list(str) remove_vars: List of variable name which will be not loaded + :param bool raw_data: If true load data without scale_factor and add_offset + :param None,list(str) remove_vars: List of variable name that will be not loaded :param None,list(str) include_vars: If defined only this variable will be loaded - :param None,dict indexs: Indexs to laad only a slice of data + :param None,dict indexs: Indexes to load only a slice of data :param int buffer_size: Size of buffer used to load zarr data :param class_kwargs: argument to set up observations class :return: Obsevations selected :return type: class """ - # FIXME must be investigate, in zarr no dimensions name (or could be add in attr) - array_dim = 50 + # FIXME if isinstance(filename, zarr.storage.MutableMapping): h_zarr = filename else: if not isinstance(filename, str): filename = filename.astype(str) h_zarr = zarr.open(filename) + + _check_versions(h_zarr.attrs.get("framework_version", None)) var_list = cls.build_var_list(list(h_zarr.keys()), remove_vars, include_vars) nb_obs = getattr(h_zarr, var_list[0]).shape[0] + track_array_variables = h_zarr.attrs["track_array_variables"] + if indexs is not None and "obs" in indexs: sl = indexs["obs"] sl = slice(sl.start, min(sl.stop, nb_obs)) @@ -662,29 +831,34 @@ def load_from_zarr( logger.warning("step of slice won't be use") logger.debug("%d observations will be load", nb_obs) kwargs = dict() - dims = cls.zarr_dimension(filename) - if array_dim in dims: - kwargs["track_array_variables"] = array_dim - kwargs["array_variables"] = list() - for variable in var_list: - if array_dim in h_zarr[variable].shape: - var_inv = VAR_DESCR_inv[variable] - kwargs["array_variables"].append(var_inv) - array_variables = kwargs.get("array_variables", list()) - kwargs["track_extra_variables"] = [] + + kwargs["track_array_variables"] = h_zarr.attrs.get( + "track_array_variables", track_array_variables + ) + + array_variables = list() + for variable in var_list: + if len(h_zarr[variable].shape) > 1: + var_inv = VAR_DESCR_inv[variable] + array_variables.append(var_inv) + kwargs["array_variables"] = array_variables + track_extra_variables = [] + for variable in var_list: var_inv = VAR_DESCR_inv[variable] if var_inv not in cls.ELEMENTS and var_inv not in array_variables: - kwargs["track_extra_variables"].append(var_inv) + track_extra_variables.append(var_inv) + kwargs["track_extra_variables"] = track_extra_variables kwargs["raw_data"] = raw_data kwargs["only_variables"] = ( None if include_vars is None else [VAR_DESCR_inv[i] for i in include_vars] ) kwargs.update(class_kwargs) eddies = cls(size=nb_obs, **kwargs) - for variable in var_list: + + for i_var, variable in enumerate(var_list): var_inv = VAR_DESCR_inv[variable] - logger.debug("%s will be loaded", variable) + logger.debug("%s will be loaded (%d/%d)", variable, i_var, len(var_list)) # find unit factor input_unit = h_zarr[variable].attrs.get("unit", None) if input_unit is None: @@ -729,10 +903,10 @@ def copy_data_to_zarr( :param zarr_dataset handler_zarr: :param array handler_eddies: :param slice zarr_dataset sl_obs: - :param int zarr_dataset buffer_size: - :param float zarr_dataset factor: - :param bool zarr_dataset raw_data: - :param None,float zarr_dataset scale_factor: + :param int buffer_size: + :param float factor: + :param bool raw_data: + :param None,float scale_factor: :param None,float add_offset: """ i_start, i_stop = sl_obs.start, sl_obs.stop @@ -740,6 +914,7 @@ def copy_data_to_zarr( i_start = 0 if i_stop is None: i_stop = handler_zarr.shape[0] + for i in range(i_start, i_stop, buffer_size): sl_in = slice(i, min(i + buffer_size, i_stop)) data = handler_zarr[sl_in] @@ -750,6 +925,7 @@ def copy_data_to_zarr( data -= add_offset if scale_factor is not None: data /= scale_factor + sl_out = slice(i - i_start, i - i_start + buffer_size) handler_eddies[sl_out] = data @@ -769,7 +945,7 @@ def load_from_netcdf( :param bool raw_data: If true load data without apply scale_factor and add_offset :param None,list(str) remove_vars: List of variable name which will be not loaded :param None,list(str) include_vars: If defined only this variable will be loaded - :param None,dict indexs: Indexs to laad only a slice of data + :param None,dict indexs: Indexes to load only a slice of data :param class_kwargs: argument to set up observations class :return: Obsevations selected :return type: class @@ -777,12 +953,14 @@ def load_from_netcdf( array_dim = "NbSample" if isinstance(filename, bytes): filename = filename.astype(str) - if isinstance(filename, ExFileObject): + if isinstance(filename, (ExFileObject, BufferedReader, BytesIO)): filename.seek(0) args, kwargs = ("in-mem-file",), dict(memory=filename.read()) else: args, kwargs = (filename,), dict() with Dataset(*args, **kwargs) as h_nc: + _check_versions(getattr(h_nc, "framework_version", None)) + var_list = cls.build_var_list( list(h_nc.variables.keys()), remove_vars, include_vars ) @@ -897,6 +1075,17 @@ def compare_units(input_unit, output_unit, name): input_unit, output_unit, ) + return factor + else: + return 1 + + @classmethod + def from_array(cls, arrays, **kwargs): + nb = arrays["time"].size + eddies = cls(size=nb, **kwargs) + for k, v in arrays.items(): + eddies.obs[k] = v + return eddies @classmethod def from_zarr(cls, handler): @@ -914,6 +1103,7 @@ def from_zarr(cls, handler): eddies.obs[variable] = handler.variables[variable][:] else: eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:] + eddies.sign_type = handler.rotation_type return eddies @classmethod @@ -932,13 +1122,14 @@ def from_netcdf(cls, handler): eddies.obs[variable] = handler.variables[variable][:] else: eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:] + eddies.sign_type = handler.rotation_type return eddies def propagate( self, previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model ): """ - Filled virtual obs (C). + Fill virtual obs (C). :param previous_obs: previous obs from current (A) :param current_obs: previous obs from virtual (B) @@ -996,44 +1187,67 @@ def intern(flag, public_label=False): labels = [VAR_DESCR[label]["nc_name"] for label in labels] return labels - def match(self, other, method="overlap", intern=False, cmin=0, **kwargs): + def match( + self, + other, + i_self=None, + i_other=None, + method="overlap", + intern=False, + cmin=0, + **kwargs, + ): """Return index and score computed on the effective contour. :param EddiesObservations other: Observations to compare + :param array[bool,int],None i_self: + Index or mask to subset observations, it could avoid to build a specific dataset. + :param array[bool,int],None i_other: + Index or mask to subset observations, it could avoid to build a specific dataset. :param str method: - "overlap": the score is computed with contours; - "circle": circles are computed and used for score (TODO) :param bool intern: if True, speed contour is used (default = effective contour) :param float cmin: 0 < cmin < 1, return only couples with score >= cmin :param dict kwargs: look at :py:meth:`vertice_overlap` - :return: return the indexes of the eddies in self coupled with eddies in + :return: return the indices of the eddies in self coupled with eddies in other and their associated score :rtype: (array(int), array(int), array(float)) .. minigallery:: py_eddy_tracker.EddiesObservations.match """ - # if method is "overlap" method will use contour to compute score, - # if method is "circle" method will apply a formula of circle overlap x_name, y_name = self.intern(intern) + if i_self is None: + i_self = slice(None) + if i_other is None: + i_other = slice(None) if method == "overlap": - i, j = bbox_intersection( - self[x_name], self[y_name], other[x_name], other[y_name] - ) - c = vertice_overlap( - self[x_name][i], - self[y_name][i], - other[x_name][j], - other[y_name][j], - **kwargs, - ) + x0, y0 = self[x_name][i_self], self[y_name][i_self] + x1, y1 = other[x_name][i_other], other[y_name][i_other] + i, j = bbox_intersection(x0, y0, x1, y1) + c = vertice_overlap(x0[i], y0[i], x1[j], y1[j], **kwargs) elif method == "close_center": - i, j, c = close_center( - self.latitude, self.longitude, other.latitude, other.longitude, **kwargs - ) - + x0, y0 = self.longitude[i_self], self.latitude[i_self] + x1, y1 = other.longitude[i_other], other.latitude[i_other] + i, j, c = close_center(x0, y0, x1, y1, **kwargs) m = c >= cmin # ajout >= pour garder la cmin dans la sélection return i[m], j[m], c[m] + @staticmethod + def re_reference_index(index, ref): + """ + Shift index with ref + + :param array,int index: local index to re ref + :param slice,array ref: + reference could be a slice in this case we juste add start to index + or could be indices and in this case we need to translate + """ + if isinstance(ref, slice): + return index + ref.start + else: + return ref[index] + @classmethod def cost_function_common_area(cls, xy_in, xy_out, distance, intern=False): """How does it work on x bound ? @@ -1109,7 +1323,7 @@ def shifted_ellipsoid_degrees_mask(self, other, minor=1.5, major=1.5): ) def fixed_ellipsoid_mask( - self, other, minor=50, major=100, only_east=False, shifted_ellips=False + self, other, minor=50, major=100, only_east=False, shifted_ellipse=False ): dist = self.distance(other).T accepted = dist < minor @@ -1128,18 +1342,18 @@ def fixed_ellipsoid_mask( if isinstance(minor, ndarray): minor = minor[index_self] # focal distance - f_degree = ((major ** 2 - minor ** 2) ** 0.5) / ( + f_degree = ((major**2 - minor**2) ** 0.5) / ( 111.2 * cos(radians(self.lat[index_self])) ) lon_self = self.lon[index_self] - if shifted_ellips: - x_center_ellips = lon_self - (major - minor) / 2 + if shifted_ellipse: + x_center_ellipse = lon_self - (major - minor) / 2 else: - x_center_ellips = lon_self + x_center_ellipse = lon_self - lon_left_f = x_center_ellips - f_degree - lon_right_f = x_center_ellips + f_degree + lon_left_f = x_center_ellipse - f_degree + lon_right_f = x_center_ellipse + f_degree dist_left_f = distance( lon_left_f, @@ -1163,7 +1377,7 @@ def fixed_ellipsoid_mask( return accepted.T @staticmethod - def basic_formula_ellips_major_axis( + def basic_formula_ellipse_major_axis( lats, cmin=1.5, cmax=10.0, c0=1.5, lat1=13.5, lat2=5.0, degrees=False ): """Give major axis in km with a given latitude""" @@ -1189,21 +1403,27 @@ def solve_conflict(cost): @staticmethod def solve_simultaneous(cost): - """Write something (TODO)""" + """Deduce link from cost matrix. + + :param array(float) cost: Cost for each available link + :return: return a boolean mask array, True for each valid couple + :rtype: array(bool) + """ mask = ~cost.mask - # Count number of link by self obs and other obs - self_links = mask.sum(axis=1) - other_links = mask.sum(axis=0) + if mask.size == 0: + return mask + # Count number of links by self obs and other obs + self_links, other_links = sum_row_column(mask) max_links = max(self_links.max(), other_links.max()) if max_links > 5: logger.warning("One observation have %d links", max_links) - # If some obs have multiple link, we keep only one link by eddy + # If some obs have multiple links, we keep only one link by eddy eddies_separation = 1 < self_links eddies_merge = 1 < other_links test = eddies_separation.any() or eddies_merge.any() if test: - # We extract matrix which contains concflict + # We extract matrix that contains conflict obs_linking_to_self = mask[eddies_separation].any(axis=0) obs_linking_to_other = mask[:, eddies_merge].any(axis=1) i_self_keep = where(obs_linking_to_other + eddies_separation)[0] @@ -1226,13 +1446,13 @@ def solve_simultaneous(cost): security_increment = 0 while False in cost_reduce.mask: if security_increment > max_iteration: - # Maybe check if the size decrease if not rise an exception + # Maybe check if the size decreases if not rise an exception # x_i, y_i = where(-cost_reduce.mask) raise Exception("To many iteration: %d" % security_increment) security_increment += 1 i_min_value = cost_reduce.argmin() i, j = floor(i_min_value / shape[1]).astype(int), i_min_value % shape[1] - # Set to False all link + # Set to False all links mask[i_self_keep[i]] = False mask[:, i_other_keep[j]] = False cost_reduce.mask[i] = True @@ -1246,19 +1466,19 @@ def solve_simultaneous(cost): @staticmethod def solve_first(cost, multiple_link=False): mask = ~cost.mask - # Count number of link by self obs and other obs + # Count number of links by self obs and other obs self_links = mask.sum(axis=1) other_links = mask.sum(axis=0) max_links = max(self_links.max(), other_links.max()) if max_links > 5: logger.warning("One observation have %d links", max_links) - # If some obs have multiple link, we keep only one link by eddy + # If some obs have multiple links, we keep only one link by eddy eddies_separation = 1 < self_links eddies_merge = 1 < other_links test = eddies_separation.any() or eddies_merge.any() if test: - # We extract matrix which contains concflict + # We extract matrix that contains conflict obs_linking_to_self = mask[eddies_separation].any(axis=0) obs_linking_to_other = mask[:, eddies_merge].any(axis=1) i_self_keep = where(obs_linking_to_other + eddies_separation)[0] @@ -1294,7 +1514,7 @@ def solve_first(cost, multiple_link=False): return mask def solve_function(self, cost_matrix): - return where(self.solve_simultaneous(cost_matrix)) + return numba_where(self.solve_simultaneous(cost_matrix)) def post_process_link(self, other, i_self, i_other): if unique(i_other).shape[0] != i_other.shape[0]: @@ -1331,8 +1551,7 @@ def to_zarr(self, handler, **kwargs): handler.attrs["track_array_variables"] = self.track_array_variables handler.attrs["array_variables"] = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - for ori_name in fields: + for ori_name in self.fields: # Patch for a transition name = ori_name # @@ -1377,12 +1596,9 @@ def to_netcdf(self, handler, **kwargs): handler.track_array_variables = self.track_array_variables handler.array_variables = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - fields_ = array( - [VAR_DESCR[field[0]]["nc_name"] for field in self.observations.dtype.descr] - ) + fields_ = array([VAR_DESCR[field]["nc_name"] for field in self.fields]) i = fields_.argsort() - for ori_name in array(fields)[i]: + for ori_name in array(self.fields)[i]: # Patch for a transition name = ori_name # @@ -1449,6 +1665,33 @@ def create_variable( except ValueError: logger.warning("Data is empty") + @staticmethod + def get_filters_zarr(name): + """Get filters to store in zarr for known variable + + :param str name: private variable name + :return list: filters list + """ + content = VAR_DESCR.get(name) + filters = list() + store_dtype = content["output_type"] + scale_factor, add_offset = content.get("scale_factor", None), content.get( + "add_offset", None + ) + if scale_factor is not None or add_offset is not None: + if add_offset is None: + add_offset = 0 + filters.append( + zarr.FixedScaleOffset( + offset=add_offset, + scale=1 / scale_factor, + dtype=content["nc_type"], + astype=store_dtype, + ) + ) + filters.extend(content.get("filters", [])) + return filters + def create_variable_zarr( self, handler_zarr, @@ -1533,12 +1776,13 @@ def write_file( filename = filename.replace(".nc", ".zarr") if filename.endswith(".zarr"): zarr_flag = True - logger.info("Store in %s", filename) + logger.info("Store in %s (%d observations)", filename, len(self)) if zarr_flag: handler = zarr.open(filename, "w") self.to_zarr(handler, **kwargs) else: - with Dataset(filename, "w", format="NETCDF4") as handler: + nc_format = kwargs.pop("format", "NETCDF4") + with Dataset(filename, "w", format=nc_format) as handler: self.to_netcdf(handler, **kwargs) @property @@ -1560,13 +1804,33 @@ def set_global_attr_netcdf(self, h_nc): for key, item in self.global_attr.items(): h_nc.setncattr(key, item) + def mask_from_polygons(self, polygons): + """ + Return mask for all observations in one of polygons list + + :param list((array,array)) polygons: list of x/y array which be used to identify observations + """ + x, y = polygons[0] + m = insidepoly( + self.longitude, self.latitude, x.reshape((1, -1)), y.reshape((1, -1)) + ) + for x, y in polygons[1:]: + m_ = ~m + m[m_] = insidepoly( + self.longitude[m_], + self.latitude[m_], + x.reshape((1, -1)), + y.reshape((1, -1)), + ) + return m + def extract_with_area(self, area, **kwargs): """ Extract geographically with a bounding box. :param dict area: 4 coordinates in a dictionary to specify bounding box (lower left corner and upper right corner) :param dict kwargs: look at :py:meth:`extract_with_mask` - :return: Return all eddy tracks which are in bounds + :return: Return all eddy trajetories in bounds :rtype: EddiesObservations .. code-block:: python @@ -1575,12 +1839,29 @@ def extract_with_area(self, area, **kwargs): .. minigallery:: py_eddy_tracker.EddiesObservations.extract_with_area """ - mask = (self.latitude > area["llcrnrlat"]) * (self.latitude < area["urcrnrlat"]) + lat0 = area.get("llcrnrlat", -90) + lat1 = area.get("urcrnrlat", 90) + mask = (self.latitude > lat0) * (self.latitude < lat1) lon0 = area["llcrnrlon"] lon = (self.longitude - lon0) % 360 + lon0 mask *= (lon > lon0) * (lon < area["urcrnrlon"]) return self.extract_with_mask(mask, **kwargs) + @property + def time_datetime64(self): + dt = (datetime64("1970-01-01") - datetime64("1950-01-01")).astype("i8") + return (self.time - dt).astype("datetime64[D]") + + def time_sub_sample(self, t0, time_step): + """ + Time sub sampling + + :param int,float t0: reference time that will be keep + :param int,float time_step: keep every observation spaced by time_step + """ + mask = (self.time - t0) % time_step == 0 + return self.extract_with_mask(mask) + def extract_with_mask(self, mask): """ Extract a subset of observations. @@ -1589,11 +1870,6 @@ def extract_with_mask(self, mask): :return: same object with selected observations :rtype: self """ - # ça n'existe plus ça?? - # full_path=False, - # remove_incomplete=False, - # compress_id=False, - # reject_virtual=False nb_obs = mask.sum() new = self.__class__.new_like(self, nb_obs) @@ -1601,10 +1877,9 @@ def extract_with_mask(self, mask): if nb_obs == 0: logger.warning("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] return new def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): @@ -1614,7 +1889,7 @@ def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): :param matplotlib.axes.Axes ax: matplotlib axe used to draw :param str,array,None name: variable used to fill the contour, if None all elements have the same color - :param float,None ref: if define use like west bound + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param float factor: multiply value by :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.scatter` :return: scatter mappable @@ -1626,7 +1901,7 @@ def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): x = (x - ref) % 360 + ref kwargs = kwargs.copy() if name is not None and "c" not in kwargs: - v = self[name] if isinstance(name, str) else name + v = self.parse_varname(name) kwargs["c"] = v * factor return ax.scatter(x, self.latitude, **kwargs) @@ -1646,7 +1921,7 @@ def filled( """ :param matplotlib.axes.Axes ax: matplotlib axe used to draw :param str,array,None varname: variable used to fill the contours, or an array of same size than obs - :param float,None ref: if define use like west bound? + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param bool intern: if True draw speed contours instead of effective contours :param str cmap: matplotlib colormap name :param int,None lut: Number of colors in the colormap @@ -1671,7 +1946,7 @@ def filled( if "facecolors" not in kwargs: kwargs = kwargs.copy() cmap = get_cmap(cmap, lut) - v = (self[varname] if isinstance(varname, str) else varname) * factor + v = self.parse_varname(varname) * factor if vmin is None: vmin = v.min() if vmax is None: @@ -1679,6 +1954,8 @@ def filled( v = (v - vmin) / (vmax - vmin) colors = [cmap(v_) for v_ in v] kwargs["facecolors"] = colors + if "label" in kwargs: + kwargs["label"] = self.format_label(kwargs["label"]) c = PolyCollection(verts, **kwargs) ax.add_collection(c) c.cmap = cmap @@ -1783,7 +2060,7 @@ def bins_stat(self, xname, bins=None, yname=None, method=None, mask=None): .. minigallery:: py_eddy_tracker.EddiesObservations.bins_stat """ - v = self[xname] if isinstance(xname, str) else xname + v = self.parse_varname(xname) mask = self.merge_filters(mask) v = v[mask] if bins is None: @@ -1791,7 +2068,7 @@ def bins_stat(self, xname, bins=None, yname=None, method=None, mask=None): y, x = hist_numba(v, bins=bins) x = (x[1:] + x[:-1]) / 2 if method == "mean": - y_v = self[yname] if isinstance(yname, str) else yname + y_v = self.parse_varname(yname) y_v = y_v[mask] y_, _ = histogram(v, bins=bins, weights=y_v) with errstate(divide="ignore", invalid="ignore"): @@ -1806,11 +2083,43 @@ def format_label(self, label): nb_obs=len(self), ) + def display_color(self, ax, field, ref=None, intern=False, **kwargs): + """Plot colored contour of eddies + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array field: color field + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool intern: if True, draw the speed contour + :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` + + .. minigallery:: py_eddy_tracker.EddiesObservations.display_color + """ + xname, yname = self.intern(intern) + x, y = self[xname], self[yname] + + if ref is not None: + # TODO : maybe buggy with global display + shape_out = x.shape + x, y = wrap_longitude(x.reshape(-1), y.reshape(-1), ref) + x, y = x.reshape(shape_out), y.reshape(shape_out) + + c = self.parse_varname(field) + cmap = get_cmap(kwargs.pop("cmap", "Spectral_r")) + cmin, cmax = kwargs.pop("vmin", c.min()), kwargs.pop("vmax", c.max()) + colors = cmap((c - cmin) / (cmax - cmin)) + lines = LineCollection( + [create_vertice(i, j) for i, j in zip(x, y)], colors=colors, **kwargs + ) + ax.add_collection(lines) + lines.cmap = cmap + lines.norm = Normalize(vmin=cmin, vmax=cmax) + return lines + def display(self, ax, ref=None, extern_only=False, intern_only=False, **kwargs): """Plot the speed and effective (dashed) contour of the eddies :param matplotlib.axes.Axes ax: matplotlib axe used to draw - :param float,None ref: western longitude reference used + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param bool extern_only: if True, draw only the effective contour :param bool intern_only: if True, draw only the speed contour :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.plot` @@ -1876,6 +2185,24 @@ def is_convex(self, intern=False): xname, yname = self.intern(intern) return convexs(self[xname], self[yname]) + def contains(self, x, y, intern=False): + """ + Return index of contour containing (x,y) + + :param array x: longitude + :param array y: latitude + :param bool intern: If true use speed contour instead of effective contour + :return: indexs, -1 if no index + :rtype: array[int32] + """ + xname, yname = self.intern(intern) + m = ~(isnan(x) + isnan(y)) + i = -ones(x.shape, dtype="i4") + + if x.size != 0 and m.any(): + i[m] = poly_indexs(x[m], y[m], self[xname], self[yname]) + return i + def inside(self, x, y, intern=False): """ True for each point inside the effective contour of an eddy @@ -1929,7 +2256,7 @@ def grid_count(self, bins, intern=False, center=False, filter=slice(None)): x_ref = ((self.longitude[filter] - x0) % 360 + x0 - 180).reshape(-1, 1) x_contour, y_contour = self[x_name][filter], self[y_name][filter] grid_count_pixel_in( - grid, + grid.data, x_contour, y_contour, x_ref, @@ -1948,10 +2275,10 @@ def grid_count(self, bins, intern=False, center=False, filter=slice(None)): def grid_box_stat(self, bins, varname, method=50, data=None, filter=slice(None)): """ - Compute mean of eddies in each bin + Get percentile of eddies in each bin :param (numpy.array,numpy.array) bins: bins (grid) to count - :param str varname: variable to apply the method + :param str varname: variable to apply the method if data is None and will be output name :param str,float method: method to apply. If float, use ? :param array data: Array used to compute stat if defined :param array,mask,slice filter: keep the data selected with the filter @@ -1999,7 +2326,7 @@ def grid_stat(self, bins, varname, data=None): Return the mean of the eddies' variable in each bin :param (numpy.array,numpy.array) bins: bins (grid) to compute the mean on - :param str varname: name of variable to compute the mean on + :param str varname: name of variable to compute the mean on and output grid_name :param array data: Array used to compute stat if defined :return: return the gridde mean variable :rtype: py_eddy_tracker.dataset.grid.RegularGridDataset @@ -2032,7 +2359,7 @@ def grid_stat(self, bins, varname, data=None): return regular_grid def interp_grid( - self, grid_object, varname, method="center", dtype=None, intern=None + self, grid_object, varname, i=None, method="center", dtype=None, intern=None ): """ Interpolate a grid on a center or contour with mean, min or max method @@ -2040,24 +2367,34 @@ def interp_grid( :param grid_object: Handler of grid to interp :type grid_object: py_eddy_tracker.dataset.grid.RegularGridDataset :param str varname: Name of variable to use + :param array[bool,int],None i: + Index or mask to subset observations, it could avoid to build a specific dataset. :param str method: 'center', 'mean', 'max', 'min', 'nearest' :param str dtype: if None we use var dtype :param bool intern: Use extern or intern contour + + .. minigallery:: py_eddy_tracker.EddiesObservations.interp_grid """ if method in ("center", "nearest"): - return grid_object.interp(varname, self.longitude, self.latitude, method) + x, y = self.longitude, self.latitude + if i is not None: + x, y = x[i], y[i] + return grid_object.interp(varname, x, y, method) elif method in ("min", "max", "mean", "count"): x0 = grid_object.x_bounds[0] x_name, y_name = self.intern(False if intern is None else intern) x_ref = ((self.longitude - x0) % 360 + x0 - 180).reshape(-1, 1) x, y = (self[x_name] - x_ref) % 360 + x_ref, self[y_name] + if i is not None: + x, y = x[i], y[i] grid = grid_object.grid(varname) - result = empty(self.shape, dtype=grid.dtype if dtype is None else dtype) + result = empty(x.shape[0], dtype=grid.dtype if dtype is None else dtype) min_method = method == "min" grid_stat( grid_object.x_c, grid_object.y_c, - -grid if min_method else grid, + -grid.data if min_method else grid.data, + grid.mask, x, y, result, @@ -2071,29 +2408,48 @@ def interp_grid( @property def period(self): """ - Give the time coverage + Give the time coverage. If collection is empty, return nan,nan :return: first and last date :rtype: (int,int) """ if self.period_ is None: - self.period_ = self.time.min(), self.time.max() + if self.time.size < 1: + self.period_ = nan, nan + else: + self.period_ = self.time.min(), self.time.max() return self.period_ @property def nb_days(self): - """Return period days cover by dataset + """Return period in days covered by the dataset :return: Number of days :rtype: int """ return self.period[1] - self.period[0] + 1 + def create_particles(self, step, intern=True): + """Create particles inside contour (Default : speed contour). Avoid creating too large numpy arrays, only to be masked + + :param step: step for particles + :type step: float + :param bool intern: If true use speed contour instead of effective contour + :return: lon, lat and indices of particles + :rtype: tuple(np.array) + """ + + xname, yname = self.intern(intern) + return create_meshed_particles(self[xname], self[yname], step) + + def empty_dataset(self): + return self.new_like(self, 0) + @njit(cache=True) def grid_count_(grid, i, j): """ - Add one to each index + Add 1 to each index """ for i_, j_ in zip(i, j): grid[i_, j_] += 1 @@ -2116,7 +2472,7 @@ def grid_count_pixel_in( y_c, ): """ - Count how many time a pixel is used. + Count how many times a pixel is used. :param array grid: :param array x: x for all contour @@ -2136,6 +2492,7 @@ def grid_count_pixel_in( for i_ in range(nb): x_, y_, x_ref_ = x[i_], y[i_], x_ref[i_] x_ = (x_ - x_ref_) % 360 + x_ref_ + x_, y_ = reduce_size(x_, y_) v = create_vertice(x_, y_) (x_start, x_stop), (y_start, y_stop) = bbox_indice_regular( v, @@ -2147,38 +2504,10 @@ def grid_count_pixel_in( is_circular, x_size, ) - i, j = get_pixel_in_regular(v, x_c, y_c, x_start, x_stop, y_start, y_stop) grid_count_(grid, i, j) -@njit(cache=True) -def insidepoly(x_p, y_p, x_c, y_c): - """ - True for each postion inside a contour - - :param array x_p: longitude to test - :param array y_p: latitude to test - :param array x_c: longitude of contours - :param array y_c: latitude of contours - """ - nb_p = x_p.shape[0] - nb_c = x_c.shape[0] - flag = zeros(nb_p, dtype=numba_types.bool_) - for i in range(nb_c): - x_c_min, y_c_min = x_c[i].min(), y_c[i].min() - x_c_max, y_c_max = x_c[i].max(), y_c[i].max() - v = create_vertice(x_c[i], y_c[i]) - for j in range(nb_p): - if flag[j]: - continue - x, y = x_p[j], y_p[j] - if x > x_c_min and x < x_c_max and y > y_c_min and y < y_c_max: - if winding_number_poly(x, y, v) != 0: - flag[j] = True - return flag - - @njit(cache=True) def grid_box_stat(x_c, y_c, grid, mask, x, y, value, circular=False, method=50): """ @@ -2223,7 +2552,6 @@ def grid_box_stat(x_c, y_c, grid, mask, x, y, value, circular=False, method=50): if i_ != i0 or j_ != j0: # apply method and store result grid[i_, j_] = percentile(values, method) - # grid[i_, j_] = len(values) mask[i_, j_] = False # start new group i0, j0 = i_, j_ @@ -2233,25 +2561,28 @@ def grid_box_stat(x_c, y_c, grid, mask, x, y, value, circular=False, method=50): @njit(cache=True) -def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): +def grid_stat(x_c, y_c, grid, mask, x, y, result, circular=False, method="mean"): """ Compute the mean or the max of the grid for each contour :param array_like x_c: the grid longitude coordinates :param array_like y_c: the grid latitude coordinates :param array_like grid: grid value + :param array[bool] mask: mask for invalid value :param array_like x: longitude of contours :param array_like y: latitude of contours :param array_like result: return values :param bool circular: True if grid is wrappable :param str method: 'mean', 'max' """ + # FIXME : how does it work on grid bound nb = result.shape[0] xstep, ystep = x_c[1] - x_c[0], y_c[1] - y_c[0] x0, y0 = x_c - xstep / 2.0, y_c - ystep / 2.0 nb_x = x_c.shape[0] max_method = "max" == method mean_method = "mean" == method + count_method = "count" == method for elt in range(nb): v = create_vertice(x[elt], y[elt]) (x_start, x_stop), (y_start, y_stop) = bbox_indice_regular( @@ -2259,15 +2590,27 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): ) i, j = get_pixel_in_regular(v, x_c, y_c, x_start, x_stop, y_start, y_stop) - if mean_method: + if count_method: + result[elt] = i.shape[0] + elif mean_method: v_sum = 0 + nb_ = 0 for i_, j_ in zip(i, j): + if mask[i_, j_]: + continue v_sum += grid[i_, j_] - result[elt] = v_sum / i.shape[0] + nb_ += 1 + # FIXME : how does it work on grid bound, + if nb_ == 0: + result[elt] = nan + else: + result[elt] = v_sum / nb_ elif max_method: v_max = -1e40 for i_, j_ in zip(i, j): - v_max = max(v_max, grid[i_, j_]) + values = grid[i_, j_] + # FIXME must use mask + v_max = max(v_max, values) result[elt] = v_max @@ -2281,3 +2624,90 @@ def elements(self): elements = super().elements elements.extend(["track", "segment_size", "dlon", "dlat"]) return list(set(elements)) + + +@njit(cache=True) +def numba_where(mask): + """Usefull when mask is close to be empty""" + return where(mask) + + +@njit(cache=True) +def sum_row_column(mask): + """ + Compute sum on row and column at same time + """ + nb_x, nb_y = mask.shape + row_sum = zeros(nb_x, dtype=numba_types.int32) + column_sum = zeros(nb_y, dtype=numba_types.int32) + for i in range(nb_x): + for j in range(nb_y): + if mask[i, j]: + row_sum[i] += 1 + column_sum[j] += 1 + return row_sum, column_sum + + +@njit(cache=True) +def numba_digitize(values, bins): + # Check if bins are regular + nb_bins = bins.shape[0] + step = bins[1] - bins[0] + bin_previous = bins[1] + for i in range(2, nb_bins): + bin_current = bins[i] + if step != (bin_current - bin_previous): + # If bins are not regular + return digitize(values, bins) + bin_previous = bin_current + nb_values = values.shape[0] + out = empty(nb_values, dtype=numba_types.int64) + up, down = bins[0], bins[-1] + for i in range(nb_values): + v_ = values[i] + if v_ >= down: + out[i] = nb_bins + continue + if v_ < up: + out[i] = 0 + continue + out[i] = (v_ - bins[0]) / step + 1 + return out + + +@njit(cache=True) +def iter_mode_reduce(x, bins): + """ + Test if we could use a reduce mode + + :param array x: array to divide in group + :param array bins: array which defined bounds between each group + :return: If reduce mode, translator, and reduce x + """ + nb = x.shape[0] + # If we use less than half value + limit = nb // 2 + # low and up + x0, x1 = bins[0], bins[-1] + m = empty(nb, dtype=numba_types.bool_) + # To count number of value cover by bins + c = 0 + for i in range(nb): + x_ = x[i] + test = (x_ >= x0) * (x_ <= x1) + m[i] = test + if test: + c += 1 + # If number value exceed limit + if c > limit: + return False, empty(0, dtype=numba_types.int_), x + # Indices to be able to translate in full index array + indices = empty(c, dtype=numba_types.int_) + x_ = empty(c, dtype=x.dtype) + j = 0 + for i in range(nb): + if m[i]: + indices[j] = i + x_[j] = x[i] + j += 1 + return True, indices, x_ diff --git a/src/py_eddy_tracker/observations/tracking.py b/src/py_eddy_tracker/observations/tracking.py index c67d130a..fa1c1f93 100644 --- a/src/py_eddy_tracker/observations/tracking.py +++ b/src/py_eddy_tracker/observations/tracking.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ -Class to manage observations gathered in track +Class to manage observations gathered in trajectories """ -import logging from datetime import datetime, timedelta +import logging from numba import njit from numpy import ( @@ -16,27 +16,25 @@ degrees, empty, histogram, - interp, + int_, median, nan, ones, radians, sin, unique, - where, zeros, ) -from Polygon import Polygon from .. import VAR_DESCR_inv, __version__ from ..generic import build_index, cumsum_by_track, distance, split_line, wrap_longitude -from ..poly import create_vertice_from_2darray, merge, polygon_overlap -from .observation import EddiesObservations +from ..poly import bbox_intersection, merge, vertice_overlap +from .groups import GroupEddiesObservations, get_missing_indices logger = logging.getLogger("pet") -class TrackEddiesObservations(EddiesObservations): +class TrackEddiesObservations(GroupEddiesObservations): """Class to practice Tracking on observations""" __slots__ = ("__obs_by_track", "__first_index_of_track", "__nb_track") @@ -70,6 +68,10 @@ def __init__(self, *args, **kwargs): self.__obs_by_track = None self.__nb_track = None + def track_slice(self, track): + i0 = self.index_from_track[track] + return slice(i0, i0 + self.nb_obs_by_track[track]) + def iter_track(self): """ Yield track @@ -79,10 +81,30 @@ def iter_track(self): continue yield self.index(slice(i0, i0 + nb)) + def get_missing_indices(self, dt): + """Find indices where observations are missing. + + :param int,float dt: theorical delta time between 2 observations + """ + return get_missing_indices( + self.time, + self.track, + dt=dt, + flag_untrack=False, + indice_untrack=self.NOGROUP, + ) + + def fix_next_previous_obs(self): + """Function used after 'insert_virtual', to correct next_obs and + previous obs. + """ + + pass + @property def nb_tracks(self): """ - Will count and send number of track + Count and return number of track """ if self.__nb_track is None: if len(self) == 0: @@ -96,7 +118,7 @@ def __repr__(self): t0, t1 = self.period period = t1 - t0 + 1 nb = self.nb_obs_by_track - nb_obs = self.observations.shape[0] + nb_obs = len(self) m = self.virtual.astype("bool") nb_m = m.sum() bins_t = (1, 30, 90, 180, 270, 365, 1000, 10000) @@ -124,8 +146,8 @@ def __repr__(self): return content def add_distance(self): - """Add a field of distance (m) between to consecutive observation, 0 for the last observation of each track""" - if "distance_next" in self.observations.dtype.descr: + """Add a field of distance (m) between two consecutive observations, 0 for the last observation of each track""" + if "distance_next" in self.fields: return self new = self.add_fields(("distance_next",)) new["distance_next"][:1] = self.distance_to_next() @@ -133,7 +155,7 @@ def add_distance(self): def distance_to_next(self): """ - :return: array of distance in m, 0 when next obs if from another track + :return: array of distance in m, 0 when next obs is from another track :rtype: array """ d = distance( @@ -148,50 +170,44 @@ def distance_to_next(self): d_[-1] = 0 return d_ - def filled_by_interpolation(self, mask): - """Filled selected values by interpolation + def normalize_longitude(self): + """Normalize all longitudes - :param array(bool) mask: True if must be filled by interpolation - - .. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation + Normalize longitude field and in the same range : + - longitude_max + - contour_lon_e (how to do if in raw) + - contour_lon_s (how to do if in raw) """ - nb_filled = mask.sum() - logger.info("%d obs will be filled (unobserved)", nb_filled) - - nb_obs = len(self) - index = arange(nb_obs) - - for field in self.obs.dtype.descr: - var = field[0] - if ( - var in ["n", "virtual", "track", "cost_association"] - or var in self.array_variables - ): - continue - # to normalize longitude before interpolation - if var == "lon": - lon = self.lon - first = where(self.n == 0)[0] - nb_obs = empty(first.shape, dtype="u4") - nb_obs[:-1] = first[1:] - first[:-1] - nb_obs[-1] = lon.shape[0] - first[-1] - lon0 = (lon[first] - 180).repeat(nb_obs) - self.lon[:] = (lon - lon0) % 360 + lon0 - self.obs[var][mask] = interp( - index[mask], index[~mask], self.obs[var][~mask] - ) + if self.lon.size == 0: + return + lon0 = (self.lon[self.index_from_track] - 180).repeat(self.nb_obs_by_track) + logger.debug("Normalize longitude") + self.lon[:] = (self.lon - lon0) % 360 + lon0 + if "lon_max" in self.fields: + logger.debug("Normalize longitude_max") + self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 + if not self.raw_data: + if "contour_lon_e" in self.fields: + logger.debug("Normalize effective contour longitude") + self.contour_lon_e[:] = ( + (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 + ).T + if "contour_lon_s" in self.fields: + logger.debug("Normalize speed contour longitude") + self.contour_lon_s[:] = ( + (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 + ).T def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True): - """Select eddies which are longer than nb_min""" + """Select the trajectories longer than nb_min""" mask = nb_obs >= nb_min nb_obs_select = mask.sum() logger.info("Selection of %d observations", nb_obs_select) eddies = self.__class__.new_like(self, nb_obs_select) eddies.sign_type = self.sign_type - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - eddies.obs[var] = self.obs[var][mask] + eddies.obs[field] = self.obs[field][mask] if compress_id: list_id = unique(eddies.obs.track) list_id.sort() @@ -207,7 +223,7 @@ def elements(self): return list(set(elements)) def set_global_attr_netcdf(self, h_nc): - """Set global attr""" + """Set global attributes""" h_nc.title = "Cyclonic" if self.sign_type == -1 else "Anticyclonic" h_nc.Metadata_Conventions = "Unidata Dataset Discovery v1.0" h_nc.comment = "Surface product; mesoscale eddies" @@ -218,20 +234,21 @@ def set_global_attr_netcdf(self, h_nc): ) h_nc.date_created = datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ") t = h_nc.variables[VAR_DESCR_inv["j1"]] - delta = t.max - t.min + 1 - h_nc.time_coverage_duration = "P%dD" % delta - d_start = datetime(1950, 1, 1) + timedelta(int(t.min)) - d_end = datetime(1950, 1, 1) + timedelta(int(t.max)) - h_nc.time_coverage_start = d_start.strftime("%Y-%m-%dT00:00:00Z") - h_nc.time_coverage_end = d_end.strftime("%Y-%m-%dT00:00:00Z") + if t.size: + delta = t.max - t.min + 1 + h_nc.time_coverage_duration = "P%dD" % delta + d_start = datetime(1950, 1, 1) + timedelta(int(t.min)) + d_end = datetime(1950, 1, 1) + timedelta(int(t.max)) + h_nc.time_coverage_start = d_start.strftime("%Y-%m-%dT00:00:00Z") + h_nc.time_coverage_end = d_end.strftime("%Y-%m-%dT00:00:00Z") def extract_with_period(self, period, **kwargs): """ - Extract with a period + Extract within a time period - :param (int,int) period: two date to define period, must be specify from 1/1/1950 + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 :param dict kwargs: look at :py:meth:`extract_with_mask` - :return: Return all eddy tracks which are in bounds + :return: Return all eddy tracks in period :rtype: TrackEddiesObservations .. minigallery:: py_eddy_tracker.TrackEddiesObservations.extract_with_period @@ -252,11 +269,11 @@ def extract_with_period(self, period, **kwargs): def get_azimuth(self, equatorward=False): """ - Return azimuth for each tracks. + Return azimuth for each track. - Azimuth is compute with first and last observation + Azimuth is computed with first and last observations - :param bool equatorward: If True, Poleward are positive and equatorward negative + :param bool equatorward: If True, Poleward is positive and Equatorward negative :rtype: array """ i0, nb = self.index_from_track, self.nb_obs_by_track @@ -286,7 +303,7 @@ def compute_index(self): """ if self.__first_index_of_track is None: s = self.tracks.max() + 1 - # Doesn't work => core dump with numba, maybe he wait i8 instead of u4 + # Doesn't work => core dump with numba, maybe he wants i8 instead of u4 # self.__first_index_of_track = -ones(s, self.tracks.dtype) # self.__obs_by_track = zeros(s, self.observation_number.dtype) self.__first_index_of_track = -ones(s, "i8") @@ -334,12 +351,12 @@ def nb_obs_by_track(self): @property def lifetime(self): - """Return for each observation lifetime""" + """Return lifetime for each observation""" return self.nb_obs_by_track.repeat(self.nb_obs_by_track) @property def age(self): - """Return for each observation age in %, will be [0:100]""" + """Return age in % for each observation, will be [0:100]""" return self.n.astype("f4") / (self.lifetime - 1) * 100.0 def extract_ids(self, tracks): @@ -348,10 +365,10 @@ def extract_ids(self, tracks): def extract_toward_direction(self, west=True, delta_lon=None): """ - Get eddy which go in same direction + Get trajectories going in the same direction - :param bool west: Only eastward eddy if True return westward - :param None,float delta_lon: Only eddy with more than delta_lon span in longitude + :param bool west: Only eastward eddies if True return westward + :param None,float delta_lon: Only eddies with more than delta_lon span in longitude :return: Only eastern eddy :rtype: __class__ @@ -363,19 +380,17 @@ def extract_toward_direction(self, west=True, delta_lon=None): d_lon = lon[i1] - lon[i0] m = d_lon < 0 if west else d_lon > 0 if delta_lon is not None: - m *= delta_lon < d_lon + m *= delta_lon < abs(d_lon) m = m.repeat(nb) return self.extract_with_mask(m) def extract_first_obs_in_box(self, res): - data = empty( - self.obs.shape, dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")] - ) + data = empty(len(self), dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]) data["lon"] = self.longitude - self.longitude % res data["lat"] = self.latitude - self.latitude % res data["track"] = self.track _, indexs = unique(data, return_index=True) - mask = zeros(self.obs.shape, dtype="bool") + mask = zeros(len(self), dtype="bool") mask[indexs] = True return self.extract_with_mask(mask) @@ -398,10 +413,10 @@ def extract_in_direction(self, direction, value=0): def extract_with_length(self, bounds): """ - Return all observations in [b0:b1] + Return the observations within trajectories lasting between [b0:b1] - :param (int,int) bounds: length min and max of selected eddies, if use of -1 this bound is not used - :return: Return all eddy tracks which have length between bounds + :param (int,int) bounds: length min and max of the desired trajectories, if -1 this bound is not used + :return: Return all trajectories having length between bounds :rtype: TrackEddiesObservations .. minigallery:: py_eddy_tracker.TrackEddiesObservations.extract_with_length @@ -417,12 +432,9 @@ def extract_with_length(self, bounds): track_mask = self.nb_obs_by_track >= b0 else: logger.warning("No valid value for bounds") - raise Exception("One bounds must be positiv") + raise Exception("One bound must be positive") return self.extract_with_mask(track_mask.repeat(self.nb_obs_by_track)) - def empty_dataset(self): - return self.new_like(self, 0) - def loess_filter(self, half_window, xfield, yfield, inplace=True): track = self.track x = self.obs[xfield] @@ -431,15 +443,16 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True): if inplace: self.obs[yfield] = result return self + return result def median_filter(self, half_window, xfield, yfield, inplace=True): - track = self.track - x = self.obs[xfield] - y = self.obs[yfield] - result = track_median_filter(half_window, x, y, track) + result = track_median_filter( + half_window, self[xfield], self[yfield], self.track + ) if inplace: - self.obs[yfield] = result + self[yfield][:] = result return self + return result def position_filter(self, median_half_window, loess_half_window): self.median_filter(median_half_window, "time", "lon").loess_filter( @@ -461,11 +474,11 @@ def extract_with_mask( Extract a subset of observations :param array(bool) mask: mask to select observations - :param bool full_path: extract full path if only one part is selected - :param bool remove_incomplete: delete path which are not fully selected - :param bool compress_id: resample track number to use a little range - :param bool reject_virtual: if track are only virtual in selection we remove track - :return: same object with selected observations + :param bool full_path: extract the full trajectory if only one part is selected + :param bool remove_incomplete: delete trajectory if not fully selected + :param bool compress_id: resample trajectory number to use a smaller range + :param bool reject_virtual: if only virtuals are selected, the trajectory is removed + :return: same object with the selected observations :rtype: self.__class__ """ if full_path and remove_incomplete: @@ -487,12 +500,11 @@ def extract_with_mask( new = self.__class__.new_like(self, nb_obs) new.sign_type = self.sign_type if nb_obs == 0: - logger.warning("Empty dataset will be created") + logger.info("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] if compress_id: list_id = unique(new.track) list_id.sort() @@ -501,16 +513,11 @@ def extract_with_mask( new.track = id_translate[new.track] return new - @staticmethod - def re_reference_index(index, ref): - if isinstance(ref, slice): - return index + ref.start - else: - return ref[index] - def shape_polygon(self, intern=False): """ - Get polygons which enclosed each track + Get the polygon enclosing each trajectory. + + The polygon merges the non-overlapping bounds of the specified contours :param bool intern: If True use speed contour instead of effective contour :rtype: list(array, array) @@ -520,10 +527,10 @@ def shape_polygon(self, intern=False): def display_shape(self, ax, ref=None, intern=False, **kwargs): """ - This function will draw shape of each track + This function draws the shape of each trajectory - :param matplotlib.axes.Axes ax: ax where drawed - :param float,int ref: if defined all coordinates will be wrapped with ref like west boundary + :param matplotlib.axes.Axes ax: ax to draw + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary :param bool intern: If True use speed contour instead of effective contour :param dict kwargs: keyword arguments for Axes.plot :return: matplotlib mappable @@ -547,20 +554,23 @@ def display_shape(self, ax, ref=None, intern=False, **kwargs): def close_tracks(self, other, nb_obs_min=10, **kwargs): """ - Get close from another atlas. + Get close trajectories from another atlas. :param self other: Atlas to compare - :param int nb_obs_min: Minimal number of overlap for one track + :param int nb_obs_min: Minimal number of overlap for one trajectory :param dict kwargs: keyword arguments for match function - :return: return other atlas reduce to common track with self + :return: return other atlas reduced to common trajectories with self .. warning:: It could be a costly operation for huge dataset """ p0, p1 = self.period + p0_other, p1_other = other.period + if p1_other < p0 or p1 < p0_other: + return other.__class__.new_like(other, 0) indexs = list() - for i_self, i_other, t0, t1 in self.align_on(other, bins=range(p0, p1 + 2)): - i, j, s = self.index(i_self).match(other.index(i_other), **kwargs) + for i_self, i_other, t0, t1 in self.align_on(other, bins=arange(p0, p1 + 2)): + i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs) indexs.append(other.re_reference_index(j, i_other)) indexs = concatenate(indexs) tr, nb = unique(other.track[indexs], return_counts=True) @@ -577,10 +587,10 @@ def format_label(self, label): def plot(self, ax, ref=None, **kwargs): """ - This function will draw path of each track + This function will draw path of each trajectory - :param matplotlib.axes.Axes ax: ax where drawed - :param float,int ref: if defined all coordinates will be wrapped with ref like west boundary + :param matplotlib.axes.Axes ax: ax to draw + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: keyword arguments for Axes.plot :return: matplotlib mappable """ @@ -595,91 +605,91 @@ def plot(self, ax, ref=None, **kwargs): return ax.plot(x, y, **kwargs) def split_network(self, intern=True, **kwargs): - """Divide each group in track""" + """Return each group (network) divided in segments""" + # Find timestep of dataset + # FIXME : how to know exact time sampling + t = unique(self.time) + dts = t[1:] - t[:-1] + timestep = median(dts) + track_s, track_e, track_ref = build_index(self.tracks) ids = empty( len(self), dtype=[ ("group", self.tracks.dtype), ("time", self.time.dtype), - ("track", "u2"), + ("track", "u4"), ("previous_cost", "f4"), ("next_cost", "f4"), ("previous_obs", "i4"), ("next_obs", "i4"), ], ) - ids["group"], ids["time"] = self.tracks, self.time - # To store id track + ids["group"], ids["time"] = self.tracks, int_(self.time / timestep) + # Initialisation + # To store the id of the segments, the backward and forward cost associations ids["track"], ids["previous_cost"], ids["next_cost"] = 0, 0, 0 + # To store the indices of the backward and forward observations associated ids["previous_obs"], ids["next_obs"] = -1, -1 + # At the end, ids["previous_obs"] == -1 means the start of a non-split segment + # and ids["next_obs"] == -1 means the end of a non-merged segment xname, yname = self.intern(intern) + display_iteration = logger.getEffectiveLevel() == logging.INFO for i_s, i_e in zip(track_s, track_e): if i_s == i_e or self.tracks[i_s] == self.NOGROUP: continue + if display_iteration: + print(f"Network obs from {i_s} to {i_e} on {track_e[-1]}", end="\r") sl = slice(i_s, i_e) local_ids = ids[sl] + # built segments with local indices self.set_tracks(self[xname][sl], self[yname][sl], local_ids, **kwargs) - m = local_ids["previous_obs"] == -1 + # shift the local indices to the total indexation for the used observations + m = local_ids["previous_obs"] != -1 local_ids["previous_obs"][m] += i_s - m = local_ids["next_obs"] == -1 + m = local_ids["next_obs"] != -1 local_ids["next_obs"][m] += i_s + if display_iteration: + print() + ids["time"] *= timestep return ids - # ids_sort = ids[new_i] - # # To be able to follow indices sorting - # reverse_sort = empty(new_i.shape[0], dtype="u4") - # reverse_sort[new_i] = arange(new_i.shape[0]) - # # Redirect indices - # m = ids_sort["next_obs"] != -1 - # ids_sort["next_obs"][m] = reverse_sort[ - # ids_sort["next_obs"][m] - # ] - # m = ids_sort["previous_obs"] != -1 - # ids_sort["previous_obs"][m] = reverse_sort[ - # ids_sort["previous_obs"][m] - # ] - # # print(ids_sort) - # display_network( - # x[new_i], - # y[new_i], - # ids_sort["track"], - # ids_sort["time"], - # ids_sort["next_cost"], - # ) - - def set_tracks(self, x, y, ids, window): - """ - Will split one group in tracks + + def set_tracks(self, x, y, ids, window, **kwargs): + """ + Split one group (network) in segments :param array x: coordinates of group :param array y: coordinates of group :param ndarray ids: several fields like time, group, ... - :param int windows: number of days where observations could missed + :param int window: number of days where observations could missed """ - - time_index = build_index(ids["time"]) + time_index = build_index((ids["time"]).astype("i4")) nb = x.shape[0] used = zeros(nb, dtype="bool") track_id = 1 - # build all polygon (need to check if wrap is needed) - polygons = [Polygon(create_vertice_from_2darray(x, y, i)) for i in range(nb)] + # build all polygons (need to check if wrap is needed) for i in range(nb): - # If observation already in one track, we go to the next one + # If the observation is already in one track, we go to the next one if used[i]: continue - self.follow_obs(i, track_id, used, ids, polygons, *time_index, window) + # Search a possible continuation (forward) + self.follow_obs(i, track_id, used, ids, x, y, *time_index, window, **kwargs) track_id += 1 + # Search a possible ancestor (backward) + self.get_previous_obs(i, ids, x, y, *time_index, window, **kwargs) @classmethod - def follow_obs(cls, i_next, track_id, used, ids, *args): + def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs): + """Associate the observations to the segments""" + while i_next != -1: # Flag used[i_next] = True # Assign id ids["track"][i_next] = track_id # Search next - i_next_ = cls.next_obs(i_next, ids, *args) + i_next_ = cls.get_next_obs(i_next, ids, *args, **kwargs) if i_next_ == -1: break ids["next_obs"][i_next] = i_next_ @@ -695,9 +705,61 @@ def follow_obs(cls, i_next, track_id, used, ids, *args): i_next = i_next_ @staticmethod - def next_obs(i_current, ids, polygons, time_s, time_e, time_ref, window): + def get_previous_obs( + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.2, + **kwargs, + ): + """Backward association of observations to the segments""" + time_cur = int_(ids["time"][i_current]) + t0, t1 = time_cur - 1 - time_ref, max(time_cur - window - time_ref, 0) + for t_step in range(t0, t1 - 1, -1): + i0, i1 = time_s[t_step], time_e[t_step] + # No observation at the time step + if i0 == i1: + continue + # Search for overlaps + xi, yi, xj, yj = x[[i_current]], y[[i_current]], x[i0:i1], y[i0:i1] + ii, ij = bbox_intersection(xi, yi, xj, yj) + if len(ii) == 0: + continue + c = zeros(len(xj)) + c[ij] = vertice_overlap( + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs + ) + # We get index of maximal overlap + i = c.argmax() + c_i = c[i] + # No overlap found + if c_i == 0: + continue + ids["previous_cost"][i_current] = c_i + ids["previous_obs"][i_current] = i0 + i + break + + @staticmethod + def get_next_obs( + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.2, + **kwargs, + ): + """Forward association of observations to the segments""" time_max = time_e.shape[0] - 1 - time_cur = ids["time"][i_current] + time_cur = int_(ids["time"][i_current]) t0, t1 = time_cur + 1 - time_ref, min(time_cur + window - time_ref, time_max) if t0 > time_max: return -1 @@ -706,10 +768,15 @@ def next_obs(i_current, ids, polygons, time_s, time_e, time_ref, window): # No observation at the time step if i0 == i1: continue - # Intersection / union, to be able to separte in case of multiple inside - c = polygon_overlap(polygons[i_current], polygons[i0:i1]) - # We remove low overlap - c[c < 0.1] = 0 + # Search for overlaps + xi, yi, xj, yj = x[[i_current]], y[[i_current]], x[i0:i1], y[i0:i1] + ii, ij = bbox_intersection(xi, yi, xj, yj) + if len(ii) == 0: + continue + c = zeros(len(xj)) + c[ij] = vertice_overlap( + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs + ) # We get index of maximal overlap i = c.argmax() c_i = c[i] @@ -754,10 +821,10 @@ def track_loess_filter(half_window, x, y, track): """ Apply a loess filter on y field - :param int,float window: parameter of smoother + :param int,float half_window: parameter of smoother :param array_like x: must be growing for each track but could be irregular :param array_like y: field to smooth - :param array_like track: field which allow to separate path + :param array_like track: field that allows to separate path :return: Array smoothed :rtype: array_like diff --git a/src/py_eddy_tracker/poly.py b/src/py_eddy_tracker/poly.py index 08792241..b5849610 100644 --- a/src/py_eddy_tracker/poly.py +++ b/src/py_eddy_tracker/poly.py @@ -5,11 +5,12 @@ import heapq -from numba import njit, prange -from numba import types as numba_types -from numpy import array, concatenate, empty, nan, ones, pi, where -from numpy.linalg import lstsq from Polygon import Polygon +from numba import njit, prange, types as numba_types +from numpy import arctan, array, concatenate, empty, nan, ones, pi, where, zeros +from numpy.linalg import lstsq + +from .generic import build_index @njit(cache=True) @@ -84,7 +85,7 @@ def poly_area_vertice(v): @njit(cache=True) def poly_area(x, y): """ - Must be call with local coordinates (in m, to get an area in m²). + Must be called with local coordinates (in m, to get an area in m²). :param array x: :param array y: @@ -209,6 +210,7 @@ def winding_number_poly(x, y, xy_poly): # loop through all edges of the polygon for i_elt in range(nb_elt): if i_elt + 1 == nb_elt: + # We close polygon with first value (no need to duplicate first value) x_next = xy_poly[0, 0] y_next = xy_poly[0, 1] else: @@ -276,7 +278,10 @@ def close_center(x0, y0, x1, y1, delta=0.1): for i0 in range(nb0): xi0, yi0 = x0[i0], y0[i0] for i1 in range(nb1): - if abs(x1[i1] - xi0) > delta: + d_x = x1[i1] - xi0 + if abs(d_x) > 180: + d_x = (d_x + 180) % 360 - 180 + if abs(d_x) > delta: continue if abs(y1[i1] - yi0) > delta: continue @@ -284,6 +289,27 @@ def close_center(x0, y0, x1, y1, delta=0.1): return array(i), array(j), array(c) +@njit(cache=True) +def create_meshed_particles(lons, lats, step): + x_out, y_out, i_out = list(), list(), list() + nb = lons.shape[0] + for i in range(nb): + lon, lat = lons[i], lats[i] + vertice = create_vertice(*reduce_size(lon, lat)) + lon_min, lon_max = lon.min(), lon.max() + lat_min, lat_max = lat.min(), lat.max() + y0 = lat_min - lat_min % step + x = lon_min - lon_min % step + while x <= lon_max: + y = y0 + while y <= lat_max: + if winding_number_poly(x, y, vertice): + x_out.append(x), y_out.append(y), i_out.append(i) + y += step + x += step + return array(x_out), array(y_out), array(i_out) + + @njit(cache=True, fastmath=True) def bbox_intersection(x0, y0, x1, y1): """ @@ -321,7 +347,7 @@ def bbox_intersection(x0, y0, x1, y1): continue i.append(i0) j.append(i1) - return array(i), array(j) + return array(i, dtype=numba_types.int32), array(j, dtype=numba_types.int32) @njit(cache=True) @@ -408,7 +434,9 @@ def merge(x, y): return concatenate(x), concatenate(y) -def vertice_overlap(x0, y0, x1, y1, minimal_area=False): +def vertice_overlap( + x0, y0, x1, y1, minimal_area=False, p1_area=False, hybrid_area=False, min_overlap=0 +): r""" Return percent of overlap for each item. @@ -417,6 +445,10 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): :param array x1: x for polygon list 1 :param array y1: y for polygon list 1 :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union + :param bool p1_area: If True, function will compute intersection/p1 polygon, else intersection/union + :param bool hybrid_area: If True, function will compute like union, + but if cost is under min_overlap, obs is kept in case of fully included + :param float min_overlap: under this value cost is set to zero :return: Result of cost function :rtype: array @@ -427,6 +459,10 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): If minimal area: .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{min(P_{0 area},P_{1 area})} + + If P1 area: + + .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{P_{1 area}} """ nb = x0.shape[0] cost = empty(nb) @@ -438,11 +474,29 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): # Area of intersection intersection = (p0 & p1).area() # we divide intersection with the little one result from 0 to 1 + if intersection == 0: + cost[i] = 0 + continue + p0_area_, p1_area_ = p0.area(), p1.area() if minimal_area: - cost[i] = intersection / min(p0.area(), p1.area()) + cost_ = intersection / min(p0_area_, p1_area_) + # we divide intersection with p1 + elif p1_area: + cost_ = intersection / p1_area_ # we divide intersection with polygon merging result from 0 to 1 else: - cost[i] = intersection / (p0 + p1).area() + cost_ = intersection / (p0_area_ + p1_area_ - intersection) + if cost_ >= min_overlap: + cost[i] = cost_ + else: + if ( + hybrid_area + and cost_ != 0 + and (intersection / min(p0_area_, p1_area_)) > 0.99 + ): + cost[i] = cost_ + else: + cost[i] = 0 return cost @@ -452,7 +506,7 @@ def polygon_overlap(p0, p1, minimal_area=False): :param list(Polygon) p0: List of polygon to compare with p1 list :param list(Polygon) p1: List of polygon to compare with p0 list - :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union + :param bool minimal_area: If True, function will compute intersection/smaller polygon, else intersection/union :return: Result of cost function :rtype: array """ @@ -462,21 +516,22 @@ def polygon_overlap(p0, p1, minimal_area=False): p_ = p1[i] # Area of intersection intersection = (p0 & p_).area() - # we divide intersection with the little one result from 0 to 1 + # we divide the intersection by the smaller area, result from 0 to 1 if minimal_area: cost[i] = intersection / min(p0.area(), p_.area()) - # we divide intersection with polygon merging result from 0 to 1 + # we divide the intersection by the merged polygons area, result from 0 to 1 else: cost[i] = intersection / (p0 + p_).area() return cost +# FIXME: only one function is needed @njit(cache=True) def fit_circle(x, y): """ From a polygon, function will fit a circle. - Must be call with local coordinates (in m, to get a radius in m). + Must be called with local coordinates (in m, to get a radius in m). :param array x: x of polygon :param array y: y of polygon @@ -491,7 +546,7 @@ def fit_circle(x, y): norme = (x[1:] - x_mean) ** 2 + (y[1:] - y_mean) ** 2 norme_max = norme.max() - scale = norme_max ** 0.5 + scale = norme_max**0.5 # Form matrix equation and solve it # Maybe put f4 @@ -502,21 +557,59 @@ def fit_circle(x, y): (x0, y0, radius), _, _, _ = lstsq(datas, norme / norme_max) # Unscale data and get circle variables - radius += x0 ** 2 + y0 ** 2 + radius += x0**2 + y0**2 radius **= 0.5 x0 *= scale y0 *= scale - # radius of fitted circle + # radius of fit circle radius *= scale - # center X-position of fitted circle + # center X-position of fit circle x0 += x_mean - # center Y-position of fitted circle + # center Y-position of fit circle y0 += y_mean err = shape_error(x, y, x0, y0, radius) return x0, y0, radius, err +@njit(cache=True) +def fit_ellipse(x, y): + r""" + From a polygon, function will fit an ellipse. + + Must be call with local coordinates (in m, to get a radius in m). + + .. math:: (\frac{x - x_0}{a})^2 + (\frac{y - y_0}{b})^2 = 1 + + .. math:: (\frac{x^2 - 2 * x * x_0 + x_0 ^2}{a^2}) + \frac{y^2 - 2 * y * y_0 + y_0 ^2}{b^2}) = 1 + + In case of angle + https://en.wikipedia.org/wiki/Ellipse + + """ + nb = x.shape[0] + datas = ones((nb, 5), dtype=x.dtype) + datas[:, 0] = x**2 + datas[:, 1] = x * y + datas[:, 2] = y**2 + datas[:, 3] = x + datas[:, 4] = y + (a, b, c, d, e), _, _, _ = lstsq(datas, ones(nb, dtype=x.dtype)) + det = b**2 - 4 * a * c + if det > 0: + print(det) + x0 = (2 * c * d - b * e) / det + y0 = (2 * a * e - b * d) / det + + AB1 = 2 * (a * e**2 + c * d**2 - b * d * e - det) + AB2 = a + c + AB3 = ((a - c) ** 2 + b**2) ** 0.5 + A = -((AB1 * (AB2 + AB3)) ** 0.5) / det + B = -((AB1 * (AB2 - AB3)) ** 0.5) / det + theta = arctan((c - a - AB3) / b) + return x0, y0, A, B, theta + + @njit(cache=True) def fit_circle_(x, y): r""" @@ -571,7 +664,7 @@ def fit_circle_(x, y): # Linear regression (a, b, c), _, _, _ = lstsq(datas, x[1:] ** 2 + y[1:] ** 2) x0, y0 = a / 2.0, b / 2.0 - radius = (c + x0 ** 2 + y0 ** 2) ** 0.5 + radius = (c + x0**2 + y0**2) ** 0.5 err = shape_error(x, y, x0, y0, radius) return x0, y0, radius, err @@ -596,14 +689,14 @@ def shape_error(x, y, x0, y0, r): :rtype: float """ # circle area - c_area = (r ** 2) * pi + c_area = (r**2) * pi p_area = poly_area(x, y) nb = x.shape[0] x, y = x.copy(), y.copy() # Find distance between circle center and polygon for i in range(nb): dx, dy = x[i] - x0, y[i] - y0 - rd = r / (dx ** 2 + dy ** 2) ** 0.5 + rd = r / (dx**2 + dy**2) ** 0.5 if rd < 1: x[i] = x0 + dx * rd y[i] = y0 + dy * rd @@ -671,56 +764,209 @@ def tri_area2(x, y, i0, i1, i2): @njit(cache=True) -def visvalingam(x, y, nb_pt=18): +def visvalingam(x, y, fixed_size=18): """Polygon simplification with visvalingam algorithm + X, Y are considered like a polygon, the next point after the last one is the first one + :param array x: :param array y: - :param int nb_pt: array size of out - :return: New (x, y) array + :param int fixed_size: array size of out + :return: + New (x, y) array, last position will be equal to first one, if array size is 6, + there is only 5 point. :rtype: array,array + + .. plot:: + + import matplotlib.pyplot as plt + import numpy as np + from py_eddy_tracker.poly import visvalingam + + x = np.array([1, 2, 3, 4, 5, 6.75, 6, 1]) + y = np.array([-0.5, -1.5, -1, -1.75, -1, -1, -0.5, -0.5]) + ax = plt.subplot(111) + ax.set_aspect("equal") + ax.grid(True), ax.set_ylim(-2, -.2) + ax.plot(x, y, "r", lw=5) + ax.plot(*visvalingam(x,y,6), "b", lw=2) + plt.show() """ + # TODO : in case of original size lesser than fixed size, jump at the end nb = x.shape[0] - i0, i1 = nb - 3, nb - 2 + nb_ori = nb + # Get indice of first triangle + i0, i1 = nb - 2, nb - 1 + # Init heap with first area and tiangle h = [(tri_area2(x, y, i0, i1, 0), (i0, i1, 0))] + # Roll index for next one i0 = i1 i1 = 0 - i_previous = empty(nb - 1, dtype=numba_types.int32) - i_next = empty(nb - 1, dtype=numba_types.int32) + # Index of previous valid point + i_previous = empty(nb, dtype=numba_types.int64) + # Index of next valid point + i_next = empty(nb, dtype=numba_types.int64) + # Mask of removed + removed = zeros(nb, dtype=numba_types.bool_) i_previous[0] = -1 i_next[0] = -1 - for i in range(1, nb - 1): + for i in range(1, nb): i_previous[i] = -1 i_next[i] = -1 + # We add triangle area for all triangle heapq.heappush(h, (tri_area2(x, y, i0, i1, i), (i0, i1, i))) i0 = i1 i1 = i # we continue until we are equal to nb_pt - while len(h) >= nb_pt: + while nb >= fixed_size: # We pop lower area _, (i0, i1, i2) = heapq.heappop(h) # We check if triangle is valid(i0 or i2 not removed) - i_p, i_n = i_previous[i0], i_next[i2] - if i_p == -1 and i_n == -1: - # We store reference of delete point - i_previous[i1] = i0 - i_next[i1] = i2 + if removed[i0] or removed[i2]: + # In this cas nothing to do continue - elif i_p == -1: - i2 = i_n - elif i_n == -1: - i0 = i_p - else: - # in this case we replace two point - i0, i2 = i_p, i_n - heapq.heappush(h, (tri_area2(x, y, i0, i1, i2), (i0, i1, i2))) - x_new, y_new = empty(nb_pt, dtype=x.dtype), empty(nb_pt, dtype=y.dtype) + # Flag obs like removed + removed[i1] = True + # We count point still valid + nb -= 1 + # Modify index for the next and previous, we jump over i1 + i_previous[i2] = i0 + i_next[i0] = i2 + # We insert 2 triangles which are modified by the deleted point + # Previous triangle + i_1 = i_previous[i0] + if i_1 == -1: + i_1 = (i0 - 1) % nb_ori + heapq.heappush(h, (tri_area2(x, y, i_1, i0, i2), (i_1, i0, i2))) + # Previous triangle + i3 = i_next[i2] + if i3 == -1: + i3 = (i2 + 1) % nb_ori + heapq.heappush(h, (tri_area2(x, y, i0, i2, i3), (i0, i2, i3))) + x_new, y_new = empty(fixed_size, dtype=x.dtype), empty(fixed_size, dtype=y.dtype) j = 0 - for i, i_n in enumerate(i_next): - if i_n == -1: + for i, flag in enumerate(removed): + if not flag: x_new[j] = x[i] y_new[j] = y[i] j += 1 - x_new[j] = x_new[0] - y_new[j] = y_new[0] + # we copy first value to fill array end + x_new[j:] = x_new[0] + y_new[j:] = y_new[0] return x_new, y_new + + +@njit(cache=True) +def reduce_size(x, y): + """ + Reduce array size if last position is repeated, in order to save compute time + + :param array x: longitude + :param array y: latitude + + :return: reduce arrays x,y + :rtype: ndarray,ndarray + """ + i = x.shape[0] + x0, y0 = x[0], y[0] + while True: + i -= 1 + if x[i] != x0 or y[i] != y0: + i += 1 + # In case of virtual obs all value could be fill with same value, to avoid empty array + i = max(3, i) + return x[:i], y[:i] + + +@njit(cache=True) +def group_obs(x, y, step, nb_x): + """Get index k_box for each box, and indexes to sort""" + nb = x.size + i = empty(nb, dtype=numba_types.uint32) + for k in range(nb): + i[k] = box_index(x[k], y[k], step, nb_x) + return i, i.argsort(kind="mergesort") + + +@njit(cache=True) +def box_index(x, y, step, nb_x): + """Return k_box index for each value""" + return numba_types.uint32((x % 360) // step + nb_x * ((y + 90) // step)) + + +@njit(cache=True) +def box_indexes(x, y, step): + """Return i_box,j_box index for each value""" + return numba_types.uint32((x % 360) // step), numba_types.uint32((y + 90) // step) + + +@njit(cache=True) +def poly_indexs(x_p, y_p, x_c, y_c): + """ + Index of contour for each postion inside a contour, -1 in case of no contour + + :param array x_p: longitude to test (must be defined, no nan) + :param array y_p: latitude to test (must be defined, no nan) + :param array x_c: longitude of contours + :param array y_c: latitude of contours + """ + nb_x = 360 + step = 1.0 + i, i_order = group_obs(x_p, y_p, step, nb_x) + nb_p = x_p.shape[0] + nb_c = x_c.shape[0] + indexs = -ones(nb_p, dtype=numba_types.int32) + # Adress table to get test bloc + start_index, end_index, i_first = build_index(i[i_order]) + nb_bloc = end_index.size + for i_contour in range(nb_c): + # Build vertice and box included contour + x_, y_ = reduce_size(x_c[i_contour], y_c[i_contour]) + x_c_min, y_c_min = x_.min(), y_.min() + x_c_max, y_c_max = x_.max(), y_.max() + v = create_vertice(x_, y_) + i0, j0 = box_indexes(x_c_min, y_c_min, step) + i1, j1 = box_indexes(x_c_max, y_c_max, step) + # i0 could be greater than i1, (x_c is always continious) so you could have a contour over bound + if i0 > i1: + i1 += nb_x + for i_x in range(i0, i1 + 1): + # we force i_x in 0 360 range + i_x %= nb_x + for i_y in range(j0, j1 + 1): + # Get box indices + i_box = i_x + nb_x * i_y - i_first + # Indice must be in table range + if i_box < 0 or i_box >= nb_bloc: + continue + for i_p_ordered in range(start_index[i_box], end_index[i_box]): + i_p = i_order[i_p_ordered] + if indexs[i_p] != -1: + continue + y = y_p[i_p] + if y > y_c_max: + continue + if y < y_c_min: + continue + # Normalize longitude at +-180° around x_c_min + x = (x_p[i_p] - x_c_min + 180) % 360 + x_c_min - 180 + if x > x_c_max: + continue + if x < x_c_min: + continue + if winding_number_poly(x, y, v) != 0: + indexs[i_p] = i_contour + return indexs + + +@njit(cache=True) +def insidepoly(x_p, y_p, x_c, y_c): + """ + True for each postion inside a contour + + :param array x_p: longitude to test + :param array y_p: latitude to test + :param array x_c: longitude of contours + :param array y_c: latitude of contours + """ + return poly_indexs(x_p, y_p, x_c, y_c) != -1 diff --git a/src/py_eddy_tracker/tracking.py b/src/py_eddy_tracker/tracking.py index ae0a2524..b64b6fcc 100644 --- a/src/py_eddy_tracker/tracking.py +++ b/src/py_eddy_tracker/tracking.py @@ -2,15 +2,14 @@ """ Class to store link between observations """ - +from datetime import datetime, timedelta import json import logging import platform -from datetime import datetime, timedelta +from tarfile import ExFileObject from netCDF4 import Dataset, default_fillvals -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( arange, array, @@ -67,6 +66,7 @@ def __init__( class_method=None, class_kw=None, previous_correspondance=None, + memory=False, ): """Initiate tracking @@ -74,6 +74,7 @@ def __init__( :param class class_method: A class which tell how to track :param dict class_kw: keyword argument to setup class :param Correspondances previous_correspondance: A previous correspondance object if you want continue tracking + :param bool memory: identification file are load in memory before to be open with netcdf """ super().__init__() # Correspondance dtype @@ -88,6 +89,7 @@ def __init__( else: self.class_method = class_method self.class_kw = dict() if class_kw is None else class_kw + self.memory = memory # To count ID self.current_id = 0 @@ -158,10 +160,10 @@ def period(self): """ date_start = datetime(1950, 1, 1) + timedelta( - int(self.class_method.load_file(self.datasets[0]).time[0]) + self.class_method.load_file(self.datasets[0]).time[0] ) date_stop = datetime(1950, 1, 1) + timedelta( - int(self.class_method.load_file(self.datasets[-1]).time[0]) + self.class_method.load_file(self.datasets[-1]).time[0] ) return date_start, date_stop @@ -171,7 +173,11 @@ def swap_dataset(self, dataset, *args, **kwargs): self.previous_obs = self.current_obs kwargs = kwargs.copy() kwargs.update(self.class_kw) - self.current_obs = self.class_method.load_file(dataset, *args, **kwargs) + if self.memory: + with open(dataset, "rb") as h: + self.current_obs = self.class_method.load_file(h, *args, **kwargs) + else: + self.current_obs = self.class_method.load_file(dataset, *args, **kwargs) def merge_correspondance(self, other): # Verify compliance of file @@ -343,6 +349,9 @@ def load_state(self): self.virtual_obs = VirtualEddiesObservations.from_netcdf( general_handler.groups["LastVirtualObs"] ) + self.previous_virtual_obs = VirtualEddiesObservations.from_netcdf( + general_handler.groups["LastPreviousVirtualObs"] + ) # Load and last previous virtual obs to be merge with current => will be previous2_obs # TODO : Need to rethink this line ?? self.current_obs = self.current_obs.merge( @@ -366,7 +375,10 @@ def track(self): # We begin with second file, first one is in previous for file_name in self.datasets[first_dataset:]: self.swap_dataset(file_name, **kwargs) - logger.info("%s match with previous state", file_name) + filename_ = ( + file_name.filename if isinstance(file_name, ExFileObject) else file_name + ) + logger.info("%s match with previous state", filename_) logger.debug("%d obs to match", len(self.current_obs)) nb_real_obs = len(self.previous_obs) @@ -400,14 +412,14 @@ def to_netcdf(self, handler): logger.debug('Create Dimensions "Nstep" : %d', nb_step) handler.createDimension("Nstep", nb_step) var_file_in = handler.createVariable( - zlib=True, + zlib=False, complevel=1, varname="FileIn", datatype="S1024", dimensions="Nstep", ) var_file_out = handler.createVariable( - zlib=True, + zlib=False, complevel=1, varname="FileOut", datatype="S1024", @@ -577,7 +589,10 @@ def prepare_merging(self): def longer_than(self, size_min): """Remove from correspondance table all association for shorter eddies than size_min""" # Identify eddies longer than - i_keep_track = where(self.nb_obs_by_tracks >= size_min)[0] + mask = self.nb_obs_by_tracks >= size_min + if not mask.any(): + return False + i_keep_track = where(mask)[0] # Reduce array self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track] self.i_current_by_tracks = ( @@ -646,7 +661,7 @@ def merge(self, until=-1, raw_data=True): # Set type of eddy with first file eddies.sign_type = self.current_obs.sign_type # Fields to copy - fields = self.current_obs.obs.dtype.names + fields = self.current_obs.fields # To know if the track start first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_) @@ -695,7 +710,7 @@ def merge(self, until=-1, raw_data=True): # Index in the current file index_current = self[i]["out"] - if "cost_association" in eddies.obs.dtype.names: + if "cost_association" in eddies.fields: eddies["cost_association"][index_final - 1] = self[i]["cost_value"] # Copy all variable for field in fields: @@ -712,53 +727,29 @@ def get_unused_data(self, raw_data=False): Returns: Unused Eddies """ - self.reset_dataset_cache() - self.swap_dataset(self.datasets[0], raw_data=raw_data) - nb_dataset = len(self.datasets) - # Get the number of obs unused - nb_obs = 0 - list_mask = list() has_virtual = "virtual" in self[0].dtype.names - logger.debug("Count unused data ...") - for i, filename in enumerate(self.datasets): + eddies = list() + for i, dataset in enumerate(self.datasets): last_dataset = i == (nb_dataset - 1) if has_virtual and not last_dataset: m_in = ~self[i]["virtual"] else: m_in = slice(None) if i == 0: - eddies_used = self[i]["in"] + index_used = self[i]["in"] elif last_dataset: - eddies_used = self[i - 1]["out"] + index_used = self[i - 1]["out"] else: - eddies_used = unique( + index_used = unique( concatenate((self[i - 1]["out"], self[i]["in"][m_in])) ) - if not isinstance(filename, str): - filename = filename.astype(str) - with Dataset(filename) as h: - nb_obs_day = len(h.dimensions["obs"]) - m = ones(nb_obs_day, dtype="bool") - m[eddies_used] = False - list_mask.append(m) - nb_obs += m.sum() - logger.debug("Count unused data OK") - eddies = EddiesObservations( - size=nb_obs, - track_extra_variables=self.current_obs.track_extra_variables, - track_array_variables=self.current_obs.track_array_variables, - array_variables=self.current_obs.array_variables, - raw_data=raw_data, - ) - j = 0 - for i, dataset in enumerate(self.datasets): - logger.debug("Loaf file : (%d) %s", i, dataset) - current_obs = self.class_method.load_file(dataset, raw_data=raw_data) - if i == 0: - eddies.sign_type = current_obs.sign_type - unused_obs = current_obs.obs[list_mask[i]] - nb = unused_obs.shape[0] - eddies.obs[j : j + nb] = unused_obs - j += nb - return eddies + + logger.debug("Load file : %s", dataset) + if self.memory: + with open(dataset, "rb") as h: + current_obs = self.class_method.load_file(h, raw_data=raw_data) + else: + current_obs = self.class_method.load_file(dataset, raw_data=raw_data) + eddies.append(current_obs.index(index_used, reverse=True)) + return EddiesObservations.concatenate(eddies) diff --git a/src/scripts/EddySubSetter b/src/scripts/EddySubSetter index 73250834..6cace388 100644 --- a/src/scripts/EddySubSetter +++ b/src/scripts/EddySubSetter @@ -22,7 +22,7 @@ def id_parser(): "--period", nargs=2, type=int, - help="Start day and end day, if it s negative value we will add to day min and add to day max, if 0 it s not use", + help="Start day and end day, if it's negative value we will add to day min and add to day max, if 0 it s not use", ) group.add_argument( "-l", diff --git a/src/scripts/EddyTracking b/src/scripts/EddyTracking deleted file mode 100644 index 28c946fa..00000000 --- a/src/scripts/EddyTracking +++ /dev/null @@ -1,274 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -Track eddy with Identification file produce with EddyIdentification -""" -import logging -from datetime import datetime -from glob import glob -from os import mkdir -from os.path import basename, dirname, exists -from os.path import join as join_path -from re import compile as re_compile - -from netCDF4 import Dataset -from numpy import bytes_, empty, unique -from yaml import load as yaml_load - -from py_eddy_tracker import EddyParser -from py_eddy_tracker.tracking import Correspondances - -logger = logging.getLogger("pet") - - -def browse_dataset_in( - data_dir, - files_model, - date_regexp, - date_model, - start_date=None, - end_date=None, - sub_sampling_step=1, - files=None, -): - pattern_regexp = re_compile(".*/" + date_regexp) - if files is not None: - filenames = bytes_(files) - else: - full_path = join_path(data_dir, files_model) - logger.info("Search files : %s", full_path) - filenames = bytes_(glob(full_path)) - - dataset_list = empty( - len(filenames), dtype=[("filename", "S500"), ("date", "datetime64[D]"),] - ) - dataset_list["filename"] = filenames - - logger.info("%s grids available", dataset_list.shape[0]) - mode_attrs = False - if "(" not in date_regexp: - logger.debug("Attrs date : %s", date_regexp) - mode_attrs = date_regexp.strip().split(":") - else: - logger.debug("Pattern date : %s", date_regexp) - - for item in dataset_list: - str_date = None - if mode_attrs: - with Dataset(item["filename"].decode("utf-8")) as h: - if len(mode_attrs) == 1: - str_date = getattr(h, mode_attrs[0]) - else: - str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1]) - else: - result = pattern_regexp.match(str(item["filename"])) - if result: - str_date = result.groups()[0] - - if str_date is not None: - item["date"] = datetime.strptime(str_date, date_model).date() - - dataset_list.sort(order=["date", "filename"]) - - steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1]) - if len(steps) > 1: - raise Exception("Several days steps in grid dataset %s" % steps) - - if sub_sampling_step != 1: - logger.info("Grid subsampling %d", sub_sampling_step) - dataset_list = dataset_list[::sub_sampling_step] - - if start_date is not None or end_date is not None: - logger.info( - "Available grid from %s to %s", - dataset_list[0]["date"], - dataset_list[-1]["date"], - ) - logger.info("Filtering grid by time %s, %s", start_date, end_date) - mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date) - - dataset_list = dataset_list[mask] - return dataset_list - - -def usage(): - """Usage - """ - # Run using: - parser = EddyParser("Tool to use identification step to compute tracking") - parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker") - parser.add_argument("--correspondance_in", help="Filename of saved correspondance") - parser.add_argument("--correspondance_out", help="Filename to save correspondance") - parser.add_argument( - "--save_correspondance_and_stop", - action="store_true", - help="Stop tracking after correspondance computation," - " merging can be done with EddyFinalTracking", - ) - parser.add_argument( - "--zarr", action="store_true", help="Output will be wrote in zarr" - ) - parser.add_argument("--unraw", action="store_true", help="Load unraw data") - parser.add_argument( - "--blank_period", - type=int, - default=0, - help="Nb of detection which will not use at the end of the period", - ) - args = parser.parse_args() - - # Read yaml configuration file - with open(args.yaml_file, "r") as stream: - config = yaml_load(stream) - if args.correspondance_in is not None and not exists(args.correspondance_in): - args.correspondance_in = None - return ( - config, - args.save_correspondance_and_stop, - args.correspondance_in, - args.correspondance_out, - args.blank_period, - args.zarr, - not args.unraw, - ) - - -if __name__ == "__main__": - ( - CONFIG, - SAVE_STOP, - CORRESPONDANCES_IN, - CORRESPONDANCES_OUT, - BLANK_PERIOD, - ZARR, - RAW, - ) = usage() - # Create output directory - SAVE_DIR = CONFIG["PATHS"].get("SAVE_DIR", None) - if SAVE_DIR is not None and not exists(SAVE_DIR): - mkdir(SAVE_DIR) - - YAML_CORRESPONDANCES_OUT = CONFIG["PATHS"].get("CORRESPONDANCES_OUT", None) - if CORRESPONDANCES_IN is None: - CORRESPONDANCES_IN = CONFIG["PATHS"].get("CORRESPONDANCES_IN", None) - if CORRESPONDANCES_OUT is None: - CORRESPONDANCES_OUT = YAML_CORRESPONDANCES_OUT - if YAML_CORRESPONDANCES_OUT is None and CORRESPONDANCES_OUT is None: - CORRESPONDANCES_OUT = "{path}/{sign_type}_correspondances.nc" - - CLASS = None - CLASS_KW = dict() - if "CLASS" in CONFIG: - CLASSNAME = CONFIG["CLASS"]["CLASS"] - CLASS = getattr( - __import__(CONFIG["CLASS"]["MODULE"], globals(), locals(), CLASSNAME), - CLASSNAME, - ) - CLASS_KW = CONFIG["CLASS"].get("OPTIONS", dict()) - - NB_VIRTUAL_OBS_MAX_BY_SEGMENT = int(CONFIG.get("VIRTUAL_LENGTH_MAX", 0)) - - if isinstance(CONFIG["PATHS"]["FILES_PATTERN"], list): - DATASET_LIST = browse_dataset_in( - data_dir=None, - files_model=None, - files=CONFIG["PATHS"]["FILES_PATTERN"], - date_regexp=".*_([0-9]*?).[nz].*", - date_model="%Y%m%d", - ) - else: - DATASET_LIST = browse_dataset_in( - data_dir=dirname(CONFIG["PATHS"]["FILES_PATTERN"]), - files_model=basename(CONFIG["PATHS"]["FILES_PATTERN"]), - date_regexp=".*_([0-9]*?).[nz].*", - date_model="%Y%m%d", - ) - - if BLANK_PERIOD > 0: - DATASET_LIST = DATASET_LIST[:-BLANK_PERIOD] - logger.info("Last %d files will be pop", BLANK_PERIOD) - - START_TIME = datetime.now() - logger.info("Start tracking on %d files", len(DATASET_LIST)) - - NB_OBS_MIN = int(CONFIG.get("TRACK_DURATION_MIN", 14)) - if NB_OBS_MIN > len(DATASET_LIST): - raise Exception( - "Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)." - % (len(DATASET_LIST), NB_OBS_MIN) - ) - - CORRESPONDANCES = Correspondances( - datasets=DATASET_LIST["filename"], - virtual=NB_VIRTUAL_OBS_MAX_BY_SEGMENT, - class_method=CLASS, - class_kw=CLASS_KW, - previous_correspondance=CORRESPONDANCES_IN, - ) - CORRESPONDANCES.track() - logger.info("Track finish") - - logger.info("Start merging") - DATE_START, DATE_STOP = CORRESPONDANCES.period - DICT_COMPLETION = dict( - date_start=DATE_START, - date_stop=DATE_STOP, - date_prod=START_TIME, - path=SAVE_DIR, - sign_type=CORRESPONDANCES.current_obs.sign_legend, - ) - - CORRESPONDANCES.save(CORRESPONDANCES_OUT, DICT_COMPLETION) - if SAVE_STOP: - exit() - - # Merge correspondance, only do if we stop and store just after compute of correspondance - CORRESPONDANCES.prepare_merging() - - logger.info( - "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max() - ) - logger.info( - "The mean length is %d observations before filtering", - CORRESPONDANCES.nb_obs_by_tracks.mean(), - ) - - CORRESPONDANCES.get_unused_data(raw_data=RAW).write_file( - path=SAVE_DIR, filename="%(path)s/%(sign_type)s_untracked.nc", zarr_flag=ZARR - ) - - SHORT_CORRESPONDANCES = CORRESPONDANCES._copy() - SHORT_CORRESPONDANCES.shorter_than(size_max=NB_OBS_MIN) - - CORRESPONDANCES.longer_than(size_min=NB_OBS_MIN) - - FINAL_EDDIES = CORRESPONDANCES.merge(raw_data=RAW) - SHORT_TRACK = SHORT_CORRESPONDANCES.merge(raw_data=RAW) - - # We flag obs - if CORRESPONDANCES.virtual: - FINAL_EDDIES["virtual"][:] = FINAL_EDDIES["time"] == 0 - FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES["virtual"] == 1) - SHORT_TRACK["virtual"][:] = SHORT_TRACK["time"] == 0 - SHORT_TRACK.filled_by_interpolation(SHORT_TRACK["virtual"] == 1) - - # Total running time - FULL_TIME = datetime.now() - START_TIME - logger.info("Mean duration by loop : %s", FULL_TIME / (len(DATASET_LIST) - 1)) - logger.info("Duration : %s", FULL_TIME) - - logger.info( - "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max() - ) - logger.info( - "The mean length is %d observations after filtering", - CORRESPONDANCES.nb_obs_by_tracks.mean(), - ) - - FINAL_EDDIES.write_file(path=SAVE_DIR, zarr_flag=ZARR) - SHORT_TRACK.write_file( - filename="%(path)s/%(sign_type)s_track_too_short.nc", - path=SAVE_DIR, - zarr_flag=ZARR, - ) - diff --git a/src/scripts/EddyTranslate b/src/scripts/EddyTranslate index 94142132..a0060e9b 100644 --- a/src/scripts/EddyTranslate +++ b/src/scripts/EddyTranslate @@ -3,8 +3,8 @@ """ Translate eddy Dataset """ -import zarr from netCDF4 import Dataset +import zarr from py_eddy_tracker import EddyParser from py_eddy_tracker.observations.observation import EddiesObservations @@ -16,6 +16,8 @@ def id_parser(): ) parser.add_argument("filename_in") parser.add_argument("filename_out") + parser.add_argument("--unraw", action="store_true", help="Load unraw data, use only for netcdf." + "If unraw is active, netcdf is loaded without apply scalefactor and add_offset.") return parser @@ -32,10 +34,10 @@ def get_variable_name(filename): return list(h.keys()) -def get_variable(filename, varname): +def get_variable(filename, varname, raw=True): if is_nc(filename): dataset = EddiesObservations.load_from_netcdf( - filename, raw_data=True, include_vars=(varname,) + filename, raw_data=raw, include_vars=(varname,) ) else: dataset = EddiesObservations.load_from_zarr(filename, include_vars=(varname,)) @@ -49,8 +51,8 @@ if __name__ == "__main__": if not is_nc(args.filename_out): h = zarr.open(args.filename_out, "w") for varname in variables: - get_variable(args.filename_in, varname).to_zarr(h) + get_variable(args.filename_in, varname, raw=not args.unraw).to_zarr(h) else: with Dataset(args.filename_out, "w") as h: for varname in variables: - get_variable(args.filename_in, varname).to_netcdf(h) + get_variable(args.filename_in, varname, raw=not args.unraw).to_netcdf(h) diff --git a/tests/test_generic.py b/tests/test_generic.py index ab3832cc..ee2d7881 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,6 +1,6 @@ from numpy import arange, array, nan, ones, zeros -from py_eddy_tracker.generic import cumsum_by_track, simplify +from py_eddy_tracker.generic import cumsum_by_track, simplify, wrap_longitude def test_simplify(): @@ -30,3 +30,22 @@ def test_cumsum_by_track(): a = ones(10, dtype="i4") * 2 track = array([1, 1, 2, 2, 2, 2, 44, 44, 44, 48]) assert (cumsum_by_track(a, track) == [2, 4, 2, 4, 6, 8, 2, 4, 6, 2]).all() + + +def test_wrapping(): + y = x = arange(-5, 5, dtype="f4") + x_, _ = wrap_longitude(x, y, ref=-10) + assert (x_ == x).all() + x_, _ = wrap_longitude(x, y, ref=1) + assert x.size == x_.size + assert (x_[6:] == x[6:]).all() + assert (x_[:6] == x[:6] + 360).all() + x_, _ = wrap_longitude(x, y, ref=1, cut=True) + assert x.size + 3 == x_.size + assert (x_[6 + 3 :] == x[6:]).all() + assert (x_[:7] == x[:7] + 360).all() + + # FIXME Need evolution in wrap_longitude + # x %= 360 + # x_, _ = wrap_longitude(x, y, ref=-10, cut=True) + # assert x.size == x_.size diff --git a/tests/test_grid.py b/tests/test_grid.py index 5da5390e..0e6dd586 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,11 +1,11 @@ from matplotlib.path import Path -from numpy import array, ma +from numpy import arange, array, isnan, ma, nan, ones, zeros from pytest import approx -from py_eddy_tracker.data import get_path +from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import RegularGridDataset -G = RegularGridDataset(get_path("mask_1_60.nc"), "lon", "lat") +G = RegularGridDataset(get_demo_path("mask_1_60.nc"), "lon", "lat") X = 0.025 contour = Path( ( @@ -69,9 +69,56 @@ def test_interp(): ) x0, y0 = array((10,)), array((5,)) x1, y1 = array((15,)), array((5,)) + # outside but usable with nearest + x2, y2 = array((25,)), array((5,)) + # Outside for any interpolation + x3, y3 = array((25,)), array((16,)) + x4, y4 = array((55,)), array((25,)) # Interp nearest assert g.interp("z", x0, y0, method="nearest") == 0 assert g.interp("z", x1, y1, method="nearest") == 2 + assert isnan(g.interp("z", x4, y4, method="nearest")) + assert g.interp("z", x2, y2, method="nearest") == 2 + assert isnan(g.interp("z", x3, y3, method="nearest")) + # Interp bilinear assert g.interp("z", x0, y0) == 1.5 assert g.interp("z", x1, y1) == 2 + assert isnan(g.interp("z", x2, y2)) + + +def test_convolution(): + """ + Add some dummy check on convolution filter + """ + # Fake grid + z = ma.array( + arange(12).reshape((-1, 1)) * arange(10).reshape((1, -1)), + mask=zeros((12, 10), dtype="bool"), + dtype="f4", + ) + g = RegularGridDataset.with_array( + coordinates=("x", "y"), + datas=dict( + z=z, + x=arange(0, 6, 0.5), + y=arange(0, 5, 0.5), + ), + centered=True, + ) + + def kernel_func(lat): + return ones((3, 3)) + + # After transpose we must get same result + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert (d.T[:9, :9] == d[:9, :9]).all() + # We mask one value and check convolution result + z.mask[2, 2] = True + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert d[1, 1] == z[:3, :3].sum() / 8 + # Add nan and check only nearest value is contaminate + z[2, 2] = nan + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert not isnan(d[0, 0]) + assert isnan(d[1:4, 1:4]).all() diff --git a/tests/test_id.py b/tests/test_id.py index cedcdff8..c69a5a26 100644 --- a/tests/test_id.py +++ b/tests/test_id.py @@ -1,10 +1,10 @@ from datetime import datetime -from py_eddy_tracker.data import get_path +from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import RegularGridDataset g = RegularGridDataset( - get_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" + get_demo_path("dt_med_allsat_phy_l4_20160515_20190101.nc"), "longitude", "latitude" ) diff --git a/tests/test_network.py b/tests/test_network.py new file mode 100644 index 00000000..5cd9b4cd --- /dev/null +++ b/tests/test_network.py @@ -0,0 +1,15 @@ +from py_eddy_tracker.observations.network import Network + + +def test_group_translate(): + translate = Network.group_translator(5, ((0, 1), (0, 2), (1, 3))) + assert (translate == [3, 3, 3, 3, 4]).all() + + translate = Network.group_translator(5, ((1, 3), (0, 1), (0, 2))) + assert (translate == [3, 3, 3, 3, 4]).all() + + translate = Network.group_translator(8, ((1, 3), (2, 3), (2, 4), (5, 6), (4, 5))) + assert (translate == [0, 6, 6, 6, 6, 6, 6, 7]).all() + + translate = Network.group_translator(6, ((0, 1), (0, 2), (1, 3), (4, 5))) + assert (translate == [3, 3, 3, 3, 5, 5]).all() diff --git a/tests/test_obs.py b/tests/test_obs.py index 59a1edab..a912e06b 100644 --- a/tests/test_obs.py +++ b/tests/test_obs.py @@ -1,11 +1,11 @@ import zarr -from py_eddy_tracker.data import get_path +from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.observations.observation import EddiesObservations a_filename, c_filename = ( - get_path("Anticyclonic_20190223.nc"), - get_path("Cyclonic_20190223.nc"), + get_demo_path("Anticyclonic_20190223.nc"), + get_demo_path("Cyclonic_20190223.nc"), ) a = EddiesObservations.load_file(a_filename) a_raw = EddiesObservations.load_file(a_filename, raw_data=True) diff --git a/tests/test_poly.py b/tests/test_poly.py index b2aacb73..a780f64d 100644 --- a/tests/test_poly.py +++ b/tests/test_poly.py @@ -1,7 +1,13 @@ -from numpy import array, pi +from numpy import array, pi, roll from pytest import approx -from py_eddy_tracker.poly import convex, fit_circle, get_convex_hull, poly_area_vertice +from py_eddy_tracker.poly import ( + convex, + fit_circle, + get_convex_hull, + poly_area_vertice, + visvalingam, +) # Vertices for next test V = array(((2, 2, 3, 3, 2), (-10, -9, -9, -10, -10))) @@ -16,7 +22,7 @@ def test_fit_circle(): x0, y0, r, err = fit_circle(*V) assert x0 == approx(2.5, rel=1e-10) assert y0 == approx(-9.5, rel=1e-10) - assert r == approx(2 ** 0.5 / 2, rel=1e-10) + assert r == approx(2**0.5 / 2, rel=1e-10) assert err == approx((1 - 2 / pi) * 100, rel=1e-10) @@ -29,3 +35,19 @@ def test_convex(): def test_convex_hull(): assert convex(*get_convex_hull(*V_concave)) is True + + +def test_visvalingam(): + x = array([1, 2, 3, 4, 5, 6.75, 6, 1]) + y = array([-0.5, -1.5, -1, -1.75, -1, -1, -0.5, -0.5]) + x_target = [1, 2, 3, 4, 6, 1] + y_target = [-0.5, -1.5, -1, -1.75, -0.5, -0.5] + x_, y_ = visvalingam(x, y, 6) + assert (x_target == x_).all() + assert (y_target == y_).all() + x_, y_ = visvalingam(x[:-1], y[:-1], 6) + assert (x_target == x_).all() + assert (y_target == y_).all() + x_, y_ = visvalingam(roll(x, 2), roll(y, 2), 6) + assert (x_target[:-1] == x_[1:]).all() + assert (y_target[:-1] == y_[1:]).all() diff --git a/tests/test_track.py b/tests/test_track.py index f1d5903e..f7e83786 100644 --- a/tests/test_track.py +++ b/tests/test_track.py @@ -1,12 +1,12 @@ -import zarr from netCDF4 import Dataset +import zarr -from py_eddy_tracker.data import get_path +from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker from py_eddy_tracker.observations.observation import EddiesObservations from py_eddy_tracker.tracking import Correspondances -filename = get_path("Anticyclonic_20190223.nc") +filename = get_demo_path("Anticyclonic_20190223.nc") a0 = EddiesObservations.load_file(filename) a1 = a0.copy() diff --git a/versioneer.py b/versioneer.py index 2b545405..1e3753e6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,5 @@ -# Version: 0.18 + +# Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -6,18 +7,14 @@ ============== * like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer +* https://github.com/python-versioneer/python-versioneer * Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based +* License: Public Domain (Unlicense) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control @@ -26,9 +23,38 @@ ## Quick Install -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results +Versioneer provides two installation modes. The "classic" vendored mode installs +a copy of versioneer into your repository. The experimental build-time dependency mode +is intended to allow you to skip this step and simplify the process of upgrading. + +### Vendored mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) + * Note that you will need to add `tomli; python_version < "3.11"` to your + build-time dependencies if you use `pyproject.toml` +* run `versioneer install --vendor` in your source tree, commit the results +* verify version information with `python setup.py version` + +### Build-time dependency mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) +* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) + to the `requires` key of the `build-system` table in `pyproject.toml`: + ```toml + [build-system] + requires = ["setuptools", "versioneer[toml]"] + build-backend = "setuptools.build_meta" + ``` +* run `versioneer install --no-vendor` in your source tree, commit the results +* verify version information with `python setup.py version` ## Version Identifiers @@ -60,7 +86,7 @@ for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. +uncommitted changes). The version identifier is used for multiple purposes: @@ -165,7 +191,7 @@ Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). +[issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects @@ -179,7 +205,7 @@ `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. + provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs @@ -193,9 +219,9 @@ Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve @@ -223,31 +249,20 @@ cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace +* edit `setup.cfg` and `pyproject.toml`, if necessary, + to include any new configuration settings indicated by the release notes. + See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files @@ -264,36 +279,70 @@ direction and include code from all supported VCS systems, reducing the number of intermediate scripts. +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . +Specifically, both are released under the "Unlicense", as described in +https://unlicense.org/. -""" +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer -from __future__ import print_function +""" +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments -try: - import configparser -except ImportError: - import ConfigParser as configparser +import configparser import errno import json import os import re import subprocess import sys +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import NoReturn +import functools + +have_tomllib = True +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + versionfile_source: str + versionfile_build: Optional[str] + parentdir_prefix: Optional[str] + verbose: Optional[bool] -def get_root(): + +def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the @@ -301,20 +350,28 @@ def get_root(): """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -323,46 +380,62 @@ def get_root(): # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py) - ) + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root -def get_config_from_root(root): +def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or + # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None + root_pth = Path(root) + pyproject_toml = root_pth / "pyproject.toml" + setup_cfg = root_pth / "setup.cfg" + section: Union[Dict[str, Any], configparser.SectionProxy, None] = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, 'rb') as fobj: + pp = tomllib.load(fobj) + section = pp['tool']['versioneer'] + except (tomllib.TOMLDecodeError, KeyError) as e: + print(f"Failed to load config from {pyproject_toml}: {e}") + print("Try to load it from setup.cfg") + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + # `cast`` really shouldn't be used, but its simplest for the + # common VersioneerConfig users at the moment. We verify against + # `None` values elsewhere where it matters cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): + cfg.VCS = section['VCS'] + cfg.style = section.get("style", "") + cfg.versionfile_source = cast(str, section.get("versionfile_source")) + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = cast(str, section.get("tag_prefix")) + if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") + cfg.parentdir_prefix = section.get("parentdir_prefix") + if isinstance(section, configparser.SectionProxy): + # Make sure configparser translates to bool + cfg.verbose = section.getboolean("verbose") + else: + cfg.verbose = section.get("verbose") + return cfg @@ -371,41 +444,48 @@ class NotThisMethod(Exception): # these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f + HANDLERS.setdefault(vcs, {})[method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -416,28 +496,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -446,9 +523,11 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -464,8 +543,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -483,13 +569,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -498,22 +584,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -524,18 +623,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -544,15 +645,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% @@ -561,41 +661,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -608,11 +715,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d @@ -621,7 +728,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: @@ -630,6 +737,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %%s" %% r) return {"version": r, @@ -645,7 +757,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -656,8 +773,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) @@ -665,24 +789,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -699,7 +856,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces @@ -724,26 +881,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -768,23 +926,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] + rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -811,12 +1017,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -833,7 +1068,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -853,7 +1088,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -873,7 +1108,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -887,10 +1122,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -905,7 +1144,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -926,7 +1165,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, @@ -953,41 +1192,48 @@ def get_versions(): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -1000,11 +1246,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1013,7 +1259,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1021,30 +1267,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -1055,7 +1304,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1063,33 +1320,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -1098,16 +1379,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -1116,12 +1398,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1132,19 +1412,20 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def do_vcs_install(manifest_in, versionfile_source, ipy): +def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py @@ -1153,36 +1434,40 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] + files = [versionfile_source] if ipy: files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) present = False try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: pass if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -1191,30 +1476,23 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from +# This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. @@ -1231,43 +1509,41 @@ def get_versions(): """ -def versions_from_file(filename): +def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search( - r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: - mo = re.search( - r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) -def write_to_version_file(filename, versions): +def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) print("set %s to '%s'" % (filename, versions["version"])) -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -1285,29 +1561,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -1334,12 +1659,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -1356,7 +1710,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -1376,7 +1730,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -1396,26 +1750,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -1425,20 +1781,16 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" -def get_versions(verbose=False): +def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. @@ -1453,10 +1805,9 @@ def get_versions(verbose=False): assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1510,22 +1861,22 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} -def get_version(): +def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" +def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and @@ -1539,25 +1890,25 @@ def get_cmdclass(): # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - cmds = {} + cmds = {} if cmdclass is None else cmdclass.copy() - # we add "version" to both distutils and setuptools - from distutils.core import Command + # we add "version" to setuptools + from setuptools import Command class cmd_version(Command): description = "report generated version string" - user_options = [] - boolean_options = [] + user_options: List[Tuple[str, str, str]] = [] + boolean_options: List[str] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) @@ -1565,10 +1916,9 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - # we override "build_py" in both distutils and setuptools + # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py @@ -1583,30 +1933,68 @@ def run(self): # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py + if 'build_py' in cmds: + _build_py: Any = cmds['build_py'] else: - from distutils.command.build_py import build_py as _build_py + from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe + if 'build_ext' in cmds: + _build_ext: Any = cmds['build_ext'] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self) -> None: + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + if not os.path.exists(target_versionfile): + print(f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py.") + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1615,7 +2003,7 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1627,28 +2015,24 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if "py2exe" in sys.modules: # py2exe enabled? + if 'py2exe' in sys.modules: # py2exe enabled? try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 + from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1660,27 +2044,60 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["py2exe"] = cmd_py2exe + # sdist farms its file list building out to egg_info + if 'egg_info' in cmds: + _egg_info: Any = cmds['egg_info'] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self) -> None: + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append('versioneer.py') + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') + for f in self.filelist.files] + + manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') + with open(manifest_filename, 'w') as fobj: + fobj.write('\n'.join(normalized)) + + cmds['egg_info'] = cmd_egg_info + # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist + if 'sdist' in cmds: + _sdist: Any = cmds['sdist'] else: - from distutils.command.sdist import sdist as _sdist + from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): + def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old @@ -1688,7 +2105,7 @@ def run(self): self.distribution.metadata.version = versions["version"] return _sdist.run(self) - def make_release_tree(self, base_dir, files): + def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) @@ -1697,10 +2114,8 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file( - target_versionfile, self._versioneer_generated_versions - ) - + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) cmds["sdist"] = cmd_sdist return cmds @@ -1743,25 +2158,28 @@ def make_release_tree(self, base_dir, files): """ -INIT_PY_SNIPPET = """ +OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" +def do_setup() -> int: + """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) - except ( - EnvironmentError, - configparser.NoSectionError, - configparser.NoOptionError, - ) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) + except (OSError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (OSError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1770,76 +2188,46 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" - if INIT_PY_SNIPPET not in old: + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) + f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print( - " appending versionfile_source ('%s') to MANIFEST.in" - % cfg.versionfile_source - ) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") + maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 -def scan_setup_py(): +def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False @@ -1876,10 +2264,14 @@ def scan_setup_py(): return errors +def setup_command() -> NoReturn: + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) + setup_command()