Skip to content

Commit f989558

Browse files
Bordacarmocca
andcommitted
simplify training phase as Enum (#5419)
* simplify training phase as Enum * tests * . * . * rename * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> * rename * flake8 Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 5968910 commit f989558

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

pytorch_lightning/trainer/deprecated_api.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from pytorch_lightning.trainer.states import RunningStage
1515
from pytorch_lightning.utilities import DistributedType, DeviceType, rank_zero_warn
1616

1717

1818
class DeprecatedDistDeviceAttributes:
1919

2020
_distrib_type: DistributedType
2121
_device_type: DeviceType
22+
_running_stage: RunningStage
2223
num_gpus: int
2324

2425
@property
@@ -129,3 +130,27 @@ def use_single_gpu(self, val: bool) -> None:
129130
)
130131
if val:
131132
self._device_type = DeviceType.GPU
133+
134+
@property
135+
def training(self) -> bool:
136+
# todo: consider rename as `is_training`
137+
return self._running_stage == RunningStage.TRAINING
138+
139+
@training.setter
140+
def training(self, val: bool) -> None:
141+
if val:
142+
self._running_stage = RunningStage.TRAINING
143+
else:
144+
self._running_stage = None
145+
146+
@property
147+
def testing(self) -> bool:
148+
# todo: consider rename as `is_testing`
149+
return self._running_stage == RunningStage.TESTING
150+
151+
@testing.setter
152+
def testing(self, val: bool) -> None:
153+
if val:
154+
self._running_stage = RunningStage.TESTING
155+
else:
156+
self._running_stage = None

pytorch_lightning/trainer/states.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ class TrainerState(LightningEnum):
2323
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
2424
to indicate what is currently or was executed.
2525
26-
>>> # you can math the type with string
26+
>>> # you can compare the type with a string
2727
>>> TrainerState.RUNNING == 'RUNNING'
2828
True
29-
>>> # which is case sensitive
29+
>>> # which is case insensitive
3030
>>> TrainerState.FINISHED == 'finished'
3131
True
3232
"""
@@ -36,6 +36,19 @@ class TrainerState(LightningEnum):
3636
INTERRUPTED = 'INTERRUPTED'
3737

3838

39+
class RunningStage(LightningEnum):
40+
"""Type of train phase.
41+
42+
>>> # you can match the Enum with string
43+
>>> RunningStage.TRAINING == 'train'
44+
True
45+
"""
46+
TRAINING = 'train'
47+
EVALUATING = 'eval'
48+
TESTING = 'test'
49+
TUNING = 'tune'
50+
51+
3952
def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
4053
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
4154
which changes state to `entering` before the function execution and `exiting`

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def __init__(
294294
super().__init__()
295295
self._device_type = DeviceType.CPU
296296
self._distrib_type = None
297+
self._running_stage = None
297298

298299
# init connectors
299300
self.dev_debugger = InternalDebugger(self)

tests/deprecated_api/test_remove_1-4.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_v1_4_0_deprecated_imports():
4242
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401
4343

4444

45-
def test_v1_4_0_deprecated_trainer_attributes():
45+
def test_v1_4_0_deprecated_trainer_device_distrib():
4646
"""Test that Trainer attributes works fine."""
4747
trainer = Trainer()
4848
trainer._distrib_type = None
@@ -89,6 +89,22 @@ def test_v1_4_0_deprecated_trainer_attributes():
8989
assert trainer.use_horovod
9090

9191

92+
def test_v1_4_0_deprecated_trainer_phase():
93+
"""Test that Trainer attributes works fine."""
94+
trainer = Trainer()
95+
96+
assert not trainer.training
97+
assert not trainer.testing
98+
99+
trainer.training = True
100+
assert trainer.training
101+
assert not trainer.testing
102+
103+
trainer.testing = True
104+
assert not trainer.training
105+
assert trainer.testing
106+
107+
92108
def test_v1_4_0_deprecated_metrics():
93109
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
94110
with pytest.deprecated_call(match='will be removed in v1.4'):

0 commit comments

Comments
 (0)