Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
3aa2f0e
add prepare_data_per_node property to datahooks
ninginthecloud Aug 17, 2021
cfad7ad
add __init__()
ninginthecloud Aug 17, 2021
9b246b5
update prepare_data_per_node in data_connector
ninginthecloud Aug 17, 2021
e8da75e
mark 'prepare_data_per_node' optional in trainer
ninginthecloud Aug 17, 2021
f3a30f8
update failed test_datamodules
ninginthecloud Aug 17, 2021
3244cd2
move deprecation warning in data_connector
ninginthecloud Aug 18, 2021
fac2d0b
update test_remove_1-7.py
ninginthecloud Aug 18, 2021
c53fa4a
update when datamodule is not defined
ninginthecloud Aug 18, 2021
e5107a4
update hook docs
ninginthecloud Aug 18, 2021
f9bd342
update CHANGELOG.md
ninginthecloud Aug 18, 2021
f9706b2
refactor `prepare_data()` in data_connector
ninginthecloud Aug 18, 2021
78a81c8
use `has_prepared_data` property in datamodule
ninginthecloud Aug 18, 2021
7db98ce
minor - comment update
ninginthecloud Aug 18, 2021
8e77b04
add MisconfigurationException
ninginthecloud Aug 18, 2021
153d16c
update MisconfigurationException
ninginthecloud Aug 18, 2021
fe6d2fb
add unit test for MisconfigurationException
ninginthecloud Aug 18, 2021
3f140b6
move item to CHANGELOG.md deprecation session
ninginthecloud Aug 19, 2021
42f500f
update BoringModel and BoringDataModule
ninginthecloud Aug 19, 2021
fcbff81
set prepare_data_per_node as unused property
ninginthecloud Aug 19, 2021
61a3b71
update __jit_unused_properties__ in lightning
ninginthecloud Aug 19, 2021
056ca6a
fix error
ninginthecloud Aug 19, 2021
25f1c45
Update comments in pytorch_lightning/trainer/trainer.py
ninginthecloud Aug 19, 2021
5da20d8
Update CHANGELOG.md
ninginthecloud Aug 19, 2021
d89f947
Update pytorch_lightning/trainer/connectors/data_connector.py
ninginthecloud Aug 19, 2021
630796f
Update pytorch_lightning/trainer/connectors/data_connector.py
ninginthecloud Aug 19, 2021
ca37a83
Update pytorch_lightning/trainer/connectors/data_connector.py
ninginthecloud Aug 19, 2021
7b52795
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2021
8ec10f6
remove `prepare_data_per_node` from boring_model
ninginthecloud Aug 19, 2021
5a81b85
update test_datamodules to avoid noop
ninginthecloud Aug 19, 2021
b66f371
Update comment pytorch_lightning/core/hooks.py
ninginthecloud Aug 20, 2021
1693e63
Update comment pytorch_lightning/trainer/trainer.py
ninginthecloud Aug 20, 2021
13e64e9
Update pytorch_lightning/trainer/connectors/data_connector.py
ninginthecloud Aug 20, 2021
ca6c829
update property in DataHooks
ninginthecloud Aug 21, 2021
e5442b7
add call_hook
ninginthecloud Aug 21, 2021
a72c929
fix test `test_datamodules.py`
ninginthecloud Aug 21, 2021
702cdc8
update prepare_data_per_node as attributes
ninginthecloud Aug 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851))


-


-
- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958)


-
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,16 @@ def configure_sharded_model(self) -> None:
class DataHooks:
"""Hooks to be used for data related stuff."""

def __init__(self) -> None:
"""
Attributes:
prepare_data_per_node:
If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
"""
super().__init__()
self.prepare_data_per_node: bool = True

def prepare_data(self) -> None:
"""
Use this to download and prepare data.
Expand Down Expand Up @@ -405,6 +415,10 @@ def prepare_data(self):
# call on GLOBAL_RANK=0 (great for shared file systems)
Trainer(prepare_data_per_node=False)

Note:
Setting ``prepare_data_per_node`` with the trainer flag is deprecated and will be removed in v1.7.0.
Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead.

This is called before requesting the dataloaders:

.. code-block:: python
Expand Down
54 changes: 40 additions & 14 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ def on_trainer_init(
check_val_every_n_epoch: int,
reload_dataloaders_every_n_epochs: int,
reload_dataloaders_every_epoch: bool,
prepare_data_per_node: bool,
prepare_data_per_node: Optional[bool] = None,
) -> None:
self.trainer.datamodule = None

if prepare_data_per_node is not None:
rank_zero_deprecation(
"Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0! "
"Please set `prepare_data_per_node` in LightningDataModule or LightningModule directly instead. "
)
self.trainer.prepare_data_per_node = prepare_data_per_node

if not isinstance(check_val_every_n_epoch, int):
Expand Down Expand Up @@ -70,20 +76,40 @@ def get_profiled_train_dataloader(self, train_dataloader) -> Iterable:
def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.trainer.datamodule is not None:
local_rank_zero = self.trainer.local_rank == 0
global_rank_zero = self.trainer.local_rank == 0 and self.trainer.node_rank == 0

datamodule = self.trainer.datamodule
lightning_module = self.trainer.lightning_module
# handle datamodule prepare data:
# check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data
if datamodule is not None and not datamodule.has_prepared_data:
dm_prepare_data_per_node = datamodule.prepare_data_per_node
dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node
if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data:
raise MisconfigurationException(
"Inconsistent settings found for `prepare_data_per_node`."
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
f" and `DataModule.prepare_data_per_node={datamodule.prepare_data_per_node}`."
" Move `prepare_data_per_node` setting to DataModule property."
)
if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero):
self.trainer.datamodule.prepare_data()
self.trainer.call_hook("prepare_data")
self.trainer._is_data_prepared = True

def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.trainer.datamodule is not None and is_overridden("prepare_data", self.trainer.datamodule):
should_call_dm_prepare_data = not self.trainer.datamodule._has_prepared_data

if self.trainer.prepare_data_per_node:
return self.trainer.local_rank == 0 and should_call_dm_prepare_data
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
# handle lightning module prepare data:
# check for prepare_data_per_node before calling lightning_module.prepare_data
if lightning_module is not None:
lm_prepare_data_per_node = lightning_module.prepare_data_per_node
lm_eq_prepare_data = lightning_module.prepare_data_per_node == self.trainer.prepare_data_per_node
if (self.trainer.prepare_data_per_node is not None) and not lm_eq_prepare_data:
raise MisconfigurationException(
"Inconsistent settings found for `prepare_data_per_node`."
f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`"
f" and `LightningModule.prepare_data_per_node={lightning_module.prepare_data_per_node}`."
" Move `prepare_data_per_node` setting to LightningModule property."
)
if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero):
self.trainer.call_hook("prepare_data")
self.trainer._is_data_prepared = True

def attach_data(
self,
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
prepare_data_per_node: Optional[bool] = None,
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
amp_backend: str = "native",
amp_level: str = "O2",
Expand Down Expand Up @@ -243,6 +243,10 @@ def __init__(
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

.. deprecated:: v1.5
Deprecated in v1.5.0 and will be removed in v1.7.0
Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead.

process_position: orders the progress bar when running multiple models on same machine.

progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Expand Down
48 changes: 25 additions & 23 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand All @@ -41,13 +42,10 @@ def test_can_prepare_data(local_rank, node_rank):
# 1 no DM
# prepare_data_per_node = True
# local rank = 0 (True)
trainer.prepare_data_per_node = True

dm.random_full = None
dm._has_prepared_data = False
local_rank.return_value = 0
assert trainer.local_rank == 0
assert trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data()
assert dm.random_full is not None
Expand All @@ -57,7 +55,6 @@ def test_can_prepare_data(local_rank, node_rank):
dm._has_prepared_data = False
local_rank.return_value = 1
assert trainer.local_rank == 1
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data()
assert dm.random_full is None
Expand All @@ -66,10 +63,9 @@ def test_can_prepare_data(local_rank, node_rank):
# global rank = 0 (True)
dm.random_full = None
dm._has_prepared_data = False
trainer.prepare_data_per_node = False
dm.prepare_data_per_node = False
node_rank.return_value = 0
local_rank.return_value = 0
assert trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data()
assert dm.random_full is not None
Expand All @@ -79,39 +75,35 @@ def test_can_prepare_data(local_rank, node_rank):
dm._has_prepared_data = False
node_rank.return_value = 1
local_rank.return_value = 0
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data()
assert dm.random_full is None

node_rank.return_value = 0
local_rank.return_value = 1
assert not trainer.data_connector.can_prepare_data()

trainer.data_connector.prepare_data()
assert dm.random_full is None

# 2 dm
# prepar per node = True
# local rank = 0 (True)
trainer.prepare_data_per_node = True
dm.prepare_data_per_node = True
local_rank.return_value = 0

# is_overridden prepare data = True
# has been called
# False
dm._has_prepared_data = True
assert not trainer.data_connector.can_prepare_data()

# has not been called
# True
dm._has_prepared_data = False
assert trainer.data_connector.can_prepare_data()
with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock:
# is_overridden prepare data = True
# has been called
# False
dm._has_prepared_data = True
trainer.data_connector.prepare_data()
dm_mock.assert_not_called()

# is_overridden prepare data = False
# True
dm.prepare_data = None
assert trainer.data_connector.can_prepare_data()
# has not been called
# True
dm._has_prepared_data = False
trainer.data_connector.prepare_data()
dm_mock.assert_called_once()


def test_hooks_no_recursion_error():
Expand Down Expand Up @@ -539,3 +531,13 @@ def __init__(self, arg0, arg1, kwarg0=None):
def test_simple_hyperparameters_saving():
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})


def test_inconsistent_prepare_data_per_node(tmpdir):
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
model = BoringModel()
dm = BoringDataModule()
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer.data_connector.prepare_data()
9 changes: 8 additions & 1 deletion tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import pytest

from pytorch_lightning import LightningDataModule
from pytorch_lightning import LightningDataModule, Trainer
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
Expand Down Expand Up @@ -80,3 +80,10 @@ def test_v1_7_0_datamodule_dims_property(tmpdir):
_ = dm.dims
with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"):
_ = LightningDataModule(dims=(1, 1, 1))


def test_v1_7_0_trainer_prepare_data_per_node(tmpdir):
with pytest.deprecated_call(
match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!"
):
_ = Trainer(prepare_data_per_node=False)
1 change: 1 addition & 0 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def __init__(self, lr: float = 0.01, num_blocks: int = 5):
super().__init__()
self.lr = lr
self.num_blocks = num_blocks
self.prepare_data_per_node = True

self.train_acc = Accuracy()
self.valid_acc = Accuracy()
Expand Down