66
77import  gzip 
88import  json 
9+ import  logging 
910import  time 
1011import  warnings 
1112from  collections .abc  import  Iterable 
1718from  matplotlib  import  colors  as  mcolors 
1819from  pytorch3d .implicitron .tools .vis_utils  import  get_visdom_connection 
1920
21+ logger  =  logging .getLogger (__name__ )
22+ 
2023
2124class  AverageMeter (object ):
2225    """Computes and stores the average and current value""" 
@@ -91,7 +94,9 @@ class Stats(object):
9194                # stats.update() automatically parses the 'objective' and 'top1e' from 
9295                # the "output" dict and stores this into the db 
9396                stats.update(output) 
94-                 stats.print() # prints the averages over given epoch 
97+                 # prints the metric averages over given epoch 
98+                 std_out = stats.get_status_string() 
99+                 logger.info(str_out) 
95100            # stores the training plots into '/tmp/epoch_stats.pdf' 
96101            # and plots into a visdom server running at localhost (if running) 
97102            stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') 
@@ -101,7 +106,6 @@ class Stats(object):
101106    def  __init__ (
102107        self ,
103108        log_vars ,
104-         verbose = False ,
105109        epoch = - 1 ,
106110        visdom_env = "main" ,
107111        do_plot = True ,
@@ -110,7 +114,6 @@ def __init__(
110114        visdom_port = 8097 ,
111115    ):
112116
113-         self .verbose  =  verbose 
114117        self .log_vars  =  log_vars 
115118        self .visdom_env  =  visdom_env 
116119        self .visdom_server  =  visdom_server 
@@ -156,32 +159,29 @@ def __exit__(self, type, value, traceback):
156159        iserr  =  type  is  not   None  and  issubclass (type , Exception )
157160        iserr  =  iserr  or  (type  is  KeyboardInterrupt )
158161        if  iserr :
159-             print ("error inside 'with' block" )
162+             logger . error ("error inside 'with' block" )
160163            return 
161164        if  self .do_plot :
162165            self .plot_stats (self .visdom_env )
163166
164167    def  reset (self ):  # to be called after each epoch 
165168        stat_sets  =  list (self .stats .keys ())
166-         if  self .verbose :
167-             print ("stats: epoch %d - reset"  %  self .epoch )
169+         logger .debug (f"stats: epoch { self .epoch }   - reset" )
168170        self .it  =  {k : - 1  for  k  in  stat_sets }
169171        for  stat_set  in  stat_sets :
170172            for  stat  in  self .stats [stat_set ]:
171173                self .stats [stat_set ][stat ].reset ()
172174
173175    def  hard_reset (self , epoch = - 1 ):  # to be called during object __init__ 
174176        self .epoch  =  epoch 
175-         if  self .verbose :
176-             print ("stats: epoch %d - hard reset"  %  self .epoch )
177+         logger .debug (f"stats: epoch { self .epoch }   - hard reset" )
177178        self .stats  =  {}
178179
179180        # reset 
180181        self .reset ()
181182
182183    def  new_epoch (self ):
183-         if  self .verbose :
184-             print ("stats: new epoch %d"  %  (self .epoch  +  1 ))
184+         logger .debug (f"stats: new epoch { (self .epoch  +  1 )}  " )
185185        self .epoch  +=  1 
186186        self .reset ()  # zero the stats + increase epoch counter 
187187
@@ -193,18 +193,17 @@ def gather_value(self, val):
193193            val  =  float (val .sum ())
194194        return  val 
195195
196-     def  add_log_vars (self , added_log_vars ,  verbose = True ):
196+     def  add_log_vars (self , added_log_vars ):
197197        for  add_log_var  in  added_log_vars :
198198            if  add_log_var  not  in   self .stats :
199-                 if  verbose :
200-                     print (f"Adding { add_log_var }  " )
199+                 logger .debug (f"Adding { add_log_var }  " )
201200                self .log_vars .append (add_log_var )
202201
203202    def  update (self , preds , time_start = None , freeze_iter = False , stat_set = "train" ):
204203
205204        if  self .epoch  ==  - 1 :  # uninitialized 
206-             print (
207-                 "warning:  epoch==-1 means uninitialized stats structure -> new_epoch() called" 
205+             logger . warning (
206+                 "epoch==-1 means uninitialized stats structure -> new_epoch() called" 
208207            )
209208            self .new_epoch ()
210209
@@ -284,6 +283,12 @@ def print(
284283        skip_nan = False ,
285284        stat_format = lambda  s : s .replace ("loss_" , "" ).replace ("prev_stage_" , "ps_" ),
286285    ):
286+         """ 
287+         stats.print() is deprecated. Please use get_status_string() instead. 
288+         example: 
289+         std_out = stats.get_status_string() 
290+         logger.info(str_out) 
291+         """ 
287292
288293        epoch  =  self .epoch 
289294        stats  =  self .stats 
@@ -311,8 +316,30 @@ def print(
311316        if  get_str :
312317            return  str_out 
313318        else :
319+             warnings .warn (
320+                 "get_str=False is deprecated." 
321+                 "Please enable this flag to get receive the output string." ,
322+                 DeprecationWarning ,
323+             )
314324            print (str_out )
315325
326+     def  get_status_string (
327+         self ,
328+         max_it = None ,
329+         stat_set = "train" ,
330+         vars_print = None ,
331+         skip_nan = False ,
332+         stat_format = lambda  s : s .replace ("loss_" , "" ).replace ("prev_stage_" , "ps_" ),
333+     ):
334+         return  self .print (
335+             max_it = max_it ,
336+             stat_set = stat_set ,
337+             vars_print = vars_print ,
338+             get_str = True ,
339+             skip_nan = skip_nan ,
340+             stat_format = stat_format ,
341+         )
342+ 
316343    def  plot_stats (
317344        self , visdom_env = None , plot_file = None , visdom_server = None , visdom_port = None 
318345    ):
@@ -329,16 +356,15 @@ def plot_stats(
329356
330357        stat_sets  =  list (self .stats .keys ())
331358
332-         print (
333-             "printing charts to visdom env '%s' (%s:%d)" 
334-             %  (visdom_env , visdom_server , visdom_port )
359+         logger .debug (
360+             f"printing charts to visdom env '{ visdom_env }  ' ({ visdom_server }  :{ visdom_port }  )" 
335361        )
336362
337363        novisdom  =  False 
338364
339365        viz  =  get_visdom_connection (server = visdom_server , port = visdom_port )
340366        if  viz  is  None  or  not  viz .check_connection ():
341-             print ("no visdom server! -> skipping visdom plots" )
367+             logger . info ("no visdom server! -> skipping visdom plots" )
342368            novisdom  =  True 
343369
344370        lines  =  []
@@ -385,7 +411,7 @@ def plot_stats(
385411                    )
386412
387413        if  plot_file :
388-             print ( "exporting  stats to %s"   %   plot_file )
414+             logger . info ( f"plotting  stats to { plot_file } "  )
389415            ncol  =  3 
390416            nrow  =  int (np .ceil (float (len (lines )) /  ncol ))
391417            matplotlib .rcParams .update ({"font.size" : 5 })
@@ -423,15 +449,15 @@ def plot_stats(
423449            except  PermissionError :
424450                warnings .warn ("Cant dump stats due to insufficient permissions!" )
425451
426-     def  synchronize_logged_vars (self , log_vars , default_val = float ("NaN" ),  verbose = True ):
452+     def  synchronize_logged_vars (self , log_vars , default_val = float ("NaN" )):
427453
428454        stat_sets  =  list (self .stats .keys ())
429455
430456        # remove the additional log_vars 
431457        for  stat_set  in  stat_sets :
432458            for  stat  in  self .stats [stat_set ].keys ():
433459                if  stat  not  in   log_vars :
434-                     print ( "additional stat %s:%s  -> removing"   %  ( stat_set ,  stat ) )
460+                     logger . warning ( f "additional stat { stat_set } : { stat }   -> removing" )
435461
436462            self .stats [stat_set ] =  {
437463                stat : v  for  stat , v  in  self .stats [stat_set ].items () if  stat  in  log_vars 
@@ -442,21 +468,19 @@ def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=Tr
442468        for  stat_set  in  stat_sets :
443469            for  stat  in  log_vars :
444470                if  stat  not  in   self .stats [stat_set ]:
445-                     if  verbose :
446-                         print (
447-                             "missing stat %s:%s -> filling with default values (%1.2f)" 
448-                             %  (stat_set , stat , default_val )
449-                         )
471+                     logger .info (
472+                         "missing stat %s:%s -> filling with default values (%1.2f)" 
473+                         %  (stat_set , stat , default_val )
474+                     )
450475                elif  len (self .stats [stat_set ][stat ].history ) !=  self .epoch  +  1 :
451476                    h  =  self .stats [stat_set ][stat ].history 
452477                    if  len (h ) ==  0 :  # just never updated stat ... skip 
453478                        continue 
454479                    else :
455-                         if  verbose :
456-                             print (
457-                                 "incomplete stat %s:%s -> reseting with default values (%1.2f)" 
458-                                 %  (stat_set , stat , default_val )
459-                             )
480+                         logger .info (
481+                             "incomplete stat %s:%s -> reseting with default values (%1.2f)" 
482+                             %  (stat_set , stat , default_val )
483+                         )
460484                else :
461485                    continue 
462486
0 commit comments