diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 00000000..d9437d16 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,70 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ master ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ master ] + schedule: + - cron: '41 16 * * 4' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'python' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] + # Learn more about CodeQL language support at https://git.io/codeql-language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 7c800b33..f2f4753e 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -1,29 +1,33 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: Python application +name: Pytest & Flake8 -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] +on: [push, pull_request] jobs: build: - - runs-on: ubuntu-latest + strategy: + matrix: + # os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, windows-latest] + python_version: ['3.10', '3.11', '3.12'] + name: Run py eddy tracker build tests + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -l {0} steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python ${{ matrix.python_version }} uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: ${{ matrix.python_version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest pytest-cov if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Install package run: | @@ -34,6 +38,3 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - pytest diff --git a/.readthedocs.yml b/.readthedocs.yml index 1299f38e..ba36f8ea 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,7 +1,13 @@ version: 2 conda: environment: doc/environment.yml +build: + os: ubuntu-lts-latest + tools: + python: "mambaforge-latest" python: - install: - - method: setuptools - path: . + install: + - method: pip + path: . +sphinx: + configuration: doc/conf.py \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b974e3bd..6d6d6a30 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,8 +7,89 @@ The format is based on `Keep a Changelog `_ and this project adheres to `Semantic Versioning `_. [Unreleased] ------------- +------------- +Changed +^^^^^^^ + +Fixed +^^^^^ + +Added +^^^^^ + +[3.6.2] - 2025-06-06 +-------------------- +Changed +^^^^^^^ + +- Remove dead end method for network will move dead end to the trash and not remove observations + +Fixed +^^^^^ + +- Fix matplotlib version + +[3.6.1] - 2022-10-14 +-------------------- +Changed +^^^^^^^ + +- Rewrite particle candidate to be easily parallelize + +Fixed +^^^^^ + +- Check strictly increasing coordinates for RegularGridDataset. +- Grid mask is check to replace mask monovalue by 2D mask with fixed value +Added +^^^^^ + +- Add method to colorize contour with a field +- Add option to force align on to return all step for reference dataset +- Add method and property to network to easily select segment and network +- Add method to found same track/segment/network in dataset + +[3.6.0] - 2022-01-12 +-------------------- +Changed +^^^^^^^ + +- Now time allows second precision (instead of daily precision) in storage on uint32 from 01/01/1950 to 01/01/2086 + New identifications are produced with this type, old files could still be loaded. + If you use old identifications for tracking use the `--unraw` option to unpack old times and store data with the new format. +- Now amplitude is stored with .1 mm of precision (instead of 1 mm), same advice as for time. +- Expose more parameters to users for bash tools build_network & divide_network +- Add warning when loading a file created from a previous version of py-eddy-tracker. + + + +Fixed +^^^^^ + +- Fix bug in convolution(filter), lowest rows was replace by zeros in convolution computation. + Important impact for tiny kernel +- Fix method of sampling before contour fitting +- Fix bug when loading dataset in zarr format, not all variables were correctly loaded +- Fix bug when zarr dataset has same size for number of observations and contour size +- Fix bug when tracking, previous_virtual_obs was not always loaded + +Added +^^^^^ + +- Allow to replace mask by isnan method to manage nan data instead of masked data +- Add drifter colocation example + +[3.5.0] - 2021-06-22 +-------------------- + +Fixed +^^^^^ +- GridCollection get_next_time_step & get_previous_time_step needed more files to work in the dataset list. + The loop needed explicitly self.dataset[i+-1] even when i==0, therefore indice went out of range + +[3.4.0] - 2021-03-29 +-------------------- Changed ^^^^^^^ - `TrackEddiesObservations.filled_by_interpolation` method stop to normalize longitude, to continue to have same diff --git a/README.md b/README.md index 95536d1d..0cc34894 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,38 @@ [![PyPI version](https://badge.fury.io/py/pyEddyTracker.svg)](https://badge.fury.io/py/pyEddyTracker) +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6333988.svg)](https://doi.org/10.5281/zenodo.6333988) [![Documentation Status](https://readthedocs.org/projects/py-eddy-tracker/badge/?version=stable)](https://py-eddy-tracker.readthedocs.io/en/stable/?badge=stable) [![Gitter](https://badges.gitter.im/py-eddy-tracker/community.svg)](https://gitter.im/py-eddy-tracker/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/AntSimi/py-eddy-tracker/master?urlpath=lab/tree/notebooks/python_module/) +[![pytest](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml/badge.svg)](https://github.com/AntSimi/py-eddy-tracker/actions/workflows/python-app.yml) # README # +### How to cite code? ### + +Zenodo provide DOI for each tagged version, [all DOI are available here](https://doi.org/10.5281/zenodo.6333988) + ### Method ### Method was described in : +[Pegliasco, C., Delepoulle, A., Morrow, R., Faugère, Y., and Dibarboure, G.: META3.1exp : A new Global Mesoscale Eddy Trajectories Atlas derived from altimetry, Earth Syst. Sci. Data Discuss.](https://doi.org/10.5194/essd-14-1087-2022) + [Mason, E., A. Pascual, and J. C. McWilliams, 2014: A new sea surface height–based code for oceanic mesoscale eddy tracking.](https://doi.org/10.1175/JTECH-D-14-00019.1) ### Use case ### Method is used in : - + [Mason, E., A. Pascual, P. Gaube, S.Ruiz, J. Pelegrí, A. Delepoulle, 2017: Subregional characterization of mesoscale eddies across the Brazil-Malvinas Confluence](https://doi.org/10.1002/2016JC012611) ### How do I get set up? ### #### Short story #### + ```bash pip install pyeddytracker ``` + #### Long story #### To avoid problems with installation, use of the virtualenv Python virtual environment is recommended. @@ -33,12 +43,20 @@ Then use pip to install all dependencies (numpy, scipy, matplotlib, netCDF4, ... pip install numpy scipy netCDF4 matplotlib opencv-python pyyaml pint polygon3 ``` -Then run the following to install the eddy tracker: +Clone : + +```bash +git clone https://github.com/AntSimi/py-eddy-tracker +``` + +Then run the following to install the eddy tracker : ```bash python setup.py install ``` + ### Tools gallery ### + Several examples based on py eddy tracker module are [here](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html). [![](https://py-eddy-tracker.readthedocs.io/en/latest/_static/logo.png)](https://py-eddy-tracker.readthedocs.io/en/latest/python_module/index.html) diff --git a/check.sh b/check.sh index ddafab69..a402bf52 100644 --- a/check.sh +++ b/check.sh @@ -1,7 +1,5 @@ -isort src tests examples -black src tests examples -blackdoc src tests examples -flake8 tests examples src --count --select=E9,F63,F7,F82 --show-source --statistics -# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -flake8 tests examples src --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -pytest -vv --cov py_eddy_tracker --cov-report html +isort . +black . +blackdoc . +flake8 . +python -m pytest -vv --cov py_eddy_tracker --cov-report html diff --git a/doc/conf.py b/doc/conf.py index ccf26e4e..0844d585 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -96,9 +96,9 @@ master_doc = "index" # General information about the project. -project = u"py-eddy-tracker" -copyright = u"2019, A. Delepoulle & E. Mason" -author = u"A. Delepoulle & E. Mason" +project = "py-eddy-tracker" +copyright = "2019, A. Delepoulle & E. Mason" +author = "A. Delepoulle & E. Mason" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -272,8 +272,8 @@ ( master_doc, "py-eddy-tracker.tex", - u"py-eddy-tracker Documentation", - u"A. Delepoulle \\& E. Mason", + "py-eddy-tracker Documentation", + "A. Delepoulle \\& E. Mason", "manual", ), ] @@ -304,7 +304,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, "py-eddy-tracker", u"py-eddy-tracker Documentation", [author], 1) + (master_doc, "py-eddy-tracker", "py-eddy-tracker Documentation", [author], 1) ] # If true, show URL addresses after external links. @@ -320,7 +320,7 @@ ( master_doc, "py-eddy-tracker", - u"py-eddy-tracker Documentation", + "py-eddy-tracker Documentation", author, "py-eddy-tracker", "One line description of project.", diff --git a/doc/environment.yml b/doc/environment.yml index db50b528..063a60de 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -1,10 +1,12 @@ channels: - conda-forge - - defaults dependencies: - - python=3.7 + - python>=3.10, <3.13 - ffmpeg + - pip - pip: + - -r ../requirements.txt + - git+https://github.com/AntSimi/py-eddy-tracker-sample-id.git - sphinx-gallery - sphinx_rtd_theme - sphinx>=3.1 diff --git a/doc/grid_identification.rst b/doc/grid_identification.rst index c645f80c..2cc3fb52 100644 --- a/doc/grid_identification.rst +++ b/doc/grid_identification.rst @@ -47,38 +47,42 @@ Activate verbose .. code-block:: python from py_eddy_tracker import start_logger - start_logger().setLevel('DEBUG') # Available options: ERROR, WARNING, INFO, DEBUG + + start_logger().setLevel("DEBUG") # Available options: ERROR, WARNING, INFO, DEBUG Run identification .. code-block:: python from datetime import datetime + h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) date = datetime(2019, 2, 23) a, c = h.eddy_identification( - 'adt', 'ugos', 'vgos', # Variables used for identification - date, # Date of identification - 0.002, # step between two isolines of detection (m) - pixel_limit=(5, 2000), # Min and max pixel count for valid contour - shape_error=55, # Error max (%) between ratio of circle fit and contour - ) + "adt", + "ugos", + "vgos", # Variables used for identification + date, # Date of identification + 0.002, # step between two isolines of detection (m) + pixel_limit=(5, 2000), # Min and max pixel count for valid contour + shape_error=55, # Error max (%) between ratio of circle fit and contour + ) Plot the resulting identification .. code-block:: python - fig = plt.figure(figsize=(15,7)) - ax = fig.add_axes([.03,.03,.94,.94]) - ax.set_title('Eddies detected -- Cyclonic(red) and Anticyclonic(blue)') - ax.set_ylim(-75,75) - ax.set_xlim(0,360) - ax.set_aspect('equal') - a.display(ax, color='b', linewidth=.5) - c.display(ax, color='r', linewidth=.5) + fig = plt.figure(figsize=(15, 7)) + ax = fig.add_axes([0.03, 0.03, 0.94, 0.94]) + ax.set_title("Eddies detected -- Cyclonic(red) and Anticyclonic(blue)") + ax.set_ylim(-75, 75) + ax.set_xlim(0, 360) + ax.set_aspect("equal") + a.display(ax, color="b", linewidth=0.5) + c.display(ax, color="r", linewidth=0.5) ax.grid() - fig.savefig('share/png/eddies.png') + fig.savefig("share/png/eddies.png") .. image:: ../share/png/eddies.png @@ -87,7 +91,8 @@ Save identification data .. code-block:: python from netCDF import Dataset - with Dataset(date.strftime('share/Anticyclonic_%Y%m%d.nc'), 'w') as h: + + with Dataset(date.strftime("share/Anticyclonic_%Y%m%d.nc"), "w") as h: a.to_netcdf(h) - with Dataset(date.strftime('share/Cyclonic_%Y%m%d.nc'), 'w') as h: + with Dataset(date.strftime("share/Cyclonic_%Y%m%d.nc"), "w") as h: c.to_netcdf(h) diff --git a/doc/grid_load_display.rst b/doc/grid_load_display.rst index 2e570274..2f0e3765 100644 --- a/doc/grid_load_display.rst +++ b/doc/grid_load_display.rst @@ -7,7 +7,12 @@ Loading grid .. code-block:: python from py_eddy_tracker.dataset.grid import RegularGridDataset - grid_name, lon_name, lat_name = 'share/nrt_global_allsat_phy_l4_20190223_20190226.nc', 'longitude', 'latitude' + + grid_name, lon_name, lat_name = ( + "share/nrt_global_allsat_phy_l4_20190223_20190226.nc", + "longitude", + "latitude", + ) h = RegularGridDataset(grid_name, lon_name, lat_name) Plotting grid @@ -15,14 +20,15 @@ Plotting grid .. code-block:: python from matplotlib import pyplot as plt + fig = plt.figure(figsize=(14, 12)) - ax = fig.add_axes([.02, .51, .9, .45]) - ax.set_title('ADT (m)') + ax = fig.add_axes([0.02, 0.51, 0.9, 0.45]) + ax.set_title("ADT (m)") ax.set_ylim(-75, 75) - ax.set_aspect('equal') - m = h.display(ax, name='adt', vmin=-1, vmax=1) + ax.set_aspect("equal") + m = h.display(ax, name="adt", vmin=-1, vmax=1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .51, .01, .45])) + plt.colorbar(m, cax=fig.add_axes([0.94, 0.51, 0.01, 0.45])) Filtering @@ -30,27 +36,27 @@ Filtering .. code-block:: python h = RegularGridDataset(grid_name, lon_name, lat_name) - h.bessel_high_filter('adt', 500, order=3) + h.bessel_high_filter("adt", 500, order=3) Save grid .. code-block:: python - h.write('/tmp/grid.nc') + h.write("/tmp/grid.nc") Add second plot .. code-block:: python - ax = fig.add_axes([.02, .02, .9, .45]) - ax.set_title('ADT Filtered (m)') - ax.set_aspect('equal') + ax = fig.add_axes([0.02, 0.02, 0.9, 0.45]) + ax.set_title("ADT Filtered (m)") + ax.set_aspect("equal") ax.set_ylim(-75, 75) - m = h.display(ax, name='adt', vmin=-.1, vmax=.1) + m = h.display(ax, name="adt", vmin=-0.1, vmax=0.1) ax.grid(True) - plt.colorbar(m, cax=fig.add_axes([.94, .02, .01, .45])) - fig.savefig('share/png/filter.png') + plt.colorbar(m, cax=fig.add_axes([0.94, 0.02, 0.01, 0.45])) + fig.savefig("share/png/filter.png") .. image:: ../share/png/filter.png \ No newline at end of file diff --git a/doc/run_tracking.rst b/doc/run_tracking.rst index 0a03848d..36290339 100644 --- a/doc/run_tracking.rst +++ b/doc/run_tracking.rst @@ -5,9 +5,9 @@ Tracking Requirements ************ -Before to run tracking, you will need to run identification on every time step of the period (period of your study). +Before tracking, you will need to run identification on every time step of the period (period of your study). -**Advice** : Before to run tracking, displaying some identification file allows to learn a lot +**Advice** : Before tracking, displaying some identification files. You will learn a lot Default method ************** @@ -24,9 +24,9 @@ Example of conf.yaml FILES_PATTERN: MY_IDENTIFICATION_PATH/Anticyclonic*.nc SAVE_DIR: MY_OUTPUT_PATH - # Number of timestep for missing detection + # Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 3 - # Minimal time to consider as a full track + # Minimal number of timesteps to considered as a long trajectory TRACK_DURATION_MIN: 10 To run: @@ -63,13 +63,13 @@ With yaml you could also select another tracker: .. code-block:: yaml PATHS: - # Files produces with EddyIdentification + # Files produced with EddyIdentification FILES_PATTERN: MY/IDENTIFICATION_PATH/Anticyclonic*.nc SAVE_DIR: MY_OUTPUT_PATH - # Number of timesteps for missing detection + # Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 3 - # Minimal time to consider as a full track + # Minimal number of timesteps to considered as a long trajectory TRACK_DURATION_MIN: 10 CLASS: diff --git a/doc/spectrum.rst b/doc/spectrum.rst index d751b909..f96e30a0 100644 --- a/doc/spectrum.rst +++ b/doc/spectrum.rst @@ -11,7 +11,7 @@ Load data raw = RegularGridDataset(grid_name, lon_name, lat_name) filtered = RegularGridDataset(grid_name, lon_name, lat_name) - filtered.bessel_low_filter('adt', 150, order=3) + filtered.bessel_low_filter("adt", 150, order=3) areas = dict( sud_pacific=dict(llcrnrlon=188, urcrnrlon=280, llcrnrlat=-64, urcrnrlat=-7), @@ -23,24 +23,33 @@ Compute and display spectrum .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum') - ax.set_xlabel('km') + ax.set_title("Spectrum") + ax.set_xlabel("km") for name_area, area in areas.items(): - - lon_spec, lat_spec = raw.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s raw' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s raw' % name_area, color=mappable.get_color(), linestyle='--') - - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area) - mappable = ax.loglog(*lat_spec, label='lat %s high' % name_area)[0] - ax.loglog(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = raw.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s raw" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s raw" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area) + mappable = ax.loglog(*lat_spec, label="lat %s high" % name_area)[0] + ax.loglog( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum.png') + fig.savefig("share/png/spectrum.png") .. image:: ../share/png/spectrum.png @@ -49,18 +58,23 @@ Compute and display spectrum ratio .. code-block:: python - fig = plt.figure(figsize=(10,6)) + fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) - ax.set_title('Spectrum ratio') - ax.set_xlabel('km') + ax.set_title("Spectrum ratio") + ax.set_xlabel("km") for name_area, area in areas.items(): - lon_spec, lat_spec = filtered.spectrum_lonlat('adt', area=area, ref=raw) - mappable = ax.plot(*lat_spec, label='lat %s high' % name_area)[0] - ax.plot(*lon_spec, label='lon %s high' % name_area, color=mappable.get_color(), linestyle='--') - - ax.set_xscale('log') + lon_spec, lat_spec = filtered.spectrum_lonlat("adt", area=area, ref=raw) + mappable = ax.plot(*lat_spec, label="lat %s high" % name_area)[0] + ax.plot( + *lon_spec, + label="lon %s high" % name_area, + color=mappable.get_color(), + linestyle="--" + ) + + ax.set_xscale("log") ax.legend() ax.grid() - fig.savefig('share/png/spectrum_ratio.png') + fig.savefig("share/png/spectrum_ratio.png") .. image:: ../share/png/spectrum_ratio.png diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..e94c7bc1 --- /dev/null +++ b/environment.yml @@ -0,0 +1,11 @@ +name: binder-pyeddytracker +channels: + - conda-forge +dependencies: + - python>=3.10, <3.13 + - pip + - ffmpeg + - pip: + - -r requirements.txt + - pyeddytrackersample + - . diff --git a/examples/01_general_things/pet_storage.py b/examples/01_general_things/pet_storage.py index 28f0f76e..918ebbee 100644 --- a/examples/01_general_things/pet_storage.py +++ b/examples/01_general_things/pet_storage.py @@ -15,9 +15,11 @@ manage eddies associated in networks, the ```track``` and ```segment``` fields allow to separate observations """ +from matplotlib import pyplot as plt +from numpy import arange, outer import py_eddy_tracker_sample -from py_eddy_tracker.data import get_demo_path, get_remote_demo_sample +from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.observations.network import NetworkObservations from py_eddy_tracker.observations.observation import EddiesObservations, Table from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -32,7 +34,7 @@ # array field like contour/profile are 2D column. # %% -# Eddies files (zarr or netcdf) could be loaded with ```load_file``` method: +# Eddies files (zarr or netcdf) can be loaded with ```load_file``` method: eddies_collections = EddiesObservations.load_file(get_demo_path("Cyclonic_20160515.nc")) eddies_collections.field_table() # offset and scale_factor are used only when data is stored in zarr or netCDF4 @@ -58,6 +60,53 @@ # --------------- # All contours are stored on the same number of points, and are resampled if needed with an algorithm to be stored as objects +# %% +# Speed profile storage +# --------------------- +# Speed profile is an interpolation of speed mean along each contour. +# For each contour included in eddy, we compute mean of speed along the contour, +# and after we interpolate speed mean array on a fixed size array. +# +# Several field are available to understand "uavg_profile" : +# 0. - num_contours : Number of contour in eddies, must be equal to amplitude divide by isoline step +# 1. - height_inner_contour : height of inner contour used +# 2. - height_max_speed_contour : height of max speed contour used +# 3. - height_external_contour : height of outter contour used +# +# Last value of "uavg_profile" is for inner contour and first value for outter contour. + +# Observations selection of "uavg_profile" with high number of contour(Eddy with high amplitude) +e = eddies_collections.extract_with_mask(eddies_collections.num_contours > 15) + +# %% + +# Raw display of profiles with more than 15 contours +ax = plt.subplot(111) +_ = ax.plot(e.uavg_profile.T, lw=0.5) + +# %% + +# Profile from inner to outter +ax = plt.subplot(111) +ax.plot(e.uavg_profile[:, ::-1].T, lw=0.5) +_ = ax.set_xlabel("From inner to outter contour"), ax.set_ylabel("Speed (m/s)") + +# %% + +# If we normalize indice of contour to set speed contour to 1 and inner contour to 0 +ax = plt.subplot(111) +h_in = e.height_inner_contour +h_s = e.height_max_speed_contour +h_e = e.height_external_contour +r = (h_e - h_in) / (h_s - h_in) +nb_pt = e.uavg_profile.shape[1] +# Create an x array for each profile +x = outer(arange(nb_pt) / nb_pt, r) + +ax.plot(x, e.uavg_profile[:, ::-1].T, lw=0.5) +_ = ax.set_xlabel("From inner to outter contour"), ax.set_ylabel("Speed (m/s)") + + # %% # Trajectories # ------------ @@ -86,11 +135,7 @@ # - next_obs : Index of the next observation in the full dataset, if -1 there are no next observation (the segment ends) # - previous_cost : Result of the cost function (1 is a good association, 0 is bad) with previous observation # - next_cost : Result of the cost function (1 is a good association, 0 is bad) with next observation -eddies_network = NetworkObservations.load_file( - get_remote_demo_sample( - "eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc" - ) -) +eddies_network = NetworkObservations.load_file(get_demo_path("network_med.nc")) eddies_network.field_table() # %% diff --git a/examples/02_eddy_identification/pet_eddy_detection_ACC.py b/examples/02_eddy_identification/pet_eddy_detection_ACC.py index c799a45e..d12c62f3 100644 --- a/examples/02_eddy_identification/pet_eddy_detection_ACC.py +++ b/examples/02_eddy_identification/pet_eddy_detection_ACC.py @@ -7,10 +7,10 @@ Two detections are provided : with a filtered ADT and without filtering """ + from datetime import datetime -from matplotlib import pyplot as plt -from matplotlib import style +from matplotlib import pyplot as plt, style from py_eddy_tracker import data from py_eddy_tracker.dataset.grid import RegularGridDataset @@ -81,7 +81,7 @@ def set_fancy_labels(fig, ticklabelsize=14, labelsize=14, labelweight="semibold" # Identification # ^^^^^^^^^^^^^^ # Run the identification step with slices of 2 mm -date = datetime(2016, 5, 15) +date = datetime(2019, 2, 23) kw_ident = dict( date=date, step=0.002, shape_error=70, sampling=30, uname="u", vname="v" ) diff --git a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py index f9e5d4c3..fa27a3d1 100644 --- a/examples/02_eddy_identification/pet_interp_grid_on_dataset.py +++ b/examples/02_eddy_identification/pet_interp_grid_on_dataset.py @@ -43,7 +43,7 @@ def update_axes(ax, mappable=None): # %% # Compute and store eke in cm²/s² aviso_map.add_grid( - "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100 ** 2) + "eke", (aviso_map.grid("u") ** 2 + aviso_map.grid("v") ** 2) * 0.5 * (100**2) ) eke_kwargs = dict(vmin=1, vmax=1000, cmap="magma_r") diff --git a/examples/02_eddy_identification/pet_statistics_on_identification.py b/examples/02_eddy_identification/pet_statistics_on_identification.py new file mode 100644 index 00000000..dbd73c61 --- /dev/null +++ b/examples/02_eddy_identification/pet_statistics_on_identification.py @@ -0,0 +1,105 @@ +""" +Stastics on identification files +================================ + +Some statistics on raw identification without any tracking +""" +from matplotlib import pyplot as plt +from matplotlib.dates import date2num +import numpy as np + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.observations.observation import EddiesObservations + +start_logger().setLevel("ERROR") + + +# %% +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94]) + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_aspect("equal") + ax.set_title(title) + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9])) + + +# %% +# We load demo sample and take only first year. +# +# Replace by a list of filename to apply on your own dataset. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:365] + +# %% +# Merge all identification datasets in one object +all_a = EddiesObservations.concatenate( + [EddiesObservations.load_file(i) for i in file_objects] +) + +# %% +# We define polygon bound +x0, x1, y0, y1 = 15, 20, 33, 38 +xs = np.array([[x0, x1, x1, x0, x0]], dtype="f8") +ys = np.array([[y0, y0, y1, y1, y0]], dtype="f8") +# Polygon object created for the match function use. +polygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys) + +# %% +# Geographic frequency of eddies +step = 0.125 +ax = start_axes("") +# Count pixel encompassed in each contour +g_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True) +m = g_a.display( + ax, cmap="terrain_r", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name="count" +) +ax.plot(polygon["contour_lon_e"][0], polygon["contour_lat_e"][0], "r") +update_axes(ax, m) + +# %% +# We use the match function to count the number of eddies that intersect the polygon defined previously +# `p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon. +i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False) +i_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True) + +# %% +dt = np.datetime64("1970-01-01") - np.datetime64("1950-01-01") +kw_hist = dict( + bins=date2num(np.arange(21900, 22300).astype("datetime64") - dt), histtype="step" +) +# translate julian day in datetime64 +t = all_a.time.astype("datetime64") - dt +# %% +# Number of eddies within a polygon +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Different ways to count eddies within a polygon") +ax.set_ylabel("Count") +m = all_a.mask_from_polygons(((xs, ys),)) +ax.hist(t[m], label="Eddy Center in polygon", **kw_hist) +ax.hist(t[i_s[c_s > 0]], label="Intersection Speed contour and polygon", **kw_hist) +ax.hist(t[i_e[c_e > 0]], label="Intersection Effective contour and polygon", **kw_hist) +ax.legend() +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() + +# %% +# Percent of the area of interest occupied by eddies. +ax = plt.figure(figsize=(12, 6)).add_subplot(111) +ax.set_title("Percent of polygon occupied by an anticyclonic eddy") +ax.set_ylabel("Percent of polygon") +ax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label="speed contour", **kw_hist) +ax.hist( + t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label="effective contour", **kw_hist +) +ax.legend(), ax.set_ylim(0, 25) +ax.set_xlim(np.datetime64("2010"), np.datetime64("2011")) +ax.grid() diff --git a/examples/06_grid_manipulation/pet_advect.py b/examples/06_grid_manipulation/pet_advect.py index 13052af5..d7cc67e9 100644 --- a/examples/06_grid_manipulation/pet_advect.py +++ b/examples/06_grid_manipulation/pet_advect.py @@ -10,9 +10,9 @@ from matplotlib.animation import FuncAnimation from numpy import arange, isnan, meshgrid, ones -import py_eddy_tracker.gui from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.observation import EddiesObservations # %% @@ -32,7 +32,7 @@ # %% # Quiver from u/v with eddies fig = plt.figure(figsize=(10, 5)) -ax = fig.add_axes([0, 0, 1, 1], projection="full_axes") +ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) ax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid() x, y = meshgrid(g.x_c, g.y_c) a.filled(ax, facecolors="r", alpha=0.1), c.filled(ax, facecolors="b", alpha=0.1) @@ -50,7 +50,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass @@ -73,7 +73,7 @@ def save(self, *args, **kwargs): # %% # Movie properties kwargs = dict(frames=arange(51), interval=100) -kw_p = dict(nb_step=2, time_step=21600) +kw_p = dict(u_name="u", v_name="v", nb_step=2, time_step=21600) frame_t = kw_p["nb_step"] * kw_p["time_step"] / 86400.0 @@ -82,7 +82,7 @@ def save(self, *args, **kwargs): def anim_ax(**kw): t = 0 fig = plt.figure(figsize=(10, 5), dpi=55) - axes = fig.add_axes([0, 0, 1, 1], projection="full_axes") + axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid() a.filled(axes, facecolors="r", alpha=0.1), c.filled(axes, facecolors="b", alpha=0.1) line = axes.plot([], [], "k", **kw)[0] @@ -102,7 +102,7 @@ def update(i_frame, t_step): # ^^^^^^^^^^^^^^^^ # Draw 3 last position in one path for each particles., # it could be run backward with `backward=True` option in filament method -p = g.filament(x, y, "u", "v", **kw_p, filament_size=3) +p = g.filament(x, y, **kw_p, filament_size=3) fig, txt, l, t = anim_ax(lw=0.5) _ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) @@ -110,13 +110,13 @@ def update(i_frame, t_step): # Particle forward # ^^^^^^^^^^^^^^^^^ # Forward advection of particles -p = g.advect(x, y, "u", "v", **kw_p) +p = g.advect(x, y, **kw_p) fig, txt, l, t = anim_ax(ls="", marker=".", markersize=1) _ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,)) # %% # We get last position and run backward until original position -p = g.advect(x, y, "u", "v", **kw_p, backward=True) +p = g.advect(x, y, **kw_p, backward=True) fig, txt, l, _ = anim_ax(ls="", marker=".", markersize=1) _ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,)) @@ -139,9 +139,11 @@ def update(i_frame, t_step): ) for time_step in (10800, 21600, 43200, 86400): x, y = x0.copy(), y0.copy() - kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step) - g.advect(x, y, "u", "v", **kw_advect).__next__() - g.advect(x, y, "u", "v", **kw_advect, backward=True).__next__() + kw_advect = dict( + nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name="u", v_name="v" + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 ax.hist(d, **kw, label=f"{86400. / time_step:.0f} time step by day") ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() @@ -158,9 +160,14 @@ def update(i_frame, t_step): time_step = 10800 for duration in (5, 50, 100): x, y = x0.copy(), y0.copy() - kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step) - g.advect(x, y, "u", "v", **kw_advect).__next__() - g.advect(x, y, "u", "v", **kw_advect, backward=True).__next__() + kw_advect = dict( + nb_step=int(duration * 86400 / time_step), + time_step=time_step, + u_name="u", + v_name="v", + ) + g.advect(x, y, **kw_advect).__next__() + g.advect(x, y, **kw_advect, backward=True).__next__() d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5 ax.hist(d, **kw, label=f"Time duration {duration} days") ax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc="lower right"), ax.grid() diff --git a/examples/06_grid_manipulation/pet_lavd.py b/examples/06_grid_manipulation/pet_lavd.py index e0dbbb54..a3ea846e 100644 --- a/examples/06_grid_manipulation/pet_lavd.py +++ b/examples/06_grid_manipulation/pet_lavd.py @@ -24,16 +24,16 @@ from matplotlib.animation import FuncAnimation from numpy import arange, meshgrid, zeros -import py_eddy_tracker.gui from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.observation import EddiesObservations # %% def start_ax(title="", dpi=90): fig = plt.figure(figsize=(16, 9), dpi=dpi) - ax = fig.add_axes([0, 0, 1, 1], projection="full_axes") + ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) ax.set_xlim(0, 32), ax.set_ylim(28, 46) ax.set_title(title) return fig, ax, ax.text(3, 32, "", fontsize=20) @@ -65,7 +65,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass @@ -110,9 +110,11 @@ def save(self, *args, **kwargs): step_by_day = 3 # Compute step of advection every 4h nb_step = 2 -kw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step) +kw_p = dict( + nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name="u", v_name="v" +) # Start a generator which at each iteration return new position at next time step -particule = g.advect(x, y, "u", "v", **kw_p, rk4=True) +particule = g.advect(x, y, **kw_p, rk4=True) # %% # LAVD @@ -142,8 +144,9 @@ def update(i_frame): kw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True) fig, ax, txt = start_ax(dpi=60) -x_g_, y_g_ = arange(0 - step / 2, 36 + step / 2, step), arange( - 28 - step / 2, 46 + step / 2, step +x_g_, y_g_ = ( + arange(0 - step / 2, 36 + step / 2, step), + arange(28 - step / 2, 46 + step / 2, step), ) # pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular pcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity) @@ -157,13 +160,7 @@ def update(i_frame): # %% # Format LAVD data lavd = RegularGridDataset.with_array( - coordinates=("lon", "lat"), - datas=dict( - lavd=lavd.T, - lon=x_g, - lat=y_g, - ), - centered=True, + coordinates=("lon", "lat"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True ) # %% diff --git a/examples/06_grid_manipulation/pet_okubo_weiss.py b/examples/06_grid_manipulation/pet_okubo_weiss.py index c60a92fb..aa8a063e 100644 --- a/examples/06_grid_manipulation/pet_okubo_weiss.py +++ b/examples/06_grid_manipulation/pet_okubo_weiss.py @@ -1,8 +1,8 @@ r""" Get Okubo Weis -===================== +============== -.. math:: OW = S_n^2 + S_s^2 + \omega^2 +.. math:: OW = S_n^2 + S_s^2 - \omega^2 with normal strain (:math:`S_n`), shear strain (:math:`S_s`) and vorticity (:math:`\omega`) diff --git a/examples/07_cube_manipulation/README.rst b/examples/07_cube_manipulation/README.rst index 147ce3f3..7cecfbd4 100644 --- a/examples/07_cube_manipulation/README.rst +++ b/examples/07_cube_manipulation/README.rst @@ -1,6 +1,2 @@ Time grid computation ===================== - -.. warning:: - - Time grid is under development, API could move quickly! diff --git a/examples/07_cube_manipulation/pet_cube.py b/examples/07_cube_manipulation/pet_cube.py index 6c0db253..cba6c85b 100644 --- a/examples/07_cube_manipulation/pet_cube.py +++ b/examples/07_cube_manipulation/pet_cube.py @@ -4,18 +4,19 @@ Example which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles. """ +from datetime import datetime, timedelta + # sphinx_gallery_thumbnail_number = 2 import re -from datetime import datetime, timedelta from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation from numpy import arange, isnan, meshgrid, ones -import py_eddy_tracker.gui from py_eddy_tracker import start_logger from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import GridCollection +from py_eddy_tracker.gui import GUI_AXES start_logger().setLevel("ERROR") @@ -31,7 +32,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass @@ -70,7 +71,7 @@ def save(self, *args, **kwargs): # Function def anim_ax(**kw): fig = plt.figure(figsize=(10, 5), dpi=55) - axes = fig.add_axes([0, 0, 1, 1], projection="full_axes") + axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES) axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid() line = axes.plot([], [], "k", **kw)[0] return fig, axes.text(21, 32.1, ""), line diff --git a/examples/07_cube_manipulation/pet_fsle_med.py b/examples/07_cube_manipulation/pet_fsle_med.py index 46c5fdcc..9d78ea02 100644 --- a/examples/07_cube_manipulation/pet_fsle_med.py +++ b/examples/07_cube_manipulation/pet_fsle_med.py @@ -14,7 +14,7 @@ from matplotlib import pyplot as plt from numba import njit -from numpy import arange, arctan2, empty, isnan, log2, ma, meshgrid, ones, pi, zeros +from numpy import arange, arctan2, empty, isnan, log, ma, meshgrid, ones, pi, zeros from py_eddy_tracker import start_logger from py_eddy_tracker.data import get_demo_path @@ -26,6 +26,10 @@ # %% # ADT in med # ---------- +# :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is +# made for data stores in time cube, you could use also +# :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to +# load data-cube from multiple file. c = GridCollection.from_netcdf_cube( get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), "longitude", @@ -45,7 +49,7 @@ def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): Check if distance between eastern or northern particle to center particle is bigger than `dist_max` """ nb_p = x.shape[0] // 3 - delta = dist_max ** 2 + delta = dist_max**2 for i in range(nb_p): i0 = i * 3 i_n = i0 + 1 @@ -55,10 +59,10 @@ def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): continue # Distance with north dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n] - dn = dxn ** 2 + dyn ** 2 + dn = dxn**2 + dyn**2 # Distance with east dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e] - de = dxe ** 2 + dye ** 2 + de = dxe**2 + dye**2 if dn >= delta or de >= delta: s1 = dn + de @@ -67,7 +71,7 @@ def check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6): s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * ( (dxn - dye) ** 2 + (dxe + dyn) ** 2 ) - flse[i] = 1 / (2 * dt) * log2(1 / (2 * dist_init ** 2) * (s1 + s2 ** 0.5)) + flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5)) theta[i] = arctan2(at1, at2 + s2) * 180 / pi # To know where value are set m_set[i] = False @@ -138,8 +142,10 @@ def build_triplet(x, y, step=0.02): used = zeros(x.shape[0], dtype="bool") # advection generator -kw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used) -p = c.advect(x, y, "u", "v", time_step=86400 / time_step_by_days, **kw) +kw = dict( + t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name="u", v_name="v" +) +p = c.advect(x, y, time_step=86400 / time_step_by_days, **kw) # We check at each step of advection if particle distance is over `dist_max` for i in range(time_step_by_days * nb_days): @@ -176,7 +182,7 @@ def build_triplet(x, y, step=0.02): ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) ax.set_aspect("equal") ax.set_title("Finite size lyapunov exponent", weight="bold") -kw = dict(cmap="viridis_r", vmin=-15, vmax=0) +kw = dict(cmap="viridis_r", vmin=-20, vmax=0) m = fsle_custom.display(ax, 1 / fsle_custom.grid("fsle"), **kw) ax.grid() _ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9])) diff --git a/examples/07_cube_manipulation/pet_lavd_detection.py b/examples/07_cube_manipulation/pet_lavd_detection.py index dc3a83ff..4dace120 100644 --- a/examples/07_cube_manipulation/pet_lavd_detection.py +++ b/examples/07_cube_manipulation/pet_lavd_detection.py @@ -23,10 +23,10 @@ from matplotlib import pyplot as plt from numpy import arange, isnan, ma, meshgrid, zeros -import py_eddy_tracker.gui from py_eddy_tracker import start_logger from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset +from py_eddy_tracker.gui import GUI_AXES start_logger().setLevel("ERROR") @@ -47,7 +47,7 @@ def from_(cls, x, y, z): # %% def start_ax(title="", dpi=90): fig = plt.figure(figsize=(12, 5), dpi=dpi) - ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection="full_axes") + ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=GUI_AXES) ax.set_xlim(-6, 36), ax.set_ylim(31, 45) ax.set_title(title) return fig, ax, ax.text(3, 32, "", fontsize=20) @@ -93,7 +93,7 @@ def update_axes(ax, mappable=None): # Time properties, for example with advection only 25 days nb_days, step_by_day = 25, 6 nb_time = step_by_day * nb_days -kw_p = dict(nb_step=1, time_step=86400 / step_by_day) +kw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name="u", v_name="v") t0 = 20236 t0_grid = c[t0] # Geographic properties, we use a coarser resolution for time consuming reasons @@ -114,7 +114,7 @@ def update_axes(ax, mappable=None): # ---------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = c.advect(x0.copy(), y0.copy(), "u", "v", t_init=t0, **kw_p) +p = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p) for _ in range(nb_time): t, x, y = p.__next__() lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) @@ -131,7 +131,7 @@ def update_axes(ax, mappable=None): # ----------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = c.advect(x0.copy(), y0.copy(), "u", "v", t_init=t0, backward=True, **kw_p) +p = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p) for i in range(nb_time): t, x, y = p.__next__() lavd_ += abs(c.interp("vort", t / 86400.0, x, y)) @@ -148,7 +148,7 @@ def update_axes(ax, mappable=None): # --------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = t0_grid.advect(x0.copy(), y0.copy(), "u", "v", **kw_p) +p = t0_grid.advect(x0.copy(), y0.copy(), **kw_p) for _ in range(nb_time): x, y = p.__next__() lavd_ += abs(t0_grid.interp("vort", x, y)) @@ -165,7 +165,7 @@ def update_axes(ax, mappable=None): # ---------------------------- lavd = zeros(original_shape) lavd_ = lavd[m] -p = t0_grid.advect(x0.copy(), y0.copy(), "u", "v", backward=True, **kw_p) +p = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p) for i in range(nb_time): x, y = p.__next__() lavd_ += abs(t0_grid.interp("vort", x, y)) diff --git a/examples/07_cube_manipulation/pet_particles_drift.py b/examples/07_cube_manipulation/pet_particles_drift.py new file mode 100644 index 00000000..3d7aa1a4 --- /dev/null +++ b/examples/07_cube_manipulation/pet_particles_drift.py @@ -0,0 +1,46 @@ +""" +Build path of particle drifting +=============================== + +""" + +from matplotlib import pyplot as plt +from numpy import arange, meshgrid + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.dataset.grid import GridCollection + +start_logger().setLevel("ERROR") + +# %% +# Load data cube +c = GridCollection.from_netcdf_cube( + get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), + "longitude", + "latitude", + "time", + unset=True, +) + +# %% +# Advection properties +nb_days, step_by_day = 10, 6 +nb_time = step_by_day * nb_days +kw_p = dict(nb_step=1, time_step=86400 / step_by_day) +t0 = 20210 + +# %% +# Get paths +x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5)) +x0, y0 = x0.reshape(-1), y0.reshape(-1) +t, x, y = c.path(x0, y0, h_name="adt", t_init=t0, **kw_p, nb_time=nb_time) + +# %% +# Plot paths +ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect="equal") +ax.plot(x0, y0, "k.", ms=20) +ax.plot(x, y, lw=3) +ax.set_title("10 days particle paths") +ax.set_xlim(31, 35), ax.set_ylim(32, 34.5) +ax.grid() diff --git a/examples/08_tracking_manipulation/pet_display_field.py b/examples/08_tracking_manipulation/pet_display_field.py index 30ad75a6..b943a2ba 100644 --- a/examples/08_tracking_manipulation/pet_display_field.py +++ b/examples/08_tracking_manipulation/pet_display_field.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_display_track.py b/examples/08_tracking_manipulation/pet_display_track.py index 13a8d3ad..b15d51d7 100644 --- a/examples/08_tracking_manipulation/pet_display_track.py +++ b/examples/08_tracking_manipulation/pet_display_track.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py new file mode 100644 index 00000000..8161ad81 --- /dev/null +++ b/examples/08_tracking_manipulation/pet_how_to_use_correspondances.py @@ -0,0 +1,94 @@ +""" +Correspondances +=============== + +Correspondances is a mechanism to intend to continue tracking with new detection + +""" + +import logging + +# %% +from matplotlib import pyplot as plt +from netCDF4 import Dataset + +from py_eddy_tracker import start_logger +from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker + +# In order to hide some warning +import py_eddy_tracker.observations.observation +from py_eddy_tracker.tracking import Correspondances + +py_eddy_tracker.observations.observation._display_check_warning = False + + +# %% +def plot_eddy(ed): + fig = plt.figure(figsize=(10, 5)) + ax = fig.add_axes([0.05, 0.03, 0.90, 0.94]) + ed.plot(ax, ref=-10, marker="x") + lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True) + plt.colorbar(lc).set_label("Time in Julian days (from 1950/01/01)") + ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3) + ax.set_aspect("equal") + ax.grid() + + +# %% +# Get remote data, we will keep only 20 first days, +# `get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename +# and don't mix cyclonic and anticyclonic files. +file_objects = get_remote_demo_sample( + "eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012" +)[:20] + +# %% +# We run a traking with a tracker which use contour overlap, on 10 first time step +c_first_run = Correspondances( + datasets=file_objects[:10], class_method=AreaTracker, virtual=4 +) +start_logger().setLevel("INFO") +c_first_run.track() +start_logger().setLevel("WARNING") +with Dataset("correspondances.nc", "w") as h: + c_first_run.to_netcdf(h) +# Next step are done only to build atlas and display it +c_first_run.prepare_merging() + +# We have now an eddy object +eddies_area_tracker = c_first_run.merge(raw_data=False) +eddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0 +eddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1) + +# %% +# Plot from first ten days +plot_eddy(eddies_area_tracker) + +# %% +# Restart from previous run +# ------------------------- +# We give all filenames, the new one and filename from previous run +c_second_run = Correspondances( + datasets=file_objects[:20], + # This parameter must be identical in each run + class_method=AreaTracker, + virtual=4, + # Previous saved correspondancs + previous_correspondance="correspondances.nc", +) +start_logger().setLevel("INFO") +c_second_run.track() +start_logger().setLevel("WARNING") +c_second_run.prepare_merging() +# We have now another eddy object +eddies_area_tracker_extend = c_second_run.merge(raw_data=False) +eddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0 +eddies_area_tracker_extend.filled_by_interpolation( + eddies_area_tracker_extend.virtual == 1 +) + + +# %% +# Plot with time extension +plot_eddy(eddies_area_tracker_extend) diff --git a/examples/08_tracking_manipulation/pet_one_track.py b/examples/08_tracking_manipulation/pet_one_track.py index 9f930281..a2536c34 100644 --- a/examples/08_tracking_manipulation/pet_one_track.py +++ b/examples/08_tracking_manipulation/pet_one_track.py @@ -2,8 +2,8 @@ One Track =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_select_track_across_area.py b/examples/08_tracking_manipulation/pet_select_track_across_area.py index b88f37e1..58184e1f 100644 --- a/examples/08_tracking_manipulation/pet_select_track_across_area.py +++ b/examples/08_tracking_manipulation/pet_select_track_across_area.py @@ -3,8 +3,8 @@ ============================ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/08_tracking_manipulation/pet_track_anim.py b/examples/08_tracking_manipulation/pet_track_anim.py index 0c18a0ba..94e09ad3 100644 --- a/examples/08_tracking_manipulation/pet_track_anim.py +++ b/examples/08_tracking_manipulation/pet_track_anim.py @@ -2,7 +2,9 @@ Track animation =============== -Run in a terminal this script, which allow to watch eddy evolution +Run in a terminal this script, which allow to watch eddy evolution. + +You could use also *EddyAnim* script to display/save animation. """ import py_eddy_tracker_sample diff --git a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py index 6776b47e..b686fd67 100644 --- a/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py +++ b/examples/08_tracking_manipulation/pet_track_anim_matplotlib_animation.py @@ -2,14 +2,16 @@ Track animation with standard matplotlib ======================================== -Run in a terminal this script, which allow to watch eddy evolution +Run in a terminal this script, which allow to watch eddy evolution. + +You could use also *EddyAnim* script to display/save animation. """ import re -import py_eddy_tracker_sample from matplotlib.animation import FuncAnimation from numpy import arange +import py_eddy_tracker_sample from py_eddy_tracker.appli.gui import Anim from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -28,7 +30,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass diff --git a/examples/10_tracking_diagnostics/pet_birth_and_death.py b/examples/10_tracking_diagnostics/pet_birth_and_death.py index d917efbd..b67993a2 100644 --- a/examples/10_tracking_diagnostics/pet_birth_and_death.py +++ b/examples/10_tracking_diagnostics/pet_birth_and_death.py @@ -5,8 +5,8 @@ Following figures are based on https://doi.org/10.1016/j.pocean.2011.01.002 """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_center_count.py b/examples/10_tracking_diagnostics/pet_center_count.py index 6d9fa417..77a4dcda 100644 --- a/examples/10_tracking_diagnostics/pet_center_count.py +++ b/examples/10_tracking_diagnostics/pet_center_count.py @@ -5,9 +5,9 @@ Do Geo stat with center and compare with frequency method show: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_pixel_used.py` """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.colors import LogNorm +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -27,7 +27,7 @@ step = 0.125 bins = ((-10, 37, step), (30, 46, step)) kwargs_pcolormesh = dict( - cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step ** 2), name="count" + cmap="terrain_r", vmin=0, vmax=2, factor=1 / (a.nb_days * step**2), name="count" ) diff --git a/examples/10_tracking_diagnostics/pet_geographic_stats.py b/examples/10_tracking_diagnostics/pet_geographic_stats.py index d2a7e90d..a2e3f6b5 100644 --- a/examples/10_tracking_diagnostics/pet_geographic_stats.py +++ b/examples/10_tracking_diagnostics/pet_geographic_stats.py @@ -4,8 +4,8 @@ """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_groups.py b/examples/10_tracking_diagnostics/pet_groups.py index f6e800ae..deedcc3f 100644 --- a/examples/10_tracking_diagnostics/pet_groups.py +++ b/examples/10_tracking_diagnostics/pet_groups.py @@ -3,9 +3,9 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones, percentile +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_histo.py b/examples/10_tracking_diagnostics/pet_histo.py index b2eff842..abf97c38 100644 --- a/examples/10_tracking_diagnostics/pet_histo.py +++ b/examples/10_tracking_diagnostics/pet_histo.py @@ -3,9 +3,9 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_lifetime.py b/examples/10_tracking_diagnostics/pet_lifetime.py index 9f84e790..4e2500fd 100644 --- a/examples/10_tracking_diagnostics/pet_lifetime.py +++ b/examples/10_tracking_diagnostics/pet_lifetime.py @@ -3,9 +3,9 @@ =================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_normalised_lifetime.py b/examples/10_tracking_diagnostics/pet_normalised_lifetime.py new file mode 100644 index 00000000..1c84a8cc --- /dev/null +++ b/examples/10_tracking_diagnostics/pet_normalised_lifetime.py @@ -0,0 +1,78 @@ +""" +Normalised Eddy Lifetimes +========================= + +Example from Evan Mason +""" +from matplotlib import pyplot as plt +from numba import njit +from numpy import interp, linspace, zeros +from py_eddy_tracker_sample import get_demo_path + +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +@njit(cache=True) +def sum_profile(x_new, y, out): + """Will sum all interpolated given array""" + out += interp(x_new, linspace(0, 1, y.size), y) + + +class MyObs(TrackEddiesObservations): + def eddy_norm_lifetime(self, name, nb, factor=1): + """ + :param str,array name: Array or field name + :param int nb: size of output array + """ + y = self.parse_varname(name) + x = linspace(0, 1, nb) + out = zeros(nb, dtype=y.dtype) + nb_track = 0 + for i, b0, b1 in self.iter_on("track"): + y_ = y[i] + size_ = y_.size + if size_ == 0: + continue + sum_profile(x, y_, out) + nb_track += 1 + return x, out / nb_track * factor + + +# %% +# Load atlas +# ---------- +kw = dict(include_vars=("speed_radius", "amplitude", "track")) +a = MyObs.load_file( + get_demo_path("eddies_med_adt_allsat_dt2018/Anticyclonic.zarr"), **kw +) +c = MyObs.load_file(get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr"), **kw) + +nb_max_a = a.nb_obs_by_track.max() +nb_max_c = c.nb_obs_by_track.max() + +# %% +# Compute normalised lifetime +# --------------------------- + +# Radius +AC_radius = a.eddy_norm_lifetime("speed_radius", nb=nb_max_a, factor=1e-3) +CC_radius = c.eddy_norm_lifetime("speed_radius", nb=nb_max_c, factor=1e-3) +# Amplitude +AC_amplitude = a.eddy_norm_lifetime("amplitude", nb=nb_max_a, factor=1e2) +CC_amplitude = c.eddy_norm_lifetime("amplitude", nb=nb_max_c, factor=1e2) + +# %% +# Figure +# ------ +fig, (ax0, ax1) = plt.subplots(nrows=2, figsize=(8, 6)) + +ax0.set_title("Normalised Mean Radius") +ax0.plot(*AC_radius), ax0.plot(*CC_radius) +ax0.set_ylabel("Radius (km)"), ax0.grid() +ax0.set_xlim(0, 1), ax0.set_ylim(0, None) + +ax1.set_title("Normalised Mean Amplitude") +ax1.plot(*AC_amplitude, label="AC"), ax1.plot(*CC_amplitude, label="CC") +ax1.set_ylabel("Amplitude (cm)"), ax1.grid(), ax1.legend() +_ = ax1.set_xlim(0, 1), ax1.set_ylim(0, None) diff --git a/examples/10_tracking_diagnostics/pet_pixel_used.py b/examples/10_tracking_diagnostics/pet_pixel_used.py index 3907ce19..75a826d6 100644 --- a/examples/10_tracking_diagnostics/pet_pixel_used.py +++ b/examples/10_tracking_diagnostics/pet_pixel_used.py @@ -5,9 +5,9 @@ Do Geo stat with frequency and compare with center count method: :ref:`sphx_glr_python_module_10_tracking_diagnostics_pet_center_count.py` """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from matplotlib.colors import LogNorm +import py_eddy_tracker_sample from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/10_tracking_diagnostics/pet_propagation.py b/examples/10_tracking_diagnostics/pet_propagation.py index 6a65a212..e6bc6c1b 100644 --- a/examples/10_tracking_diagnostics/pet_propagation.py +++ b/examples/10_tracking_diagnostics/pet_propagation.py @@ -3,9 +3,9 @@ ===================== """ -import py_eddy_tracker_sample from matplotlib import pyplot as plt from numpy import arange, ones +import py_eddy_tracker_sample from py_eddy_tracker.generic import cumsum_by_track from py_eddy_tracker.observations.tracking import TrackEddiesObservations diff --git a/examples/12_external_data/pet_drifter_loopers.py b/examples/12_external_data/pet_drifter_loopers.py new file mode 100644 index 00000000..5266db7b --- /dev/null +++ b/examples/12_external_data/pet_drifter_loopers.py @@ -0,0 +1,153 @@ +""" +Colocate looper with eddy from altimetry +======================================== + +All loopers data used in this example are a subset from the dataset described in this article +[Lumpkin, R. : Global characteristics of coherent vortices from surface drifter trajectories](https://doi.org/10.1002/2015JC011435) +""" + +import re + +from matplotlib import pyplot as plt +from matplotlib.animation import FuncAnimation +import numpy as np +import py_eddy_tracker_sample + +from py_eddy_tracker import data +from py_eddy_tracker.appli.gui import Anim +from py_eddy_tracker.observations.tracking import TrackEddiesObservations + + +# %% +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) + + +def start_axes(title): + fig = plt.figure(figsize=(13, 5)) + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], aspect="equal") + ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) + ax.set_title(title, weight="bold") + return ax + + +def update_axes(ax, mappable=None): + ax.grid() + if mappable: + plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9])) + + +# %% +# Load eddies dataset +cyclonic_eddies = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path("eddies_med_adt_allsat_dt2018/Cyclonic.zarr") +) +anticyclonic_eddies = TrackEddiesObservations.load_file( + py_eddy_tracker_sample.get_demo_path( + "eddies_med_adt_allsat_dt2018/Anticyclonic.zarr" + ) +) + +# %% +# Load loopers dataset +loopers_med = TrackEddiesObservations.load_file( + data.get_demo_path("loopers_lumpkin_med.nc") +) + +# %% +# Global view +# =========== +ax = start_axes("All drifters available in Med from Lumpkin dataset") +loopers_med.plot(ax, lw=0.5, color="r", ref=-10) +update_axes(ax) + +# %% +# One segment of drifter +# ====================== +# +# Get a drifter segment (the indexes used have no correspondance with the original dataset). +looper = loopers_med.extract_ids((3588,)) +fig = plt.figure(figsize=(16, 6)) +ax = fig.add_subplot(111, aspect="equal") +looper.plot(ax, lw=0.5, label="Original position of drifter") +looper_filtered = looper.copy() +looper_filtered.position_filter(1, 13) +s = looper_filtered.scatter( + ax, + "time", + cmap=plt.get_cmap("Spectral_r", 20), + label="Filtered position of drifter", +) +plt.colorbar(s).set_label("time (days from 1/1/1950)") +ax.legend() +ax.grid() + +# %% +# Try to find a detected eddies with adt at same place. We used filtered track to simulate an eddy center +match = looper_filtered.close_tracks( + anticyclonic_eddies, method="close_center", delta=0.1, nb_obs_min=50 +) +fig = plt.figure(figsize=(16, 6)) +ax = fig.add_subplot(111, aspect="equal") +looper.plot(ax, lw=0.5, label="Original position of drifter") +looper_filtered.plot(ax, lw=1.5, label="Filtered position of drifter") +match.plot(ax, lw=1.5, label="Matched eddy") +ax.legend() +ax.grid() + +# %% +# Display radius of this 2 datasets. +fig = plt.figure(figsize=(20, 8)) +ax = fig.add_subplot(111) +ax.plot(looper.time, looper.radius_s / 1e3, label="loopers") +looper_radius = looper.copy() +looper_radius.median_filter(1, "time", "radius_s", inplace=True) +looper_radius.loess_filter(13, "time", "radius_s", inplace=True) +ax.plot( + looper_radius.time, + looper_radius.radius_s / 1e3, + label="loopers (filtered half window 13 days)", +) +ax.plot(match.time, match.radius_s / 1e3, label="altimetry") +match_radius = match.copy() +match_radius.median_filter(1, "time", "radius_s", inplace=True) +match_radius.loess_filter(13, "time", "radius_s", inplace=True) +ax.plot( + match_radius.time, + match_radius.radius_s / 1e3, + label="altimetry (filtered half window 13 days)", +) +ax.set_ylabel("radius(km)"), ax.set_ylim(0, 100) +ax.legend() +ax.set_title("Radius from loopers and altimeter") +ax.grid() + + +# %% +# Animation of a drifter and its colocated eddy +def update(frame): + # We display last 5 days of loopers trajectory + m = (looper.time < frame) * (looper.time > (frame - 5)) + anim.func_animation(frame) + line.set_data(looper.lon[m], looper.lat[m]) + + +anim = Anim(match, intern=True, figsize=(8, 8), cmap="magma_r", nb_step=10, dpi=75) +# mappable to show drifter in red +line = anim.ax.plot([], [], "r", lw=4, zorder=100)[0] +anim.fig.suptitle("") +_ = VideoAnimation(anim.fig, update, frames=np.arange(*anim.period, 1), interval=125) diff --git a/examples/14_generic_tools/pet_fit_contour.py b/examples/14_generic_tools/pet_fit_contour.py index 9c3f9183..2d3b6dc9 100644 --- a/examples/14_generic_tools/pet_fit_contour.py +++ b/examples/14_generic_tools/pet_fit_contour.py @@ -15,7 +15,7 @@ from py_eddy_tracker import data from py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates from py_eddy_tracker.observations.observation import EddiesObservations -from py_eddy_tracker.poly import fit_circle_, fit_ellips +from py_eddy_tracker.poly import fit_circle_, fit_ellipse # %% # Load example identification file @@ -23,14 +23,14 @@ # %% -# Function to draw circle or ellips from parameter +# Function to draw circle or ellipse from parameter def build_circle(x0, y0, r): angle = radians(linspace(0, 360, 50)) x_norm, y_norm = cos(angle), sin(angle) return local_to_coordinates(x_norm * r, y_norm * r, x0, y0) -def build_ellips(x0, y0, a, b, theta): +def build_ellipse(x0, y0, a, b, theta): angle = radians(linspace(0, 360, 50)) x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle) y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle) @@ -38,7 +38,7 @@ def build_ellips(x0, y0, a, b, theta): # %% -# Plot fitted circle or ellips on stored contour +# Plot fitted circle or ellipse on stored contour xs, ys = a.contour_lon_s, a.contour_lat_s fig = plt.figure(figsize=(15, 15)) @@ -51,9 +51,9 @@ def build_ellips(x0, y0, a, b, theta): ax = fig.add_subplot(4, 4, j) ax.grid(), ax.set_aspect("equal") ax.plot(x, y, label="store", color="black") - x0, y0, a, b, theta = fit_ellips(x_, y_) + x0, y0, a, b, theta = fit_ellipse(x_, y_) x0, y0 = local_to_coordinates(x0, y0, x0_, y0_) - ax.plot(*build_ellips(x0, y0, a, b, theta), label="ellips", color="green") + ax.plot(*build_ellipse(x0, y0, a, b, theta), label="ellipse", color="green") x0, y0, radius, shape_error = fit_circle_(x_, y_) x0, y0 = local_to_coordinates(x0, y0, x0_, y0_) ax.plot(*build_circle(x0, y0, radius), label="circle", color="red", lw=0.5) diff --git a/examples/14_generic_tools/pet_visvalingam.py b/examples/14_generic_tools/pet_visvalingam.py index f7b29c10..736e8852 100644 --- a/examples/14_generic_tools/pet_visvalingam.py +++ b/examples/14_generic_tools/pet_visvalingam.py @@ -2,8 +2,8 @@ Visvalingam algorithm ===================== """ -import matplotlib.animation as animation from matplotlib import pyplot as plt +import matplotlib.animation as animation from numba import njit from numpy import array, empty diff --git a/examples/16_network/pet_atlas.py b/examples/16_network/pet_atlas.py index 540d312d..48b374e2 100644 --- a/examples/16_network/pet_atlas.py +++ b/examples/16_network/pet_atlas.py @@ -5,8 +5,8 @@ from matplotlib import pyplot as plt from numpy import ma -import py_eddy_tracker.gui from py_eddy_tracker.data import get_remote_demo_sample +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.network import NetworkObservations n = NetworkObservations.load_file( @@ -26,7 +26,7 @@ # Functions def start_axes(title): fig = plt.figure(figsize=(13, 5)) - ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection="full_axes") + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46) ax.set_aspect("equal") ax.set_title(title, weight="bold") @@ -129,7 +129,9 @@ def update_axes(ax, mappable=None): # Merging in networks longer than 10 days, with dead end remove (shorter than 10 observations) # -------------------------------------------------------------------------------------------- ax = start_axes("") -merger = n10.remove_dead_end(nobs=10).merging_event() +n10_ = n10.copy() +n10_.remove_dead_end(nobs=10) +merger = n10_.merging_event() g_10_merging = merger.grid_count(bins) m = g_10_merging.display(ax, **kw_time, vmin=0, vmax=1) update_axes(ax, m).set_label("Pixel used in % of time") @@ -153,33 +155,33 @@ def update_axes(ax, mappable=None): update_axes(ax, m).set_label("Pixel used in % all atlas") # %% -# All Spliting -# ------------ -# Display the occurence of spliting events +# All splitting +# ------------- +# Display the occurence of splitting events ax = start_axes("") -g_all_spliting = n.spliting_event().grid_count(bins) -m = g_all_spliting.display(ax, **kw_time, vmin=0, vmax=1) +g_all_splitting = n.splitting_event().grid_count(bins) +m = g_all_splitting.display(ax, **kw_time, vmin=0, vmax=1) update_axes(ax, m).set_label("Pixel used in % of time") # %% -# Ratio spliting events / eddy presence +# Ratio splitting events / eddy presence ax = start_axes("") -g_ = g_all_spliting.vars["count"] * 100.0 / g_all.vars["count"] -m = g_all_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +g_ = g_all_splitting.vars["count"] * 100.0 / g_all.vars["count"] +m = g_all_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) update_axes(ax, m).set_label("Pixel used in % all atlas") # %% -# Spliting in networks longer than 10 days -# ---------------------------------------- +# splitting in networks longer than 10 days +# ----------------------------------------- ax = start_axes("") -g_10_spliting = n10.spliting_event().grid_count(bins) -m = g_10_spliting.display(ax, **kw_time, vmin=0, vmax=1) +g_10_splitting = n10.splitting_event().grid_count(bins) +m = g_10_splitting.display(ax, **kw_time, vmin=0, vmax=1) update_axes(ax, m).set_label("Pixel used in % of time") # %% ax = start_axes("") g_ = ma.array( - g_10_spliting.vars["count"] * 100.0 / g_10.vars["count"], + g_10_splitting.vars["count"] * 100.0 / g_10.vars["count"], mask=g_10.vars["count"] < 365, ) -m = g_10_spliting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) +m = g_10_splitting.display(ax, **kw_ratio, vmin=0, vmax=5, name=g_) update_axes(ax, m).set_label("Pixel used in % all atlas") diff --git a/examples/16_network/pet_follow_particle.py b/examples/16_network/pet_follow_particle.py index 65746015..6815fb6e 100644 --- a/examples/16_network/pet_follow_particle.py +++ b/examples/16_network/pet_follow_particle.py @@ -5,19 +5,16 @@ """ import re -from matplotlib import colors -from matplotlib import pyplot as plt +from matplotlib import colors, pyplot as plt from matplotlib.animation import FuncAnimation -from numba import njit -from numba import types as nb_types -from numpy import arange, meshgrid, ones, unique, where, zeros +from numpy import arange, meshgrid, ones, unique, zeros from py_eddy_tracker import start_logger from py_eddy_tracker.appli.gui import Anim from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.dataset.grid import GridCollection +from py_eddy_tracker.observations.groups import particle_candidate from py_eddy_tracker.observations.network import NetworkObservations -from py_eddy_tracker.poly import group_obs start_logger().setLevel("ERROR") @@ -33,7 +30,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass @@ -44,7 +41,8 @@ def save(self, *args, **kwargs): # %% n = NetworkObservations.load_file(get_demo_path("network_med.nc")).network(651) n = n.extract_with_mask((n.time >= 20180) * (n.time <= 20269)) -n = n.remove_dead_end(nobs=0, ndays=10) +n.remove_dead_end(nobs=0, ndays=10) +n = n.remove_trash() n.numbering_segment() c = GridCollection.from_netcdf_cube( get_demo_path("dt_med_allsat_phy_l4_2005T2.nc"), @@ -98,11 +96,17 @@ def save(self, *args, **kwargs): a.txt.set_position((25, 31)) step = 0.25 -kw_p = dict(nb_step=2, time_step=86400 * step * 0.5, t_init=t_snapshot - 2 * step) +kw_p = dict( + nb_step=2, + time_step=86400 * step * 0.5, + t_init=t_snapshot - 2 * step, + u_name="u", + v_name="v", +) mappables = dict() -particules = c.advect(x, y, "u", "v", **kw_p) -filament = c.filament(x_f, y_f, "u", "v", **kw_p, filament_size=3) +particules = c.advect(x, y, **kw_p) +filament = c.filament(x_f, y_f, **kw_p, filament_size=3) kw = dict(ls="", marker=".", markersize=0.25) for k in index_: m = k == index @@ -124,105 +128,26 @@ def update(frame): ani = VideoAnimation(a.fig, update, frames=arange(20200, 20269, step), interval=200) -# %% -# In which observations are the particle -# -------------------------------------- -def advect(x, y, c, t0, delta_t): - """ - Advect particle from t0 to t0 + delta_t, with data cube. - """ - kw = dict(nb_step=6, time_step=86400 / 6) - if delta_t < 0: - kw["backward"] = True - delta_t = -delta_t - p = c.advect(x, y, "u", "v", t_init=t0, **kw) - for _ in range(delta_t): - t, x, y = p.__next__() - return t, x, y - - -def particle_candidate(x, y, c, eddies, t_start, i_target, pct, **kwargs): - # Obs from initial time - m_start = eddies.time == t_start - e = eddies.extract_with_mask(m_start) - # to be able to get global index - translate_start = where(m_start)[0] - # Identify particle in eddies(only in core) - i_start = e.contains(x, y, intern=True) - m = i_start != -1 - x, y, i_start = x[m], y[m], i_start[m] - # Advect - t_end, x, y = advect(x, y, c, t_start, **kwargs) - # eddies at last date - m_end = eddies.time == t_end / 86400 - e_end = eddies.extract_with_mask(m_end) - # to be able to get global index - translate_end = where(m_end)[0] - # Id eddies for each alive particle(in core and extern) - i_end = e_end.contains(x, y) - # compute matrix and filled target array - get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct) - - -@njit(cache=True) -def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct): - nb_start, nb_end = translate_start.size, translate_end.size - # Matrix which will store count for every couple - count = zeros((nb_start, nb_end), dtype=nb_types.int32) - # Number of particle in each origin observation - ref = zeros(nb_start, dtype=nb_types.int32) - # For each particle - for i in range(i_start.size): - i_end_ = i_end[i] - i_start_ = i_start[i] - if i_end_ != -1: - count[i_start_, i_end_] += 1 - ref[i_start_] += 1 - for i in range(nb_start): - for j in range(nb_end): - pct_ = count[i, j] - # If there are particle from i to j - if pct_ != 0: - # Get percent - pct_ = pct_ / ref[i] * 100.0 - # Get indices in full dataset - i_, j_ = translate_start[i], translate_end[j] - pct_0 = pct[i_, 0] - if pct_ > pct_0: - pct[i_, 1] = pct_0 - pct[i_, 0] = pct_ - i_target[i_, 1] = i_target[i_, 0] - i_target[i_, 0] = j_ - elif pct_ > pct[i_, 1]: - pct[i_, 1] = pct_ - i_target[i_, 1] = j_ - return i_target, pct - - # %% # Particle advection # ^^^^^^^^^^^^^^^^^^ -step = 1 / 60.0 +# Advection from speed contour to speed contour (default) -x, y = meshgrid(arange(24, 36, step), arange(31, 36, step)) -x0, y0 = x.reshape(-1), y.reshape(-1) -# Pre-order to speed up -_, i = group_obs(x0, y0, 1, 360) -x0, y0 = x0[i], y0[i] +step = 1 / 60.0 -t_start, t_end = n.period +t_start, t_end = int(n.period[0]), int(n.period[1]) dt = 14 shape = (n.obs.size, 2) # Forward run i_target_f, pct_target_f = -ones(shape, dtype="i4"), zeros(shape, dtype="i1") -for t in range(t_start, t_end - dt): - particle_candidate(x0, y0, c, n, t, i_target_f, pct_target_f, delta_t=dt) +for t in arange(t_start, t_end - dt): + particle_candidate(c, n, step, t, i_target_f, pct_target_f, n_days=dt) # Backward run i_target_b, pct_target_b = -ones(shape, dtype="i4"), zeros(shape, dtype="i1") -for t in range(t_start + dt, t_end): - particle_candidate(x0, y0, c, n, t, i_target_b, pct_target_b, delta_t=-dt) +for t in arange(t_start + dt, t_end): + particle_candidate(c, n, step, t, i_target_b, pct_target_b, n_days=-dt) # %% fig = plt.figure(figsize=(10, 10)) @@ -232,6 +157,11 @@ def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct): ax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45]) ax_1st_b.set_title("Backward advection for each time step") ax_1st_f.set_title("Forward advection for each time step") +ax_1st_b.set_ylabel("Color -> First target\nLatitude") +ax_2nd_b.set_ylabel("Color -> Secondary target\nLatitude") +ax_2nd_b.set_xlabel("Julian days"), ax_2nd_f.set_xlabel("Julian days") +ax_1st_f.set_yticks([]), ax_2nd_f.set_yticks([]) +ax_1st_f.set_xticks([]), ax_1st_b.set_xticks([]) def color_alpha(target, pct, vmin=5, vmax=80): diff --git a/examples/16_network/pet_group_anim.py b/examples/16_network/pet_group_anim.py index 8ecee534..f2d439ed 100644 --- a/examples/16_network/pet_group_anim.py +++ b/examples/16_network/pet_group_anim.py @@ -2,9 +2,10 @@ Network group process ===================== """ +from datetime import datetime + # sphinx_gallery_thumbnail_number = 2 import re -from datetime import datetime from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation @@ -29,7 +30,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass diff --git a/examples/16_network/pet_ioannou_2017_case.py b/examples/16_network/pet_ioannou_2017_case.py index 7669b010..56bec82e 100644 --- a/examples/16_network/pet_ioannou_2017_case.py +++ b/examples/16_network/pet_ioannou_2017_case.py @@ -6,23 +6,26 @@ We want to find the Ierapetra Eddy described above in a network demonstration run. """ +from datetime import datetime, timedelta + # %% import re -from datetime import datetime, timedelta -from matplotlib import colors -from matplotlib import pyplot as plt +from matplotlib import colors, pyplot as plt from matplotlib.animation import FuncAnimation from matplotlib.ticker import FuncFormatter -from numpy import arange, where +from numpy import arange, array, pi, where -import py_eddy_tracker.gui from py_eddy_tracker.appli.gui import Anim from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.generic import coordinates_to_local +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.network import NetworkObservations - +from py_eddy_tracker.poly import fit_ellipse # %% + + class VideoAnimation(FuncAnimation): def _repr_html_(self, *args, **kwargs): """To get video in html and have a player""" @@ -33,7 +36,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass @@ -48,7 +51,7 @@ def formatter(x, pos): def start_axes(title=""): fig = plt.figure(figsize=(13, 6)) - ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection="full_axes") + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) ax.set_xlim(19, 29), ax.set_ylim(31, 35.5) ax.set_aspect("equal") ax.set_title(title, weight="bold") @@ -188,3 +191,47 @@ def update_axes(ax, mappable=None): m = close_to_i3.scatter_timeline(ax, "shape_error_e", vmin=14, vmax=70, **kw) cb = update_axes(ax, m["scatter"]) cb.set_label("Effective shape error") + +# %% +# Rotation angle +# -------------- +# For each obs, fit an ellipse to the contour, with theta the angle from the x-axis, +# a the semi ax in x direction and b the semi ax in y dimension + +theta_ = list() +a_ = list() +b_ = list() +for obs in close_to_i3: + x, y = obs["contour_lon_s"], obs["contour_lat_s"] + x0_, y0_ = x.mean(), y.mean() + x_, y_ = coordinates_to_local(x, y, x0_, y0_) + x0, y0, a, b, theta = fit_ellipse(x_, y_) + theta_.append(theta) + a_.append(a) + b_.append(b) +a_ = array(a_) +b_ = array(b_) + +# %% +# Theta +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, theta_, vmin=-pi / 2, vmax=pi / 2, cmap="hsv") +_ = update_axes(ax, m["scatter"]) + +# %% +# a +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, a_ * 1e-3, vmin=0, vmax=80, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) + +# %% +# b +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, b_ * 1e-3, vmin=0, vmax=80, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) + +# %% +# a/b +ax = timeline_axes() +m = close_to_i3.scatter_timeline(ax, a_ / b_, vmin=1, vmax=2, cmap="Spectral_r") +_ = update_axes(ax, m["scatter"]) diff --git a/examples/16_network/pet_relative.py b/examples/16_network/pet_relative.py index 2759edb4..dd97b538 100644 --- a/examples/16_network/pet_relative.py +++ b/examples/16_network/pet_relative.py @@ -5,8 +5,8 @@ from matplotlib import pyplot as plt from numpy import where -import py_eddy_tracker.gui from py_eddy_tracker import data +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.network import NetworkObservations # %% @@ -127,7 +127,9 @@ # Remove dead branch # ------------------ # Remove all tiny segments with less than N obs which didn't join two segments -n_clean = n.remove_dead_end(nobs=5, ndays=10) +n_clean = n.copy() +n_clean.remove_dead_end(nobs=5, ndays=10) +n_clean = n_clean.remove_trash() fig = plt.figure(figsize=(15, 12)) ax = fig.add_axes([0.04, 0.54, 0.90, 0.40]) ax.set_title(f"Original network ({n.infos()})") @@ -261,12 +263,14 @@ # -------------------- # Get a simplified network -n = n2.remove_dead_end(nobs=50, recursive=1) +n = n2.copy() +n.remove_dead_end(nobs=50, recursive=1) +n = n.remove_trash() n.numbering_segment() # %% # Only a map can be tricky to understand, with a timeline it's easier! fig = plt.figure(figsize=(15, 8)) -ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) n.plot(ax, color_cycle=n.COLORS) ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() ax = fig.add_axes([0.08, 0.7, 0.7, 0.3]) @@ -278,7 +282,7 @@ # ----------------- # Display the position of the eddies after a merging fig = plt.figure(figsize=(15, 8)) -ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) n.plot(ax, color_cycle=n.COLORS) m1, m0, m0_stop = n.merging_event(triplet=True) m1.display(ax, color="violet", lw=2, label="Eddies after merging") @@ -292,13 +296,13 @@ m1 # %% -# Get spliting event -# ------------------ +# Get splitting event +# ------------------- # Display the position of the eddies before a splitting fig = plt.figure(figsize=(15, 8)) -ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) n.plot(ax, color_cycle=n.COLORS) -s0, s1, s1_start = n.spliting_event(triplet=True) +s0, s1, s1_start = n.splitting_event(triplet=True) s0.display(ax, color="violet", lw=2, label="Eddies before splitting") s1.display(ax, color="blueviolet", lw=2, label="Eddies after splitting") s1_start.display(ax, color="black", lw=2, label="Eddies starting by splitting") @@ -314,7 +318,7 @@ # --------------- # Display the starting position of non-splitted eddies fig = plt.figure(figsize=(15, 8)) -ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) birth = n.birth_event() birth.display(ax) ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() @@ -325,7 +329,7 @@ # --------------- # Display the last position of non-merged eddies fig = plt.figure(figsize=(15, 8)) -ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES) death = n.death_event() death.display(ax) ax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid() diff --git a/examples/16_network/pet_replay_segmentation.py b/examples/16_network/pet_replay_segmentation.py index c33028fc..d909af7f 100644 --- a/examples/16_network/pet_replay_segmentation.py +++ b/examples/16_network/pet_replay_segmentation.py @@ -11,8 +11,8 @@ from matplotlib.ticker import FuncFormatter from numpy import where -import py_eddy_tracker.gui from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.network import NetworkObservations from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -24,7 +24,7 @@ def formatter(x, pos): def start_axes(title=""): fig = plt.figure(figsize=(13, 6)) - ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection="full_axes") + ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES) ax.set_xlim(19, 29), ax.set_ylim(31, 35.5) ax.set_aspect("equal") ax.set_title(title, weight="bold") @@ -147,15 +147,9 @@ def get_obs(dataset): ax = timeline_axes() n_.median_filter(15, "time", "latitude") -kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30 ** 2 * 20 +kw["s"] = (n_.radius_e * 1e-3) ** 2 / 30**2 * 20 m = n_.scatter_timeline( - ax, - "shape_error_e", - vmin=14, - vmax=70, - **kw, - yfield="lon", - method="all", + ax, "shape_error_e", vmin=14, vmax=70, **kw, yfield="lon", method="all" ) ax.set_ylabel("Longitude") cb = update_axes(ax, m["scatter"]) @@ -169,7 +163,6 @@ def get_obs(dataset): for b0, b1 in [ (datetime(i, 1, 1), datetime(i, 12, 31)) for i in (2004, 2005, 2006, 2007, 2008) ]: - ref, delta = datetime(1950, 1, 1), 20 b0_, b1_ = (b0 - ref).days, (b1 - ref).days ax = timeline_axes() diff --git a/examples/16_network/pet_segmentation_anim.py b/examples/16_network/pet_segmentation_anim.py index b2757809..1fcb9ae1 100644 --- a/examples/16_network/pet_segmentation_anim.py +++ b/examples/16_network/pet_segmentation_anim.py @@ -10,8 +10,8 @@ from matplotlib.colors import ListedColormap from numpy import ones, where -import py_eddy_tracker.gui from py_eddy_tracker.data import get_demo_path +from py_eddy_tracker.gui import GUI_AXES from py_eddy_tracker.observations.network import NetworkObservations from py_eddy_tracker.observations.tracking import TrackEddiesObservations @@ -27,7 +27,7 @@ def _repr_html_(self, *args, **kwargs): def save(self, *args, **kwargs): if args[0].endswith("gif"): - # In this case gif is use to create thumbnail which are not use but consume same time than video + # In this case gif is used to create thumbnail which is not used but consume same time than video # So we create an empty file, to save time with open(args[0], "w") as _: pass @@ -96,15 +96,14 @@ def update(i_frame): indices_frames = INDICES[i_frame] mappable_CONTOUR.set_data( - e.contour_lon_e[indices_frames], - e.contour_lat_e[indices_frames], + e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames] ) mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)]) return (mappable_tracks,) fig = plt.figure(figsize=(16, 9), dpi=60) -ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) ax.set_title(f"{len(e)} observations to segment") ax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid() vmax = TRACKS[-1].max() @@ -121,6 +120,6 @@ def update(i_frame): # Final Result # ------------ fig = plt.figure(figsize=(16, 9)) -ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection="full_axes") +ax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES) ax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid() _ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20) diff --git a/notebooks/python_module/01_general_things/pet_storage.ipynb b/notebooks/python_module/01_general_things/pet_storage.ipynb index 4b4a6630..a56e4def 100644 --- a/notebooks/python_module/01_general_things/pet_storage.ipynb +++ b/notebooks/python_module/01_general_things/pet_storage.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import py_eddy_tracker_sample\n\nfrom py_eddy_tracker.data import get_demo_path, get_remote_demo_sample\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.observation import EddiesObservations, Table\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + "import py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, outer\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.observation import EddiesObservations, Table\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -40,7 +40,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Eddies files (zarr or netcdf) could be loaded with ```load_file``` method:\n\n" + "Eddies files (zarr or netcdf) can be loaded with ```load_file``` method:\n\n" ] }, { @@ -108,6 +108,57 @@ "## Contour storage\nAll contours are stored on the same number of points, and are resampled if needed with an algorithm to be stored as objects\n\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Speed profile storage\nSpeed profile is an interpolation of speed mean along each contour.\nFor each contour included in eddy, we compute mean of speed along the contour,\nand after we interpolate speed mean array on a fixed size array.\n\nSeveral field are available to understand \"uavg_profile\" :\n 0. - num_contours : Number of contour in eddies, must be equal to amplitude divide by isoline step\n 1. - height_inner_contour : height of inner contour used\n 2. - height_max_speed_contour : height of max speed contour used\n 3. - height_external_contour : height of outter contour used\n\nLast value of \"uavg_profile\" is for inner contour and first value for outter contour.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Observations selection of \"uavg_profile\" with high number of contour(Eddy with high amplitude)\ne = eddies_collections.extract_with_mask(eddies_collections.num_contours > 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Raw display of profiles with more than 15 contours\nax = plt.subplot(111)\n_ = ax.plot(e.uavg_profile.T, lw=0.5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Profile from inner to outter\nax = plt.subplot(111)\nax.plot(e.uavg_profile[:, ::-1].T, lw=0.5)\n_ = ax.set_xlabel(\"From inner to outter contour\"), ax.set_ylabel(\"Speed (m/s)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# If we normalize indice of contour to set speed contour to 1 and inner contour to 0\nax = plt.subplot(111)\nh_in = e.height_inner_contour\nh_s = e.height_max_speed_contour\nh_e = e.height_external_contour\nr = (h_e - h_in) / (h_s - h_in)\nnb_pt = e.uavg_profile.shape[1]\n# Create an x array for each profile\nx = outer(arange(nb_pt) / nb_pt, r)\n\nax.plot(x, e.uavg_profile[:, ::-1].T, lw=0.5)\n_ = ax.set_xlabel(\"From inner to outter contour\"), ax.set_ylabel(\"Speed (m/s)\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -141,7 +192,7 @@ }, "outputs": [], "source": [ - "eddies_network = NetworkObservations.load_file(\n get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc\"\n )\n)\neddies_network.field_table()" + "eddies_network = NetworkObservations.load_file(get_demo_path(\"network_med.nc\"))\neddies_network.field_table()" ] }, { @@ -179,7 +230,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb b/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb index 36989357..2d924387 100644 --- a/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_contour_circle.ipynb @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb b/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb index 6d40974f..d59f9e15 100644 --- a/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_display_id.ipynb @@ -129,7 +129,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb index fb6c17f8..7469b034 100644 --- a/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection.ipynb @@ -291,7 +291,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb index c2a3648d..6ac75cee 100644 --- a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_ACC.ipynb @@ -161,7 +161,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb index c39bc011..49024327 100644 --- a/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_eddy_detection_gulf_stream.ipynb @@ -273,7 +273,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb b/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb index 63e763ff..381aa8f6 100644 --- a/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_filter_and_detection.ipynb @@ -176,7 +176,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb b/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb index 94e61b30..0cfdc9a8 100644 --- a/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_interp_grid_on_dataset.ipynb @@ -111,7 +111,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb b/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb index c70f7dd6..03eba8bf 100644 --- a/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_radius_vs_area.ipynb @@ -107,7 +107,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb b/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb index ffa58c1f..0ef03f6f 100644 --- a/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_shape_gallery.ipynb @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb b/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb index efbfcc76..9b8b3951 100644 --- a/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb +++ b/notebooks/python_module/02_eddy_identification/pet_sla_and_adt.ipynb @@ -223,7 +223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb new file mode 100644 index 00000000..7fa04435 --- /dev/null +++ b/notebooks/python_module/02_eddy_identification/pet_statistics_on_identification.ipynb @@ -0,0 +1,202 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Stastics on identification files\n\nSome statistics on raw identification without any tracking\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import numpy as np\nfrom matplotlib import pyplot as plt\nfrom matplotlib.dates import date2num\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.observations.observation import EddiesObservations\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title)\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.95, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load demo sample and take only first year.\n\nReplace by a list of filename to apply on your own dataset.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:365]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Merge all identification dataset in one object\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "all_a = EddiesObservations.concatenate(\n [EddiesObservations.load_file(i) for i in file_objects]\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define polygon bound\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, x1, y0, y1 = 15, 20, 33, 38\nxs = np.array([[x0, x1, x1, x0, x0]], dtype=\"f8\")\nys = np.array([[y0, y0, y1, y1, y0]], dtype=\"f8\")\n# Polygon object is create to be usable by match function.\npolygon = dict(contour_lon_e=xs, contour_lat_e=ys, contour_lon_s=xs, contour_lat_s=ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Geographic frequency of eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "step = 0.125\nax = start_axes(\"\")\n# Count pixel used for each contour\ng_a = all_a.grid_count(bins=((-10, 37, step), (30, 46, step)), intern=True)\nm = g_a.display(\n ax, cmap=\"terrain_r\", vmin=0, vmax=0.75, factor=1 / all_a.nb_days, name=\"count\"\n)\nax.plot(polygon[\"contour_lon_e\"][0], polygon[\"contour_lat_e\"][0], \"r\")\nupdate_axes(ax, m)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We use match function to count number of eddies which intersect the polygon defined previously.\n`p1_area` option allow to get in c_e/c_s output, precentage of area occupy by eddies in the polygon.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "i_e, j_e, c_e = all_a.match(polygon, p1_area=True, intern=False)\ni_s, j_s, c_s = all_a.match(polygon, p1_area=True, intern=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "dt = np.datetime64(\"1970-01-01\") - np.datetime64(\"1950-01-01\")\nkw_hist = dict(\n bins=date2num(np.arange(21900, 22300).astype(\"datetime64\") - dt), histtype=\"step\"\n)\n# translate julian day in datetime64\nt = all_a.time.astype(\"datetime64\") - dt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Count how many are in polygon\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Different way to count eddies presence in a polygon\")\nax.set_ylabel(\"Count\")\nm = all_a.mask_from_polygons(((xs, ys),))\nax.hist(t[m], label=\"center in polygon\", **kw_hist)\nax.hist(t[i_s[c_s > 0]], label=\"intersect speed contour with polygon\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], label=\"intersect extern contour with polygon\", **kw_hist)\nax.legend()\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Percent of are of interest occupy by eddies\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(12, 6)).add_subplot(111)\nax.set_title(\"Percent of polygon occupy by an anticyclonic eddy\")\nax.set_ylabel(\"Percent of polygon\")\nax.hist(t[i_s[c_s > 0]], weights=c_s[c_s > 0] * 100.0, label=\"speed contour\", **kw_hist)\nax.hist(t[i_e[c_e > 0]], weights=c_e[c_e > 0] * 100.0, label=\"effective contour\", **kw_hist)\nax.legend(), ax.set_ylim(0, 25)\nax.set_xlim(np.datetime64(\"2010\"), np.datetime64(\"2011\"))\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb index 53725a05..90ee1722 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_advect.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.observation import EddiesObservations" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(10, 5))\nax = fig.add_axes([0, 0, 1, 1], projection=\"full_axes\")\nax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid()\nx, y = meshgrid(g.x_c, g.y_c)\na.filled(ax, facecolors=\"r\", alpha=0.1), c.filled(ax, facecolors=\"b\", alpha=0.1)\n_ = ax.quiver(x.T, y.T, g.grid(\"u\"), g.grid(\"v\"), scale=20)" + "fig = plt.figure(figsize=(10, 5))\nax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\nax.set_xlim(19, 30), ax.set_ylim(31, 36.5), ax.grid()\nx, y = meshgrid(g.x_c, g.y_c)\na.filled(ax, facecolors=\"r\", alpha=0.1), c.filled(ax, facecolors=\"b\", alpha=0.1)\n_ = ax.quiver(x.T, y.T, g.grid(\"u\"), g.grid(\"v\"), scale=20)" ] }, { @@ -91,7 +91,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { @@ -127,7 +127,7 @@ }, "outputs": [], "source": [ - "kwargs = dict(frames=arange(51), interval=100)\nkw_p = dict(nb_step=2, time_step=21600)\nframe_t = kw_p[\"nb_step\"] * kw_p[\"time_step\"] / 86400.0" + "kwargs = dict(frames=arange(51), interval=100)\nkw_p = dict(u_name=\"u\", v_name=\"v\", nb_step=2, time_step=21600)\nframe_t = kw_p[\"nb_step\"] * kw_p[\"time_step\"] / 86400.0" ] }, { @@ -145,7 +145,7 @@ }, "outputs": [], "source": [ - "def anim_ax(**kw):\n t = 0\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=\"full_axes\")\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n a.filled(axes, facecolors=\"r\", alpha=0.1), c.filled(axes, facecolors=\"b\", alpha=0.1)\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line, t\n\n\ndef update(i_frame, t_step):\n global t\n x, y = p.__next__()\n t += t_step\n l.set_data(x, y)\n txt.set_text(f\"T0 + {t:.1f} days\")" + "def anim_ax(**kw):\n t = 0\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n a.filled(axes, facecolors=\"r\", alpha=0.1), c.filled(axes, facecolors=\"b\", alpha=0.1)\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line, t\n\n\ndef update(i_frame, t_step):\n global t\n x, y = p.__next__()\n t += t_step\n l.set_data(x, y)\n txt.set_text(f\"T0 + {t:.1f} days\")" ] }, { @@ -163,7 +163,7 @@ }, "outputs": [], "source": [ - "p = g.filament(x, y, \"u\", \"v\", **kw_p, filament_size=3)\nfig, txt, l, t = anim_ax(lw=0.5)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + "p = g.filament(x, y, **kw_p, filament_size=3)\nfig, txt, l, t = anim_ax(lw=0.5)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" ] }, { @@ -181,7 +181,7 @@ }, "outputs": [], "source": [ - "p = g.advect(x, y, \"u\", \"v\", **kw_p)\nfig, txt, l, t = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" + "p = g.advect(x, y, **kw_p)\nfig, txt, l, t = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(frame_t,))" ] }, { @@ -199,7 +199,7 @@ }, "outputs": [], "source": [ - "p = g.advect(x, y, \"u\", \"v\", **kw_p, backward=True)\nfig, txt, l, _ = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,))" + "p = g.advect(x, y, **kw_p, backward=True)\nfig, txt, l, _ = anim_ax(ls=\"\", marker=\".\", markersize=1)\n_ = VideoAnimation(fig, update, **kwargs, fargs=(-frame_t,))" ] }, { @@ -224,7 +224,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.001),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step)\n g.advect(x, y, \"u\", \"v\", **kw_advect).__next__()\n g.advect(x, y, \"u\", \"v\", **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" + "fig = plt.figure()\nax = fig.add_subplot(111)\nkw = dict(\n bins=arange(0, 50, 0.001),\n cumulative=True,\n weights=ones(x0.shape) / x0.shape[0] * 100.0,\n histtype=\"step\",\n)\nfor time_step in (10800, 21600, 43200, 86400):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(50 * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"{86400. / time_step:.0f} time step by day\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\"Distance after 50 days forward and 50 days backward\")\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than\")" ] }, { @@ -242,7 +242,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (5, 50, 100):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step)\n g.advect(x, y, \"u\", \"v\", **kw_advect).__next__()\n g.advect(x, y, \"u\", \"v\", **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" + "fig = plt.figure()\nax = fig.add_subplot(111)\ntime_step = 10800\nfor duration in (5, 50, 100):\n x, y = x0.copy(), y0.copy()\n kw_advect = dict(nb_step=int(duration * 86400 / time_step), time_step=time_step, u_name=\"u\", v_name=\"v\")\n g.advect(x, y, **kw_advect).__next__()\n g.advect(x, y, **kw_advect, backward=True).__next__()\n d = ((x - x0) ** 2 + (y - y0) ** 2) ** 0.5\n ax.hist(d, **kw, label=f\"Time duration {duration} days\")\nax.set_xlim(0, 0.25), ax.set_ylim(0, 100), ax.legend(loc=\"lower right\"), ax.grid()\nax.set_title(\n \"Distance after N days forward and N days backward\\nwith a time step of 1/8 days\"\n)\nax.set_xlabel(\"Distance between original position and final position (in degrees)\")\n_ = ax.set_ylabel(\"Percent of particles with distance lesser than \")" ] } ], @@ -262,7 +262,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb b/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb index 74a266c2..2d6a7d3a 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_filter.ipynb @@ -215,7 +215,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb b/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb index c9bca31e..f30076fa 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_hide_pixel_out_eddies.ipynb @@ -111,7 +111,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb index 6cef91bc..cbe6de64 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_lavd.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, zeros\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.observations.observation import EddiesObservations" + "import re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, zeros\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.observation import EddiesObservations" ] }, { @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(16, 9), dpi=dpi)\n ax = fig.add_axes([0, 0, 1, 1], projection=\"full_axes\")\n ax.set_xlim(0, 32), ax.set_ylim(28, 46)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"Vorticity integration along trajectory at initial position\")\n return cb\n\n\nkw_vorticity = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" + "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(16, 9), dpi=dpi)\n ax = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n ax.set_xlim(0, 32), ax.set_ylim(28, 46)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"Vorticity integration along trajectory at initial position\")\n return cb\n\n\nkw_vorticity = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" ] }, { @@ -48,7 +48,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { @@ -102,7 +102,7 @@ }, "outputs": [], "source": [ - "step = 1 / 32\nx_g, y_g = arange(0, 36, step), arange(28, 46, step)\nx, y = meshgrid(x_g, y_g)\noriginal_shape = x.shape\nx, y = x.reshape(-1), y.reshape(-1)\nprint(f\"{len(x)} particles advected\")\n# A frame every 8h\nstep_by_day = 3\n# Compute step of advection every 4h\nnb_step = 2\nkw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step)\n# Start a generator which at each iteration return new position at next time step\nparticule = g.advect(x, y, \"u\", \"v\", **kw_p, rk4=True)" + "step = 1 / 32\nx_g, y_g = arange(0, 36, step), arange(28, 46, step)\nx, y = meshgrid(x_g, y_g)\noriginal_shape = x.shape\nx, y = x.reshape(-1), y.reshape(-1)\nprint(f\"{len(x)} particles advected\")\n# A frame every 8h\nstep_by_day = 3\n# Compute step of advection every 4h\nnb_step = 2\nkw_p = dict(nb_step=nb_step, time_step=86400 / step_by_day / nb_step, u_name=\"u\", v_name=\"v\")\n# Start a generator which at each iteration return new position at next time step\nparticule = g.advect(x, y, **kw_p, rk4=True)" ] }, { @@ -138,7 +138,7 @@ }, "outputs": [], "source": [ - "def update(i_frame):\n global lavd, i\n i += 1\n x, y = particule.__next__()\n # Interp vorticity on new_position\n lavd += abs(g.interp(\"vort\", x, y).reshape(original_shape) * 1 / nb_time)\n txt.set_text(f\"T0 + {i / step_by_day:.2f} days of advection\")\n pcolormesh.set_array(lavd / i * nb_time)\n return pcolormesh, txt\n\n\nkw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True)\nfig, ax, txt = start_ax(dpi=60)\nx_g_, y_g_ = arange(0 - step / 2, 36 + step / 2, step), arange(\n 28 - step / 2, 46 + step / 2, step\n)\n# pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular\npcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity)\nupdate_axes(ax, pcolormesh)\n_ = VideoAnimation(ax.figure, update, **kw_video)" + "def update(i_frame):\n global lavd, i\n i += 1\n x, y = particule.__next__()\n # Interp vorticity on new_position\n lavd += abs(g.interp(\"vort\", x, y).reshape(original_shape) * 1 / nb_time)\n txt.set_text(f\"T0 + {i / step_by_day:.2f} days of advection\")\n pcolormesh.set_array(lavd / i * nb_time)\n return pcolormesh, txt\n\n\nkw_video = dict(frames=arange(nb_time), interval=1000.0 / step_by_day / 2, blit=True)\nfig, ax, txt = start_ax(dpi=60)\nx_g_, y_g_ = (\n arange(0 - step / 2, 36 + step / 2, step),\n arange(28 - step / 2, 46 + step / 2, step),\n)\n# pcolorfast will be faster than pcolormesh, we could use pcolorfast due to x and y are regular\npcolormesh = ax.pcolorfast(x_g_, y_g_, lavd, **kw_vorticity)\nupdate_axes(ax, pcolormesh)\n_ = VideoAnimation(ax.figure, update, **kw_video)" ] }, { @@ -163,7 +163,7 @@ }, "outputs": [], "source": [ - "lavd = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n lavd=lavd.T,\n lon=x_g,\n lat=y_g,\n ),\n centered=True,\n)" + "lavd = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"), datas=dict(lavd=lavd.T, lon=x_g, lat=y_g), centered=True\n)" ] }, { @@ -201,7 +201,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb b/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb index b410be0a..ca4998ee 100644 --- a/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb +++ b/notebooks/python_module/06_grid_manipulation/pet_okubo_weiss.ipynb @@ -201,7 +201,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb b/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb index 63cd36dc..d4cdb187 100644 --- a/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb +++ b/notebooks/python_module/07_cube_manipulation/pet_cube.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Time advection\n\nExample which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles.\n" + "\nTime advection\n==============\n\nExample which use CMEMS surface current with a Runge-Kutta 4 algorithm to advect particles.\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "# sphinx_gallery_thumbnail_number = 2\nimport re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\n\nstart_logger().setLevel(\"ERROR\")" + "# sphinx_gallery_thumbnail_number = 2\nimport re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, isnan, meshgrid, ones\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.gui import GUI_AXES\n\nstart_logger().setLevel(\"ERROR\")" ] }, { @@ -37,14 +37,14 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Data\nLoad Input time grid ADT\n\n" + "Data\n----\nLoad Input time grid ADT\n\n" ] }, { @@ -62,7 +62,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Anim\nParticles setup\n\n" + "Anim\n----\nParticles setup\n\n" ] }, { @@ -91,7 +91,7 @@ }, "outputs": [], "source": [ - "def anim_ax(**kw):\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=\"full_axes\")\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line\n\n\ndef update(_):\n tt, xt, yt = f.__next__()\n mappable.set_data(xt, yt)\n d = timedelta(tt / 86400.0) + datetime(1950, 1, 1)\n txt.set_text(f\"{d:%Y/%m/%d-%H}\")" + "def anim_ax(**kw):\n fig = plt.figure(figsize=(10, 5), dpi=55)\n axes = fig.add_axes([0, 0, 1, 1], projection=GUI_AXES)\n axes.set_xlim(19, 30), axes.set_ylim(31, 36.5), axes.grid()\n line = axes.plot([], [], \"k\", **kw)[0]\n return fig, axes.text(21, 32.1, \"\"), line\n\n\ndef update(_):\n tt, xt, yt = f.__next__()\n mappable.set_data(xt, yt)\n d = timedelta(tt / 86400.0) + datetime(1950, 1, 1)\n txt.set_text(f\"{d:%Y/%m/%d-%H}\")" ] }, { @@ -109,7 +109,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Particules stat\nTime_step settings\n^^^^^^^^^^^^^^^^^^\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" + "Particules stat\n---------------\nTime_step settings\n^^^^^^^^^^^^^^^^^^\nDummy experiment to test advection precision, we run particles 50 days forward and backward with different time step\nand we measure distance between new positions and original positions.\n\n" ] }, { @@ -127,7 +127,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Time duration\nWe keep same time_step but change time duration\n\n" + "Time duration\n^^^^^^^^^^^^^\nWe keep same time_step but change time duration\n\n" ] }, { diff --git a/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb index e821df6d..6f52e750 100644 --- a/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb +++ b/notebooks/python_module/07_cube_manipulation/pet_fsle_med.ipynb @@ -26,14 +26,14 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import arange, arctan2, empty, isnan, log2, ma, meshgrid, ones, pi, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\n\nstart_logger().setLevel(\"ERROR\")" + "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import arange, arctan2, empty, isnan, log, ma, meshgrid, ones, pi, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\n\nstart_logger().setLevel(\"ERROR\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## ADT in med\n\n" + "## ADT in med\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_cube` method is\nmade for data stores in time cube, you could use also\n:py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` method to\nload data-cube from multiple file.\n\n" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "@njit(cache=True, fastmath=True)\ndef check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6):\n \"\"\"\n Check if distance between eastern or northern particle to center particle is bigger than `dist_max`\n \"\"\"\n nb_p = x.shape[0] // 3\n delta = dist_max ** 2\n for i in range(nb_p):\n i0 = i * 3\n i_n = i0 + 1\n i_e = i0 + 2\n # If particle already set, we skip\n if m[i0] or m[i_n] or m[i_e]:\n continue\n # Distance with north\n dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n]\n dn = dxn ** 2 + dyn ** 2\n # Distance with east\n dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e]\n de = dxe ** 2 + dye ** 2\n\n if dn >= delta or de >= delta:\n s1 = dn + de\n at1 = 2 * (dxe * dxn + dye * dyn)\n at2 = de - dn\n s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * (\n (dxn - dye) ** 2 + (dxe + dyn) ** 2\n )\n flse[i] = 1 / (2 * dt) * log2(1 / (2 * dist_init ** 2) * (s1 + s2 ** 0.5))\n theta[i] = arctan2(at1, at2 + s2) * 180 / pi\n # To know where value are set\n m_set[i] = False\n # To stop particle advection\n m[i0], m[i_n], m[i_e] = True, True, True\n\n\n@njit(cache=True)\ndef build_triplet(x, y, step=0.02):\n \"\"\"\n Triplet building for each position we add east and north point with defined step\n \"\"\"\n nb_x = x.shape[0]\n x_ = empty(nb_x * 3, dtype=x.dtype)\n y_ = empty(nb_x * 3, dtype=y.dtype)\n for i in range(nb_x):\n i0 = i * 3\n i_n, i_e = i0 + 1, i0 + 2\n x__, y__ = x[i], y[i]\n x_[i0], y_[i0] = x__, y__\n x_[i_n], y_[i_n] = x__, y__ + step\n x_[i_e], y_[i_e] = x__ + step, y__\n return x_, y_" + "@njit(cache=True, fastmath=True)\ndef check_p(x, y, flse, theta, m_set, m, dt, dist_init=0.02, dist_max=0.6):\n \"\"\"\n Check if distance between eastern or northern particle to center particle is bigger than `dist_max`\n \"\"\"\n nb_p = x.shape[0] // 3\n delta = dist_max**2\n for i in range(nb_p):\n i0 = i * 3\n i_n = i0 + 1\n i_e = i0 + 2\n # If particle already set, we skip\n if m[i0] or m[i_n] or m[i_e]:\n continue\n # Distance with north\n dxn, dyn = x[i0] - x[i_n], y[i0] - y[i_n]\n dn = dxn**2 + dyn**2\n # Distance with east\n dxe, dye = x[i0] - x[i_e], y[i0] - y[i_e]\n de = dxe**2 + dye**2\n\n if dn >= delta or de >= delta:\n s1 = dn + de\n at1 = 2 * (dxe * dxn + dye * dyn)\n at2 = de - dn\n s2 = ((dxn + dye) ** 2 + (dxe - dyn) ** 2) * (\n (dxn - dye) ** 2 + (dxe + dyn) ** 2\n )\n flse[i] = 1 / (2 * dt) * log(1 / (2 * dist_init**2) * (s1 + s2**0.5))\n theta[i] = arctan2(at1, at2 + s2) * 180 / pi\n # To know where value are set\n m_set[i] = False\n # To stop particle advection\n m[i0], m[i_n], m[i_e] = True, True, True\n\n\n@njit(cache=True)\ndef build_triplet(x, y, step=0.02):\n \"\"\"\n Triplet building for each position we add east and north point with defined step\n \"\"\"\n nb_x = x.shape[0]\n x_ = empty(nb_x * 3, dtype=x.dtype)\n y_ = empty(nb_x * 3, dtype=y.dtype)\n for i in range(nb_x):\n i0 = i * 3\n i_n, i_e = i0 + 1, i0 + 2\n x__, y__ = x[i], y[i]\n x_[i0], y_[i0] = x__, y__\n x_[i_n], y_[i_n] = x__, y__ + step\n x_[i_e], y_[i_e] = x__ + step, y__\n return x_, y_" ] }, { @@ -116,7 +116,7 @@ }, "outputs": [], "source": [ - "# Array to compute fsle\nfsle = zeros(x0.shape[0], dtype=\"f4\")\ntheta = zeros(x0.shape[0], dtype=\"f4\")\nmask = ones(x0.shape[0], dtype=\"f4\")\nx, y = build_triplet(x0, y0, dist_init)\nused = zeros(x.shape[0], dtype=\"bool\")\n\n# advection generator\nkw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used)\np = c.advect(x, y, \"u\", \"v\", time_step=86400 / time_step_by_days, **kw)\n\n# We check at each step of advection if particle distance is over `dist_max`\nfor i in range(time_step_by_days * nb_days):\n t, xt, yt = p.__next__()\n dt = t / 86400.0 - t0\n check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init)\n\n# Get index with original_position\ni = ((x0 - x0_) / step_grid_out).astype(\"i4\")\nj = ((y0 - y0_) / step_grid_out).astype(\"i4\")\nfsle_ = empty(grid_shape, dtype=\"f4\")\ntheta_ = empty(grid_shape, dtype=\"f4\")\nmask_ = ones(grid_shape, dtype=\"bool\")\nfsle_[i, j] = fsle\ntheta_[i, j] = theta\nmask_[i, j] = mask\n# Create a grid object\nfsle_custom = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n fsle=ma.array(fsle_, mask=mask_),\n theta=ma.array(theta_, mask=mask_),\n lon=lon_p,\n lat=lat_p,\n ),\n centered=True,\n)" + "# Array to compute fsle\nfsle = zeros(x0.shape[0], dtype=\"f4\")\ntheta = zeros(x0.shape[0], dtype=\"f4\")\nmask = ones(x0.shape[0], dtype=\"f4\")\nx, y = build_triplet(x0, y0, dist_init)\nused = zeros(x.shape[0], dtype=\"bool\")\n\n# advection generator\nkw = dict(t_init=t0, nb_step=1, backward=backward, mask_particule=used, u_name=\"u\", v_name=\"v\")\np = c.advect(x, y, time_step=86400 / time_step_by_days, **kw)\n\n# We check at each step of advection if particle distance is over `dist_max`\nfor i in range(time_step_by_days * nb_days):\n t, xt, yt = p.__next__()\n dt = t / 86400.0 - t0\n check_p(xt, yt, fsle, theta, mask, used, dt, dist_max=dist_max, dist_init=dist_init)\n\n# Get index with original_position\ni = ((x0 - x0_) / step_grid_out).astype(\"i4\")\nj = ((y0 - y0_) / step_grid_out).astype(\"i4\")\nfsle_ = empty(grid_shape, dtype=\"f4\")\ntheta_ = empty(grid_shape, dtype=\"f4\")\nmask_ = ones(grid_shape, dtype=\"bool\")\nfsle_[i, j] = fsle\ntheta_[i, j] = theta\nmask_[i, j] = mask\n# Create a grid object\nfsle_custom = RegularGridDataset.with_array(\n coordinates=(\"lon\", \"lat\"),\n datas=dict(\n fsle=ma.array(fsle_, mask=mask_),\n theta=ma.array(theta_, mask=mask_),\n lon=lon_p,\n lat=lat_p,\n ),\n centered=True,\n)" ] }, { @@ -134,7 +134,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(13, 5), dpi=150)\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\nax.set_aspect(\"equal\")\nax.set_title(\"Finite size lyapunov exponent\", weight=\"bold\")\nkw = dict(cmap=\"viridis_r\", vmin=-15, vmax=0)\nm = fsle_custom.display(ax, 1 / fsle_custom.grid(\"fsle\"), **kw)\nax.grid()\n_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9]))" + "fig = plt.figure(figsize=(13, 5), dpi=150)\nax = fig.add_axes([0.03, 0.03, 0.90, 0.94])\nax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\nax.set_aspect(\"equal\")\nax.set_title(\"Finite size lyapunov exponent\", weight=\"bold\")\nkw = dict(cmap=\"viridis_r\", vmin=-20, vmax=0)\nm = fsle_custom.display(ax, 1 / fsle_custom.grid(\"fsle\"), **kw)\nax.grid()\n_ = plt.colorbar(m, cax=fig.add_axes([0.94, 0.05, 0.01, 0.9]))" ] }, { @@ -172,7 +172,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb index e123abe4..708d7024 100644 --- a/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb +++ b/notebooks/python_module/07_cube_manipulation/pet_lavd_detection.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, isnan, ma, meshgrid, zeros\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\n\nstart_logger().setLevel(\"ERROR\")" + "from datetime import datetime\n\nfrom matplotlib import pyplot as plt\nfrom numpy import arange, isnan, ma, meshgrid, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection, RegularGridDataset\nfrom py_eddy_tracker.gui import GUI_AXES\n\nstart_logger().setLevel(\"ERROR\")" ] }, { @@ -48,7 +48,7 @@ }, "outputs": [], "source": [ - "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(12, 5), dpi=dpi)\n ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=\"full_axes\")\n ax.set_xlim(-6, 36), ax.set_ylim(31, 45)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"LAVD at initial position\")\n return cb\n\n\nkw_lavd = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" + "def start_ax(title=\"\", dpi=90):\n fig = plt.figure(figsize=(12, 5), dpi=dpi)\n ax = fig.add_axes([0.05, 0.08, 0.9, 0.9], projection=GUI_AXES)\n ax.set_xlim(-6, 36), ax.set_ylim(31, 45)\n ax.set_title(title)\n return fig, ax, ax.text(3, 32, \"\", fontsize=20)\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n cb = plt.colorbar(\n mappable,\n cax=ax.figure.add_axes([0.05, 0.1, 0.9, 0.01]),\n orientation=\"horizontal\",\n )\n cb.set_label(\"LAVD at initial position\")\n return cb\n\n\nkw_lavd = dict(vmin=0, vmax=2e-5, cmap=\"viridis\")" ] }, { @@ -84,7 +84,7 @@ }, "outputs": [], "source": [ - "# Time properties, for example with advection only 25 days\nnb_days, step_by_day = 25, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day)\nt0 = 20236\nt0_grid = c[t0]\n# Geographic properties, we use a coarser resolution for time consuming reasons\nstep = 1 / 32.0\nx_g, y_g = arange(-6, 36, step), arange(30, 46, step)\nx0, y0 = meshgrid(x_g, y_g)\noriginal_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Get all particles in defined area\nm = ~isnan(t0_grid.interp(\"vort\", x0, y0))\nx0, y0 = x0[m], y0[m]\nprint(f\"{x0.size} particles advected\")\n# Gridded mask\nm = m.reshape(original_shape)" + "# Time properties, for example with advection only 25 days\nnb_days, step_by_day = 25, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day, u_name=\"u\", v_name=\"v\")\nt0 = 20236\nt0_grid = c[t0]\n# Geographic properties, we use a coarser resolution for time consuming reasons\nstep = 1 / 32.0\nx_g, y_g = arange(-6, 36, step), arange(30, 46, step)\nx0, y0 = meshgrid(x_g, y_g)\noriginal_shape = x0.shape\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\n# Get all particles in defined area\nm = ~isnan(t0_grid.interp(\"vort\", x0, y0))\nx0, y0 = x0[m], y0[m]\nprint(f\"{x0.size} particles advected\")\n# Gridded mask\nm = m.reshape(original_shape)" ] }, { @@ -102,7 +102,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), \"u\", \"v\", t_init=t0, **kw_p)\nfor _ in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection\")\nmappable = lavd_forward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, **kw_p)\nfor _ in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection\")\nmappable = lavd_forward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -120,7 +120,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), \"u\", \"v\", t_init=t0, backward=True, **kw_p)\nfor i in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection\")\nmappable = lavd_backward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = c.advect(x0.copy(), y0.copy(), t_init=t0, backward=True, **kw_p)\nfor i in range(nb_time):\n t, x, y = p.__next__()\n lavd_ += abs(c.interp(\"vort\", t / 86400.0, x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection\")\nmappable = lavd_backward.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -138,7 +138,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), \"u\", \"v\", **kw_p)\nfor _ in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection on a static velocity field\")\nmappable = lavd_forward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), **kw_p)\nfor _ in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_forward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a forward advection on a static velocity field\")\nmappable = lavd_forward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -156,7 +156,7 @@ }, "outputs": [], "source": [ - "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), \"u\", \"v\", backward=True, **kw_p)\nfor i in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection on a static velocity field\")\nmappable = lavd_backward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" + "lavd = zeros(original_shape)\nlavd_ = lavd[m]\np = t0_grid.advect(x0.copy(), y0.copy(), backward=True, **kw_p)\nfor i in range(nb_time):\n x, y = p.__next__()\n lavd_ += abs(t0_grid.interp(\"vort\", x, y))\nlavd[m] = lavd_ / nb_time\n# Put LAVD result in a standard py eddy tracker grid\nlavd_backward_static = LAVDGrid.from_(x_g, y_g, ma.array(lavd, mask=~m).T)\n# Display\nfig, ax, _ = start_ax(\"LAVD with a backward advection on a static velocity field\")\nmappable = lavd_backward_static.display(ax, \"lavd\", **kw_lavd)\n_ = update_axes(ax, mappable)" ] }, { @@ -194,7 +194,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb new file mode 100644 index 00000000..b92c4d21 --- /dev/null +++ b/notebooks/python_module/07_cube_manipulation/pet_particles_drift.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Build path of particle drifting\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numpy import arange, meshgrid\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\n\nstart_logger().setLevel(\"ERROR\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load data cube\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c = GridCollection.from_netcdf_cube(\n get_demo_path(\"dt_med_allsat_phy_l4_2005T2.nc\"),\n \"longitude\",\n \"latitude\",\n \"time\",\n unset=True\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Advection properties\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "nb_days, step_by_day = 10, 6\nnb_time = step_by_day * nb_days\nkw_p = dict(nb_step=1, time_step=86400 / step_by_day)\nt0 = 20210" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "x0, y0 = meshgrid(arange(32, 35, 0.5), arange(32.5, 34.5, 0.5))\nx0, y0 = x0.reshape(-1), y0.reshape(-1)\nt, x, y = c.path(x0, y0, h_name=\"adt\", t_init=t0, **kw_p, nb_time=nb_time)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot paths\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = plt.figure(figsize=(9, 6)).add_subplot(111, aspect=\"equal\")\nax.plot(x0, y0, \"k.\", ms=20)\nax.plot(x, y, lw=3)\nax.set_title(\"10 days particle paths\")\nax.set_xlim(31, 35), ax.set_ylim(32, 34.5)\nax.grid()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb index bf924b36..6e43e9a4 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_display_field.ipynb @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb index 1af7b49a..c98e53f0 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_display_track.ipynb @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb new file mode 100644 index 00000000..0681c0fc --- /dev/null +++ b/notebooks/python_module/08_tracking_manipulation/pet_how_to_use_correspondances.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Correspondances\n\nCorrespondances is a mechanism to intend to continue tracking with new detection\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import logging" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom netCDF4 import Dataset\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.featured_tracking.area_tracker import AreaTracker\n\n# In order to hide some warning\nimport py_eddy_tracker.observations.observation\nfrom py_eddy_tracker.tracking import Correspondances\n\npy_eddy_tracker.observations.observation._display_check_warning = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def plot_eddy(ed):\n fig = plt.figure(figsize=(10, 5))\n ax = fig.add_axes([0.05, 0.03, 0.90, 0.94])\n ed.plot(ax, ref=-10, marker=\"x\")\n lc = ed.display_color(ax, field=ed.time, ref=-10, intern=True)\n plt.colorbar(lc).set_label(\"Time in Julian days (from 1950/01/01)\")\n ax.set_xlim(4.5, 8), ax.set_ylim(36.8, 38.3)\n ax.set_aspect(\"equal\")\n ax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get remote data, we will keep only 20 first days,\n`get_remote_demo_sample` function is only to get demo dataset, in your own case give a list of identification filename\nand don't mix cyclonic and anticyclonic files.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "file_objects = get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic_2010_2011_2012\"\n)[:20]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We run a traking with a tracker which use contour overlap, on 10 first time step\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_first_run = Correspondances(\n datasets=file_objects[:10], class_method=AreaTracker, virtual=4\n)\nstart_logger().setLevel(\"INFO\")\nc_first_run.track()\nstart_logger().setLevel(\"WARNING\")\nwith Dataset(\"correspondances.nc\", \"w\") as h:\n c_first_run.to_netcdf(h)\n# Next step are done only to build atlas and display it\nc_first_run.prepare_merging()\n\n# We have now an eddy object\neddies_area_tracker = c_first_run.merge(raw_data=False)\neddies_area_tracker.virtual[:] = eddies_area_tracker.time == 0\neddies_area_tracker.filled_by_interpolation(eddies_area_tracker.virtual == 1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot from first ten days\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Restart from previous run\nWe give all filenames, the new one and filename from previous run\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "c_second_run = Correspondances(\n datasets=file_objects[:20],\n # This parameter must be identical in each run\n class_method=AreaTracker,\n virtual=4,\n # Previous saved correspondancs\n previous_correspondance=\"correspondances.nc\",\n)\nstart_logger().setLevel(\"INFO\")\nc_second_run.track()\nstart_logger().setLevel(\"WARNING\")\nc_second_run.prepare_merging()\n# We have now another eddy object\neddies_area_tracker_extend = c_second_run.merge(raw_data=False)\neddies_area_tracker_extend.virtual[:] = eddies_area_tracker_extend.time == 0\neddies_area_tracker_extend.filled_by_interpolation(\n eddies_area_tracker_extend.virtual == 1\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Plot with time extension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "plot_eddy(eddies_area_tracker_extend)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb index 2749f7e9..95595a7a 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_one_track.ipynb @@ -93,7 +93,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb index e8871283..d0a2e5b0 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_run_a_tracking.ipynb @@ -154,7 +154,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb index 5ba0d481..8e64b680 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_select_track_across_area.ipynb @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb index 041c8987..08364d16 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_track_anim.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Track animation\n\nRun in a terminal this script, which allow to watch eddy evolution\n" + "\nTrack animation\n===============\n\nRun in a terminal this script, which allow to watch eddy evolution.\n\nYou could use also *EddyAnim* script to display/save animation.\n" ] }, { diff --git a/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb b/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb index 9f77dbae..1fc4d082 100644 --- a/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb +++ b/notebooks/python_module/08_tracking_manipulation/pet_track_anim_matplotlib_animation.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Track animation with standard matplotlib\n\nRun in a terminal this script, which allow to watch eddy evolution\n" + "\nTrack animation with standard matplotlib\n========================================\n\nRun in a terminal this script, which allow to watch eddy evolution.\n\nYou could use also *EddyAnim* script to display/save animation.\n" ] }, { @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { @@ -98,4 +98,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb index d9a2ef2b..635c6b5a 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_birth_and_death.ipynb @@ -144,7 +144,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb index b6bb15bd..753cd625 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_center_count.ipynb @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb index 3e884552..df495703 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_geographic_stats.ipynb @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb index 85e32c6a..9f06e010 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_groups.ipynb @@ -136,7 +136,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb index 851c6ca4..81809d8b 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_histo.ipynb @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb index 4a3ff0af..ed8c0295 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_lifetime.ipynb @@ -82,7 +82,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb new file mode 100644 index 00000000..f9fb474f --- /dev/null +++ b/notebooks/python_module/10_tracking_diagnostics/pet_normalised_lifetime.ipynb @@ -0,0 +1,119 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nNormalised Eddy Lifetimes\n=========================\n\nExample from Evan Mason\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\nfrom numba import njit\nfrom numpy import interp, linspace, zeros\nfrom py_eddy_tracker_sample import get_demo_path\n\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "@njit(cache=True)\ndef sum_profile(x_new, y, out):\n \"\"\"Will sum all interpolated given array\"\"\"\n out += interp(x_new, linspace(0, 1, y.size), y)\n\n\nclass MyObs(TrackEddiesObservations):\n def eddy_norm_lifetime(self, name, nb, factor=1):\n \"\"\"\n :param str,array name: Array or field name\n :param int nb: size of output array\n \"\"\"\n y = self.parse_varname(name)\n x = linspace(0, 1, nb)\n out = zeros(nb, dtype=y.dtype)\n nb_track = 0\n for i, b0, b1 in self.iter_on(\"track\"):\n y_ = y[i]\n size_ = y_.size\n if size_ == 0:\n continue\n sum_profile(x, y_, out)\n nb_track += 1\n return x, out / nb_track * factor" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load atlas\n----------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "kw = dict(include_vars=(\"speed_radius\", \"amplitude\", \"track\"))\na = MyObs.load_file(\n get_demo_path(\"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"), **kw\n)\nc = MyObs.load_file(get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\"), **kw)\n\nnb_max_a = a.nb_obs_by_track.max()\nnb_max_c = c.nb_obs_by_track.max()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Compute normalised lifetime\n---------------------------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# Radius\nAC_radius = a.eddy_norm_lifetime(\"speed_radius\", nb=nb_max_a, factor=1e-3)\nCC_radius = c.eddy_norm_lifetime(\"speed_radius\", nb=nb_max_c, factor=1e-3)\n# Amplitude\nAC_amplitude = a.eddy_norm_lifetime(\"amplitude\", nb=nb_max_a, factor=1e2)\nCC_amplitude = c.eddy_norm_lifetime(\"amplitude\", nb=nb_max_c, factor=1e2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Figure\n------\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig, (ax0, ax1) = plt.subplots(nrows=2, figsize=(8, 6))\n\nax0.set_title(\"Normalised Mean Radius\")\nax0.plot(*AC_radius), ax0.plot(*CC_radius)\nax0.set_ylabel(\"Radius (km)\"), ax0.grid()\nax0.set_xlim(0, 1), ax0.set_ylim(0, None)\n\nax1.set_title(\"Normalised Mean Amplitude\")\nax1.plot(*AC_amplitude, label=\"AC\"), ax1.plot(*CC_amplitude, label=\"CC\")\nax1.set_ylabel(\"Amplitude (cm)\"), ax1.grid(), ax1.legend()\n_ = ax1.set_xlim(0, 1), ax1.set_ylim(0, None)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb index 81bed372..23f830d6 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_pixel_used.ipynb @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb b/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb index e0d1f2d2..9792f8f4 100644 --- a/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb +++ b/notebooks/python_module/10_tracking_diagnostics/pet_propagation.ipynb @@ -118,7 +118,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb b/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb index 05b0413c..b30682a1 100644 --- a/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb +++ b/notebooks/python_module/12_external_data/pet_SST_collocation.ipynb @@ -226,7 +226,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb b/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb new file mode 100644 index 00000000..7ba30914 --- /dev/null +++ b/notebooks/python_module/12_external_data/pet_drifter_loopers.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\nColocate looper with eddy from altimetry\n========================================\n\nAll loopers data used in this example are a subset from the dataset described in this article\n[Lumpkin, R. : Global characteristics of coherent vortices from surface drifter trajectories](https://doi.org/10.1002/2015JC011435)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import re\n\nimport numpy as np\nimport py_eddy_tracker_sample\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], aspect=\"equal\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load eddies dataset\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "cyclonic_eddies = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\"eddies_med_adt_allsat_dt2018/Cyclonic.zarr\")\n)\nanticyclonic_eddies = TrackEddiesObservations.load_file(\n py_eddy_tracker_sample.get_demo_path(\n \"eddies_med_adt_allsat_dt2018/Anticyclonic.zarr\"\n )\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load loopers dataset\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "loopers_med = TrackEddiesObservations.load_file(\n data.get_demo_path(\"loopers_lumpkin_med.nc\")\n)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Global view\n===========\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = start_axes(\"All drifters available in Med from Lumpkin dataset\")\nloopers_med.plot(ax, lw=0.5, color=\"r\", ref=-10)\nupdate_axes(ax)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One segment of drifter\n======================\n\nGet a drifter segment (the indexes used have no correspondance with the original dataset).\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "looper = loopers_med.extract_ids((3588,))\nfig = plt.figure(figsize=(16, 6))\nax = fig.add_subplot(111, aspect=\"equal\")\nlooper.plot(ax, lw=0.5, label=\"Original position of drifter\")\nlooper_filtered = looper.copy()\nlooper_filtered.position_filter(1, 13)\ns = looper_filtered.scatter(\n ax,\n \"time\",\n cmap=plt.get_cmap(\"Spectral_r\", 20),\n label=\"Filtered position of drifter\",\n)\nplt.colorbar(s).set_label(\"time (days from 1/1/1950)\")\nax.legend()\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Try to find a detected eddies with adt at same place. We used filtered track to simulate an eddy center\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "match = looper_filtered.close_tracks(\n anticyclonic_eddies, method=\"close_center\", delta=0.1, nb_obs_min=50\n)\nfig = plt.figure(figsize=(16, 6))\nax = fig.add_subplot(111, aspect=\"equal\")\nlooper.plot(ax, lw=0.5, label=\"Original position of drifter\")\nlooper_filtered.plot(ax, lw=1.5, label=\"Filtered position of drifter\")\nmatch.plot(ax, lw=1.5, label=\"Matched eddy\")\nax.legend()\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Display radius of this 2 datasets.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(20, 8))\nax = fig.add_subplot(111)\nax.plot(looper.time, looper.radius_s / 1e3, label=\"loopers\")\nlooper_radius = looper.copy()\nlooper_radius.median_filter(1, \"time\", \"radius_s\", inplace=True)\nlooper_radius.loess_filter(13, \"time\", \"radius_s\", inplace=True)\nax.plot(\n looper_radius.time,\n looper_radius.radius_s / 1e3,\n label=\"loopers (filtered half window 13 days)\",\n)\nax.plot(match.time, match.radius_s / 1e3, label=\"altimetry\")\nmatch_radius = match.copy()\nmatch_radius.median_filter(1, \"time\", \"radius_s\", inplace=True)\nmatch_radius.loess_filter(13, \"time\", \"radius_s\", inplace=True)\nax.plot(\n match_radius.time,\n match_radius.radius_s / 1e3,\n label=\"altimetry (filtered half window 13 days)\",\n)\nax.set_ylabel(\"radius(km)\"), ax.set_ylim(0, 100)\nax.legend()\nax.set_title(\"Radius from loopers and altimeter\")\nax.grid()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Animation of a drifter and its colocated eddy\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "def update(frame):\n # We display last 5 days of loopers trajectory\n m = (looper.time < frame) * (looper.time > (frame - 5))\n anim.func_animation(frame)\n line.set_data(looper.lon[m], looper.lat[m])\n\n\nanim = Anim(match, intern=True, figsize=(8, 8), cmap=\"magma_r\", nb_step=10, dpi=75)\n# mappable to show drifter in red\nline = anim.ax.plot([], [], \"r\", lw=4, zorder=100)[0]\nanim.fig.suptitle(\"\")\n_ = VideoAnimation(anim.fig, update, frames=np.arange(*anim.period, 1), interval=125)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb b/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb index 4cec72b2..a46a7e22 100644 --- a/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb +++ b/notebooks/python_module/14_generic_tools/pet_fit_contour.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom numpy import cos, linspace, radians, sin\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import fit_circle_, fit_ellips" + "from matplotlib import pyplot as plt\nfrom numpy import cos, linspace, radians, sin\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.generic import coordinates_to_local, local_to_coordinates\nfrom py_eddy_tracker.observations.observation import EddiesObservations\nfrom py_eddy_tracker.poly import fit_circle_, fit_ellipse" ] }, { @@ -51,7 +51,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Function to draw circle or ellips from parameter\n\n" + "Function to draw circle or ellipse from parameter\n\n" ] }, { @@ -62,14 +62,14 @@ }, "outputs": [], "source": [ - "def build_circle(x0, y0, r):\n angle = radians(linspace(0, 360, 50))\n x_norm, y_norm = cos(angle), sin(angle)\n return local_to_coordinates(x_norm * r, y_norm * r, x0, y0)\n\n\ndef build_ellips(x0, y0, a, b, theta):\n angle = radians(linspace(0, 360, 50))\n x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle)\n y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle)\n return local_to_coordinates(x, y, x0, y0)" + "def build_circle(x0, y0, r):\n angle = radians(linspace(0, 360, 50))\n x_norm, y_norm = cos(angle), sin(angle)\n return local_to_coordinates(x_norm * r, y_norm * r, x0, y0)\n\n\ndef build_ellipse(x0, y0, a, b, theta):\n angle = radians(linspace(0, 360, 50))\n x = a * cos(theta) * cos(angle) - b * sin(theta) * sin(angle)\n y = a * sin(theta) * cos(angle) + b * cos(theta) * sin(angle)\n return local_to_coordinates(x, y, x0, y0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Plot fitted circle or ellips on stored contour\n\n" + "Plot fitted circle or ellipse on stored contour\n\n" ] }, { @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "xs, ys = a.contour_lon_s, a.contour_lat_s\n\nfig = plt.figure(figsize=(15, 15))\n\nj = 1\nfor i in range(0, 800, 30):\n x, y = xs[i], ys[i]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n ax = fig.add_subplot(4, 4, j)\n ax.grid(), ax.set_aspect(\"equal\")\n ax.plot(x, y, label=\"store\", color=\"black\")\n x0, y0, a, b, theta = fit_ellips(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_ellips(x0, y0, a, b, theta), label=\"ellips\", color=\"green\")\n x0, y0, radius, shape_error = fit_circle_(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_circle(x0, y0, radius), label=\"circle\", color=\"red\", lw=0.5)\n if j == 16:\n break\n j += 1" + "xs, ys = a.contour_lon_s, a.contour_lat_s\n\nfig = plt.figure(figsize=(15, 15))\n\nj = 1\nfor i in range(0, 800, 30):\n x, y = xs[i], ys[i]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n ax = fig.add_subplot(4, 4, j)\n ax.grid(), ax.set_aspect(\"equal\")\n ax.plot(x, y, label=\"store\", color=\"black\")\n x0, y0, a, b, theta = fit_ellipse(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_ellipse(x0, y0, a, b, theta), label=\"ellipse\", color=\"green\")\n x0, y0, radius, shape_error = fit_circle_(x_, y_)\n x0, y0 = local_to_coordinates(x0, y0, x0_, y0_)\n ax.plot(*build_circle(x0, y0, radius), label=\"circle\", color=\"red\", lw=0.5)\n if j == 16:\n break\n j += 1" ] } ], @@ -100,7 +100,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb b/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb index 0183abde..69e49b57 100644 --- a/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb +++ b/notebooks/python_module/14_generic_tools/pet_visvalingam.ipynb @@ -75,7 +75,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/16_network/pet_atlas.ipynb b/notebooks/python_module/16_network/pet_atlas.ipynb index 514317a6..31e3580f 100644 --- a/notebooks/python_module/16_network/pet_atlas.ipynb +++ b/notebooks/python_module/16_network/pet_atlas.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom numpy import ma\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nn = NetworkObservations.load_file(\n get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc\"\n )\n)" + "from matplotlib import pyplot as plt\nfrom numpy import ma\n\nfrom py_eddy_tracker.data import get_remote_demo_sample\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nn = NetworkObservations.load_file(\n get_remote_demo_sample(\n \"eddies_med_adt_allsat_dt2018_err70_filt500_order1/Anticyclonic_network.nc\"\n )\n)" ] }, { @@ -62,7 +62,7 @@ }, "outputs": [], "source": [ - "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=\"full_axes\")\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + "def start_axes(title):\n fig = plt.figure(figsize=(13, 5))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(-6, 36.5), ax.set_ylim(30, 46)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid()\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" ] }, { @@ -363,7 +363,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/16_network/pet_follow_particle.ipynb b/notebooks/python_module/16_network/pet_follow_particle.ipynb index 30c85a49..a2a97944 100644 --- a/notebooks/python_module/16_network/pet_follow_particle.ipynb +++ b/notebooks/python_module/16_network/pet_follow_particle.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Follow particle\n" + "\nFollow particle\n===============\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import re\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numba import njit\nfrom numba import types as nb_types\nfrom numpy import arange, meshgrid, ones, unique, where, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.poly import group_obs\n\nstart_logger().setLevel(\"ERROR\")" + "import re\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom numpy import arange, meshgrid, ones, unique, zeros\n\nfrom py_eddy_tracker import start_logger\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.dataset.grid import GridCollection\nfrom py_eddy_tracker.observations.groups import particle_candidate\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\nstart_logger().setLevel(\"ERROR\")" ] }, { @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { @@ -55,7 +55,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Schema\n\n" + "Schema\n------\n\n" ] }, { @@ -73,7 +73,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Animation\nParticle settings\n\n" + "Animation\n---------\nParticle settings\n\n" ] }, { @@ -109,7 +109,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## In which observations are the particle\n\n" + "Particle advection\n^^^^^^^^^^^^^^^^^^\nAdvection from speed contour to speed contour (default)\n\n" ] }, { @@ -120,25 +120,7 @@ }, "outputs": [], "source": [ - "def advect(x, y, c, t0, delta_t):\n \"\"\"\n Advect particle from t0 to t0 + delta_t, with data cube.\n \"\"\"\n kw = dict(nb_step=6, time_step=86400 / 6)\n if delta_t < 0:\n kw[\"backward\"] = True\n delta_t = -delta_t\n p = c.advect(x, y, \"u\", \"v\", t_init=t0, **kw)\n for _ in range(delta_t):\n t, x, y = p.__next__()\n return t, x, y\n\n\ndef particle_candidate(x, y, c, eddies, t_start, i_target, pct, **kwargs):\n # Obs from initial time\n m_start = eddies.time == t_start\n e = eddies.extract_with_mask(m_start)\n # to be able to get global index\n translate_start = where(m_start)[0]\n # Identify particle in eddies(only in core)\n i_start = e.contains(x, y, intern=True)\n m = i_start != -1\n x, y, i_start = x[m], y[m], i_start[m]\n # Advect\n t_end, x, y = advect(x, y, c, t_start, **kwargs)\n # eddies at last date\n m_end = eddies.time == t_end / 86400\n e_end = eddies.extract_with_mask(m_end)\n # to be able to get global index\n translate_end = where(m_end)[0]\n # Id eddies for each alive particle(in core and extern)\n i_end = e_end.contains(x, y)\n # compute matrix and filled target array\n get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct)\n\n\n@njit(cache=True)\ndef get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct):\n nb_start, nb_end = translate_start.size, translate_end.size\n # Matrix which will store count for every couple\n count = zeros((nb_start, nb_end), dtype=nb_types.int32)\n # Number of particle in each origin observation\n ref = zeros(nb_start, dtype=nb_types.int32)\n # For each particle\n for i in range(i_start.size):\n i_end_ = i_end[i]\n i_start_ = i_start[i]\n if i_end_ != -1:\n count[i_start_, i_end_] += 1\n ref[i_start_] += 1\n for i in range(nb_start):\n for j in range(nb_end):\n pct_ = count[i, j]\n # If there are particle from i to j\n if pct_ != 0:\n # Get percent\n pct_ = pct_ / ref[i] * 100.0\n # Get indices in full dataset\n i_, j_ = translate_start[i], translate_end[j]\n pct_0 = pct[i_, 0]\n if pct_ > pct_0:\n pct[i_, 1] = pct_0\n pct[i_, 0] = pct_\n i_target[i_, 1] = i_target[i_, 0]\n i_target[i_, 0] = j_\n elif pct_ > pct[i_, 1]:\n pct[i_, 1] = pct_\n i_target[i_, 1] = j_\n return i_target, pct" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Particle advection\n\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": false - }, - "outputs": [], - "source": [ - "step = 1 / 60.0\n\nx, y = meshgrid(arange(24, 36, step), arange(31, 36, step))\nx0, y0 = x.reshape(-1), y.reshape(-1)\n# Pre-order to speed up\n_, i = group_obs(x0, y0, 1, 360)\nx0, y0 = x0[i], y0[i]\n\nt_start, t_end = n.period\ndt = 14\n\nshape = (n.obs.size, 2)\n# Forward run\ni_target_f, pct_target_f = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in range(t_start, t_end - dt):\n particle_candidate(x0, y0, c, n, t, i_target_f, pct_target_f, delta_t=dt)\n\n# Backward run\ni_target_b, pct_target_b = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in range(t_start + dt, t_end):\n particle_candidate(x0, y0, c, n, t, i_target_b, pct_target_b, delta_t=-dt)" + "step = 1 / 60.0\n\nt_start, t_end = int(n.period[0]), int(n.period[1])\ndt = 14\n\nshape = (n.obs.size, 2)\n# Forward run\ni_target_f, pct_target_f = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in arange(t_start, t_end - dt):\n particle_candidate(c, n, step, t, i_target_f, pct_target_f, n_days=dt)\n\n# Backward run\ni_target_b, pct_target_b = -ones(shape, dtype=\"i4\"), zeros(shape, dtype=\"i1\")\nfor t in arange(t_start + dt, t_end):\n particle_candidate(c, n, step, t, i_target_b, pct_target_b, n_days=-dt)" ] }, { @@ -149,7 +131,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(10, 10))\nax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45])\nax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45])\nax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45])\nax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45])\nax_1st_b.set_title(\"Backward advection for each time step\")\nax_1st_f.set_title(\"Forward advection for each time step\")\n\n\ndef color_alpha(target, pct, vmin=5, vmax=80):\n color = cmap(n.segment[target])\n # We will hide under 5 % and from 80% to 100 % it will be 1\n alpha = (pct - vmin) / (vmax - vmin)\n alpha[alpha < 0] = 0\n alpha[alpha > 1] = 1\n color[:, 3] = alpha\n return color\n\n\nkw = dict(\n name=None, yfield=\"longitude\", event=False, zorder=-100, s=(n.speed_area / 20e6)\n)\nn.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw)\nn.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw)\nn.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw)\nn.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw)\nfor ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f):\n n.display_timeline(ax, field=\"longitude\", marker=\"+\", lw=2, markersize=5)\n ax.grid()" + "fig = plt.figure(figsize=(10, 10))\nax_1st_b = fig.add_axes([0.05, 0.52, 0.45, 0.45])\nax_2nd_b = fig.add_axes([0.05, 0.05, 0.45, 0.45])\nax_1st_f = fig.add_axes([0.52, 0.52, 0.45, 0.45])\nax_2nd_f = fig.add_axes([0.52, 0.05, 0.45, 0.45])\nax_1st_b.set_title(\"Backward advection for each time step\")\nax_1st_f.set_title(\"Forward advection for each time step\")\nax_1st_b.set_ylabel(\"Color -> First target\\nLatitude\")\nax_2nd_b.set_ylabel(\"Color -> Secondary target\\nLatitude\")\nax_2nd_b.set_xlabel(\"Julian days\"), ax_2nd_f.set_xlabel(\"Julian days\")\nax_1st_f.set_yticks([]), ax_2nd_f.set_yticks([])\nax_1st_f.set_xticks([]), ax_1st_b.set_xticks([])\n\n\ndef color_alpha(target, pct, vmin=5, vmax=80):\n color = cmap(n.segment[target])\n # We will hide under 5 % and from 80% to 100 % it will be 1\n alpha = (pct - vmin) / (vmax - vmin)\n alpha[alpha < 0] = 0\n alpha[alpha > 1] = 1\n color[:, 3] = alpha\n return color\n\n\nkw = dict(\n name=None, yfield=\"longitude\", event=False, zorder=-100, s=(n.speed_area / 20e6)\n)\nn.scatter_timeline(ax_1st_b, c=color_alpha(i_target_b.T[0], pct_target_b.T[0]), **kw)\nn.scatter_timeline(ax_2nd_b, c=color_alpha(i_target_b.T[1], pct_target_b.T[1]), **kw)\nn.scatter_timeline(ax_1st_f, c=color_alpha(i_target_f.T[0], pct_target_f.T[0]), **kw)\nn.scatter_timeline(ax_2nd_f, c=color_alpha(i_target_f.T[1], pct_target_f.T[1]), **kw)\nfor ax in (ax_1st_b, ax_2nd_b, ax_1st_f, ax_2nd_f):\n n.display_timeline(ax, field=\"longitude\", marker=\"+\", lw=2, markersize=5)\n ax.grid()" ] } ], diff --git a/notebooks/python_module/16_network/pet_group_anim.ipynb b/notebooks/python_module/16_network/pet_group_anim.ipynb index ffb9dd17..090170ff 100644 --- a/notebooks/python_module/16_network/pet_group_anim.ipynb +++ b/notebooks/python_module/16_network/pet_group_anim.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Network group process\n" + "\nNetwork group process\n=====================\n" ] }, { @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)" ] }, { @@ -156,7 +156,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Anim\n\n" + "Anim\n----\n\n" ] }, { @@ -174,7 +174,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Final Result\n\n" + "Final Result\n------------\n\n" ] }, { diff --git a/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb b/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb index e803df5f..9d659597 100644 --- a/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb +++ b/notebooks/python_module/16_network/pet_ioannou_2017_case.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Ioannou case\nFigure 10 from https://doi.org/10.1002/2017JC013158\n\nWe want to find the Ierapetra Eddy described above in a network demonstration run.\n" + "\nIoannou case\n============\nFigure 10 from https://doi.org/10.1002/2017JC013158\n\nWe want to find the Ierapetra Eddy described above in a network demonstration run.\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "import re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import arange, where\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.observations.network import NetworkObservations" + "import re\nfrom datetime import datetime, timedelta\n\nfrom matplotlib import colors\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import arange, array, pi, where\n\nfrom py_eddy_tracker.appli.gui import Anim\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.generic import coordinates_to_local\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.poly import fit_ellipse" ] }, { @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=\"full_axes\")\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.03, 0.06, 0.90, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.03, 0.06, 0.90, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" ] }, { @@ -80,7 +80,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Full Timeline\nThe network span for many years... How to cut the interesting part?\n\n" + "Full Timeline\n-------------\nThe network span for many years... How to cut the interesting part?\n\n" ] }, { @@ -98,7 +98,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Sub network and new numbering\nHere we chose to keep only the order 3 segments relatives to our chosen eddy\n\n" + "Sub network and new numbering\n-----------------------------\nHere we chose to keep only the order 3 segments relatives to our chosen eddy\n\n" ] }, { @@ -116,7 +116,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Anim\nQuick movie to see better!\n\n" + "Anim\n----\nQuick movie to see better!\n\n" ] }, { @@ -134,7 +134,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Classic display\n\n" + "Classic display\n---------------\n\n" ] }, { @@ -163,7 +163,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Latitude Timeline\n\n" + "Latitude Timeline\n-----------------\n\n" ] }, { @@ -181,7 +181,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Local radius timeline\nEffective (bold) and Speed (thin) Radius together\n\n" + "Local radius timeline\n---------------------\nEffective (bold) and Speed (thin) Radius together\n\n" ] }, { @@ -199,7 +199,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Parameters timeline\nEffective Radius\n\n" + "Parameters timeline\n-------------------\nEffective Radius\n\n" ] }, { @@ -230,6 +230,96 @@ "source": [ "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, \"shape_error_e\", vmin=14, vmax=70, **kw)\ncb = update_axes(ax, m[\"scatter\"])\ncb.set_label(\"Effective shape error\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Rotation angle\n--------------\nFor each obs, fit an ellipse to the contour, with theta the angle from the x-axis,\na the semi ax in x direction and b the semi ax in y dimension\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "theta_ = list()\na_ = list()\nb_ = list()\nfor obs in close_to_i3:\n x, y = obs[\"contour_lon_s\"], obs[\"contour_lat_s\"]\n x0_, y0_ = x.mean(), y.mean()\n x_, y_ = coordinates_to_local(x, y, x0_, y0_)\n x0, y0, a, b, theta = fit_ellipse(x_, y_)\n theta_.append(theta)\n a_.append(a)\n b_.append(b)\na_ = array(a_)\nb_ = array(b_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Theta\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, theta_, vmin=-pi / 2, vmax=pi / 2, cmap=\"hsv\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "a\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, a_ * 1e-3, vmin=0, vmax=80, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "b\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, b_ * 1e-3, vmin=0, vmax=80, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "a/b\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "ax = timeline_axes()\nm = close_to_i3.scatter_timeline(ax, a_ / b_, vmin=1, vmax=2, cmap=\"Spectral_r\")\n_ = update_axes(ax, m[\"scatter\"])" + ] } ], "metadata": { diff --git a/notebooks/python_module/16_network/pet_relative.ipynb b/notebooks/python_module/16_network/pet_relative.ipynb index 23537375..9f3fd3d9 100644 --- a/notebooks/python_module/16_network/pet_relative.ipynb +++ b/notebooks/python_module/16_network/pet_relative.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\nfrom numpy import where\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.observations.network import NetworkObservations" + "from matplotlib import pyplot as plt\nfrom numpy import where\n\nfrom py_eddy_tracker import data\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations" ] }, { @@ -447,7 +447,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=\"full_axes\")\nn.plot(ax, color_cycle=n.COLORS)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nax = fig.add_axes([0.08, 0.7, 0.7, 0.3])\n_ = n.display_timeline(ax)" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nax = fig.add_axes([0.08, 0.7, 0.7, 0.3])\n_ = n.display_timeline(ax)" ] }, { @@ -465,7 +465,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=\"full_axes\")\nn.plot(ax, color_cycle=n.COLORS)\nm1, m0, m0_stop = n.merging_event(triplet=True)\nm1.display(ax, color=\"violet\", lw=2, label=\"Eddies after merging\")\nm0.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies before merging\")\nm0_stop.display(ax, color=\"black\", lw=2, label=\"Eddies stopped by merging\")\nax.plot(m1.lon, m1.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(m0.lon, m0.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(m0_stop.lon, m0_stop.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nm1" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\nm1, m0, m0_stop = n.merging_event(triplet=True)\nm1.display(ax, color=\"violet\", lw=2, label=\"Eddies after merging\")\nm0.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies before merging\")\nm0_stop.display(ax, color=\"black\", lw=2, label=\"Eddies stopped by merging\")\nax.plot(m1.lon, m1.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(m0.lon, m0.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(m0_stop.lon, m0_stop.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nm1" ] }, { @@ -483,7 +483,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=\"full_axes\")\nn.plot(ax, color_cycle=n.COLORS)\ns0, s1, s1_start = n.spliting_event(triplet=True)\ns0.display(ax, color=\"violet\", lw=2, label=\"Eddies before splitting\")\ns1.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies after splitting\")\ns1_start.display(ax, color=\"black\", lw=2, label=\"Eddies starting by splitting\")\nax.plot(s0.lon, s0.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(s1.lon, s1.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(s1_start.lon, s1_start.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ns1" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nn.plot(ax, color_cycle=n.COLORS)\ns0, s1, s1_start = n.spliting_event(triplet=True)\ns0.display(ax, color=\"violet\", lw=2, label=\"Eddies before splitting\")\ns1.display(ax, color=\"blueviolet\", lw=2, label=\"Eddies after splitting\")\ns1_start.display(ax, color=\"black\", lw=2, label=\"Eddies starting by splitting\")\nax.plot(s0.lon, s0.lat, marker=\".\", color=\"purple\", ls=\"\")\nax.plot(s1.lon, s1.lat, marker=\".\", color=\"blueviolet\", ls=\"\")\nax.plot(s1_start.lon, s1_start.lat, marker=\".\", color=\"black\", ls=\"\")\nax.legend()\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ns1" ] }, { @@ -501,7 +501,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=\"full_axes\")\nbirth = n.birth_event()\nbirth.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nbirth" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\nbirth = n.birth_event()\nbirth.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\nbirth" ] }, { @@ -519,7 +519,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=\"full_axes\")\ndeath = n.death_event()\ndeath.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ndeath" + "fig = plt.figure(figsize=(15, 8))\nax = fig.add_axes([0.04, 0.06, 0.90, 0.88], projection=GUI_AXES)\ndeath = n.death_event()\ndeath.display(ax)\nax.set_xlim(17.5, 27.5), ax.set_ylim(31, 36), ax.grid()\ndeath" ] } ], @@ -539,7 +539,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/16_network/pet_replay_segmentation.ipynb b/notebooks/python_module/16_network/pet_replay_segmentation.ipynb index 7acb3f51..7c632138 100644 --- a/notebooks/python_module/16_network/pet_replay_segmentation.ipynb +++ b/notebooks/python_module/16_network/pet_replay_segmentation.ipynb @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "from datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import where\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=\"full_axes\")\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.04, 0.06, 0.89, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" + "from datetime import datetime, timedelta\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.ticker import FuncFormatter\nfrom numpy import where\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations\n\n\n@FuncFormatter\ndef formatter(x, pos):\n return (timedelta(x) + datetime(1950, 1, 1)).strftime(\"%d/%m/%Y\")\n\n\ndef start_axes(title=\"\"):\n fig = plt.figure(figsize=(13, 6))\n ax = fig.add_axes([0.03, 0.03, 0.90, 0.94], projection=GUI_AXES)\n ax.set_xlim(19, 29), ax.set_ylim(31, 35.5)\n ax.set_aspect(\"equal\")\n ax.set_title(title, weight=\"bold\")\n return ax\n\n\ndef timeline_axes(title=\"\"):\n fig = plt.figure(figsize=(15, 5))\n ax = fig.add_axes([0.04, 0.06, 0.89, 0.88])\n ax.set_title(title, weight=\"bold\")\n ax.xaxis.set_major_formatter(formatter), ax.grid()\n return ax\n\n\ndef update_axes(ax, mappable=None):\n ax.grid(True)\n if mappable:\n return plt.colorbar(mappable, cax=ax.figure.add_axes([0.94, 0.05, 0.01, 0.9]))" ] }, { @@ -172,7 +172,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.7.9" } }, "nbformat": 4, diff --git a/notebooks/python_module/16_network/pet_segmentation_anim.ipynb b/notebooks/python_module/16_network/pet_segmentation_anim.ipynb index 18da8478..0a546832 100644 --- a/notebooks/python_module/16_network/pet_segmentation_anim.ipynb +++ b/notebooks/python_module/16_network/pet_segmentation_anim.ipynb @@ -15,7 +15,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n# Network segmentation process\n" + "\nNetwork segmentation process\n============================\n" ] }, { @@ -26,7 +26,7 @@ }, "outputs": [], "source": [ - "# sphinx_gallery_thumbnail_number = 2\nimport re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.colors import ListedColormap\nfrom numpy import ones, where\n\nimport py_eddy_tracker.gui\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" + "# sphinx_gallery_thumbnail_number = 2\nimport re\n\nfrom matplotlib import pyplot as plt\nfrom matplotlib.animation import FuncAnimation\nfrom matplotlib.colors import ListedColormap\nfrom numpy import ones, where\n\nfrom py_eddy_tracker.data import get_demo_path\nfrom py_eddy_tracker.gui import GUI_AXES\nfrom py_eddy_tracker.observations.network import NetworkObservations\nfrom py_eddy_tracker.observations.tracking import TrackEddiesObservations" ] }, { @@ -37,7 +37,7 @@ }, "outputs": [], "source": [ - "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is use to create thumbnail which are not use but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef get_obs(dataset):\n \"Function to isolate a specific obs\"\n return where(\n (dataset.lat > 33)\n * (dataset.lat < 34)\n * (dataset.lon > 22)\n * (dataset.lon < 23)\n * (dataset.time > 20630)\n * (dataset.time < 20650)\n )[0][0]" + "class VideoAnimation(FuncAnimation):\n def _repr_html_(self, *args, **kwargs):\n \"\"\"To get video in html and have a player\"\"\"\n content = self.to_html5_video()\n return re.sub(\n r'width=\"[0-9]*\"\\sheight=\"[0-9]*\"', 'width=\"100%\" height=\"100%\"', content\n )\n\n def save(self, *args, **kwargs):\n if args[0].endswith(\"gif\"):\n # In this case gif is used to create thumbnail which is not used but consume same time than video\n # So we create an empty file, to save time\n with open(args[0], \"w\") as _:\n pass\n return\n return super().save(*args, **kwargs)\n\n\ndef get_obs(dataset):\n \"Function to isolate a specific obs\"\n return where(\n (dataset.lat > 33)\n * (dataset.lat < 34)\n * (dataset.lon > 22)\n * (dataset.lon < 23)\n * (dataset.time > 20630)\n * (dataset.time < 20650)\n )[0][0]" ] }, { @@ -62,7 +62,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Load data\nLoad data where observations are put in same network but no segmentation\n\n" + "Load data\n---------\nLoad data where observations are put in same network but no segmentation\n\n" ] }, { @@ -80,7 +80,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Do segmentation\nSegmentation based on maximum overlap, temporal window for candidates = 5 days\n\n" + "Do segmentation\n---------------\nSegmentation based on maximum overlap, temporal window for candidates = 5 days\n\n" ] }, { @@ -98,7 +98,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Anim\n\n" + "Anim\n----\n\n" ] }, { @@ -109,14 +109,14 @@ }, "outputs": [], "source": [ - "def update(i_frame):\n tr = TRACKS[i_frame]\n mappable_tracks.set_array(tr)\n s = 40 * ones(tr.shape)\n s[tr == 0] = 4\n mappable_tracks.set_sizes(s)\n\n indices_frames = INDICES[i_frame]\n mappable_CONTOUR.set_data(\n e.contour_lon_e[indices_frames],\n e.contour_lat_e[indices_frames],\n )\n mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])\n return (mappable_tracks,)\n\n\nfig = plt.figure(figsize=(16, 9), dpi=60)\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=\"full_axes\")\nax.set_title(f\"{len(e)} observations to segment\")\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\nvmax = TRACKS[-1].max()\ncmap = ListedColormap([\"gray\", *e.COLORS[:-1]], name=\"from_list\", N=vmax)\nmappable_tracks = ax.scatter(\n e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20\n)\nmappable_CONTOUR = ax.plot(\n e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]\n)[0]\nani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125)" + "def update(i_frame):\n tr = TRACKS[i_frame]\n mappable_tracks.set_array(tr)\n s = 40 * ones(tr.shape)\n s[tr == 0] = 4\n mappable_tracks.set_sizes(s)\n\n indices_frames = INDICES[i_frame]\n mappable_CONTOUR.set_data(\n e.contour_lon_e[indices_frames], e.contour_lat_e[indices_frames],\n )\n mappable_CONTOUR.set_color(cmap.colors[tr[indices_frames] % len(cmap.colors)])\n return (mappable_tracks,)\n\n\nfig = plt.figure(figsize=(16, 9), dpi=60)\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nax.set_title(f\"{len(e)} observations to segment\")\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\nvmax = TRACKS[-1].max()\ncmap = ListedColormap([\"gray\", *e.COLORS[:-1]], name=\"from_list\", N=vmax)\nmappable_tracks = ax.scatter(\n e.lon, e.lat, c=TRACKS[0], cmap=cmap, vmin=0, vmax=vmax, s=20\n)\nmappable_CONTOUR = ax.plot(\n e.contour_lon_e[INDICES[0]], e.contour_lat_e[INDICES[0]], color=cmap.colors[0]\n)[0]\nani = VideoAnimation(fig, update, frames=range(1, len(TRACKS), 4), interval=125)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Final Result\n\n" + "Final Result\n------------\n\n" ] }, { @@ -127,7 +127,7 @@ }, "outputs": [], "source": [ - "fig = plt.figure(figsize=(16, 9))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=\"full_axes\")\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\n_ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20)" + "fig = plt.figure(figsize=(16, 9))\nax = fig.add_axes([0.04, 0.06, 0.94, 0.88], projection=GUI_AXES)\nax.set_xlim(19, 29), ax.set_ylim(31, 35.5), ax.grid()\n_ = ax.scatter(e.lon, e.lat, c=TRACKS[-1], cmap=cmap, vmin=0, vmax=vmax, s=20)" ] } ], diff --git a/notebooks/python_module/16_network/pet_something_cool.ipynb b/notebooks/python_module/16_network/pet_something_cool.ipynb new file mode 100644 index 00000000..158852f9 --- /dev/null +++ b/notebooks/python_module/16_network/pet_something_cool.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# essai\n\non tente des trucs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import cartopy.crs as ccrs\nimport cartopy.feature as cfeature\nimport numpy as np\nfrom matplotlib import pyplot as plt\n\nfrom py_eddy_tracker.observations.network import NetworkObservations\n\n\ndef rect_from_extent(extent):\n rect_lon = [extent[0], extent[1], extent[1], extent[0], extent[0]]\n rect_lat = [extent[2], extent[2], extent[3], extent[3], extent[2]]\n return rect_lon, rect_lat\n\n\ndef indice_from_extent(lon, lat, extent):\n mask = (lon > extent[0]) * (lon < extent[1]) * (lat > extent[2]) * (lat < extent[3])\n return np.where(mask)[0]\n\n\nfichier = \"/data/adelepoulle/work/Eddies/20201217_network_build/big_network.nc\"\nnetwork = NetworkObservations.load_file(fichier)\nsub_network = network.network(1078566)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# extent_begin = [0, 2, -50, -15]\n# extent_end = [-42, -35, -40, -10]\n\nextent_begin = [2, 22, -50, -30]\ni_obs_begin = indice_from_extent(\n sub_network.longitude, sub_network.latitude, extent_begin\n)\nnetwork_begin = sub_network.find_link(i_obs_begin)\ntime_mini = network_begin.time.min()\ntime_maxi = network_begin.time.max()\n\nextent_end = [-52, -45, -37, -33]\ni_obs_end = indice_from_extent(\n (network_begin.longitude + 180) % 360 - 180, network_begin.latitude, extent_end\n)\nnetwork_end = network_begin.find_link(i_obs_end, forward=False, backward=True)\n\n\ndatasets = [network_begin, network_end]\nextents = [extent_begin, extent_end]\nfig, (ax1, ax2) = plt.subplots(\n 2, 1, figsize=(10, 9), dpi=140, subplot_kw={\"projection\": ccrs.PlateCarree()}\n)\n\nfor ax, dataset, extent in zip([ax1, ax2], datasets, extents):\n sca = dataset.scatter(\n ax,\n name=\"time\",\n cmap=\"Spectral_r\",\n label=\"observation dans le temps\",\n vmin=time_mini,\n vmax=time_maxi,\n )\n\n x, y = rect_from_extent(extent)\n ax.fill(x, y, color=\"grey\", alpha=0.3, label=\"observations choisies\")\n # ax.plot(x, y, marker='o')\n\n ax.legend()\n\n gridlines = ax.gridlines(\n alpha=0.2, color=\"black\", linestyle=\"dotted\", draw_labels=True, dms=True\n )\n\n gridlines.left_labels = False\n gridlines.top_labels = False\n\n ax.coastlines()\n ax.add_feature(cfeature.LAND)\n ax.add_feature(cfeature.LAKES, zorder=10)\n ax.add_feature(cfeature.BORDERS, lw=0.25)\n ax.add_feature(cfeature.OCEAN, alpha=0.2)\n\n\nax1.set_title(\n \"Recherche du d\u00e9placement de l'eau dans les eddies \u00e0 travers les observations choisies\"\n)\nax2.set_title(\"Recherche de la provenance de l'eau \u00e0 travers les observations choisies\")\nax2.set_extent(ax1.get_extent(), ccrs.PlateCarree())\n\nfig.subplots_adjust(right=0.87, left=0.02)\ncbar_ax = fig.add_axes([0.90, 0.1, 0.02, 0.8])\ncbar = fig.colorbar(sca[\"scatter\"], cax=cbar_ax, orientation=\"vertical\")\n_ = cbar.set_label(\"time (jj)\", rotation=270, labelpad=-65)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index eae54426..556cabbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,11 @@ -matplotlib -netCDF4 -numba -numpy +matplotlib < 3.8 # need an update of contour management opencv-python pint polygon3 pyyaml requests scipy -zarr -# for binder -pyeddytrackersample \ No newline at end of file +zarr < 3.0 +netCDF4 +numpy +numba \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 00000000..a005c37d --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,7 @@ +-r requirements.txt +isort +black +blackdoc +flake8 +pytest +pytest-cov \ No newline at end of file diff --git a/requirements_doc.txt b/requirements_doc.txt deleted file mode 100644 index 0d926b32..00000000 --- a/requirements_doc.txt +++ /dev/null @@ -1,16 +0,0 @@ -matplotlib -netCDF4 -numba -numpy -opencv-python -pint -polygon3 -pyyaml -requests -scipy -zarr -# doc -sphinx-gallery -pyeddytrackersample -sphinx_rtd_theme -sphinx>=3.1 diff --git a/setup.cfg b/setup.cfg index dfed5c3b..7e773ae8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,29 @@ + +[yapf] +column_limit = 100 + [flake8] +max-line-length = 140 ignore = - E203, # whitespace before ':' - W503, # line break before binary operator + E203, + W503, +exclude= + build + doc + versioneer.py + +[isort] +combine_as_imports=True +force_grid_wrap=0 +force_sort_within_sections=True +force_to_top=typing +include_trailing_comma=True +line_length=140 +multi_line_output=3 +skip= + build + doc/conf.py + [versioneer] VCS = git @@ -13,4 +35,5 @@ parentdir_prefix = [tool:pytest] filterwarnings= - ignore:tostring.*is deprecated \ No newline at end of file + ignore:tostring.*is deprecated + diff --git a/setup.py b/setup.py index 06432bd1..7b836763 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- +from setuptools import find_packages, setup + import versioneer -from setuptools import setup, find_packages with open("README.md", "r") as fh: long_description = fh.read() @@ -9,7 +10,7 @@ setup( name="pyEddyTracker", - python_requires=">=3.7", + python_requires=">=3.10", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), description="Py-Eddy-Tracker libraries", @@ -48,6 +49,7 @@ "EddyNetworkGroup = py_eddy_tracker.appli.network:build_network", "EddyNetworkBuildPath = py_eddy_tracker.appli.network:divide_network", "EddyNetworkSubSetter = py_eddy_tracker.appli.network:subset_network", + "EddyNetworkQuickCompare = py_eddy_tracker.appli.network:quick_compare", # anim/gui "EddyAnim = py_eddy_tracker.appli.gui:anim", "GUIEddy = py_eddy_tracker.appli.gui:guieddy", diff --git a/share/fig.py b/share/fig.py index 8640abcb..80c7f12b 100644 --- a/share/fig.py +++ b/share/fig.py @@ -1,8 +1,10 @@ -from matplotlib import pyplot as plt -from py_eddy_tracker.dataset.grid import RegularGridDataset from datetime import datetime import logging +from matplotlib import pyplot as plt + +from py_eddy_tracker.dataset.grid import RegularGridDataset + grid_name, lon_name, lat_name = ( "nrt_global_allsat_phy_l4_20190223_20190226.nc", "longitude", diff --git a/share/nrt_global_allsat_phy_l4_20190223_20190226.nc b/share/nrt_global_allsat_phy_l4_20190223_20190226.nc new file mode 120000 index 00000000..077ce7e6 --- /dev/null +++ b/share/nrt_global_allsat_phy_l4_20190223_20190226.nc @@ -0,0 +1 @@ +../src/py_eddy_tracker/data/nrt_global_allsat_phy_l4_20190223_20190226.nc \ No newline at end of file diff --git a/share/tracking.yaml b/share/tracking.yaml index 0f8766b8..d6264104 100644 --- a/share/tracking.yaml +++ b/share/tracking.yaml @@ -1,13 +1,14 @@ PATHS: - # Files produces with EddyIdentification + # Files produced with EddyIdentification FILES_PATTERN: /home/emason/toto/Anticyclonic_*.nc - # Path for saving of outputs + # Path to save outputs SAVE_DIR: '/home/emason/toto/' -# Minimum number of observations to store eddy -TRACK_DURATION_MIN: 4 +# Number of consecutive timesteps with missing detection allowed VIRTUAL_LENGTH_MAX: 0 +# Minimal number of timesteps to considered as a long trajectory +TRACK_DURATION_MIN: 4 -#CLASS: -# MODULE: py_eddy_tracker.featured_tracking.old_tracker_reference -# CLASS: CheltonTracker +CLASS: + MODULE: py_eddy_tracker.featured_tracking.area_tracker + CLASS: AreaTracker diff --git a/src/py_eddy_tracker/__init__.py b/src/py_eddy_tracker/__init__.py index 46946e77..7115bf67 100644 --- a/src/py_eddy_tracker/__init__.py +++ b/src/py_eddy_tracker/__init__.py @@ -20,8 +20,9 @@ """ -import logging from argparse import ArgumentParser +from datetime import datetime +import logging import zarr @@ -31,15 +32,13 @@ del get_versions -def start_logger(): - FORMAT_LOG = ( - "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t\t\t\t\t%(message)s" - ) +def start_logger(color=True): + FORMAT_LOG = "%(levelname)-8s %(asctime)s %(module)s.%(funcName)s :\n\t%(message)s" logger = logging.getLogger("pet") if len(logger.handlers) == 0: # set up logging to CONSOLE console = logging.StreamHandler() - console.setFormatter(ColoredFormatter(FORMAT_LOG)) + console.setFormatter(ColoredFormatter(FORMAT_LOG, color=color)) # add the handler to the root logger logger.addHandler(console) return logger @@ -54,13 +53,14 @@ class ColoredFormatter(logging.Formatter): DEBUG="\033[34m\t", ) - def __init__(self, message): + def __init__(self, message, color=True): super().__init__(message) + self.with_color = color def format(self, record): color = self.COLOR_LEVEL.get(record.levelname, "") color_reset = "\033[0m" - model = color + "%s" + color_reset + model = (color + "%s" + color_reset) if self.with_color else "%s" record.msg = model % record.msg record.funcName = model % record.funcName record.module = model % record.module @@ -108,12 +108,26 @@ def parse_args(self, *args, **kwargs): return opts +TIME_MODELS = ["%Y%m%d", "%Y%m%d%H%M%S", "%Y%m%dT%H%M%S"] + + +def identify_time(str_date): + for model in TIME_MODELS: + try: + return datetime.strptime(str_date, model) + except ValueError: + pass + raise Exception("No time model found") + + VAR_DESCR = dict( time=dict( attr_name="time", nc_name="time", old_nc_name=["j1"], - nc_type="int32", + nc_type="float64", + output_type="uint32", + scale_factor=1 / 86400.0, nc_dims=("obs",), nc_attr=dict( standard_name="time", @@ -253,7 +267,7 @@ def parse_args(self, *args, **kwargs): old_nc_name=["A"], nc_type="float32", output_type="uint16", - scale_factor=0.001, + scale_factor=0.0001, nc_dims=("obs",), nc_attr=dict( long_name="Amplitude", @@ -391,7 +405,7 @@ def parse_args(self, *args, **kwargs): nc_dims=("obs",), nc_attr=dict( long_name="Previous observation index", - comment="Index of previous observation in a spliting case", + comment="Index of previous observation in a splitting case", ), ), next_obs=dict( @@ -683,3 +697,6 @@ def parse_args(self, *args, **kwargs): VAR_DESCR_inv[VAR_DESCR[key]["nc_name"]] = key for key_old in VAR_DESCR[key].get("old_nc_name", list()): VAR_DESCR_inv[key_old] = key + +from . import _version +__version__ = _version.get_versions()['version'] diff --git a/src/py_eddy_tracker/_version.py b/src/py_eddy_tracker/_version.py index 44367e3a..589e706f 100644 --- a/src/py_eddy_tracker/_version.py +++ b/src/py_eddy_tracker/_version.py @@ -1,11 +1,13 @@ + # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -14,9 +16,11 @@ import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -32,8 +36,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -51,41 +62,50 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} - +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -96,18 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -116,64 +138,64 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -186,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -199,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -207,30 +229,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -241,7 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -249,33 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -284,16 +341,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -302,12 +360,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -318,26 +374,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -355,29 +412,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -404,12 +510,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -426,7 +561,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -446,7 +581,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -466,26 +601,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -495,16 +632,12 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -515,7 +648,8 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) except NotThisMethod: pass @@ -524,16 +658,13 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -547,10 +678,6 @@ def get_versions(): except NotThisMethod: pass - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} diff --git a/src/py_eddy_tracker/appli/eddies.py b/src/py_eddy_tracker/appli/eddies.py index d5a727f9..c1c7a90d 100644 --- a/src/py_eddy_tracker/appli/eddies.py +++ b/src/py_eddy_tracker/appli/eddies.py @@ -3,19 +3,18 @@ Applications on detection and tracking files """ import argparse -import logging from datetime import datetime from glob import glob +import logging from os import mkdir -from os.path import basename, dirname, exists -from os.path import join as join_path +from os.path import basename, dirname, exists, join as join_path from re import compile as re_compile from netCDF4 import Dataset from numpy import bincount, bytes_, empty, in1d, unique from yaml import safe_load -from .. import EddyParser +from .. import EddyParser, identify_time from ..observations.observation import EddiesObservations, reverse_index from ..observations.tracking import TrackEddiesObservations from ..tracking import Correspondances @@ -163,7 +162,12 @@ def eddies_tracking(): parser.add_argument( "--zarr", action="store_true", help="Output will be wrote in zarr" ) - parser.add_argument("--unraw", action="store_true", help="Load unraw data") + parser.add_argument( + "--unraw", + action="store_true", + help="Load unraw data, use only for netcdf." + "If unraw is active, netcdf is loaded without apply scalefactor and add_offset.", + ) parser.add_argument( "--blank_period", type=int, @@ -223,7 +227,7 @@ def browse_dataset_in( data_dir, files_model, date_regexp, - date_model, + date_model=None, start_date=None, end_date=None, sub_sampling_step=1, @@ -239,10 +243,7 @@ def browse_dataset_in( dataset_list = empty( len(filenames), - dtype=[ - ("filename", "S500"), - ("date", "datetime64[D]"), - ], + dtype=[("filename", "S500"), ("date", "datetime64[s]")], ) dataset_list["filename"] = filenames @@ -268,13 +269,15 @@ def browse_dataset_in( str_date = result.groups()[0] if str_date is not None: - item["date"] = datetime.strptime(str_date, date_model).date() + if date_model is None: + item["date"] = identify_time(str_date) + else: + item["date"] = datetime.strptime(str_date, date_model) dataset_list.sort(order=["date", "filename"]) - steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1]) if len(steps) > 1: - raise Exception("Several days steps in grid dataset %s" % steps) + raise Exception("Several timesteps in grid dataset %s" % steps) if sub_sampling_step != 1: logger.info("Grid subsampling %d", sub_sampling_step) @@ -304,7 +307,7 @@ def track( correspondances_only=False, **kw_c, ): - kw = dict(date_regexp=".*_([0-9]*?).[nz].*", date_model="%Y%m%d") + kw = dict(date_regexp=".*_([0-9]*?).[nz].*") if isinstance(pattern, list): kw.update(dict(data_dir=None, files_model=None, files=pattern)) else: @@ -323,10 +326,9 @@ def track( c = Correspondances(datasets=datasets["filename"], **kw_c) c.track() logger.info("Track finish") - t0, t1 = c.period kw_save = dict( - date_start=t0, - date_stop=t1, + date_start=datasets["date"][0], + date_stop=datasets["date"][-1], date_prod=datetime.now(), path=output_dir, sign_type=c.current_obs.sign_legend, @@ -351,11 +353,13 @@ def track( short_c = c._copy() short_c.shorter_than(size_max=nb_obs_min) - c.longer_than(size_min=nb_obs_min) - - long_track = c.merge(raw_data=raw) short_track = short_c.merge(raw_data=raw) + if c.longer_than(size_min=nb_obs_min) is False: + long_track = short_track.empty_dataset() + else: + long_track = c.merge(raw_data=raw) + # We flag obs if c.virtual: long_track["virtual"][:] = long_track["time"] == 0 diff --git a/src/py_eddy_tracker/appli/grid.py b/src/py_eddy_tracker/appli/grid.py index 7f2b9610..099465ee 100644 --- a/src/py_eddy_tracker/appli/grid.py +++ b/src/py_eddy_tracker/appli/grid.py @@ -3,9 +3,8 @@ All entry point to manipulate grid """ from argparse import Action -from datetime import datetime -from .. import EddyParser +from .. import EddyParser, identify_time from ..dataset.grid import RegularGridDataset, UnRegularGridDataset @@ -121,7 +120,7 @@ def eddy_id(args=None): cut_wavelength = [0, *cut_wavelength] inf_bnds, upper_bnds = cut_wavelength - date = datetime.strptime(args.datetime, "%Y%m%d") + date = identify_time(args.datetime) kwargs = dict( step=args.isoline_step, shape_error=args.fit_errmax, @@ -150,7 +149,7 @@ def eddy_id(args=None): sampling_method=args.sampling_method, **kwargs, ) - out_name = date.strftime("%(path)s/%(sign_type)s_%Y%m%d.nc") + out_name = date.strftime("%(path)s/%(sign_type)s_%Y%m%dT%H%M%S.nc") a.write_file(path=args.path_out, filename=out_name, zarr_flag=args.zarr) c.write_file(path=args.path_out, filename=out_name, zarr_flag=args.zarr) diff --git a/src/py_eddy_tracker/appli/gui.py b/src/py_eddy_tracker/appli/gui.py index 427db24b..c3d7619b 100644 --- a/src/py_eddy_tracker/appli/gui.py +++ b/src/py_eddy_tracker/appli/gui.py @@ -3,15 +3,15 @@ Entry point of graphic user interface """ -import logging from datetime import datetime, timedelta from itertools import chain +import logging from matplotlib import pyplot from matplotlib.animation import FuncAnimation from matplotlib.axes import Axes from matplotlib.collections import LineCollection -from numpy import arange, where +from numpy import arange, where, nan from .. import EddyParser from ..gui import GUI @@ -58,7 +58,10 @@ def setup( self.kw_label["fontweight"] = kwargs.pop("fontweight", "demibold") # To text each visible eddy if field_txt: - self.field_txt = self.eddy[field_txt] + if isinstance(field_txt,str): + self.field_txt = self.eddy[field_txt] + else : + self.field_txt=field_txt if field_color: # To color each visible eddy self.field_color = self.eddy[field_color].astype("f4") diff --git a/src/py_eddy_tracker/appli/network.py b/src/py_eddy_tracker/appli/network.py index 90450078..0a3d06ca 100644 --- a/src/py_eddy_tracker/appli/network.py +++ b/src/py_eddy_tracker/appli/network.py @@ -5,6 +5,8 @@ import logging +from numpy import in1d, zeros + from .. import EddyParser from ..observations.network import Network, NetworkObservations from ..observations.tracking import TrackEddiesObservations @@ -21,6 +23,25 @@ def build_network(): parser.add_argument( "--window", "-w", type=int, help="Half time window to search eddy", default=1 ) + + parser.add_argument( + "--min-overlap", + "-p", + type=float, + help="minimum overlap area to associate observations", + default=0.2, + ) + parser.add_argument( + "--minimal-area", + action="store_true", + help="If True, use intersection/little polygon, else intersection/union", + ) + parser.add_argument( + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", + ) + parser.contour_intern_arg() parser.memory_arg() @@ -32,18 +53,39 @@ def build_network(): intern=args.intern, memory=args.memory, ) - group = n.group_observations(minimal_area=True) + group = n.group_observations( + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, + ) n.build_dataset(group).write_file(filename=args.out) def divide_network(): - parser = EddyParser("Separate path for a same group(network)") + parser = EddyParser("Separate path for a same group (network)") parser.add_argument("input", help="input network file") parser.add_argument("out", help="output file") parser.contour_intern_arg() parser.add_argument( "--window", "-w", type=int, help="Half time window to search eddy", default=1 ) + parser.add_argument( + "--min-overlap", + "-p", + type=float, + help="minimum overlap area to associate observations", + default=0.2, + ) + parser.add_argument( + "--minimal-area", + action="store_true", + help="If True, use intersection/little polygon, else intersection/union", + ) + parser.add_argument( + "--hybrid-area", + action="store_true", + help="If True, use minimal-area method if overlap is under min overlap, else intersection/union", + ) args = parser.parse_args() contour_name = TrackEddiesObservations.intern(args.intern, public_label=True) e = TrackEddiesObservations.load_file( @@ -52,7 +94,13 @@ def divide_network(): ) n = NetworkObservations.from_split_network( TrackEddiesObservations.load_file(args.input, raw_data=True), - e.split_network(intern=args.intern, window=args.window), + e.split_network( + intern=args.intern, + window=args.window, + min_overlap=args.min_overlap, + minimal_area=args.minimal_area, + hybrid_area=args.hybrid_area, + ), ) n.write_file(filename=args.out) @@ -66,7 +114,7 @@ def subset_network(): "--length", nargs=2, type=int, - help="Nb of day which must be cover by network, first minimum number of day and last maximum number of day," + help="Nb of days that must be covered by the network, first minimum number of day and last maximum number of day," "if value is negative, this bound won't be used", ) parser.add_argument( @@ -80,16 +128,21 @@ def subset_network(): action="store_true", help="Remove trash (network id == 0)", ) + parser.add_argument( + "-i", "--ids", nargs="+", type=int, help="List of network which will be extract" + ) parser.add_argument( "-p", "--period", nargs=2, type=int, - help="Start day and end day, if it's negative value we will add to day min and add to day max," - "if 0 it s not use", + help="Start day and end day, if it's a negative value we will add to day min and add to day max," + "if 0 it is not used", ) args = parser.parse_args() n = NetworkObservations.load_file(args.input, raw_data=True) + if args.ids is not None: + n = n.networks(args.ids) if args.length is not None: n = n.longer_than(*args.length) if args.remove_dead_end is not None: @@ -97,3 +150,153 @@ def subset_network(): if args.period is not None: n = n.extract_with_period(args.period) n.write_file(filename=args.out) + + +def quick_compare(): + parser = EddyParser( + """Tool to have a quick comparison between several network: + - N : network + - S : segment + - Obs : observations + """ + ) + parser.add_argument("ref", help="Identification file of reference") + parser.add_argument("others", nargs="+", help="Identifications files to compare") + parser.add_argument( + "--path_out", default=None, help="Save each group in separate file" + ) + args = parser.parse_args() + + kw = dict( + include_vars=[ + "longitude", + "latitude", + "time", + "track", + "segment", + "next_obs", + "previous_obs", + ] + ) + + if args.path_out is not None: + kw = dict() + + ref = NetworkObservations.load_file(args.ref, **kw) + print( + f"[ref] {args.ref} -> {ref.nb_network} network / {ref.nb_segment} segment / {len(ref)} obs " + f"-> {ref.network_size(0)} trash obs, " + f"{len(ref.merging_event())} merging, {len(ref.splitting_event())} spliting" + ) + others = { + other: NetworkObservations.load_file(other, **kw) for other in args.others + } + + # if args.path_out is not None: + # groups_ref, groups_other = run_compare(ref, others, **kwargs) + # if not exists(args.path_out): + # mkdir(args.path_out) + # for i, other_ in enumerate(args.others): + # dirname_ = f"{args.path_out}/{other_.replace('/', '_')}/" + # if not exists(dirname_): + # mkdir(dirname_) + # for k, v in groups_other[other_].items(): + # basename_ = f"other_{k}.nc" + # others[other_].index(v).write_file(filename=f"{dirname_}/{basename_}") + # for k, v in groups_ref[other_].items(): + # basename_ = f"ref_{k}.nc" + # ref.index(v).write_file(filename=f"{dirname_}/{basename_}") + # return + display_compare(ref, others) + + +def run_compare(ref, others): + outs = dict() + for i, (k, other) in enumerate(others.items()): + out = dict() + print( + f"[{i}] {k} -> {other.nb_network} network / {other.nb_segment} segment / {len(other)} obs " + f"-> {other.network_size(0)} trash obs, " + f"{len(other.merging_event())} merging, {len(other.splitting_event())} spliting" + ) + ref_id, other_id = ref.identify_in(other, size_min=2) + m = other_id != -1 + ref_id, other_id = ref_id[m], other_id[m] + out["same N(N)"] = m.sum() + out["same N(Obs)"] = ref.network_size(ref_id).sum() + + # For network which have same obs + ref_, other_ = ref.networks(ref_id), other.networks(other_id) + ref_segu, other_segu = ref_.identify_in(other_, segment=True) + m = other_segu == -1 + ref_track_no_match, _ = ref_.unique_segment_to_id(ref_segu[m]) + ref_segu, other_segu = ref_segu[~m], other_segu[~m] + m = ~in1d(ref_id, ref_track_no_match) + out["same NS(N)"] = m.sum() + out["same NS(Obs)"] = ref.network_size(ref_id[m]).sum() + + # Check merge/split + def follow_obs(d, i_follow): + m = i_follow != -1 + i_follow = i_follow[m] + t, x, y = ( + zeros(m.size, d.time.dtype), + zeros(m.size, d.longitude.dtype), + zeros(m.size, d.latitude.dtype), + ) + t[m], x[m], y[m] = ( + d.time[i_follow], + d.longitude[i_follow], + d.latitude[i_follow], + ) + return t, x, y + + def next_obs(d, i_seg): + last_i = d.index_segment_track[1][i_seg] - 1 + return follow_obs(d, d.next_obs[last_i]) + + def previous_obs(d, i_seg): + first_i = d.index_segment_track[0][i_seg] + return follow_obs(d, d.previous_obs[first_i]) + + tref, xref, yref = next_obs(ref_, ref_segu) + tother, xother, yother = next_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + tref, xref, yref = previous_obs(ref_, ref_segu) + tother, xother, yother = previous_obs(other_, other_segu) + + m = (tref == tother) & (xref == xother) & (yref == yother) + print(m.sum(), m.size, ref_segu.size, ref_track_no_match.size) + + ref_segu, other_segu = ref.identify_in(other, segment=True) + m = other_segu != -1 + out["same S(S)"] = m.sum() + out["same S(Obs)"] = ref.segment_size()[ref_segu[m]].sum() + + outs[k] = out + return outs + + +def display_compare(ref, others): + def display(value, ref=None): + if ref: + outs = [f"{v / ref[k] * 100:.1f}% ({v})" for k, v in value.items()] + else: + outs = value + return "".join([f"{v:^18}" for v in outs]) + + datas = run_compare(ref, others) + ref_ = { + "same N(N)": ref.nb_network, + "same N(Obs)": len(ref), + "same NS(N)": ref.nb_network, + "same NS(Obs)": len(ref), + "same S(S)": ref.nb_segment, + "same S(Obs)": len(ref), + } + print(" ", display(ref_.keys())) + for i, (_, v) in enumerate(datas.items()): + print(f"[{i:2}] ", display(v, ref=ref_)) diff --git a/src/py_eddy_tracker/data/Anticyclonic_20190223.nc b/src/py_eddy_tracker/data/Anticyclonic_20190223.nc index ce48c8d6..4ab8f226 100644 Binary files a/src/py_eddy_tracker/data/Anticyclonic_20190223.nc and b/src/py_eddy_tracker/data/Anticyclonic_20190223.nc differ diff --git a/src/py_eddy_tracker/data/__init__.py b/src/py_eddy_tracker/data/__init__.py index 4702af8f..bf062983 100644 --- a/src/py_eddy_tracker/data/__init__.py +++ b/src/py_eddy_tracker/data/__init__.py @@ -8,10 +8,11 @@ 20160515 adt None None longitude latitude . \ --cut 800 --fil 1 """ + import io import lzma -import tarfile from os import path +import tarfile import requests @@ -26,14 +27,20 @@ def get_remote_demo_sample(path): if path.endswith(".nc"): return io.BytesIO(content) else: - if path.endswith(".nc"): + try: + import py_eddy_tracker_sample_id + if path.endswith(".nc"): + return py_eddy_tracker_sample_id.get_remote_demo_sample(path) + content = open(py_eddy_tracker_sample_id.get_remote_demo_sample(f"{path}.tar.xz"), "rb").read() + except: + if path.endswith(".nc"): + content = requests.get( + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}" + ).content + return io.BytesIO(content) content = requests.get( - f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}" + f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" ).content - return io.BytesIO(content) - content = requests.get( - f"https://github.com/AntSimi/py-eddy-tracker-sample-id/raw/master/{path}.tar.xz" - ).content # Tar module could manage lzma tar, but it will apply uncompress for each extractfile tar = tarfile.open(mode="r", fileobj=io.BytesIO(lzma.decompress(content))) diff --git a/src/py_eddy_tracker/data/loopers_lumpkin_med.nc b/src/py_eddy_tracker/data/loopers_lumpkin_med.nc new file mode 100644 index 00000000..cf817424 Binary files /dev/null and b/src/py_eddy_tracker/data/loopers_lumpkin_med.nc differ diff --git a/src/py_eddy_tracker/dataset/grid.py b/src/py_eddy_tracker/dataset/grid.py index 6a4624dd..f15503b2 100644 --- a/src/py_eddy_tracker/dataset/grid.py +++ b/src/py_eddy_tracker/dataset/grid.py @@ -2,14 +2,14 @@ """ Class to load and manipulate RegularGrid and UnRegularGrid """ -import logging from datetime import datetime +import logging from cv2 import filter2D from matplotlib.path import Path as BasePath from netCDF4 import Dataset -from numba import njit, prange -from numba import types as numba_types +from numba import njit, prange, types as numba_types +import numpy as np from numpy import ( arange, array, @@ -28,9 +28,7 @@ isnan, linspace, ma, -) -from numpy import mean as np_mean -from numpy import ( + mean as np_mean, meshgrid, nan, nanmean, @@ -38,9 +36,9 @@ percentile, pi, radians, - round_, sin, sinc, + sqrt, where, zeros, ) @@ -52,6 +50,7 @@ from scipy.special import j1 from .. import VAR_DESCR +from ..data import get_demo_path from ..eddy_feature import Amplitude, Contours from ..generic import ( bbox_indice_regular, @@ -125,7 +124,7 @@ def value_on_regular_contour(x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size @njit(cache=True) def mean_on_regular_contour( - x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=None, nan_remove=False + x_g, y_g, z_g, m_g, vertices, num_fac=2, fixed_size=-1, nan_remove=False ): x_val, y_val = vertices[:, 0], vertices[:, 1] x_new, y_new = uniform_resample(x_val, y_val, num_fac, fixed_size) @@ -153,7 +152,7 @@ def _circle_from_equal_area(vertice): # last coordinates == first lon0, lat0 = lons[1:].mean(), lats[1:].mean() c_x, c_y = coordinates_to_local(lons, lats, lon0, lat0) - # Some time, edge is only a dot of few coordinates + # Sometimes, edge is only a dot of few coordinates d_lon = lons.max() - lons.min() d_lat = lats.max() - lats.min() if d_lon < 1e-7 and d_lat < 1e-7: @@ -239,7 +238,7 @@ def nb_pixel(self): class GridDataset(object): """ - Class to have basic tool on NetCDF Grid + Class for basic tools on NetCDF Grid """ __slots__ = ( @@ -258,6 +257,7 @@ class GridDataset(object): "global_attrs", "vars", "contours", + "nan_mask", ) GRAVITY = 9.807 @@ -267,15 +267,23 @@ class GridDataset(object): N = 1 def __init__( - self, filename, x_name, y_name, centered=None, indexs=None, unset=False + self, + filename, + x_name, + y_name, + centered=None, + indexs=None, + unset=False, + nan_masking=False, ): """ :param str filename: Filename to load :param str x_name: Name of longitude coordinates :param str y_name: Name of latitude coordinates :param bool,None centered: Allow to know how coordinates could be used with pixel - :param dict indexs: A dictionary which set indexs to use for non-coordinate dimensions + :param dict indexs: A dictionary that sets indexes to use for non-coordinate dimensions :param bool unset: Set to True to create an empty grid object without file + :param bool nan_masking: Set to True to replace data.mask with isnan method result """ self.dimensions = None self.variables_description = None @@ -286,6 +294,7 @@ def __init__( self.y_bounds = None self.x_dim = None self.y_dim = None + self.nan_mask = nan_masking self.centered = centered self.contours = None self.filename = filename @@ -294,12 +303,29 @@ def __init__( self.indexs = dict() if indexs is None else indexs if centered is None: logger.warning( - "We assume pixel position of grid is center for %s", filename + "We assume pixel position of grid is centered for %s", filename ) if not unset: + self.populate() + + def populate(self): + if self.dimensions is None: self.load_general_features() self.load() + def clean(self): + self.dimensions = None + self.variables_description = None + self.global_attrs = None + self.x_c = None + self.y_c = None + self.x_bounds = None + self.y_bounds = None + self.x_dim = None + self.y_dim = None + self.contours = None + self.vars = dict() + @property def is_centered(self): """Give True if pixel is described with its center's position or @@ -314,7 +340,7 @@ def is_centered(self): return self.centered def load_general_features(self): - """Load attrs to be stored in object""" + """Load attrs to be stored in object""" logger.debug( "Load general feature from %(filename)s", dict(filename=self.filename) ) @@ -392,12 +418,20 @@ def load(self): self.setup_coordinates() + @staticmethod + def get_mask(a): + if len(a.mask.shape): + m = a.mask + else: + m = ones(a.shape, dtype="bool") if a.mask else zeros(a.shape, dtype="bool") + return m + @staticmethod def c_to_bounds(c): """ - Centred coordinates to bounds coordinates + Centered coordinates to bounds coordinates - :param array c: centred coordinates to translate + :param array c: centered coordinates to translate :return: bounds coordinates """ bounds = concatenate((c, (2 * c[-1] - c[-2],))) @@ -409,9 +443,9 @@ def c_to_bounds(c): def setup_coordinates(self): x_name, y_name = self.coordinates if self.is_centered: - logger.info("Grid center") - self.x_c = self.vars[x_name].astype("float64") - self.y_c = self.vars[y_name].astype("float64") + # logger.info("Grid center") + self.x_c = array(self.vars[x_name].astype("float64")) + self.y_c = array(self.vars[y_name].astype("float64")) self.x_bounds = concatenate((self.x_c, (2 * self.x_c[-1] - self.x_c[-2],))) self.y_bounds = concatenate((self.y_c, (2 * self.y_c[-1] - self.y_c[-2],))) @@ -423,8 +457,8 @@ def setup_coordinates(self): self.y_bounds[-1] -= d_y[-1] / 2 else: - self.x_bounds = self.vars[x_name].astype("float64") - self.y_bounds = self.vars[y_name].astype("float64") + self.x_bounds = array(self.vars[x_name].astype("float64")) + self.y_bounds = array(self.vars[y_name].astype("float64")) if len(self.x_dim) == 1: self.x_c = self.x_bounds.copy() @@ -519,6 +553,11 @@ def grid(self, varname, indexs=None): if i_x > i_y: self.variables_description[varname]["infos"]["transpose"] = True self.vars[varname] = self.vars[varname].T + if self.nan_mask: + self.vars[varname] = ma.array( + self.vars[varname], + mask=isnan(self.vars[varname]), + ) if not hasattr(self.vars[varname], "mask"): self.vars[varname] = ma.array( self.vars[varname], @@ -558,7 +597,7 @@ def grid_tiles(self, varname, slice_x, slice_y): return data def high_filter(self, grid_name, w_cut, **kwargs): - """Return the grid high-pass filtered, by substracting to the grid the low-pass filter (default: order=1) + """Return the high-pass filtered grid, by substracting to the initial grid the low-pass filtered grid (default: order=1) :param grid_name: the name of the grid :param int, w_cut: the half-power wavelength cutoff (km) @@ -567,7 +606,7 @@ def high_filter(self, grid_name, w_cut, **kwargs): self.vars[grid_name] -= result def low_filter(self, grid_name, w_cut, **kwargs): - """Return the grid low-pass filtered (default: order=1) + """Return the low-pass filtered grid (default: order=1) :param grid_name: the name of the grid :param int, w_cut: the half-power wavelength cutoff (km) @@ -593,6 +632,7 @@ def eddy_identification( date, step=0.005, shape_error=55, + presampling_multiplier=10, sampling=50, sampling_method="visvalingam", pixel_limit=None, @@ -607,26 +647,31 @@ def eddy_identification( :param str grid_height: Grid name of Sea Surface Height :param str uname: Grid name of u speed component :param str vname: Grid name of v speed component - :param datetime.datetime date: Date which will be stored in object to date data + :param datetime.datetime date: Date to be stored in object to date data :param float,int step: Height between two layers in m - :param float,int shape_error: Maximal error allowed for outter contour in % + :param float,int shape_error: Maximal error allowed for outermost contour in % + :param int presampling_multiplier: + Evenly oversample the initial number of points in the contour by nb_pts x presampling_multiplier to fit circles :param int sampling: Number of points to store contours and speed profile - :param str sampling_method: Method to resample 'uniform' or 'visvalingam' + :param str sampling_method: Method to resample the stored contours, 'uniform' or 'visvalingam' :param (int,int),None pixel_limit: - Min and max number of pixels inside the inner and the outer contour to be considered as an eddy + Min and max number of pixels inside the inner and the outermost contour to be considered as an eddy :param float,None precision: Truncate values at the defined precision in m :param str force_height_unit: Unit used for height unit :param str force_speed_unit: Unit used for speed unit - :param dict kwargs: Argument given to amplitude + :param dict kwargs: Arguments given to amplitude (mle, nb_step_min, nb_step_to_be_mle). + Look at :py:meth:`py_eddy_tracker.eddy_feature.Amplitude` + The amplitude threshold is given by `step*nb_step_min` - :return: Return a list of 2 elements: Anticyclone and Cyclone + + :return: Return a list of 2 elements: Anticyclones and Cyclones :rtype: py_eddy_tracker.observations.observation.EddiesObservations .. minigallery:: py_eddy_tracker.GridDataset.eddy_identification """ if not isinstance(date, datetime): - raise Exception("Date argument be a datetime object") - # The inf limit must be in pixel and sup limit in surface + raise Exception("Date argument must be a datetime object") + # The inf limit must be in pixel and sup limit in surface if pixel_limit is None: pixel_limit = (4, 1000) @@ -651,10 +696,10 @@ def eddy_identification( # Get ssh grid data = self.grid(grid_height).astype("f8") - # In case of a reduce mask + # In case of a reduced mask if len(data.mask.shape) == 0 and not data.mask: data.mask = zeros(data.shape, dtype="bool") - # we remove noisy information + # we remove noisy data if precision is not None: data = (data / precision).round() * precision # Compute levels for ssh @@ -677,6 +722,7 @@ def eddy_identification( ) z_min, z_max = z_min_p, z_max_p + logger.debug("Levels from %f to %f", z_min, z_max) levels = arange(z_min - z_min % step, z_max - z_max % step + 2 * step, step) # Get x and y values @@ -729,7 +775,8 @@ def eddy_identification( for contour in contour_paths: if contour.used: continue - # FIXME : center could be not in contour and fit on raw sampling + # FIXME : center could be outside the contour due to the fit + # FIXME : warning : the fit is made on raw sampling _, _, _, aerr = contour.fit_circle() # Filter for shape @@ -748,11 +795,12 @@ def eddy_identification( # Test of the rotating sense: cyclone or anticyclone if has_value( - data, i_x_in, i_y_in, cvalues, below=anticyclonic_search + data.data, i_x_in, i_y_in, cvalues, below=anticyclonic_search ): continue - # FIXME : Maybe limit max must be replace with a maximum of surface + # Test the number of pixels within the outermost contour + # FIXME : Maybe limit max must be replaced with a maximum of surface if ( contour.nb_pixel < pixel_limit[0] or contour.nb_pixel > pixel_limit[1] @@ -760,6 +808,9 @@ def eddy_identification( contour.reject = 3 continue + # Here the considered contour passed shape_error test, masked_pixels test, + # values strictly above (AEs) or below (CEs) the contour, number_pixels test) + # Compute amplitude reset_centroid, amp = self.get_amplitude( contour, @@ -775,13 +826,12 @@ def eddy_identification( contour.reject = 4 continue if reset_centroid: - if self.is_circular(): centi = self.normalize_x_indice(reset_centroid[0]) else: centi = reset_centroid[0] centj = reset_centroid[1] - # To move in regular and unregular grid + # FIXME : To move in regular and unregular grid if len(x.shape) == 1: centlon_e = x[centi] centlat_e = y[centj] @@ -789,7 +839,7 @@ def eddy_identification( centlon_e = x[centi, centj] centlat_e = y[centi, centj] - # centlat_e and centlon_e must be index of maximum, we will loose some inner contour if it's not + # centlat_e and centlon_e must be indexes of maximum, we will loose some inner contour if it's not ( max_average_speed, speed_contour, @@ -807,7 +857,7 @@ def eddy_identification( pixel_min=pixel_limit[0], ) - # FIXME : Instantiate new EddyObservation object (high cost need to be reviewed) + # FIXME : Instantiate new EddyObservation object (high cost, need to be reviewed) obs = EddiesObservations( size=1, track_extra_variables=track_extra_variables, @@ -826,42 +876,59 @@ def eddy_identification( obs.amplitude[:] = amp.amplitude obs.speed_average[:] = max_average_speed obs.num_point_e[:] = contour.lon.shape[0] - xy_e = resample(contour.lon, contour.lat, **out_sampling) - obs.contour_lon_e[:], obs.contour_lat_e[:] = xy_e obs.num_point_s[:] = speed_contour.lon.shape[0] - xy_s = resample( - speed_contour.lon, speed_contour.lat, **out_sampling - ) - obs.contour_lon_s[:], obs.contour_lat_s[:] = xy_s - # FIXME : we use a contour without resampling - # First, get position based on innermost contour - centlon_i, centlat_i, _, _ = _fit_circle_path( - create_vertice(inner_contour.lon, inner_contour.lat) + # Evenly resample contours with nb_pts = nb_pts_original x presampling_multiplier + xy_i = uniform_resample( + inner_contour.lon, + inner_contour.lat, + num_fac=presampling_multiplier, ) - # Second, get speed-based radius based on contour of max uavg - centlon_s, centlat_s, eddy_radius_s, aerr_s = _fit_circle_path( - create_vertice(*xy_s) + xy_e = uniform_resample( + contour.lon, + contour.lat, + num_fac=presampling_multiplier, ) - # Compute again to use resampled contour - _, _, eddy_radius_e, aerr_e = _fit_circle_path( - create_vertice(*xy_e) + xy_s = uniform_resample( + speed_contour.lon, + speed_contour.lat, + num_fac=presampling_multiplier, ) + # First, get position of max SSH based on best fit circle with resampled innermost contour + centlon_i, centlat_i, _, _ = _fit_circle_path(create_vertice(*xy_i)) + obs.lon_max[:] = centlon_i + obs.lat_max[:] = centlat_i + + # Second, get speed-based radius, shape error, eddy center, area based on resampled contour of max uavg + centlon_s, centlat_s, eddy_radius_s, aerr_s = _fit_circle_path( + create_vertice(*xy_s) + ) obs.radius_s[:] = eddy_radius_s - obs.radius_e[:] = eddy_radius_e - obs.shape_error_e[:] = aerr_e obs.shape_error_s[:] = aerr_s obs.speed_area[:] = poly_area( *coordinates_to_local(*xy_s, lon0=centlon_s, lat0=centlat_s) ) + obs.lon[:] = centlon_s + obs.lat[:] = centlat_s + + # Third, compute effective radius, shape error, area from resampled effective contour + _, _, eddy_radius_e, aerr_e = _fit_circle_path( + create_vertice(*xy_e) + ) + obs.radius_e[:] = eddy_radius_e + obs.shape_error_e[:] = aerr_e obs.effective_area[:] = poly_area( *coordinates_to_local(*xy_e, lon0=centlon_s, lat0=centlat_s) ) - obs.lon[:] = centlon_s - obs.lat[:] = centlat_s - obs.lon_max[:] = centlon_i - obs.lat_max[:] = centlat_i + + # Finally, resample contours with output parameters + xy_e_f = resample(*xy_e, **out_sampling) + xy_s_f = resample(*xy_s, **out_sampling) + + obs.contour_lon_s[:], obs.contour_lat_s[:] = xy_s_f + obs.contour_lon_e[:], obs.contour_lat_e[:] = xy_e_f + if aerr > 99.9 or aerr_s > 99.9: logger.warning( "Strange shape at this step! shape_error : %f, %f", @@ -923,7 +990,7 @@ def get_uavg( pixel_min=3, ): """ - Calculate geostrophic speed around successive contours + Compute geostrophic speed around successive contours Returns the average """ # Init max speed to search maximum @@ -1035,7 +1102,7 @@ def load(self): @property def bounds(self): - """Give bound""" + """Give bounds""" return self.x_c.min(), self.x_c.max(), self.y_c.min(), self.y_c.max() def bbox_indice(self, vertices): @@ -1067,7 +1134,7 @@ def compute_pixel_path(self, x0, y0, x1, y1): pass def init_pos_interpolator(self): - logger.debug("Create a KdTree could be long ...") + logger.debug("Create a KdTree, could be long ...") self.index_interp = cKDTree( create_vertice(self.x_c.reshape(-1), self.y_c.reshape(-1)) ) @@ -1085,7 +1152,7 @@ def _low_filter(self, grid_name, w_cut, factor=8.0): bins = (x_array, y_array) x_flat, y_flat, z_flat = x.reshape((-1,)), y.reshape((-1,)), data.reshape((-1,)) - m = ~z_flat.mask + m = ~self.get_mask(z_flat) x_flat, y_flat, z_flat = x_flat[m], y_flat[m], z_flat[m] nb_value, _, _ = histogram2d(x_flat, y_flat, bins=bins) @@ -1152,6 +1219,12 @@ def setup_coordinates(self): raise Exception( "Coordinates in RegularGridDataset must be 1D array, or think to use UnRegularGridDataset" ) + dx = self.x_bounds[1:] - self.x_bounds[:-1] + dy = self.y_bounds[1:] - self.y_bounds[:-1] + if (dx < 0).any() or (dy < 0).any(): + raise Exception( + "Coordinates in RegularGridDataset must be strictly increasing" + ) self._x_step = (self.x_c[1:] - self.x_c[:-1]).mean() self._y_step = (self.y_c[1:] - self.y_c[:-1]).mean() @@ -1197,10 +1270,10 @@ def bbox_indice(self, vertices): def get_pixels_in(self, contour): """ - Get indices of pixels in contour. + Get indexes of pixels in contour. - :param vertice,Path contour: Contour which enclosed some pixels - :return: Indices of grid in contour + :param vertice,Path contour: Contour that encloses some pixels + :return: Indexes of grid in contour :rtype: array[int],array[int] """ if isinstance(contour, BasePath): @@ -1233,7 +1306,7 @@ def ystep(self): return self._y_step def compute_pixel_path(self, x0, y0, x1, y1): - """Give a series of indexes which describe the path between to position""" + """Give a series of indexes describing the path between two positions""" return compute_pixel_path( x0, y0, @@ -1246,9 +1319,13 @@ def compute_pixel_path(self, x0, y0, x1, y1): self.x_size, ) - def clean_land(self): + def clean_land(self, name): """Function to remove all land pixel""" - pass + mask_land = self.__class__(get_demo_path("mask_1_60.nc"), "lon", "lat") + x, y = meshgrid(self.x_c, self.y_c) + m = mask_land.interp("mask", x.reshape(-1), y.reshape(-1), "nearest") + data = self.grid(name) + self.vars[name] = ma.array(data, mask=m.reshape(x.shape).T) def is_circular(self): """Check if the grid is circular""" @@ -1321,6 +1398,24 @@ def kernel_lanczos(self, lat, wave_length, order=1): kernel[dist_norm > order] = 0 return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_loess(self, lat, wave_length, order=1): + """ + https://fr.wikipedia.org/wiki/R%C3%A9gression_locale + """ + order = self.check_order(order) + half_x_pt, half_y_pt, dist_norm = self.estimate_kernel_shape( + lat, wave_length, order + ) + + def inc_func(xdist): + f = zeros(xdist.size) + f[abs(xdist) < 1] = 1 + return f + + kernel = (1 - abs(dist_norm) ** 3) ** 3 + kernel[abs(dist_norm) > order] = 0 + return self.finalize_kernel(kernel, order, half_x_pt, half_y_pt) + def kernel_bessel(self, lat, wave_length, order=1): """wave_length in km order must be int @@ -1396,7 +1491,8 @@ def convolve_filter_with_dynamic_kernel( tmp_matrix = ma.zeros((2 * d_lon + data.shape[0], k_shape[1])) tmp_matrix.mask = ones(tmp_matrix.shape, dtype=bool) # Slice to apply on input data - sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat, data.shape[1])) + # +1 for upper bound, to take in acount this column + sl_lat_data = slice(max(0, i - d_lat), min(i + d_lat + 1, data.shape[1])) # slice to apply on temporary matrix to store input data sl_lat_in = slice( d_lat - (i - sl_lat_data.start), d_lat + (sl_lat_data.stop - i) @@ -1476,13 +1572,13 @@ def bessel_high_filter(self, grid_name, wave_length, order=1, lat_max=85, **kwar :param str grid_name: grid to filter, data will replace original one :param float wave_length: in km :param int order: order to use, if > 1 negative values of the cardinal sinus are present in kernel - :param float lat_max: absolute latitude above no filtering apply + :param float lat_max: absolute latitude, no filtering above :param dict kwargs: look at :py:meth:`RegularGridDataset.convolve_filter_with_dynamic_kernel` .. minigallery:: py_eddy_tracker.RegularGridDataset.bessel_high_filter """ logger.debug( - "Run filtering with wave of %(wave_length)s km and order of %(order)s ...", + "Run filtering with wavelength of %(wave_length)s km and order of %(order)s ...", dict(wave_length=wave_length, order=order), ) data_out = self.convolve_filter_with_dynamic_kernel( @@ -1574,7 +1670,7 @@ def spectrum_lonlat(self, grid_name, area=None, ref=None, **kwargs): (lat_content[0], lat_content[1] / ref_lat_content[1]), ) - def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False): + def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=False, second=False): if not isinstance(schema, int) and schema < 1: raise Exception("schema must be a positive int") @@ -1597,12 +1693,17 @@ def compute_finite_difference(self, data, schema=1, mode="reflect", vertical=Fal data1[-schema:] = nan data2[:schema] = nan + # Distance for one degree d = self.EARTH_RADIUS * 2 * pi / 360 * 2 * schema + # Mulitply by 2 step if vertical: - d *= self.ystep + d *= self.ystep else: d *= self.xstep * cos(deg2rad(self.y_c)) - return (data1 - data2) / d + if second: + return (data1 + data2 - 2 * data) / (d ** 2 / 4) + else: + return (data1 - data2) / d def compute_stencil( self, data, stencil_halfwidth=4, mode="reflect", vertical=False @@ -1682,20 +1783,21 @@ def compute_stencil( self.x_c, self.y_c, data.data, - data.mask, + self.get_mask(data), self.EARTH_RADIUS, vertical=vertical, stencil_halfwidth=stencil_halfwidth, ) return ma.array(g, mask=m) - def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): - self.add_uv(grid_height, uname, vname) + def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15, **kwargs): + self.add_uv(grid_height, uname, vname, **kwargs) latmax = 5 - _, (i_start, i_end) = self.nearest_grd_indice((0, 0), (-latmax, latmax)) + _, i_start = self.nearest_grd_indice(0, -latmax) + _, i_end = self.nearest_grd_indice(0, latmax) sl = slice(i_start, i_end) # Divide by sideral day - lat = self.y_c[sl] + lat = self.y_c gob = ( cos(deg2rad(lat)) * ones((self.x_c.shape[0], 1)) @@ -1709,39 +1811,26 @@ def add_uv_lagerloef(self, grid_height, uname="u", vname="v", schema=15): mode = "wrap" if self.is_circular() else "reflect" # fill data to compute a finite difference on all point - data = self.convolve_filter_with_dynamic_kernel( - grid_height, - self.kernel_bessel, - lat_max=10, - wave_length=500, - order=1, - extend=0.1, - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) - data = self.convolve_filter_with_dynamic_kernel( - data, self.kernel_bessel, lat_max=10, wave_length=500, order=1, extend=0.1 - ) + kw_filter = dict(kernel_func=self.kernel_bessel, order=1, extend=.1) + data = self.convolve_filter_with_dynamic_kernel(grid_height, wave_length=500, **kw_filter, lat_max=6+5+2+3) v_lagerloef = ( self.compute_finite_difference( - self.compute_finite_difference(data, mode=mode, schema=schema), - mode=mode, - schema=schema, - )[:, sl] - * gob - ) - u_lagerloef = ( - -self.compute_finite_difference( - self.compute_finite_difference(data, vertical=True, schema=schema), - vertical=True, - schema=schema, - )[:, sl] + self.compute_finite_difference(data, mode=mode, schema=1), + vertical=True, schema=1 + ) * gob ) - w = 1 - exp(-((lat / 2.2) ** 2)) - self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef * (1 - w) - self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef * (1 - w) + u_lagerloef = -self.compute_finite_difference(data, vertical=True, schema=schema, second=True) * gob + + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + v_lagerloef = self.convolve_filter_with_dynamic_kernel(v_lagerloef, wave_length=416, **kw_filter, lat_max=6) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=195, **kw_filter, lat_max=6 + 5 +2) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6 + 5) + u_lagerloef = self.convolve_filter_with_dynamic_kernel(u_lagerloef, wave_length=416, **kw_filter, lat_max=6) + w = 1 - exp(-((lat[sl] / 2.2) ** 2)) + self.vars[vname][:, sl] = self.vars[vname][:, sl] * w + v_lagerloef[:, sl] * (1 - w) + self.vars[uname][:, sl] = self.vars[uname][:, sl] * w + u_lagerloef[:, sl] * (1 - w) def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): r"""Compute a u and v grid @@ -1808,13 +1897,13 @@ def add_uv(self, grid_height, uname="u", vname="v", stencil_halfwidth=4): ) def speed_coef_mean(self, contour): - """Some nan can be computed over contour if we are near border, + """Some nan can be computed over contour if we are near borders, something to explore """ return mean_on_regular_contour( self.x_c, self.y_c, - self._speed_ev, + self._speed_ev.data, self._speed_ev.mask, contour.vertices, nan_remove=True, @@ -1822,14 +1911,15 @@ def speed_coef_mean(self, contour): def init_speed_coef(self, uname="u", vname="v"): """Draft""" - self._speed_ev = (self.grid(uname) ** 2 + self.grid(vname) ** 2) ** 0.5 + u, v = self.grid(uname), self.grid(vname) + self._speed_ev = sqrt(u * u + v * v) def display(self, ax, name, factor=1, ref=None, **kwargs): """ - :param matplotlib.axes.Axes ax: matplotlib axes use to draw + :param matplotlib.axes.Axes ax: matplotlib axes used to draw :param str,array name: variable to display, could be an array :param float factor: multiply grid by - :param float,None ref: if define use like west bound + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.pcolormesh` .. minigallery:: py_eddy_tracker.RegularGridDataset.display @@ -1848,10 +1938,10 @@ def display(self, ax, name, factor=1, ref=None, **kwargs): def contour(self, ax, name, factor=1, ref=None, **kwargs): """ - :param matplotlib.axes.Axes ax: matplotlib axes use to draw + :param matplotlib.axes.Axes ax: matplotlib axes used to draw :param str,array name: variable to display, could be an array :param float factor: multiply grid by - :param float,None ref: if define use like west bound + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.contour` .. minigallery:: py_eddy_tracker.RegularGridDataset.contour @@ -1890,14 +1980,6 @@ def regrid(self, other, grid_name, new_name=None): # self.variables_description[new_name]['infos'] = False # self.variables_description[new_name]['kwargs']['dimensions'] = ... - @staticmethod - def get_mask(a): - if len(a.mask.shape): - m = a.mask - else: - m = ones(a.shape) if a.mask else zeros(a.shape) - return m - def interp(self, grid_name, lons, lats, method="bilinear"): """ Compute z over lons, lats @@ -1912,17 +1994,32 @@ def interp(self, grid_name, lons, lats, method="bilinear"): g = self.grid(grid_name) m = self.get_mask(g) return interp2d_geo( - self.x_c, self.y_c, g, m, lons, lats, nearest=method == "nearest" + self.x_c, self.y_c, g.data, m, lons, lats, nearest=method == "nearest" ) - def uv_for_advection(self, u_name, v_name, time_step=600, backward=False, factor=1): + def uv_for_advection( + self, + u_name=None, + v_name=None, + time_step=600, + h_name=None, + backward=False, + factor=1, + ): """ Get U,V to be used in degrees with precomputed time step - :param str,array u_name: U field to advect obs - :param str,array v_name: V field to advect obs + :param None,str,array u_name: U field to advect obs, if h_name is None + :param None,str,array v_name: V field to advect obs, if h_name is None + :param None,str,array h_name: H field to compute UV to advect obs, if u_name and v_name are None :param int time_step: Number of second for each advection """ + if h_name is not None: + u_name, v_name = "u", "v" + if u_name not in self.vars: + self.add_uv(h_name) + self.vars.pop(h_name, None) + u = (self.grid(u_name) if isinstance(u_name, str) else u_name).copy() v = (self.grid(v_name) if isinstance(v_name, str) else v_name).copy() # N seconds / 1 degrees in m @@ -1933,19 +2030,19 @@ def uv_for_advection(self, u_name, v_name, time_step=600, backward=False, factor u = -u v = -v m = u.mask + v.mask - return u, v, m + return u.data, v.data, m def advect(self, x, y, u_name, v_name, nb_step=10, rk4=True, **kw): """ At each call it will update position in place with u & v field - It's a dummy advection which use only one layer of current + It's a dummy advection using only one layer of current :param array x: Longitude of obs to move :param array y: Latitude of obs to move :param str,array u_name: U field to advect obs :param str,array v_name: V field to advect obs - :param int nb_step: Number of iteration before to release data + :param int nb_step: Number of iterations before releasing data .. minigallery:: py_eddy_tracker.GridDataset.advect """ @@ -1962,13 +2059,13 @@ def filament( """ Produce filament with concatenation of advection - It's a dummy advection which use only one layer of current + It's a dummy advection using only one layer of current :param array x: Longitude of obs to move :param array y: Latitude of obs to move :param str,array u_name: U field to advect obs :param str,array v_name: V field to advect obs - :param int nb_step: Number of iteration before to release data + :param int nb_step: Number of iteration before releasing data :param int filament_size: Number of point by filament :return: x,y for a line @@ -2014,7 +2111,7 @@ def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): v00, v01, v10, v11 = 0.0, 0.0, 0.0, 0.0 # On each particle for i in prange(x.size): - # If particle are not valid => continue + # If particle is not valid => continue if m[i]: continue x_, y_ = x[i], y[i] @@ -2032,7 +2129,7 @@ def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): masked, u00, u01, u10, u11, v00, v01, v10, v11 = get_uv_quad( ii_, jj_, u_g, v_g, m_g, nb_x ) - # The 3 following could be in cache operation but this one must be test in any case + # The 3 following could be in cache operation but this one must be tested in any case if masked: x_, y_ = nan, nan m[i] = True @@ -2056,7 +2153,7 @@ def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): m[i] = True break u2, v2 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) - # k3, slope at middle with update guess position + # k3, slope at middle with updated guess position x2, y2 = x_ + u2 * 0.5, y_ + v2 * 0.5 ii_, jj_, xd, yd = get_grid_indices( x_ref, y_ref, x_step, y_step, x2, y2, nb_x @@ -2074,7 +2171,7 @@ def advect_rk4(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): m[i] = True break u3, v3 = interp_uv(xd, yd, u00, u01, u10, u11, v00, v01, v10, v11) - # k4, slope at end with update guess position + # k4, slope at end with updated guess position x3, y3 = x_ + u3, y_ + v3 ii_, jj_, xd, yd = get_grid_indices( x_ref, y_ref, x_step, y_step, x3, y3, nb_x @@ -2110,14 +2207,14 @@ def advect(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 nb_x_ = x_g.size nb_x = nb_x_ if is_circular else 0 - # Indices which should be never exist + # Indexes which should be never exist i0_old, j0_old = -100000, -100000 masked = False u00, u01, u10, u11 = 0.0, 0.0, 0.0, 0.0 v00, v01, v10, v11 = 0.0, 0.0, 0.0, 0.0 # On each particule for i in prange(x.size): - # If particule are not valid => continue + # If particule is not valid => continue if m[i]: continue # Iterate on whole steps @@ -2125,9 +2222,9 @@ def advect(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): i0, j0, xd, yd = get_grid_indices( x_ref, y_ref, x_step, y_step, x[i], y[i], nb_x ) - # corner are the same need only a new xd and yd + # corners are the same, need only a new xd and yd if i0 != i0_old or j0 != j0_old: - # Need to be store only on change + # Need to be stored only on change i0_old, j0_old = i0, j0 if not is_circular and (i0 < 0 or i0 > nb_x_): masked = True @@ -2147,42 +2244,41 @@ def advect(x_g, y_g, u_g, v_g, m_g, x, y, m, nb_step): @njit(cache=True, fastmath=True) def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): - """Give a serie of indexes describing the path between two position""" + """Give a serie of indexes describing the path between two positions""" # index nx = x0.shape[0] i_x0 = empty(nx, dtype=numba_types.int_) i_x1 = empty(nx, dtype=numba_types.int_) i_y0 = empty(nx, dtype=numba_types.int_) i_y1 = empty(nx, dtype=numba_types.int_) - # Because round_ is not accepted with array in numba for i in range(nx): - i_x0[i] = round_(((x0[i] - x_ori) % 360) / x_step) - i_x1[i] = round_(((x1[i] - x_ori) % 360) / x_step) - i_y0[i] = round_((y0[i] - y_ori) / y_step) - i_y1[i] = round_((y1[i] - y_ori) / y_step) + i_x0[i] = np.round(((x0[i] - x_ori) % 360) / x_step) + i_x1[i] = np.round(((x1[i] - x_ori) % 360) / x_step) + i_y0[i] = np.round((y0[i] - y_ori) / y_step) + i_y1[i] = np.round((y1[i] - y_ori) / y_step) # Delta index of x d_x = i_x1 - i_x0 d_x = (d_x + nb_x // 2) % nb_x - (nb_x // 2) i_x1 = i_x0 + d_x # Delta index of y d_y = i_y1 - i_y0 - # max and abs sum doesn't work on array? + # max and abs sum do not work on array? d_max = empty(nx, dtype=numba_types.int32) nb_value = 0 for i in range(nx): d_max[i] = max(abs(d_x[i]), abs(d_y[i])) - # Compute number of pixel which we go trought + # Compute number of pixel we go trought nb_value += d_max[i] + 1 - # Create an empty array to store value of pixel across the travel + # Create an empty array to store value of pixel across the path i_g = empty(nb_value, dtype=numba_types.int32) j_g = empty(nb_value, dtype=numba_types.int32) # Index to determine the position in the global array ii = 0 - # Iteration on each travel + # Iteration on each path for i, delta in enumerate(d_max): - # If the travel don't cross multiple pixel + # If the path doesn't cross multiple pixels if delta == 0: i_g[ii : ii + delta + 1] = i_x0[i] j_g[ii : ii + delta + 1] = i_y0[i] @@ -2196,7 +2292,7 @@ def compute_pixel_path(x0, y0, x1, y1, x_ori, y_ori, x_step, y_step, nb_x): sup = -1 if d_x[i] < 0 else 1 i_g[ii : ii + delta + 1] = arange(i_x0[i], i_x1[i] + sup, sup) j_g[ii : ii + delta + 1] = i_y0[i] - # In case of multiple direction + # In case of multiple directions else: a = (i_x1[i] - i_x0[i]) / float(i_y1[i] - i_y0[i]) if abs(d_x[i]) >= abs(d_y[i]): @@ -2239,26 +2335,51 @@ def __init__(self): self.datasets = list() @classmethod - def from_netcdf_cube(cls, filename, x_name, y_name, t_name, heigth=None): + def from_netcdf_cube(cls, filename, x_name, y_name, t_name, heigth=None, **kwargs): new = cls() with Dataset(filename) as h: for i, t in enumerate(h.variables[t_name][:]): - d = RegularGridDataset(filename, x_name, y_name, indexs={t_name: i}) + d = RegularGridDataset( + filename, x_name, y_name, indexs={t_name: i}, **kwargs + ) if heigth is not None: d.add_uv(heigth) new.datasets.append((t, d)) return new @classmethod - def from_netcdf_list(cls, filenames, t, x_name, y_name, indexs=None, heigth=None): + def from_netcdf_list( + cls, filenames, t, x_name, y_name, indexs=None, heigth=None, **kwargs + ): new = cls() - for i, t in enumerate(t): - d = RegularGridDataset(filenames[i], x_name, y_name, indexs=indexs) + for i, _t in enumerate(t): + filename = filenames[i] + logger.debug(f"load file {i:02d}/{len(t)} t={_t} : {filename}") + d = RegularGridDataset(filename, x_name, y_name, indexs=indexs, **kwargs) if heigth is not None: d.add_uv(heigth) - new.datasets.append((t, d)) + new.datasets.append((_t, d)) return new + @property + def are_loaded(self): + return ~array([d.dimensions is None for _, d in self.datasets]) + + def __repr__(self): + nb_dataset = len(self.datasets) + return f"{self.are_loaded.sum()}/{nb_dataset} datasets loaded" + + def shift_files(self, t, filename, heigth=None, **rgd_kwargs): + """Add next file to the list and remove the oldest""" + + self.datasets = self.datasets[1:] + + d = RegularGridDataset(filename, **rgd_kwargs) + if heigth is not None: + d.add_uv(heigth) + self.datasets.append((t, d)) + logger.debug(f"shift and adding i={len(self.datasets)} t={t} : {filename}") + def interp(self, grid_name, t, lons, lats, method="bilinear"): """ Compute z over lons, lats @@ -2287,9 +2408,19 @@ def __iter__(self): for _, d in self.datasets: yield d + @property + def time(self): + return array([t for t, _ in self.datasets]) + + @property + def period(self): + t = self.time + return t.min(), t.max() + def __getitem__(self, item): for t, d in self.datasets: if t == item: + d.populate() return d raise KeyError(item) @@ -2369,17 +2500,23 @@ def filament( t += dt yield t, f_x, f_y + def reset_grids(self, N=None): + if N is not None: + m = self.are_loaded + if m.sum() > N: + for i in where(m)[0]: + self.datasets[i][1].clean() + def advect( self, x, y, - u_name, - v_name, t_init, mask_particule=None, nb_step=10, time_step=600, rk4=True, + reset_grid=None, **kw, ): """ @@ -2387,15 +2524,18 @@ def advect( :param array x: Longitude of obs to move :param array y: Latitude of obs to move - :param str,array u_name: U field to advect obs - :param str,array v_name: V field to advect obs + :param float t_init: time to start advection + :param array,None mask_particule: advect only i mask is True :param int nb_step: Number of iteration before to release data :param int time_step: Number of second for each advection + :param bool rk4: Use rk4 algorithm instead of finite difference + :param int reset_grid: Delete all loaded data in cube if there are more than N grid loaded - :return: x,y position + :return: t,x,y position .. minigallery:: py_eddy_tracker.GridCollection.advect """ + self.reset_grids(reset_grid) backward = kw.get("backward", False) if backward: generator = self.get_previous_time_step(t_init) @@ -2406,9 +2546,9 @@ def advect( dt = nb_step * time_step t_step = time_step t0, d0 = generator.__next__() - u0, v0, m0 = d0.uv_for_advection(u_name, v_name, time_step, **kw) + u0, v0, m0 = d0.uv_for_advection(time_step=time_step, **kw) t1, d1 = generator.__next__() - u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) t0 = t0 * 86400 t1 = t1 * 86400 t = t_init * 86400 @@ -2418,11 +2558,12 @@ def advect( else: mask_particule += isnan(x) + isnan(y) while True: + logger.debug(f"advect : t={t/86400}") if (backward and t <= t1) or (not backward and t >= t1): t0, u0, v0, m0 = t1, u1, v1, m1 t1, d1 = generator.__next__() t1 = t1 * 86400 - u1, v1, m1 = d1.uv_for_advection(u_name, v_name, time_step, **kw) + u1, v1, m1 = d1.uv_for_advection(time_step=time_step, **kw) w = 1 - (arange(t, t + dt, t_step) - t0) / (t1 - t0) half_w = t_step / 2.0 / (t1 - t0) advect_( @@ -2444,27 +2585,45 @@ def advect( yield t, x, y def get_next_time_step(self, t_init): - first = True for i, (t, dataset) in enumerate(self.datasets): if t < t_init: continue - if first: - first = False - yield self.datasets[i - 1] + dataset.populate() + logger.debug(f"i={i}, t={t}, dataset={dataset}") yield t, dataset def get_previous_time_step(self, t_init): - first = True i = len(self.datasets) for t, dataset in reversed(self.datasets): i -= 1 if t > t_init: continue - if first: - first = False - yield self.datasets[i + 1] + dataset.populate() + logger.debug(f"i={i}, t={t}, dataset={dataset}") yield t, dataset + def path(self, x0, y0, *args, nb_time=2, **kwargs): + """ + At each call it will update position in place with u & v field + + :param array x0: Longitude of obs to move + :param array y0: Latitude of obs to move + :param int nb_time: Number of iteration for particle + :param dict kwargs: look at :py:meth:`GridCollection.advect` + + :return: t,x,y + + .. minigallery:: py_eddy_tracker.GridCollection.path + """ + particles = self.advect(x0.copy(), y0.copy(), *args, **kwargs) + t = empty(nb_time + 1, dtype="f8") + x = empty((nb_time + 1, x0.size), dtype=x0.dtype) + y = empty(x.shape, dtype=y0.dtype) + t[0], x[0], y[0] = kwargs.get("t_init"), x0, y0 + for i in range(nb_time): + t[i + 1], x[i + 1], y[i + 1] = particles.__next__() + return t, x, y + @njit(cache=True) def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, half_w=0): @@ -2474,7 +2633,7 @@ def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, hal is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 nb_x_ = x_g.size nb_x = nb_x_ if is_circular else 0 - # Indices which should be never exist + # Indexes that should never exist i0_old, j0_old = -100000, -100000 m0, m1 = False, False u000, u001, u010, u011 = 0.0, 0.0, 0.0, 0.0 @@ -2483,7 +2642,7 @@ def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, hal v100, v101, v110, v111 = 0.0, 0.0, 0.0, 0.0 # On each particle for i in prange(x.size): - # If particle are not valid => continue + # If particle is not valid => continue if m[i]: continue # Iterate on whole steps @@ -2492,7 +2651,7 @@ def advect_t(x_g, y_g, u_g0, v_g0, m_g0, u_g1, v_g1, m_g1, x, y, m, weigths, hal x_ref, y_ref, x_step, y_step, x[i], y[i], nb_x ) if i0 != i0_old or j0 != j0_old: - # Need to be store only on change + # Need to be stored only on change i0_old, j0_old = i0, j0 if not is_circular and (i0 < 0 or i0 > nb_x_): m0, m1 = True, True @@ -2523,8 +2682,8 @@ def get_uv_quad(i0, j0, u, v, m, nb_x=0): """ Return u/v for (i0, j0), (i1, j0), (i0, j1), (i1, j1) - :param int i0: indices of longitude - :param int j0: indices of latitude + :param int i0: indexes of longitude + :param int j0: indexes of latitude :param array[float] u: current along x axis :param array[float] v: current along y axis :param array[bool] m: flag to know if position is valid @@ -2536,6 +2695,11 @@ def get_uv_quad(i0, j0, u, v, m, nb_x=0): i1, j1 = i0 + 1, j0 + 1 if nb_x != 0: i1 %= nb_x + i_max, j_max = m.shape + + if i1 >= i_max or j1 >= j_max: + return True, nan, nan, nan, nan, nan, nan, nan, nan + if m[i0, j0] or m[i0, j1] or m[i1, j0] or m[i1, j1]: return True, nan, nan, nan, nan, nan, nan, nan, nan # Extract value for u and v @@ -2547,7 +2711,7 @@ def get_uv_quad(i0, j0, u, v, m, nb_x=0): @njit(cache=True, fastmath=True) def get_grid_indices(x0, y0, x_step, y_step, x, y, nb_x=0): """ - Return grid indices and weight + Return grid indexes and weight :param float x0: first longitude :param float y0: first latitude @@ -2557,7 +2721,7 @@ def get_grid_indices(x0, y0, x_step, y_step, x, y, nb_x=0): :param float y: latitude to interpolate :param int nb_x: If different of 0 we check if wrapping is needed - :return: indices and weight + :return: indexes and weight :rtype: int,int,float,float """ i, j = (x - x0) / x_step, (y - y0) / y_step @@ -2609,7 +2773,7 @@ def advect_t_rk4( v100, v101, v110, v111 = 0.0, 0.0, 0.0, 0.0 # On each particle for i in prange(x.size): - # If particle are not valid => continue + # If particle is not valid => continue if m[i]: continue x_, y_ = x[i], y[i] @@ -2630,7 +2794,7 @@ def advect_t_rk4( (m1, u100, u101, u110, u111, v100, v101, v110, v111) = get_uv_quad( ii_, jj_, u_g1, v_g1, m_g1, nb_x ) - # The 3 following could be in cache operation but this one must be test in any case + # The 3 following could be in cache operation but this one must be tested in any case if m0 or m1: x_, y_ = nan, nan m[i] = True @@ -2662,7 +2826,7 @@ def advect_t_rk4( u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) w_ = w - half_w u2, v2 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) - # k3, slope at middle with update guess position + # k3, slope at middle with updated guess position x2, y2 = x_ + u2 * 0.5, y_ + v2 * 0.5 ii_, jj_, xd, yd = get_grid_indices( x_ref, y_ref, x_step, y_step, x2, y2, nb_x @@ -2685,7 +2849,7 @@ def advect_t_rk4( u0_, v0_ = interp_uv(xd, yd, u000, u001, u010, u011, v000, v001, v010, v011) u1_, v1_ = interp_uv(xd, yd, u100, u101, u110, u111, v100, v101, v110, v111) u3, v3 = u0_ * w_ + u1_ * (1 - w_), v0_ * w_ + v1_ * (1 - w_) - # k4, slope at end with update guess position + # k4, slope at end with updated guess position x3, y3 = x_ + u3, y_ + v3 ii_, jj_, xd, yd = get_grid_indices( x_ref, y_ref, x_step, y_step, x3, y3, nb_x @@ -2732,7 +2896,7 @@ def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth= :param array x: longitude coordinates :param array y: latitude coordinates :param array h: 2D array to derivate - :param array m: mask associate to h to know where are invalid data + :param array m: mask associated to h to know where are invalid data :param float earth_radius: Earth radius in m :param bool vertical: if True stencil will be vertical (along y) :param int stencil_halfwidth: from 1 to 4 to specify maximal kernel usable @@ -2776,7 +2940,7 @@ def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth= h_3, h_2, h_1, h0 = h[-4, j], h[-3, j], h[-2, j], h[-1, j] m_3, m_2, m_1, m0 = m[-4, j], m[-3, j], m[-2, j], m[-1, j] else: - m_3, m_2, m_1, m0 = False, False, False, False + m_3, m_2, m_1, m0 = True, True, True, True h1, h2, h3, h4 = h[0, j], h[1, j], h[2, j], h[3, j] m1, m2, m3, m4 = m[0, j], m[1, j], m[2, j], m[3, j] for i in range(nb_x): @@ -2826,7 +2990,7 @@ def compute_stencil(x, y, h, m, earth_radius, vertical=False, stencil_halfwidth= grad[i, j] = (h3 - h_3 + 9 * (h_2 - h2) + 45 * (h1 - h_1)) / 60 * d_ m_out[i, j] = False continue - # If all value of buffer are available + # If all values of buffer are available grad[i, j] = ( (3 * (h_4 - h4) + 32 * (h3 - h_3) + 168 * (h_2 - h2) + 672 * (h1 - h_1)) / 840 diff --git a/src/py_eddy_tracker/eddy_feature.py b/src/py_eddy_tracker/eddy_feature.py index 3e6c86ea..8bc139ab 100644 --- a/src/py_eddy_tracker/eddy_feature.py +++ b/src/py_eddy_tracker/eddy_feature.py @@ -8,8 +8,7 @@ from matplotlib.cm import get_cmap from matplotlib.colors import Normalize from matplotlib.figure import Figure -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( array, concatenate, @@ -31,7 +30,7 @@ class Amplitude(object): """ Class to calculate *amplitude* and counts of *local maxima/minima* - within a closed region of a sea level anomaly field. + within a closed region of a sea surface height field. """ EPSILON = 1e-8 @@ -61,13 +60,13 @@ def __init__( """ Create amplitude object - :param Contours contour: - :param float contour_height: - :param array data: - :param float interval: - :param int mle: maximum number of local maxima in contour - :param int nb_step_min: number of interval to consider like an eddy - :param int nb_step_to_be_mle: number of interval to be consider like another maxima + :param Contours contour: usefull class defined below + :param float contour_height: field value of the contour + :param array data: grid + :param float interval: step between two contours + :param int mle: maximum number of local extrema in contour + :param int nb_step_min: minimum number of intervals to consider the contour as an eddy + :param int nb_step_to_be_mle: number of intervals to be considered as another extrema """ # Height of the contour @@ -102,8 +101,9 @@ def __init__( self.nb_pixel = i_x.shape[0] # Only pixel in contour + # FIXME : change sla by ssh as the grid can be adt? self.sla = data[contour.pixels_index] - # Amplitude which will be provide + # Amplitude which will be provided self.amplitude = 0 # Maximum local extrema accepted self.mle = mle @@ -117,7 +117,7 @@ def all_pixels_below_h0(self, level): Check CSS11 criterion 1: The SSH values of all of the pixels are below a given SSH threshold for cyclonic eddies. """ - # In some case pixel value must be very near of contour bounds + # In some cases pixel value may be very close to the contour bounds if self.sla.mask.any() or ((self.sla.data - self.h_0) > self.EPSILON).any(): return False else: @@ -293,10 +293,10 @@ class Contours(object): Attributes: contour: - A matplotlib contour object of high-pass filtered SLA + A matplotlib contour object of high-pass filtered SSH eddy: - A tracklist object holding the SLA data + A tracklist object holding the SSH data grd: A grid object @@ -406,7 +406,7 @@ def __init__(self, x, y, z, levels, wrap_x=False, keep_unclose=False): fig = Figure() ax = fig.add_subplot(111) if wrap_x: - logger.debug("wrapping activate to compute contour") + logger.debug("wrapping activated to compute contour") x = concatenate((x, x[:1] + 360)) z = ma.concatenate((z, z[:1])) logger.debug("X shape : %s", x.shape) @@ -600,10 +600,10 @@ def display( 4. - Amplitude criterion (yellow) :param str field: Must be 'shape_error', 'x', 'y' or 'radius'. - If define display_criterion is not use. - bins argument must be define - :param array bins: bins use to colorize contour - :param str cmap: Name of cmap to use for field display + If defined display_criterion is not use. + bins argument must be defined + :param array bins: bins used to colorize contour + :param str cmap: Name of cmap for field display :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` .. minigallery:: py_eddy_tracker.Contours.display @@ -645,7 +645,7 @@ def display( paths.append(i.vertices) local_kwargs = kwargs.copy() if "color" not in kwargs: - local_kwargs["color"] = collection.get_color() + local_kwargs["color"] = collection.get_edgecolor() local_kwargs.pop("label", None) elif j != 0: local_kwargs.pop("label", None) @@ -688,7 +688,7 @@ def display( ax.autoscale_view() def label_contour_unused_which_contain_eddies(self, eddies): - """Select contour which contain several eddies""" + """Select contour containing several eddies""" if eddies.sign_type == 1: # anticyclonic sl = slice(None, -1) @@ -783,7 +783,7 @@ def index_from_nearest_path_with_pt_in_bbox_( d_x = x_value[i_elt_pt] - xpt_ if abs(d_x) > 180: d_x = (d_x + 180) % 360 - 180 - dist = d_x ** 2 + (y_value[i_elt_pt] - ypt) ** 2 + dist = d_x**2 + (y_value[i_elt_pt] - ypt) ** 2 if dist < dist_ref: dist_ref = dist i_ref = i_elt_c diff --git a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py index 7baaffd3..b0d4abfa 100644 --- a/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py +++ b/src/py_eddy_tracker/featured_tracking/old_tracker_reference.py @@ -8,7 +8,6 @@ class CheltonTracker(Model): - __slots__ = tuple() GROUND = RegularGridDataset( @@ -21,13 +20,13 @@ def cost_function(records_in, records_out, distance): return distance def mask_function(self, other, distance): - """We mask link with ellips and ratio""" - # Compute Parameter of ellips + """We mask link with ellipse and ratio""" + # Compute Parameter of ellipse minor, major = 1.05, 1.5 - y = self.basic_formula_ellips_major_axis( + y = self.basic_formula_ellipse_major_axis( self.lat, degrees=True, c0=minor, cmin=minor, cmax=major, lat1=23, lat2=5 ) - # mask from ellips + # mask from ellipse mask = self.shifted_ellipsoid_degrees_mask( other, minor=minor, major=y # Minor can be bigger than major?? ) diff --git a/src/py_eddy_tracker/generic.py b/src/py_eddy_tracker/generic.py index 35c23817..2fdb737a 100644 --- a/src/py_eddy_tracker/generic.py +++ b/src/py_eddy_tracker/generic.py @@ -3,8 +3,7 @@ Tool method which use mostly numba """ -from numba import njit, prange -from numba import types as numba_types +from numba import njit, prange, types as numba_types from numpy import ( absolute, arcsin, @@ -30,7 +29,7 @@ @njit(cache=True) def count_consecutive(mask): """ - Count consecutive event every False flag count restart + Count consecutive events every False flag count restart :param array[bool] mask: event to count :return: count when consecutive event @@ -50,7 +49,7 @@ def count_consecutive(mask): @njit(cache=True) def reverse_index(index, nb): """ - Compute a list of index, which are not in index. + Compute a list of indices, which are not in index. :param array index: index of group which will be set to False :param array nb: Count for each group @@ -65,17 +64,18 @@ def reverse_index(index, nb): @njit(cache=True) def build_index(groups): - """We expected that variable is monotonous, and return index for each step change. + """We expect that variable is monotonous, and return index for each step change. - :param array groups: array which contain group to be separated - :return: (first_index of each group, last_index of each group, value to shift group) + :param array groups: array that contains groups to be separated + :return: (first_index of each group, last_index of each group, value to shift groups) :rtype: (array, array, int) - Examples - -------- + :Example: + >>> build_index(array((1, 1, 3, 4, 4))) (array([0, 2, 2, 3]), array([2, 2, 3, 5]), 1) """ + i0, i1 = groups.min(), groups.max() amplitude = i1 - i0 + 1 # Index of first observation for each group @@ -83,33 +83,32 @@ def build_index(groups): for i, group in enumerate(groups[:-1]): # Get next value to compare next_group = groups[i + 1] - # if different we need to set index for all group between the 2 values + # if different we need to set index for all groups between the 2 values if group != next_group: first_index[group - i0 + 1 : next_group - i0 + 1] = i + 1 last_index = zeros(amplitude, dtype=numba_types.int_) last_index[:-1] = first_index[1:] - # + 2 because we iterate only until -2 and we want upper bound ( 1 + 1) - last_index[-1] = i + 2 + last_index[-1] = len(groups) return first_index, last_index, i0 @njit(cache=True) def hist_numba(x, bins): - """Call numba histogram to speed up.""" + """Call numba histogram to speed up.""" return histogram(x, bins) @njit(cache=True, fastmath=True, parallel=False) def distance_grid(lon0, lat0, lon1, lat1): """ - Get distance for every couple of point. + Get distance for every couple of points. :param array lon0: :param array lat0: :param array lon1: :param array lat1: - :return: nan value for far away point, and km for other + :return: nan value for far away points, and km for other :rtype: array """ nb_0 = lon0.shape[0] @@ -132,8 +131,8 @@ def distance_grid(lon0, lat0, lon1, lat1): sin_dlon = sin((dlon) * 0.5 * D2R) cos_lat1 = cos(lat0[i] * D2R) cos_lat2 = cos(lat1[j] * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - dist[i, j] = 6370.997 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + dist[i, j] = 6370.997 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) return dist @@ -154,8 +153,8 @@ def distance(lon0, lat0, lon1, lat1): sin_dlon = sin((lon1 - lon0) * 0.5 * D2R) cos_lat1 = cos(lat0 * D2R) cos_lat2 = cos(lat1 * D2R) - a_val = sin_dlon ** 2 * cos_lat1 * cos_lat2 + sin_dlat ** 2 - return 6370997.0 * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat1 * cos_lat2 + sin_dlat**2 + return 6370997.0 * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) @njit(cache=True) @@ -164,7 +163,7 @@ def cumsum_by_track(field, track): Cumsum by track. :param array field: data to sum - :pram array(int) track: id of track to separate data + :pram array(int) track: id of trajectories to separate data :return: cumsum with a reset at each start of track :rtype: array """ @@ -192,7 +191,7 @@ def interp2d_geo(x_g, y_g, z_g, m_g, x, y, nearest=False): :param array m_g: Boolean grid, True if value is masked :param array x: coordinate where interpolate z :param array y: coordinate where interpolate z - :param bool nearest: if true we will take nearest pixel + :param bool nearest: if True we will take nearest pixel :return: z interpolated :rtype: array """ @@ -256,17 +255,17 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): nb_x = x_g.shape[0] nb_y = y_g.shape[0] is_circular = abs(x_g[-1] % 360 - (x_g[0] - x_step) % 360) < 1e-5 - # Indices which should be never exist + # Indexes that should never exist i0_old, j0_old, masked = -100000000, -10000000, False z = empty(x.shape, dtype=z_g.dtype) for i in prange(x.size): x_ = (x[i] - x_ref) / x_step y_ = (y[i] - y_ref) / y_step i0 = int(floor(x_)) - # To keep original value if wrapping apply to compute xd + # To keep original values if wrapping applied to compute xd i0_ = i0 j0 = int(floor(y_)) - # corner are the same need only a new xd and yd + # corners are the same need only a new xd and yd if i0 != i0_old or j0 != j0_old: i1 = i0 + 1 j1 = j0 + 1 @@ -288,7 +287,7 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): z_g[i1, j1], ) masked = False - # Need to be store only on change + # Need to be stored only on change i0_old, j0_old = i0, j0 if masked: z[i] = nan @@ -302,14 +301,14 @@ def interp2d_bilinear(x_g, y_g, z_g, m_g, x, y): @njit(cache=True, fastmath=True) -def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): +def uniform_resample(x_val, y_val, num_fac=2, fixed_size=-1): """ Resample contours to have (nearly) equal spacing. :param array_like x_val: input x contour coordinates :param array_like y_val: input y contour coordinates :param int num_fac: factor to increase lengths of output coordinates - :param int,None fixed_size: if define, it will used to set sampling + :param int fixed_size: if > -1, will be used to set sampling """ nb = x_val.shape[0] # Get distances @@ -320,7 +319,7 @@ def uniform_resample(x_val, y_val, num_fac=2, fixed_size=None): dist[1:][dist[1:] < 1e-3] = 1e-3 dist = dist.cumsum() # Get uniform distances - if fixed_size is None: + if fixed_size == -1: fixed_size = dist.size * num_fac d_uniform = linspace(0, dist[-1], fixed_size) x_new = interp(d_uniform, dist, x_val) @@ -359,17 +358,17 @@ def flatten_line_matrix(l_matrix): @njit(cache=True) def simplify(x, y, precision=0.1): """ - Will remove all middle/end point which are closer than precision. + Will remove all middle/end points closer than precision. :param array x: :param array y: - :param float precision: if two points have distance inferior to precision with remove next point + :param float precision: if two points have distance inferior to precision we remove next point :return: (x,y) :rtype: (array,array) """ - precision2 = precision ** 2 + precision2 = precision**2 nb = x.shape[0] - # will be True for value keep + # will be True for kept values mask = ones(nb, dtype=bool_) for j in range(0, nb): x_previous, y_previous = x[j], y[j] @@ -399,7 +398,7 @@ def simplify(x, y, precision=0.1): if d_y > precision: x_previous, y_previous = x_, y_ continue - d2 = d_x ** 2 + d_y ** 2 + d2 = d_x**2 + d_y**2 if d2 > precision2: x_previous, y_previous = x_, y_ continue @@ -423,11 +422,11 @@ def split_line(x, y, i): :param y: array :param i: array of int at each i change, we cut x, y - :return: x and y separate by nan at each i jump + :return: x and y separated by nan at each i jump """ nb_jump = len(where(i[1:] - i[:-1] != 0)[0]) nb_value = x.shape[0] - final_size = (nb_jump - 1) + nb_value + final_size = nb_jump + nb_value new_x = empty(final_size, dtype=x.dtype) new_y = empty(final_size, dtype=y.dtype) new_j = 0 @@ -445,11 +444,11 @@ def split_line(x, y, i): @njit(cache=True) def wrap_longitude(x, y, ref, cut=False): """ - Will wrap contiguous longitude with reference as west bound. + Will wrap contiguous longitude with reference as western boundary. :param array x: :param array y: - :param float ref: longitude of reference, all the new value will be between ref and ref + 360 + :param float ref: longitude of reference, all the new values will be between ref and ref + 360 :param bool cut: if True line will be cut at the bounds :return: lon,lat :rtype: (array,array) @@ -457,17 +456,18 @@ def wrap_longitude(x, y, ref, cut=False): if cut: indexs = list() nb = x.shape[0] - new_previous = (x[0] - ref) % 360 + + new_x_previous = (x[0] - ref) % 360 + ref x_previous = x[0] for i in range(1, nb): x_ = x[i] - new_x = (x_ - ref) % 360 + new_x = (x_ - ref) % 360 + ref if not isnan(x_) and not isnan(x_previous): - d_new = new_x - new_previous + d_new = new_x - new_x_previous d = x_ - x_previous if abs(d - d_new) > 1e-5: indexs.append(i) - x_previous, new_previous = x_, new_x + x_previous, new_x_previous = x_, new_x nb_indexs = len(indexs) new_size = nb + nb_indexs * 3 @@ -478,6 +478,7 @@ def wrap_longitude(x, y, ref, cut=False): for i in range(nb): if j < nb_indexs and i == indexs[j]: j += 1 + # FIXME need check cor = 360 if x[i - 1] > x[i] else -360 out_x[i + i_] = (x[i] - ref) % 360 + ref - cor out_y[i + i_] = y[i] @@ -517,8 +518,8 @@ def coordinates_to_local(lon, lat, lon0, lat0): sin_dlon = sin(dlon * 0.5) cos_lat0 = cos(lat0 * D2R) cos_lat = cos(lat * D2R) - a_val = sin_dlon ** 2 * cos_lat0 * cos_lat + sin_dlat ** 2 - module = R * 2 * arctan2(a_val ** 0.5, (1 - a_val) ** 0.5) + a_val = sin_dlon**2 * cos_lat0 * cos_lat + sin_dlat**2 + module = R * 2 * arctan2(a_val**0.5, (1 - a_val) ** 0.5) azimuth = pi / 2 - arctan2( cos_lat * sin(dlon), @@ -541,7 +542,7 @@ def local_to_coordinates(x, y, lon0, lat0): """ D2R = pi / 180.0 R = 6370997 - d = (x ** 2 + y ** 2) ** 0.5 / R + d = (x**2 + y**2) ** 0.5 / R a = -(arctan2(y, x) - pi / 2) lat = arcsin(sin(lat0 * D2R) * cos(d) + cos(lat0 * D2R) * sin(d) * cos(a)) lon = ( @@ -557,7 +558,7 @@ def local_to_coordinates(x, y, lon0, lat0): @njit(cache=True, fastmath=True) def nearest_grd_indice(x, y, x0, y0, xstep, ystep): """ - Get nearest grid indice from a position. + Get nearest grid index from a position. :param x: longitude :param y: latitude @@ -575,7 +576,7 @@ def nearest_grd_indice(x, y, x0, y0, xstep, ystep): @njit(cache=True) def bbox_indice_regular(vertices, x0, y0, xstep, ystep, N, circular, x_size): """ - Get bbox indice of a contour in a regular grid. + Get bbox index of a contour in a regular grid. :param vertices: vertice of contour :param float x0: first grid longitude @@ -612,3 +613,48 @@ def build_circle(x0, y0, r): angle = radians(linspace(0, 360, 50)) x_norm, y_norm = cos(angle), sin(angle) return x_norm * r + x0, y_norm * r + y0 + + +def window_index(x, x0, half_window=1): + """ + Give for a fixed half_window each start and end index for each x0, in + an unsorted array. + + :param array x: array of value + :param array x0: array of window center + :param float half_window: half window + """ + # Sort array, bounds will be sort also + i_ordered = x.argsort(kind="mergesort") + return window_index_(x, i_ordered, x0, half_window) + + +@njit(cache=True) +def window_index_(x, i_ordered, x0, half_window=1): + nb_x, nb_pt = x.size, x0.size + first_index = empty(nb_pt, dtype=i_ordered.dtype) + last_index = empty(nb_pt, dtype=i_ordered.dtype) + # First bound to find + j_min, j_max = 0, 0 + x_min = x0[j_min] - half_window + x_max = x0[j_max] + half_window + # We iterate on ordered x + for i, i_x in enumerate(i_ordered): + x_ = x[i_x] + # if x bigger than x_min , we found bound and search next one + while x_ > x_min and j_min < nb_pt: + first_index[j_min] = i + j_min += 1 + x_min = x0[j_min] - half_window + # if x bigger than x_max , we found bound and search next one + while x_ > x_max and j_max < nb_pt: + last_index[j_max] = i + j_max += 1 + x_max = x0[j_max] + half_window + if j_max == nb_pt: + break + for i in range(j_min, nb_pt): + first_index[i] = nb_x + for i in range(j_max, nb_pt): + last_index[i] = nb_x + return i_ordered, first_index, last_index diff --git a/src/py_eddy_tracker/gui.py b/src/py_eddy_tracker/gui.py index 423ff306..a85e9c18 100644 --- a/src/py_eddy_tracker/gui.py +++ b/src/py_eddy_tracker/gui.py @@ -4,13 +4,16 @@ """ from datetime import datetime, timedelta +import logging +from matplotlib.projections import register_projection import matplotlib.pyplot as plt import numpy as np -from matplotlib.projections import register_projection from .generic import flatten_line_matrix, split_line +logger = logging.getLogger("pet") + try: from pylook.axes import PlatCarreAxes except ImportError: @@ -26,12 +29,15 @@ def __init__(self, *args, **kwargs): self.set_aspect("equal") +GUI_AXES = "full_axes" + + class GUIAxes(PlatCarreAxes): """ - Axes which will use full space available + Axes that uses full space available """ - name = "full_axes" + name = GUI_AXES def end_pan(self, *args, **kwargs): (x0, x1), (y0, y1) = self.get_xlim(), self.get_ylim() @@ -88,7 +94,7 @@ def set_initial_values(self): for dataset in self.datasets.values(): t0_, t1_ = dataset.period t0, t1 = min(t0, t0_), max(t1, t1_) - + logger.debug("period detected %f -> %f", t0, t1) self.settings = dict(period=(t0, t1), now=t1) @property @@ -125,7 +131,7 @@ def med(self): def setup(self): self.figure = plt.figure() # map - self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection="full_axes") + self.map = self.figure.add_axes((0, 0.25, 1, 0.75), projection=GUI_AXES) self.map.grid() self.map.tick_params("both", pad=-22) # self.map.tick_params("y", pad=-22) @@ -288,8 +294,8 @@ def get_infos(self, name, index): i_first = d.index_from_track[tr] track = d.obs[i_first : i_first + nb] nb -= 1 - t0 = timedelta(days=int(track[0]["time"])) + datetime(1950, 1, 1) - t1 = timedelta(days=int(track[-1]["time"])) + datetime(1950, 1, 1) + t0 = timedelta(days=track[0]["time"]) + datetime(1950, 1, 1) + t1 = timedelta(days=track[-1]["time"]) + datetime(1950, 1, 1) txt = f"--{name}--\n" txt += f" {t0} -> {t1}\n" txt += f" Tracks : {tr} {now['n']}/{nb} ({now['n'] / nb * 100:.2f} %)\n" diff --git a/src/py_eddy_tracker/misc.py b/src/py_eddy_tracker/misc.py new file mode 100644 index 00000000..647bfba3 --- /dev/null +++ b/src/py_eddy_tracker/misc.py @@ -0,0 +1,21 @@ +import re + +from matplotlib.animation import FuncAnimation + + +class VideoAnimation(FuncAnimation): + def _repr_html_(self, *args, **kwargs): + """To get video in html and have a player""" + content = self.to_html5_video() + return re.sub( + r'width="[0-9]*"\sheight="[0-9]*"', 'width="100%" height="100%"', content + ) + + def save(self, *args, **kwargs): + if args[0].endswith("gif"): + # In this case gif is used to create thumbnail which is not used but consume same time than video + # So we create an empty file, to save time + with open(args[0], "w") as _: + pass + return + return super().save(*args, **kwargs) diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py index 5a01d452..81929e1e 100644 --- a/src/py_eddy_tracker/observations/groups.py +++ b/src/py_eddy_tracker/observations/groups.py @@ -1,9 +1,11 @@ -import logging from abc import ABC, abstractmethod +import logging -from numba import njit -from numpy import arange, int32, interp, median, zeros +from numba import njit, types as nb_types +from numpy import arange, full, int32, interp, isnan, median, where, zeros +from ..generic import window_index +from ..poly import create_meshed_particles, poly_indexs from .observation import EddiesObservations logger = logging.getLogger("pet") @@ -13,17 +15,17 @@ def get_missing_indices( array_time, array_track, dt=1, flag_untrack=True, indice_untrack=0 ): - """return indices where it misses values + """Return indexes where values are missing :param np.array(int) array_time : array of strictly increasing int representing time - :param np.array(int) array_track: N° track where observation belong - :param int,float dt: theorical timedelta between 2 observation + :param np.array(int) array_track: N° track where observations belong + :param int,float dt: theorical timedelta between 2 observations :param bool flag_untrack: if True, ignore observations where n°track equal `indice_untrack` - :param int indice_untrack: n° representing where observations are untrack + :param int indice_untrack: n° representing where observations are untracked ex : array_time = np.array([67, 68, 70, 71, 74, 75]) - array_track= np.array([ 1, 1, 1, 1, 1, 1]) + array_track= np.array([ 1, 1, 1, 1, 1, 1]) return : np.array([2, 4, 4]) """ @@ -65,6 +67,206 @@ def get_missing_indices( return indices +def advect(x, y, c, t0, n_days, u_name="u", v_name="v"): + """ + Advect particles from t0 to t0 + n_days, with data cube. + + :param np.array(float) x: longitude of particles + :param np.array(float) y: latitude of particles + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int t0: julian day of advection start + :param int n_days: number of days to advect + :param str u_name: variable name for u component + :param str v_name: variable name for v component + """ + + kw = dict(nb_step=6, time_step=86400 / 6) + if n_days < 0: + kw["backward"] = True + n_days = -n_days + p = c.advect(x, y, u_name=u_name, v_name=v_name, t_init=t0, **kw) + for _ in range(n_days): + t, x, y = p.__next__() + return t, x, y + + +def particle_candidate_step( + t_start, contours_start, contours_end, space_step, dt, c, day_fraction=6, **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages. + For one time step. + + :param int t_start: julian day of the advection + :param (np.array(float),np.array(float)) contours_start: origin contour + :param (np.array(float),np.array(float)) contours_end: destination contour + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param int day_fraction: fraction of day + :params dict kwargs: dict of params given to advection + :return (np.array,np.array): return target index and percent associate + """ + # In case of zarr array + contours_start = [i[:] for i in contours_start] + contours_end = [i[:] for i in contours_end] + # Create particles in start contour + x, y, i_start = create_meshed_particles(*contours_start, space_step) + # Advect particles + kw = dict(nb_step=day_fraction, time_step=86400 / day_fraction) + p = c.advect(x, y, t_init=t_start, **kwargs, **kw) + for _ in range(abs(dt)): + _, x, y = p.__next__() + m = ~(isnan(x) + isnan(y)) + i_end = full(x.shape, -1, dtype="i4") + if m.any(): + # Id eddies for each alive particle in start contour + i_end[m] = poly_indexs(x[m], y[m], *contours_end) + shape = (contours_start[0].shape[0], 2) + # Get target for each contour + i_target, pct_target = full(shape, -1, dtype="i4"), zeros(shape, dtype="f8") + nb_end = contours_end[0].shape[0] + get_targets(i_start, i_end, i_target, pct_target, nb_end) + return i_target, pct_target.astype("i1") + + +def particle_candidate( + c, + eddies, + step_mesh, + t_start, + i_target, + pct, + contour_start="speed", + contour_end="effective", + **kwargs +): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` c: GridCollection with speed for particles + :param GroupEddiesObservations eddies: GroupEddiesObservations considered + :param int t_start: julian day of the advection + :param np.array(int) i_target: corresponding obs where particles are advected + :param np.array(int) pct: corresponding percentage of avected particles + :param str contour_start: contour where particles are injected + :param str contour_end: contour where particles are counted after advection + :params dict kwargs: dict of params given to `advect` + + """ + # Obs from initial time + m_start = eddies.time == t_start + e = eddies.extract_with_mask(m_start) + + # to be able to get global index + translate_start = where(m_start)[0] + + # Create particles in specified contour + intern = False if contour_start == "effective" else True + x, y, i_start = e.create_particles(step_mesh, intern=intern) + # Advection + t_end, x, y = advect(x, y, c, t_start, **kwargs) + + # eddies at last date + m_end = eddies.time == t_end / 86400 + e_end = eddies.extract_with_mask(m_end) + + # to be able to get global index + translate_end = where(m_end)[0] + + # Id eddies for each alive particle in specified contour + intern = False if contour_end == "effective" else True + i_end = e_end.contains(x, y, intern=intern) + + # compute matrix and fill target array + get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct) + + +@njit(cache=True) +def get_targets(i_start, i_end, i_target, pct, nb_end): + """Compute target observation and associated percentages + + :param array(int) i_start: indices in time 0 + :param array(int) i_end: indices in time N + :param array(int) i_target: corresponding obs where particles are advected + :param array(int) pct: corresponding percentage of avected particles + :param int nb_end: number of contour at time N + """ + nb_start = i_target.shape[0] + # Matrix which will store count for every couple + counts = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + ref[i_start_] += 1 + if i_end_ != -1: + counts[i_start_, i_end_] += 1 + # From i to j + for i in range(nb_start): + for j in range(nb_end): + count = counts[i, j] + if count == 0: + continue + pct_ = count / ref[i] * 100 + pct_0 = pct[i, 0] + # If percent is higher than previous stored in rank 0 + if pct_ > pct_0: + pct[i, 1] = pct_0 + pct[i, 0] = pct_ + i_target[i, 1] = i_target[i, 0] + i_target[i, 0] = j + # If percent is higher than previous stored in rank 1 + elif pct_ > pct[i, 1]: + pct[i, 1] = pct_ + i_target[i, 1] = j + + +@njit(cache=True) +def get_matrix(i_start, i_end, translate_start, translate_end, i_target, pct): + """Compute target observation and associated percentages + + :param np.array(int) i_start: indices of associated contours at starting advection day + :param np.array(int) i_end: indices of associated contours after advection + :param np.array(int) translate_start: corresponding global indices at starting advection day + :param np.array(int) translate_end: corresponding global indices after advection + :param np.array(int) i_target: corresponding obs where particles are advected + :param np.array(int) pct: corresponding percentage of avected particles + """ + + nb_start, nb_end = translate_start.size, translate_end.size + # Matrix which will store count for every couple + count = zeros((nb_start, nb_end), dtype=nb_types.int32) + # Number of particles in each origin observation + ref = zeros(nb_start, dtype=nb_types.int32) + # For each particle + for i in range(i_start.size): + i_end_ = i_end[i] + i_start_ = i_start[i] + if i_end_ != -1: + count[i_start_, i_end_] += 1 + ref[i_start_] += 1 + for i in range(nb_start): + for j in range(nb_end): + pct_ = count[i, j] + # If there are particles from i to j + if pct_ != 0: + # Get percent + pct_ = pct_ / ref[i] * 100.0 + # Get indices in full dataset + i_, j_ = translate_start[i], translate_end[j] + pct_0 = pct[i_, 0] + if pct_ > pct_0: + pct[i_, 1] = pct_0 + pct[i_, 0] = pct_ + i_target[i_, 1] = i_target[i_, 0] + i_target[i_, 0] = j_ + elif pct_ > pct[i_, 1]: + pct[i_, 1] = pct_ + i_target[i_, 1] = j_ + return i_target, pct + + class GroupEddiesObservations(EddiesObservations, ABC): @abstractmethod def fix_next_previous_obs(self): @@ -72,50 +274,49 @@ def fix_next_previous_obs(self): @abstractmethod def get_missing_indices(self, dt): - "find indices where observations is missing" + "Find indexes where observations are missing" pass def filled_by_interpolation(self, mask): - """Filled selected values by interpolation + """Fill selected values by interpolation :param array(bool) mask: True if must be filled by interpolation .. minigallery:: py_eddy_tracker.TrackEddiesObservations.filled_by_interpolation """ - + if self.track.size == 0: + return nb_filled = mask.sum() logger.info("%d obs will be filled (unobserved)", nb_filled) nb_obs = len(self) index = arange(nb_obs) - for field in self.obs.dtype.descr: - # print(f"field : {field}") - var = field[0] + for field in self.fields: if ( - var in ["n", "virtual", "track", "cost_association"] - or var in self.array_variables + field in ["n", "virtual", "track", "cost_association"] + or field in self.array_variables ): continue - self.obs[var][mask] = interp( - index[mask], index[~mask], self.obs[var][~mask] + self.obs[field][mask] = interp( + index[mask], index[~mask], self.obs[field][~mask] ) def insert_virtual(self): - """insert virtual observation on segments where observations were not found""" + """Insert virtual observations on segments where observations are missing""" dt_theorical = median(self.time[1:] - self.time[:-1]) indices = self.get_missing_indices(dt_theorical) logger.info("%d virtual observation will be added", indices.size) - # new observation size + # new observations size size_obs_corrected = self.time.size + indices.size - # correction of indices for new size + # correction of indexes for new size indices_corrected = indices + arange(indices.size) - # creating mask with indices + # creating mask with indexes mask = zeros(size_obs_corrected, dtype=bool) mask[indices_corrected] = 1 @@ -128,12 +329,12 @@ def insert_virtual(self): def keep_tracks_by_date(self, date, nb_days): """ - Find tracks which exist at date `date` and lasted at least `nb_days` after. + Find tracks that exist at date `date` and lasted at least `nb_days` after. :param int,float date: date where the tracks must exist - :param int,float nb_days: number of time where the tracks must exist. Can be negative + :param int,float nb_days: number of times the tracks must exist. Can be negative - If nb_days is negative, it search a tracks which exist at the date, + If nb_days is negative, it searches a track that exists at the date, but existed at least `nb_days` before the date """ @@ -148,3 +349,128 @@ def keep_tracks_by_date(self, date, nb_days): mask[i] = True return self.extract_with_mask(mask) + + def particle_candidate_atlas( + self, + cube, + space_step, + dt, + start_intern=False, + end_intern=False, + callback_coherence=None, + finalize_coherence=None, + **kwargs + ): + """Select particles within eddies, advect them, return target observation and associated percentages + + :param `~py_eddy_tracker.dataset.grid.GridCollection` cube: GridCollection with speed for particles + :param float space_step: step between 2 particles + :param int dt: duration of advection + :param bool start_intern: Use intern or extern contour at injection, defaults to False + :param bool end_intern: Use intern or extern contour at end of advection, defaults to False + :param dict kwargs: dict of params given to advection + :param func callback_coherence: if None we will use cls.fill_coherence + :param func finalize_coherence: to apply on results of callback_coherence + :return (np.array,np.array): return target index and percent associate + """ + t_start, t_end = int(self.period[0]), int(self.period[1]) + # Pre-compute to get time index + i_sort, i_start, i_end = window_index( + self.time, arange(t_start, t_end + 1), 0.5 + ) + # Out shape + shape = (len(self), 2) + i_target, pct = full(shape, -1, dtype="i4"), zeros(shape, dtype="i1") + # Backward or forward + times = arange(t_start, t_end - dt) if dt > 0 else arange(t_start - dt, t_end) + + if callback_coherence is None: + callback_coherence = self.fill_coherence + indexs = dict() + results = list() + kw_coherence = dict(space_step=space_step, dt=dt, c=cube) + kw_coherence.update(kwargs) + for t in times: + logger.info( + "Coherence for time step : %s in [%s:%s]", t, times[0], times[-1] + ) + # Get index for origin + i = t - t_start + indexs0 = i_sort[i_start[i] : i_end[i]] + # Get index for end + i = t + dt - t_start + indexs1 = i_sort[i_start[i] : i_end[i]] + if indexs0.size == 0 or indexs1.size == 0: + continue + + results.append( + callback_coherence( + self, + i_target, + pct, + indexs0, + indexs1, + start_intern, + end_intern, + t_start=t, + **kw_coherence + ) + ) + indexs[results[-1]] = indexs0, indexs1 + + if finalize_coherence is not None: + finalize_coherence(results, indexs, i_target, pct) + return i_target, pct + + @classmethod + def fill_coherence( + cls, + network, + i_targets, + percents, + i_origin, + i_end, + start_intern, + end_intern, + **kwargs + ): + """_summary_ + + :param array i_targets: global target + :param array percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + :param bool start_intern: Use intern or extern contour at injection + :param bool end_intern: Use intern or extern contour at end of advection + """ + # Get contour data + contours_start = [ + network[label][i_origin] for label in cls.intern(start_intern) + ] + contours_end = [network[label][i_end] for label in cls.intern(end_intern)] + # Compute local coherence + i_local_targets, local_percents = particle_candidate_step( + contours_start=contours_start, contours_end=contours_end, **kwargs + ) + # Store + cls.merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ) + + @staticmethod + def merge_particle_result( + i_targets, percents, i_local_targets, local_percents, i_origin, i_end + ): + """Copy local result in merged result with global indexation + + :param array i_targets: global target + :param array percents: + :param array i_local_targets: local index target + :param array local_percents: + :param array i_origin: indices of origins + :param array i_end: indices of ends + """ + m = i_local_targets != -1 + i_local_targets[m] = i_end[i_local_targets[m]] + i_targets[i_origin] = i_local_targets + percents[i_origin] = local_percents diff --git a/src/py_eddy_tracker/observations/network.py b/src/py_eddy_tracker/observations/network.py index 53af9f30..f0b9d7cc 100644 --- a/src/py_eddy_tracker/observations/network.py +++ b/src/py_eddy_tracker/observations/network.py @@ -2,28 +2,38 @@ """ Class to create network of observations """ -import logging from glob import glob - -from numba import njit +import logging +import time +from datetime import timedelta, datetime +import os +import netCDF4 +from numba import njit, types as nb_types +from numba.typed import List +import numpy as np from numpy import ( arange, array, bincount, bool_, concatenate, + empty, - in1d, + nan, ones, + percentile, + uint16, uint32, unique, where, zeros, ) +import zarr +from ..dataset.grid import GridCollection from ..generic import build_index, wrap_longitude from ..poly import bbox_intersection, vertice_overlap -from .groups import GroupEddiesObservations, get_missing_indices +from .groups import GroupEddiesObservations, get_missing_indices, particle_candidate from .observation import EddiesObservations from .tracking import TrackEddiesObservations, track_loess_filter, track_median_filter @@ -77,10 +87,10 @@ def load_contour(self, filename): @njit(cache=True) def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): - """when an observation is virtual, we have to fix the previous and next obs + """When an observation is virtual, we have to fix the previous and next obs - :param np.array(int) next_obs : indice of next observation from network - :param np.array(int previous_obs: indice of previous observation from network + :param np.array(int) next_obs : index of next observation from network + :param np.array(int previous_obs: index of previous observation from network :param np.array(bool) flag_virtual: if observation is virtual or not """ @@ -88,8 +98,8 @@ def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): if not flag_virtual[i_o]: continue - # if there is many virtual side by side, there is some values writted multiple times. - # but it should not be slow + # if there are several consecutive virtuals, some values are written multiple times. + # but it should not be slow next_obs[i_o - 1] = i_o next_obs[i_o] = i_o + 1 previous_obs[i_o] = i_o - 1 @@ -97,24 +107,46 @@ def fix_next_previous_obs(next_obs, previous_obs, flag_virtual): class NetworkObservations(GroupEddiesObservations): - - __slots__ = ("_index_network",) - + __slots__ = ("_index_network", "_index_segment_track", "_segment_track_array") NOGROUP = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.reset_index() + + def __repr__(self): + m_event, s_event = ( + self.merging_event(only_index=True, triplet=True)[0], + self.splitting_event(only_index=True, triplet=True)[0], + ) + period = (self.period[1] - self.period[0]) / 365.25 + nb_by_network = self.network_size() + nb_trash = 0 if self.ref_index != 0 else nb_by_network[0] + lifetime=self.lifetime + big = 50_000 + infos = [ + f"Atlas with {self.nb_network} networks ({self.nb_network / period:0.0f} networks/year)," + f" {self.nb_segment} segments ({self.nb_segment / period:0.0f} segments/year), {len(self)} observations ({len(self) / period:0.0f} observations/year)", + f" {m_event.size} merging ({m_event.size / period:0.0f} merging/year), {s_event.size} splitting ({s_event.size / period:0.0f} splitting/year)", + f" with {(nb_by_network > big).sum()} network with more than {big} obs and the biggest have {nb_by_network.max()} observations ({nb_by_network[nb_by_network > big].sum()} observations cumulate)", + f" {nb_trash} observations in trash", + f" {lifetime.max()} days max of lifetime", + ] + return "\n".join(infos) + + def reset_index(self): self._index_network = None + self._index_segment_track = None + self._segment_track_array = None def find_segments_relative(self, obs, stopped=None, order=1): """ - find all relative segments within an event from an order. + Find all relative segments from obs linked with merging/splitting events at a specific order. - :param int obs: indice of event after the event - :param int stopped: indice of event before the event + :param int obs: index of observation after the event + :param int stopped: index of observation before the event :param int order: order of relatives accepted - - :return: all segments relatives + :return: all relative segments :rtype: EddiesObservations """ @@ -133,9 +165,9 @@ def find_segments_relative(self, obs, stopped=None, order=1): return nw.relatives([i_obs, i_stopped], order=order) def get_missing_indices(self, dt): - """find indices where observations is missing. + """Find indices where observations are missing. - As network have all untrack observation in tracknumber `self.NOGROUP`, + As network have all untracked observation in tracknumber `self.NOGROUP`, we don't compute them :param int,float dt: theorical delta time between 2 observations @@ -145,7 +177,7 @@ def get_missing_indices(self, dt): ) def fix_next_previous_obs(self): - """function used after 'insert_virtual', to correct next_obs and + """Function used after 'insert_virtual', to correct next_obs and previous obs. """ @@ -157,6 +189,80 @@ def index_network(self): self._index_network = build_index(self.track) return self._index_network + @property + def index_segment_track(self): + if self._index_segment_track is None: + self._index_segment_track = build_index(self.segment_track_array) + return self._index_segment_track + + def segment_size(self): + return self.index_segment_track[1] - self.index_segment_track[0] + + @property + def ref_segment_track_index(self): + return self.index_segment_track[2] + + @property + def ref_index(self): + return self.index_network[2] + + @property + def lifetime(self): + """Return lifetime for each observation""" + lt=self.networks_period.astype("int") + nb_by_network=self.network_size() + return lt.repeat(nb_by_network) + + def network_segment_size(self, id_networks=None): + """Get number of segment by network + + :return array: + """ + i0, i1, ref = build_index(self.track[self.index_segment_track[0]]) + if id_networks is None: + return i1 - i0 + else: + i = id_networks - ref + return i1[i] - i0[i] + + def network_size(self, id_networks=None): + """ + Return size for specified network + + :param list,array, None id_networks: ids to identify network + """ + if id_networks is None: + return self.index_network[1] - self.index_network[0] + else: + i = id_networks - self.index_network[2] + return self.index_network[1][i] - self.index_network[0][i] + + @property + def networks_period(self): + """ + Return period for each network + """ + return get_period_with_index(self.time, *self.index_network[:2]) + + + + def unique_segment_to_id(self, id_unique): + """Return id network and id segment for a unique id + + :param array id_unique: + """ + i = self.index_segment_track[0][id_unique] - self.ref_segment_track_index + return self.track[i], self.segment[i] + + def segment_slice(self, id_network, id_segment): + """ + Return slice for one segment + + :param int id_network: id to identify network + :param int id_segment: id to identify segment + """ + raise Exception("need to be implemented") + def network_slice(self, id_network): """ Return slice for one network @@ -184,40 +290,48 @@ def elements(self): def astype(self, cls): new = cls.new_like(self, self.shape) - print() - for k in new.obs.dtype.names: - if k in self.obs.dtype.names: + for k in new.fields: + if k in self.fields: new[k][:] = self[k][:] new.sign_type = self.sign_type return new - + def longer_than(self, nb_day_min=-1, nb_day_max=-1): """ Select network on time duration - :param int nb_day_min: Minimal number of day covered by one network, if negative -> not used - :param int nb_day_max: Maximal number of day covered by one network, if negative -> not used + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used + """ + return self.extract_with_mask(self.mask_longer_than(nb_day_min, nb_day_max)) + + def mask_longer_than(self, nb_day_min=-1, nb_day_max=-1): + """ + Select network on time duration + + :param int nb_day_min: Minimal number of days covered by one network, if negative -> not used + :param int nb_day_max: Maximal number of days covered by one network, if negative -> not used """ if nb_day_max < 0: nb_day_max = 1000000000000 mask = zeros(self.shape, dtype="bool") t = self.time - for i, b0, b1 in self.iter_on(self.track): + for i, _, _ in self.iter_on(self.track): nb = i.stop - i.start if nb == 0: continue - if nb_day_min <= ptp(t[i]) <= nb_day_max: + if nb_day_min <= (ptp(t[i]) + 1) <= nb_day_max: mask[i] = True - return self.extract_with_mask(mask) + return mask @classmethod def from_split_network(cls, group_dataset, indexs, **kwargs): """ - Build a NetworkObservations object with Group dataset and indexes + Build a NetworkObservations object with Group dataset and indices :param TrackEddiesObservations group_dataset: Group dataset :param indexs: result from split_network - return NetworkObservations + :return: NetworkObservations """ index_order = indexs.argsort(order=("group", "track", "time")) network = cls.new_like(group_dataset, len(group_dataset), **kwargs) @@ -227,7 +341,7 @@ def from_split_network(cls, group_dataset, indexs, **kwargs): continue network[field][:] = group_dataset[field][index_order] network.segment[:] = indexs["track"][index_order] - # n & p must be re-index + # n & p must be re-indexed n, p = indexs["next_obs"][index_order], indexs["previous_obs"][index_order] # we add 2 for -1 index return index -1 translate = -ones(index_order.max() + 2, dtype="i4") @@ -243,14 +357,20 @@ def infos(self, label=""): def correct_close_events(self, nb_days_max=20): """ - transform event where - segment A split to B, then A merge into B - + Transform event where + segment A splits from segment B, then x days after segment B merges with A to + segment A splits from segment B then x days after segment A merges with B (B will be longer) + These events have to last less than `nb_days_max` to be changed. - segment A split to B, then B merge to A - these events are filtered with `nb_days_max`, which the event have to take place in less than `nb_days_max` + ------------------- A + / / + B -------------------- + to + --A-- + / \ + B ----------------------------------- :param float nb_days_max: maximum time to search for splitting-merging event """ @@ -265,7 +385,7 @@ def correct_close_events(self, nb_days_max=20): previous_obs, next_obs = self.previous_obs, self.next_obs - # record for every segments, the slice, indice of next obs & indice of previous obs + # record for every segment the slice, index of next obs & index of previous obs for i, seg, _ in self.iter_on(segment): if i.start == i.stop: continue @@ -274,54 +394,58 @@ def correct_close_events(self, nb_days_max=20): segments_connexion[seg] = [i, i_p, i_n] for seg in sorted(segments_connexion.keys()): - seg_slice, i_seg_p, i_seg_n = segments_connexion[seg] + seg_slice, _, i_seg_n = segments_connexion[seg] # the segment ID has to be corrected, because we may have changed it since seg_corrected = segment[seg_slice.stop - 1] # we keep the real segment number seg_corrected_copy = segment_copy[seg_slice.stop - 1] + if i_seg_n == -1: + continue + # if segment is split n_seg = segment[i_seg_n] - # if segment has splitting - if i_seg_n != -1: - seg2_slice, i2_seg_p, i2_seg_n = segments_connexion[n_seg] - p2_seg = segment[i2_seg_p] - - # if it merge on the first in a certain time - if (p2_seg == seg_corrected) and ( - _time[i_seg_n] - _time[i2_seg_p] < nb_days_max - ): - my_slice = slice(i_seg_n, seg2_slice.stop) - # correct the factice segment - segment[my_slice] = seg_corrected - # correct the good segment - segment_copy[my_slice] = seg_corrected_copy - previous_obs[i_seg_n] = seg_slice.stop - 1 + seg2_slice, i2_seg_p, _ = segments_connexion[n_seg] + if i2_seg_p == -1: + continue + p2_seg = segment[i2_seg_p] - segments_connexion[seg_corrected][0] = my_slice + # if it merges on the first in a certain time + if (p2_seg == seg_corrected) and ( + _time[i_seg_n] - _time[i2_seg_p] < nb_days_max + ): + my_slice = slice(i_seg_n, seg2_slice.stop) + # correct the factice segment + segment[my_slice] = seg_corrected + # correct the good segment + segment_copy[my_slice] = seg_corrected_copy + previous_obs[i_seg_n] = seg_slice.stop - 1 - self.segment[:] = segment_copy - self.previous_obs[:] = previous_obs + segments_connexion[seg_corrected][0] = my_slice - self.sort() + return self.sort() def sort(self, order=("track", "segment", "time")): """ - sort observations + Sort observations - :param tuple order: order or sorting. Passed to `np.argsort` + :param tuple order: order or sorting. Given to :func:`numpy.argsort` """ - - index_order = self.obs.argsort(order=order) - for field in self.elements: + index_order = self.obs.argsort(order=order, kind="mergesort") + self.reset_index() + for field in self.fields: self[field][:] = self[field][index_order] - translate = -ones(index_order.max() + 2, dtype="i4") - translate[index_order] = arange(index_order.shape[0]) + nb_obs = len(self) + # we add 1 for -1 index return index -1 + translate = -ones(nb_obs + 1, dtype="i4") + translate[index_order] = arange(nb_obs) + # next & previous must be re-indexed self.next_obs[:] = translate[self.next_obs] self.previous_obs[:] = translate[self.previous_obs] + return index_order, translate def obs_relative_order(self, i_obs): self.only_one_network() @@ -329,13 +453,13 @@ def obs_relative_order(self, i_obs): def find_link(self, i_observations, forward=True, backward=False): """ - find all observations where obs `i_observation` could be + Find all observations where obs `i_observation` could be in future or past. - if forward=True, search all observation where water + If forward=True, search all observations where water from obs "i_observation" could go - if backward=True, search all observation + If backward=True, search all observation where water from obs `i_observation` could come from :param int,iterable(int) i_observation: @@ -372,7 +496,6 @@ def find_link(self, i_observations, forward=True, backward=False): segments_connexion[seg][0] = i_slice if i_p != -1: - if p_seg not in segments_connexion: segments_connexion[p_seg] = [None, [], []] @@ -424,8 +547,10 @@ def func_backward(seg, indice): return self.extract_with_mask(mask) def connexions(self, multi_network=False): - """ - create dictionnary for each segments, gives the segments which interact with + """Create dictionnary for each segment, gives the segments in interaction with + + :param bool multi_network: use segment_track_array instead of segment, defaults to False + :return dict: Return dict of set, for each seg id we get set of segment which have event with him """ if multi_network: segment = self.segment_track_array @@ -434,25 +559,28 @@ def connexions(self, multi_network=False): segment = self.segment segments_connexion = dict() - def add_seg(father, child): - if father not in segments_connexion: - segments_connexion[father] = set() - segments_connexion[father].add(child) - - previous_obs, next_obs = self.previous_obs, self.next_obs - for i, seg, _ in self.iter_on(segment): - if i.start == i.stop: - continue - i_p, i_n = previous_obs[i.start], next_obs[i.stop - 1] - # segment of interaction - p_seg, n_seg = segment[i_p], segment[i_n] - # Where segment are called - if i_p != -1: - add_seg(p_seg, seg) - add_seg(seg, p_seg) - if i_n != -1: - add_seg(n_seg, seg) - add_seg(seg, n_seg) + def add_seg(s1, s2): + if s1 not in segments_connexion: + segments_connexion[s1] = set() + if s2 not in segments_connexion: + segments_connexion[s2] = set() + segments_connexion[s1].add(s2), segments_connexion[s2].add(s1) + + # Get index for each segment + i0, i1, _ = self.index_segment_track + i1 = i1 - 1 + # Check if segment merge + i_next = self.next_obs[i1] + m_n = i_next != -1 + # Check if segment come from splitting + i_previous = self.previous_obs[i0] + m_p = i_previous != -1 + # For each split + for s1, s2 in zip(segment[i_previous[m_p]], segment[i0[m_p]]): + add_seg(s1, s2) + # For each merge + for s1, s2 in zip(segment[i_next[m_n]], segment[i1[m_n]]): + add_seg(s1, s2) return segments_connexion @classmethod @@ -474,6 +602,7 @@ def segment_relative_order(self, seg_origine): """ Compute the relative order of each segment to the chosen segment """ + self.only_one_network() i_s, i_e, i_ref = build_index(self.segment) segment_connexions = self.connexions() relative_tr = -ones(i_s.shape, dtype="i4") @@ -485,39 +614,70 @@ def segment_relative_order(self, seg_origine): d[i0:i1] = v return d - def relative(self, i_obs, order=2, direct=True, only_past=False, only_future=False): + def relatives(self, obs, order=2): """ - Extract the segments at a certain order from one observation. + Extract the segments at a certain order from multiple observations. - :param list obs: indice of observation for relative computation + :param iterable,int obs: + indices of observation for relatives computation. Can be one observation (int) + or collection of observations (iterable(int)) :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ... - - :return: all segments relatives + :return: all segments' relatives :rtype: EddiesObservations """ + segment = self.segment_track_array + previous_obs, next_obs = self.previous_obs, self.next_obs - d = self.segment_relative_order(self.segment[i_obs]) - m = (d <= order) * (d != -1) - return self.extract_with_mask(m) + segments_connexion = dict() - def relatives(self, obs, order=2, direct=True, only_past=False, only_future=False): - """ - Extract the segments at a certain order from multiple observations. + for i_slice, seg, _ in self.iter_on(segment): + if i_slice.start == i_slice.stop: + continue - :param list obs: indices of observation for relatives computation - :param int order: order of relatives wanted. 0 means only observations in obs, 1 means direct relatives, ... + i_p, i_n = previous_obs[i_slice.start], next_obs[i_slice.stop - 1] + p_seg, n_seg = segment[i_p], segment[i_n] - :return: all segments relatives - :rtype: EddiesObservations - """ + # dumping slice into dict + if seg not in segments_connexion: + segments_connexion[seg] = [i_slice, []] + else: + segments_connexion[seg][0] = i_slice - mask = zeros(self.segment.shape, dtype=bool) + if i_p != -1: + if p_seg not in segments_connexion: + segments_connexion[p_seg] = [None, []] - for i_obs in obs: - d = self.segment_relative_order(self.segment[i_obs]) - mask += (d <= order) * (d != -1) + # backward + segments_connexion[seg][1].append(p_seg) + segments_connexion[p_seg][1].append(seg) - return self.extract_with_mask(mask) + if i_n != -1: + if n_seg not in segments_connexion: + segments_connexion[n_seg] = [None, []] + + # forward + segments_connexion[seg][1].append(n_seg) + segments_connexion[n_seg][1].append(seg) + + i_obs = [obs] if not hasattr(obs, "__iter__") else obs + distance = zeros(segment.size, dtype=uint16) - 1 + + def loop(seg, dist=1): + i_slice, links = segments_connexion[seg] + d = distance[i_slice.start] + + if dist < d and dist <= order: + distance[i_slice] = dist + for _seg in links: + loop(_seg, dist + 1) + + for indice in i_obs: + loop(segment[indice], 0) + + return self.extract_with_mask(distance <= order) + + # keep old names, for backward compatibility + relative = relatives def close_network(self, other, nb_obs_min=10, **kwargs): """ @@ -526,7 +686,7 @@ def close_network(self, other, nb_obs_min=10, **kwargs): :param self other: Atlas to compare :param int nb_obs_min: Minimal number of overlap for one trajectory :param dict kwargs: keyword arguments for match function - :return: return other atlas reduce to common track with self + :return: return other atlas reduced to common tracks with self .. warning:: It could be a costly operation for huge dataset @@ -544,7 +704,7 @@ def close_network(self, other, nb_obs_min=10, **kwargs): return other.extract_with_mask(m) def normalize_longitude(self): - """Normalize all longitude + """Normalize all longitudes Normalize longitude field and in the same range : - longitude_max @@ -555,16 +715,16 @@ def normalize_longitude(self): lon0 = (self.lon[i_start] - 180).repeat(i_stop - i_start) logger.debug("Normalize longitude") self.lon[:] = (self.lon - lon0) % 360 + lon0 - if "lon_max" in self.obs.dtype.names: + if "lon_max" in self.fields: logger.debug("Normalize longitude_max") self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 if not self.raw_data: - if "contour_lon_e" in self.obs.dtype.names: + if "contour_lon_e" in self.fields: logger.debug("Normalize effective contour longitude") self.contour_lon_e[:] = ( (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 ).T - if "contour_lon_s" in self.obs.dtype.names: + if "contour_lon_s" in self.fields: logger.debug("Normalize speed contour longitude") self.contour_lon_s[:] = ( (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 @@ -575,7 +735,7 @@ def numbering_segment(self, start=0): New numbering of segment """ for i, _, _ in self.iter_on("track"): - new_numbering(self.segment[i]) + new_numbering(self.segment[i], start) def numbering_network(self, start=1): """ @@ -589,7 +749,7 @@ def only_one_network(self): if there are more than one network """ _, i_start, _ = self.index_network - if len(i_start) > 1: + if i_start.size > 1: raise Exception("Several networks") def position_filter(self, median_half_window, loess_half_window): @@ -629,7 +789,7 @@ def display_timeline( **kwargs, ): """ - Plot a timeline of a network. + Plot the timeline of a network. Must be called on only one network. :param matplotlib.axes.Axes ax: matplotlib axe used to draw @@ -664,17 +824,19 @@ def display_timeline( colors_mode=colors_mode, ) ) + if field is not None: + field = self.parse_varname(field) for i, b0, b1 in self.iter_on("segment"): - x = self.time[i] + x = self.time_datetime64[i] if x.shape[0] == 0: continue if field is None: y = b0 * ones(x.shape) else: if method == "all": - y = self[field][i] * factor + y = field[i] * factor else: - y = self[field][i].mean() * ones(x.shape) * factor + y = field[i].mean() * ones(x.shape) * factor if colors_mode == "roll": _color = self.get_color(j) @@ -690,17 +852,17 @@ def display_timeline( return mappables def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="roll"): - """mark events in plot""" + """Mark events in plot""" j = 0 - events = dict(spliting=[], merging=[]) + events = dict(splitting=[], merging=[]) # TODO : fill mappables dict y_seg = dict() - _time = self.time + _time = self.time_datetime64 if field is not None and method != "all": for i, b0, _ in self.iter_on("segment"): - y = self[field][i] + y = self.parse_varname(field)[i] if y.shape[0] != 0: y_seg[b0] = y.mean() * factor mappables = dict() @@ -726,7 +888,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol y0 = b0 else: if method == "all": - y0 = self[field][i.stop - 1] * factor + y0 = self.parse_varname(field)[i.stop - 1] * factor else: y0 = y_seg[b0] if i_n != -1: @@ -735,7 +897,7 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol seg_next if field is None else ( - self[field][i_n] * factor + self.parse_varname(field)[i_n] * factor if method == "all" else y_seg[seg_next] ) @@ -751,21 +913,21 @@ def event_timeline(self, ax, field=None, method=None, factor=1, colors_mode="rol seg_previous if field is None else ( - self[field][i_p] * factor + self.parse_varname(field)[i_p] * factor if method == "all" else y_seg[seg_previous] ) ) ax.plot((x[0], _time[i_p]), (y0, y1), **event_kw)[0] - events["spliting"].append((x[0], y0)) + events["splitting"].append((x[0], y0)) j += 1 kwargs = dict(color="k", zorder=-1, linestyle=" ") - if len(events["spliting"]) > 0: - X, Y = list(zip(*events["spliting"])) + if len(events["splitting"]) > 0: + X, Y = list(zip(*events["splitting"])) ref = ax.plot( - X, Y, marker="*", markersize=12, label="spliting events", **kwargs + X, Y, marker="*", markersize=12, label="splitting events", **kwargs )[0] mappables.setdefault("events", []).append(ref) @@ -787,7 +949,7 @@ def map_segment(self, method, y, same=True, **kw): out = empty(y.shape, **kw) else: out = list() - for i, b0, b1 in self.iter_on(self.segment_track_array): + for i, _, _ in self.iter_on(self.segment_track_array): res = method(y[i]) if same: out[i] = res @@ -804,11 +966,11 @@ def map_segment(self, method, y, same=True, **kw): def map_network(self, method, y, same=True, return_dict=False, **kw): """ - transform data `y` with method `method` for each track. + Transform data `y` with method `method` for each track. - :param Callable method: method to apply on each tracks + :param Callable method: method to apply on each track :param np.array y: data where to apply method - :param bool same: if True, return array same size from y. else, return list with track edited + :param bool same: if True, return an array with the same size than y. Else, return a list with the edited tracks :param bool return_dict: if None, mean values are used :param float kw: to multiply field :return: array or dict of result from method for each network @@ -816,7 +978,7 @@ def map_network(self, method, y, same=True, return_dict=False, **kw): if same and return_dict: raise NotImplementedError( - "both condition 'same' and 'return_dict' should no be true" + "both conditions 'same' and 'return_dict' should no be true" ) if same: @@ -860,7 +1022,7 @@ def scatter_timeline( **kwargs, ): """ - Must be call on only one network + Must be called on only one network """ self.only_one_network() y = (self.segment if yfield is None else self.parse_varname(yfield)) * yfactor @@ -876,11 +1038,11 @@ def scatter_timeline( if "c" not in kwargs: v = self.parse_varname(name) kwargs["c"] = v * factor - mappables["scatter"] = ax.scatter(self.time, y, **kwargs) + mappables["scatter"] = ax.scatter(self.time_datetime64, y, **kwargs) return mappables def event_map(self, ax, **kwargs): - """Add the merging and splitting events """ + """Add the merging and splitting events to a map""" j = 0 mappables = dict() symbol_kw = dict( @@ -924,12 +1086,12 @@ def scatter( **kwargs, ): """ - This function will scatter the path of each network, with the merging and splitting events + This function scatters the path of each network, with the merging and splitting events :param matplotlib.axes.Axes ax: matplotlib axe used to draw :param str,array,None name: - variable used to fill the contour, if None all elements have the same color - :param float,None ref: if define use like west bound + variable used to fill the contours, if None all elements have the same color + :param float,None ref: if defined, ref is used as western boundary :param float factor: multiply value by :param list edgecolor_cycle: list of colors :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.scatter` @@ -972,7 +1134,7 @@ def extract_event(self, indices): raw_data=self.raw_data, ) - for k in new.obs.dtype.names: + for k in new.fields: new[k][:] = self[k][indices] new.sign_type = self.sign_type return new @@ -980,33 +1142,35 @@ def extract_event(self, indices): @property def segment_track_array(self): """Return a unique segment id when multiple networks are considered""" - return build_unique_array(self.segment, self.track) - - def birth_event(self): - # FIXME how to manage group 0 - indices = list() - previous_obs = self.previous_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_p = previous_obs[i.start] - if i_p == -1: - indices.append(i.start) - return self.extract_event(list(set(indices))) - - def death_event(self): - # FIXME how to manage group 0 - indices = list() - next_obs = self.next_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_n = next_obs[i.stop - 1] - if i_n == -1: - indices.append(i.stop - 1) - return self.extract_event(list(set(indices))) + if self._segment_track_array is None: + self._segment_track_array = build_unique_array(self.segment, self.track) + return self._segment_track_array + + def birth_event(self, only_index=False): + """Extract birth events.""" + i_start, _, _ = self.index_segment_track + indices = i_start[self.previous_obs[i_start] == -1] + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + generation_event = birth_event + + def death_event(self, only_index=False): + """Extract death events.""" + _, i_stop, _ = self.index_segment_track + indices = i_stop[self.next_obs[i_stop - 1] == -1] - 1 + if self.first_is_trash(): + indices = indices[1:] + if only_index: + return indices + else : + return self.extract_event(indices) + + dissipation_event = death_event def merging_event(self, triplet=False, only_index=False): """Return observation after a merging event. @@ -1014,25 +1178,26 @@ def merging_event(self, triplet=False, only_index=False): If `triplet=True` return the eddy after a merging event, the eddy before the merging event, and the eddy stopped due to merging. """ - idx_m1 = list() + # Get start and stop for each segment, there is no empty segment + _, i1, _ = self.index_segment_track + # Get last index for each segment + i_stop = i1 - 1 + # Get target index + idx_m1 = self.next_obs[i_stop] + # Get mask and valid target + m = idx_m1 != -1 + idx_m1 = idx_m1[m] + # Sort by time event + i = self.time[idx_m1].argsort() + idx_m1 = idx_m1[i] if triplet: - idx_m0_stop = list() - idx_m0 = list() - next_obs, previous_obs = self.next_obs, self.previous_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_n = next_obs[i.stop - 1] - if i_n != -1: - if triplet: - idx_m0_stop.append(i.stop - 1) - idx_m0.append(previous_obs[i_n]) - idx_m1.append(i_n) + # Get obs before target + idx_m0_stop = i_stop[m][i] + idx_m0 = self.previous_obs[idx_m1].copy() if triplet: if only_index: - return (idx_m1, idx_m0, idx_m0_stop) + return idx_m1, idx_m0, idx_m0_stop else: return ( self.extract_event(idx_m1), @@ -1040,46 +1205,45 @@ def merging_event(self, triplet=False, only_index=False): self.extract_event(idx_m0_stop), ) else: - idx_m1 = list(set(idx_m1)) + idx_m1 = unique(idx_m1) if only_index: return idx_m1 else: return self.extract_event(idx_m1) - def spliting_event(self, triplet=False, only_index=False): + def splitting_event(self, triplet=False, only_index=False): """Return observation before a splitting event. If `triplet=True` return the eddy before a splitting event, the eddy after the splitting event, and the eddy starting due to splitting. """ - idx_s0 = list() + # Get start and stop for each segment, there is no empty segment + i_start, _, _ = self.index_segment_track + # Get target index + idx_s0 = self.previous_obs[i_start] + # Get mask and valid target + m = idx_s0 != -1 + idx_s0 = idx_s0[m] + # Sort by time event + i = self.time[idx_s0].argsort() + idx_s0 = idx_s0[i] if triplet: - idx_s1_start = list() - idx_s1 = list() - next_obs, previous_obs = self.next_obs, self.previous_obs - for i, _, _ in self.iter_on(self.segment_track_array): - nb = i.stop - i.start - if nb == 0: - continue - i_p = previous_obs[i.start] - if i_p != -1: - if triplet: - idx_s1_start.append(i.start) - idx_s1.append(next_obs[i_p]) - idx_s0.append(i_p) + # Get obs after target + idx_s1_start = i_start[m][i] + idx_s1 = self.next_obs[idx_s0].copy() if triplet: if only_index: - return (idx_s0, idx_s1, idx_s1_start) + return idx_s0, idx_s1, idx_s1_start else: return ( - self.extract_event(list(idx_s0)), - self.extract_event(list(idx_s1)), - self.extract_event(list(idx_s1_start)), + self.extract_event(idx_s0), + self.extract_event(idx_s1), + self.extract_event(idx_s1_start), ) else: - idx_s0 = list(set(idx_s0)) + idx_s0 = unique(idx_s0) if only_index: return idx_s0 else: @@ -1087,70 +1251,172 @@ def spliting_event(self, triplet=False, only_index=False): def dissociate_network(self): """ - Dissociate network with no known interaction (spliting/merging) + Dissociate networks with no known interaction (splitting/merging) """ - - tags = self.tag_segment(multi_network=True) + tags = self.tag_segment() if self.track[0] == 0: tags -= 1 - self.track[:] = tags[self.segment_track_array] + return self.sort() - i_sort = self.obs.argsort(order=("track", "segment", "time"), kind="mergesort") - # Sort directly obs, with hope to save memory - self.obs.sort(order=("track", "segment", "time"), kind="mergesort") - self._index_network = None - - # n & p must be re-index - n, p = self.next_obs, self.previous_obs - # we add 2 for -1 index return index -1 - nb_obs = len(self) - translate = -ones(nb_obs + 1, dtype="i4") - translate[:-1][i_sort] = arange(nb_obs) - self.next_obs[:] = translate[n] - self.previous_obs[:] = translate[p] + def network_segment(self, id_network, id_segment): + return self.extract_with_mask(self.segment_slice(id_network, id_segment)) def network(self, id_network): return self.extract_with_mask(self.network_slice(id_network)) + def networks_mask(self, id_networks, segment=False): + if segment: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_segment_track + ) + else: + return generate_mask_from_ids( + id_networks, self.track.size, *self.index_network + ) + def networks(self, id_networks): - m = zeros(self.track.shape, dtype=bool) - for tr in id_networks: - m[self.network_slice(tr)] = True - return self.extract_with_mask(m) + return self.extract_with_mask( + generate_mask_from_ids( + array(id_networks), self.track.size, *self.index_network + ) + ) + + @property + def nb_network(self): + """ + Count and return number of network + """ + return (self.network_size() != 0).sum() + + @property + def nb_segment(self): + """ + Count and return number of segment in all network + """ + return self.index_segment_track[0].size + + def identify_in(self, other, size_min=1, segment=False): + """ + Return couple of segment or network which are equal + + :param other: other atlas to compare + :param int size_min: number of observation in network/segment + :param bool segment: segment mode + """ + if segment: + counts = self.segment_size(), other.segment_size() + i_self_ref, i_other_ref = ( + self.ref_segment_track_index, + other.ref_segment_track_index, + ) + var_id = "segment" + else: + counts = self.network_size(), other.network_size() + i_self_ref, i_other_ref = self.ref_index, other.ref_index + var_id = "track" + # object to contain index of couple + in_self, in_other = list(), list() + # We iterate on item of same size + for i_self, i_other, i0, _ in self.align_on(other, counts, all_ref=True): + if i0 < size_min: + continue + if isinstance(i_other, slice): + i_other = arange(i_other.start, i_other.stop) + # All_ref will give all item of self, sometime there is no things to compare with other + if i_other.size == 0: + id_self = i_self + i_self_ref + in_self.append(id_self) + in_other.append(-ones(id_self.shape, dtype=id_self.dtype)) + continue + if isinstance(i_self, slice): + i_self = arange(i_self.start, i_self.stop) + # We get absolute id + id_self, id_other = i_self + i_self_ref, i_other + i_other_ref + # We compute mask to select data + m_self, m_other = self.networks_mask(id_self, segment), other.networks_mask( + id_other, segment + ) + + # We extract obs + obs_self, obs_other = self.obs[m_self], other.obs[m_other] + x1, y1, t1 = obs_self["lon"], obs_self["lat"], obs_self["time"] + x2, y2, t2 = obs_other["lon"], obs_other["lat"], obs_other["time"] + + if segment: + ids1 = build_unique_array(obs_self["segment"], obs_self["track"]) + ids2 = build_unique_array(obs_other["segment"], obs_other["track"]) + label1 = self.segment_track_array[m_self] + label2 = other.segment_track_array[m_other] + else: + label1, label2 = ids1, ids2 = obs_self[var_id], obs_other[var_id] + # For each item we get index to sort + i01, indexs1, id1 = list(), List(), list() + for sl_self, id_, _ in self.iter_on(ids1): + i01.append(sl_self.start) + indexs1.append(obs_self[sl_self].argsort(order=["time", "lon", "lat"])) + id1.append(label1[sl_self.start]) + i02, indexs2, id2 = list(), List(), list() + for sl_other, _, _ in other.iter_on(ids2): + i02.append(sl_other.start) + indexs2.append( + obs_other[sl_other].argsort(order=["time", "lon", "lat"]) + ) + id2.append(label2[sl_other.start]) + + id1, id2 = array(id1), array(id2) + # We search item from self in item of others + i_local_target = same_position( + x1, y1, t1, x2, y2, t2, array(i01), array(i02), indexs1, indexs2 + ) + + # -1 => no item found in other dataset + m = i_local_target != -1 + in_self.append(id1) + track2_ = -ones(id1.shape, dtype="i4") + track2_[m] = id2[i_local_target[m]] + in_other.append(track2_) + + return concatenate(in_self), concatenate(in_other) @classmethod def __tag_segment(cls, seg, tag, groups, connexions): """ Will set same temporary ID for each connected segment. - :param int seg: current ID of seg - :param ing tag: temporary ID to set for seg and its connexion - :param array[int] groups: array where tag will be stored - :param dict connexions: gives for one ID of seg all seg connected + :param int seg: current ID of segment + :param ing tag: temporary ID to set for segment and its connexion + :param array[int] groups: array where tag is stored + :param dict connexions: gives for one ID of segment all connected segments """ - # If seg are already used we stop recursivity + # If segments are already used we stop recursivity if groups[seg] != 0: return - # We set tag for this seg + # We set tag for this segment groups[seg] = tag - # Get all connexions of this seg + # Get all connexions of this segment segs = connexions.get(seg, None) if segs is not None: for seg in segs: # For each connexion we apply same function cls.__tag_segment(seg, tag, groups, connexions) - def tag_segment(self, multi_network=False): - if multi_network: - nb = self.segment_track_array[-1] + 1 - else: - nb = self.segment.max() + 1 + def tag_segment(self): + """For each segment, method give a new network id, and all segment are connected + + :return array: for each unique seg id, it return new network id + """ + nb = self.segment_track_array[-1] + 1 sub_group = zeros(nb, dtype="u4") - c = self.connexions(multi_network=multi_network) + c = self.connexions(multi_network=True) j = 1 # for each available id for i in range(nb): + # No connexions, no need to explore + if i not in c: + sub_group[i] = j + j += 1 + continue # Skip if already set if sub_group[i] != 0: continue @@ -1160,25 +1426,42 @@ def tag_segment(self, multi_network=False): return sub_group def fully_connected(self): + """Suspicious""" + raise Exception("Must be check") self.only_one_network() return self.tag_segment().shape[0] == 1 + def first_is_trash(self): + """Check if first network is Trash + + :return bool: True if first network is trash + """ + i_start, i_stop, _ = self.index_segment_track + sl = slice(i_start[0], i_stop[0]) + return (self.previous_obs[sl] == -1).all() and (self.next_obs[sl] == -1).all() + def remove_trash(self): - return self.extract_with_mask(self.track != 0) + """ + Remove the lonely eddies (only 1 obs in segment, associated network number is 0) + """ + if self.first_is_trash(): + return self.extract_with_mask(self.track != 0) + else: + return self def plot(self, ax, ref=None, color_cycle=None, **kwargs): """ - This function will draw path of each trajectory + This function draws the path of each trajectory :param matplotlib.axes.Axes ax: ax to draw - :param float,int ref: if defined, all coordinates will be wrapped with ref like west boundary + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: keyword arguments for Axes.plot :return: a list of matplotlib mappables """ - nb_colors = 0 - if color_cycle is not None: - kwargs = kwargs.copy() - nb_colors = len(color_cycle) + kwargs = kwargs.copy() + if color_cycle is None: + color_cycle = self.COLORS + nb_colors = len(color_cycle) mappables = list() if "label" in kwargs: kwargs["label"] = self.format_label(kwargs["label"]) @@ -1196,57 +1479,82 @@ def plot(self, ax, ref=None, color_cycle=None, **kwargs): j += 1 return mappables - def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None): + def remove_dead_end(self, nobs=3, ndays=0, recursive=0, mask=None, return_mask=False): """ - Remove short segment which didn't connect several segment + Remove short segments that don't connect several segments - :param int nobs: Minimal number of observation to keep segment - :param int ndays: Minimal number of days to keep segment - :param int recursive: Run method N times more - :param int mask: if one or more observation of segment are select by mask, the segment is keep + :param int nobs: Minimal number of observation to keep a segment + :param int ndays: Minimal number of days to keep a segment + :param int recursive: Run method N times more + :param int mask: if one or more observation of the segment are selected by mask, the segment is kept .. warning:: - It will remove short segment which splits than merges with same segment + It will remove short segment that splits from then merges with the same segment """ - segments_keep = list() connexions = self.connexions(multi_network=True) - t = self.time - for i, b0, _ in self.iter_on(self.segment_track_array): - if mask and mask[i].any(): - segments_keep.append(b0) - continue - nb = i.stop - i.start - dt = t[i.stop - 1] - t[i.start] - if (nb < nobs or dt < ndays) and len(connexions.get(b0, tuple())) < 2: - continue - segments_keep.append(b0) - if recursive > 0: - return self.extract_segment(segments_keep, absolute=True).remove_dead_end( - nobs, ndays, recursive - 1 + i0, i1, _ = self.index_segment_track + dt = self.time[i1 - 1] - self.time[i0] + 1 + nb = i1 - i0 + m = (dt >= ndays) * (nb >= nobs) + nb_connexions = array([len(connexions.get(i, tuple())) for i in where(~m)[0]]) + m[~m] = nb_connexions >= 2 + segments_keep = where(m)[0] + if mask is not None: + segments_keep = unique( + concatenate((segments_keep, self.segment_track_array[mask])) ) - return self.extract_segment(segments_keep, absolute=True) + # get mask for selected obs + m = ~self.segment_mask(segments_keep) + if return_mask: + return ~m + self.track[m] = 0 + self.segment[m] = 0 + self.previous_obs[m] = -1 + self.previous_cost[m] = 0 + self.next_obs[m] = -1 + self.next_cost[m] = 0 + + m_previous = m[self.previous_obs] + self.previous_obs[m_previous] = -1 + self.previous_cost[m_previous] = 0 + m_next = m[self.next_obs] + self.next_obs[m_next] = -1 + self.next_cost[m_next] = 0 + + self.sort() + if recursive > 0: + self.remove_dead_end(nobs, ndays, recursive - 1) + + def extract_segment(self, segments, absolute=False): - mask = ones(self.shape, dtype="bool") - segments = array(segments) - values = self.segment_track_array if absolute else "segment" - keep = ones(values.max() + 1, dtype="bool") - v = unique(values) - keep[v] = in1d(v, segments) - for i, b0, b1 in self.iter_on(values): - if not keep[b0]: - mask[i] = False - return self.extract_with_mask(mask) + """Extract given segments - def extract_with_period(self, period): + :param array,tuple,list segments: list of segment to extract + :param bool absolute: keep for compatibility, defaults to False + :return NetworkObservations: Return observations from selected segment """ - Extract within a time period + if not absolute: + raise Exception("Not implemented") + return self.extract_with_mask(self.segment_mask(segments)) - :param (int,int) period: two dates to define the period, must be specify from 1/1/1950 - :return: Return all eddy tracks which are in bounds - :rtype: NetworkObservations + def segment_mask(self, segments): + """Get mask from list of segment + + :param list,array segments: absolute id of segment + """ + return generate_mask_from_ids( + array(segments), len(self), *self.index_segment_track + ) + + def get_mask_with_period(self, period): + """ + obtain mask within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :return: mask where period is defined + :rtype: np.array(bool) - .. minigallery:: py_eddy_tracker.NetworkObservations.extract_with_period """ dataset_period = self.period p_min, p_max = period @@ -1260,7 +1568,70 @@ def extract_with_period(self, period): mask *= self.time <= p_max elif p_max < 0: mask *= self.time <= (dataset_period[1] + p_max) - return self.extract_with_mask(mask) + return mask + + def extract_with_period(self, period): + """ + Extract within a time period + + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 + :return: Return all eddy trajectories in period + :rtype: NetworkObservations + + .. minigallery:: py_eddy_tracker.NetworkObservations.extract_with_period + """ + + return self.extract_with_mask(self.get_mask_with_period(period)) + + def extract_light_with_mask(self, mask, track_extra_variables=[]): + """extract data with mask, but only with variables used for coherence, aka self.array_variables + + :param mask: mask used to extract + :type mask: np.array(bool) + :return: new EddiesObservation with data wanted + :rtype: self + """ + + if isinstance(mask, slice): + nb_obs = mask.stop - mask.start + else: + nb_obs = mask.sum() + + # only time & contour_lon/lat_e/s + variables = ["time"] + self.array_variables + new = self.__class__( + size=nb_obs, + track_extra_variables=track_extra_variables, + track_array_variables=self.track_array_variables, + array_variables=self.array_variables, + only_variables=variables, + raw_data=self.raw_data, + ) + new.sign_type = self.sign_type + if nb_obs == 0: + logger.info("Empty dataset will be created") + else: + logger.info( + f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" + ) + + for field in variables + track_extra_variables: + logger.debug("Copy of field %s ...", field) + new.obs[field] = self.obs[field][mask] + + if ( + "previous_obs" in track_extra_variables + and "next_obs" in track_extra_variables + ): + # n & p must be re-index + n, p = self.next_obs[mask], self.previous_obs[mask] + # we add 2 for -1 index return index -1 + translate = -ones(len(self) + 1, dtype="i4") + translate[:-1][mask] = arange(nb_obs) + new.next_obs[:] = translate[n] + new.previous_obs[:] = translate[p] + + return new def extract_with_mask(self, mask): """ @@ -1277,17 +1648,16 @@ def extract_with_mask(self, mask): new = self.__class__.new_like(self, nb_obs) new.sign_type = self.sign_type if nb_obs == 0: - logger.warning("Empty dataset will be created") + logger.info("Empty dataset will be created") else: - logger.info( - f"{nb_obs} observations will be extract ({nb_obs / self.shape[0]:.3%})" + logger.debug( + f"{nb_obs} observations will be extracted ({nb_obs / self.shape[0]:.3%})" ) - for field in self.obs.dtype.descr: + for field in self.fields: if field in ("next_obs", "previous_obs"): continue logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] # n & p must be re-index n, p = self.next_obs[mask], self.previous_obs[mask] # we add 2 for -1 index return index -1 @@ -1297,6 +1667,305 @@ def extract_with_mask(self, mask): new.previous_obs[:] = translate[p] return new + def analysis_coherence( + self, + date_function, + uv_params, + advection_mode="both", + n_days=14, + step_mesh=1.0 / 50, + output_name=None, + dissociate_network=False, + correct_close_events=0, + remove_dead_end=0, + ): + """Global function to analyse segments coherence, with network preprocessing. + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: nuber of days for advection + :param float step_mesh: step for particule mesh in degrees + :param str output_name: path/name for the output (without extension) to store the clean + network in .nc and the coherence results in .zarr. Works only for advection_mode = "both" + :param bool dissociate_network: If True apply + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.dissociate_network` + :param int correct_close_events: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.correct_close_events` + :param int remove_dead_end: Number of days in + :py:meth:`~py_eddy_tracker.observation.network.NetworkObservations.remove_dead_end` + :return target_forward, target_bakward: 2D numpy.array with the eddy observation the + particles ended in after advection + :return target_forward, target_bakward: percentage of ending particles within the + eddy observation with regards to the starting number + """ + + if dissociate_network: + self.dissociate_network() + + if correct_close_events > 0: + self.correct_close_events(nb_days_max=correct_close_events) + + if remove_dead_end > 0: + network_clean = self.remove_dead_end(nobs=0, ndays=remove_dead_end) + else: + network_clean = self + + network_clean.numbering_segment() + + res = [] + if (advection_mode == "both") | (advection_mode == "forward"): + target_forward, pct_forward = network_clean.segment_coherence_forward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_forward, pct_forward] + + if (advection_mode == "both") | (advection_mode == "backward"): + target_backward, pct_backward = network_clean.segment_coherence_backward( + date_function=date_function, + uv_params=uv_params, + n_days=n_days, + step_mesh=step_mesh, + ) + res = res + [target_backward, pct_backward] + + if (output_name is not None) & (advection_mode == "both"): + # TODO : put some path verification? + # Save the clean network in netcdf + with netCDF4.Dataset(output_name + ".nc", "w") as fh: + network_clean.to_netcdf(fh) + # Save the results of particles advection in zarr + # zarr compression parameters + # TODO : check size? compression? + params_seg = dict() + params_pct = dict() + zg = zarr.open(output_name + ".zarr", mode="w") + zg.array("target_forward", target_forward, **params_seg) + zg.array("pct_forward", pct_forward, **params_pct) + zg.array("target_backward", target_backward, **params_seg) + zg.array("pct_backward", pct_backward, **params_pct) + + return network_clean, res + + def segment_coherence_backward( + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", + ): + """ + Percentage of particules and their targets after backward advection from a specific eddy. + + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date (see note) + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: days for advection + :param float step_mesh: step for particule mesh in degrees + :return: observations matchs, and percents + + .. note:: the param `date_function` should be something like : + + .. code-block:: python + + def date2file(julian_day): + date = datetime.timedelta(days=julian_day) + datetime.datetime( + 1950, 1, 1 + ) + + return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" + """ + shape = len(self), 2 + itb_final = -ones(shape, dtype="i4") + ptb_final = zeros(shape, dtype="i1") + + t_start, t_end = int(self.period[0]), int(self.period[1]) + + # dates = arange(t_start, t_start + n_days + 1) + dates = arange(t_start, min(t_start + n_days + 1, t_end + 1)) + first_files = [date_function(x) for x in dates] + + c = GridCollection.from_netcdf_list(first_files, dates, **uv_params) + first = True + range_start = t_start + n_days + range_end = t_end + 1 + + for _t in range(t_start + n_days, t_end + 1): + _timestamp = time.time() + t_shift = _t + + # skip first shift, because already included + if first: + first = False + else: + # add next date to GridCollection and delete last date + c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) + particle_candidate( + c, + self, + step_mesh, + _t, + itb_final, + ptb_final, + n_days=-n_days, + contour_start=contour_start, + contour_end=contour_end, + ) + logger.info( + ( + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" + ) + ) + + return itb_final, ptb_final + + def segment_coherence_forward( + self, + date_function, + uv_params, + n_days=14, + step_mesh=1.0 / 50, + contour_start="speed", + contour_end="speed", + **kwargs, + ): + """ + Percentage of particules and their targets after forward advection from a specific eddy. + + :param callable date_function: python function, takes as param `int` (julian day) and return + data filename associated to the date (see note) + :param dict uv_params: dict of parameters used by + :py:meth:`~py_eddy_tracker.dataset.grid.GridCollection.from_netcdf_list` + :param int n_days: days for advection + :param float step_mesh: step for particule mesh in degrees + :return: observations matchs, and percents + + .. note:: the param `date_function` should be something like : + + .. code-block:: python + + def date2file(julian_day): + date = datetime.timedelta(days=julian_day) + datetime.datetime( + 1950, 1, 1 + ) + + return f"/tmp/dt_global_{date.strftime('%Y%m%d')}.nc" + """ + shape = len(self), 2 + itf_final = -ones(shape, dtype="i4") + ptf_final = zeros(shape, dtype="i1") + + t_start, t_end = int(self.period[0]), int(self.period[1]) + + dates = arange(t_start, min(t_start + n_days + 1, t_end + 1)) + first_files = [date_function(x) for x in dates] + + c = GridCollection.from_netcdf_list(first_files, dates, **uv_params) + first = True + range_start = t_start + range_end = t_end - n_days + 1 + + for _t in range(range_start, range_end): + _timestamp = time.time() + t_shift = _t + n_days + + # skip first shift, because already included + if first: + first = False + else: + # add next date to GridCollection and delete last date + c.shift_files(t_shift, date_function(int(t_shift)), **uv_params) + particle_candidate( + c, + self, + step_mesh, + _t, + itf_final, + ptf_final, + n_days=n_days, + contour_start=contour_start, + contour_end=contour_end, + **kwargs, + ) + logger.info( + ( + f"coherence {_t} / {range_end - 1} ({(_t - range_start) / (range_end - range_start - 1):.1%})" + f" : {time.time() - _timestamp:5.2f}s" + ) + ) + return itf_final, ptf_final + + def mask_obs_close_event(self, merging=True, spliting=True, dt=3): + """Build a mask of close observation from event + + :param n: Network + :param bool merging: select merging event, defaults to True + :param bool spliting: select splitting event, defaults to True + :param int dt: delta of time max , defaults to 3 + :return array: mask + """ + m = zeros(len(self), dtype="bool") + if merging: + i_target, ip1, ip2 = self.merging_event(triplet=True, only_index=True) + mask_follow_obs(m, self.previous_obs, self.time, ip1, dt) + mask_follow_obs(m, self.previous_obs, self.time, ip2, dt) + mask_follow_obs(m, self.next_obs, self.time, i_target, dt) + if spliting: + i_target, in1, in2 = self.splitting_event(triplet=True, only_index=True) + mask_follow_obs(m, self.next_obs, self.time, in1, dt) + mask_follow_obs(m, self.next_obs, self.time, in2, dt) + mask_follow_obs(m, self.previous_obs, self.time, i_target, dt) + return m + + def swap_track( + self, + length_main_max_after_event=2, + length_secondary_min_after_event=10, + delta_pct_max=-0.2, + ): + events = self.splitting_event(triplet=True, only_index=True) + count = 0 + for i_main, i1, i2 in zip(*events): + seg_main, _, seg2 = ( + self.segment_track_array[i_main], + self.segment_track_array[i1], + self.segment_track_array[i2], + ) + i_start, i_end, i0 = self.index_segment_track + # For splitting + last_index_main = i_end[seg_main - i0] - 1 + last_index_secondary = i_end[seg2 - i0] - 1 + last_main_next_obs = self.next_obs[last_index_main] + t_event, t_main_end, t_secondary_start, t_secondary_end = ( + self.time[i_main], + self.time[last_index_main], + self.time[i2], + self.time[last_index_secondary], + ) + dt_main, dt_secondary = ( + t_main_end - t_event, + t_secondary_end - t_secondary_start, + ) + delta_cost = self.previous_cost[i2] - self.previous_cost[i1] + if ( + dt_main <= length_main_max_after_event + and dt_secondary >= length_secondary_min_after_event + and last_main_next_obs == -1 + and delta_cost > delta_pct_max + ): + self.segment[i1 : last_index_main + 1] = self.segment[i2] + self.segment[i2 : last_index_secondary + 1] = self.segment[i_main] + count += 1 + logger.info("%d segmnent swap on %d", count, len(events[0])) + return self.sort() + class Network: __slots__ = ( @@ -1370,10 +2039,10 @@ def group_translator(nb, duos): Create a translator with all duos :param int nb: size of translator - :param set((int, int)) duos: set of all group which must be join + :param set((int, int)) duos: set of all groups that must be joined + + :Example: - Examples - -------- >>> NetworkObservations.group_translator(5, ((0, 1), (0, 2), (1, 3))) [3, 3, 3, 3, 5] """ @@ -1384,21 +2053,42 @@ def group_translator(nb, duos): apply_replace(translate, gr_i, gr_j) return translate - def group_observations(self, **kwargs): + def group_observations(self, min_overlap=0.2, minimal_area=False, **kwargs): + """Store every interaction between identifications + + :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union, by default False + :param float min_overlap: minimum overlap area to associate observations, by default 0.2 + + :return: + :rtype: TrackEddiesObservations + """ + results, nb_obs = list(), list() # To display print only in INFO display_iteration = logger.getEffectiveLevel() == logging.INFO + for i, filename in enumerate(self.filenames): if display_iteration: print(f"{filename} compared to {self.window} next", end="\r") - # Load observations with function to buffered observations + # Load observations with function to buffer observations xi, yi = self.buffer.load_contour(filename) # Append number of observations by filename nb_obs.append(xi.shape[0]) for j in range(i + 1, min(self.window + i + 1, self.nb_input)): xj, yj = self.buffer.load_contour(self.filenames[j]) ii, ij = bbox_intersection(xi, yi, xj, yj) - m = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) > 0.2 + m = ( + vertice_overlap( + xi[ii], + yi[ii], + xj[ij], + yj[ij], + minimal_area=minimal_area, + min_overlap=min_overlap, + **kwargs, + ) + != 0 + ) results.append((i, j, ii[m], ij[m])) if display_iteration: print() @@ -1407,7 +2097,7 @@ def group_observations(self, **kwargs): nb_alone, nb_obs, nb_gr = (gr == self.NOGROUP).sum(), len(gr), len(unique(gr)) logger.info( f"{nb_alone} alone / {nb_obs} obs, {nb_gr} groups, " - f"{nb_alone *100./nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group" + f"{nb_alone * 100. / nb_obs:.2f} % alone, {(nb_obs - nb_alone) / (nb_gr - 1):.1f} obs/group" ) return gr @@ -1416,7 +2106,7 @@ def build_dataset(self, group, raw_data=True): model = TrackEddiesObservations.load_file(self.filenames[-1], raw_data=raw_data) eddies = TrackEddiesObservations.new_like(model, nb_obs) eddies.sign_type = model.sign_type - # Get new index to re-order observation by group + # Get new index to re-order observations by groups new_i = get_next_index(group) display_iteration = logger.getEffectiveLevel() == logging.INFO elements = eddies.elements @@ -1442,9 +2132,97 @@ def build_dataset(self, group, raw_data=True): return eddies +@njit(cache=True) +def get_percentile_on_following_obs( + i, indexs, percents, follow_obs, t, segment, i_target, window, q=50, nb_min=1 +): + """Get stat on a part of segment close of an event + + :param int i: index to follow + :param array indexs: indexs from coherence + :param array percents: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int i_target: index of target + :param int window: time window of search + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return float : return statistic + """ + last_t, segment_follow = t[i], segment[i] + segment_target = segment[i_target] + percent_target = empty(window, dtype=percents.dtype) + j = 0 + while abs(last_t - t[i]) < window and i != -1 and segment_follow == segment[i]: + # Iter on primary & secondary + for index, percent in zip(indexs[i], percents[i]): + if index != -1 and segment[index] == segment_target: + percent_target[j] = percent + j += 1 + i = follow_obs[i] + if j < nb_min: + return nan + return percentile(percent_target[:j], q) + + +@njit(cache=True) +def get_percentile_around_event( + i, + i1, + i2, + ind, + pct, + follow_obs, + t, + segment, + window=10, + follow_parent=False, + q=50, + nb_min=1, +): + """Get stat around event + + :param array[int] i: Indexs of target + :param array[int] i1: Indexs of primary origin + :param array[int] i2: Indexs of secondary origin + :param array ind: indexs from coherence + :param array pct: percent from coherence + :param array[int] follow_obs: give index for the following observation + :param array t: time for each observation + :param array segment: segment for each observation + :param int window: time window of search, defaults to 10 + :param bool follow_parent: Follow parent instead of child, defaults to False + :param int q: Percentile from 0 to 100, defaults to 50 + :param int nb_min: Number minimal of observation to provide statistics, defaults to 1 + :return (array,array) : statistic for each event + """ + stat1 = empty(i.size, dtype=nb_types.float32) + stat2 = empty(i.size, dtype=nb_types.float32) + # iter on event + for j, (i_, i1_, i2_) in enumerate(zip(i, i1, i2)): + if follow_parent: + # We follow parent + stat1[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i1_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i_, ind, pct, follow_obs, t, segment, i2_, window, q, nb_min + ) + else: + # We follow child + stat1[j] = get_percentile_on_following_obs( + i1_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + stat2[j] = get_percentile_on_following_obs( + i2_, ind, pct, follow_obs, t, segment, i_, window, q, nb_min + ) + return stat1, stat2 + + @njit(cache=True) def get_next_index(gr): - """Return for each obs index the new position to join all group""" + """Return for each obs index the new position to join all groups""" nb_obs_gr = bincount(gr) i_gr = nb_obs_gr.cumsum() - nb_obs_gr new_index = empty(gr.shape, dtype=uint32) @@ -1493,3 +2271,98 @@ def new_numbering(segs, start=0): @njit(cache=True) def ptp(values): return values.max() - values.min() + + +@njit(cache=True) +def generate_mask_from_ids(id_networks, nb, istart, iend, i0): + """From list of id, we generate a mask + + :param array id_networks: list of ids + :param int nb: size of mask + :param array istart: first index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param array iend: last index for each id from :py:meth:`~py_eddy_tracker.generic.build_index` + :param int i0: ref index from :py:meth:`~py_eddy_tracker.generic.build_index` + :return array: return a mask + """ + m = zeros(nb, dtype="bool") + for i in id_networks: + for j in range(istart[i - i0], iend[i - i0]): + m[j] = True + return m + + +@njit(cache=True) +def same_position(x0, y0, t0, x1, y1, t1, i00, i01, i0, i1): + """Return index of track/segment found in other dataset + + :param array x0: + :param array y0: + :param array t0: + :param array x1: + :param array y1: + :param array t1: + :param array i00: First index of track/segment/network in dataset0 + :param array i01: First index of track/segment/network in dataset1 + :param List(array) i0: list of array which contain index to order dataset0 + :param List(array) i1: list of array which contain index to order dataset1 + :return array: index of dataset1 which match with dataset0, -1 => no match + """ + nb0, nb1 = i00.size, i01.size + i_target = -ones(nb0, dtype="i4") + # To avoid to compare multiple time, if already match + used1 = zeros(nb1, dtype="bool") + for j0 in range(nb0): + for j1 in range(nb1): + if used1[j1]: + continue + test = True + for i0_, i1_ in zip(i0[j0], i1[j1]): + i0_ += i00[j0] + i1_ += i01[j1] + if t0[i0_] != t1[i1_] or x0[i0_] != x1[i1_] or y0[i0_] != y1[i1_]: + test = False + break + if test: + i_target[j0] = j1 + used1[j1] = True + break + return i_target + + +@njit(cache=True) +def mask_follow_obs(m, next_obs, time, indexs, dt=3): + """Generate a mask to select close obs in time from index + + :param array m: mask to fill with True + :param array next_obs: index of the next observation + :param array time: time of each obs + :param array indexs: index to start follow + :param int dt: delta of time max from index, defaults to 3 + """ + for i in indexs: + t0 = time[i] + m[i] = True + i_next = next_obs[i] + dt_ = abs(time[i_next] - t0) + while dt_ < dt and i_next != -1: + m[i_next] = True + i_next = next_obs[i_next] + dt_ = abs(time[i_next] - t0) + + +@njit(cache=True) +def get_period_with_index(t, i0, i1): + """Return peek to peek cover by each slice define by i0 and i1 + + :param array t: array which contain values to estimate spread + :param array i0: index which determine start of slice + :param array i1: index which determine end of slice + :return array: Peek to peek of t + """ + periods = np.empty(i0.size, t.dtype) + for i in range(i0.size): + if i1[i] == i0[i]: + periods[i] = 0 + continue + periods[i] = t[i0[i] : i1[i]].ptp() + return periods diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py index 73df5734..b39f7f83 100644 --- a/src/py_eddy_tracker/observations/observation.py +++ b/src/py_eddy_tracker/observations/observation.py @@ -2,19 +2,18 @@ """ Base class to manage eddy observation """ -import logging from datetime import datetime from io import BufferedReader, BytesIO +import logging from tarfile import ExFileObject from tokenize import TokenError -import zarr +from Polygon import Polygon from matplotlib.cm import get_cmap -from matplotlib.collections import PolyCollection +from matplotlib.collections import LineCollection, PolyCollection from matplotlib.colors import Normalize from netCDF4 import Dataset -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( absolute, arange, @@ -23,6 +22,7 @@ ceil, concatenate, cos, + datetime64, digitize, empty, errstate, @@ -43,9 +43,10 @@ where, zeros, ) +import packaging.version from pint import UnitRegistry from pint.errors import UndefinedUnitError -from Polygon import Polygon +import zarr from .. import VAR_DESCR, VAR_DESCR_inv, __version__ from ..generic import ( @@ -57,12 +58,14 @@ hist_numba, local_to_coordinates, reverse_index, + window_index, wrap_longitude, ) from ..poly import ( bbox_intersection, close_center, convexs, + create_meshed_particles, create_vertice, get_pixel_in_regular, insidepoly, @@ -73,6 +76,31 @@ logger = logging.getLogger("pet") +# keep only major and minor version number +_software_version_reduced = packaging.version.Version( + "{v.major}.{v.minor}".format(v=packaging.version.parse(__version__)) +) +_display_check_warning = True + + +def _check_versions(version): + """Check if version of py_eddy_tracker used to create the file is compatible with software version + + if not, warn user with both versions + + :param version: string version of software used to create the file. If None, version was not provided + :type version: str, None + """ + if not _display_check_warning: + return + file_version = packaging.version.parse(version) if version is not None else None + if file_version is None or file_version < _software_version_reduced: + logger.warning( + "File was created with py-eddy-tracker version '%s' but software version is '%s'", + file_version, + _software_version_reduced, + ) + @njit(cache=True, fastmath=True) def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5): @@ -87,7 +115,7 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 # Focal f_right = lon0 f_left = f_right - (c - minor) - # Ellips center + # Ellipse center x_c = (f_left + f_right) * 0.5 nb_0, nb_1 = lat0.shape[0], lat1.shape[0] @@ -105,7 +133,7 @@ def shifted_ellipsoid_degrees_mask2(lon0, lat0, lon1, lat1, minor=1.5, major=1.5 if dx > major[j]: m[j, i] = False continue - d_normalize = dx ** 2 / major[j] ** 2 + dy ** 2 / minor ** 2 + d_normalize = dx**2 / major[j] ** 2 + dy**2 / minor**2 m[j, i] = d_normalize < 1.0 return m @@ -224,7 +252,7 @@ def __eq__(self, other): return array_equal(self.obs, other.obs) def get_color(self, i): - """Return colors like a cyclic list""" + """Return colors as a cyclic list""" return self.COLORS[i % self.NB_COLORS] @property @@ -240,7 +268,7 @@ def get_infos(self): bins_lat=(-90, -60, -15, 15, 60, 90), bins_amplitude=array((0, 1, 2, 3, 4, 5, 10, 500)), bins_radius=array((0, 15, 30, 45, 60, 75, 100, 200, 2000)), - nb_obs=self.observations.shape[0], + nb_obs=len(self), ) t0, t1 = self.period infos["t0"], infos["t1"] = t0, t1 @@ -260,7 +288,7 @@ def hist(self, varname, x, bins, percent=False, mean=False, nb=False): :param str,array varname: variable to use to compute stat :param str,array x: variable to use to know in which bins :param array bins: - :param bool percent: normalize by sum of all bins + :param bool percent: normalized by sum of all bins :param bool mean: compute mean by bins :param bool nb: only count by bins :return: value by bins @@ -279,15 +307,19 @@ def hist(self, varname, x, bins, percent=False, mean=False, nb=False): @staticmethod def box_display(value): - """Return value evenly spaced with few numbers""" + """Return values evenly spaced with few numbers""" return "".join([f"{v_:10.2f}" for v_ in value]) + @property + def fields(self): + return list(self.obs.dtype.names) + def field_table(self): """ Produce description table of the fields available in this object """ rows = [("Name (Unit)", "Long name", "Scale factor", "Offset")] - names = list(self.obs.dtype.names) + names = self.fields names.sort() for field in names: infos = VAR_DESCR[field] @@ -313,7 +345,7 @@ def __repr__(self): bins_lat = (-90, -60, -15, 15, 60, 90) bins_amplitude = array((0, 1, 2, 3, 4, 5, 10, 500)) bins_radius = array((0, 15, 30, 45, 60, 75, 100, 200, 2000)) - nb_obs = self.observations.shape[0] + nb_obs = len(self) return f""" | {nb_obs} observations from {t0} to {t1} ({period} days, ~{nb_obs / period:.0f} obs/day) | Speed area : {self.speed_area.sum() / period / 1e12:.2f} Mkm²/day @@ -388,9 +420,9 @@ def remove_fields(self, *fields): """ Copy with fields listed remove """ - nb_obs = self.obs.shape[0] + nb_obs = len(self) fields = set(fields) - only_variables = set(self.obs.dtype.names) - fields + only_variables = set(self.fields) - fields track_extra_variables = set(self.track_extra_variables) - fields array_variables = set(self.array_variables) - fields new = self.__class__( @@ -402,7 +434,7 @@ def remove_fields(self, *fields): raw_data=self.raw_data, ) new.sign_type = self.sign_type - for name in new.obs.dtype.names: + for name in new.fields: logger.debug("Copy of field %s ...", name) new.obs[name] = self.obs[name] return new @@ -411,7 +443,7 @@ def add_fields(self, fields=list(), array_fields=list()): """ Add a new field. """ - nb_obs = self.obs.shape[0] + nb_obs = len(self) new = self.__class__( size=nb_obs, track_extra_variables=list( @@ -419,13 +451,11 @@ def add_fields(self, fields=list(), array_fields=list()): ), track_array_variables=self.track_array_variables, array_variables=list(concatenate((self.array_variables, array_fields))), - only_variables=list( - concatenate((self.obs.dtype.names, fields, array_fields)) - ), + only_variables=list(concatenate((self.fields, fields, array_fields))), raw_data=self.raw_data, ) new.sign_type = self.sign_type - for name in self.obs.dtype.names: + for name in self.fields: logger.debug("Copy of field %s ...", name) new.obs[name] = self.obs[name] return new @@ -437,14 +467,14 @@ def add_rotation_type(self): def circle_contour(self, only_virtual=False, factor=1): """ - Set contours as a circles from radius and center data. + Set contours as circles from radius and center data. .. minigallery:: py_eddy_tracker.EddiesObservations.circle_contour """ angle = radians(linspace(0, 360, self.track_array_variables)) x_norm, y_norm = cos(angle), sin(angle) - radius_s = "contour_lon_s" in self.obs.dtype.names - radius_e = "contour_lon_e" in self.obs.dtype.names + radius_s = "contour_lon_s" in self.fields + radius_e = "contour_lon_e" in self.fields for i, obs in enumerate(self): if only_virtual and not obs["virtual"]: continue @@ -519,9 +549,9 @@ def merge(self, other): nb_obs_self = len(self) nb_obs = nb_obs_self + len(other) eddies = self.new_like(self, nb_obs) - other_keys = other.obs.dtype.fields.keys() - self_keys = self.obs.dtype.fields.keys() - for key in eddies.obs.dtype.fields.keys(): + other_keys = other.fields + self_keys = self.fields + for key in eddies.fields: eddies.obs[key][:nb_obs_self] = self.obs[key][:] if key in other_keys: eddies.obs[key][nb_obs_self:] = other.obs[key][:] @@ -546,56 +576,76 @@ def __iter__(self): for obs in self.obs: yield obs - def iter_on(self, xname, bins=None): + def iter_on(self, xname, window=None, bins=None): """ Yield observation group for each bin. :param str,array xname: - :param array bins: bounds of each bin , - :return: index or mask, bound low, bound up - """ - x = self[xname] if isinstance(xname, str) else xname - d = x[1:] - x[:-1] - if bins is None: - bins = arange(x.min(), x.max() + 2) - elif not isinstance(bins, ndarray): - bins = array(bins) - nb_bins = len(bins) - 1 - - # Not monotonous - if (d < 0).any(): - # If bins cover a small part of value - test, translate, x = iter_mode_reduce(x, bins) - # convert value in bins number - i = numba_digitize(x, bins) - 1 - # Order by bins - i_sort = i.argsort() - # If in reduce mode we will translate i_sort in full array index - i_sort_ = translate[i_sort] if test else i_sort - # Bound for each bins in sorting view - i0, i1, _ = build_index(i[i_sort]) - m = ~(i0 == i1) - i0, i1 = i0[m], i1[m] - for i0_, i1_ in zip(i0, i1): - i_bins = i[i_sort[i0_]] - if i_bins == -1 or i_bins == nb_bins: - continue - yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1] + :param float,None window: if defined we use a moving window with value like half window + :param array bins: bounds of each bin + :yield array,float,float: index in self, lower bound, upper bound + + .. minigallery:: py_eddy_tracker.EddiesObservations.iter_on + """ + x = self.parse_varname(xname) + if window is not None: + x0 = arange(x.min(), x.max()) if bins is None else array(bins) + i_ordered, first_index, last_index = window_index(x, x0, window) + for x_, i0, i1 in zip(x0, first_index, last_index): + yield i_ordered[i0:i1], x_ - window, x_ + window else: - i = numba_digitize(x, bins) - 1 - i0, i1, _ = build_index(i) - m = ~(i0 == i1) - i0, i1 = i0[m], i1[m] - for i0_, i1_ in zip(i0, i1): - i_bins = i[i0_] - yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] + d = x[1:] - x[:-1] + if bins is None: + bins = arange(x.min(), x.max() + 2) + elif not isinstance(bins, ndarray): + bins = array(bins) + nb_bins = len(bins) - 1 + + # Not monotonous + if (d < 0).any(): + # If bins cover a small part of value + test, translate, x = iter_mode_reduce(x, bins) + # convert value in bins number + i = numba_digitize(x, bins) - 1 + # Order by bins + i_sort = i.argsort() + # If in reduced mode we will translate i_sort in full array index + i_sort_ = translate[i_sort] if test else i_sort + # Bound for each bins in sorting view + i0, i1, _ = build_index(i[i_sort]) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i_sort[i0_]] + if i_bins == -1 or i_bins == nb_bins: + continue + yield i_sort_[i0_:i1_], bins[i_bins], bins[i_bins + 1] + else: + i = numba_digitize(x, bins) - 1 + i0, i1, _ = build_index(i) + m = ~(i0 == i1) + i0, i1 = i0[m], i1[m] + for i0_, i1_ in zip(i0, i1): + i_bins = i[i0_] + yield slice(i0_, i1_), bins[i_bins], bins[i_bins + 1] - def align_on(self, other, var_name="time", **kwargs): + def align_on(self, other, var_name="time", all_ref=False, **kwargs): """ - Align the time indexes of two datasets. + Align the variable indices of two datasets. + + :param other: other compare with self + :param str,tuple var_name: variable name to align or two array, defaults to "time" + :param bool all_ref: yield all value of ref, if false only common value, defaults to False + :yield array,array,float,float: index in self, index in other, lower bound, upper bound + + .. minigallery:: py_eddy_tracker.EddiesObservations.align_on """ - iter_self = self.iter_on(var_name, **kwargs) - iter_other = other.iter_on(var_name, **kwargs) + if isinstance(var_name, str): + iter_self = self.iter_on(var_name, **kwargs) + iter_other = other.iter_on(var_name, **kwargs) + else: + iter_self = self.iter_on(var_name[0], **kwargs) + iter_other = other.iter_on(var_name[1], **kwargs) indexs_other, b0_other, b1_other = iter_other.__next__() for indexs_self, b0_self, b1_self in iter_self: if b0_self > b0_other: @@ -605,15 +655,19 @@ def align_on(self, other, var_name="time", **kwargs): except StopIteration: break if b0_self < b0_other: + if all_ref: + yield indexs_self, empty( + 0, dtype=indexs_self.dtype + ), b0_self, b1_self continue yield indexs_self, indexs_other, b0_self, b1_self def insert_observations(self, other, index): - """Insert other obs in self at the index.""" + """Insert other obs in self at the given index.""" if not self.coherence(other): raise Exception("Observations with no coherence") - insert_size = len(other.obs) - self_size = len(self.obs) + insert_size = len(other) + self_size = len(self) new_size = self_size + insert_size if self_size == 0: self.observations = other.obs @@ -643,7 +697,7 @@ def distance(self, other): def __copy__(self): eddies = self.new_like(self, len(self)) - for k in self.obs.dtype.names: + for k in self.fields: eddies[k][:] = self[k][:] eddies.sign_type = self.sign_type return eddies @@ -682,10 +736,13 @@ def zarr_dimension(filename): h = filename else: h = zarr.open(filename) + dims = list() for varname in h: - dims.extend(list(getattr(h, varname).shape)) - return set(dims) + shape = getattr(h, varname).shape + if len(shape) > len(dims): + dims = shape + return dims @classmethod def load_file(cls, filename, **kwargs): @@ -719,6 +776,7 @@ def load_file(cls, filename, **kwargs): zarr_file = filename_.endswith(end) else: zarr_file = False + logger.info(f"loading file '{filename_}'") if zarr_file: return cls.load_from_zarr(filename, **kwargs) else: @@ -738,30 +796,29 @@ def load_from_zarr( """Load data from zarr. :param str,store filename: path or store to load data - :param bool raw_data: If true load data without apply scale_factor and add_offset - :param None,list(str) remove_vars: List of variable name which will be not loaded + :param bool raw_data: If true load data without scale_factor and add_offset + :param None,list(str) remove_vars: List of variable name that will be not loaded :param None,list(str) include_vars: If defined only this variable will be loaded - :param None,dict indexs: Indexs to laad only a slice of data + :param None,dict indexs: Indexes to load only a slice of data :param int buffer_size: Size of buffer used to load zarr data :param class_kwargs: argument to set up observations class :return: Obsevations selected :return type: class """ # FIXME - array_dim = -1 if isinstance(filename, zarr.storage.MutableMapping): h_zarr = filename else: if not isinstance(filename, str): filename = filename.astype(str) h_zarr = zarr.open(filename) + + _check_versions(h_zarr.attrs.get("framework_version", None)) var_list = cls.build_var_list(list(h_zarr.keys()), remove_vars, include_vars) nb_obs = getattr(h_zarr, var_list[0]).shape[0] - dims = list(cls.zarr_dimension(filename)) - if len(dims) == 2 and nb_obs in dims: - # FIXME must be investigate, in zarr no dimensions name (or could be add in attr) - array_dim = dims[1] if nb_obs == dims[0] else dims[0] + track_array_variables = h_zarr.attrs["track_array_variables"] + if indexs is not None and "obs" in indexs: sl = indexs["obs"] sl = slice(sl.start, min(sl.stop, nb_obs)) @@ -775,28 +832,33 @@ def load_from_zarr( logger.debug("%d observations will be load", nb_obs) kwargs = dict() - if array_dim in dims: - kwargs["track_array_variables"] = array_dim - kwargs["array_variables"] = list() - for variable in var_list: - if array_dim in h_zarr[variable].shape: - var_inv = VAR_DESCR_inv[variable] - kwargs["array_variables"].append(var_inv) - array_variables = kwargs.get("array_variables", list()) - kwargs["track_extra_variables"] = [] + kwargs["track_array_variables"] = h_zarr.attrs.get( + "track_array_variables", track_array_variables + ) + + array_variables = list() + for variable in var_list: + if len(h_zarr[variable].shape) > 1: + var_inv = VAR_DESCR_inv[variable] + array_variables.append(var_inv) + kwargs["array_variables"] = array_variables + track_extra_variables = [] + for variable in var_list: var_inv = VAR_DESCR_inv[variable] if var_inv not in cls.ELEMENTS and var_inv not in array_variables: - kwargs["track_extra_variables"].append(var_inv) + track_extra_variables.append(var_inv) + kwargs["track_extra_variables"] = track_extra_variables kwargs["raw_data"] = raw_data kwargs["only_variables"] = ( None if include_vars is None else [VAR_DESCR_inv[i] for i in include_vars] ) kwargs.update(class_kwargs) eddies = cls(size=nb_obs, **kwargs) - for variable in var_list: + + for i_var, variable in enumerate(var_list): var_inv = VAR_DESCR_inv[variable] - logger.debug("%s will be loaded", variable) + logger.debug("%s will be loaded (%d/%d)", variable, i_var, len(var_list)) # find unit factor input_unit = h_zarr[variable].attrs.get("unit", None) if input_unit is None: @@ -852,6 +914,7 @@ def copy_data_to_zarr( i_start = 0 if i_stop is None: i_stop = handler_zarr.shape[0] + for i in range(i_start, i_stop, buffer_size): sl_in = slice(i, min(i + buffer_size, i_stop)) data = handler_zarr[sl_in] @@ -862,6 +925,7 @@ def copy_data_to_zarr( data -= add_offset if scale_factor is not None: data /= scale_factor + sl_out = slice(i - i_start, i - i_start + buffer_size) handler_eddies[sl_out] = data @@ -881,7 +945,7 @@ def load_from_netcdf( :param bool raw_data: If true load data without apply scale_factor and add_offset :param None,list(str) remove_vars: List of variable name which will be not loaded :param None,list(str) include_vars: If defined only this variable will be loaded - :param None,dict indexs: Indexs to laad only a slice of data + :param None,dict indexs: Indexes to load only a slice of data :param class_kwargs: argument to set up observations class :return: Obsevations selected :return type: class @@ -895,6 +959,8 @@ def load_from_netcdf( else: args, kwargs = (filename,), dict() with Dataset(*args, **kwargs) as h_nc: + _check_versions(getattr(h_nc, "framework_version", None)) + var_list = cls.build_var_list( list(h_nc.variables.keys()), remove_vars, include_vars ) @@ -1009,6 +1075,17 @@ def compare_units(input_unit, output_unit, name): input_unit, output_unit, ) + return factor + else: + return 1 + + @classmethod + def from_array(cls, arrays, **kwargs): + nb = arrays["time"].size + eddies = cls(size=nb, **kwargs) + for k, v in arrays.items(): + eddies.obs[k] = v + return eddies @classmethod def from_zarr(cls, handler): @@ -1026,6 +1103,7 @@ def from_zarr(cls, handler): eddies.obs[variable] = handler.variables[variable][:] else: eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:] + eddies.sign_type = handler.rotation_type return eddies @classmethod @@ -1044,13 +1122,14 @@ def from_netcdf(cls, handler): eddies.obs[variable] = handler.variables[variable][:] else: eddies.obs[VAR_DESCR_inv[variable]] = handler.variables[variable][:] + eddies.sign_type = handler.rotation_type return eddies def propagate( self, previous_obs, current_obs, obs_to_extend, dead_track, nb_next, model ): """ - Filled virtual obs (C). + Fill virtual obs (C). :param previous_obs: previous obs from current (A) :param current_obs: previous obs from virtual (B) @@ -1131,7 +1210,7 @@ def match( :param bool intern: if True, speed contour is used (default = effective contour) :param float cmin: 0 < cmin < 1, return only couples with score >= cmin :param dict kwargs: look at :py:meth:`vertice_overlap` - :return: return the indexes of the eddies in self coupled with eddies in + :return: return the indices of the eddies in self coupled with eddies in other and their associated score :rtype: (array(int), array(int), array(float)) @@ -1162,7 +1241,7 @@ def re_reference_index(index, ref): :param array,int index: local index to re ref :param slice,array ref: reference could be a slice in this case we juste add start to index - or could be indexs and in this case we need to translate + or could be indices and in this case we need to translate """ if isinstance(ref, slice): return index + ref.start @@ -1244,7 +1323,7 @@ def shifted_ellipsoid_degrees_mask(self, other, minor=1.5, major=1.5): ) def fixed_ellipsoid_mask( - self, other, minor=50, major=100, only_east=False, shifted_ellips=False + self, other, minor=50, major=100, only_east=False, shifted_ellipse=False ): dist = self.distance(other).T accepted = dist < minor @@ -1263,18 +1342,18 @@ def fixed_ellipsoid_mask( if isinstance(minor, ndarray): minor = minor[index_self] # focal distance - f_degree = ((major ** 2 - minor ** 2) ** 0.5) / ( + f_degree = ((major**2 - minor**2) ** 0.5) / ( 111.2 * cos(radians(self.lat[index_self])) ) lon_self = self.lon[index_self] - if shifted_ellips: - x_center_ellips = lon_self - (major - minor) / 2 + if shifted_ellipse: + x_center_ellipse = lon_self - (major - minor) / 2 else: - x_center_ellips = lon_self + x_center_ellipse = lon_self - lon_left_f = x_center_ellips - f_degree - lon_right_f = x_center_ellips + f_degree + lon_left_f = x_center_ellipse - f_degree + lon_right_f = x_center_ellipse + f_degree dist_left_f = distance( lon_left_f, @@ -1298,7 +1377,7 @@ def fixed_ellipsoid_mask( return accepted.T @staticmethod - def basic_formula_ellips_major_axis( + def basic_formula_ellipse_major_axis( lats, cmin=1.5, cmax=10.0, c0=1.5, lat1=13.5, lat2=5.0, degrees=False ): """Give major axis in km with a given latitude""" @@ -1324,20 +1403,27 @@ def solve_conflict(cost): @staticmethod def solve_simultaneous(cost): - """Write something (TODO)""" + """Deduce link from cost matrix. + + :param array(float) cost: Cost for each available link + :return: return a boolean mask array, True for each valid couple + :rtype: array(bool) + """ mask = ~cost.mask - # Count number of link by self obs and other obs + if mask.size == 0: + return mask + # Count number of links by self obs and other obs self_links, other_links = sum_row_column(mask) max_links = max(self_links.max(), other_links.max()) if max_links > 5: logger.warning("One observation have %d links", max_links) - # If some obs have multiple link, we keep only one link by eddy + # If some obs have multiple links, we keep only one link by eddy eddies_separation = 1 < self_links eddies_merge = 1 < other_links test = eddies_separation.any() or eddies_merge.any() if test: - # We extract matrix which contains concflict + # We extract matrix that contains conflict obs_linking_to_self = mask[eddies_separation].any(axis=0) obs_linking_to_other = mask[:, eddies_merge].any(axis=1) i_self_keep = where(obs_linking_to_other + eddies_separation)[0] @@ -1360,13 +1446,13 @@ def solve_simultaneous(cost): security_increment = 0 while False in cost_reduce.mask: if security_increment > max_iteration: - # Maybe check if the size decrease if not rise an exception + # Maybe check if the size decreases if not rise an exception # x_i, y_i = where(-cost_reduce.mask) raise Exception("To many iteration: %d" % security_increment) security_increment += 1 i_min_value = cost_reduce.argmin() i, j = floor(i_min_value / shape[1]).astype(int), i_min_value % shape[1] - # Set to False all link + # Set to False all links mask[i_self_keep[i]] = False mask[:, i_other_keep[j]] = False cost_reduce.mask[i] = True @@ -1380,19 +1466,19 @@ def solve_simultaneous(cost): @staticmethod def solve_first(cost, multiple_link=False): mask = ~cost.mask - # Count number of link by self obs and other obs + # Count number of links by self obs and other obs self_links = mask.sum(axis=1) other_links = mask.sum(axis=0) max_links = max(self_links.max(), other_links.max()) if max_links > 5: logger.warning("One observation have %d links", max_links) - # If some obs have multiple link, we keep only one link by eddy + # If some obs have multiple links, we keep only one link by eddy eddies_separation = 1 < self_links eddies_merge = 1 < other_links test = eddies_separation.any() or eddies_merge.any() if test: - # We extract matrix which contains concflict + # We extract matrix that contains conflict obs_linking_to_self = mask[eddies_separation].any(axis=0) obs_linking_to_other = mask[:, eddies_merge].any(axis=1) i_self_keep = where(obs_linking_to_other + eddies_separation)[0] @@ -1465,8 +1551,7 @@ def to_zarr(self, handler, **kwargs): handler.attrs["track_array_variables"] = self.track_array_variables handler.attrs["array_variables"] = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - for ori_name in fields: + for ori_name in self.fields: # Patch for a transition name = ori_name # @@ -1511,12 +1596,9 @@ def to_netcdf(self, handler, **kwargs): handler.track_array_variables = self.track_array_variables handler.array_variables = ",".join(self.array_variables) # Iter on variables to create: - fields = [field[0] for field in self.observations.dtype.descr] - fields_ = array( - [VAR_DESCR[field[0]]["nc_name"] for field in self.observations.dtype.descr] - ) + fields_ = array([VAR_DESCR[field]["nc_name"] for field in self.fields]) i = fields_.argsort() - for ori_name in array(fields)[i]: + for ori_name in array(self.fields)[i]: # Patch for a transition name = ori_name # @@ -1583,6 +1665,33 @@ def create_variable( except ValueError: logger.warning("Data is empty") + @staticmethod + def get_filters_zarr(name): + """Get filters to store in zarr for known variable + + :param str name: private variable name + :return list: filters list + """ + content = VAR_DESCR.get(name) + filters = list() + store_dtype = content["output_type"] + scale_factor, add_offset = content.get("scale_factor", None), content.get( + "add_offset", None + ) + if scale_factor is not None or add_offset is not None: + if add_offset is None: + add_offset = 0 + filters.append( + zarr.FixedScaleOffset( + offset=add_offset, + scale=1 / scale_factor, + dtype=content["nc_type"], + astype=store_dtype, + ) + ) + filters.extend(content.get("filters", [])) + return filters + def create_variable_zarr( self, handler_zarr, @@ -1672,7 +1781,8 @@ def write_file( handler = zarr.open(filename, "w") self.to_zarr(handler, **kwargs) else: - with Dataset(filename, "w", format="NETCDF4") as handler: + nc_format = kwargs.pop("format", "NETCDF4") + with Dataset(filename, "w", format=nc_format) as handler: self.to_netcdf(handler, **kwargs) @property @@ -1696,7 +1806,7 @@ def set_global_attr_netcdf(self, h_nc): def mask_from_polygons(self, polygons): """ - Return mask for all observation in one of polygons list + Return mask for all observations in one of polygons list :param list((array,array)) polygons: list of x/y array which be used to identify observations """ @@ -1720,7 +1830,7 @@ def extract_with_area(self, area, **kwargs): :param dict area: 4 coordinates in a dictionary to specify bounding box (lower left corner and upper right corner) :param dict kwargs: look at :py:meth:`extract_with_mask` - :return: Return all eddy tracks which are in bounds + :return: Return all eddy trajetories in bounds :rtype: EddiesObservations .. code-block:: python @@ -1737,11 +1847,16 @@ def extract_with_area(self, area, **kwargs): mask *= (lon > lon0) * (lon < area["urcrnrlon"]) return self.extract_with_mask(mask, **kwargs) + @property + def time_datetime64(self): + dt = (datetime64("1970-01-01") - datetime64("1950-01-01")).astype("i8") + return (self.time - dt).astype("datetime64[D]") + def time_sub_sample(self, t0, time_step): """ Time sub sampling - :param int,float t0: reference time which will be keep + :param int,float t0: reference time that will be keep :param int,float time_step: keep every observation spaced by time_step """ mask = (self.time - t0) % time_step == 0 @@ -1762,10 +1877,9 @@ def extract_with_mask(self, mask): if nb_obs == 0: logger.warning("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] return new def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): @@ -1775,7 +1889,7 @@ def scatter(self, ax, name=None, ref=None, factor=1, **kwargs): :param matplotlib.axes.Axes ax: matplotlib axe used to draw :param str,array,None name: variable used to fill the contour, if None all elements have the same color - :param float,None ref: if define use like west bound + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param float factor: multiply value by :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.scatter` :return: scatter mappable @@ -1807,7 +1921,7 @@ def filled( """ :param matplotlib.axes.Axes ax: matplotlib axe used to draw :param str,array,None varname: variable used to fill the contours, or an array of same size than obs - :param float,None ref: if define use like west bound? + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param bool intern: if True draw speed contours instead of effective contours :param str cmap: matplotlib colormap name :param int,None lut: Number of colors in the colormap @@ -1969,11 +2083,43 @@ def format_label(self, label): nb_obs=len(self), ) + def display_color(self, ax, field, ref=None, intern=False, **kwargs): + """Plot colored contour of eddies + + :param matplotlib.axes.Axes ax: matplotlib axe used to draw + :param str,array field: color field + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary + :param bool intern: if True, draw the speed contour + :param dict kwargs: look at :py:meth:`matplotlib.collections.LineCollection` + + .. minigallery:: py_eddy_tracker.EddiesObservations.display_color + """ + xname, yname = self.intern(intern) + x, y = self[xname], self[yname] + + if ref is not None: + # TODO : maybe buggy with global display + shape_out = x.shape + x, y = wrap_longitude(x.reshape(-1), y.reshape(-1), ref) + x, y = x.reshape(shape_out), y.reshape(shape_out) + + c = self.parse_varname(field) + cmap = get_cmap(kwargs.pop("cmap", "Spectral_r")) + cmin, cmax = kwargs.pop("vmin", c.min()), kwargs.pop("vmax", c.max()) + colors = cmap((c - cmin) / (cmax - cmin)) + lines = LineCollection( + [create_vertice(i, j) for i, j in zip(x, y)], colors=colors, **kwargs + ) + ax.add_collection(lines) + lines.cmap = cmap + lines.norm = Normalize(vmin=cmin, vmax=cmax) + return lines + def display(self, ax, ref=None, extern_only=False, intern_only=False, **kwargs): """Plot the speed and effective (dashed) contour of the eddies :param matplotlib.axes.Axes ax: matplotlib axe used to draw - :param float,None ref: western longitude reference used + :param float,None ref: if defined, all coordinates are wrapped with ref as western boundary :param bool extern_only: if True, draw only the effective contour :param bool intern_only: if True, draw only the speed contour :param dict kwargs: look at :py:meth:`matplotlib.axes.Axes.plot` @@ -2041,7 +2187,7 @@ def is_convex(self, intern=False): def contains(self, x, y, intern=False): """ - Return index of contour which contain (x,y) + Return index of contour containing (x,y) :param array x: longitude :param array y: latitude @@ -2050,7 +2196,12 @@ def contains(self, x, y, intern=False): :rtype: array[int32] """ xname, yname = self.intern(intern) - return poly_indexs(x, y, self[xname], self[yname]) + m = ~(isnan(x) + isnan(y)) + i = -ones(x.shape, dtype="i4") + + if x.size != 0 and m.any(): + i[m] = poly_indexs(x[m], y[m], self[xname], self[yname]) + return i def inside(self, x, y, intern=False): """ @@ -2063,7 +2214,6 @@ def inside(self, x, y, intern=False): :rtype: array[bool] """ xname, yname = self.intern(intern) - # FIXME: wrapping return insidepoly(x, y, self[xname], self[yname]) def grid_count(self, bins, intern=False, center=False, filter=slice(None)): @@ -2106,7 +2256,7 @@ def grid_count(self, bins, intern=False, center=False, filter=slice(None)): x_ref = ((self.longitude[filter] - x0) % 360 + x0 - 180).reshape(-1, 1) x_contour, y_contour = self[x_name][filter], self[y_name][filter] grid_count_pixel_in( - grid, + grid.data, x_contour, y_contour, x_ref, @@ -2209,7 +2359,7 @@ def grid_stat(self, bins, varname, data=None): return regular_grid def interp_grid( - self, grid_object, varname, method="center", dtype=None, intern=None + self, grid_object, varname, i=None, method="center", dtype=None, intern=None ): """ Interpolate a grid on a center or contour with mean, min or max method @@ -2217,24 +2367,34 @@ def interp_grid( :param grid_object: Handler of grid to interp :type grid_object: py_eddy_tracker.dataset.grid.RegularGridDataset :param str varname: Name of variable to use + :param array[bool,int],None i: + Index or mask to subset observations, it could avoid to build a specific dataset. :param str method: 'center', 'mean', 'max', 'min', 'nearest' :param str dtype: if None we use var dtype :param bool intern: Use extern or intern contour + + .. minigallery:: py_eddy_tracker.EddiesObservations.interp_grid """ if method in ("center", "nearest"): - return grid_object.interp(varname, self.longitude, self.latitude, method) + x, y = self.longitude, self.latitude + if i is not None: + x, y = x[i], y[i] + return grid_object.interp(varname, x, y, method) elif method in ("min", "max", "mean", "count"): x0 = grid_object.x_bounds[0] x_name, y_name = self.intern(False if intern is None else intern) x_ref = ((self.longitude - x0) % 360 + x0 - 180).reshape(-1, 1) x, y = (self[x_name] - x_ref) % 360 + x_ref, self[y_name] + if i is not None: + x, y = x[i], y[i] grid = grid_object.grid(varname) - result = empty(self.shape, dtype=grid.dtype if dtype is None else dtype) + result = empty(x.shape[0], dtype=grid.dtype if dtype is None else dtype) min_method = method == "min" grid_stat( grid_object.x_c, grid_object.y_c, - -grid if min_method else grid, + -grid.data if min_method else grid.data, + grid.mask, x, y, result, @@ -2262,18 +2422,34 @@ def period(self): @property def nb_days(self): - """Return period days cover by dataset + """Return period in days covered by the dataset :return: Number of days :rtype: int """ return self.period[1] - self.period[0] + 1 + def create_particles(self, step, intern=True): + """Create particles inside contour (Default : speed contour). Avoid creating too large numpy arrays, only to be masked + + :param step: step for particles + :type step: float + :param bool intern: If true use speed contour instead of effective contour + :return: lon, lat and indices of particles + :rtype: tuple(np.array) + """ + + xname, yname = self.intern(intern) + return create_meshed_particles(self[xname], self[yname], step) + + def empty_dataset(self): + return self.new_like(self, 0) + @njit(cache=True) def grid_count_(grid, i, j): """ - Add one to each index + Add 1 to each index """ for i_, j_ in zip(i, j): grid[i_, j_] += 1 @@ -2296,7 +2472,7 @@ def grid_count_pixel_in( y_c, ): """ - Count how many time a pixel is used. + Count how many times a pixel is used. :param array grid: :param array x: x for all contour @@ -2385,13 +2561,14 @@ def grid_box_stat(x_c, y_c, grid, mask, x, y, value, circular=False, method=50): @njit(cache=True) -def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): +def grid_stat(x_c, y_c, grid, mask, x, y, result, circular=False, method="mean"): """ Compute the mean or the max of the grid for each contour :param array_like x_c: the grid longitude coordinates :param array_like y_c: the grid latitude coordinates :param array_like grid: grid value + :param array[bool] mask: mask for invalid value :param array_like x: longitude of contours :param array_like y: latitude of contours :param array_like result: return values @@ -2417,9 +2594,12 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): result[elt] = i.shape[0] elif mean_method: v_sum = 0 + nb_ = 0 for i_, j_ in zip(i, j): + if mask[i_, j_]: + continue v_sum += grid[i_, j_] - nb_ = i.shape[0] + nb_ += 1 # FIXME : how does it work on grid bound, if nb_ == 0: result[elt] = nan @@ -2428,7 +2608,9 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"): elif max_method: v_max = -1e40 for i_, j_ in zip(i, j): - v_max = max(v_max, grid[i_, j_]) + values = grid[i_, j_] + # FIXME must use mask + v_max = max(v_max, values) result[elt] = v_max diff --git a/src/py_eddy_tracker/observations/tracking.py b/src/py_eddy_tracker/observations/tracking.py index 5902462d..fa1c1f93 100644 --- a/src/py_eddy_tracker/observations/tracking.py +++ b/src/py_eddy_tracker/observations/tracking.py @@ -2,8 +2,8 @@ """ Class to manage observations gathered in trajectories """ -import logging from datetime import datetime, timedelta +import logging from numba import njit from numpy import ( @@ -16,6 +16,7 @@ degrees, empty, histogram, + int_, median, nan, ones, @@ -67,6 +68,10 @@ def __init__(self, *args, **kwargs): self.__obs_by_track = None self.__nb_track = None + def track_slice(self, track): + i0 = self.index_from_track[track] + return slice(i0, i0 + self.nb_obs_by_track[track]) + def iter_track(self): """ Yield track @@ -77,7 +82,7 @@ def iter_track(self): yield self.index(slice(i0, i0 + nb)) def get_missing_indices(self, dt): - """find indices where observations is missing. + """Find indices where observations are missing. :param int,float dt: theorical delta time between 2 observations """ @@ -90,7 +95,7 @@ def get_missing_indices(self, dt): ) def fix_next_previous_obs(self): - """function used after 'insert_virtual', to correct next_obs and + """Function used after 'insert_virtual', to correct next_obs and previous obs. """ @@ -99,7 +104,7 @@ def fix_next_previous_obs(self): @property def nb_tracks(self): """ - Will count and send number of track + Count and return number of track """ if self.__nb_track is None: if len(self) == 0: @@ -113,7 +118,7 @@ def __repr__(self): t0, t1 = self.period period = t1 - t0 + 1 nb = self.nb_obs_by_track - nb_obs = self.observations.shape[0] + nb_obs = len(self) m = self.virtual.astype("bool") nb_m = m.sum() bins_t = (1, 30, 90, 180, 270, 365, 1000, 10000) @@ -142,7 +147,7 @@ def __repr__(self): def add_distance(self): """Add a field of distance (m) between two consecutive observations, 0 for the last observation of each track""" - if "distance_next" in self.observations.dtype.descr: + if "distance_next" in self.fields: return self new = self.add_fields(("distance_next",)) new["distance_next"][:1] = self.distance_to_next() @@ -150,7 +155,7 @@ def add_distance(self): def distance_to_next(self): """ - :return: array of distance in m, 0 when next obs if from another track + :return: array of distance in m, 0 when next obs is from another track :rtype: array """ d = distance( @@ -166,26 +171,28 @@ def distance_to_next(self): return d_ def normalize_longitude(self): - """Normalize all longitude + """Normalize all longitudes Normalize longitude field and in the same range : - longitude_max - contour_lon_e (how to do if in raw) - contour_lon_s (how to do if in raw) """ + if self.lon.size == 0: + return lon0 = (self.lon[self.index_from_track] - 180).repeat(self.nb_obs_by_track) logger.debug("Normalize longitude") self.lon[:] = (self.lon - lon0) % 360 + lon0 - if "lon_max" in self.obs.dtype.names: + if "lon_max" in self.fields: logger.debug("Normalize longitude_max") self.lon_max[:] = (self.lon_max - self.lon + 180) % 360 + self.lon - 180 if not self.raw_data: - if "contour_lon_e" in self.obs.dtype.names: + if "contour_lon_e" in self.fields: logger.debug("Normalize effective contour longitude") self.contour_lon_e[:] = ( (self.contour_lon_e.T - self.lon + 180) % 360 + self.lon - 180 ).T - if "contour_lon_s" in self.obs.dtype.names: + if "contour_lon_s" in self.fields: logger.debug("Normalize speed contour longitude") self.contour_lon_s[:] = ( (self.contour_lon_s.T - self.lon + 180) % 360 + self.lon - 180 @@ -198,10 +205,9 @@ def extract_longer_eddies(self, nb_min, nb_obs, compress_id=True): logger.info("Selection of %d observations", nb_obs_select) eddies = self.__class__.new_like(self, nb_obs_select) eddies.sign_type = self.sign_type - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - eddies.obs[var] = self.obs[var][mask] + eddies.obs[field] = self.obs[field][mask] if compress_id: list_id = unique(eddies.obs.track) list_id.sort() @@ -217,7 +223,7 @@ def elements(self): return list(set(elements)) def set_global_attr_netcdf(self, h_nc): - """Set global attr""" + """Set global attributes""" h_nc.title = "Cyclonic" if self.sign_type == -1 else "Anticyclonic" h_nc.Metadata_Conventions = "Unidata Dataset Discovery v1.0" h_nc.comment = "Surface product; mesoscale eddies" @@ -228,20 +234,21 @@ def set_global_attr_netcdf(self, h_nc): ) h_nc.date_created = datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ") t = h_nc.variables[VAR_DESCR_inv["j1"]] - delta = t.max - t.min + 1 - h_nc.time_coverage_duration = "P%dD" % delta - d_start = datetime(1950, 1, 1) + timedelta(int(t.min)) - d_end = datetime(1950, 1, 1) + timedelta(int(t.max)) - h_nc.time_coverage_start = d_start.strftime("%Y-%m-%dT00:00:00Z") - h_nc.time_coverage_end = d_end.strftime("%Y-%m-%dT00:00:00Z") + if t.size: + delta = t.max - t.min + 1 + h_nc.time_coverage_duration = "P%dD" % delta + d_start = datetime(1950, 1, 1) + timedelta(int(t.min)) + d_end = datetime(1950, 1, 1) + timedelta(int(t.max)) + h_nc.time_coverage_start = d_start.strftime("%Y-%m-%dT00:00:00Z") + h_nc.time_coverage_end = d_end.strftime("%Y-%m-%dT00:00:00Z") def extract_with_period(self, period, **kwargs): """ Extract within a time period - :param (int,int) period: two dates to define the period, must be specify from 1/1/1950 + :param (int,int) period: two dates to define the period, must be specified from 1/1/1950 :param dict kwargs: look at :py:meth:`extract_with_mask` - :return: Return all eddy tracks which are in bounds + :return: Return all eddy tracks in period :rtype: TrackEddiesObservations .. minigallery:: py_eddy_tracker.TrackEddiesObservations.extract_with_period @@ -264,9 +271,9 @@ def get_azimuth(self, equatorward=False): """ Return azimuth for each track. - Azimuth is computed with first and last observation + Azimuth is computed with first and last observations - :param bool equatorward: If True, Poleward are positive and equatorward negative + :param bool equatorward: If True, Poleward is positive and Equatorward negative :rtype: array """ i0, nb = self.index_from_track, self.nb_obs_by_track @@ -373,19 +380,17 @@ def extract_toward_direction(self, west=True, delta_lon=None): d_lon = lon[i1] - lon[i0] m = d_lon < 0 if west else d_lon > 0 if delta_lon is not None: - m *= delta_lon < d_lon + m *= delta_lon < abs(d_lon) m = m.repeat(nb) return self.extract_with_mask(m) def extract_first_obs_in_box(self, res): - data = empty( - self.obs.shape, dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")] - ) + data = empty(len(self), dtype=[("lon", "f4"), ("lat", "f4"), ("track", "i4")]) data["lon"] = self.longitude - self.longitude % res data["lat"] = self.latitude - self.latitude % res data["track"] = self.track _, indexs = unique(data, return_index=True) - mask = zeros(self.obs.shape, dtype="bool") + mask = zeros(len(self), dtype="bool") mask[indexs] = True return self.extract_with_mask(mask) @@ -427,12 +432,9 @@ def extract_with_length(self, bounds): track_mask = self.nb_obs_by_track >= b0 else: logger.warning("No valid value for bounds") - raise Exception("One bounds must be positiv") + raise Exception("One bound must be positive") return self.extract_with_mask(track_mask.repeat(self.nb_obs_by_track)) - def empty_dataset(self): - return self.new_like(self, 0) - def loess_filter(self, half_window, xfield, yfield, inplace=True): track = self.track x = self.obs[xfield] @@ -441,6 +443,7 @@ def loess_filter(self, half_window, xfield, yfield, inplace=True): if inplace: self.obs[yfield] = result return self + return result def median_filter(self, half_window, xfield, yfield, inplace=True): result = track_median_filter( @@ -474,7 +477,7 @@ def extract_with_mask( :param bool full_path: extract the full trajectory if only one part is selected :param bool remove_incomplete: delete trajectory if not fully selected :param bool compress_id: resample trajectory number to use a smaller range - :param bool reject_virtual: if only virtual are selected, the trajectory is removed + :param bool reject_virtual: if only virtuals are selected, the trajectory is removed :return: same object with the selected observations :rtype: self.__class__ """ @@ -497,12 +500,11 @@ def extract_with_mask( new = self.__class__.new_like(self, nb_obs) new.sign_type = self.sign_type if nb_obs == 0: - logger.warning("Empty dataset will be created") + logger.info("Empty dataset will be created") else: - for field in self.obs.dtype.descr: + for field in self.fields: logger.debug("Copy of field %s ...", field) - var = field[0] - new.obs[var] = self.obs[var][mask] + new.obs[field] = self.obs[field][mask] if compress_id: list_id = unique(new.track) list_id.sort() @@ -525,10 +527,10 @@ def shape_polygon(self, intern=False): def display_shape(self, ax, ref=None, intern=False, **kwargs): """ - This function will draw the shape of each trajectory + This function draws the shape of each trajectory :param matplotlib.axes.Axes ax: ax to draw - :param float,int ref: if defined all coordinates will be wrapped with ref like west boundary + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary :param bool intern: If True use speed contour instead of effective contour :param dict kwargs: keyword arguments for Axes.plot :return: matplotlib mappable @@ -557,14 +559,17 @@ def close_tracks(self, other, nb_obs_min=10, **kwargs): :param self other: Atlas to compare :param int nb_obs_min: Minimal number of overlap for one trajectory :param dict kwargs: keyword arguments for match function - :return: return other atlas reduce to common track with self + :return: return other atlas reduced to common trajectories with self .. warning:: It could be a costly operation for huge dataset """ p0, p1 = self.period + p0_other, p1_other = other.period + if p1_other < p0 or p1 < p0_other: + return other.__class__.new_like(other, 0) indexs = list() - for i_self, i_other, t0, t1 in self.align_on(other, bins=range(p0, p1 + 2)): + for i_self, i_other, t0, t1 in self.align_on(other, bins=arange(p0, p1 + 2)): i, j, s = self.match(other, i_self=i_self, i_other=i_other, **kwargs) indexs.append(other.re_reference_index(j, i_other)) indexs = concatenate(indexs) @@ -585,7 +590,7 @@ def plot(self, ax, ref=None, **kwargs): This function will draw path of each trajectory :param matplotlib.axes.Axes ax: ax to draw - :param float,int ref: if defined, all coordinates will be wrapped with ref like west boundary + :param float,int ref: if defined, all coordinates are wrapped with ref as western boundary :param dict kwargs: keyword arguments for Axes.plot :return: matplotlib mappable """ @@ -601,6 +606,12 @@ def plot(self, ax, ref=None, **kwargs): def split_network(self, intern=True, **kwargs): """Return each group (network) divided in segments""" + # Find timestep of dataset + # FIXME : how to know exact time sampling + t = unique(self.time) + dts = t[1:] - t[:-1] + timestep = median(dts) + track_s, track_e, track_ref = build_index(self.tracks) ids = empty( len(self), @@ -614,11 +625,11 @@ def split_network(self, intern=True, **kwargs): ("next_obs", "i4"), ], ) - ids["group"], ids["time"] = self.tracks, self.time + ids["group"], ids["time"] = self.tracks, int_(self.time / timestep) # Initialisation # To store the id of the segments, the backward and forward cost associations ids["track"], ids["previous_cost"], ids["next_cost"] = 0, 0, 0 - # To store the indexes of the backward and forward observations associated + # To store the indices of the backward and forward observations associated ids["previous_obs"], ids["next_obs"] = -1, -1 # At the end, ids["previous_obs"] == -1 means the start of a non-split segment # and ids["next_obs"] == -1 means the end of a non-merged segment @@ -641,19 +652,19 @@ def split_network(self, intern=True, **kwargs): local_ids["next_obs"][m] += i_s if display_iteration: print() + ids["time"] *= timestep return ids def set_tracks(self, x, y, ids, window, **kwargs): """ - Will split one group (network) in segments + Split one group (network) in segments :param array x: coordinates of group :param array y: coordinates of group :param ndarray ids: several fields like time, group, ... - :param int windows: number of days where observations could missed + :param int window: number of days where observations could missed """ - - time_index = build_index(ids["time"]) + time_index = build_index((ids["time"]).astype("i4")) nb = x.shape[0] used = zeros(nb, dtype="bool") track_id = 1 @@ -695,11 +706,19 @@ def follow_obs(cls, i_next, track_id, used, ids, *args, **kwargs): @staticmethod def get_previous_obs( - i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.2, + **kwargs, ): """Backward association of observations to the segments""" - - time_cur = ids["time"][i_current] + time_cur = int_(ids["time"][i_current]) t0, t1 = time_cur - 1 - time_ref, max(time_cur - window - time_ref, 0) for t_step in range(t0, t1 - 1, -1): i0, i1 = time_s[t_step], time_e[t_step] @@ -712,9 +731,9 @@ def get_previous_obs( if len(ii) == 0: continue c = zeros(len(xj)) - c[ij] = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) - # We remove low overlap - c[c < 0.01] = 0 + c[ij] = vertice_overlap( + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs + ) # We get index of maximal overlap i = c.argmax() c_i = c[i] @@ -726,10 +745,21 @@ def get_previous_obs( break @staticmethod - def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwargs): + def get_next_obs( + i_current, + ids, + x, + y, + time_s, + time_e, + time_ref, + window, + min_overlap=0.2, + **kwargs, + ): """Forward association of observations to the segments""" time_max = time_e.shape[0] - 1 - time_cur = ids["time"][i_current] + time_cur = int_(ids["time"][i_current]) t0, t1 = time_cur + 1 - time_ref, min(time_cur + window - time_ref, time_max) if t0 > time_max: return -1 @@ -744,9 +774,9 @@ def get_next_obs(i_current, ids, x, y, time_s, time_e, time_ref, window, **kwarg if len(ii) == 0: continue c = zeros(len(xj)) - c[ij] = vertice_overlap(xi[ii], yi[ii], xj[ij], yj[ij], **kwargs) - # We remove low overlap - c[c < 0.01] = 0 + c[ij] = vertice_overlap( + xi[ii], yi[ii], xj[ij], yj[ij], min_overlap=min_overlap, **kwargs + ) # We get index of maximal overlap i = c.argmax() c_i = c[i] @@ -791,10 +821,10 @@ def track_loess_filter(half_window, x, y, track): """ Apply a loess filter on y field - :param int,float window: parameter of smoother + :param int,float half_window: parameter of smoother :param array_like x: must be growing for each track but could be irregular :param array_like y: field to smooth - :param array_like track: field which allow to separate path + :param array_like track: field that allows to separate path :return: Array smoothed :rtype: array_like diff --git a/src/py_eddy_tracker/poly.py b/src/py_eddy_tracker/poly.py index 2bf34509..b5849610 100644 --- a/src/py_eddy_tracker/poly.py +++ b/src/py_eddy_tracker/poly.py @@ -5,11 +5,10 @@ import heapq -from numba import njit, prange -from numba import types as numba_types +from Polygon import Polygon +from numba import njit, prange, types as numba_types from numpy import arctan, array, concatenate, empty, nan, ones, pi, where, zeros from numpy.linalg import lstsq -from Polygon import Polygon from .generic import build_index @@ -86,7 +85,7 @@ def poly_area_vertice(v): @njit(cache=True) def poly_area(x, y): """ - Must be call with local coordinates (in m, to get an area in m²). + Must be called with local coordinates (in m, to get an area in m²). :param array x: :param array y: @@ -279,7 +278,10 @@ def close_center(x0, y0, x1, y1, delta=0.1): for i0 in range(nb0): xi0, yi0 = x0[i0], y0[i0] for i1 in range(nb1): - if abs(x1[i1] - xi0) > delta: + d_x = x1[i1] - xi0 + if abs(d_x) > 180: + d_x = (d_x + 180) % 360 - 180 + if abs(d_x) > delta: continue if abs(y1[i1] - yi0) > delta: continue @@ -287,6 +289,27 @@ def close_center(x0, y0, x1, y1, delta=0.1): return array(i), array(j), array(c) +@njit(cache=True) +def create_meshed_particles(lons, lats, step): + x_out, y_out, i_out = list(), list(), list() + nb = lons.shape[0] + for i in range(nb): + lon, lat = lons[i], lats[i] + vertice = create_vertice(*reduce_size(lon, lat)) + lon_min, lon_max = lon.min(), lon.max() + lat_min, lat_max = lat.min(), lat.max() + y0 = lat_min - lat_min % step + x = lon_min - lon_min % step + while x <= lon_max: + y = y0 + while y <= lat_max: + if winding_number_poly(x, y, vertice): + x_out.append(x), y_out.append(y), i_out.append(i) + y += step + x += step + return array(x_out), array(y_out), array(i_out) + + @njit(cache=True, fastmath=True) def bbox_intersection(x0, y0, x1, y1): """ @@ -411,7 +434,9 @@ def merge(x, y): return concatenate(x), concatenate(y) -def vertice_overlap(x0, y0, x1, y1, minimal_area=False): +def vertice_overlap( + x0, y0, x1, y1, minimal_area=False, p1_area=False, hybrid_area=False, min_overlap=0 +): r""" Return percent of overlap for each item. @@ -420,6 +445,10 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): :param array x1: x for polygon list 1 :param array y1: y for polygon list 1 :param bool minimal_area: If True, function will compute intersection/little polygon, else intersection/union + :param bool p1_area: If True, function will compute intersection/p1 polygon, else intersection/union + :param bool hybrid_area: If True, function will compute like union, + but if cost is under min_overlap, obs is kept in case of fully included + :param float min_overlap: under this value cost is set to zero :return: Result of cost function :rtype: array @@ -430,6 +459,10 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): If minimal area: .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{min(P_{0 area},P_{1 area})} + + If P1 area: + + .. math:: Score = \frac{Intersection(P_0,P_1)_{area}}{P_{1 area}} """ nb = x0.shape[0] cost = empty(nb) @@ -441,11 +474,29 @@ def vertice_overlap(x0, y0, x1, y1, minimal_area=False): # Area of intersection intersection = (p0 & p1).area() # we divide intersection with the little one result from 0 to 1 + if intersection == 0: + cost[i] = 0 + continue + p0_area_, p1_area_ = p0.area(), p1.area() if minimal_area: - cost[i] = intersection / min(p0.area(), p1.area()) + cost_ = intersection / min(p0_area_, p1_area_) + # we divide intersection with p1 + elif p1_area: + cost_ = intersection / p1_area_ # we divide intersection with polygon merging result from 0 to 1 else: - cost[i] = intersection / (p0 + p1).area() + cost_ = intersection / (p0_area_ + p1_area_ - intersection) + if cost_ >= min_overlap: + cost[i] = cost_ + else: + if ( + hybrid_area + and cost_ != 0 + and (intersection / min(p0_area_, p1_area_)) > 0.99 + ): + cost[i] = cost_ + else: + cost[i] = 0 return cost @@ -474,13 +525,13 @@ def polygon_overlap(p0, p1, minimal_area=False): return cost -# FIXME: only one function are needed +# FIXME: only one function is needed @njit(cache=True) def fit_circle(x, y): """ From a polygon, function will fit a circle. - Must be call with local coordinates (in m, to get a radius in m). + Must be called with local coordinates (in m, to get a radius in m). :param array x: x of polygon :param array y: y of polygon @@ -495,7 +546,7 @@ def fit_circle(x, y): norme = (x[1:] - x_mean) ** 2 + (y[1:] - y_mean) ** 2 norme_max = norme.max() - scale = norme_max ** 0.5 + scale = norme_max**0.5 # Form matrix equation and solve it # Maybe put f4 @@ -506,15 +557,15 @@ def fit_circle(x, y): (x0, y0, radius), _, _, _ = lstsq(datas, norme / norme_max) # Unscale data and get circle variables - radius += x0 ** 2 + y0 ** 2 + radius += x0**2 + y0**2 radius **= 0.5 x0 *= scale y0 *= scale - # radius of fitted circle + # radius of fit circle radius *= scale - # center X-position of fitted circle + # center X-position of fit circle x0 += x_mean - # center Y-position of fitted circle + # center Y-position of fit circle y0 += y_mean err = shape_error(x, y, x0, y0, radius) @@ -522,9 +573,9 @@ def fit_circle(x, y): @njit(cache=True) -def fit_ellips(x, y): +def fit_ellipse(x, y): r""" - From a polygon, function will fit an ellips. + From a polygon, function will fit an ellipse. Must be call with local coordinates (in m, to get a radius in m). @@ -536,24 +587,23 @@ def fit_ellips(x, y): https://en.wikipedia.org/wiki/Ellipse """ - # x,y = x[1:],y[1:] nb = x.shape[0] datas = ones((nb, 5), dtype=x.dtype) - datas[:, 0] = x ** 2 + datas[:, 0] = x**2 datas[:, 1] = x * y - datas[:, 2] = y ** 2 + datas[:, 2] = y**2 datas[:, 3] = x datas[:, 4] = y (a, b, c, d, e), _, _, _ = lstsq(datas, ones(nb, dtype=x.dtype)) - det = b ** 2 - 4 * a * c + det = b**2 - 4 * a * c if det > 0: print(det) x0 = (2 * c * d - b * e) / det y0 = (2 * a * e - b * d) / det - AB1 = 2 * (a * e ** 2 + c * d ** 2 - b * d * e - det) + AB1 = 2 * (a * e**2 + c * d**2 - b * d * e - det) AB2 = a + c - AB3 = ((a - c) ** 2 + b ** 2) ** 0.5 + AB3 = ((a - c) ** 2 + b**2) ** 0.5 A = -((AB1 * (AB2 + AB3)) ** 0.5) / det B = -((AB1 * (AB2 - AB3)) ** 0.5) / det theta = arctan((c - a - AB3) / b) @@ -614,7 +664,7 @@ def fit_circle_(x, y): # Linear regression (a, b, c), _, _, _ = lstsq(datas, x[1:] ** 2 + y[1:] ** 2) x0, y0 = a / 2.0, b / 2.0 - radius = (c + x0 ** 2 + y0 ** 2) ** 0.5 + radius = (c + x0**2 + y0**2) ** 0.5 err = shape_error(x, y, x0, y0, radius) return x0, y0, radius, err @@ -639,14 +689,14 @@ def shape_error(x, y, x0, y0, r): :rtype: float """ # circle area - c_area = (r ** 2) * pi + c_area = (r**2) * pi p_area = poly_area(x, y) nb = x.shape[0] x, y = x.copy(), y.copy() # Find distance between circle center and polygon for i in range(nb): dx, dy = x[i] - x0, y[i] - y0 - rd = r / (dx ** 2 + dy ** 2) ** 0.5 + rd = r / (dx**2 + dy**2) ** 0.5 if rd < 1: x[i] = x0 + dx * rd y[i] = y0 + dy * rd @@ -717,50 +767,86 @@ def tri_area2(x, y, i0, i1, i2): def visvalingam(x, y, fixed_size=18): """Polygon simplification with visvalingam algorithm + X, Y are considered like a polygon, the next point after the last one is the first one + :param array x: :param array y: :param int fixed_size: array size of out - :return: New (x, y) array + :return: + New (x, y) array, last position will be equal to first one, if array size is 6, + there is only 5 point. :rtype: array,array + + .. plot:: + + import matplotlib.pyplot as plt + import numpy as np + from py_eddy_tracker.poly import visvalingam + + x = np.array([1, 2, 3, 4, 5, 6.75, 6, 1]) + y = np.array([-0.5, -1.5, -1, -1.75, -1, -1, -0.5, -0.5]) + ax = plt.subplot(111) + ax.set_aspect("equal") + ax.grid(True), ax.set_ylim(-2, -.2) + ax.plot(x, y, "r", lw=5) + ax.plot(*visvalingam(x,y,6), "b", lw=2) + plt.show() """ + # TODO : in case of original size lesser than fixed size, jump at the end nb = x.shape[0] - i0, i1 = nb - 3, nb - 2 + nb_ori = nb + # Get indice of first triangle + i0, i1 = nb - 2, nb - 1 + # Init heap with first area and tiangle h = [(tri_area2(x, y, i0, i1, 0), (i0, i1, 0))] + # Roll index for next one i0 = i1 i1 = 0 - i_previous = empty(nb - 1, dtype=numba_types.int32) - i_next = empty(nb - 1, dtype=numba_types.int32) + # Index of previous valid point + i_previous = empty(nb, dtype=numba_types.int64) + # Index of next valid point + i_next = empty(nb, dtype=numba_types.int64) + # Mask of removed + removed = zeros(nb, dtype=numba_types.bool_) i_previous[0] = -1 i_next[0] = -1 - for i in range(1, nb - 1): + for i in range(1, nb): i_previous[i] = -1 i_next[i] = -1 + # We add triangle area for all triangle heapq.heappush(h, (tri_area2(x, y, i0, i1, i), (i0, i1, i))) i0 = i1 i1 = i # we continue until we are equal to nb_pt - while len(h) >= fixed_size: + while nb >= fixed_size: # We pop lower area _, (i0, i1, i2) = heapq.heappop(h) # We check if triangle is valid(i0 or i2 not removed) - i_p, i_n = i_previous[i0], i_next[i2] - if i_p == -1 and i_n == -1: - # We store reference of delete point - i_previous[i1] = i0 - i_next[i1] = i2 + if removed[i0] or removed[i2]: + # In this cas nothing to do continue - elif i_p == -1: - i2 = i_n - elif i_n == -1: - i0 = i_p - else: - # in this case we replace two point - i0, i2 = i_p, i_n - heapq.heappush(h, (tri_area2(x, y, i0, i1, i2), (i0, i1, i2))) + # Flag obs like removed + removed[i1] = True + # We count point still valid + nb -= 1 + # Modify index for the next and previous, we jump over i1 + i_previous[i2] = i0 + i_next[i0] = i2 + # We insert 2 triangles which are modified by the deleted point + # Previous triangle + i_1 = i_previous[i0] + if i_1 == -1: + i_1 = (i0 - 1) % nb_ori + heapq.heappush(h, (tri_area2(x, y, i_1, i0, i2), (i_1, i0, i2))) + # Previous triangle + i3 = i_next[i2] + if i3 == -1: + i3 = (i2 + 1) % nb_ori + heapq.heappush(h, (tri_area2(x, y, i0, i2, i3), (i0, i2, i3))) x_new, y_new = empty(fixed_size, dtype=x.dtype), empty(fixed_size, dtype=y.dtype) j = 0 - for i, i_n in enumerate(i_next): - if i_n == -1: + for i, flag in enumerate(removed): + if not flag: x_new[j] = x[i] y_new[j] = y[i] j += 1 @@ -815,12 +901,12 @@ def box_indexes(x, y, step): @njit(cache=True) -def poly_indexs_(x_p, y_p, x_c, y_c): +def poly_indexs(x_p, y_p, x_c, y_c): """ Index of contour for each postion inside a contour, -1 in case of no contour - :param array x_p: longitude to test (must be define, no nan) - :param array y_p: latitude to test (must be define, no nan) + :param array x_p: longitude to test (must be defined, no nan) + :param array y_p: latitude to test (must be defined, no nan) :param array x_c: longitude of contours :param array y_c: latitude of contours """ @@ -830,7 +916,7 @@ def poly_indexs_(x_p, y_p, x_c, y_c): nb_p = x_p.shape[0] nb_c = x_c.shape[0] indexs = -ones(nb_p, dtype=numba_types.int32) - # Adress table to get particle bloc + # Adress table to get test bloc start_index, end_index, i_first = build_index(i[i_order]) nb_bloc = end_index.size for i_contour in range(nb_c): @@ -873,34 +959,6 @@ def poly_indexs_(x_p, y_p, x_c, y_c): return indexs -@njit(cache=True) -def poly_indexs(x_p, y_p, x_c, y_c): - """ - index of contour for each postion inside a contour, -1 in case of no contour - - :param array x_p: longitude to test - :param array y_p: latitude to test - :param array x_c: longitude of contours - :param array y_c: latitude of contours - """ - nb_p = x_p.shape[0] - nb_c = x_c.shape[0] - indexs = -ones(nb_p, dtype=numba_types.int32) - for i in range(nb_c): - x_, y_ = reduce_size(x_c[i], y_c[i]) - x_c_min, y_c_min = x_.min(), y_.min() - x_c_max, y_c_max = x_.max(), y_.max() - v = create_vertice(x_, y_) - for j in range(nb_p): - if indexs[j] != -1: - continue - x, y = x_p[j], y_p[j] - if x > x_c_min and x < x_c_max and y > y_c_min and y < y_c_max: - if winding_number_poly(x, y, v) != 0: - indexs[j] = i - return indexs - - @njit(cache=True) def insidepoly(x_p, y_p, x_c, y_c): """ @@ -911,20 +969,4 @@ def insidepoly(x_p, y_p, x_c, y_c): :param array x_c: longitude of contours :param array y_c: latitude of contours """ - # TODO must be optimize like poly_index - nb_p = x_p.shape[0] - nb_c = x_c.shape[0] - flag = zeros(nb_p, dtype=numba_types.bool_) - for i in range(nb_c): - x_, y_ = reduce_size(x_c[i], y_c[i]) - x_c_min, y_c_min = x_.min(), y_.min() - x_c_max, y_c_max = x_.max(), y_.max() - v = create_vertice(x_, y_) - for j in range(nb_p): - if flag[j]: - continue - x, y = x_p[j], y_p[j] - if x > x_c_min and x < x_c_max and y > y_c_min and y < y_c_max: - if winding_number_poly(x, y, v) != 0: - flag[j] = True - return flag + return poly_indexs(x_p, y_p, x_c, y_c) != -1 diff --git a/src/py_eddy_tracker/tracking.py b/src/py_eddy_tracker/tracking.py index cba9985e..b64b6fcc 100644 --- a/src/py_eddy_tracker/tracking.py +++ b/src/py_eddy_tracker/tracking.py @@ -2,15 +2,14 @@ """ Class to store link between observations """ - +from datetime import datetime, timedelta import json import logging import platform -from datetime import datetime, timedelta +from tarfile import ExFileObject from netCDF4 import Dataset, default_fillvals -from numba import njit -from numba import types as numba_types +from numba import njit, types as numba_types from numpy import ( arange, array, @@ -161,10 +160,10 @@ def period(self): """ date_start = datetime(1950, 1, 1) + timedelta( - int(self.class_method.load_file(self.datasets[0]).time[0]) + self.class_method.load_file(self.datasets[0]).time[0] ) date_stop = datetime(1950, 1, 1) + timedelta( - int(self.class_method.load_file(self.datasets[-1]).time[0]) + self.class_method.load_file(self.datasets[-1]).time[0] ) return date_start, date_stop @@ -350,6 +349,9 @@ def load_state(self): self.virtual_obs = VirtualEddiesObservations.from_netcdf( general_handler.groups["LastVirtualObs"] ) + self.previous_virtual_obs = VirtualEddiesObservations.from_netcdf( + general_handler.groups["LastPreviousVirtualObs"] + ) # Load and last previous virtual obs to be merge with current => will be previous2_obs # TODO : Need to rethink this line ?? self.current_obs = self.current_obs.merge( @@ -373,7 +375,10 @@ def track(self): # We begin with second file, first one is in previous for file_name in self.datasets[first_dataset:]: self.swap_dataset(file_name, **kwargs) - logger.info("%s match with previous state", file_name) + filename_ = ( + file_name.filename if isinstance(file_name, ExFileObject) else file_name + ) + logger.info("%s match with previous state", filename_) logger.debug("%d obs to match", len(self.current_obs)) nb_real_obs = len(self.previous_obs) @@ -407,14 +412,14 @@ def to_netcdf(self, handler): logger.debug('Create Dimensions "Nstep" : %d', nb_step) handler.createDimension("Nstep", nb_step) var_file_in = handler.createVariable( - zlib=True, + zlib=False, complevel=1, varname="FileIn", datatype="S1024", dimensions="Nstep", ) var_file_out = handler.createVariable( - zlib=True, + zlib=False, complevel=1, varname="FileOut", datatype="S1024", @@ -584,7 +589,10 @@ def prepare_merging(self): def longer_than(self, size_min): """Remove from correspondance table all association for shorter eddies than size_min""" # Identify eddies longer than - i_keep_track = where(self.nb_obs_by_tracks >= size_min)[0] + mask = self.nb_obs_by_tracks >= size_min + if not mask.any(): + return False + i_keep_track = where(mask)[0] # Reduce array self.nb_obs_by_tracks = self.nb_obs_by_tracks[i_keep_track] self.i_current_by_tracks = ( @@ -653,7 +661,7 @@ def merge(self, until=-1, raw_data=True): # Set type of eddy with first file eddies.sign_type = self.current_obs.sign_type # Fields to copy - fields = self.current_obs.obs.dtype.names + fields = self.current_obs.fields # To know if the track start first_obs_save_in_tracks = zeros(self.i_current_by_tracks.shape, dtype=bool_) @@ -702,7 +710,7 @@ def merge(self, until=-1, raw_data=True): # Index in the current file index_current = self[i]["out"] - if "cost_association" in eddies.obs.dtype.names: + if "cost_association" in eddies.fields: eddies["cost_association"][index_final - 1] = self[i]["cost_value"] # Copy all variable for field in fields: diff --git a/src/scripts/EddyTranslate b/src/scripts/EddyTranslate index 94142132..a0060e9b 100644 --- a/src/scripts/EddyTranslate +++ b/src/scripts/EddyTranslate @@ -3,8 +3,8 @@ """ Translate eddy Dataset """ -import zarr from netCDF4 import Dataset +import zarr from py_eddy_tracker import EddyParser from py_eddy_tracker.observations.observation import EddiesObservations @@ -16,6 +16,8 @@ def id_parser(): ) parser.add_argument("filename_in") parser.add_argument("filename_out") + parser.add_argument("--unraw", action="store_true", help="Load unraw data, use only for netcdf." + "If unraw is active, netcdf is loaded without apply scalefactor and add_offset.") return parser @@ -32,10 +34,10 @@ def get_variable_name(filename): return list(h.keys()) -def get_variable(filename, varname): +def get_variable(filename, varname, raw=True): if is_nc(filename): dataset = EddiesObservations.load_from_netcdf( - filename, raw_data=True, include_vars=(varname,) + filename, raw_data=raw, include_vars=(varname,) ) else: dataset = EddiesObservations.load_from_zarr(filename, include_vars=(varname,)) @@ -49,8 +51,8 @@ if __name__ == "__main__": if not is_nc(args.filename_out): h = zarr.open(args.filename_out, "w") for varname in variables: - get_variable(args.filename_in, varname).to_zarr(h) + get_variable(args.filename_in, varname, raw=not args.unraw).to_zarr(h) else: with Dataset(args.filename_out, "w") as h: for varname in variables: - get_variable(args.filename_in, varname).to_netcdf(h) + get_variable(args.filename_in, varname, raw=not args.unraw).to_netcdf(h) diff --git a/tests/test_generic.py b/tests/test_generic.py index ab3832cc..ee2d7881 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,6 +1,6 @@ from numpy import arange, array, nan, ones, zeros -from py_eddy_tracker.generic import cumsum_by_track, simplify +from py_eddy_tracker.generic import cumsum_by_track, simplify, wrap_longitude def test_simplify(): @@ -30,3 +30,22 @@ def test_cumsum_by_track(): a = ones(10, dtype="i4") * 2 track = array([1, 1, 2, 2, 2, 2, 44, 44, 44, 48]) assert (cumsum_by_track(a, track) == [2, 4, 2, 4, 6, 8, 2, 4, 6, 2]).all() + + +def test_wrapping(): + y = x = arange(-5, 5, dtype="f4") + x_, _ = wrap_longitude(x, y, ref=-10) + assert (x_ == x).all() + x_, _ = wrap_longitude(x, y, ref=1) + assert x.size == x_.size + assert (x_[6:] == x[6:]).all() + assert (x_[:6] == x[:6] + 360).all() + x_, _ = wrap_longitude(x, y, ref=1, cut=True) + assert x.size + 3 == x_.size + assert (x_[6 + 3 :] == x[6:]).all() + assert (x_[:7] == x[:7] + 360).all() + + # FIXME Need evolution in wrap_longitude + # x %= 360 + # x_, _ = wrap_longitude(x, y, ref=-10, cut=True) + # assert x.size == x_.size diff --git a/tests/test_grid.py b/tests/test_grid.py index 2c89550a..0e6dd586 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,5 +1,5 @@ from matplotlib.path import Path -from numpy import array, isnan, ma +from numpy import arange, array, isnan, ma, nan, ones, zeros from pytest import approx from py_eddy_tracker.data import get_demo_path @@ -85,3 +85,40 @@ def test_interp(): assert g.interp("z", x0, y0) == 1.5 assert g.interp("z", x1, y1) == 2 assert isnan(g.interp("z", x2, y2)) + + +def test_convolution(): + """ + Add some dummy check on convolution filter + """ + # Fake grid + z = ma.array( + arange(12).reshape((-1, 1)) * arange(10).reshape((1, -1)), + mask=zeros((12, 10), dtype="bool"), + dtype="f4", + ) + g = RegularGridDataset.with_array( + coordinates=("x", "y"), + datas=dict( + z=z, + x=arange(0, 6, 0.5), + y=arange(0, 5, 0.5), + ), + centered=True, + ) + + def kernel_func(lat): + return ones((3, 3)) + + # After transpose we must get same result + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert (d.T[:9, :9] == d[:9, :9]).all() + # We mask one value and check convolution result + z.mask[2, 2] = True + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert d[1, 1] == z[:3, :3].sum() / 8 + # Add nan and check only nearest value is contaminate + z[2, 2] = nan + d = g.convolve_filter_with_dynamic_kernel("z", kernel_func) + assert not isnan(d[0, 0]) + assert isnan(d[1:4, 1:4]).all() diff --git a/tests/test_poly.py b/tests/test_poly.py index b2aacb73..a780f64d 100644 --- a/tests/test_poly.py +++ b/tests/test_poly.py @@ -1,7 +1,13 @@ -from numpy import array, pi +from numpy import array, pi, roll from pytest import approx -from py_eddy_tracker.poly import convex, fit_circle, get_convex_hull, poly_area_vertice +from py_eddy_tracker.poly import ( + convex, + fit_circle, + get_convex_hull, + poly_area_vertice, + visvalingam, +) # Vertices for next test V = array(((2, 2, 3, 3, 2), (-10, -9, -9, -10, -10))) @@ -16,7 +22,7 @@ def test_fit_circle(): x0, y0, r, err = fit_circle(*V) assert x0 == approx(2.5, rel=1e-10) assert y0 == approx(-9.5, rel=1e-10) - assert r == approx(2 ** 0.5 / 2, rel=1e-10) + assert r == approx(2**0.5 / 2, rel=1e-10) assert err == approx((1 - 2 / pi) * 100, rel=1e-10) @@ -29,3 +35,19 @@ def test_convex(): def test_convex_hull(): assert convex(*get_convex_hull(*V_concave)) is True + + +def test_visvalingam(): + x = array([1, 2, 3, 4, 5, 6.75, 6, 1]) + y = array([-0.5, -1.5, -1, -1.75, -1, -1, -0.5, -0.5]) + x_target = [1, 2, 3, 4, 6, 1] + y_target = [-0.5, -1.5, -1, -1.75, -0.5, -0.5] + x_, y_ = visvalingam(x, y, 6) + assert (x_target == x_).all() + assert (y_target == y_).all() + x_, y_ = visvalingam(x[:-1], y[:-1], 6) + assert (x_target == x_).all() + assert (y_target == y_).all() + x_, y_ = visvalingam(roll(x, 2), roll(y, 2), 6) + assert (x_target[:-1] == x_[1:]).all() + assert (y_target[:-1] == y_[1:]).all() diff --git a/tests/test_track.py b/tests/test_track.py index 4f362a26..f7e83786 100644 --- a/tests/test_track.py +++ b/tests/test_track.py @@ -1,5 +1,5 @@ -import zarr from netCDF4 import Dataset +import zarr from py_eddy_tracker.data import get_demo_path from py_eddy_tracker.featured_tracking.area_tracker import AreaTracker diff --git a/versioneer.py b/versioneer.py index 2b545405..1e3753e6 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,5 @@ -# Version: 0.18 + +# Version: 0.29 """The Versioneer - like a rocketeer, but for versions. @@ -6,18 +7,14 @@ ============== * like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer +* https://github.com/python-versioneer/python-versioneer * Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based +* License: Public Domain (Unlicense) +* Compatible with: Python 3.7, 3.8, 3.9, 3.10, 3.11 and pypy3 +* [![Latest Version][pypi-image]][pypi-url] +* [![Build Status][travis-image]][travis-url] + +This is a tool for managing a recorded version number in setuptools-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control @@ -26,9 +23,38 @@ ## Quick Install -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results +Versioneer provides two installation modes. The "classic" vendored mode installs +a copy of versioneer into your repository. The experimental build-time dependency mode +is intended to allow you to skip this step and simplify the process of upgrading. + +### Vendored mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) + * Note that you will need to add `tomli; python_version < "3.11"` to your + build-time dependencies if you use `pyproject.toml` +* run `versioneer install --vendor` in your source tree, commit the results +* verify version information with `python setup.py version` + +### Build-time dependency mode + +* `pip install versioneer` to somewhere in your $PATH + * A [conda-forge recipe](https://github.com/conda-forge/versioneer-feedstock) is + available, so you can also use `conda install -c conda-forge versioneer` +* add a `[tool.versioneer]` section to your `pyproject.toml` or a + `[versioneer]` section to your `setup.cfg` (see [Install](INSTALL.md)) +* add `versioneer` (with `[toml]` extra, if configuring in `pyproject.toml`) + to the `requires` key of the `build-system` table in `pyproject.toml`: + ```toml + [build-system] + requires = ["setuptools", "versioneer[toml]"] + build-backend = "setuptools.build_meta" + ``` +* run `versioneer install --no-vendor` in your source tree, commit the results +* verify version information with `python setup.py version` ## Version Identifiers @@ -60,7 +86,7 @@ for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. +uncommitted changes). The version identifier is used for multiple purposes: @@ -165,7 +191,7 @@ Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). +[issues page](https://github.com/python-versioneer/python-versioneer/issues). ### Subprojects @@ -179,7 +205,7 @@ `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. + provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs @@ -193,9 +219,9 @@ Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking +[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the +[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve @@ -223,31 +249,20 @@ cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes +[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace +* edit `setup.cfg` and `pyproject.toml`, if necessary, + to include any new configuration settings indicated by the release notes. + See [UPGRADING](./UPGRADING.md) for details. +* re-run `versioneer install --[no-]vendor` in your source tree, to replace `SRC/_version.py` * commit any changed files @@ -264,36 +279,70 @@ direction and include code from all supported VCS systems, reducing the number of intermediate scripts. +## Similar projects + +* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time + dependency +* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of + versioneer +* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools + plugin ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . +Specifically, both are released under the "Unlicense", as described in +https://unlicense.org/. -""" +[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg +[pypi-url]: https://pypi.python.org/pypi/versioneer/ +[travis-image]: +https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg +[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer -from __future__ import print_function +""" +# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring +# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements +# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error +# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with +# pylint:disable=attribute-defined-outside-init,too-many-arguments -try: - import configparser -except ImportError: - import ConfigParser as configparser +import configparser import errno import json import os import re import subprocess import sys +from pathlib import Path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from typing import NoReturn +import functools + +have_tomllib = True +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + versionfile_source: str + versionfile_build: Optional[str] + parentdir_prefix: Optional[str] + verbose: Optional[bool] -def get_root(): + +def get_root() -> str: """Get the project root directory. We require that all commands are run from the project root, i.e. the @@ -301,20 +350,28 @@ def get_root(): """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") + pyproject_toml = os.path.join(root, "pyproject.toml") versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) + if not ( + os.path.exists(setup_py) + or os.path.exists(pyproject_toml) + or os.path.exists(versioneer_py) + ): + err = ("Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND').") raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -323,46 +380,62 @@ def get_root(): # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py) - ) + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): + print("Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py)) except NameError: pass return root -def get_config_from_root(root): +def get_config_from_root(root: str) -> VersioneerConfig: """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or + # This might raise OSError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None + root_pth = Path(root) + pyproject_toml = root_pth / "pyproject.toml" + setup_cfg = root_pth / "setup.cfg" + section: Union[Dict[str, Any], configparser.SectionProxy, None] = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, 'rb') as fobj: + pp = tomllib.load(fobj) + section = pp['tool']['versioneer'] + except (tomllib.TOMLDecodeError, KeyError) as e: + print(f"Failed to load config from {pyproject_toml}: {e}") + print("Try to load it from setup.cfg") + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + # `cast`` really shouldn't be used, but its simplest for the + # common VersioneerConfig users at the moment. We verify against + # `None` values elsewhere where it matters cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): + cfg.VCS = section['VCS'] + cfg.style = section.get("style", "") + cfg.versionfile_source = cast(str, section.get("versionfile_source")) + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = cast(str, section.get("tag_prefix")) + if cfg.tag_prefix in ("''", '""', None): cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") + cfg.parentdir_prefix = section.get("parentdir_prefix") + if isinstance(section, configparser.SectionProxy): + # Make sure configparser translates to bool + cfg.verbose = section.getboolean("verbose") + else: + cfg.verbose = section.get("verbose") + return cfg @@ -371,41 +444,48 @@ class NotThisMethod(Exception): # these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f + HANDLERS.setdefault(vcs, {})[method] = f return f - return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -416,28 +496,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = ''' +LONG_VERSION_PY['git'] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" @@ -446,9 +523,11 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import functools -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -464,8 +543,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool + -def get_config(): +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -483,13 +569,13 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -498,22 +584,35 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -524,18 +623,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -544,15 +645,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% @@ -561,41 +661,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -608,11 +715,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d @@ -621,7 +728,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: @@ -630,6 +737,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %%s" %% r) return {"version": r, @@ -645,7 +757,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -656,8 +773,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) @@ -665,24 +789,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -699,7 +856,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces @@ -724,26 +881,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -768,23 +926,71 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] + rendered = "0.post0.dev%%d" %% pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -811,12 +1017,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -833,7 +1068,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -853,7 +1088,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -873,7 +1108,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", @@ -887,10 +1122,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -905,7 +1144,7 @@ def render(pieces, style): "date": pieces.get("date")} -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some @@ -926,7 +1165,7 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, @@ -953,41 +1192,48 @@ def get_versions(): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -1000,11 +1246,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1013,7 +1259,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) + tags = {r for r in refs if re.search(r'\d', r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1021,30 +1267,33 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue if verbose: print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, + root: str, + verbose: bool, + runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -1055,7 +1304,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1063,33 +1320,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -1098,16 +1379,17 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] + git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%s'" + % describe_out) return pieces # tag @@ -1116,12 +1398,10 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) + pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" + % (full_tag, tag_prefix)) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] + pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1132,19 +1412,20 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def do_vcs_install(manifest_in, versionfile_source, ipy): +def do_vcs_install(versionfile_source: str, ipy: Optional[str]) -> None: """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py @@ -1153,36 +1434,40 @@ def do_vcs_install(manifest_in, versionfile_source, ipy): GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] + files = [versionfile_source] if ipy: files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) present = False try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: pass if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -1191,30 +1476,23 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) + print("Tried directories %s but none started with prefix %s" % + (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from +# This file was generated by 'versioneer.py' (0.29) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. @@ -1231,43 +1509,41 @@ def get_versions(): """ -def versions_from_file(filename): +def versions_from_file(filename: str) -> Dict[str, Any]: """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() - except EnvironmentError: + except OSError: raise NotThisMethod("unable to read _version.py") - mo = re.search( - r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: - mo = re.search( - r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S - ) + mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) -def write_to_version_file(filename, versions): +def write_to_version_file(filename: str, versions: Dict[str, Any]) -> None: """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, + indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) print("set %s to '%s'" % (filename, versions["version"])) -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -1285,29 +1561,78 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -1334,12 +1659,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -1356,7 +1710,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -1376,7 +1730,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -1396,26 +1750,28 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -1425,20 +1781,16 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" -def get_versions(verbose=False): +def get_versions(verbose: bool = False) -> Dict[str, Any]: """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. @@ -1453,10 +1805,9 @@ def get_versions(verbose=False): assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` + assert cfg.versionfile_source is not None, \ + "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1510,22 +1861,22 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, "error": "unable to compute version", + "date": None} -def get_version(): +def get_version() -> str: """Get the short version string for this project.""" return get_versions()["version"] -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" +def get_cmdclass(cmdclass: Optional[Dict[str, Any]] = None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it + should be provide as an argument. + """ if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and @@ -1539,25 +1890,25 @@ def get_cmdclass(): # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - cmds = {} + cmds = {} if cmdclass is None else cmdclass.copy() - # we add "version" to both distutils and setuptools - from distutils.core import Command + # we add "version" to setuptools + from setuptools import Command class cmd_version(Command): description = "report generated version string" - user_options = [] - boolean_options = [] + user_options: List[Tuple[str, str, str]] = [] + boolean_options: List[str] = [] - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) @@ -1565,10 +1916,9 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - # we override "build_py" in both distutils and setuptools + # we override "build_py" in setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py @@ -1583,30 +1933,68 @@ def run(self): # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py + if 'build_py' in cmds: + _build_py: Any = cmds['build_py'] else: - from distutils.command.build_py import build_py as _build_py + from setuptools.command.build_py import build_py as _build_py class cmd_build_py(_build_py): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe + if 'build_ext' in cmds: + _build_ext: Any = cmds['build_ext'] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self) -> None: + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join(self.build_lib, + cfg.versionfile_build) + if not os.path.exists(target_versionfile): + print(f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py.") + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + cmds["build_ext"] = cmd_build_ext + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe # type: ignore # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1615,7 +2003,7 @@ def run(self): # ... class cmd_build_exe(_build_exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1627,28 +2015,24 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if "py2exe" in sys.modules: # py2exe enabled? + if 'py2exe' in sys.modules: # py2exe enabled? try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 + from py2exe.setuptools_buildexe import py2exe as _py2exe # type: ignore except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 + from py2exe.distutils_buildexe import py2exe as _py2exe # type: ignore class cmd_py2exe(_py2exe): - def run(self): + def run(self) -> None: root = get_root() cfg = get_config_from_root(root) versions = get_versions() @@ -1660,27 +2044,60 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - + f.write(LONG % + {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) cmds["py2exe"] = cmd_py2exe + # sdist farms its file list building out to egg_info + if 'egg_info' in cmds: + _egg_info: Any = cmds['egg_info'] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self) -> None: + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append('versioneer.py') + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + normalized = [unicode_utils.filesys_decode(f).replace(os.sep, '/') + for f in self.filelist.files] + + manifest_filename = os.path.join(self.egg_info, 'SOURCES.txt') + with open(manifest_filename, 'w') as fobj: + fobj.write('\n'.join(normalized)) + + cmds['egg_info'] = cmd_egg_info + # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist + if 'sdist' in cmds: + _sdist: Any = cmds['sdist'] else: - from distutils.command.sdist import sdist as _sdist + from setuptools.command.sdist import sdist as _sdist class cmd_sdist(_sdist): - def run(self): + def run(self) -> None: versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old @@ -1688,7 +2105,7 @@ def run(self): self.distribution.metadata.version = versions["version"] return _sdist.run(self) - def make_release_tree(self, base_dir, files): + def make_release_tree(self, base_dir: str, files: List[str]) -> None: root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) @@ -1697,10 +2114,8 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file( - target_versionfile, self._versioneer_generated_versions - ) - + write_to_version_file(target_versionfile, + self._versioneer_generated_versions) cmds["sdist"] = cmd_sdist return cmds @@ -1743,25 +2158,28 @@ def make_release_tree(self, base_dir, files): """ -INIT_PY_SNIPPET = """ +OLD_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" +def do_setup() -> int: + """Do main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) - except ( - EnvironmentError, - configparser.NoSectionError, - configparser.NoOptionError, - ) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) + except (OSError, configparser.NoSectionError, + configparser.NoOptionError) as e: + if isinstance(e, (OSError, configparser.NoSectionError)): + print("Adding sample versioneer config to setup.cfg", + file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1770,76 +2188,46 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + f.write(LONG % {"DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + }) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), + "__init__.py") + maybe_ipy: Optional[str] = ipy if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() - except EnvironmentError: + except OSError: old = "" - if INIT_PY_SNIPPET not in old: + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) + f.write(snippet) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print( - " appending versionfile_source ('%s') to MANIFEST.in" - % cfg.versionfile_source - ) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") + maybe_ipy = None # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) + do_vcs_install(cfg.versionfile_source, maybe_ipy) return 0 -def scan_setup_py(): +def scan_setup_py() -> int: """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False @@ -1876,10 +2264,14 @@ def scan_setup_py(): return errors +def setup_command() -> NoReturn: + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1) + setup_command()