|
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | 15 | from unittest import mock |
16 | | -from unittest.mock import patch |
| 16 | +from unittest.mock import Mock, patch |
17 | 17 |
|
| 18 | +import numpy |
18 | 19 | import pytest |
19 | 20 | import torch |
20 | 21 | from torch.utils.data.dataloader import DataLoader |
21 | | -from torch.utils.data.dataset import IterableDataset, Subset |
| 22 | +from torch.utils.data.dataset import Dataset, IterableDataset, Subset |
22 | 23 | from torch.utils.data.distributed import DistributedSampler |
23 | 24 | from torch.utils.data.sampler import SequentialSampler |
24 | 25 |
|
@@ -635,6 +636,109 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): |
635 | 636 | trainer.fit(model, train_dataloader=train_multi_dl, val_dataloaders=val_multi_dl) |
636 | 637 |
|
637 | 638 |
|
| 639 | +class NumpyRandomDataset(Dataset): |
| 640 | + # this datset uses numpy instead of torch to produce random numbers |
| 641 | + size = 16 |
| 642 | + |
| 643 | + def __getitem__(self, index): |
| 644 | + return numpy.random.randint(0, 100, 3) |
| 645 | + |
| 646 | + def __len__(self): |
| 647 | + return self.size |
| 648 | + |
| 649 | + |
| 650 | +def _user_worker_init_fn(_): |
| 651 | + pass |
| 652 | + |
| 653 | + |
| 654 | +def test_missing_worker_init_fn(): |
| 655 | + """ Test that naive worker seed initialization leads to undesired random state in subprocesses. """ |
| 656 | + dataset = NumpyRandomDataset() |
| 657 | + |
| 658 | + seed_everything(0) |
| 659 | + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) |
| 660 | + batches0 = torch.cat([batch for batch in dataloader]) |
| 661 | + |
| 662 | + seed_everything(0) |
| 663 | + dataloader = DataLoader(dataset, batch_size=2, num_workers=2, shuffle=False) |
| 664 | + batches1 = torch.cat([batch for batch in dataloader]) |
| 665 | + |
| 666 | + is_duplicated = len(torch.unique(batches1, dim=0)) < len(dataset) |
| 667 | + is_deterministic = torch.eq(batches0, batches1).all() |
| 668 | + |
| 669 | + # depending on the OS, we either have |
| 670 | + # 1) the same seed in all worker proceses, producing duplicate samples / augmentations, or |
| 671 | + # 2) different seeds in each worker process, but they are not derived from the seed of the main process |
| 672 | + assert not is_deterministic or is_duplicated |
| 673 | + |
| 674 | + |
| 675 | +def test_auto_add_worker_init_fn(): |
| 676 | + """ Test Trainer adds a default worker_init_fn to the dataloader when seed_everything() is used. """ |
| 677 | + dataset = Mock() |
| 678 | + dataloader = DataLoader(dataset) |
| 679 | + trainer = Trainer() |
| 680 | + |
| 681 | + # without pl.seed_everything() |
| 682 | + trainer.auto_add_worker_init_fn(dataloader) |
| 683 | + assert dataloader.worker_init_fn is None |
| 684 | + |
| 685 | + # with forcefully avoiding it |
| 686 | + seed_everything(0, workers=False) |
| 687 | + trainer.auto_add_worker_init_fn(dataloader) |
| 688 | + assert dataloader.worker_init_fn is None |
| 689 | + |
| 690 | + # when user already has a worker_init_fn |
| 691 | + user_function = _user_worker_init_fn |
| 692 | + dataloader.worker_init_fn = user_function |
| 693 | + trainer.auto_add_worker_init_fn(dataloader) |
| 694 | + assert dataloader.worker_init_fn is user_function |
| 695 | + dataloader.worker_init_fn = None |
| 696 | + |
| 697 | + # main use case |
| 698 | + seed_everything(0, workers=True) |
| 699 | + trainer.auto_add_worker_init_fn(dataloader) |
| 700 | + assert dataloader.worker_init_fn is not None |
| 701 | + |
| 702 | + |
| 703 | +class MultiProcessModel(BoringModel): |
| 704 | + |
| 705 | + def __init__(self): |
| 706 | + super().__init__() |
| 707 | + self.batches_seen = [] |
| 708 | + |
| 709 | + def training_step(self, batch, batch_idx): |
| 710 | + self.batches_seen.append(batch) |
| 711 | + |
| 712 | + def training_epoch_end(self, outputs): |
| 713 | + world_size = 2 |
| 714 | + num_samples = NumpyRandomDataset.size |
| 715 | + all_batches = torch.cat(self.batches_seen) |
| 716 | + all_batches = self.all_gather(all_batches) |
| 717 | + assert all_batches.shape[0] == world_size |
| 718 | + all_batches = all_batches.view(-1, 3) |
| 719 | + assert len(torch.unique(all_batches, dim=0)) == num_samples |
| 720 | + |
| 721 | + |
| 722 | +@RunIf(min_gpus=2) |
| 723 | +def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): |
| 724 | + """ Test that the lightning worker_init_fn takes care of dataloaders in multi-gpu/multi-node training. """ |
| 725 | + dataset = NumpyRandomDataset() |
| 726 | + num_workers = 2 |
| 727 | + batch_size = 2 |
| 728 | + |
| 729 | + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) |
| 730 | + seed_everything(0, workers=True) |
| 731 | + trainer = Trainer( |
| 732 | + default_root_dir=tmpdir, |
| 733 | + max_epochs=1, |
| 734 | + gpus=2, |
| 735 | + accelerator="ddp_spawn", |
| 736 | + ) |
| 737 | + model = MultiProcessModel() |
| 738 | + model.val_dataloader = None |
| 739 | + trainer.fit(model, train_dataloader=dataloader) |
| 740 | + |
| 741 | + |
638 | 742 | def test_warning_with_iterable_dataset_and_len(tmpdir): |
639 | 743 | """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ |
640 | 744 | model = BoringModel() |
|
0 commit comments