2828===========================================================================
2929
3030"""
31- from numpy import zeros , empty , nan , arange , interp , where , unique
31+ from numpy import zeros , empty , nan , arange , interp , where , unique , \
32+ ma
3233from netCDF4 import Dataset
3334from py_eddy_tracker .tools import distance_matrix
3435from . import VAR_DESCR , VAR_DESCR_inv
@@ -82,9 +83,11 @@ def __getitem__(self, attr):
8283 if attr in self .elements :
8384 return self .observations [attr ]
8485 raise KeyError ('%s unknown' % attr )
85-
86+
8687 @property
8788 def dtype (self ):
89+ """Return dtype to build numpy array
90+ """
8891 dtype = list ()
8992 for elt in self .elements :
9093 dtype .append ((elt , VAR_DESCR [elt ][
@@ -94,6 +97,8 @@ def dtype(self):
9497
9598 @property
9699 def elements (self ):
100+ """Return all variable name
101+ """
97102 elements = [
98103 'lon' , # 'centlon'
99104 'lat' , # 'centlat'
@@ -113,9 +118,13 @@ def elements(self):
113118 return elements
114119
115120 def coherence (self , other ):
121+ """Check coherence between two dataset
122+ """
116123 return self .track_extra_variables == other .track_extra_variables
117-
124+
118125 def merge (self , other ):
126+ """Merge two dataset
127+ """
119128 nb_obs_self = len (self )
120129 nb_obs = nb_obs_self + len (other )
121130 eddies = self .__class__ (size = nb_obs )
@@ -126,6 +135,8 @@ def merge(self, other):
126135
127136 @property
128137 def obs (self ):
138+ """returan array observations
139+ """
129140 return self .observations
130141
131142 def __len__ (self ):
@@ -136,6 +147,8 @@ def __iter__(self):
136147 yield obs
137148
138149 def insert_observations (self , other , index ):
150+ """Insert other obs in self at the index
151+ """
139152 if not self .coherence (other ):
140153 raise Exception ('Observations with no coherence' )
141154 insert_size = len (other .obs )
@@ -156,6 +169,8 @@ def insert_observations(self, other, index):
156169 return self
157170
158171 def append (self , other ):
172+ """Merge
173+ """
159174 return self + other
160175
161176 def __add__ (self , other ):
@@ -172,13 +187,15 @@ def distance(self, other):
172187 return dist_result
173188
174189 def index (self , index ):
190+ """Return obs from self at the index
191+ """
175192 size = 1
176193 if hasattr (index , '__iter__' ):
177194 size = len (index )
178195 eddies = self .__class__ (size , self .track_extra_variables )
179196 eddies .obs [:] = self .obs [index ]
180197 return eddies
181-
198+
182199 @staticmethod
183200 def load_from_netcdf (filename ):
184201 with Dataset (filename ) as h_nc :
@@ -187,22 +204,79 @@ def load_from_netcdf(filename):
187204 for variable in h_nc .variables :
188205 if variable == 'cyc' :
189206 continue
190- eddies .obs [VAR_DESCR_inv [variable ]] = h_nc .variables [variable ][:]
207+ eddies .obs [VAR_DESCR_inv [variable ]
208+ ] = h_nc .variables [variable ][:]
191209 eddies .sign_type = h_nc .variables ['cyc' ][0 ]
192210 return eddies
193211
194212 def tracking (self , other ):
195- """Track obs between from self to other
213+ """Track obs between self and other
196214 """
197- dist = self .distance (other )
198- i_self , i_other = where (dist < 20. )
215+ cost = self .distance (other )
216+ # Links available which respect a maximal cost
217+ mask_accept_cost = cost < 75
218+ cost = ma .array (cost , mask = - mask_accept_cost , dtype = 'i2' )
219+
220+ # Count number of link by self obs and other obs
221+ self_links = mask_accept_cost .sum (axis = 1 )
222+ other_links = mask_accept_cost .sum (axis = 0 )
223+ max_links = max (self_links .max (), other_links .max ())
224+ if max_links > 5 :
225+ logging .warning ('One observation have %d links' , max_links )
226+
227+ # If some obs have multiple link, we keep only one link by eddy
228+ eddies_separation = 1 < self_links
229+ eddies_merge = 1 < other_links
230+ test = eddies_separation .any () or eddies_merge .any ()
231+ if test :
232+ # We extract matrix which contains concflict
233+ obs_linking_to_self = mask_accept_cost [eddies_separation
234+ ].any (axis = 0 )
235+ obs_linking_to_other = mask_accept_cost [:, eddies_merge
236+ ].any (axis = 1 )
237+ i_self_keep = where (obs_linking_to_other + eddies_separation )[0 ]
238+ i_other_keep = where (obs_linking_to_self + eddies_merge )[0 ]
239+
240+ # Cost to resolve conflict
241+ cost_reduce = cost [i_self_keep ][:, i_other_keep ]
242+ shape = cost_reduce .shape
243+ logging .debug ('Shape conflict matrix : %s' , shape )
244+
245+ matrix_size = shape [0 ] * shape [1 ]
246+ if (matrix_size ) >= 20000 :
247+ logging .warning ('High number of conflict : %d (matrix_size)' ,
248+ matrix_size )
249+
250+ links_resolve = 0
251+ while False in cost_reduce .mask :
252+ i_min_value = cost_reduce .argmin ()
253+ i , j = i_min_value / shape [1 ], i_min_value % shape [1 ]
254+ # Set to False all link
255+ mask_accept_cost [i_self_keep [i ]] = False
256+ mask_accept_cost [:, i_other_keep [j ]] = False
257+ cost_reduce .mask [i ] = True
258+ cost_reduce .mask [:, j ] = True
259+ # we active only this link
260+ mask_accept_cost [i_self_keep [i ], i_other_keep [j ]] = True
261+ links_resolve += 1
262+ logging .debug ('%d links resolve' , links_resolve )
263+
264+ i_self , i_other = where (mask_accept_cost )
199265
200266 logging .debug ('%d matched with previous' , i_self .shape [0 ])
267+
268+ # Check
269+ if unique (i_other ).shape [0 ] != i_other .shape [0 ]:
270+ raise Exception ()
271+ if unique (i_self ).shape [0 ] != i_self .shape [0 ]:
272+ raise Exception ()
201273 return i_self , i_other
202274
203275
204276class VirtualEddiesObservations (EddiesObservations ):
205-
277+ """Class to work with virtual obs
278+ """
279+
206280 @property
207281 def elements (self ):
208282 elements = super (VirtualEddiesObservations , self ).elements
@@ -211,7 +285,9 @@ def elements(self):
211285
212286
213287class TrackEddiesObservations (EddiesObservations ):
214-
288+ """Class to practice Tracking on observations
289+ """
290+
215291 def filled_by_interpolation (self , mask ):
216292 """Filled selected values by interpolation
217293 """
@@ -228,42 +304,47 @@ def filled_by_interpolation(self, mask):
228304 self .obs [var ][- mask ])
229305
230306 def extract_longer_eddies (self , nb_min , nb_obs , compress_id = True ):
231- m = nb_obs >= nb_min
232- nb_obs_select = m .sum ()
307+ """Select eddies which are longer than nb_min
308+ """
309+ mask = nb_obs >= nb_min
310+ nb_obs_select = mask .sum ()
233311 logging .info ('Selection of %d observations' , nb_obs_select )
234312 eddies = TrackEddiesObservations (size = nb_obs_select )
235313 eddies .sign_type = self .sign_type
236314 for var , _ in eddies .obs .dtype .descr :
237- eddies .obs [var ] = self .obs [var ][m ]
315+ eddies .obs [var ] = self .obs [var ][mask ]
238316 if compress_id :
239317 list_id = unique (eddies .obs ['track' ])
240318 list_id .sort ()
241319 id_translate = arange (list_id .max () + 1 )
242320 id_translate [list_id ] = arange (len (list_id )) + 1
243321 eddies .obs ['track' ] = id_translate [eddies .obs ['track' ]]
244322 return eddies
245-
323+
246324 @property
247325 def elements (self ):
248326 elements = super (TrackEddiesObservations , self ).elements
249327 elements .extend (['track' , 'n' , 'virtual' ])
250328 return elements
251-
252- def create_variable (self , handler_nc , kwargs_variable ,
329+
330+ @staticmethod
331+ def create_variable (handler_nc , kwargs_variable ,
253332 attr_variable , data , scale_factor = None ):
333+ """Create variable
334+ """
254335 var = handler_nc .createVariable (
255336 zlib = True ,
256337 complevel = 1 ,
257338 ** kwargs_variable )
258339 for attr , attr_value in attr_variable .iteritems ():
259340 var .setncattr (attr , attr_value )
260-
341+
261342 var [:] = data
262-
343+
263344 #~ var.set_auto_maskandscale(False)
264345 if scale_factor is not None :
265346 var .scale_factor = scale_factor
266-
347+
267348 try :
268349 var .setncattr ('min' , var [:].min ())
269350 var .setncattr ('max' , var [:].max ())
@@ -291,7 +372,10 @@ def write_netcdf(self):
291372 dimensions = VAR_DESCR [name ]['nc_dims' ]),
292373 VAR_DESCR [name ]['nc_attr' ],
293374 self .observations [name ],
294- scale_factor = None if 'scale_factor' not in VAR_DESCR [name ] else VAR_DESCR [name ]['scale_factor' ])
375+ scale_factor = None
376+ if 'scale_factor' not in VAR_DESCR [name ] else
377+ VAR_DESCR [name ]['scale_factor' ]
378+ )
295379
296380 # Add cyclonic information
297381 self .create_variable (
@@ -305,7 +389,12 @@ def write_netcdf(self):
305389 self .set_global_attr_netcdf (h_nc )
306390
307391 def set_global_attr_netcdf (self , h_nc ):
308- h_nc .title = 'Cyclonic' if self .sign_type == - 1 else 'Anticyclonic' + ' eddy tracks'
392+ """Set global attr
393+ """
394+ if self .sign_type == - 1 :
395+ h_nc .title = 'Cyclonic'
396+ else :
397+ h_nc .title = 'Anticyclonic' + ' eddy tracks'
309398 #~ h_nc.grid_filename = self.grd.grid_filename
310399 #~ h_nc.grid_date = str(self.grd.grid_date)
311400 #~ h_nc.product = self.product
@@ -324,7 +413,7 @@ def set_global_attr_netcdf(self, h_nc):
324413 #~ h_nc.evolve_amp_max = self.evolve_amp_max
325414 #~ h_nc.evolve_area_min = self.evolve_area_min
326415 #~ h_nc.evolve_area_max = self.evolve_area_max
327- #~
416+
328417 #~ h_nc.llcrnrlon = self.grd.lonmin
329418 #~ h_nc.urcrnrlon = self.grd.lonmax
330419 #~ h_nc.llcrnrlat = self.grd.latmin
0 commit comments