Skip to content

Commit 0f64f15

Browse files
ref: unify slurm and TE under backendPlugin 1/n (#4578)
* ref: unify slurm and TE under backendPlugin * ref: unify slurm and TE under backendPlugin
1 parent 09a5169 commit 0f64f15

File tree

9 files changed

+38
-33
lines changed

9 files changed

+38
-33
lines changed

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,8 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
4646
self.nickname = 'ddp2'
4747

4848
def setup(self, model):
49-
self._resolve_task_idx()
5049
self.trainer.model = model
51-
52-
def _resolve_task_idx(self):
53-
if self.trainer.is_slurm_managing_tasks:
54-
self.task_idx = int(os.environ['SLURM_LOCALID'])
55-
else:
56-
# torchelastic or general non_slurm ddp2
57-
try:
58-
self.task_idx = int(os.environ['LOCAL_RANK'])
59-
except Exception as exp:
60-
m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags'
61-
raise MisconfigurationException(m) from exp
50+
self.task_idx = self.cluster_environment.local_rank()
6251

6352
def train(self):
6453
model = self.trainer.model

pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
5353

5454
def setup(self, model):
5555
self.trainer.model = model
56-
self.task_idx = int(os.environ['SLURM_LOCALID'])
56+
self.task_idx = self.cluster_environment.local_rank()
5757

5858
def train(self):
5959
model = self.trainer.model
@@ -118,7 +118,7 @@ def ddp_train(self, process_idx, model):
118118
self.set_world_ranks(process_idx)
119119

120120
# toggle prog bar
121-
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
121+
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
122122
self.trainer.progress_bar_callback.disable()
123123

124124
# set warning rank

pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
5252

5353
def setup(self, model):
5454
self.trainer.model = model
55-
self.task_idx = int(os.environ['LOCAL_RANK'])
55+
self.task_idx = self.cluster_environment.local_rank()
5656

5757
def train(self):
5858
model = self.trainer.model
@@ -117,7 +117,7 @@ def ddp_train(self, process_idx, model):
117117
self.set_world_ranks(process_idx)
118118

119119
# toggle prog bar
120-
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
120+
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
121121
self.trainer.progress_bar_callback.disable()
122122

123123
# set warning rank

pytorch_lightning/accelerators/ddp_slurm_accelerator.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.utilities import AMPType
2727
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
28-
from pytorch_lightning.utilities.seed import seed_everything
2928

3029
try:
3130
from hydra.utils import to_absolute_path, get_original_cwd
@@ -52,7 +51,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
5251

5352
def setup(self, model):
5453
self.trainer.model = model
55-
self.task_idx = int(os.environ['SLURM_LOCALID'])
54+
self.task_idx = self.cluster_environment.local_rank()
5655

5756
def train(self):
5857
model = self.trainer.model
@@ -88,7 +87,7 @@ def test_step(self, args):
8887
output = self.training_step(args)
8988
return output
9089

91-
def barrier(self, name: str = None):
90+
def barrier(self, name: Optional[str] = None):
9291
if torch_distrib.is_initialized():
9392
torch_distrib.barrier()
9493

@@ -115,15 +114,11 @@ def ddp_train(self, process_idx, model):
115114
Dict with evaluation results
116115
117116
"""
118-
seed = os.environ.get("PL_GLOBAL_SEED")
119-
if seed is not None:
120-
seed_everything(int(seed))
121-
122117
# determine which process we are and world size
123118
self.set_world_ranks(process_idx)
124119

125120
# toggle prog bar
126-
if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None:
121+
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
127122
self.trainer.progress_bar_callback.disable()
128123

129124
# set warning rank

pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.utilities import AMPType
27-
from pytorch_lightning.utilities.distributed import rank_zero_only
28-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
27+
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2928

3029

3130
try:
@@ -53,7 +52,7 @@ def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
5352

5453
def setup(self, model):
5554
self.trainer.model = model
56-
self.task_idx = int(os.environ['LOCAL_RANK'])
55+
self.task_idx = self.cluster_environment.local_rank()
5756

5857
def train(self):
5958
model = self.trainer.model
@@ -120,7 +119,7 @@ def ddp_train(self, process_idx, model):
120119
self.set_world_ranks(process_idx)
121120

122121
# toggle prog bar
123-
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
122+
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
124123
self.trainer.progress_bar_callback.disable()
125124

126125
# set warning rank

pytorch_lightning/cluster_environments/cluster_environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
class ClusterEnvironment:
1617

1718
def __init__(self):
@@ -25,3 +26,6 @@ def master_port(self):
2526

2627
def world_size(self):
2728
return self._world_size
29+
30+
def local_rank(self):
31+
pass

pytorch_lightning/cluster_environments/slurm_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def master_port(self):
6767
def world_size(self):
6868
return self._world_size
6969

70+
def local_rank(self):
71+
return int(os.environ['SLURM_LOCALID'])
72+
7073
def _resolve_root_node_address(self, root_node):
7174
if '[' in root_node:
7275
name, numbers = root_node.split('[', maxsplit=1)

pytorch_lightning/cluster_environments/torchelastic_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,6 @@ def master_port(self):
4646

4747
def world_size(self):
4848
return os.environ.get('WORLD_SIZE')
49+
50+
def local_rank(self):
51+
return int(os.environ['LOCAL_RANK'])

tests/backends/test_accelerator_connector.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def on_fit_start(self, trainer, pl_module):
104104
"SLURM_NTASKS": "2",
105105
"SLURM_JOB_NAME": "SOME_NAME",
106106
"SLURM_NODEID": "0",
107-
"SLURM_LOCALID": "0"
107+
"SLURM_LOCALID": "10"
108108
})
109109
@mock.patch('torch.cuda.device_count', return_value=2)
110110
def test_accelerator_choice_ddp_slurm(tmpdir):
@@ -113,6 +113,8 @@ def on_fit_start(self, trainer, pl_module):
113113
assert trainer.use_ddp
114114
assert isinstance(trainer.accelerator_backend, accelerators.DDPSLURMAccelerator)
115115
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
116+
assert trainer.accelerator_backend.task_idx == 10
117+
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
116118
raise SystemExit()
117119

118120
model = BoringModel()
@@ -133,7 +135,7 @@ def on_fit_start(self, trainer, pl_module):
133135
"SLURM_JOB_NAME": "SOME_NAME",
134136
"SLURM_NODEID": "0",
135137
"LOCAL_RANK": "0",
136-
"SLURM_LOCALID": "0"
138+
"SLURM_LOCALID": "10"
137139
})
138140
@mock.patch('torch.cuda.device_count', return_value=2)
139141
def test_accelerator_choice_ddp2_slurm(tmpdir):
@@ -142,6 +144,9 @@ def on_fit_start(self, trainer, pl_module):
142144
assert trainer.use_ddp2
143145
assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator)
144146
assert isinstance(trainer.accelerator_backend.cluster_environment, SLURMEnvironment)
147+
assert trainer.accelerator_backend.task_idx == 10
148+
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
149+
145150
raise SystemExit()
146151

147152
model = BoringModel()
@@ -159,7 +164,7 @@ def on_fit_start(self, trainer, pl_module):
159164
@mock.patch.dict(os.environ, {
160165
"CUDA_VISIBLE_DEVICES": "0,1",
161166
"WORLD_SIZE": "2",
162-
"LOCAL_RANK": "0",
167+
"LOCAL_RANK": "10",
163168
"NODE_RANK": "0"
164169
})
165170
@mock.patch('torch.cuda.device_count', return_value=2)
@@ -169,6 +174,8 @@ def on_fit_start(self, trainer, pl_module):
169174
assert trainer.use_ddp
170175
assert isinstance(trainer.accelerator_backend, accelerators.DDPTorchElasticAccelerator)
171176
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
177+
assert trainer.accelerator_backend.task_idx == 10
178+
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
172179
raise SystemExit()
173180

174181
model = BoringModel()
@@ -186,7 +193,7 @@ def on_fit_start(self, trainer, pl_module):
186193
@mock.patch.dict(os.environ, {
187194
"CUDA_VISIBLE_DEVICES": "0,1",
188195
"WORLD_SIZE": "2",
189-
"LOCAL_RANK": "0",
196+
"LOCAL_RANK": "10",
190197
"NODE_RANK": "0"
191198
})
192199
@mock.patch('torch.cuda.device_count', return_value=2)
@@ -196,6 +203,8 @@ def on_fit_start(self, trainer, pl_module):
196203
assert trainer.use_ddp2
197204
assert isinstance(trainer.accelerator_backend, accelerators.DDP2Accelerator)
198205
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
206+
assert trainer.accelerator_backend.task_idx == 10
207+
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
199208
raise SystemExit()
200209

201210
model = BoringModel()
@@ -212,7 +221,7 @@ def on_fit_start(self, trainer, pl_module):
212221

213222
@mock.patch.dict(os.environ, {
214223
"WORLD_SIZE": "1",
215-
"LOCAL_RANK": "0",
224+
"LOCAL_RANK": "10",
216225
"NODE_RANK": "0"
217226
})
218227
@mock.patch('torch.cuda.device_count', return_value=0)
@@ -222,6 +231,9 @@ def on_fit_start(self, trainer, pl_module):
222231
assert trainer.use_ddp
223232
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUTorchElasticAccelerator)
224233
assert isinstance(trainer.accelerator_backend.cluster_environment, TorchElasticEnvironment)
234+
assert trainer.accelerator_backend.task_idx == 10
235+
assert trainer.accelerator_backend.cluster_environment.local_rank() == trainer.accelerator_backend.task_idx
236+
225237
raise SystemExit()
226238

227239
model = BoringModel()

0 commit comments

Comments
 (0)