2727===========================================================================
2828
2929"""
30- from numpy import zeros , empty , nan , arange , interp , where , unique , \
30+ from numpy import zeros , empty , nan , arange , where , unique , \
3131 ma , concatenate , cos , radians , isnan , ones , ndarray , meshgrid , \
32- bincount , bool_ , array , interp , int_ , int32 , round , maximum , floor
32+ array , interp , int_ , int32 , round , maximum , floor
3333from scipy .interpolate import interp1d
3434from netCDF4 import Dataset
3535from py_eddy_tracker .tools import distance_matrix , distance_vector
3838from . import VAR_DESCR , VAR_DESCR_inv
3939import logging
4040from datetime import datetime
41- from scipy .interpolate import RectBivariateSpline
4241
4342
4443class GridDataset (object ):
@@ -185,6 +184,10 @@ def __init__(self, size=0, track_extra_variables=None,
185184 self .active = True
186185 self .sign_type = None
187186
187+ @property
188+ def sign_legend (self ):
189+ return 'Cyclonic' if self .sign_type == - 1 else 'Anticyclonic'
190+
188191 @property
189192 def shape (self ):
190193 return self .observations .shape
@@ -222,7 +225,7 @@ def elements(self):
222225
223226 if len (self .track_extra_variables ):
224227 elements += self .track_extra_variables
225- return elements
228+ return list ( set ( elements ))
226229
227230 def coherence (self , other ):
228231 """Check coherence between two dataset
@@ -341,6 +344,8 @@ def index(self, index):
341344 @classmethod
342345 def load_from_netcdf (cls , filename ):
343346 array_dim = 'NbSample'
347+ if not isinstance (filename , str ):
348+ filename = filename .astype (str )
344349 with Dataset (filename ) as h_nc :
345350 nb_obs = len (h_nc .dimensions ['Nobs' ])
346351 kwargs = dict ()
@@ -386,6 +391,53 @@ def from_netcdf(cls, handler):
386391 eddies .obs [VAR_DESCR_inv [variable ]] = handler .variables [variable ][:]
387392 return eddies
388393
394+ @staticmethod
395+ def propagate (previous_obs , current_obs , obs_to_extend , dead_track , nb_next , model ):
396+ """
397+ Filled virtual obs (C)
398+ Args:
399+ previous_obs: previous obs from current (A)
400+ current_obs: previous obs from virtual (B)
401+ obs_to_extend:
402+ dead_track:
403+ nb_next:
404+ model:
405+
406+ Returns:
407+ New position C = B + AB
408+ """
409+ next_obs = VirtualEddiesObservations (
410+ size = nb_next ,
411+ track_extra_variables = model .track_extra_variables ,
412+ track_array_variables = model .track_array_variables ,
413+ array_variables = model .array_variables )
414+ nb_dead = len (previous_obs )
415+ nb_virtual_extend = nb_next - nb_dead
416+
417+ for key in model .elements :
418+ if key in ['lon' , 'lat' , 'time' ] or 'contour_' in key :
419+ continue
420+ next_obs [key ][:nb_dead ] = current_obs [key ]
421+ next_obs ['dlon' ][:nb_dead ] = current_obs ['lon' ] - previous_obs ['lon' ]
422+ next_obs ['dlat' ][:nb_dead ] = current_obs ['lat' ] - previous_obs ['lat' ]
423+ next_obs ['lon' ][:nb_dead ] = current_obs ['lon' ] + next_obs ['dlon' ][:nb_dead ]
424+ next_obs ['lat' ][:nb_dead ] = current_obs ['lat' ] + next_obs ['dlat' ][:nb_dead ]
425+ # Id which are extended
426+ next_obs ['track' ][:nb_dead ] = dead_track
427+ # Add previous virtual
428+ if nb_virtual_extend > 0 :
429+ for key in next_obs .elements :
430+ if key in ['lon' , 'lat' , 'time' , 'track' , 'segment_size' ] or 'contour_' in key :
431+ continue
432+ next_obs [key ][nb_dead :] = obs_to_extend [key ]
433+ next_obs ['lon' ][nb_dead :] = obs_to_extend ['lon' ] + obs_to_extend ['dlon' ]
434+ next_obs ['lat' ][nb_dead :] = obs_to_extend ['lat' ] + obs_to_extend ['dlat' ]
435+ next_obs ['track' ][nb_dead :] = obs_to_extend ['track' ]
436+ next_obs ['segment_size' ][nb_dead :] = obs_to_extend ['segment_size' ]
437+ # Count
438+ next_obs ['segment_size' ][:] += 1
439+ return next_obs
440+
389441 @staticmethod
390442 def cost_function2 (records_in , records_out , distance ):
391443 nb_records = records_in .shape [0 ]
@@ -715,8 +767,7 @@ def write_netcdf(self, path='./', filename='%(path)s/%(sign_type)s.nc'):
715767 """Write a netcdf with eddy obs
716768 """
717769 eddy_size = len (self .observations )
718- sign_type = 'Cyclonic' if self .sign_type == - 1 else 'Anticyclonic'
719- filename = filename % dict (path = path , sign_type = sign_type , prod_time = datetime .now ().strftime ('%Y%m%d' ))
770+ filename = filename % dict (path = path , sign_type = self .sign_legend , prod_time = datetime .now ().strftime ('%Y%m%d' ))
720771 logging .info ('Store in %s' , filename )
721772 with Dataset (filename , 'w' , format = 'NETCDF4' ) as h_nc :
722773 logging .info ('Create file %s' , filename )
@@ -763,7 +814,7 @@ class VirtualEddiesObservations(EddiesObservations):
763814 def elements (self ):
764815 elements = super (VirtualEddiesObservations , self ).elements
765816 elements .extend (['track' , 'segment_size' , 'dlon' , 'dlat' ])
766- return elements
817+ return list ( set ( elements ))
767818
768819
769820class TrackEddiesObservations (EddiesObservations ):
@@ -783,8 +834,16 @@ def filled_by_interpolation(self, mask):
783834 var = field [0 ]
784835 if var in ['n' , 'virtual' , 'track' ] or var in self .array_variables :
785836 continue
786- self .obs [var ][mask ] = interp (index [mask ], index [~ mask ],
787- self .obs [var ][~ mask ])
837+ # to normalize longitude before interpolation
838+ if var == 'lon' :
839+ lon = self .obs [var ]
840+ first = where (self .obs ['n' ] == 0 )[0 ]
841+ nb_obs = empty (first .shape , dtype = 'u4' )
842+ nb_obs [:- 1 ] = first [1 :] - first [:- 1 ]
843+ nb_obs [- 1 ] = lon .shape [0 ] - first [- 1 ]
844+ lon0 = (lon [first ] - 180 ).repeat (nb_obs )
845+ self .obs [var ] = (lon - lon0 ) % 360 + lon0
846+ self .obs [var ][mask ] = interp (index [mask ], index [~ mask ], self .obs [var ][~ mask ])
788847
789848 def extract_longer_eddies (self , nb_min , nb_obs , compress_id = True ):
790849 """Select eddies which are longer than nb_min
0 commit comments