Skip to content

Commit 922324c

Browse files
awaelchliBorda
authored andcommitted
Call set_epoch for distributed batch samplers (#13396)
Co-authored-by: Jirka <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> (cherry picked from commit 2dd332f)
1 parent 032c9eb commit 922324c

File tree

8 files changed

+103
-32
lines changed

8 files changed

+103
-32
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Fixed bug with Python version check that prevented use with development versions of Python ([#13420](https://github.com/PyTorchLightning/pytorch-lightning/pull/13420))
1616

1717

18+
- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396))
19+
20+
1821

1922
## [1.6.4] - 2022-06-01
2023

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytorch_lightning.accelerators import GPUAccelerator
2727
from pytorch_lightning.loops.dataloader import DataLoaderLoop
2828
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
29+
from pytorch_lightning.loops.utilities import _set_sampler_epoch
2930
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
3031
from pytorch_lightning.trainer.states import TrainerFn
3132
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -161,14 +162,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
161162
self._has_run = True
162163

163164
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
164-
dataloader = self.current_dataloader
165-
if (
166-
dataloader is not None
167-
and getattr(dataloader, "sampler", None)
168-
and callable(getattr(dataloader.sampler, "set_epoch", None))
169-
):
170-
# set seed for distributed sampler (enables shuffling for each epoch)
171-
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
165+
if self.current_dataloader is not None:
166+
_set_sampler_epoch(self.current_dataloader, self.trainer.fit_loop.epoch_progress.current.processed)
172167

173168
super().on_advance_start(*args, **kwargs)
174169

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
77
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
8+
from pytorch_lightning.loops.utilities import _set_sampler_epoch
89
from pytorch_lightning.strategies import DDPSpawnStrategy
910
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1011
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
@@ -87,13 +88,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
8788
"""Predicts one entire dataloader."""
8889
void(*args, **kwargs)
8990
dataloader = self.current_dataloader
90-
if (
91-
dataloader is not None
92-
and getattr(dataloader, "sampler", None)
93-
and callable(getattr(dataloader.sampler, "set_epoch", None))
94-
):
95-
# set seed for distributed sampler (enables shuffling for each epoch)
96-
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
91+
if dataloader is not None:
92+
_set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed)
9793
dataloader = self.trainer.strategy.process_dataloader(dataloader)
9894
dataloader_iter = enumerate(dataloader)
9995
dl_max_batches = self.max_batches[self.current_dataloader_idx]

pytorch_lightning/loops/fit_loop.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytorch_lightning.loops import Loop
2222
from pytorch_lightning.loops.epoch import TrainingEpochLoop
2323
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
24-
from pytorch_lightning.loops.utilities import _is_max_limit_reached
24+
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
2525
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2626
from pytorch_lightning.trainer.progress import Progress
2727
from pytorch_lightning.trainer.supporters import TensorRunningAccum
@@ -232,11 +232,8 @@ def on_advance_start(self) -> None: # type: ignore[override]
232232
# reset outputs here instead of in `reset` as they are not accumulated between epochs
233233
self._outputs = []
234234

235-
if self.trainer.train_dataloader is not None and callable(
236-
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
237-
):
238-
# set seed for distributed sampler (enables shuffling for each epoch)
239-
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
235+
if self.trainer.train_dataloader is not None:
236+
_set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed)
240237

241238
# changing gradient according accumulation_scheduler
242239
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

pytorch_lightning/loops/utilities.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222
import torch
2323
from torch.optim import Optimizer
24+
from torch.utils.data import DataLoader
2425

2526
import pytorch_lightning as pl
2627
from pytorch_lightning.loops import Loop
@@ -228,3 +229,16 @@ def _reset_progress(loop: Loop) -> None:
228229
def _v1_8_output_format(fx: Callable) -> bool:
229230
parameters = inspect.signature(fx).parameters
230231
return "new_format" in parameters and parameters["new_format"].default is True
232+
233+
234+
def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None:
235+
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
236+
237+
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
238+
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
239+
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
240+
"""
241+
for sampler_name in ("sampler", "batch_sampler"):
242+
sampler = getattr(dataloader, sampler_name, None)
243+
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
244+
sampler.set_epoch(epoch)

pytorch_lightning/trainer/supporters.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,14 @@ class DataLoaderDict(dict):
438438

439439
@property
440440
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
441-
"""Return a collections of samplers extracting from loaders."""
441+
"""Return a collections of samplers extracted from loaders."""
442442
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None)
443443

444+
@property
445+
def batch_sampler(self) -> Union[Iterable, Sequence, Mapping]:
446+
"""Return a collections of batch samplers extracted from loaders."""
447+
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "batch_sampler", None)
448+
444449
def _wrap_loaders_max_size_cycle(self) -> Any:
445450
"""Wraps all loaders to make sure they are cycled until the longest loader is exhausted.
446451

tests/loops/test_evaluation_loop.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from unittest import mock
15-
from unittest.mock import Mock
15+
from unittest.mock import call, Mock
1616

1717
import torch
1818
from torch.utils.data.dataloader import DataLoader
19-
from torch.utils.data.sampler import RandomSampler
19+
from torch.utils.data.sampler import BatchSampler, RandomSampler
2020

2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.loops import EvaluationEpochLoop
@@ -44,9 +44,8 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
4444
assert eval_epoch_end_mock.call_count == 4
4545

4646

47-
def test_set_epoch_called_eval_predict(tmpdir):
48-
"""Tests that set_epoch (if the sampler has one) is called on the DataLoader during evaluation and
49-
prediction."""
47+
def test_evaluation_loop_sampler_set_epoch_called(tmpdir):
48+
"""Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation."""
5049

5150
def _get_dataloader():
5251
dataset = RandomDataset(32, 64)
@@ -56,20 +55,60 @@ def _get_dataloader():
5655

5756
model = BoringModel()
5857
trainer = Trainer(
59-
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, enable_model_summary=False
58+
default_root_dir=tmpdir,
59+
limit_train_batches=1,
60+
limit_val_batches=1,
61+
max_epochs=2,
62+
enable_model_summary=False,
63+
enable_checkpointing=False,
64+
logger=False,
65+
)
66+
67+
train_dataloader = _get_dataloader()
68+
val_dataloader = _get_dataloader()
69+
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
70+
# One for each epoch
71+
assert train_dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]
72+
# One for each epoch + sanity check
73+
assert val_dataloader.sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]
74+
75+
val_dataloader = _get_dataloader()
76+
trainer.validate(model, val_dataloader)
77+
assert val_dataloader.sampler.set_epoch.call_args_list == [call(2)]
78+
79+
80+
def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir):
81+
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation."""
82+
83+
def _get_dataloader():
84+
dataset = RandomDataset(32, 64)
85+
sampler = RandomSampler(dataset)
86+
batch_sampler = BatchSampler(sampler, 2, True)
87+
batch_sampler.set_epoch = Mock()
88+
return DataLoader(dataset, batch_sampler=batch_sampler)
89+
90+
model = BoringModel()
91+
trainer = Trainer(
92+
default_root_dir=tmpdir,
93+
limit_train_batches=1,
94+
limit_val_batches=1,
95+
max_epochs=2,
96+
enable_model_summary=False,
97+
enable_checkpointing=False,
98+
logger=False,
6099
)
61100

62101
train_dataloader = _get_dataloader()
63102
val_dataloader = _get_dataloader()
64103
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
65104
# One for each epoch
66-
assert train_dataloader.sampler.set_epoch.call_count == 2
105+
assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)]
67106
# One for each epoch + sanity check
68-
assert val_dataloader.sampler.set_epoch.call_count == 3
107+
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]
69108

70109
val_dataloader = _get_dataloader()
71110
trainer.validate(model, val_dataloader)
72-
assert val_dataloader.sampler.set_epoch.call_count == 1
111+
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)]
73112

74113

75114
@mock.patch(

tests/loops/test_utilities.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import Mock
15+
1416
import pytest
1517
import torch
1618

17-
from pytorch_lightning.loops.utilities import _extract_hiddens, _v1_8_output_format
19+
from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch, _v1_8_output_format
1820
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1921

2022

@@ -61,3 +63,23 @@ def training_epoch_end(outputs, new_format=True):
6163
...
6264

6365
assert _v1_8_output_format(training_epoch_end)
66+
67+
68+
def test_set_sampler_epoch():
69+
# No samplers
70+
dataloader = Mock()
71+
dataloader.sampler = None
72+
dataloader.batch_sampler = None
73+
_set_sampler_epoch(dataloader, 55)
74+
75+
# set_epoch not callable
76+
dataloader = Mock()
77+
dataloader.sampler.set_epoch = None
78+
dataloader.batch_sampler.set_epoch = None
79+
_set_sampler_epoch(dataloader, 55)
80+
81+
# set_epoch callable
82+
dataloader = Mock()
83+
_set_sampler_epoch(dataloader, 55)
84+
dataloader.sampler.set_epoch.assert_called_once_with(55)
85+
dataloader.batch_sampler.set_epoch.assert_called_once_with(55)

0 commit comments

Comments
 (0)