diff --git a/CHANGELOG.md b/CHANGELOG.md index c90af8b9c97cd..57cb9f3f04a40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,10 +111,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) -- - - -- +- Deprecated `prepare_data_per_node` flag on Trainer and set it as a property of `DataHooks`, accessible in the `LightningModule` and `LightningDataModule` [#8958](https://github.com/PyTorchLightning/pytorch-lightning/pull/8958) - diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 3b73ae418ffe2..7ff21885343a9 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -372,6 +372,16 @@ def configure_sharded_model(self) -> None: class DataHooks: """Hooks to be used for data related stuff.""" + def __init__(self) -> None: + """ + Attributes: + prepare_data_per_node: + If True, each LOCAL_RANK=0 will call prepare data. + Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. + """ + super().__init__() + self.prepare_data_per_node: bool = True + def prepare_data(self) -> None: """ Use this to download and prepare data. @@ -405,6 +415,10 @@ def prepare_data(self): # call on GLOBAL_RANK=0 (great for shared file systems) Trainer(prepare_data_per_node=False) + Note: + Setting ``prepare_data_per_node`` with the trainer flag is deprecated and will be removed in v1.7.0. + Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead. + This is called before requesting the dataloaders: .. code-block:: python diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 5c98eb68783a9..d5b2face7563c 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -33,9 +33,15 @@ def on_trainer_init( check_val_every_n_epoch: int, reload_dataloaders_every_n_epochs: int, reload_dataloaders_every_epoch: bool, - prepare_data_per_node: bool, + prepare_data_per_node: Optional[bool] = None, ) -> None: self.trainer.datamodule = None + + if prepare_data_per_node is not None: + rank_zero_deprecation( + "Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0! " + "Please set `prepare_data_per_node` in LightningDataModule or LightningModule directly instead. " + ) self.trainer.prepare_data_per_node = prepare_data_per_node if not isinstance(check_val_every_n_epoch, int): @@ -70,20 +76,40 @@ def get_profiled_train_dataloader(self, train_dataloader) -> Iterable: def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 - if self.can_prepare_data(): - if self.trainer.datamodule is not None: + local_rank_zero = self.trainer.local_rank == 0 + global_rank_zero = self.trainer.local_rank == 0 and self.trainer.node_rank == 0 + + datamodule = self.trainer.datamodule + lightning_module = self.trainer.lightning_module + # handle datamodule prepare data: + # check for prepare_data_per_node & datamodule lifecycle properties before calling datamodule.prepare_data + if datamodule is not None and not datamodule.has_prepared_data: + dm_prepare_data_per_node = datamodule.prepare_data_per_node + dm_eq_prepare_data = datamodule.prepare_data_per_node == self.trainer.prepare_data_per_node + if self.trainer.prepare_data_per_node is not None and not dm_eq_prepare_data: + raise MisconfigurationException( + "Inconsistent settings found for `prepare_data_per_node`." + f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`" + f" and `DataModule.prepare_data_per_node={datamodule.prepare_data_per_node}`." + " Move `prepare_data_per_node` setting to DataModule property." + ) + if (dm_prepare_data_per_node and local_rank_zero) or (not dm_prepare_data_per_node and global_rank_zero): self.trainer.datamodule.prepare_data() - self.trainer.call_hook("prepare_data") - self.trainer._is_data_prepared = True - - def can_prepare_data(self): - should_call_dm_prepare_data = True - if self.trainer.datamodule is not None and is_overridden("prepare_data", self.trainer.datamodule): - should_call_dm_prepare_data = not self.trainer.datamodule._has_prepared_data - - if self.trainer.prepare_data_per_node: - return self.trainer.local_rank == 0 and should_call_dm_prepare_data - return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data + # handle lightning module prepare data: + # check for prepare_data_per_node before calling lightning_module.prepare_data + if lightning_module is not None: + lm_prepare_data_per_node = lightning_module.prepare_data_per_node + lm_eq_prepare_data = lightning_module.prepare_data_per_node == self.trainer.prepare_data_per_node + if (self.trainer.prepare_data_per_node is not None) and not lm_eq_prepare_data: + raise MisconfigurationException( + "Inconsistent settings found for `prepare_data_per_node`." + f" Value was set with both `Trainer(prepare_data_per_node={self.trainer.prepare_data_per_node}.)`" + f" and `LightningModule.prepare_data_per_node={lightning_module.prepare_data_per_node}`." + " Move `prepare_data_per_node` setting to LightningModule property." + ) + if (lm_prepare_data_per_node and local_rank_zero) or (not lm_prepare_data_per_node and global_rank_zero): + self.trainer.call_hook("prepare_data") + self.trainer._is_data_prepared = True def attach_data( self, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index becf4d4cf2c4d..d3446bc1d4e7b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -151,7 +151,7 @@ def __init__( replace_sampler_ddp: bool = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, - prepare_data_per_node: bool = True, + prepare_data_per_node: Optional[bool] = None, plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None, amp_backend: str = "native", amp_level: str = "O2", @@ -243,6 +243,10 @@ def __init__( prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data + .. deprecated:: v1.5 + Deprecated in v1.5.0 and will be removed in v1.7.0 + Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead. + process_position: orders the progress bar when running multiple models on same machine. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5c33a2f68acf0..3bfe3aaa6cf80 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -23,6 +23,7 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -41,13 +42,10 @@ def test_can_prepare_data(local_rank, node_rank): # 1 no DM # prepare_data_per_node = True # local rank = 0 (True) - trainer.prepare_data_per_node = True - dm.random_full = None dm._has_prepared_data = False local_rank.return_value = 0 assert trainer.local_rank == 0 - assert trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data() assert dm.random_full is not None @@ -57,7 +55,6 @@ def test_can_prepare_data(local_rank, node_rank): dm._has_prepared_data = False local_rank.return_value = 1 assert trainer.local_rank == 1 - assert not trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data() assert dm.random_full is None @@ -66,10 +63,9 @@ def test_can_prepare_data(local_rank, node_rank): # global rank = 0 (True) dm.random_full = None dm._has_prepared_data = False - trainer.prepare_data_per_node = False + dm.prepare_data_per_node = False node_rank.return_value = 0 local_rank.return_value = 0 - assert trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data() assert dm.random_full is not None @@ -79,14 +75,12 @@ def test_can_prepare_data(local_rank, node_rank): dm._has_prepared_data = False node_rank.return_value = 1 local_rank.return_value = 0 - assert not trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data() assert dm.random_full is None node_rank.return_value = 0 local_rank.return_value = 1 - assert not trainer.data_connector.can_prepare_data() trainer.data_connector.prepare_data() assert dm.random_full is None @@ -94,24 +88,22 @@ def test_can_prepare_data(local_rank, node_rank): # 2 dm # prepar per node = True # local rank = 0 (True) - trainer.prepare_data_per_node = True + dm.prepare_data_per_node = True local_rank.return_value = 0 - # is_overridden prepare data = True - # has been called - # False - dm._has_prepared_data = True - assert not trainer.data_connector.can_prepare_data() - - # has not been called - # True - dm._has_prepared_data = False - assert trainer.data_connector.can_prepare_data() + with mock.patch.object(trainer.datamodule, "prepare_data") as dm_mock: + # is_overridden prepare data = True + # has been called + # False + dm._has_prepared_data = True + trainer.data_connector.prepare_data() + dm_mock.assert_not_called() - # is_overridden prepare data = False - # True - dm.prepare_data = None - assert trainer.data_connector.can_prepare_data() + # has not been called + # True + dm._has_prepared_data = False + trainer.data_connector.prepare_data() + dm_mock.assert_called_once() def test_hooks_no_recursion_error(): @@ -539,3 +531,13 @@ def __init__(self, arg0, arg1, kwarg0=None): def test_simple_hyperparameters_saving(): data = DataModuleWithHparams(10, "foo", kwarg0="bar") assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) + + +def test_inconsistent_prepare_data_per_node(tmpdir): + with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): + model = BoringModel() + dm = BoringDataModule() + trainer = Trainer(prepare_data_per_node=False) + trainer.model = model + trainer.datamodule = dm + trainer.data_connector.prepare_data() diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index d836f1427a110..7581bf2b0c142 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -15,7 +15,7 @@ import pytest -from pytorch_lightning import LightningDataModule +from pytorch_lightning import LightningDataModule, Trainer from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.datamodules import MNISTDataModule @@ -80,3 +80,10 @@ def test_v1_7_0_datamodule_dims_property(tmpdir): _ = dm.dims with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): _ = LightningDataModule(dims=(1, 1, 1)) + + +def test_v1_7_0_trainer_prepare_data_per_node(tmpdir): + with pytest.deprecated_call( + match="Setting `prepare_data_per_node` with the trainer flag is deprecated and will be removed in v1.7.0!" + ): + _ = Trainer(prepare_data_per_node=False) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index f0c1d7d49b586..a5e4e1d189aaa 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -470,6 +470,7 @@ def __init__(self, lr: float = 0.01, num_blocks: int = 5): super().__init__() self.lr = lr self.num_blocks = num_blocks + self.prepare_data_per_node = True self.train_acc = Accuracy() self.valid_acc = Accuracy()