Skip to content

Commit 7449ce2

Browse files
Bordacarmocca
andauthored
simplify training phase as Enum (#5419)
* simplify training phase as Enum * tests * . * . * rename * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> * rename * flake8 Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 7e4d6cb commit 7449ce2

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

pytorch_lightning/trainer/deprecated_api.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from pytorch_lightning.trainer.states import RunningStage
1515
from pytorch_lightning.utilities import DistributedType, DeviceType, rank_zero_warn
1616

1717

1818
class DeprecatedDistDeviceAttributes:
1919

2020
_distrib_type: DistributedType
2121
_device_type: DeviceType
22+
_running_stage: RunningStage
2223
num_gpus: int
2324

2425
@property
@@ -129,3 +130,27 @@ def use_single_gpu(self, val: bool) -> None:
129130
)
130131
if val:
131132
self._device_type = DeviceType.GPU
133+
134+
@property
135+
def training(self) -> bool:
136+
# todo: consider rename as `is_training`
137+
return self._running_stage == RunningStage.TRAINING
138+
139+
@training.setter
140+
def training(self, val: bool) -> None:
141+
if val:
142+
self._running_stage = RunningStage.TRAINING
143+
else:
144+
self._running_stage = None
145+
146+
@property
147+
def testing(self) -> bool:
148+
# todo: consider rename as `is_testing`
149+
return self._running_stage == RunningStage.TESTING
150+
151+
@testing.setter
152+
def testing(self, val: bool) -> None:
153+
if val:
154+
self._running_stage = RunningStage.TESTING
155+
else:
156+
self._running_stage = None

pytorch_lightning/trainer/states.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ class TrainerState(LightningEnum):
2323
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
2424
to indicate what is currently or was executed.
2525
26-
>>> # you can math the type with string
26+
>>> # you can compare the type with a string
2727
>>> TrainerState.RUNNING == 'RUNNING'
2828
True
29-
>>> # which is case sensitive
29+
>>> # which is case insensitive
3030
>>> TrainerState.FINISHED == 'finished'
3131
True
3232
"""
@@ -36,6 +36,19 @@ class TrainerState(LightningEnum):
3636
INTERRUPTED = 'INTERRUPTED'
3737

3838

39+
class RunningStage(LightningEnum):
40+
"""Type of train phase.
41+
42+
>>> # you can match the Enum with string
43+
>>> RunningStage.TRAINING == 'train'
44+
True
45+
"""
46+
TRAINING = 'train'
47+
EVALUATING = 'eval'
48+
TESTING = 'test'
49+
TUNING = 'tune'
50+
51+
3952
def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
4053
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
4154
which changes state to `entering` before the function execution and `exiting`

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def __init__(
294294
super().__init__()
295295
self._device_type = DeviceType.CPU
296296
self._distrib_type = None
297+
self._running_stage = None
297298

298299
# init connectors
299300
self.dev_debugger = InternalDebugger(self)

tests/deprecated_api/test_remove_1-4.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_v1_4_0_deprecated_imports():
3737
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils # noqa: F811 F401
3838

3939

40-
def test_v1_4_0_deprecated_trainer_attributes():
40+
def test_v1_4_0_deprecated_trainer_device_distrib():
4141
"""Test that Trainer attributes works fine."""
4242
trainer = Trainer()
4343
trainer._distrib_type = None
@@ -84,6 +84,22 @@ def test_v1_4_0_deprecated_trainer_attributes():
8484
assert trainer.use_horovod
8585

8686

87+
def test_v1_4_0_deprecated_trainer_phase():
88+
"""Test that Trainer attributes works fine."""
89+
trainer = Trainer()
90+
91+
assert not trainer.training
92+
assert not trainer.testing
93+
94+
trainer.training = True
95+
assert trainer.training
96+
assert not trainer.testing
97+
98+
trainer.testing = True
99+
assert not trainer.training
100+
assert trainer.testing
101+
102+
87103
def test_v1_4_0_deprecated_metrics():
88104
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
89105
with pytest.deprecated_call(match='will be removed in v1.4'):

0 commit comments

Comments
 (0)