Skip to content

Commit 22d8266

Browse files
Seppo Enarvicarmocca
andauthored
Seed all workers when using DDP (#7942)
* Seed all workers when using DDP * Fix to dataloader seeding * Make argument name explicit Co-authored-by: Carlos Mocholí <[email protected]> * Use f-strings when logging * Removed a redundant log message Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 436fc53 commit 22d8266

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
214214

215215
### Fixed
216216

217+
- Fixed setting `worker_init_fn` to seed dataloaders correctly when using DDP ([#7942](https://github.com/PyTorchLightning/pytorch-lightning/pull/7942))
218+
219+
217220
- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))
218221

219222

pytorch_lightning/utilities/seed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ def reset_seed() -> None:
8484
If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing.
8585
"""
8686
seed = os.environ.get("PL_GLOBAL_SEED", None)
87+
workers = os.environ.get("PL_SEED_WORKERS", False)
8788
if seed is not None:
88-
seed_everything(int(seed))
89+
seed_everything(int(seed), workers=bool(workers))
8990

9091

9192
def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover
@@ -100,6 +101,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p
100101
process_seed = torch.initial_seed()
101102
# back out the base seed so we can use all the bits
102103
base_seed = process_seed - worker_id
104+
log.debug(
105+
f'Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}'
106+
)
103107
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
104108
# use 128 bits (4 x 32-bit words)
105109
np.random.seed(ss.generate_state(4))

0 commit comments

Comments
 (0)