Skip to content

Commit ce91795

Browse files
Bordatchaton
andauthored
ref: clean config [1/n] add intermediate setters (#4990)
* add intermediate setters * show inputs * fix options * move * fix * less talk * fix * talk less * str * cases * rename Co-authored-by: chaton <[email protected]>
1 parent 068502f commit ce91795

File tree

8 files changed

+228
-27
lines changed

8 files changed

+228
-27
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,4 @@ def block_ddp_plugin_sync_behaviour(self):
256256
yield cm
257257

258258

259-
# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
260-
class BackendType(Enum):
261-
DP = 'dp'
262-
DDP = 'ddp'
263-
DDP2 = 'ddp2'
264-
DDP_SPAWN = 'ddp_spawn'
265-
# decuple distrib and device
266-
DDP_CPU = 'ddp_cpu'
267-
HOROVOD = 'horovod'
268-
# this is rather device
269-
TPU = 'tpu'
259+

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def set_distributed_mode(self):
335335
self.trainer.use_ddp = True
336336
self.trainer.data_parallel_device_ids = None
337337
self.trainer.on_gpu = False
338+
self.trainer.on_cpu = True
338339
elif self.trainer.distributed_backend == "horovod":
339340
self._set_horovod_backend()
340341

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121

2222

2323
class LoggerStages(str, Enum):
24+
""" Train/validation/test phase in each training step.
25+
26+
>>> # you can math the type with string
27+
>>> LoggerStages.TRAIN == 'train'
28+
True
29+
"""
2430
TRAIN = "train"
2531
VAL = "validation"
2632
TEST = "test"
@@ -35,7 +41,7 @@ def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages':
3541
raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given")
3642

3743

38-
class ResultStoreType(Enum):
44+
class ResultStoreType(str, Enum):
3945
INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
4046
OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"
4147

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pytorch_lightning.utilities import rank_zero_warn, DistributedType, DeviceType
16+
17+
18+
class DeprecatedDistDeviceAttributes:
19+
20+
_distrib_type: DistributedType
21+
_device_type: DeviceType
22+
num_gpus: int
23+
24+
@property
25+
def on_cpu(self) -> bool:
26+
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
27+
return self._device_type and self._device_type == DeviceType.CPU
28+
29+
@on_cpu.setter
30+
def on_cpu(self, val: bool) -> None:
31+
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
32+
if val:
33+
self._device_type = DeviceType.CPU
34+
35+
@property
36+
def on_tpu(self) -> bool:
37+
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
38+
return self._device_type and self._device_type == DeviceType.TPU
39+
40+
@on_tpu.setter
41+
def on_tpu(self, val: bool) -> None:
42+
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
43+
# todo add logic that it cannot be set if TPU is missing
44+
if val:
45+
self._device_type = DeviceType.TPU
46+
47+
@property
48+
def use_tpu(self) -> bool:
49+
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
50+
return self._device_type and self._device_type == DeviceType.TPU
51+
52+
@use_tpu.setter
53+
def use_tpu(self, val: bool) -> None:
54+
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
55+
# todo add logic that it cannot be set if TPU is missing
56+
if val:
57+
self._device_type = DeviceType.TPU
58+
59+
@property
60+
def on_gpu(self) -> bool:
61+
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
62+
return self._device_type and self._device_type == DeviceType.GPU
63+
64+
@on_gpu.setter
65+
def on_gpu(self, val: bool) -> None:
66+
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
67+
# todo add logic that it cannot be set if GPU is missing
68+
if val:
69+
self._device_type = DeviceType.GPU
70+
71+
@property
72+
def use_dp(self) -> bool:
73+
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
74+
return self._device_type and self._distrib_type == DistributedType.DP
75+
76+
@use_dp.setter
77+
def use_dp(self, val: bool) -> None:
78+
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
79+
if val:
80+
self._distrib_type = DistributedType.DP
81+
82+
@property
83+
def use_ddp(self) -> bool:
84+
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
85+
return self._device_type and self._distrib_type == DistributedType.DDP
86+
87+
@use_ddp.setter
88+
def use_ddp(self, val: bool) -> None:
89+
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
90+
if val:
91+
self._distrib_type = DistributedType.DDP
92+
93+
@property
94+
def use_ddp2(self) -> bool:
95+
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
96+
return self._device_type and self._distrib_type == DistributedType.DDP2
97+
98+
@use_ddp2.setter
99+
def use_ddp2(self, val: bool) -> None:
100+
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
101+
if val:
102+
self._distrib_type = DistributedType.DDP2
103+
104+
@property
105+
def use_horovod(self) -> bool:
106+
# rank_zero_warn(
107+
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
108+
# )
109+
return self._device_type and self._distrib_type == DistributedType.HOROVOD
110+
111+
@use_horovod.setter
112+
def use_horovod(self, val: bool) -> None:
113+
# rank_zero_warn(
114+
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
115+
# )
116+
if val:
117+
self._distrib_type = DistributedType.HOROVOD
118+
119+
@property
120+
def use_single_gpu(self) -> bool:
121+
# rank_zero_warn(
122+
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
123+
# )
124+
# todo, limiting to exclude DDP2 is not clear but it comes from connectors...
125+
return (self._device_type and self._device_type == DeviceType.GPU
126+
and self.num_gpus == 1
127+
and self._distrib_type not in (DistributedType.DDP2, ))
128+
129+
@use_single_gpu.setter
130+
def use_single_gpu(self, val: bool) -> None:
131+
# rank_zero_warn(
132+
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
133+
# )
134+
if val:
135+
self._device_type = DeviceType.GPU

pytorch_lightning/trainer/states.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,17 @@
1919
import pytorch_lightning
2020

2121

22-
class TrainerState(Enum):
22+
class TrainerState(str, Enum):
2323
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
24-
to indicate what is currently or was executed. """
24+
to indicate what is currently or was executed.
25+
26+
>>> # you can math the type with string
27+
>>> TrainerState.RUNNING == 'RUNNING'
28+
True
29+
>>> # which is case sensitive
30+
>>> TrainerState.FINISHED == 'finished'
31+
False
32+
"""
2533
INITIALIZING = 'INITIALIZING'
2634
RUNNING = 'RUNNING'
2735
FINISHED = 'FINISHED'

pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
from pytorch_lightning import _logger as log
2525
from pytorch_lightning.accelerators.accelerator import Accelerator
2626
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
27-
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator
28-
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
27+
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
28+
from pytorch_lightning.callbacks import Callback
2929
from pytorch_lightning.core.datamodule import LightningDataModule
3030
from pytorch_lightning.core.lightning import LightningModule
31-
from pytorch_lightning.core.memory import ModelSummary
3231
from pytorch_lightning.core.step_result import EvalResult, Result
3332
from pytorch_lightning.loggers import LightningLoggerBase
3433
from pytorch_lightning.plugins.plugin_connector import PluginConnector
@@ -53,11 +52,11 @@
5352
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
5453
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
5554
from pytorch_lightning.trainer.properties import TrainerProperties
56-
from pytorch_lightning.trainer.states import TrainerState, trainer_state
55+
from pytorch_lightning.trainer.states import TrainerState
5756
from pytorch_lightning.trainer.training_loop import TrainLoop
5857
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
5958
from pytorch_lightning.tuner.tuning import Tuner
60-
from pytorch_lightning.utilities import rank_zero_warn
59+
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
6160
from pytorch_lightning.utilities.cloud_io import load as pl_load
6261
from pytorch_lightning.utilities.debugging import InternalDebugger
6362
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -78,6 +77,7 @@ class Trainer(
7877
TrainerLoggingMixin,
7978
TrainerTrainingTricksMixin,
8079
TrainerDataLoadingMixin,
80+
DeprecatedDistDeviceAttributes,
8181
):
8282
@overwrite_by_env_vars
8383
def __init__(
@@ -284,6 +284,8 @@ def __init__(
284284
handle AMP, TPU, accumulated_gradients, etc..
285285
"""
286286
super().__init__()
287+
self._device_type = DeviceType.CPU
288+
self._distrib_type = None
287289

288290
# init connectors
289291
self.dev_debugger = InternalDebugger(self)

pytorch_lightning/utilities/__init__.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import platform
1717
from distutils.version import LooseVersion
1818
from enum import Enum
19+
from typing import Union
1920

2021
import numpy
2122
import torch
@@ -66,6 +67,62 @@ def _module_available(module_path: str) -> bool:
6667
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
6768

6869

69-
class AMPType(Enum):
70+
class LightningEnum(str, Enum):
71+
""" Type of any enumerator with allowed comparison to string invariant to cases. """
72+
73+
@classmethod
74+
def from_str(cls, value: str) -> 'LightningEnum':
75+
statuses = [status for status in dir(cls) if not status.startswith('_')]
76+
for st in statuses:
77+
if st.lower() == value.lower():
78+
return getattr(cls, st)
79+
return None
80+
81+
def __eq__(self, other: Union[str, Enum]) -> bool:
82+
other = other.value if isinstance(other, Enum) else str(other)
83+
return self.value.lower() == other.lower()
84+
85+
86+
class AMPType(LightningEnum):
87+
"""Type of Automatic Mixed Precission used for training.
88+
89+
>>> # you can math the type with string
90+
>>> AMPType.APEX == 'apex'
91+
True
92+
"""
7093
APEX = 'apex'
7194
NATIVE = 'native'
95+
96+
97+
class DistributedType(LightningEnum):
98+
""" Define type of ditributed computing.
99+
100+
>>> # you can math the type with string
101+
>>> DistributedType.DDP == 'ddp'
102+
True
103+
>>> # which is case invariant
104+
>>> DistributedType.DDP2 == 'DDP2'
105+
True
106+
"""
107+
DP = 'dp'
108+
DDP = 'ddp'
109+
DDP2 = 'ddp2'
110+
DDP_SPAWN = 'ddp_spawn'
111+
HOROVOD = 'horovod'
112+
113+
114+
class DeviceType(LightningEnum):
115+
""" Define Device type byt its nature - acceleatrors.
116+
117+
>>> DeviceType.CPU == DeviceType.from_str('cpu')
118+
True
119+
>>> # you can math the type with string
120+
>>> DeviceType.GPU == 'GPU'
121+
True
122+
>>> # which is case invariant
123+
>>> DeviceType.TPU == 'tpu'
124+
True
125+
"""
126+
CPU = 'CPU'
127+
GPU = 'GPU'
128+
TPU = 'TPU'

tests/trainer/test_trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,15 +1332,17 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
13321332
),
13331333
],
13341334
)
1335+
# Todo: mock nb Gpus so all these tests can run on any device
1336+
# todo: think about simplification, that the the expected will be just a list use_xxx which shall be true...
13351337
def test_trainer_config(trainer_kwargs, expected):
13361338
trainer = Trainer(**trainer_kwargs)
1337-
assert trainer.use_dp is expected["use_dp"]
1338-
assert trainer.use_ddp is expected["use_ddp"]
1339-
assert trainer.use_ddp2 is expected["use_ddp2"]
1340-
assert trainer.num_gpus == expected["num_gpus"]
1341-
assert trainer.on_gpu is expected["on_gpu"]
1342-
assert trainer.use_single_gpu is expected["use_single_gpu"]
1343-
assert trainer.num_processes == expected["num_processes"]
1339+
assert trainer.use_dp is expected["use_dp"], 'for input: %s' % trainer_kwargs
1340+
assert trainer.use_ddp is expected["use_ddp"], 'for input: %s' % trainer_kwargs
1341+
assert trainer.use_ddp2 is expected["use_ddp2"], 'for input: %s' % trainer_kwargs
1342+
assert trainer.num_gpus == expected["num_gpus"], 'for input: %s' % trainer_kwargs
1343+
assert trainer.on_gpu is expected["on_gpu"], 'for input: %s' % trainer_kwargs
1344+
assert trainer.use_single_gpu is expected["use_single_gpu"], 'for input: %s' % trainer_kwargs
1345+
assert trainer.num_processes == expected["num_processes"], 'for input: %s' % trainer_kwargs
13441346

13451347

13461348
def test_trainer_subclassing():

0 commit comments

Comments
 (0)