Skip to content

Commit 2c21f7d

Browse files
ref: adding compute environments (2/n) (#3842)
* ref: adding compute environments (2/n) * ref: adding compute environments (2/n) * ref: adding compute environments (2/n) * ref: adding compute environments (2/n)
1 parent a628d18 commit 2c21f7d

18 files changed

+81
-31
lines changed

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
88
from pytorch_lightning.utilities.exceptions import MisconfigurationException
99
from pytorch_lightning import _logger as log
10+
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
11+
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
1012

1113
try:
1214
import torch_xla
@@ -40,9 +42,12 @@ def on_trainer_init(
4042
sync_batchnorm,
4143
benchmark,
4244
replace_sampler_ddp,
43-
deterministic
45+
deterministic,
46+
cluster_environment
4447
):
4548
self.trainer.deterministic = deterministic
49+
self.cluster_environment = cluster_environment
50+
4651
torch.backends.cudnn.deterministic = self.trainer.deterministic
4752
if self.trainer.deterministic:
4853
# fixing non-deterministic part of horovod
@@ -123,6 +128,22 @@ def on_trainer_init(
123128

124129
self.trainer.replace_sampler_ddp = replace_sampler_ddp
125130

131+
def _select_environment(self):
132+
env = None
133+
134+
# in priority: user environment, torchelastic (which is a generic environment), slurm
135+
if self.cluster_environment is not None:
136+
env = self.cluster_environment
137+
elif self._is_using_torchelastic():
138+
env = TorchElasticEnvironment()
139+
elif self.trainer.is_slurm_managing_tasks:
140+
env = SLURMEnvironment()
141+
return env
142+
143+
def _is_using_torchelastic(self):
144+
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
145+
return te_flags_passed
146+
126147
def select_accelerator(self):
127148
if self.trainer.accelerator_backend is not None:
128149
return self.trainer.accelerator_backend

pytorch_lightning/accelerators/base_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323

2424
class Accelerator(object):
2525

26-
def __init__(self, trainer):
26+
def __init__(self, trainer, cluster_environment=None):
2727
self.trainer = trainer
28+
self.cluster_environment = cluster_environment
2829
self.dist = AttributeDict(rank=0, device=None)
2930

3031
def setup(self, model):

pytorch_lightning/accelerators/cpu_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
class CPUBackend(Accelerator):
2222

23-
def __init__(self, trainer):
24-
super().__init__(trainer)
23+
def __init__(self, trainer, cluster_environment=None):
24+
super().__init__(trainer, cluster_environment)
2525

2626
def setup(self, model):
2727
# run through amp wrapper

pytorch_lightning/accelerators/ddp2_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535

3636
class DDP2Backend(Accelerator):
3737

38-
def __init__(self, trainer):
39-
super().__init__(trainer)
38+
def __init__(self, trainer, cluster_environment=None):
39+
super().__init__(trainer, cluster_environment)
4040
self.task_idx = None
4141
self.dist = LightningDistributed()
4242

pytorch_lightning/accelerators/ddp_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242

4343
class DDPBackend(Accelerator):
4444

45-
def __init__(self, trainer):
46-
super().__init__(trainer)
45+
def __init__(self, trainer, cluster_environment=None):
46+
super().__init__(trainer, cluster_environment)
4747
self.task_idx = None
4848
self._has_spawned_children = False
4949
self.interactive_ddp_procs = []

pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838

3939
class DDPCPUSpawnBackend(Accelerator):
4040

41-
def __init__(self, trainer, nprocs):
42-
super().__init__(trainer)
41+
def __init__(self, trainer, nprocs, cluster_environment=None):
42+
super().__init__(trainer, cluster_environment)
4343
self.mp_queue = None
4444
self.nprocs = nprocs
4545
self.dist = LightningDistributed()

pytorch_lightning/accelerators/ddp_slurm_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
# -------------------------------------------
4141
class DDPSLURMBackend(Accelerator):
4242

43-
def __init__(self, trainer):
44-
super().__init__(trainer)
43+
def __init__(self, trainer, cluster_environment=None):
44+
super().__init__(trainer, cluster_environment)
4545
self.task_idx = None
4646
self._has_spawned_children = False
4747
self.dist = LightningDistributed()

pytorch_lightning/accelerators/ddp_spawn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040

4141
class DDPSpawnBackend(Accelerator):
4242

43-
def __init__(self, trainer, nprocs):
44-
super().__init__(trainer)
43+
def __init__(self, trainer, nprocs, cluster_environment=None):
44+
super().__init__(trainer, cluster_environment)
4545
self.mp_queue = None
4646
self.nprocs = nprocs
4747
self.dist = LightningDistributed()

pytorch_lightning/accelerators/ddp_torchelastic_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
# -------------------------------------------
4141
class DDPTorchElasticBackend(Accelerator):
4242

43-
def __init__(self, trainer):
44-
super().__init__(trainer)
43+
def __init__(self, trainer, cluster_environment=None):
44+
super().__init__(trainer, cluster_environment)
4545
self.task_idx = None
4646
self._has_spawned_children = False
4747
self.dist = LightningDistributed()

pytorch_lightning/accelerators/dp_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
class DataParallelBackend(Accelerator):
2727

28-
def __init__(self, trainer):
29-
super().__init__(trainer)
28+
def __init__(self, trainer, cluster_environment=None):
29+
super().__init__(trainer, cluster_environment)
3030
self.model_autocast_original_forward = None
3131
self.dist = LightningDistributed()
3232

0 commit comments

Comments
 (0)