2828=========================================================================== 
2929
3030""" 
31- from  numpy  import  zeros , empty , nan , arange , interp , where , unique 
31+ from  numpy  import  zeros , empty , nan , arange , interp , where , unique , \
32+     ma 
3233from  netCDF4  import  Dataset 
3334from  py_eddy_tracker .tools  import  distance_matrix 
3435from  . import  VAR_DESCR , VAR_DESCR_inv 
@@ -82,9 +83,11 @@ def __getitem__(self, attr):
8283        if  attr  in  self .elements :
8384            return  self .observations [attr ]
8485        raise  KeyError ('%s unknown'  %  attr )
85-      
86+ 
8687    @property  
8788    def  dtype (self ):
89+         """Return dtype to build numpy array 
90+         """ 
8891        dtype  =  list ()
8992        for  elt  in  self .elements :
9093            dtype .append ((elt , VAR_DESCR [elt ][
@@ -94,6 +97,8 @@ def dtype(self):
9497
9598    @property  
9699    def  elements (self ):
100+         """Return all variable name 
101+         """ 
97102        elements  =  [
98103            'lon' ,  # 'centlon' 
99104            'lat' ,  # 'centlat' 
@@ -113,9 +118,13 @@ def elements(self):
113118        return  elements 
114119
115120    def  coherence (self , other ):
121+         """Check coherence between two dataset 
122+         """ 
116123        return  self .track_extra_variables  ==  other .track_extra_variables 
117-      
124+ 
118125    def  merge (self , other ):
126+         """Merge two dataset 
127+         """ 
119128        nb_obs_self  =  len (self )
120129        nb_obs  =  nb_obs_self  +  len (other )
121130        eddies  =  self .__class__ (size = nb_obs )
@@ -126,6 +135,8 @@ def merge(self, other):
126135
127136    @property  
128137    def  obs (self ):
138+         """returan array observations 
139+         """ 
129140        return  self .observations 
130141
131142    def  __len__ (self ):
@@ -136,6 +147,8 @@ def __iter__(self):
136147            yield  obs 
137148
138149    def  insert_observations (self , other , index ):
150+         """Insert other obs in self at the index 
151+         """ 
139152        if  not  self .coherence (other ):
140153            raise  Exception ('Observations with no coherence' )
141154        insert_size  =  len (other .obs )
@@ -156,6 +169,8 @@ def insert_observations(self, other, index):
156169        return  self 
157170
158171    def  append (self , other ):
172+         """Merge 
173+         """ 
159174        return  self  +  other 
160175
161176    def  __add__ (self , other ):
@@ -172,13 +187,15 @@ def distance(self, other):
172187        return  dist_result 
173188
174189    def  index (self , index ):
190+         """Return obs from self at the index 
191+         """ 
175192        size  =  1 
176193        if  hasattr (index , '__iter__' ):
177194            size  =  len (index )
178195        eddies  =  self .__class__ (size , self .track_extra_variables )
179196        eddies .obs [:] =  self .obs [index ]
180197        return  eddies 
181-      
198+ 
182199    @staticmethod  
183200    def  load_from_netcdf (filename ):
184201        with  Dataset (filename ) as  h_nc :
@@ -187,22 +204,79 @@ def load_from_netcdf(filename):
187204            for  variable  in  h_nc .variables :
188205                if  variable  ==  'cyc' :
189206                    continue 
190-                 eddies .obs [VAR_DESCR_inv [variable ]] =  h_nc .variables [variable ][:]
207+                 eddies .obs [VAR_DESCR_inv [variable ]
208+                            ] =  h_nc .variables [variable ][:]
191209            eddies .sign_type  =  h_nc .variables ['cyc' ][0 ]
192210        return  eddies 
193211
194212    def  tracking (self , other ):
195-         """Track obs between from  self to  other 
213+         """Track obs between self and  other 
196214        """ 
197-         dist  =  self .distance (other )
198-         i_self , i_other  =  where (dist  <  20. )
215+         cost  =  self .distance (other )
216+         # Links available which respect a maximal cost 
217+         mask_accept_cost  =  cost  <  75 
218+         cost  =  ma .array (cost , mask = - mask_accept_cost , dtype = 'i2' )
219+ 
220+         # Count number of link by self obs and other obs 
221+         self_links  =  mask_accept_cost .sum (axis = 1 )
222+         other_links  =  mask_accept_cost .sum (axis = 0 )
223+         max_links  =  max (self_links .max (), other_links .max ())
224+         if  max_links  >  5 :
225+             logging .warning ('One observation have %d links' , max_links )
226+ 
227+         # If some obs have multiple link, we keep only one link by eddy 
228+         eddies_separation  =  1  <  self_links 
229+         eddies_merge  =  1  <  other_links 
230+         test  =  eddies_separation .any () or  eddies_merge .any ()
231+         if  test :
232+             # We extract matrix which contains concflict 
233+             obs_linking_to_self  =  mask_accept_cost [eddies_separation 
234+                                                    ].any (axis = 0 )
235+             obs_linking_to_other  =  mask_accept_cost [:, eddies_merge 
236+                                                     ].any (axis = 1 )
237+             i_self_keep  =  where (obs_linking_to_other  +  eddies_separation )[0 ]
238+             i_other_keep  =  where (obs_linking_to_self  +  eddies_merge )[0 ]
239+ 
240+             # Cost to resolve conflict 
241+             cost_reduce  =  cost [i_self_keep ][:, i_other_keep ]
242+             shape  =  cost_reduce .shape 
243+             logging .debug ('Shape conflict matrix : %s' , shape )
244+ 
245+             matrix_size  =  shape [0 ] *  shape [1 ]
246+             if  (matrix_size ) >=  20000 :
247+                 logging .warning ('High number of conflict : %d (matrix_size)' ,
248+                                 matrix_size )
249+ 
250+             links_resolve  =  0 
251+             while  False  in  cost_reduce .mask :
252+                 i_min_value  =  cost_reduce .argmin ()
253+                 i , j  =  i_min_value  /  shape [1 ], i_min_value  %  shape [1 ]
254+                 # Set to False all link 
255+                 mask_accept_cost [i_self_keep [i ]] =  False 
256+                 mask_accept_cost [:, i_other_keep [j ]] =  False 
257+                 cost_reduce .mask [i ] =  True 
258+                 cost_reduce .mask [:, j ] =  True 
259+                 # we active only this link 
260+                 mask_accept_cost [i_self_keep [i ], i_other_keep [j ]] =  True 
261+                 links_resolve  +=  1 
262+             logging .debug ('%d links resolve' , links_resolve )
263+ 
264+         i_self , i_other  =  where (mask_accept_cost )
199265
200266        logging .debug ('%d matched with previous' , i_self .shape [0 ])
267+ 
268+         # Check 
269+         if  unique (i_other ).shape [0 ] !=  i_other .shape [0 ]:
270+             raise  Exception ()
271+         if  unique (i_self ).shape [0 ] !=  i_self .shape [0 ]:
272+             raise  Exception ()
201273        return  i_self , i_other 
202274
203275
204276class  VirtualEddiesObservations (EddiesObservations ):
205-     
277+     """Class to work with virtual obs 
278+     """ 
279+ 
206280    @property  
207281    def  elements (self ):
208282        elements  =  super (VirtualEddiesObservations , self ).elements 
@@ -211,7 +285,9 @@ def elements(self):
211285
212286
213287class  TrackEddiesObservations (EddiesObservations ):
214-     
288+     """Class to practice Tracking on observations 
289+     """ 
290+ 
215291    def  filled_by_interpolation (self , mask ):
216292        """Filled selected values by interpolation 
217293        """ 
@@ -228,42 +304,47 @@ def filled_by_interpolation(self, mask):
228304                                         self .obs [var ][- mask ])
229305
230306    def  extract_longer_eddies (self , nb_min , nb_obs , compress_id = True ):
231-         m  =  nb_obs  >=  nb_min 
232-         nb_obs_select  =  m .sum ()
307+         """Select eddies which are longer than nb_min 
308+         """ 
309+         mask  =  nb_obs  >=  nb_min 
310+         nb_obs_select  =  mask .sum ()
233311        logging .info ('Selection of %d observations' , nb_obs_select )
234312        eddies  =  TrackEddiesObservations (size = nb_obs_select )
235313        eddies .sign_type  =  self .sign_type 
236314        for  var , _  in  eddies .obs .dtype .descr :
237-             eddies .obs [var ] =  self .obs [var ][m ]
315+             eddies .obs [var ] =  self .obs [var ][mask ]
238316        if  compress_id :
239317            list_id  =  unique (eddies .obs ['track' ])
240318            list_id .sort ()
241319            id_translate  =  arange (list_id .max () +  1 )
242320            id_translate [list_id ] =  arange (len (list_id )) +  1 
243321            eddies .obs ['track' ] =  id_translate [eddies .obs ['track' ]]
244322        return  eddies 
245-      
323+ 
246324    @property  
247325    def  elements (self ):
248326        elements  =  super (TrackEddiesObservations , self ).elements 
249327        elements .extend (['track' , 'n' , 'virtual' ])
250328        return  elements 
251-     
252-     def  create_variable (self , handler_nc , kwargs_variable ,
329+ 
330+     @staticmethod  
331+     def  create_variable (handler_nc , kwargs_variable ,
253332                        attr_variable , data , scale_factor = None ):
333+         """Create variable 
334+         """ 
254335        var  =  handler_nc .createVariable (
255336            zlib = True ,
256337            complevel = 1 ,
257338            ** kwargs_variable )
258339        for  attr , attr_value  in  attr_variable .iteritems ():
259340            var .setncattr (attr , attr_value )
260-              
341+ 
261342        var [:] =  data 
262-          
343+ 
263344        #~ var.set_auto_maskandscale(False) 
264345        if  scale_factor  is  not None :
265346            var .scale_factor  =  scale_factor 
266-              
347+ 
267348        try :
268349            var .setncattr ('min' , var [:].min ())
269350            var .setncattr ('max' , var [:].max ())
@@ -291,7 +372,10 @@ def write_netcdf(self):
291372                         dimensions = VAR_DESCR [name ]['nc_dims' ]),
292373                    VAR_DESCR [name ]['nc_attr' ],
293374                    self .observations [name ],
294-                     scale_factor = None  if  'scale_factor'  not  in VAR_DESCR [name ] else  VAR_DESCR [name ]['scale_factor' ])
375+                     scale_factor = None 
376+                         if  'scale_factor'  not  in VAR_DESCR [name ] else 
377+                         VAR_DESCR [name ]['scale_factor' ]
378+                     )
295379
296380            # Add cyclonic information 
297381            self .create_variable (
@@ -305,7 +389,12 @@ def write_netcdf(self):
305389            self .set_global_attr_netcdf (h_nc )
306390
307391    def  set_global_attr_netcdf (self , h_nc ):
308-         h_nc .title  =  'Cyclonic'  if  self .sign_type  ==  - 1  else  'Anticyclonic'  +  ' eddy tracks' 
392+         """Set global attr 
393+         """ 
394+         if  self .sign_type  ==  - 1 :
395+             h_nc .title  =  'Cyclonic' 
396+         else :
397+             h_nc .title  =  'Anticyclonic'  +  ' eddy tracks' 
309398        #~ h_nc.grid_filename = self.grd.grid_filename 
310399        #~ h_nc.grid_date = str(self.grd.grid_date) 
311400        #~ h_nc.product = self.product 
@@ -324,7 +413,7 @@ def set_global_attr_netcdf(self, h_nc):
324413        #~ h_nc.evolve_amp_max = self.evolve_amp_max 
325414        #~ h_nc.evolve_area_min = self.evolve_area_min 
326415        #~ h_nc.evolve_area_max = self.evolve_area_max 
327- #~  
416+ 
328417        #~ h_nc.llcrnrlon = self.grd.lonmin 
329418        #~ h_nc.urcrnrlon = self.grd.lonmax 
330419        #~ h_nc.llcrnrlat = self.grd.latmin 
0 commit comments