|
7 | 7 | from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info |
8 | 8 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
9 | 9 | from pytorch_lightning import _logger as log |
| 10 | +from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment |
| 11 | +from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment |
10 | 12 |
|
11 | 13 | try: |
12 | 14 | import torch_xla |
@@ -40,9 +42,12 @@ def on_trainer_init( |
40 | 42 | sync_batchnorm, |
41 | 43 | benchmark, |
42 | 44 | replace_sampler_ddp, |
43 | | - deterministic |
| 45 | + deterministic, |
| 46 | + cluster_environment |
44 | 47 | ): |
45 | 48 | self.trainer.deterministic = deterministic |
| 49 | + self.cluster_environment = cluster_environment |
| 50 | + |
46 | 51 | torch.backends.cudnn.deterministic = self.trainer.deterministic |
47 | 52 | if self.trainer.deterministic: |
48 | 53 | # fixing non-deterministic part of horovod |
@@ -123,6 +128,22 @@ def on_trainer_init( |
123 | 128 |
|
124 | 129 | self.trainer.replace_sampler_ddp = replace_sampler_ddp |
125 | 130 |
|
| 131 | + def _select_environment(self): |
| 132 | + env = None |
| 133 | + |
| 134 | + # in priority: user environment, torchelastic (which is a generic environment), slurm |
| 135 | + if self.cluster_environment is not None: |
| 136 | + env = self.cluster_environment |
| 137 | + elif self._is_using_torchelastic(): |
| 138 | + env = TorchElasticEnvironment() |
| 139 | + elif self.trainer.is_slurm_managing_tasks: |
| 140 | + env = SLURMEnvironment() |
| 141 | + return env |
| 142 | + |
| 143 | + def _is_using_torchelastic(self): |
| 144 | + te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) |
| 145 | + return te_flags_passed |
| 146 | + |
126 | 147 | def select_accelerator(self): |
127 | 148 | if self.trainer.accelerator_backend is not None: |
128 | 149 | return self.trainer.accelerator_backend |
|
0 commit comments