1717import torch
1818
1919from pytorch_lightning .core .step_result import Result
20+ from pytorch_lightning .trainer .states import RunningStage
2021from pytorch_lightning .utilities import DistributedType , LightningEnum
2122
2223
23- class LoggerStages (LightningEnum ):
24- """ Train/validation/test phase in each training step.
25-
26- >>> # you can math the type with string
27- >>> LoggerStages.TRAIN == 'train'
28- True
29- """
30- TRAIN = "train"
31- VAL = "validation"
32- TEST = "test"
33-
34- @staticmethod
35- def determine_stage (stage_or_testing : Union [str , bool ]) -> 'LoggerStages' :
36- if isinstance (stage_or_testing , str ) and stage_or_testing in list (LoggerStages ):
37- return LoggerStages (stage_or_testing )
38- if isinstance (stage_or_testing , (bool , int )):
39- # stage_or_testing is trainer.testing
40- return LoggerStages .TEST if bool (stage_or_testing ) else LoggerStages .VAL
41- raise RuntimeError (f"Invalid stage { stage_or_testing } of type { type (stage_or_testing )} given" )
42-
43-
4424class ResultStoreType (LightningEnum ):
4525 INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
4626 OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"
@@ -276,7 +256,7 @@ class EpochResultStore:
276256
277257 def __init__ (self , trainer , stage ):
278258 self .trainer = trainer
279- self ._stage = stage
259+ self ._stage = RunningStage ( stage )
280260 self .reset ()
281261
282262 def __getitem__ (self , key : str ) -> Any :
@@ -371,15 +351,14 @@ def update_logger_connector(self) -> None:
371351 callback_metrics = {}
372352 batch_pbar_metrics = {}
373353 batch_log_metrics = {}
374- is_train = self ._stage in LoggerStages .TRAIN .value
375354
376355 if not self ._has_batch_loop_finished :
377356 # get pbar
378357 batch_pbar_metrics = self .get_latest_batch_pbar_metrics ()
379358 logger_connector .add_progress_bar_metrics (batch_pbar_metrics )
380359 batch_log_metrics = self .get_latest_batch_log_metrics ()
381360
382- if is_train :
361+ if self . _stage == RunningStage . TRAINING :
383362 # Only log and add to callback epoch step during evaluation, test.
384363 logger_connector ._logged_metrics .update (batch_log_metrics )
385364 callback_metrics .update (batch_pbar_metrics )
@@ -401,7 +380,7 @@ def update_logger_connector(self) -> None:
401380 callback_metrics .update (epoch_log_metrics )
402381 callback_metrics .update (forked_metrics )
403382
404- if not is_train and self .trainer .testing :
383+ if self . _stage != RunningStage . TRAINING and self .trainer .testing :
405384 logger_connector .evaluation_callback_metrics .update (callback_metrics )
406385
407386 # update callback_metrics
0 commit comments