Skip to content

Commit 8835c26

Browse files
williamFalconSeanNaren
authored andcommitted
ref: unify slurm and TE under backendPlugin 2/n (#4580)
(cherry picked from commit bfaf014)
1 parent e8c00bc commit 8835c26

File tree

5 files changed

+7
-221
lines changed

5 files changed

+7
-221
lines changed

pytorch_lightning/accelerators/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
2121
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
2222
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
23-
from pytorch_lightning.accelerators.ddp_slurm_accelerator import DDPSLURMAccelerator
24-
from pytorch_lightning.accelerators.ddp_torchelastic_accelerator import DDPTorchElasticAccelerator
23+
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
2524
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_accelerator import DDPCPUTorchElasticAccelerator
2625
from pytorch_lightning.accelerators.ddp_cpu_slurm_accelerator import DDPCPUSLURMAccelerator
2726
from pytorch_lightning.accelerators.accelerator import Accelerator

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def select_accelerator(self):
227227
)
228228

229229
elif use_slurm_ddp:
230-
accelerator_backend = accelerators.DDPSLURMAccelerator(
230+
accelerator_backend = accelerators.DDPHPCAccelerator(
231231
self.trainer,
232232
cluster_env,
233233
self.trainer.plugin_connector.ddp_plugin
@@ -241,7 +241,7 @@ def select_accelerator(self):
241241
)
242242

243243
elif use_torchelastic_ddp:
244-
accelerator_backend = accelerators.DDPTorchElasticAccelerator(
244+
accelerator_backend = accelerators.DDPHPCAccelerator(
245245
self.trainer,
246246
cluster_env,
247247
self.trainer.plugin_connector.ddp_plugin

pytorch_lightning/accelerators/ddp_slurm_accelerator.py renamed to pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.utilities import AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2828

29+
2930
try:
3031
from hydra.utils import to_absolute_path, get_original_cwd
3132
from hydra.core.hydra_config import HydraConfig
@@ -35,12 +36,7 @@
3536
HYDRA_AVAILABLE = True
3637

3738

38-
# -------------------------------------------
39-
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
40-
# TEMP CLASS WHILE WE DECOUPLE SLURM FROM DDP
41-
# !!!!!!!!!!!!!! NOTE !!!!!!!!!!!!!!!!!!!!!!
42-
# -------------------------------------------
43-
class DDPSLURMAccelerator(Accelerator):
39+
class DDPHPCAccelerator(Accelerator):
4440

4541
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
4642
super().__init__(trainer, cluster_environment, ddp_plugin)

pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py

Lines changed: 0 additions & 209 deletions
This file was deleted.

tests/backends/test_accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_accelerator_choice_ddp_slurm(tmpdir):
111111
class CB(Callback):
112112
def on_fit_start(self, trainer, pl_module):
113113
assert trainer.use_ddp
114-
assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMAccelerator)
114+
assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator)
115115
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
116116
assert trainer.accelerator_backend.task_idx == 10
117117
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
@@ -172,7 +172,7 @@ def test_accelerator_choice_ddp_te(tmpdir):
172172
class CB(Callback):
173173
def on_fit_start(self, trainer, pl_module):
174174
assert trainer.use_ddp
175-
assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticAccelerator)
175+
assert isinstance(trainer.accelerator_backend, accelerators.DDPHPCAccelerator)
176176
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
177177
assert trainer.accelerator_backend.task_idx == 10
178178
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx

0 commit comments

Comments
 (0)