Skip to content

Commit 3004f13

Browse files
Lite: Fix DataLoader shuffling when using DistributedSampler (#15931)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 904323b commit 3004f13

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

src/lightning_lite/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242

4343
### Fixed
4444

45-
-
45+
- Fixed `shuffle=False` having no effect when using DDP/DistributedSampler ([#15931](https://github.com/Lightning-AI/lightning/issues/15931))
46+
4647

4748

4849
## [1.8.3] - 2022-11-22

src/lightning_lite/lite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning_utilities.core.rank_zero import rank_zero_warn
2626
from torch import Tensor
2727
from torch.optim import Optimizer
28-
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
28+
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler
2929

3030
from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
3131
from lightning_lite.accelerators.accelerator import Accelerator
@@ -582,6 +582,7 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
582582

583583
@staticmethod
584584
def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler:
585+
kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler))
585586
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
586587
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)
587588

tests/tests_lite/test_lite.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tests_lite.helpers.runif import RunIf
2424
from tests_lite.helpers.utils import no_warning_call
2525
from torch import nn
26-
from torch.utils.data import DataLoader, DistributedSampler, Sampler
26+
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset
2727

2828
from lightning_lite.lite import LightningLite
2929
from lightning_lite.plugins import Precision
@@ -40,7 +40,7 @@
4040
from lightning_lite.strategies.strategy import _Sharded
4141
from lightning_lite.utilities import _StrategyType
4242
from lightning_lite.utilities.exceptions import MisconfigurationException
43-
from lightning_lite.utilities.seed import pl_worker_init_function
43+
from lightning_lite.utilities.seed import pl_worker_init_function, seed_everything
4444
from lightning_lite.utilities.warnings import PossibleUserWarning
4545
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
4646

@@ -384,6 +384,32 @@ def test_setup_dataloaders_distributed_sampler_not_needed():
384384
assert lite_dataloader.sampler is custom_sampler
385385

386386

387+
def test_setup_dataloaders_distributed_sampler_shuffle():
388+
"""Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed
389+
sampler."""
390+
lite = LightningLite(accelerator="cpu", strategy="ddp_spawn", devices=2)
391+
# no lite.launch(): pretend we are on rank 0 now
392+
393+
dataset = TensorDataset(torch.arange(8))
394+
395+
# shuffling turned off
396+
no_shuffle_dataloaders = [
397+
DataLoader(dataset),
398+
DataLoader(dataset, shuffle=False),
399+
DataLoader(dataset, sampler=SequentialSampler(dataset)),
400+
]
401+
for dataloader in no_shuffle_dataloaders:
402+
dataloader = lite.setup_dataloaders(dataloader)
403+
assert list(t[0].item() for t in iter(dataloader)) == [0, 2, 4, 6]
404+
405+
# shuffling turned on
406+
shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))]
407+
for dataloader in shuffle_dataloaders:
408+
seed_everything(1)
409+
dataloader = lite.setup_dataloaders(dataloader)
410+
assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1]
411+
412+
387413
@mock.patch.dict(os.environ, {}, clear=True)
388414
def test_seed_everything():
389415
"""Test that seed everything is static and sets the worker init function on the dataloader."""

tests/tests_lite/test_parity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator,
203203
)
204204
def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir):
205205
LightningLite.seed_everything(42)
206-
train_dataloader = DataLoader(RandomDataset(32, 4))
206+
train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True)
207207
model = BoringModel()
208208
num_epochs = 1
209209
state_dict = deepcopy(model.state_dict())
@@ -214,13 +214,13 @@ def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir
214214
lite_model_state_dict = model.state_dict()
215215

216216
for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()):
217-
assert not torch.equal(w_pure.cpu(), w_lite.cpu())
217+
assert not torch.allclose(w_pure.cpu(), w_lite.cpu())
218218

219219
LightningLite.seed_everything(42)
220-
train_dataloader = DataLoader(RandomDataset(32, 4))
220+
train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True)
221221
model = BoringModel()
222222
run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir)
223223
pure_model_state_dict = model.state_dict()
224224

225225
for w_pure, w_lite in zip(pure_model_state_dict.values(), lite_model_state_dict.values()):
226-
assert torch.equal(w_pure.cpu(), w_lite.cpu())
226+
torch.testing.assert_close(w_pure.cpu(), w_lite.cpu())

0 commit comments

Comments
 (0)