Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945))

- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))

- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_deprecation


class LightningDataModule(CheckpointHooks, DataHooks):
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_dataloader():
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
obj = super().__new__(cls)
# track `DataHooks` calls and run `prepare_data` only on rank zero
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
obj.prepare_data = cls._track_data_hook_calls(obj, obj.prepare_data)
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
return obj
Expand Down
25 changes: 25 additions & 0 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):

model = BoringModel()
dm = BoringDataModule()
trainer = Trainer()
trainer.datamodule = dm
Expand All @@ -43,30 +44,54 @@ def test_can_prepare_data(local_rank, node_rank):
# local rank = 0 (True)
trainer.prepare_data_per_node = True

dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 0
assert trainer.local_rank == 0
assert trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is not None

# local rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 1
assert trainer.local_rank == 1
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is None

# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
dm.random_full = None
dm._has_prepared_data = False
trainer.prepare_data_per_node = False
node_rank.return_value = 0
local_rank.return_value = 0
assert trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is not None

# global rank = 1 (False)
dm.random_full = None
dm._has_prepared_data = False
node_rank.return_value = 1
local_rank.return_value = 0
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is None

node_rank.return_value = 0
local_rank.return_value = 1
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data(model)
assert dm.random_full is None

# 2 dm
# prepar per node = True
# local rank = 0 (True)
Expand Down