@@ -36,6 +36,7 @@ class LoggerConnector:
3636 def __init__ (self , trainer ):
3737 self .trainer = trainer
3838 self .callback_metrics = {}
39+ self .evaluation_callback_metrics = {}
3940 self .logged_metrics = {}
4041 self .progress_bar_metrics = {}
4142 self .eval_loop_results = []
@@ -59,10 +60,9 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
5960 on_epoch = on_epoch )
6061
6162 def on_evaluation_batch_start (self , testing , batch , dataloader_idx , num_dataloaders ):
62- # reset the result of the PL module
6363 model = self .trainer .get_model ()
64+ # set dataloader_idx only if multiple ones
6465 model ._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
65-
6666 # track batch_size
6767 self .cached_results ._batch_size = Result .extract_batch_size (batch )
6868
@@ -226,19 +226,41 @@ def add_progress_bar_metrics(self, metrics):
226226
227227 self .trainer .dev_debugger .track_pbar_metrics_history (metrics )
228228
229- def on_evaluation_epoch_end (self , deprecated_eval_results , epoch_logs , using_eval_result , test_mode ):
229+ def track_metrics_deprecated (self , deprecated_eval_results , using_eval_result , test_mode ):
230230 self ._track_callback_metrics (deprecated_eval_results , using_eval_result )
231-
232- # TODO: deprecate parts of this for 1.0 (when removing results)
233231 self .__process_eval_epoch_end_results_and_log_legacy (deprecated_eval_results , test_mode )
234232
235- self ._log_on_evaluation_epoch_end_metrics (epoch_logs )
233+ def evaluation_epoch_end (self , testing ):
234+ # reset dataloader idx
235+ model_ref = self .trainer .get_model ()
236+ model_ref ._current_dataloader_idx = None
237+
238+ # setting `has_batch_loop_finished` to True
239+ # will perform Results reduction accross entire epoch.
240+ self .cached_results .has_batch_loop_finished = True
241+
242+ def add_to_eval_loop_results (self , dl_idx , has_been_initialized ):
243+ callback_metrics = deepcopy (self .evaluation_callback_metrics )
244+ for key in list (callback_metrics .keys ()):
245+ if "dataloader_idx" in key :
246+ if f"dataloader_idx_{ dl_idx } " not in key :
247+ # remove dl_idx from self.callback_metrics not belonging to this dataset.
248+ del callback_metrics [key ]
249+ if has_been_initialized :
250+ self .eval_loop_results [dl_idx ].update (callback_metrics )
251+ else :
252+ self .eval_loop_results .append (callback_metrics )
236253
237- # get the final loop results
238- eval_loop_results = self ._get_evaluate_epoch_results (test_mode )
239- return eval_loop_results
254+ def prepare_eval_loop_results (self ):
255+ num_dataloaders = self .trainer .evaluation_loop .num_dataloaders
256+ has_been_initialized = len (self .eval_loop_results ) == num_dataloaders
257+ for dl_idx in range (self .trainer .evaluation_loop .num_dataloaders ):
258+ self .add_to_eval_loop_results (dl_idx , has_been_initialized )
259+
260+ def get_evaluate_epoch_results (self , test_mode ):
261+
262+ self .prepare_eval_loop_results ()
240263
241- def _get_evaluate_epoch_results (self , test_mode ):
242264 # log results of test
243265 if test_mode and self .trainer .is_global_zero and self .trainer .verbose_test :
244266 print ('-' * 80 )
@@ -253,106 +275,6 @@ def _get_evaluate_epoch_results(self, test_mode):
253275 self .eval_loop_results = []
254276 return results
255277
256- def _log_on_evaluation_epoch_end_metrics (self , epoch_logs ):
257- step_metrics = self .trainer .evaluation_loop .step_metrics
258-
259- num_loaders = len (step_metrics )
260-
261- # clear mem
262- self .trainer .evaluation_loop .step_metrics = []
263-
264- if self .trainer .running_sanity_check :
265- return
266-
267- # track all metrics we want to log
268- metrics_to_log = []
269-
270- # ---------------------------
271- # UPDATE EPOCH LOGGED METRICS
272- # ---------------------------
273- # (ie: in methods at the val_epoch_end level)
274- # union the epoch logs with whatever was returned from loaders and reduced
275- epoch_logger_metrics = epoch_logs .get_epoch_log_metrics ()
276- epoch_pbar_metrics = epoch_logs .get_epoch_pbar_metrics ()
277-
278- self .logged_metrics .update (epoch_logger_metrics )
279- self .add_progress_bar_metrics (epoch_pbar_metrics )
280-
281- # enable the metrics to be monitored
282- self .callback_metrics .update (epoch_logger_metrics )
283- self .callback_metrics .update (epoch_pbar_metrics )
284-
285- if len (epoch_logger_metrics ) > 0 :
286- metrics_to_log .append (epoch_logger_metrics )
287-
288- # --------------------------------
289- # UPDATE METRICS PER DATALOADER
290- # --------------------------------
291- # each dataloader aggregated metrics
292- # now we log all of them
293- for dl_idx , dl_metrics in enumerate (step_metrics ):
294- if len (dl_metrics ) == 0 :
295- # Ensure custom logged metrics are included if not included with step metrics
296- if len (epoch_logger_metrics ) > 0 :
297- self .eval_loop_results .append (epoch_logger_metrics )
298- continue
299-
300- reduced_epoch_metrics = dl_metrics [0 ].__class__ .reduce_on_epoch_end (dl_metrics )
301- # track the metrics
302- logger_metrics = reduced_epoch_metrics .get_epoch_log_metrics ()
303- pbar_metrics = reduced_epoch_metrics .get_epoch_pbar_metrics ()
304- forked_metrics = reduced_epoch_metrics .get_forked_metrics ()
305-
306- # make the keys 'k/dl'
307- logger_metrics = self .__rename_keys_by_dataloader_idx (logger_metrics , dl_idx , num_loaders )
308- pbar_metrics = self .__rename_keys_by_dataloader_idx (pbar_metrics , dl_idx , num_loaders )
309- forked_metrics = self .__rename_keys_by_dataloader_idx (forked_metrics , dl_idx , num_loaders )
310-
311- self .logged_metrics .update (logger_metrics )
312- self .add_progress_bar_metrics (pbar_metrics )
313-
314- # enable the metrics to be monitored
315- self .callback_metrics .update (logger_metrics )
316- self .callback_metrics .update (pbar_metrics )
317-
318- # forked metrics were dropped, enable them for callbacks
319- self .callback_metrics .update (forked_metrics )
320-
321- # track the final results for the dataloader
322- self .add_to_eval_loop_results (dl_idx , num_loaders )
323-
324- # actually log
325- if len (logger_metrics ) > 0 :
326- metrics_to_log .append (logger_metrics )
327-
328- # log all the metrics as a s single dict
329- metrics_to_log = dict (ChainMap (* metrics_to_log ))
330- if len (metrics_to_log ) > 0 :
331- self .log_metrics (metrics_to_log , {})
332-
333- def add_to_eval_loop_results (self , dl_idx , num_loaders ):
334- callback_metrics = deepcopy (self .callback_metrics )
335- if num_loaders == 1 :
336- if len (self .eval_loop_results ) > 0 :
337- self .eval_loop_results [0 ].update (callback_metrics )
338- else :
339- self .eval_loop_results .append (callback_metrics )
340- return
341-
342- for key in list (callback_metrics .keys ()):
343- if "dataloader_idx" in key :
344- if f"dataloader_idx_{ dl_idx } " not in key :
345- # remove dl_idx from self.callback_metrics not belonging to this dataset.
346- del callback_metrics [key ]
347- self .eval_loop_results .append (callback_metrics )
348-
349- def __rename_keys_by_dataloader_idx (self , metrics , dataloader_idx , num_loaders ):
350- if num_loaders == 1 :
351- return metrics
352-
353- result = {f'{ k } /dataloader_idx_{ dataloader_idx } ' : v for k , v in metrics .items ()}
354- return result
355-
356278 def _track_callback_metrics (self , eval_results , using_eval_result ):
357279 if (
358280 len (eval_results ) > 0 and
@@ -364,8 +286,10 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
364286 if isinstance (eval_results , list ):
365287 for eval_result in eval_results :
366288 self .trainer .logger_connector .callback_metrics .update (eval_result .callback_metrics )
289+ self .trainer .logger_connector .evaluation_callback_metrics .update (eval_result .callback_metrics )
367290 else :
368291 self .trainer .logger_connector .callback_metrics .update (eval_results .callback_metrics )
292+ self .trainer .logger_connector .evaluation_callback_metrics .update (eval_results .callback_metrics )
369293 else :
370294 flat = {}
371295 if isinstance (eval_results , list ):
@@ -381,6 +305,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
381305 flat ['checkpoint_on' ] = flat ['val_loss' ]
382306 flat ['early_stop_on' ] = flat ['val_loss' ]
383307 self .trainer .logger_connector .callback_metrics .update (flat )
308+ self .trainer .logger_connector .evaluation_callback_metrics .update (flat )
384309 else :
385310 # with a scalar return, auto set it to "val_loss" for callbacks
386311 if isinstance (eval_results , torch .Tensor ):
@@ -393,6 +318,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
393318 flat ['checkpoint_on' ] = flat ['val_loss' ]
394319 flat ['early_stop_on' ] = flat ['val_loss' ]
395320 self .trainer .logger_connector .callback_metrics .update (flat )
321+ self .trainer .logger_connector .evaluation_callback_metrics .update (flat )
396322
397323 def __process_eval_epoch_end_results_and_log_legacy_update (self , prog_bar_metrics , log_metrics , callback_metrics ):
398324 # eval loop returns all metrics
@@ -406,9 +332,10 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
406332 self .trainer .logger_connector .log_metrics (log_metrics , {})
407333
408334 # track metrics for callbacks (all prog bar, logged and callback metrics)
335+ callback_metrics .update (log_metrics )
336+ callback_metrics .update (prog_bar_metrics )
409337 self .trainer .logger_connector .callback_metrics .update (callback_metrics )
410- self .trainer .logger_connector .callback_metrics .update (log_metrics )
411- self .trainer .logger_connector .callback_metrics .update (prog_bar_metrics )
338+ self .trainer .logger_connector .evaluation_callback_metrics .update (callback_metrics )
412339
413340 if len (dataloader_result_metrics ) > 0 :
414341 self .eval_loop_results .append (dataloader_result_metrics )
0 commit comments