@@ -44,25 +44,14 @@ def __init__(self, trainer):
4444 self ._callback_hook_validator = CallbackHookNameValidator ()
4545 self ._current_stage = None
4646
47- def cached_results (self , stage_or_testing : Union [str , bool ]) -> Union [EpochResultStore , None ]:
48- """ Function to access cached_results using str or bool. Bool is used only for testing"""
49- stage_or_testing = str (stage_or_testing )
50- stages = self ._stages
51- if stage_or_testing in self ._stages :
52- return self ._cached_results [stage_or_testing ]
53- if stage_or_testing in LOOKUP_TABLE :
54- # Acces using trainer.testing
55- stage = LOOKUP_TABLE [stage_or_testing ]
56- return self ._cached_results [stage ]
57- raise MisconfigurationException (
58- f"Provide stage_or_testing { stage_or_testing } doesn't belong either to { self ._stages } "
59- f" or { LOOKUP_TABLE .keys ()} "
60- )
47+ @property
48+ def cached_results (self ) -> Union [EpochResultStore , None ]:
49+ return self ._cached_results [self ._current_stage ]
6150
6251 def set_stage (self , stage_or_testing : str , reset :bool = False ) -> None :
6352 self ._current_stage = self ._determine_stage (stage_or_testing )
6453 if reset :
65- self .cached_results ( stage_or_testing ) .reset ()
54+ self .cached_results .reset ()
6655
6756 def check_logging_in_callbacks (self , hook_fx_name , on_step : bool = None , on_epoch : bool = None ) -> None :
6857 self ._callback_hook_validator .check_logging_in_callbacks (current_hook_fx_name = hook_fx_name ,
@@ -75,17 +64,17 @@ def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataload
7564 model ._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
7665
7766 # track batch_size
78- self .cached_results ( testing ) ._batch_size = Result .extract_batch_size (batch )
67+ self .cached_results ._batch_size = Result .extract_batch_size (batch )
7968
80- def on_batch_start (self , split_idx : int , opt_idx : int , split_batch ) -> None :
81- self ._cached_results [ "train" ] ._split_idx = split_idx
82- self ._cached_results [ "train" ] ._opt_idx = opt_idx
83- self ._cached_results [ "train" ] ._batch_size = Result .extract_batch_size (split_batch )
69+ def on_train_split_start (self , split_idx : int , opt_idx : int , split_batch ) -> None :
70+ self .cached_results ._split_idx = split_idx
71+ self .cached_results ._opt_idx = opt_idx
72+ self .cached_results ._batch_size = Result .extract_batch_size (split_batch )
8473
8574 def on_train_batch_end (self ) -> None :
86- self ._cached_results [ "train" ] ._split_idx = None
87- self ._cached_results [ "train" ] ._opt_idx = None
88- self ._cached_results [ "train" ] ._batch_size = None
75+ self .cached_results ._split_idx = None
76+ self .cached_results ._opt_idx = None
77+ self .cached_results ._batch_size = None
8978
9079 def _determine_stage (self , stage_or_testing : Union [str , bool ]) -> str :
9180 stage_or_testing = str (stage_or_testing )
@@ -112,6 +101,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
112101 self .trainer .flush_logs_every_n_steps = flush_logs_every_n_steps
113102 self .trainer .log_every_n_steps = log_every_n_steps
114103
104+ @property
105+ def should_flush_logs (self ):
106+ should_flush = (self .trainer .global_step + 1 ) % self .trainer .flush_logs_every_n_steps == 0
107+ return should_flush or self .trainer .should_stop
108+
109+ @property
110+ def should_update_logs (self ):
111+ should_log_every_n_steps = (self .trainer .global_step + 1 ) % self .trainer .log_every_n_steps == 0
112+ return should_log_every_n_steps or self .trainer .should_stop
113+
115114 def configure_logger (self , logger ):
116115 if logger is True :
117116 version = os .environ .get ('PL_EXP_VERSION' , self .trainer .slurm_job_id )
@@ -130,6 +129,53 @@ def configure_logger(self, logger):
130129 else :
131130 self .trainer .logger = logger
132131
132+ def cache_training_step_metrics (self , opt_closure_result ):
133+ """
134+ This function is responsible to update
135+ logger_connector internals metrics holder based for depreceated logging
136+ """
137+ using_results_obj = isinstance (opt_closure_result .training_step_output , Result )
138+
139+ # temporary dict to collect metrics
140+ logged_metrics_tmp = {}
141+ pbar_metrics_tmp = {}
142+ callback_metrics_tmp = {}
143+
144+ if using_results_obj :
145+ batch_log_metrics = opt_closure_result .training_step_output .get_batch_log_metrics (
146+ include_forked_originals = False
147+ )
148+ logged_metrics_tmp .update (batch_log_metrics )
149+
150+ batch_pbar_metrics = opt_closure_result .training_step_output .get_batch_pbar_metrics (
151+ include_forked_originals = False
152+ )
153+ pbar_metrics_tmp .update (batch_pbar_metrics )
154+
155+ forked_metrics = opt_closure_result .training_step_output .get_forked_metrics ()
156+ callback_metrics_tmp .update (forked_metrics )
157+ callback_metrics_tmp .update (logged_metrics_tmp )
158+
159+ else :
160+ batch_log_metrics = opt_closure_result .training_step_output .log_metrics
161+ logged_metrics_tmp .update (batch_log_metrics )
162+
163+ callback_metrics = opt_closure_result .training_step_output .callback_metrics
164+ callback_metrics_tmp .update (callback_metrics )
165+
166+ batch_pbar_metrics = opt_closure_result .training_step_output .pbar_on_batch_end
167+ pbar_metrics_tmp .update (batch_pbar_metrics )
168+
169+ # track progress bar metrics
170+ if len (pbar_metrics_tmp ) > 0 :
171+ self .add_progress_bar_metrics (pbar_metrics_tmp )
172+
173+ self .callback_metrics .update (callback_metrics_tmp )
174+
175+ # save legacy log metrics
176+ self .logged_metrics .update (logged_metrics_tmp )
177+ self .cached_results .legacy_batch_log_metrics .update (logged_metrics_tmp )
178+
133179 def log_metrics (self , metrics , grad_norm_dic , step = None ):
134180 """Logs the metric dict passed in.
135181 If `step` parameter is None and `step` key is presented is metrics,
@@ -396,8 +442,9 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
396442 if num_loaders == 1 :
397443 self .__process_eval_epoch_end_results_and_log_legacy_update (prog_bar_metrics , log_metrics , callback_metrics )
398444
399- def on_train_epoch_end (self , epoch_output ):
400- pass
445+ def on_train_epoch_end (self ):
446+ # inform cached logger connector epoch finished
447+ self .cached_results .has_batch_loop_finished = True
401448
402449 def log_train_epoch_end_metrics (self ,
403450 epoch_output ,
@@ -441,12 +488,10 @@ def log_train_epoch_end_metrics(self,
441488 # ------------------
442489 if is_1_0_result :
443490 # lightning module hook
444- epoch_end_log_result = self .training_epoch_end (model , epoch_output , num_optimizers )
491+ self .training_epoch_end (model , epoch_output , num_optimizers )
445492
446493 # log/aggregate metrics automatically
447494 epoch_log_metrics , epoch_progress_bar_metrics = self .__auto_reduce_results_on_epoch_end (epoch_output )
448- epoch_log_metrics .update (epoch_end_log_result .get_epoch_log_metrics ())
449- epoch_progress_bar_metrics .update (epoch_end_log_result .get_epoch_pbar_metrics ())
450495
451496 # TODO: deprecate 1.0
452497 else :
@@ -459,6 +504,14 @@ def log_train_epoch_end_metrics(self,
459504 )
460505 epoch_log_metrics , epoch_progress_bar_metrics , epoch_callback_metrics = out
461506
507+ # it will perform reduction over epoch and return log metrics
508+ cached_epoch_log_metrics = self .cached_results .get_epoch_log_metrics ()
509+ cached_epoch_pbar_metrics = self .cached_results .get_epoch_pbar_metrics ()
510+
511+ # update
512+ epoch_log_metrics .update (cached_epoch_log_metrics )
513+ epoch_progress_bar_metrics .update (cached_epoch_pbar_metrics )
514+
462515 # --------------------------
463516 # track results
464517 # --------------------------
@@ -475,15 +528,16 @@ def log_train_epoch_end_metrics(self,
475528 self .add_progress_bar_metrics (epoch_progress_bar_metrics )
476529 self .callback_metrics .update (epoch_progress_bar_metrics )
477530
531+ # reset epoch loop result for next epoch
532+ self .cached_results .reset ()
533+
478534 def training_epoch_end (self , model , epoch_output , num_optimizers ):
479535 if not is_overridden ('training_epoch_end' , model = model ):
480- return Result ()
536+ return
481537
482538 # run training_epoch_end
483539 # refresh the result for custom logging at the epoch level
484540 model ._current_fx_name = 'training_epoch_end'
485- model ._results = Result ()
486-
487541 epoch_output = self .__prepare_epoch_end_inputs (epoch_output )
488542
489543 if num_optimizers == 1 or not self .trainer .train_loop .automatic_optimization :
@@ -492,15 +546,11 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
492546 # lightningmodule hook
493547 epoch_output = model .training_epoch_end (epoch_output )
494548
495- model ._current_fx_name = ''
496-
497549 if epoch_output is not None :
498550 raise MisconfigurationException ('training_epoch_end expects a return of None. '
499551 'HINT: remove the return statement in training_epoch_end' )
500-
501- # user can ALSO log at the end of an epoch
502- new_epoch_end_logs = model ._results
503- return new_epoch_end_logs
552+ # capture logging
553+ self .trainer .logger_connector .cache_logged_metrics ()
504554
505555 def __run_legacy_training_epoch_end (
506556 self ,
@@ -527,8 +577,12 @@ def __run_legacy_training_epoch_end(
527577
528578 # run training_epoch_end
529579 # a list with a result per optimizer index
580+ model ._current_fx_name = 'training_epoch_end'
530581 epoch_output = model .training_epoch_end (epoch_output )
531582
583+ # capture logging
584+ self .trainer .logger_connector .cache_logged_metrics ()
585+
532586 if isinstance (epoch_output , Result ):
533587 epoch_log_metrics = epoch_output .epoch_log_metrics
534588 epoch_progress_bar_metrics = epoch_output .epoch_pbar_metrics
@@ -563,7 +617,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
563617 # reduce across training steps
564618 opt_outputs = time_reduced_outputs [0 ].__class__ .reduce_on_epoch_end (time_reduced_outputs )
565619
566- # with manual opt need 1+ metrics because meta is always there
620+ # with manual opt need 1 + metrics because meta is always there
567621 if opt_outputs .minimize is not None :
568622 opt_outputs .minimize = opt_outputs .minimize .mean ()
569623 epoch_log_metrics .update (opt_outputs .epoch_log_metrics )
@@ -623,12 +677,9 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):
623677
624678 def log_train_step_metrics (self , batch_output ):
625679 # when metrics should be logged
626- should_log_metrics = (
627- (self .trainer .global_step + 1 ) % self .trainer .log_every_n_steps == 0 or self .trainer .should_stop
628- )
629- if should_log_metrics or self .trainer .fast_dev_run :
680+ if self .should_update_logs or self .trainer .fast_dev_run :
630681 # logs user requested information to logger
631- metrics = batch_output . batch_log_metrics
682+ metrics = self . cached_results . get_latest_batch_log_metrics ()
632683 grad_norm_dic = batch_output .grad_norm_dic
633684 if len (metrics ) > 0 or len (grad_norm_dic ) > 0 :
634685 self .log_metrics (metrics , grad_norm_dic )
0 commit comments