Skip to content

Commit d502153

Browse files
Queuecumbercarmocca
authored andcommitted
Remove rank_zero_only on DataModule prepare_data (#7945)
Signed-off-by: Max Ehrlich <[email protected]>
1 parent d5c8c62 commit d502153

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

pytorch_lightning/core/datamodule.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch.utils.data import DataLoader, Dataset, IterableDataset
2121

2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
23-
from pytorch_lightning.utilities import rank_zero_only
2423
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
2524

2625

@@ -329,7 +328,7 @@ def test_dataloader():
329328
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
330329
obj = super().__new__(cls)
331330
# track `DataHooks` calls and run `prepare_data` only on rank zero
332-
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
331+
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
333332
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
334333
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
335334
return obj

tests/core/test_datamodules.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
3535
def test_can_prepare_data(local_rank, node_rank):
3636

37+
model = BoringModel()
3738
dm = BoringDataModule()
3839
trainer = Trainer()
3940
trainer.datamodule = dm
@@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
4344
# local rank = 0 (True)
4445
trainer.prepare_data_per_node = True
4546

47+
dm.random_full = None
48+
dm._has_prepared_data = False
4649
local_rank.return_value = 0
4750
assert trainer.local_rank == 0
4851
assert trainer.data_connector.can_prepare_data()
4952

53+
trainer.data_connector.prepare_data(model)
54+
assert dm.random_full is not None
55+
5056
# local rank = 1 (False)
57+
dm.random_full = None
58+
dm._has_prepared_data = False
5159
local_rank.return_value = 1
5260
assert trainer.local_rank == 1
5361
assert not trainer.data_connector.can_prepare_data()
5462

63+
trainer.data_connector.prepare_data(model)
64+
assert dm.random_full is None
65+
5566
# prepare_data_per_node = False (prepare across all nodes)
5667
# global rank = 0 (True)
68+
dm.random_full = None
69+
dm._has_prepared_data = False
5770
trainer.prepare_data_per_node = False
5871
node_rank.return_value = 0
5972
local_rank.return_value = 0
6073
assert trainer.data_connector.can_prepare_data()
6174

75+
trainer.data_connector.prepare_data(model)
76+
assert dm.random_full is not None
77+
6278
# global rank = 1 (False)
79+
dm.random_full = None
80+
dm._has_prepared_data = False
6381
node_rank.return_value = 1
6482
local_rank.return_value = 0
6583
assert not trainer.data_connector.can_prepare_data()
84+
85+
trainer.data_connector.prepare_data(model)
86+
assert dm.random_full is None
87+
6688
node_rank.return_value = 0
6789
local_rank.return_value = 1
6890
assert not trainer.data_connector.can_prepare_data()
6991

92+
trainer.data_connector.prepare_data(model)
93+
assert dm.random_full is None
94+
7095
# 2 dm
7196
# prepar per node = True
7297
# local rank = 0 (True)

0 commit comments

Comments
 (0)