From 5f6017dc8e00b6b90b71bad3d3cedbfadbcd334b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 23 Jun 2022 22:39:50 +0200 Subject: [PATCH 1/7] call set_epoch for sampler and batch sampler --- .../loops/dataloader/evaluation_loop.py | 11 +++-------- .../loops/dataloader/prediction_loop.py | 10 +++------- src/pytorch_lightning/loops/fit_loop.py | 9 +++------ src/pytorch_lightning/loops/utilities.py | 9 +++++++++ 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py index 53dbce9cacdf1..48ccf7df1266d 100644 --- a/src/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/src/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -26,6 +26,7 @@ from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop +from pytorch_lightning.loops.utilities import _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -161,14 +162,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self._has_run = True def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - dataloader = self.current_dataloader - if ( - dataloader is not None - and getattr(dataloader, "sampler", None) - and callable(getattr(dataloader.sampler, "set_epoch", None)) - ): - # set seed for distributed sampler (enables shuffling for each epoch) - dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed) + if self.current_dataloader is not None: + _set_sampler_epoch(self.current_dataloader, self.trainer.fit_loop.epoch_progress.current.processed) super().on_advance_start(*args, **kwargs) diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index 4ff6543064a6e..ce9ec9008c2db 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -5,6 +5,7 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop +from pytorch_lightning.loops.utilities import _set_sampler_epoch from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT @@ -90,13 +91,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) dataloader = self.current_dataloader - if ( - dataloader is not None - and getattr(dataloader, "sampler", None) - and callable(getattr(dataloader.sampler, "set_epoch", None)) - ): - # set seed for distributed sampler (enables shuffling for each epoch) - dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed) + if dataloader is not None: + _set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed) dataloader = self.trainer.strategy.process_dataloader(dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index ac33390a97cec..0771a4a71de9f 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -21,7 +21,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE -from pytorch_lightning.loops.utilities import _is_max_limit_reached +from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum @@ -232,11 +232,8 @@ def on_advance_start(self) -> None: # type: ignore[override] # reset outputs here instead of in `reset` as they are not accumulated between epochs self._outputs = [] - if self.trainer.train_dataloader is not None and callable( - getattr(self.trainer.train_dataloader.sampler, "set_epoch", None) - ): - # set seed for distributed sampler (enables shuffling for each epoch) - self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed) + if self.trainer.train_dataloader is not None: + _set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed) # changing gradient according accumulation_scheduler self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index b5fefcd4b0011..360d54d5fb357 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -22,6 +22,7 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.loops import Loop @@ -220,3 +221,11 @@ def _reset_progress(loop: Loop) -> None: def _v1_8_output_format(fx: Callable) -> bool: parameters = inspect.signature(fx).parameters return "new_format" in parameters and parameters["new_format"].default is True + + +def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None: + # set seed for distributed sampler (enables shuffling for each epoch) + for sampler_name in ("sampler", "batch_sampler"): + sampler = getattr(dataloader, sampler_name, None) + if sampler and callable(getattr(sampler, "set_epoch", None)): + sampler.set_epoch(epoch) From 39a284cee1bdee61f67e5b0a3b56fac722475b21 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 23 Jun 2022 22:49:58 +0200 Subject: [PATCH 2/7] add docstring --- src/pytorch_lightning/loops/utilities.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index 360d54d5fb357..1ec63c4b992ee 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -224,7 +224,12 @@ def _v1_8_output_format(fx: Callable) -> bool: def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None: - # set seed for distributed sampler (enables shuffling for each epoch) + """Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader. + + Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a + :class:`~torch.utils.data.DistributedSampler`, ``set_epoch`` must be called at the beginning of every epoch to + ensure shuffling applies a new ordering. This has no effect if shuffling is off. + """ for sampler_name in ("sampler", "batch_sampler"): sampler = getattr(dataloader, sampler_name, None) if sampler and callable(getattr(sampler, "set_epoch", None)): From 4a472daccbb2f71ec263038fe01d6cda1ea25a80 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 24 Jun 2022 03:03:10 +0200 Subject: [PATCH 3/7] fixes and tests --- src/pytorch_lightning/loops/utilities.py | 2 +- src/pytorch_lightning/trainer/supporters.py | 7 ++- .../loops/test_evaluation_loop.py | 57 ++++++++++++++++--- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index 1ec63c4b992ee..53b8b018cb928 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -232,5 +232,5 @@ def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None: """ for sampler_name in ("sampler", "batch_sampler"): sampler = getattr(dataloader, sampler_name, None) - if sampler and callable(getattr(sampler, "set_epoch", None)): + if sampler is not None and callable(getattr(sampler, "set_epoch", None)): sampler.set_epoch(epoch) diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py index b8f688892b318..6d3ec88b0be6a 100644 --- a/src/pytorch_lightning/trainer/supporters.py +++ b/src/pytorch_lightning/trainer/supporters.py @@ -438,9 +438,14 @@ class DataLoaderDict(dict): @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: - """Return a collections of samplers extracting from loaders.""" + """Return a collections of samplers extracted from loaders.""" return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None) + @property + def batch_sampler(self) -> Union[Iterable, Sequence, Mapping]: + """Return a collections of batch samplers extracted from loaders.""" + return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "batch_sampler", None) + def _wrap_loaders_max_size_cycle(self) -> Any: """Wraps all loaders to make sure they are cycled until the longest loader is exhausted. diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index cd531aaa2f80b..2f14deb37c836 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import Mock +from unittest.mock import call, Mock import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.sampler import RandomSampler +from torch.utils.data.sampler import BatchSampler, RandomSampler from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset @@ -44,9 +44,8 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): assert eval_epoch_end_mock.call_count == 4 -def test_set_epoch_called_eval_predict(tmpdir): - """Tests that set_epoch (if the sampler has one) is called on the DataLoader during evaluation and - prediction.""" +def test_evaluation_loop_sampler_set_epoch_called(tmpdir): + """Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation.""" def _get_dataloader(): dataset = RandomDataset(32, 64) @@ -56,20 +55,60 @@ def _get_dataloader(): model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, enable_model_summary=False + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, + ) + + train_dataloader = _get_dataloader() + val_dataloader = _get_dataloader() + trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) + # One for each epoch + assert train_dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)] + # One for each epoch + sanity check + assert val_dataloader.sampler.set_epoch.call_args_list == [call(0), call(0), call(1)] + + val_dataloader = _get_dataloader() + trainer.validate(model, val_dataloader) + assert val_dataloader.sampler.set_epoch.call_args_list == [call(2)] + + +def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir): + """Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation.""" + + def _get_dataloader(): + dataset = RandomDataset(32, 64) + sampler = RandomSampler(dataset) + batch_sampler = BatchSampler(sampler, 2, True) + batch_sampler.set_epoch = Mock() + return DataLoader(dataset, batch_sampler=batch_sampler) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=2, + enable_model_summary=False, + enable_checkpointing=False, + logger=False, ) train_dataloader = _get_dataloader() val_dataloader = _get_dataloader() trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) # One for each epoch - assert train_dataloader.sampler.set_epoch.call_count == 2 + assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)] # One for each epoch + sanity check - assert val_dataloader.sampler.set_epoch.call_count == 3 + assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)] val_dataloader = _get_dataloader() trainer.validate(model, val_dataloader) - assert val_dataloader.sampler.set_epoch.call_count == 1 + assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)] @mock.patch( From 0ecefe1d15f07f9d3f4123f96aa84a744e4de8c7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 24 Jun 2022 03:04:46 +0200 Subject: [PATCH 4/7] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd8118986132f..30639b5ed001a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -268,7 +268,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350)) -- +- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396)) ## [1.6.4] - 2022-06-01 From 5f9ea64a90fe033b0c0aa8eecf1189c7d297d935 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 24 Jun 2022 03:24:02 +0200 Subject: [PATCH 5/7] add unit test --- tests/tests_pytorch/loops/test_utilities.py | 25 ++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/loops/test_utilities.py b/tests/tests_pytorch/loops/test_utilities.py index c5d2e98d008b0..11698bb0c7243 100644 --- a/tests/tests_pytorch/loops/test_utilities.py +++ b/tests/tests_pytorch/loops/test_utilities.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock, Mock + import pytest import torch +from torch.utils.data import DataLoader -from pytorch_lightning.loops.utilities import _extract_hiddens, _v1_8_output_format +from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch, _v1_8_output_format from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -61,3 +64,23 @@ def training_epoch_end(outputs, new_format=True): ... assert _v1_8_output_format(training_epoch_end) + + +def test_set_sampler_epoch(): + # No samplers + dataloader = Mock() + dataloader.sampler = None + dataloader.batch_sampler = None + _set_sampler_epoch(dataloader, 55) + + # set_epoch not callable + dataloader = Mock() + dataloader.sampler.set_epoch = None + dataloader.batch_sampler.set_epoch = None + _set_sampler_epoch(dataloader, 55) + + # set_epoch callable + dataloader = Mock() + _set_sampler_epoch(dataloader, 55) + dataloader.sampler.set_epoch.assert_called_once_with(55) + dataloader.batch_sampler.set_epoch.assert_called_once_with(55) From 79a70e62a7bf5710bf96c43995b35fdbbee53645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 28 Jun 2022 14:53:17 -0400 Subject: [PATCH 6/7] Update src/pytorch_lightning/loops/utilities.py Co-authored-by: Rohit Gupta --- src/pytorch_lightning/loops/utilities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index 53b8b018cb928..02d9cc2c42552 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -227,8 +227,8 @@ def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None: """Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader. Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a - :class:`~torch.utils.data.DistributedSampler`, ``set_epoch`` must be called at the beginning of every epoch to - ensure shuffling applies a new ordering. This has no effect if shuffling is off. + :class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning + of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off. """ for sampler_name in ("sampler", "batch_sampler"): sampler = getattr(dataloader, sampler_name, None) From 15b82734b7678ec8fe602e7b1869f07244d85776 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 29 Jun 2022 17:45:48 +0200 Subject: [PATCH 7/7] unused imports --- tests/tests_pytorch/loops/test_utilities.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests_pytorch/loops/test_utilities.py b/tests/tests_pytorch/loops/test_utilities.py index 11698bb0c7243..914c1de8e115b 100644 --- a/tests/tests_pytorch/loops/test_utilities.py +++ b/tests/tests_pytorch/loops/test_utilities.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock import pytest import torch -from torch.utils.data import DataLoader from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch, _v1_8_output_format from pytorch_lightning.utilities.exceptions import MisconfigurationException