|
23 | 23 | from tests_lite.helpers.runif import RunIf |
24 | 24 | from tests_lite.helpers.utils import no_warning_call |
25 | 25 | from torch import nn |
26 | | -from torch.utils.data import DataLoader, DistributedSampler, Sampler |
| 26 | +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset |
27 | 27 |
|
28 | 28 | from lightning_lite.lite import LightningLite |
29 | 29 | from lightning_lite.plugins import Precision |
|
40 | 40 | from lightning_lite.strategies.strategy import _Sharded |
41 | 41 | from lightning_lite.utilities import _StrategyType |
42 | 42 | from lightning_lite.utilities.exceptions import MisconfigurationException |
43 | | -from lightning_lite.utilities.seed import pl_worker_init_function |
| 43 | +from lightning_lite.utilities.seed import pl_worker_init_function, seed_everything |
44 | 44 | from lightning_lite.utilities.warnings import PossibleUserWarning |
45 | 45 | from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer |
46 | 46 |
|
@@ -384,6 +384,32 @@ def test_setup_dataloaders_distributed_sampler_not_needed(): |
384 | 384 | assert lite_dataloader.sampler is custom_sampler |
385 | 385 |
|
386 | 386 |
|
| 387 | +def test_setup_dataloaders_distributed_sampler_shuffle(): |
| 388 | + """Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed |
| 389 | + sampler.""" |
| 390 | + lite = LightningLite(accelerator="cpu", strategy="ddp_spawn", devices=2) |
| 391 | + # no lite.launch(): pretend we are on rank 0 now |
| 392 | + |
| 393 | + dataset = TensorDataset(torch.arange(8)) |
| 394 | + |
| 395 | + # shuffling turned off |
| 396 | + no_shuffle_dataloaders = [ |
| 397 | + DataLoader(dataset), |
| 398 | + DataLoader(dataset, shuffle=False), |
| 399 | + DataLoader(dataset, sampler=SequentialSampler(dataset)), |
| 400 | + ] |
| 401 | + for dataloader in no_shuffle_dataloaders: |
| 402 | + dataloader = lite.setup_dataloaders(dataloader) |
| 403 | + assert list(t[0].item() for t in iter(dataloader)) == [0, 2, 4, 6] |
| 404 | + |
| 405 | + # shuffling turned on |
| 406 | + shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))] |
| 407 | + for dataloader in shuffle_dataloaders: |
| 408 | + seed_everything(1) |
| 409 | + dataloader = lite.setup_dataloaders(dataloader) |
| 410 | + assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1] |
| 411 | + |
| 412 | + |
387 | 413 | @mock.patch.dict(os.environ, {}, clear=True) |
388 | 414 | def test_seed_everything(): |
389 | 415 | """Test that seed everything is static and sets the worker init function on the dataloader.""" |
|
0 commit comments