Skip to content

Commit 6f2f475

Browse files
committed
drop LoggerStages
1 parent 9d165f6 commit 6f2f475

File tree

2 files changed

+11
-28
lines changed

2 files changed

+11
-28
lines changed

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,10 @@
1717
import torch
1818

1919
from pytorch_lightning.core.step_result import Result
20+
from pytorch_lightning.trainer.states import RunningStage
2021
from pytorch_lightning.utilities import DistributedType, LightningEnum
2122

2223

23-
class LoggerStages(LightningEnum):
24-
""" Train/validation/test phase in each training step.
25-
26-
>>> # you can math the type with string
27-
>>> LoggerStages.TRAIN == 'train'
28-
True
29-
"""
30-
TRAIN = "train"
31-
VAL = "validation"
32-
TEST = "test"
33-
34-
@staticmethod
35-
def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages':
36-
if isinstance(stage_or_testing, str) and stage_or_testing in list(LoggerStages):
37-
return LoggerStages(stage_or_testing)
38-
if isinstance(stage_or_testing, (bool, int)):
39-
# stage_or_testing is trainer.testing
40-
return LoggerStages.TEST if bool(stage_or_testing) else LoggerStages.VAL
41-
raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given")
42-
43-
4424
class ResultStoreType(LightningEnum):
4525
INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
4626
OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"
@@ -276,7 +256,7 @@ class EpochResultStore:
276256

277257
def __init__(self, trainer, stage):
278258
self.trainer = trainer
279-
self._stage = stage
259+
self._stage = RunningStage(stage)
280260
self.reset()
281261

282262
def __getitem__(self, key: str) -> Any:
@@ -371,15 +351,14 @@ def update_logger_connector(self) -> None:
371351
callback_metrics = {}
372352
batch_pbar_metrics = {}
373353
batch_log_metrics = {}
374-
is_train = self._stage in LoggerStages.TRAIN.value
375354

376355
if not self._has_batch_loop_finished:
377356
# get pbar
378357
batch_pbar_metrics = self.get_latest_batch_pbar_metrics()
379358
logger_connector.add_progress_bar_metrics(batch_pbar_metrics)
380359
batch_log_metrics = self.get_latest_batch_log_metrics()
381360

382-
if is_train:
361+
if self._stage == RunningStage.TRAINING:
383362
# Only log and add to callback epoch step during evaluation, test.
384363
logger_connector._logged_metrics.update(batch_log_metrics)
385364
callback_metrics.update(batch_pbar_metrics)
@@ -401,7 +380,7 @@ def update_logger_connector(self) -> None:
401380
callback_metrics.update(epoch_log_metrics)
402381
callback_metrics.update(forked_metrics)
403382

404-
if not is_train and self.trainer.testing:
383+
if self._stage != RunningStage.TRAINING and self.trainer.testing:
405384
logger_connector.evaluation_callback_metrics.update(callback_metrics)
406385

407386
# update callback_metrics

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
from pytorch_lightning.core.step_result import Result
2323
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
2424
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
25-
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
25+
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
2626
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
27+
from pytorch_lightning.trainer.states import RunningStage
2728
from pytorch_lightning.utilities import DeviceType, flatten_dict
2829
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2930
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -37,7 +38,7 @@ def __init__(self, trainer):
3738
self._logged_metrics = MetricsHolder()
3839
self._progress_bar_metrics = MetricsHolder()
3940
self.eval_loop_results = []
40-
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages}
41+
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage}
4142
self._callback_hook_validator = CallbackHookNameValidator()
4243
self._current_stage = None
4344

@@ -91,7 +92,10 @@ def set_metrics(self, key: str, val: Any) -> None:
9192
metrics_holder.reset(val)
9293

9394
def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None:
94-
self._current_stage = LoggerStages.determine_stage(stage_or_testing)
95+
if isinstance(stage_or_testing, (bool, int)):
96+
self._current_stage = RunningStage.TESTING if bool(stage_or_testing) else RunningStage.EVALUATING
97+
else:
98+
self._current_stage = RunningStage.from_str(stage_or_testing)
9599
if reset:
96100
self.cached_results.reset()
97101

0 commit comments

Comments
 (0)