Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,16 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
def __set_random_port(self):
"""
When running DDP NOT managed by SLURM, the ports might collide
:return:
"""
try:
default_port = os.environ['MASTER_PORT']
except Exception:
import random
default_port = random.randint(10000, 19000)
os.environ['MASTER_PORT'] = str(default_port)
# use the process id as a seed to a generator for port only
pid = os.getpid()
rng1 = np.random.RandomState(pid)
default_port = rng1.randint(10000, 19999, 1)[0]

os.environ['MASTER_PORT'] = str(default_port)

def spawn_ddp_children(self, model):
self.__set_random_port()
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only

try:
from apex import amp
Expand Down Expand Up @@ -322,6 +322,14 @@ def __init__(
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# init the default rank if exists
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
# this way we only show it on rank 0
if 'LOCAL_RANK' in os.environ:
rank_zero_only.rank = os.environ['LOCAL_RANK']
if 'SLURM_JOB_ID' in os.environ:
rank_zero_only.rank = os.environ['SLURM_JOB_ID']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line hear breaks logging in my setup. rank_zero_only.rank is set to a non-zero value (the slurm job id) and thus non of the logging functions are ever executed. Did you mean SLURM_PROCID which is the MPI rank?


# Init callbacks
self.prepare_data_per_node = prepare_data_per_node
self.callbacks = callbacks or []
Expand Down Expand Up @@ -892,6 +900,7 @@ def fit(
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))

elif self.distributed_backend == 'ddp_spawn':
self.__set_random_port()
model.share_memory()

# spin up peers
Expand Down