Skip to content

Commit 6386f45

Browse files
authored
clarify Trainer running state atribs. (#5589)
* update Trainer is_ attributes * tests * more * isort * split * rename * check * fix
1 parent 671887f commit 6386f45

File tree

5 files changed

+46
-43
lines changed

5 files changed

+46
-43
lines changed

pytorch_lightning/plugins/sharded_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _optim_state_dict(self, optimizer):
5757

5858
def _wrap_optimizers(self, model):
5959
trainer = model.trainer
60-
if trainer.testing is True:
60+
if trainer.testing:
6161
return
6262

6363
self._reinit_with_fairscale_oss(trainer)

pytorch_lightning/trainer/deprecated_api.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -130,27 +130,3 @@ def use_single_gpu(self, val: bool) -> None:
130130
)
131131
if val:
132132
self._device_type = DeviceType.GPU
133-
134-
@property
135-
def training(self) -> bool:
136-
# todo: consider rename as `is_training`
137-
return self._running_stage == RunningStage.TRAINING
138-
139-
@training.setter
140-
def training(self, val: bool) -> None:
141-
if val:
142-
self._running_stage = RunningStage.TRAINING
143-
else:
144-
self._running_stage = None
145-
146-
@property
147-
def testing(self) -> bool:
148-
# todo: consider rename as `is_testing`
149-
return self._running_stage == RunningStage.TESTING
150-
151-
@testing.setter
152-
def testing(self, val: bool) -> None:
153-
if val:
154-
self._running_stage = RunningStage.TESTING
155-
else:
156-
self._running_stage = None

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def on_trainer_init(self):
3838
self.trainer.test_dataloaders = None
3939
self.trainer.val_dataloaders = None
4040
self.trainer.running_sanity_check = False
41-
self.trainer.testing = False
4241

4342
# when .test() is called, it sets this
4443
self.trainer.tested_ckpt_path = None

pytorch_lightning/trainer/trainer.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
5454
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
5555
from pytorch_lightning.trainer.properties import TrainerProperties
56-
from pytorch_lightning.trainer.states import TrainerState
56+
from pytorch_lightning.trainer.states import RunningStage, TrainerState
5757
from pytorch_lightning.trainer.training_loop import TrainLoop
5858
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
5959
from pytorch_lightning.tuner.tuning import Tuner
@@ -921,3 +921,47 @@ def available_plugins():
921921
Returns: List of all available plugins that are supported as string arguments.
922922
"""
923923
return PluginConnector.available_plugins()
924+
925+
@property
926+
def training(self) -> bool:
927+
return self._running_stage == RunningStage.TRAINING
928+
929+
@training.setter
930+
def training(self, val: bool) -> None:
931+
if val:
932+
self._running_stage = RunningStage.TRAINING
933+
elif self.training:
934+
self._running_stage = None
935+
936+
@property
937+
def testing(self) -> bool:
938+
return self._running_stage == RunningStage.TESTING
939+
940+
@testing.setter
941+
def testing(self, val: bool) -> None:
942+
if val:
943+
self._running_stage = RunningStage.TESTING
944+
elif self.testing:
945+
self._running_stage = None
946+
947+
@property
948+
def tuning(self) -> bool:
949+
return self._running_stage == RunningStage.TUNING
950+
951+
@tuning.setter
952+
def tuning(self, val: bool) -> None:
953+
if val:
954+
self._running_stage = RunningStage.TUNING
955+
elif self.tuning:
956+
self._running_stage = None
957+
958+
@property
959+
def evaluating(self) -> bool:
960+
return self._running_stage == RunningStage.EVALUATING
961+
962+
@evaluating.setter
963+
def evaluating(self, val: bool) -> None:
964+
if val:
965+
self._running_stage = RunningStage.EVALUATING
966+
elif self.evaluating:
967+
self._running_stage = None

tests/deprecated_api/test_remove_1-4.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,22 +89,6 @@ def test_v1_4_0_deprecated_trainer_device_distrib():
8989
assert trainer.use_horovod
9090

9191

92-
def test_v1_4_0_deprecated_trainer_phase():
93-
"""Test that Trainer attributes works fine."""
94-
trainer = Trainer()
95-
96-
assert not trainer.training
97-
assert not trainer.testing
98-
99-
trainer.training = True
100-
assert trainer.training
101-
assert not trainer.testing
102-
103-
trainer.testing = True
104-
assert not trainer.training
105-
assert trainer.testing
106-
107-
10892
def test_v1_4_0_deprecated_metrics():
10993
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
11094
with pytest.deprecated_call(match='will be removed in v1.4'):

0 commit comments

Comments
 (0)