Skip to content

Commit 54d20dc

Browse files
authored
Refactor: clean trainer device & distrib getters (#5300)
* warnings * . * . * flake8 * . * . * . * use_tpu * use_dp * . * use_ddp * . * use_horovod * . * . * .
1 parent 2373858 commit 54d20dc

27 files changed

+143
-124
lines changed

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,21 @@ def select_accelerator(self):
185185
# ----------------------------------
186186
# choose an accelerator for the user
187187
# ----------------------------------
188-
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
188+
use_slurm_ddp = (
189+
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
190+
and self.trainer.is_slurm_managing_tasks
191+
)
189192

190193
# torchelastic or general non_slurm ddp
191194
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
192-
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
195+
use_torchelastic_ddp = (
196+
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN) and te_flags_passed
197+
)
193198

194-
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn"
195-
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu"
199+
use_ddp_cpu_spawn = (
200+
self.trainer._distrib_type in (DistributedType.DDP, DistributedType.DDP_SPAWN)
201+
and self.trainer._device_type == DeviceType.CPU
202+
)
196203

197204
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self._is_using_torchelastic()
198205
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.trainer.is_slurm_managing_tasks
@@ -204,8 +211,9 @@ def select_accelerator(self):
204211

205212
cluster_env = self._select_environment()
206213

214+
# TODO: clean-up this branching as most just select class and uses the very same arguments
207215
# choose the appropriate accelerator backend
208-
if self.trainer.use_ddp2:
216+
if self.trainer._distrib_type == DistributedType.DDP2:
209217
accelerator_backend = accelerators.DDP2Accelerator(
210218
self.trainer,
211219
cluster_env,
@@ -240,7 +248,7 @@ def select_accelerator(self):
240248
self.trainer.plugin_connector.ddp_plugin
241249
)
242250

243-
elif use_ddp_spawn:
251+
elif self.trainer._distrib_type == DistributedType.DDP_SPAWN:
244252
accelerator_backend = accelerators.DDPSpawnAccelerator(
245253
self.trainer,
246254
nprocs=self.trainer.num_processes,
@@ -263,16 +271,16 @@ def select_accelerator(self):
263271
ddp_plugin=self.trainer.plugin_connector.ddp_plugin
264272
)
265273

266-
elif self.trainer.use_dp:
274+
elif self.trainer._distrib_type == DistributedType.DP:
267275
accelerator_backend = accelerators.DataParallelAccelerator(self.trainer, cluster_env)
268276

269-
elif self.trainer.use_horovod:
277+
elif self.trainer._distrib_type == DistributedType.HOROVOD:
270278
accelerator_backend = accelerators.HorovodAccelerator(self.trainer, cluster_env)
271279

272-
elif self.trainer.use_single_gpu:
280+
elif self.trainer._device_type == DeviceType.GPU and self.trainer.num_gpus == 1:
273281
accelerator_backend = accelerators.GPUAccelerator(self.trainer, cluster_env)
274282

275-
elif self.trainer.use_tpu:
283+
elif self.trainer._device_type == DeviceType.TPU:
276284
accelerator_backend = accelerators.TPUAccelerator(self.trainer, cluster_env)
277285

278286
elif self.trainer.distributed_backend is None:
@@ -347,13 +355,16 @@ def set_distributed_mode(self):
347355
self._set_horovod_backend()
348356

349357
# throw error to force user ddp or ddp2 choice
350-
if self.trainer.num_nodes > 1 and self.trainer._distrib_type not in (DistributedType.DDP2, DistributedType.DDP):
358+
_ddp = (DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
359+
if (self.trainer.num_nodes > 1 and self.trainer._distrib_type not in _ddp):
351360
raise MisconfigurationException(
352361
'DataParallel does not support num_nodes > 1. Switching to DistributedDataParallel for you. '
353362
'To silence this warning set `accelerator="ddp"` or `accelerator="ddp2"`'
354363
)
355364

356-
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
365+
rank_zero_info(
366+
f'GPU available: {torch.cuda.is_available()}, used: {self.trainer._device_type == DeviceType.GPU}'
367+
)
357368
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
358369
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')
359370

@@ -366,7 +377,7 @@ def _set_horovod_backend(self):
366377

367378
# Initialize Horovod to get rank / size info
368379
hvd.init()
369-
if self.trainer.on_gpu:
380+
if self.trainer._device_type == DeviceType.GPU:
370381
# Horovod assigns one local GPU per process
371382
self.trainer.root_gpu = hvd.local_rank()
372383

pytorch_lightning/accelerators/horovod_accelerator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2121
from pytorch_lightning.cluster_environments import ClusterEnvironment
22-
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType
22+
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType, DeviceType
2323
from pytorch_lightning.utilities.distributed import rank_zero_only
2424

2525
if _HOROVOD_AVAILABLE:
@@ -46,7 +46,7 @@ def setup(self, model):
4646
# call setup after the ddp process has connected
4747
self.trainer.call_setup_hook(model)
4848

49-
if torch.cuda.is_available() and self.trainer.on_gpu:
49+
if torch.cuda.is_available() and self.trainer._device_type == DeviceType.GPU:
5050
# Horovod: pin GPU to local rank
5151
assert self.trainer.root_gpu == hvd.local_rank()
5252
torch.cuda.set_device(self.trainer.root_gpu)
@@ -116,7 +116,7 @@ def train(self):
116116
return results
117117

118118
def _step(self, model_step: Callable, args):
119-
if self.trainer.on_gpu:
119+
if self.trainer._device_type == DeviceType.GPU:
120120
args[0] = self.batch_to_device(args[0], hvd.local_rank())
121121

122122
if self.trainer.amp_backend == AMPType.NATIVE:
@@ -141,7 +141,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
141141
optimizer.synchronize()
142142

143143
def on_train_epoch_end(self, outputs):
144-
hvd.join(hvd.local_rank() if self.trainer.on_gpu else -1)
144+
hvd.join(hvd.local_rank() if self.trainer._device_type == DeviceType.GPU else -1)
145145

146146
def barrier(self, name: Optional[str] = None):
147147
hvd.join()

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from typing import Dict, List, Tuple
2828

2929
from pytorch_lightning.callbacks.base import Callback
30-
from pytorch_lightning.utilities import rank_zero_only
30+
from pytorch_lightning.utilities import rank_zero_only, DeviceType
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232
from pytorch_lightning.utilities.parsing import AttributeDict
3333

@@ -104,7 +104,7 @@ def on_train_start(self, trainer, *args, **kwargs):
104104
'Cannot use GPUStatsMonitor callback with Trainer that has no logger.'
105105
)
106106

107-
if not trainer.on_gpu:
107+
if trainer._device_type != DeviceType.GPU:
108108
raise MisconfigurationException(
109109
'You are using GPUStatsMonitor but are not running on GPU'
110110
f' since gpus attribute in Trainer is set to {trainer.gpus}.'

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,8 @@ def __init__(self, *args, **kwargs):
8585
#: Pointer to the logger object
8686
self.logger = None
8787

88-
#: True if using dp
89-
self.use_dp = False
90-
91-
#: True if using ddp
92-
self.use_ddp = False
93-
94-
#: True if using ddp2
95-
self.use_ddp2 = False
96-
97-
# True if on tpu
98-
self.use_tpu = False
88+
self._distrib_type = None
89+
self._device_type = None
9990

10091
#: True if using amp
10192
self.use_amp = False

pytorch_lightning/core/memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn as nn
2424
from torch.utils.hooks import RemovableHandle
2525

26-
from pytorch_lightning.utilities import AMPType
26+
from pytorch_lightning.utilities import AMPType, DeviceType
2727

2828
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
2929
UNKNOWN_SIZE = "?"
@@ -229,7 +229,7 @@ def _forward_example_input(self) -> None:
229229
input_ = model.example_input_array
230230
input_ = model.transfer_batch_to_device(input_, model.device)
231231

232-
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
232+
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and trainer._device_type != DeviceType.TPU:
233233
model.forward = torch.cuda.amp.autocast()(model.forward)
234234

235235
mode = model.training

pytorch_lightning/core/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from torch.optim.optimizer import Optimizer
1919

20-
from pytorch_lightning.utilities import _TPU_AVAILABLE
20+
from pytorch_lightning.utilities import _TPU_AVAILABLE, DeviceType
2121
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2222

2323
if _TPU_AVAILABLE:
@@ -125,7 +125,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
125125
optimizer = self._optimizer
126126
model = trainer.get_model()
127127

128-
if trainer.on_tpu:
128+
if trainer._device_type == DeviceType.TPU:
129129
with trainer.profiler.profile(profiler_name):
130130
xm.optimizer_step(optimizer, optimizer_args={'closure': closure, **kwargs})
131131

pytorch_lightning/overrides/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _worker(i, module, input, kwargs, device=None):
285285
if output is None:
286286
warn_missing_output(fx_called)
287287

288-
if output is not None and (module.use_dp or module.use_ddp2):
288+
if output is not None and module._distrib_type in ('dp', 'ddp2'):
289289
auto_squeeze_dim_zeros(output)
290290
# ---------------
291291

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
2424
from pytorch_lightning.plugins.plugin import LightningPlugin
25+
from pytorch_lightning.utilities import DeviceType
2526

2627

2728
class DDPPlugin(LightningPlugin):
@@ -95,7 +96,7 @@ def init_ddp_connection(
9596
os.environ["MASTER_ADDR"] = str(cluster_environment.master_address())
9697
os.environ["MASTER_PORT"] = str(cluster_environment.master_port())
9798
os.environ["WORLD_SIZE"] = str(cluster_environment.world_size())
98-
torch_backend = "nccl" if trainer.on_gpu else "gloo"
99+
torch_backend = "nccl" if trainer._device_type == DeviceType.GPU else "gloo"
99100

100101
if not torch_distrib.is_initialized():
101102
log.info(

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
import pytorch_lightning
2323
from pytorch_lightning.core.lightning import LightningModule
24-
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
24+
from pytorch_lightning.utilities import (
25+
_APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn, DeviceType)
2526
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
2627
from pytorch_lightning.utilities.cloud_io import load as pl_load
2728
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -50,26 +51,26 @@ def restore_weights(self, model: LightningModule) -> None:
5051
3. don't restore
5152
"""
5253
# clear cache before restore
53-
if self.trainer.on_gpu:
54+
if self.trainer._device_type == DeviceType.GPU:
5455
torch.cuda.empty_cache()
5556

5657
# 1. Attempt to restore states from HPC checkpoint
5758
dir_path_hpc = str(self.trainer.weights_save_path)
5859
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
5960
if max_suffix is not None:
6061
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
61-
self.hpc_load(checkpoint_path, self.trainer.on_gpu)
62+
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU)
6263
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
6364

6465
# 2. Attempt to restore states from `resume_from_checkpoint` file
6566
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
66-
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)
67+
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
6768

6869
# wait for all to catch up
6970
self.trainer.accelerator_backend.barrier('TrainerIOMixin.restore_weights')
7071

7172
# clear cache after restore
72-
if self.trainer.on_gpu:
73+
if self.trainer._device_type == DeviceType.GPU:
7374
torch.cuda.empty_cache()
7475

7576
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
@@ -291,7 +292,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
291292

292293
# dump amp scaling
293294
if (self.trainer.amp_backend == AMPType.NATIVE
294-
and not self.trainer.use_tpu
295+
and self.trainer._device_type != DeviceType.TPU
295296
and self.trainer.scaler is not None):
296297
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
297298
elif self.trainer.amp_backend == AMPType.APEX:

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from pytorch_lightning.core.step_result import Result
21+
from pytorch_lightning.utilities import DistributedType
2122

2223

2324
class LoggerStages(str, Enum):
@@ -343,7 +344,7 @@ def cache_result(self) -> None:
343344
hook_result.detach()
344345
if self.trainer.move_metrics_to_cpu:
345346
hook_result.cpu()
346-
elif self.trainer.use_dp:
347+
elif self.trainer._distrib_type == DistributedType.DP:
347348
hook_result.to(torch.device("cuda", self.trainer.root_gpu))
348349

349350
self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)

0 commit comments

Comments
 (0)