From 18e61ee4102c32995079bd639d30557956f42bc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 14:52:10 +0200 Subject: [PATCH 01/38] example --- pl_examples/bug_report_model.py | 122 +++++--------------------------- 1 file changed, 16 insertions(+), 106 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index b8e45512bb7c7..edac5a45a5944 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,76 +1,25 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -------------------------------------------- -# -------------------------------------------- -# -------------------------------------------- -# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT -# -------------------------------------------- -# -------------------------------------------- -# -------------------------------------------- import os import torch from torch.utils.data import Dataset -from pl_examples import cli_lightning_logo -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningModule, Trainer, seed_everything +import numpy as np +from torch.utils.data import Dataset -class RandomDataset(Dataset): - """ - >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS - <...bug_report_model.RandomDataset object at ...> - """ - - def __init__(self, size, length): - self.len = length - self.data = torch.randn(length, size) +class RandomDataset(Dataset): def __getitem__(self, index): - return self.data[index] + return np.random.randint(0, 10, 3) def __len__(self): - return self.len + return 16 class BoringModel(LightningModule): - """ - >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - BoringModel( - (layer): Linear(...) - ) - """ def __init__(self): - """ - Testing PL Module - - Use as follows: - - subclass - - modify the behavior for what you want - - class TestModel(BaseTestModel): - def training_step(...): - # do your own thing - - or: - - model = BaseTestModel() - model.training_epoch_end = None - - """ super().__init__() self.layer = torch.nn.Linear(32, 2) @@ -87,31 +36,8 @@ def step(self, x): return out def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def training_step_end(self, training_step_outputs): - return training_step_outputs - - def training_epoch_end(self, outputs) -> None: - torch.stack([x["loss"] for x in outputs]).mean() - - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"x": loss} - - def validation_epoch_end(self, outputs) -> None: - torch.stack([x['x'] for x in outputs]).mean() - - def test_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"y": loss} - - def test_epoch_end(self, outputs) -> None: - torch.stack([x["y"] for x in outputs]).mean() + print(batch) + return None def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -119,39 +45,23 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] -# NOTE: If you are using a cmd line to run your script, -# provide the cmd line as below. -# opt = "--max_epochs 1 --limit_train_batches 1".split(" ") -# parser = ArgumentParser() -# args = parser.parse_args(opt) - - -class TestModel(BoringModel): - - def on_train_epoch_start(self) -> None: - print('override any method to prove your bug') - - -def test_run(): +def run(): # fake data - train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) - val_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) - test_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + train_data = torch.utils.data.DataLoader(RandomDataset(), batch_size=2, num_workers=4) # model - model = TestModel() + model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, + limit_train_batches=4, max_epochs=1, weights_summary=None, + progress_bar_refresh_rate=0, ) - trainer.fit(model, train_data, val_data) - trainer.test(test_dataloaders=test_data) + trainer.fit(model, train_data) if __name__ == '__main__': - cli_lightning_logo() - test_run() + seed_everything(1) + run() From 588f5813ef038487bcfa9708f51bdd604ddb211d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 16:37:12 +0200 Subject: [PATCH 02/38] auto add worker fn --- pl_examples/bug_report_model.py | 14 +++++++-- pytorch_lightning/trainer/data_loading.py | 14 +++++++++ pytorch_lightning/utilities/seed.py | 15 ++++++++++ tests/trainer/test_data_loading.py | 7 ++--- tests/trainer/test_dataloaders.py | 35 +++++++++++++++++++++-- 5 files changed, 77 insertions(+), 8 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index edac5a45a5944..993b4a1955b91 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,5 +1,7 @@ import os +import random +import numpy import torch from torch.utils.data import Dataset @@ -48,15 +50,23 @@ def configure_optimizers(self): def run(): # fake data - train_data = torch.utils.data.DataLoader(RandomDataset(), batch_size=2, num_workers=4) + train_data = torch.utils.data.DataLoader(RandomDataset(), batch_size=2, num_workers=2) + # + # def worker_fn(worker_id): + # worker_seed = torch.initial_seed() % (2 ** 32) + # numpy.random.seed(worker_seed) + # random.seed(worker_seed) + + # train_data.worker_init_fn = worker_fn # model model = BoringModel() trainer = Trainer( default_root_dir=os.getcwd(), limit_train_batches=4, - max_epochs=1, + max_epochs=2, weights_summary=None, + # reload_dataloaders_every_epoch=True, progress_bar_refresh_rate=0, ) trainer.fit(model, train_data) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 59944dada330c..5e570caf580a3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect import multiprocessing +import os from abc import ABC from copy import deepcopy from typing import Iterable, List, Tuple, Union @@ -30,6 +31,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.seed import pl_worker_init_function class TrainerDataLoadingMixin(ABC): @@ -100,6 +102,12 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: f' in the `DataLoader` init to improve performance.' ) + @staticmethod + def auto_add_worker_init_fn(dataloader: DataLoader) -> None: + if dataloader.worker_init_fn is not None or "PL_GLOBAL_SEED" not in os.environ: + return + dataloader.worker_init_fn = pl_worker_init_function + def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: # don't do anything if it's not a dataloader @@ -231,6 +239,9 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn) + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode) @@ -329,6 +340,9 @@ def _reset_eval_dataloader( # add samplers dataloaders = [self.auto_add_sampler(dl, shuffle=False) for dl in dataloaders if dl is not None] + # add worker_init_fn for correct seeding in worker processes + apply_to_collection(dataloaders, dtype=DataLoader, function=self.auto_add_worker_init_fn) + loader_num_batches = [] # determine number of batches diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index c3da02b7d2cdb..145d6a9bd9009 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -18,6 +18,7 @@ import random from typing import Optional +import numpy import numpy as np import torch @@ -66,3 +67,17 @@ def seed_everything(seed: Optional[int] = None) -> int: def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: return random.randint(min_seed_value, max_seed_value) + + +def pl_worker_init_function(worker_id: int) -> None: + """ + The worker_init_fn that Lightning automatically adds to your dataloader if you previously set + set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`. + + See Also + `Randomness in Dataloades `_ + """ + # + worker_seed = torch.initial_seed() % (2 ** 32) + numpy.random.seed(worker_seed) + random.seed(worker_seed) \ No newline at end of file diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 382311c107958..831fc474336b6 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest from torch.utils.data import DataLoader from torch.utils.data.sampler import BatchSampler, SequentialSampler @@ -72,7 +71,7 @@ def test_dataloader(self): return [self.create_dataset()] * self._numbers_test_dataloaders -def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode): +def check_replace_distributed_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode): num_processes = 2 limit_test_batches = 2 trainer_args = { @@ -100,8 +99,8 @@ def check_replace_distrubuted_sampler(tmpdir, save_preds_on_dl_idx, accelerator, @RunIf(min_gpus=2, special=True) @pytest.mark.parametrize("mode", [1, 2]) -def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): - check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode) +def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): + check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode) @pytest.mark.parametrize("num_workers", [0, 1]) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 7f9cf6210ce7c..8bf9970d311cb 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -15,18 +15,20 @@ from unittest import mock from unittest.mock import patch +import numpy import pytest import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.dataset import IterableDataset, Subset, Dataset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler import tests.helpers.pipelines as tpipes -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Callback, Trainer, seed_everything from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.seed import pl_worker_init_function from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -634,6 +636,35 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) +class NumpyRandomDataset(Dataset): + def __getitem__(self, index): + return numpy.random.randint(0, 100, 3) + + def __len__(self): + return 16 + + +def test_auto_add_worker_init_fn(tmpdir): + """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ + dataset = NumpyRandomDataset() + num_samples = len(dataset) + num_workers = 2 + batch_size = 2 + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + trainer = Trainer(default_root_dir=tmpdir) + + # without pl.seed_everything() + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is None + + # with pl.seed_everything() + seed_everything(0) + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is pl_worker_init_function + unique_batches = set(tuple(batch.view(-1).tolist()) for batch in dataloader) + assert len(unique_batches) > (num_samples // (batch_size * num_workers)) + + def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = BoringModel() From bf927930c1efa5c669a267df1eeedaebf4cf64f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 16:37:47 +0200 Subject: [PATCH 03/38] Revert "example" This reverts commit 18e61ee4102c32995079bd639d30557956f42bc8. --- pl_examples/bug_report_model.py | 130 ++++++++++++++++++++++++++------ 1 file changed, 107 insertions(+), 23 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 993b4a1955b91..9a5a09468a135 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -1,3 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -------------------------------------------- +# -------------------------------------------- +# -------------------------------------------- +# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT +# -------------------------------------------- +# -------------------------------------------- +# -------------------------------------------- import os import random @@ -5,23 +26,53 @@ import torch from torch.utils.data import Dataset -from pytorch_lightning import LightningModule, Trainer, seed_everything - -import numpy as np -from torch.utils.data import Dataset +from pl_examples import cli_lightning_logo +from pytorch_lightning import LightningModule, Trainer class RandomDataset(Dataset): + """ + >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS + <...bug_report_model.RandomDataset object at ...> + """ + + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + def __getitem__(self, index): - return np.random.randint(0, 10, 3) + return self.data[index] def __len__(self): - return 16 + return self.len class BoringModel(LightningModule): + """ + >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + BoringModel( + (layer): Linear(...) + ) + """ def __init__(self): + """ + Testing PL Module + + Use as follows: + - subclass + - modify the behavior for what you want + + class TestModel(BaseTestModel): + def training_step(...): + # do your own thing + + or: + + model = BaseTestModel() + model.training_epoch_end = None + + """ super().__init__() self.layer = torch.nn.Linear(32, 2) @@ -38,8 +89,31 @@ def step(self, x): return out def training_step(self, batch, batch_idx): - print(batch) - return None + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_step_end(self, training_step_outputs): + return training_step_outputs + + def training_epoch_end(self, outputs) -> None: + torch.stack([x["loss"] for x in outputs]).mean() + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"x": loss} + + def validation_epoch_end(self, outputs) -> None: + torch.stack([x['x'] for x in outputs]).mean() + + def test_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"y": loss} + + def test_epoch_end(self, outputs) -> None: + torch.stack([x["y"] for x in outputs]).mean() def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -47,31 +121,41 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] -def run(): +# NOTE: If you are using a cmd line to run your script, +# provide the cmd line as below. +# opt = "--max_epochs 1 --limit_train_batches 1".split(" ") +# parser = ArgumentParser() +# args = parser.parse_args(opt) + + +class TestModel(BoringModel): + + def on_train_epoch_start(self) -> None: + print('override any method to prove your bug') + + +def test_run(): # fake data - train_data = torch.utils.data.DataLoader(RandomDataset(), batch_size=2, num_workers=2) - # - # def worker_fn(worker_id): - # worker_seed = torch.initial_seed() % (2 ** 32) - # numpy.random.seed(worker_seed) - # random.seed(worker_seed) + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + val_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + test_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) # train_data.worker_init_fn = worker_fn # model - model = BoringModel() + model = TestModel() trainer = Trainer( default_root_dir=os.getcwd(), - limit_train_batches=4, - max_epochs=2, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, weights_summary=None, - # reload_dataloaders_every_epoch=True, - progress_bar_refresh_rate=0, ) - trainer.fit(model, train_data) + trainer.fit(model, train_data, val_data) + trainer.test(test_dataloaders=test_data) if __name__ == '__main__': - seed_everything(1) - run() + cli_lightning_logo() + test_run() From 49f320096a85d0073b54d4e71c93a0e3c10bff75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 16:49:17 +0200 Subject: [PATCH 04/38] revert --- pl_examples/bug_report_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 9a5a09468a135..b8e45512bb7c7 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -20,9 +20,7 @@ # -------------------------------------------- # -------------------------------------------- import os -import random -import numpy import torch from torch.utils.data import Dataset @@ -141,8 +139,6 @@ def test_run(): val_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) test_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) - # train_data.worker_init_fn = worker_fn - # model model = TestModel() trainer = Trainer( From 74893bd348b0a4285d22eab9c4859ac856f6fdf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 16:51:08 +0200 Subject: [PATCH 05/38] typo --- pytorch_lightning/utilities/seed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 145d6a9bd9009..f89bf5524406c 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -75,7 +75,7 @@ def pl_worker_init_function(worker_id: int) -> None: set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`. See Also - `Randomness in Dataloades `_ + `Randomness in DataLoaders `_ """ # worker_seed = torch.initial_seed() % (2 ** 32) From 49c1a64fa41499350f05b95579babc8cec9677f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 16:56:58 +0200 Subject: [PATCH 06/38] flake --- pytorch_lightning/utilities/seed.py | 4 ++-- tests/trainer/test_dataloaders.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index f89bf5524406c..2aa83d79732a3 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -78,6 +78,6 @@ def pl_worker_init_function(worker_id: int) -> None: `Randomness in DataLoaders `_ """ # - worker_seed = torch.initial_seed() % (2 ** 32) + worker_seed = torch.initial_seed() % (2**32) numpy.random.seed(worker_seed) - random.seed(worker_seed) \ No newline at end of file + random.seed(worker_seed) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 8bf9970d311cb..0d7f185419259 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -637,6 +637,7 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): class NumpyRandomDataset(Dataset): + def __getitem__(self, index): return numpy.random.randint(0, 100, 3) From 50d35adf63177ad45144fd219ee8a7d61678a9f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 18:36:29 +0200 Subject: [PATCH 07/38] add worker_id --- pytorch_lightning/utilities/seed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 2aa83d79732a3..b55909b7d4e41 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -78,6 +78,6 @@ def pl_worker_init_function(worker_id: int) -> None: `Randomness in DataLoaders `_ """ # - worker_seed = torch.initial_seed() % (2**32) + worker_seed = (torch.initial_seed() + worker_id) % (2**32) numpy.random.seed(worker_seed) random.seed(worker_seed) From 873584e74b0a73eed5977192c4ff96faf1bb6df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 18:36:47 +0200 Subject: [PATCH 08/38] typo --- pytorch_lightning/utilities/seed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index b55909b7d4e41..3db59677b1e87 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -77,7 +77,6 @@ def pl_worker_init_function(worker_id: int) -> None: See Also `Randomness in DataLoaders `_ """ - # worker_seed = (torch.initial_seed() + worker_id) % (2**32) numpy.random.seed(worker_seed) random.seed(worker_seed) From b71d16cf93d8046e41c6f852034133d744c61b80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 11 Apr 2021 19:14:36 +0200 Subject: [PATCH 09/38] workers argument in seed_everyting --- pytorch_lightning/trainer/data_loading.py | 5 ++--- pytorch_lightning/utilities/seed.py | 21 +++++++++++++++------ tests/trainer/test_dataloaders.py | 20 ++++++++++++++++++-- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5e570caf580a3..6353316d1e567 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -104,9 +104,8 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: @staticmethod def auto_add_worker_init_fn(dataloader: DataLoader) -> None: - if dataloader.worker_init_fn is not None or "PL_GLOBAL_SEED" not in os.environ: - return - dataloader.worker_init_fn = pl_worker_init_function + if dataloader.worker_init_fn is None and os.environ.get("PL_SEED_WORKERS", False): + dataloader.worker_init_fn = pl_worker_init_function def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 3db59677b1e87..0f6ec0413daf6 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -27,17 +27,23 @@ log = logging.getLogger(__name__) -def seed_everything(seed: Optional[int] = None) -> int: +def seed_everything(seed: Optional[int] = None, workers: bool = True) -> int: """ Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random - In addition, sets the env variable `PL_GLOBAL_SEED` which will be passed to - spawned subprocesses (e.g. ddp_spawn backend). + In addition, sets the following environment variables: + + - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). + - `PL_SEED_WORKERS`: (optional) is set to 1 if ```workers=True``. Args: seed: the integer value seed for global random state in Lightning. If `None`, will read seed from `PL_GLOBAL_SEED` env variable or select it randomly. + workers: if set to ``True``, will properly configure all dataloaders passed to the + Trainer with a ``worker_init_fn``. If the user already provides such a function + for their dataloaders, setting this argument will have no influence. See also: + :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. """ max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min @@ -62,6 +68,10 @@ def seed_everything(seed: Optional[int] = None) -> int: np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + if workers: + os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" + return seed @@ -73,9 +83,8 @@ def pl_worker_init_function(worker_id: int) -> None: """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`. - - See Also - `Randomness in DataLoaders `_ + See also the PyTorch documentation on + `randomness in DataLoaders `_. """ worker_seed = (torch.initial_seed() + worker_id) % (2**32) numpy.random.seed(worker_seed) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 0d7f185419259..f3966c4d20431 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -645,6 +645,10 @@ def __len__(self): return 16 +def _user_worker_init_fn(_): + pass + + def test_auto_add_worker_init_fn(tmpdir): """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ dataset = NumpyRandomDataset() @@ -658,8 +662,20 @@ def test_auto_add_worker_init_fn(tmpdir): trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None - # with pl.seed_everything() - seed_everything(0) + # with forcefully avoiding it + seed_everything(0, workers=False) + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is None + + # when user already has a worker_init_fn + user_function = _user_worker_init_fn + dataloader.worker_init_fn = user_function + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is user_function + dataloader.worker_init_fn = None + + # main use case + seed_everything(0, workers=True) trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is pl_worker_init_function unique_batches = set(tuple(batch.view(-1).tolist()) for batch in dataloader) From 23f28f35c199aeeb0a2f28619e8bb5e74674a4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Apr 2021 00:49:26 +0200 Subject: [PATCH 10/38] add global rank for worker_init_fn --- pytorch_lightning/utilities/seed.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 0f6ec0413daf6..1e168a13f2ca0 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -21,6 +21,7 @@ import numpy import numpy as np import torch +from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import rank_zero_warn @@ -86,6 +87,7 @@ def pl_worker_init_function(worker_id: int) -> None: See also the PyTorch documentation on `randomness in DataLoaders `_. """ - worker_seed = (torch.initial_seed() + worker_id) % (2**32) + global_rank = rank_zero_only.rank + worker_seed = (torch.initial_seed() + worker_id + global_rank) % (2**32) numpy.random.seed(worker_seed) random.seed(worker_seed) From 59416c9c34867576c547cf7e08749feccfb3123a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Apr 2021 00:55:12 +0200 Subject: [PATCH 11/38] include torch.manual_seed --- pytorch_lightning/utilities/seed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 1e168a13f2ca0..9644fae6975ba 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -89,5 +89,6 @@ def pl_worker_init_function(worker_id: int) -> None: """ global_rank = rank_zero_only.rank worker_seed = (torch.initial_seed() + worker_id + global_rank) % (2**32) + torch.manual_seed(worker_seed) numpy.random.seed(worker_seed) random.seed(worker_seed) From 3b247b81381d454e1cfc427babb6f5087fa8ec02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Apr 2021 13:11:05 +0200 Subject: [PATCH 12/38] fix env var access --- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/utilities/seed.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 6353316d1e567..1243cc0c1bdbc 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -104,7 +104,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: @staticmethod def auto_add_worker_init_fn(dataloader: DataLoader) -> None: - if dataloader.worker_init_fn is None and os.environ.get("PL_SEED_WORKERS", False): + if dataloader.worker_init_fn is None and int(os.environ.get("PL_SEED_WORKERS", "0")): dataloader.worker_init_fn = pl_worker_init_function def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 9644fae6975ba..77652f104a408 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -70,8 +70,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = True) -> int: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - if workers: - os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" + os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" return seed From 5ed6791f8493aa22e948900980acd856d6022a96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Apr 2021 15:25:05 +0200 Subject: [PATCH 13/38] ddp test --- tests/trainer/test_dataloaders.py | 32 +++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f3966c4d20431..1e51a8f0adb3e 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -26,6 +26,7 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, Trainer, seed_everything from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import pl_worker_init_function @@ -649,39 +650,58 @@ def _user_worker_init_fn(_): pass -def test_auto_add_worker_init_fn(tmpdir): +def test_auto_add_worker_init_fn(): """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ dataset = NumpyRandomDataset() num_samples = len(dataset) num_workers = 2 batch_size = 2 dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) - trainer = Trainer(default_root_dir=tmpdir) # without pl.seed_everything() - trainer.auto_add_worker_init_fn(dataloader) + Trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None # with forcefully avoiding it seed_everything(0, workers=False) - trainer.auto_add_worker_init_fn(dataloader) + Trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None # when user already has a worker_init_fn user_function = _user_worker_init_fn dataloader.worker_init_fn = user_function - trainer.auto_add_worker_init_fn(dataloader) + Trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is user_function dataloader.worker_init_fn = None # main use case seed_everything(0, workers=True) - trainer.auto_add_worker_init_fn(dataloader) + Trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is pl_worker_init_function unique_batches = set(tuple(batch.view(-1).tolist()) for batch in dataloader) assert len(unique_batches) > (num_samples // (batch_size * num_workers)) +def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): + """ Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """ + dataset = NumpyRandomDataset() + num_workers = 2 + batch_size = 2 + world_size = 2 + num_samples = len(dataset) + + # simulate distributed processes by setting rank and collecting the batches + unique_batches_world = set() + for current_rank in range(world_size): + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + seed_everything(0, workers=True) + monkeypatch.setattr(rank_zero_only, "rank", current_rank) + Trainer.auto_add_worker_init_fn(dataloader) + unique_batches_world |= set(tuple(batch.view(-1).tolist()) for batch in dataloader) + + assert len(unique_batches_world) > (num_samples // (batch_size * num_workers * world_size)) + + def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = BoringModel() From 08e552449f455790ca1c67b5afefdd4fa9c6421c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Apr 2021 13:59:29 +0200 Subject: [PATCH 14/38] suggestion by r.kern avoid collision --- pytorch_lightning/utilities/seed.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 77652f104a408..f5d7c2846b49b 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -86,8 +86,12 @@ def pl_worker_init_function(worker_id: int) -> None: See also the PyTorch documentation on `randomness in DataLoaders `_. """ + # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 global_rank = rank_zero_only.rank - worker_seed = (torch.initial_seed() + worker_id + global_rank) % (2**32) - torch.manual_seed(worker_seed) - numpy.random.seed(worker_seed) - random.seed(worker_seed) + process_seed = torch.initial_seed() + # back out the base seed so we can use all the bits + base_seed = process_seed - worker_id + ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) + # use 128 bits (4 x 32-bit words) + np.random.seed(ss.generate_state(4)) + torch.manual_seed((process_seed * global_rank) % (2**64)) From b7828ecf0f05b4b127e51891d9d618a9e70324b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Apr 2021 21:58:39 +0200 Subject: [PATCH 15/38] incorporate seedsequence for torch and stdlib seed --- pytorch_lightning/utilities/seed.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index f5d7c2846b49b..c9ab2edfa7597 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -94,4 +94,10 @@ def pl_worker_init_function(worker_id: int) -> None: ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) # use 128 bits (4 x 32-bit words) np.random.seed(ss.generate_state(4)) - torch.manual_seed((process_seed * global_rank) % (2**64)) + # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module + torch_ss, stdlib_ss = ss.spawn(2) + # PyTorch takes a 64-bit seed + torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) + # use 128 bits expressed as an integer + stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() + random.seed(stdlib_seed) From b3af3d16749de919466859d509296eeb59f3baa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 13:46:56 +0200 Subject: [PATCH 16/38] strict assert --- tests/trainer/test_dataloaders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 1e51a8f0adb3e..26cb31c278aa4 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -678,8 +678,10 @@ def test_auto_add_worker_init_fn(): seed_everything(0, workers=True) Trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is pl_worker_init_function - unique_batches = set(tuple(batch.view(-1).tolist()) for batch in dataloader) - assert len(unique_batches) > (num_samples // (batch_size * num_workers)) + all_batches = torch.cat([batch for batch in dataloader]) + assert all_batches.shape[0] == num_samples + unique_samples = set([tuple(sample.tolist()) for sample in all_batches]) + assert len(unique_samples) == num_samples def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): From e4cda361c4e6fc27de6246c91202126c257b0f40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 14:27:13 +0200 Subject: [PATCH 17/38] strict assert --- tests/trainer/test_dataloaders.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 26cb31c278aa4..ca3011082fc3a 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -693,15 +693,19 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): num_samples = len(dataset) # simulate distributed processes by setting rank and collecting the batches - unique_batches_world = set() + all_batches = [] for current_rank in range(world_size): - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, worker_init_fn=pl_worker_init_function) seed_everything(0, workers=True) monkeypatch.setattr(rank_zero_only, "rank", current_rank) + assert rank_zero_only.rank == current_rank Trainer.auto_add_worker_init_fn(dataloader) - unique_batches_world |= set(tuple(batch.view(-1).tolist()) for batch in dataloader) + all_batches.extend([batch for batch in dataloader]) - assert len(unique_batches_world) > (num_samples // (batch_size * num_workers * world_size)) + all_batches = torch.cat(all_batches) + print(all_batches) + assert all_batches.shape[0] == num_samples * world_size + assert len(torch.unique(all_batches, dim=0)) == num_samples * world_size def test_warning_with_iterable_dataset_and_len(tmpdir): From 2f461e7b7af22909781c69c3bd597c4bf70a0fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 14:31:18 +0200 Subject: [PATCH 18/38] remove print statement --- tests/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index ca3011082fc3a..9b2b4707a00c3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -703,7 +703,6 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): all_batches.extend([batch for batch in dataloader]) all_batches = torch.cat(all_batches) - print(all_batches) assert all_batches.shape[0] == num_samples * world_size assert len(torch.unique(all_batches, dim=0)) == num_samples * world_size From 398ddef2b644f4537500d8aaad86e89e578df585 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 16:48:49 +0200 Subject: [PATCH 19/38] fix global rank issues --- pytorch_lightning/trainer/data_loading.py | 6 +- pytorch_lightning/utilities/seed.py | 4 +- tests/trainer/test_dataloaders.py | 93 +++++++++++++---------- 3 files changed, 58 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1243cc0c1bdbc..43f367ee2ed3b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,6 +16,7 @@ import os from abc import ABC from copy import deepcopy +from functools import partial from typing import Iterable, List, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler @@ -102,10 +103,9 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: f' in the `DataLoader` init to improve performance.' ) - @staticmethod - def auto_add_worker_init_fn(dataloader: DataLoader) -> None: + def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: if dataloader.worker_init_fn is None and int(os.environ.get("PL_SEED_WORKERS", "0")): - dataloader.worker_init_fn = pl_worker_init_function + dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index c9ab2edfa7597..662eb4ece8951 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -79,7 +79,7 @@ def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> return random.randint(min_seed_value, max_seed_value) -def pl_worker_init_function(worker_id: int) -> None: +def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`. @@ -87,7 +87,7 @@ def pl_worker_init_function(worker_id: int) -> None: `randomness in DataLoaders `_. """ # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 - global_rank = rank_zero_only.rank + global_rank = rank if rank is not None else rank_zero_only.rank process_seed = torch.initial_seed() # back out the base seed so we can use all the bits base_seed = process_seed - worker_id diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 9b2b4707a00c3..0310554e31f61 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from functools import partial from unittest import mock -from unittest.mock import patch +from unittest.mock import patch, Mock import numpy import pytest @@ -637,74 +638,86 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) -class NumpyRandomDataset(Dataset): - - def __getitem__(self, index): - return numpy.random.randint(0, 100, 3) - - def __len__(self): - return 16 - - def _user_worker_init_fn(_): pass def test_auto_add_worker_init_fn(): """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ - dataset = NumpyRandomDataset() - num_samples = len(dataset) - num_workers = 2 - batch_size = 2 - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + dataset = Mock() + dataloader = DataLoader(dataset) + trainer = Trainer() # without pl.seed_everything() - Trainer.auto_add_worker_init_fn(dataloader) + trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None # with forcefully avoiding it seed_everything(0, workers=False) - Trainer.auto_add_worker_init_fn(dataloader) + trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is None # when user already has a worker_init_fn user_function = _user_worker_init_fn dataloader.worker_init_fn = user_function - Trainer.auto_add_worker_init_fn(dataloader) + trainer.auto_add_worker_init_fn(dataloader) assert dataloader.worker_init_fn is user_function dataloader.worker_init_fn = None # main use case seed_everything(0, workers=True) - Trainer.auto_add_worker_init_fn(dataloader) - assert dataloader.worker_init_fn is pl_worker_init_function - all_batches = torch.cat([batch for batch in dataloader]) - assert all_batches.shape[0] == num_samples - unique_samples = set([tuple(sample.tolist()) for sample in all_batches]) - assert len(unique_samples) == num_samples + trainer.auto_add_worker_init_fn(dataloader) + assert dataloader.worker_init_fn is not None +class NumpyRandomDataset(Dataset): + size = 16 + + def __getitem__(self, index): + return numpy.random.randint(0, 100, 3) + + def __len__(self): + return self.size + + +class MultiProcessModel(BoringModel): + + def __init__(self): + super().__init__() + self.batches_seen = [] + + def training_step(self, batch, batch_idx): + self.batches_seen.append(batch) + + def training_epoch_end(self, outputs): + world_size = 2 + num_samples = NumpyRandomDataset.size + all_batches = torch.cat(self.batches_seen) + all_batches = self.all_gather(all_batches) + assert all_batches.shape[0] == world_size + all_batches = all_batches.view(-1, 3) + assert len(torch.unique(all_batches, dim=0)) == num_samples + + +@RunIf(min_gpus=2) def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): """ Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """ dataset = NumpyRandomDataset() num_workers = 2 batch_size = 2 - world_size = 2 - num_samples = len(dataset) - - # simulate distributed processes by setting rank and collecting the batches - all_batches = [] - for current_rank in range(world_size): - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, worker_init_fn=pl_worker_init_function) - seed_everything(0, workers=True) - monkeypatch.setattr(rank_zero_only, "rank", current_rank) - assert rank_zero_only.rank == current_rank - Trainer.auto_add_worker_init_fn(dataloader) - all_batches.extend([batch for batch in dataloader]) - - all_batches = torch.cat(all_batches) - assert all_batches.shape[0] == num_samples * world_size - assert len(torch.unique(all_batches, dim=0)) == num_samples * world_size + + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) + seed_everything(0, workers=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + gpus=2, + accelerator="ddp_spawn", + ) + model = MultiProcessModel() + model.train_dataloader = None + model.val_dataloader = None + trainer.fit(model, train_dataloader=dataloader) def test_warning_with_iterable_dataset_and_len(tmpdir): From becb26b79f9eb2ec0803c50274c025093101dd3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 16:51:00 +0200 Subject: [PATCH 20/38] unused import --- pytorch_lightning/utilities/seed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 662eb4ece8951..52217460910fe 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -18,7 +18,6 @@ import random from typing import Optional -import numpy import numpy as np import torch from pytorch_lightning.utilities.distributed import rank_zero_only From a999b95cb81ae7836a56019f6a41101020fc475b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 17:09:16 +0200 Subject: [PATCH 21/38] ignore coverage for worker function --- pytorch_lightning/utilities/seed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 52217460910fe..b9ab041314681 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -78,7 +78,7 @@ def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> return random.randint(min_seed_value, max_seed_value) -def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: +def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`. From 896bed7e942c2b69b4c9d324146adf123f8d3b13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 17:37:54 +0200 Subject: [PATCH 22/38] Update tests/trainer/test_dataloaders.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- tests/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 0310554e31f61..243235f9b38bd 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -715,7 +715,6 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): accelerator="ddp_spawn", ) model = MultiProcessModel() - model.train_dataloader = None model.val_dataloader = None trainer.fit(model, train_dataloader=dataloader) From 5b656aa24c612eeff6554adaa3adba83884206cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 17:38:38 +0200 Subject: [PATCH 23/38] Update pytorch_lightning/trainer/data_loading.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 26b8eb5aa851e..76c70c9ea71bb 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -105,7 +105,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: ) def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: - if dataloader.worker_init_fn is None and int(os.environ.get("PL_SEED_WORKERS", "0")): + if dataloader.worker_init_fn is None and int(os.environ.get("PL_SEED_WORKERS", 0)): dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: From 69deaf8e3d3209cc77640eb8ae58ce42418cf55c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 15 Apr 2021 19:08:07 +0200 Subject: [PATCH 24/38] Update pytorch_lightning/trainer/data_loading.py Co-authored-by: ananthsub --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 76c70c9ea71bb..f52c60655f512 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -105,7 +105,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: ) def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: - if dataloader.worker_init_fn is None and int(os.environ.get("PL_SEED_WORKERS", 0)): + if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: From 60e0669ecdb977b26dda779066288d3e6c2e7df4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 01:09:23 +0200 Subject: [PATCH 25/38] update req --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3faed306a488a..027f776f145a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.16.6 +numpy>=1.17 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 From 9e0ac38eca4982028a6cf268523b6b877ddd751f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 04:56:51 +0200 Subject: [PATCH 26/38] 32-bit seeding for pytorch prior 1.7 --- pytorch_lightning/utilities/seed.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index b9ab041314681..13185900c30a0 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -20,9 +20,9 @@ import numpy as np import torch -from pytorch_lightning.utilities.distributed import rank_zero_only -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_only log = logging.getLogger(__name__) @@ -95,8 +95,9 @@ def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # p np.random.seed(ss.generate_state(4)) # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module torch_ss, stdlib_ss = ss.spawn(2) - # PyTorch takes a 64-bit seed - torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) + # PyTorch 1.7 and above takes a 64-bit seed + dtype = np.uint64 if _TORCH_GREATER_EQUAL_1_7 else np.uint32 + torch.manual_seed(torch_ss.generate_state(1, dtype=dtype)[0]) # use 128 bits expressed as an integer stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() random.seed(stdlib_seed) From e56d1b0c1023548439ff4ef03479b0aa6fcce349 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 05:01:10 +0200 Subject: [PATCH 27/38] flake8 --- tests/trainer/test_dataloaders.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 243235f9b38bd..d1150b9a3438b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from functools import partial from unittest import mock from unittest.mock import patch, Mock @@ -27,10 +26,8 @@ import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, Trainer, seed_everything from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.seed import pl_worker_init_function from tests.base import EvalModelTemplate from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf From e8844b8bb4699f15393acc486387418b22d177f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 05:21:03 +0200 Subject: [PATCH 28/38] test duplicates --- pytorch_lightning/trainer/data_loading.py | 1 - tests/trainer/test_dataloaders.py | 39 +++++++++++++++-------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f52c60655f512..db9ac2196fcb2 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,7 +16,6 @@ import os from abc import ABC from copy import deepcopy - from functools import partial from typing import Iterable, List, Optional, Tuple, Union diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d1150b9a3438b..d6ca906e2259c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from unittest import mock -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch import numpy import pytest import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import IterableDataset, Subset, Dataset +from torch.utils.data.dataset import Dataset, IterableDataset, Subset from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler import tests.helpers.pipelines as tpipes -from pytorch_lightning import Callback, Trainer, seed_everything +from pytorch_lightning import Callback, seed_everything, Trainer from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -635,10 +636,32 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) +class NumpyRandomDataset(Dataset): + # this datset uses numpy instead of torch to produce random numbers + size = 16 + + def __getitem__(self, index): + return numpy.random.randint(0, 100, 3) + + def __len__(self): + return self.size + + def _user_worker_init_fn(_): pass +@pytest.mark.skipif(not sys.platform.startswith('linux'), reason="only on platforms that fork") +def test_missing_worker_init_fn(): + """ Test that the dataloader workers produce duplicates when we use numpy but don't initialize the worker seed. """ + seed_everything(0) + dataset = NumpyRandomDataset() + dataloader = DataLoader(dataset, batch_size=2, num_workers=2) + batches = [batch for batch in dataloader] + all_batches = torch.cat(batches) + assert len(torch.unique(all_batches, dim=0)) < len(dataset) + + def test_auto_add_worker_init_fn(): """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ dataset = Mock() @@ -667,16 +690,6 @@ def test_auto_add_worker_init_fn(): assert dataloader.worker_init_fn is not None -class NumpyRandomDataset(Dataset): - size = 16 - - def __getitem__(self, index): - return numpy.random.randint(0, 100, 3) - - def __len__(self): - return self.size - - class MultiProcessModel(BoringModel): def __init__(self): From cc9f8182cfa6beab9dcc0135776db4511066827f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 05:28:37 +0200 Subject: [PATCH 29/38] update docs --- docs/source/common/trainer.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index eea68cbd460c5..881b3dbec87bf 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -211,12 +211,16 @@ Example:: from pytorch_lightning import Trainer, seed_everything - seed_everything(42) + seed_everything(42, workers=True) # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. model = Model() trainer = Trainer(deterministic=True) +In Lightning version 1.3 and above, ``seed_everything`` also guarantees unique seeds across all dataloader worker +processes. This is turned on by default and ensures that e.g. data augmentations are not repeated across workers. +If the old behavior is desired, set ``seed_everything(42, workers=False)``. + ------- Trainer flags From 4d6d56baad96dd888d969781c161361ed3d659df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 16 Apr 2021 13:32:37 +0200 Subject: [PATCH 30/38] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 027f776f145a9..95773a540e7b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.17 +numpy>=1.17.5 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 From 0a62d55fc8ea5afbd69084a458f2ae0ee2f118f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 16 Apr 2021 13:41:21 +0200 Subject: [PATCH 31/38] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 95773a540e7b9..b4dfef5ca4a3f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.17.5 +numpy>=1.17.3 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 From 1fa020f44cf2d34b209260b72bafac14fe03d8c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 16 Apr 2021 13:50:56 +0200 Subject: [PATCH 32/38] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b4dfef5ca4a3f..2bd1233be34f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.17.3 +numpy>=1.17.1 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 From 2124ef5e9228ec16aeb0fa05ce88b5de24410ef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 16 Apr 2021 14:16:41 +0200 Subject: [PATCH 33/38] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2bd1233be34f9..3438b1ea2189b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # the default package dependencies -numpy>=1.17.1 +numpy>=1.17.2 torch>=1.4 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 From 4456bf82bbd5f929ce86a2febfdc2b7ab53c9fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 19:00:31 +0200 Subject: [PATCH 34/38] change default --- docs/source/common/trainer.rst | 6 +++--- pytorch_lightning/utilities/seed.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 881b3dbec87bf..7c855ddc47522 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -217,9 +217,9 @@ Example:: trainer = Trainer(deterministic=True) -In Lightning version 1.3 and above, ``seed_everything`` also guarantees unique seeds across all dataloader worker -processes. This is turned on by default and ensures that e.g. data augmentations are not repeated across workers. -If the old behavior is desired, set ``seed_everything(42, workers=False)``. +By setting ``workers=True`` in :func:`~pytorch_lightning.utilities.seed.seed_everything`, Lightning derives +unique seeds across all dataloader workers and processes for :mod:`torch`, :mod:`numpy` and stdlib +:mod:`random` number generators. When turned on, it ensures that e.g. data augmentations are not repeated across workers. ------- diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 13185900c30a0..b7eaba72c1b02 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -27,7 +27,7 @@ log = logging.getLogger(__name__) -def seed_everything(seed: Optional[int] = None, workers: bool = True) -> int: +def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: """ Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random @@ -81,7 +81,7 @@ def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set - set the seed with :func:`~pytorch_lightning.utilities.seed.seed_everything`. + set the seed with ``seed_everything(seed, workers=True)``. See also the PyTorch documentation on `randomness in DataLoaders `_. """ From 708581a67f0b66ec01f8c1628f7483382d0df628 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 19:01:03 +0200 Subject: [PATCH 35/38] remove sanity test --- tests/trainer/test_dataloaders.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d6ca906e2259c..46aa51acc1e24 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -651,17 +651,6 @@ def _user_worker_init_fn(_): pass -@pytest.mark.skipif(not sys.platform.startswith('linux'), reason="only on platforms that fork") -def test_missing_worker_init_fn(): - """ Test that the dataloader workers produce duplicates when we use numpy but don't initialize the worker seed. """ - seed_everything(0) - dataset = NumpyRandomDataset() - dataloader = DataLoader(dataset, batch_size=2, num_workers=2) - batches = [batch for batch in dataloader] - all_batches = torch.cat(batches) - assert len(torch.unique(all_batches, dim=0)) < len(dataset) - - def test_auto_add_worker_init_fn(): """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ dataset = Mock() From fd59239911f553310265327d5483b16c85370287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 19:02:48 +0200 Subject: [PATCH 36/38] unused import --- tests/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 46aa51acc1e24..7084b3d6fc101 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import sys from unittest import mock from unittest.mock import Mock, patch From 99799a003a77a06c9772660078a59dedb51fbf8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 19:01:03 +0200 Subject: [PATCH 37/38] Revert "remove sanity test" This reverts commit 708581a67f0b66ec01f8c1628f7483382d0df628. --- tests/trainer/test_dataloaders.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 7084b3d6fc101..54a36a5b351d7 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -650,6 +650,17 @@ def _user_worker_init_fn(_): pass +@pytest.mark.skipif(not sys.platform.startswith('linux'), reason="only on platforms that fork") +def test_missing_worker_init_fn(): + """ Test that the dataloader workers produce duplicates when we use numpy but don't initialize the worker seed. """ + seed_everything(0) + dataset = NumpyRandomDataset() + dataloader = DataLoader(dataset, batch_size=2, num_workers=2) + batches = [batch for batch in dataloader] + all_batches = torch.cat(batches) + assert len(torch.unique(all_batches, dim=0)) < len(dataset) + + def test_auto_add_worker_init_fn(): """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ dataset = Mock() From 851482165a9b5298d18b1aee8e555a5de105987f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 20:13:46 +0200 Subject: [PATCH 38/38] better sanity check --- tests/trainer/test_dataloaders.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 54a36a5b351d7..b558651d8d3c3 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -650,15 +650,25 @@ def _user_worker_init_fn(_): pass -@pytest.mark.skipif(not sys.platform.startswith('linux'), reason="only on platforms that fork") def test_missing_worker_init_fn(): - """ Test that the dataloader workers produce duplicates when we use numpy but don't initialize the worker seed. """ - seed_everything(0) + """ Test that naive worker seed initialization leads to undesired random state in subprocesses. """ dataset = NumpyRandomDataset() - dataloader = DataLoader(dataset, batch_size=2, num_workers=2) - batches = [batch for batch in dataloader] - all_batches = torch.cat(batches) - assert len(torch.unique(all_batches, dim=0)) < len(dataset) + + seed_everything(0) + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) + batches0 = torch.cat([batch for batch in dataloader]) + + seed_everything(0) + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) + batches1 = torch.cat([batch for batch in dataloader]) + + is_duplicated = len(torch.unique(batches1, dim=0)) < len(dataset) + is_deterministic = torch.eq(batches0, batches1).all() + + # depending on the OS, we either have + # 1) the same seed in all worker proceses, producing duplicate samples / augmentations, or + # 2) different seeds in each worker process, but they are not derived from the seed of the main process + assert not is_deterministic or is_duplicated def test_auto_add_worker_init_fn():