2929"""
3030from matplotlib .dates import julian2num , num2date
3131
32- from py_eddy_tracker .observations import EddiesObservations , \
33- VirtualEddiesObservations , TrackEddiesObservations
34- from numpy import bool_ , array , arange , ones , setdiff1d , zeros , uint16 , \
35- where , empty
32+ from py_eddy_tracker .observations import EddiesObservations , VirtualEddiesObservations , TrackEddiesObservations
33+ from numpy import bool_ , array , arange , ones , setdiff1d , zeros , uint16 , where , empty , isin
3634from netCDF4 import Dataset
3735import logging
3836
@@ -237,17 +235,13 @@ def recense_dead_id_to_extend(self):
237235 # get id already dead from few time
238236 nb_virtual_extend = 0
239237 if self .virtual_obs is not None :
240- virtual_dead_id = setdiff1d (self .virtual_obs ['track' ],
241- self [- 1 ]['id' ])
238+ virtual_dead_id = setdiff1d (self .virtual_obs ['track' ], self [- 1 ]['id' ])
242239 list_previous_virtual_id = self .virtual_obs ['track' ].tolist ()
243- i_virtual_dead_id = [
244- list_previous_virtual_id .index (i ) for i in virtual_dead_id ]
240+ i_virtual_dead_id = [list_previous_virtual_id .index (i ) for i in virtual_dead_id ]
245241 # Virtual obs which can be prolongate
246- alive_virtual_obs = self .virtual_obs ['segment_size'
247- ][i_virtual_dead_id ] < self .nb_virtual
242+ alive_virtual_obs = self .virtual_obs ['segment_size' ][i_virtual_dead_id ] < self .nb_virtual
248243 nb_virtual_extend = alive_virtual_obs .sum ()
249- logging .debug ('%d virtual obs will be prolongate on the '
250- 'next step' , nb_virtual_extend )
244+ logging .debug ('%d virtual obs will be prolongate on the next step' , nb_virtual_extend )
251245
252246 # Save previous state to count virtual obs
253247 self .previous_virtual_obs = self .virtual_obs
@@ -270,35 +264,27 @@ def recense_dead_id_to_extend(self):
270264 # Position N-1 : B
271265 # Virtual Position : C
272266 # New position C = B + AB
273- for key in obs_b .dtype .fields .keys ():
274- if key in ['lon' , 'lat' , 'time' , 'track' , 'segment_size' ,
275- 'dlon' , 'dlat' ] or 'contour_' in key :
267+ for key in self .previous_obs .elements :
268+ if key in ['lon' , 'lat' , 'time' ] or 'contour_' in key :
276269 continue
277270 self .virtual_obs [key ][:nb_dead ] = obs_b [key ]
278271 self .virtual_obs ['dlon' ][:nb_dead ] = obs_b ['lon' ] - obs_a ['lon' ]
279272 self .virtual_obs ['dlat' ][:nb_dead ] = obs_b ['lat' ] - obs_a ['lat' ]
280- self .virtual_obs ['lon' ][:nb_dead
281- ] = obs_b ['lon' ] + self .virtual_obs ['dlon' ][:nb_dead ]
282- self .virtual_obs ['lat' ][:nb_dead
283- ] = obs_b ['lat' ] + self .virtual_obs ['dlat' ][:nb_dead ]
273+ self .virtual_obs ['lon' ][:nb_dead ] = obs_b ['lon' ] + self .virtual_obs ['dlon' ][:nb_dead ]
274+ self .virtual_obs ['lat' ][:nb_dead ] = obs_b ['lat' ] + self .virtual_obs ['dlat' ][:nb_dead ]
284275 # Id which are extended
285276 self .virtual_obs ['track' ][:nb_dead ] = dead_id
286277 # Add previous virtual
287278 if nb_virtual_extend > 0 :
288- obs_to_extend = self .previous_virtual_obs .obs [i_virtual_dead_id
289- ][alive_virtual_obs ]
290- for key in obs_b .dtype .fields .keys ():
291- if key in ['lon' , 'lat' , 'time' , 'track' , 'segment_size' ,
292- 'dlon' , 'dlat' ] or 'contour_' in key :
279+ obs_to_extend = self .previous_virtual_obs .obs [i_virtual_dead_id ][alive_virtual_obs ]
280+ for key in self .virtual_obs .elements :
281+ if key in ['lon' , 'lat' , 'time' , 'track' , 'segment_size' ] or 'contour_' in key :
293282 continue
294283 self .virtual_obs [key ][nb_dead :] = obs_to_extend [key ]
295- self .virtual_obs ['lon' ][nb_dead :
296- ] = obs_to_extend ['lon' ] + obs_to_extend ['dlon' ]
297- self .virtual_obs ['lat' ][nb_dead :
298- ] = obs_to_extend ['lat' ] + obs_to_extend ['dlat' ]
284+ self .virtual_obs ['lon' ][nb_dead :] = obs_to_extend ['lon' ] + obs_to_extend ['dlon' ]
285+ self .virtual_obs ['lat' ][nb_dead :] = obs_to_extend ['lat' ] + obs_to_extend ['dlat' ]
299286 self .virtual_obs ['track' ][nb_dead :] = obs_to_extend ['track' ]
300- self .virtual_obs ['segment_size' ][nb_dead :
301- ] = obs_to_extend ['segment_size' ]
287+ self .virtual_obs ['segment_size' ][nb_dead :] = obs_to_extend ['segment_size' ]
302288 # Count
303289 self .virtual_obs ['segment_size' ][:] += 1
304290
@@ -335,11 +321,9 @@ def track(self):
335321
336322 nb_real_obs = len (self .previous_obs )
337323 if flg_virtual :
338- logging .debug ('%d virtual obs will be add to previous' ,
339- len (self .virtual_obs ))
324+ logging .debug ('%d virtual obs will be add to previous' , len (self .virtual_obs ))
340325 self .previous_obs = self .previous_obs .merge (self .virtual_obs )
341- i_previous , i_current = self .previous_obs .tracking (
342- self .current_obs )
326+ i_previous , i_current = self .previous_obs .tracking (self .current_obs )
343327
344328 # return true if the first time (previous2obs is none)
345329 if self .store_correspondance (i_previous , i_current , nb_real_obs ):
@@ -455,59 +439,63 @@ def load(cls, filename):
455439 def prepare_merging (self ):
456440 # count obs by tracks (we add directly one, because correspondance
457441 # is an interval)
458- self .nb_obs_by_tracks = zeros (self .current_id , dtype = self .N_DTYPE ) + 1
442+ self .nb_obs_by_tracks = ones (self .current_id , dtype = self .N_DTYPE )
459443 for correspondance in self :
460444 self .nb_obs_by_tracks [correspondance ['id' ]] += 1
461445 if self .virtual :
462446 # When start is virtual, we don't have a previous
463447 # correspondance
464- self .nb_obs_by_tracks [
465- correspondance ['id' ][correspondance ['virtual' ]]
466- ] += correspondance ['virtual_length' ][
467- correspondance ['virtual' ]]
448+ self .nb_obs_by_tracks [correspondance ['id' ][correspondance ['virtual' ]]
449+ ] += correspondance ['virtual_length' ][correspondance ['virtual' ]]
468450
469451 # Compute index of each tracks
470- self .i_current_by_tracks = \
471- self .nb_obs_by_tracks .cumsum () - self .nb_obs_by_tracks
452+ self .i_current_by_tracks = self .nb_obs_by_tracks .cumsum () - self .nb_obs_by_tracks
472453 # Number of global obs
473454 self .nb_obs = self .nb_obs_by_tracks .sum ()
474455 logging .info ('%d tracks identified' , self .current_id )
475456 logging .info ('%d observations will be join' , self .nb_obs )
476457
477- def merge (self , until = - 1 ):
458+ def merge (self , until = - 1 , size_min = None ):
478459 """Merge all the correspondance in one array with all fields
479460 """
480461 # Start loading identification again to save in the finals tracks
481462 # Load first file
463+ self .reset_dataset_cache ()
482464 self .swap_dataset (self .datasets [0 ])
483465
484466 # Start create netcdf to agglomerate all eddy
485467 logging .debug ('We will create an array (size %d)' , self .nb_obs )
468+ i_keep_track = slice (None )
469+ if size_min is not None :
470+ i_keep_track = where (self .nb_obs_by_tracks >= size_min )
471+ self .nb_obs_by_tracks = self .nb_obs_by_tracks [i_keep_track ]
472+ self .i_current_by_tracks [i_keep_track ] = self .nb_obs_by_tracks .cumsum () - self .nb_obs_by_tracks
473+ self .nb_obs = self .nb_obs_by_tracks .sum ()
474+ # ??
475+ self .current_id = self .nb_obs_by_tracks .shape [0 ]
486476 eddies = TrackEddiesObservations (
487477 size = self .nb_obs ,
488478 track_extra_variables = self .current_obs .track_extra_variables ,
489479 track_array_variables = self .current_obs .track_array_variables ,
490- array_variables = self .current_obs .array_variables ,
491- )
480+ array_variables = self .current_obs .array_variables )
492481
493482 # Calculate the index in each tracks, we compute in u4 and translate
494483 # in u2 (which are limited to 65535)
495484 logging .debug ('Compute global index array (N)' )
496485 eddies ['n' ][:] = uint16 (
497- arange (self .nb_obs , dtype = 'u4' )
498- - self .i_current_by_tracks .repeat (self .nb_obs_by_tracks ))
486+ arange (self .nb_obs , dtype = 'u4' ) - self .i_current_by_tracks [i_keep_track ].repeat (self .nb_obs_by_tracks ))
499487 logging .debug ('Compute global track array' )
500- eddies ['track' ][:] = arange (self .current_id
501- ).repeat (self .nb_obs_by_tracks )
488+ eddies ['track' ][:] = arange (self .current_id ).repeat (self .nb_obs_by_tracks )
489+ if size_min is not None :
490+ eddies ['track' ][:] += 1
502491
503492 # Set type of eddy with first file
504493 eddies .sign_type = self .current_obs .sign_type
505494 # Fields to copy
506495 fields = self .current_obs .obs .dtype .descr
507496
508497 # To know if the track start
509- first_obs_save_in_tracks = zeros (self .i_current_by_tracks .shape ,
510- dtype = bool_ )
498+ first_obs_save_in_tracks = zeros (self .i_current_by_tracks .shape , dtype = bool_ )
511499
512500 for i , file_name in enumerate (self .datasets [1 :]):
513501 if until != - 1 and i >= until :
@@ -517,19 +505,23 @@ def merge(self, until=-1):
517505 self .swap_dataset (file_name )
518506 # We select the list of id which are involve in the correspondance
519507 i_id = self [i ]['id' ]
520- # Index where we will write in the final object
508+ if size_min is not None :
509+ m_id = isin (i_id , i_keep_track )
510+ i_id = i_id [m_id ]
511+ else :
512+ m_id = slice (None )
513+ # Index where we will write in the final object
521514 index_final = self .i_current_by_tracks [i_id ]
522515
523516 # First obs of eddies
524517 m_first_obs = ~ first_obs_save_in_tracks [i_id ]
525518 if m_first_obs .any ():
526519 # Index in the previous file
527- index_in = self [i ]['in' ][m_first_obs ]
520+ index_in = self [i ]['in' ][m_id ][ m_first_obs ]
528521 # Copy all variable
529522 for field in fields :
530523 var = field [0 ]
531- eddies [var ][index_final [m_first_obs ]
532- ] = self .previous_obs [var ][index_in ]
524+ eddies [var ][index_final [m_first_obs ]] = self .previous_obs [var ][index_in ]
533525 # Increment
534526 self .i_current_by_tracks [i_id [m_first_obs ]] += 1
535527 # Active this flag, we have only one first by tracks
@@ -539,22 +531,20 @@ def merge(self, until=-1):
539531 if self .virtual :
540532 # If the flag virtual in correspondance is active,
541533 # the previous is virtual
542- m_virtual = self [i ]['virtual' ]
534+ m_virtual = self [i ]['virtual' ][ m_id ]
543535 if m_virtual .any ():
544536 # Incrementing index
545- self .i_current_by_tracks [i_id [m_virtual ]
546- ] += self [i ]['virtual_length' ][m_virtual ]
537+ self .i_current_by_tracks [i_id [m_virtual ]] += self [i ]['virtual_length' ][m_id ][m_virtual ]
547538 # Get new index
548539 index_final = self .i_current_by_tracks [i_id ]
549540
550541 # Index in the current file
551- index_current = self [i ]['out' ]
542+ index_current = self [i ]['out' ][ m_id ]
552543
553544 # Copy all variable
554545 for field in fields :
555546 var = field [0 ]
556- eddies [var ][index_final
557- ] = self .current_obs [var ][index_current ]
547+ eddies [var ][index_final ] = self .current_obs [var ][index_current ]
558548
559549 # Add increment for each index used
560550 self .i_current_by_tracks [i_id ] += 1
0 commit comments