Skip to content

Commit 60c1c8f

Browse files
awaelchlicarmoccaananthsub
authored
Auto-set DataLoader.worker_init_fn with seed_everything (#6960)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: ananthsub <[email protected]>
1 parent d1529c2 commit 60c1c8f

File tree

6 files changed

+166
-12
lines changed

6 files changed

+166
-12
lines changed

docs/source/common/trainer.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,16 @@ Example::
184184

185185
from pytorch_lightning import Trainer, seed_everything
186186

187-
seed_everything(42)
187+
seed_everything(42, workers=True)
188188
# sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
189189
model = Model()
190190
trainer = Trainer(deterministic=True)
191191

192192

193+
By setting ``workers=True`` in :func:`~pytorch_lightning.utilities.seed.seed_everything`, Lightning derives
194+
unique seeds across all dataloader workers and processes for :mod:`torch`, :mod:`numpy` and stdlib
195+
:mod:`random` number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers.
196+
193197
-------
194198

195199
Trainer flags

pytorch_lightning/trainer/data_loading.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from abc import ABC
1818
from copy import deepcopy
19+
from functools import partial
1920
from typing import Iterable, List, Optional, Tuple, Union
2021

2122
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
@@ -31,6 +32,7 @@
3132
from pytorch_lightning.utilities.debugging import InternalDebugger
3233
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3334
from pytorch_lightning.utilities.model_helpers import is_overridden
35+
from pytorch_lightning.utilities.seed import pl_worker_init_function
3436

3537

3638
class TrainerDataLoadingMixin(ABC):
@@ -101,6 +103,10 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
101103
f' in the `DataLoader` init to improve performance.'
102104
)
103105

106+
def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
107+
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
108+
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)
109+
104110
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
105111

106112
# don't do anything if it's not a dataloader
@@ -234,6 +240,9 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
234240
# check the workers recursively
235241
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')
236242

243+
# add worker_init_fn for correct seeding in worker processes
244+
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)
245+
237246
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
238247
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)
239248

@@ -332,6 +341,9 @@ def _reset_eval_dataloader(
332341
# add samplers
333342
dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None]
334343

344+
# add worker_init_fn for correct seeding in worker processes
345+
apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn)
346+
335347
loader_num_batches = []
336348

337349
# determine number of batches

pytorch_lightning/utilities/seed.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,29 @@
2121
import numpy as np
2222
import torch
2323

24-
from pytorch_lightning.utilities import rank_zero_warn
24+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
25+
from pytorch_lightning.utilities.distributed import rank_zero_only
2526

2627
log = logging.getLogger(__name__)
2728

2829

29-
def seed_everything(seed: Optional[int] = None) -> int:
30+
def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
3031
"""
3132
Function that sets seed for pseudo-random number generators in:
3233
pytorch, numpy, python.random
33-
In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to
34-
spawned subprocesses (e.g. ddp_spawn backend).
34+
In addition, sets the following environment variables:
35+
36+
- `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
37+
- `PL_SEED_WORKERS`: (optional) is set to 1 if ```workers=True``.
3538
3639
Args:
3740
seed: the integer value seed for global random state in Lightning.
3841
If `None`, will read seed from `PL_GLOBAL_SEED` env variable
3942
or select it randomly.
43+
workers: if set to ``True``, will properly configure all dataloaders passed to the
44+
Trainer with a ``worker_init_fn``. If the user already provides such a function
45+
for their dataloaders, setting this argument will have no influence. See also:
46+
:func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`.
4047
"""
4148
max_seed_value = np.iinfo(np.uint32).max
4249
min_seed_value = np.iinfo(np.uint32).min
@@ -61,8 +68,36 @@ def seed_everything(seed: Optional[int] = None) -> int:
6168
np.random.seed(seed)
6269
torch.manual_seed(seed)
6370
torch.cuda.manual_seed_all(seed)
71+
72+
os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
73+
6474
return seed
6575

6676

6777
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
6878
return random.randint(min_seed_value, max_seed_value)
79+
80+
81+
def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover
82+
"""
83+
The worker_init_fn that Lightning automatically adds to your dataloader if you previously set
84+
set the seed with ``seed_everything(seed, workers=True)``.
85+
See also the PyTorch documentation on
86+
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_.
87+
"""
88+
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
89+
global_rank = rank if rank is not None else rank_zero_only.rank
90+
process_seed = torch.initial_seed()
91+
# back out the base seed so we can use all the bits
92+
base_seed = process_seed - worker_id
93+
ss = np.random.SeedSequence([base_seed, worker_id, global_rank])
94+
# use 128 bits (4 x 32-bit words)
95+
np.random.seed(ss.generate_state(4))
96+
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module
97+
torch_ss, stdlib_ss = ss.spawn(2)
98+
# PyTorch 1.7 and above takes a 64-bit seed
99+
dtype = np.uint64 if _TORCH_GREATER_EQUAL_1_7 else np.uint32
100+
torch.manual_seed(torch_ss.generate_state(1, dtype=dtype)[0])
101+
# use 128 bits expressed as an integer
102+
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum()
103+
random.seed(stdlib_seed)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# the default package dependencies
22

3-
numpy>=1.16.6
3+
numpy>=1.17.2
44
torch>=1.4
55
future>=0.17.1 # required for builtins in setup.py
66
# pyyaml>=3.13

tests/trainer/test_data_loading.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import pytest
1615
from torch.utils.data import DataLoader
1716
from torch.utils.data.sampler import BatchSampler, SequentialSampler
@@ -72,7 +71,7 @@ def test_dataloader(self):
7271
return [self.create_dataset()] * self._numbers_test_dataloaders
7372

7473

75-
def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode):
74+
def check_replace_distributed_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode):
7675
num_processes = 2
7776
limit_test_batches = 2
7877
trainer_args = {
@@ -100,8 +99,8 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator,
10099

101100
@RunIf(min_gpus=2, special=True)
102101
@pytest.mark.parametrize("mode", [1, 2])
103-
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
104-
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)
102+
def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
103+
check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode)
105104

106105

107106
@pytest.mark.parametrize("num_workers", [0, 1])

tests/trainer/test_dataloaders.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414
import os
1515
from unittest import mock
16-
from unittest.mock import patch
16+
from unittest.mock import Mock, patch
1717

18+
import numpy
1819
import pytest
1920
import torch
2021
from torch.utils.data.dataloader import DataLoader
21-
from torch.utils.data.dataset import IterableDataset, Subset
22+
from torch.utils.data.dataset import Dataset, IterableDataset, Subset
2223
from torch.utils.data.distributed import DistributedSampler
2324
from torch.utils.data.sampler import SequentialSampler
2425

@@ -635,6 +636,109 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
635636
trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl)
636637

637638

639+
class NumpyRandomDataset(Dataset):
640+
# this datset uses numpy instead of torch to produce random numbers
641+
size = 16
642+
643+
def __getitem__(self, index):
644+
return numpy.random.randint(0, 100, 3)
645+
646+
def __len__(self):
647+
return self.size
648+
649+
650+
def _user_worker_init_fn(_):
651+
pass
652+
653+
654+
def test_missing_worker_init_fn():
655+
""" Test that naive worker seed initialization leads to undesired random state in subprocesses. """
656+
dataset = NumpyRandomDataset()
657+
658+
seed_everything(0)
659+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
660+
batches0 = torch.cat([batch for batch in dataloader])
661+
662+
seed_everything(0)
663+
dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False)
664+
batches1 = torch.cat([batch for batch in dataloader])
665+
666+
is_duplicated = len(torch.unique(batches1, dim=0)) < len(dataset)
667+
is_deterministic = torch.eq(batches0, batches1).all()
668+
669+
# depending on the OS, we either have
670+
# 1) the same seed in all worker proceses, producing duplicate samples / augmentations, or
671+
# 2) different seeds in each worker process, but they are not derived from the seed of the main process
672+
assert not is_deterministic or is_duplicated
673+
674+
675+
def test_auto_add_worker_init_fn():
676+
""" Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """
677+
dataset = Mock()
678+
dataloader = DataLoader(dataset)
679+
trainer = Trainer()
680+
681+
# without pl.seed_everything()
682+
trainer.auto_add_worker_init_fn(dataloader)
683+
assert dataloader.worker_init_fn is None
684+
685+
# with forcefully avoiding it
686+
seed_everything(0, workers=False)
687+
trainer.auto_add_worker_init_fn(dataloader)
688+
assert dataloader.worker_init_fn is None
689+
690+
# when user already has a worker_init_fn
691+
user_function = _user_worker_init_fn
692+
dataloader.worker_init_fn = user_function
693+
trainer.auto_add_worker_init_fn(dataloader)
694+
assert dataloader.worker_init_fn is user_function
695+
dataloader.worker_init_fn = None
696+
697+
# main use case
698+
seed_everything(0, workers=True)
699+
trainer.auto_add_worker_init_fn(dataloader)
700+
assert dataloader.worker_init_fn is not None
701+
702+
703+
class MultiProcessModel(BoringModel):
704+
705+
def __init__(self):
706+
super().__init__()
707+
self.batches_seen = []
708+
709+
def training_step(self, batch, batch_idx):
710+
self.batches_seen.append(batch)
711+
712+
def training_epoch_end(self, outputs):
713+
world_size = 2
714+
num_samples = NumpyRandomDataset.size
715+
all_batches = torch.cat(self.batches_seen)
716+
all_batches = self.all_gather(all_batches)
717+
assert all_batches.shape[0] == world_size
718+
all_batches = all_batches.view(-1, 3)
719+
assert len(torch.unique(all_batches, dim=0)) == num_samples
720+
721+
722+
@RunIf(min_gpus=2)
723+
def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch):
724+
""" Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """
725+
dataset = NumpyRandomDataset()
726+
num_workers = 2
727+
batch_size = 2
728+
729+
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
730+
seed_everything(0, workers=True)
731+
trainer = Trainer(
732+
default_root_dir=tmpdir,
733+
max_epochs=1,
734+
gpus=2,
735+
accelerator="ddp_spawn",
736+
)
737+
model = MultiProcessModel()
738+
model.val_dataloader = None
739+
trainer.fit(model, train_dataloader=dataloader)
740+
741+
638742
def test_warning_with_iterable_dataset_and_len(tmpdir):
639743
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
640744
model = BoringModel()

0 commit comments

Comments
 (0)