Skip to content

Commit 460ab54

Browse files
Gen ddp support (#1961)
* updated docs * added mixed * added mixed
1 parent c967b88 commit 460ab54

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -856,17 +856,21 @@ def fit(
856856
if self.use_ddp2:
857857
if self.is_slurm_managing_tasks:
858858
task = int(os.environ['SLURM_LOCALID'])
859-
elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
859+
860+
# torchelastic or general non_slurm ddp2
861+
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
860862
task = int(os.environ['LOCAL_RANK'])
861863
self.ddp_train(task, model)
862864
elif self.use_ddp:
863865
if self.is_slurm_managing_tasks:
864866
task = int(os.environ['SLURM_LOCALID'])
865867
self.ddp_train(task, model)
866-
# torchelastic
867-
elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
868+
869+
# torchelastic or general non_slurm ddp
870+
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
868871
task = int(os.environ['LOCAL_RANK'])
869872
self.ddp_train(task, model)
873+
870874
else:
871875
self.__set_random_port()
872876
# track for predict

0 commit comments

Comments
 (0)