From e530fb5f0f3f578aeb93dacbea9ac7a208479086 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 15:47:31 -0700 Subject: [PATCH 1/7] Deprecate DataModule lifecycle properties --- pytorch_lightning/core/datamodule.py | 64 +++++++++++++++++++++++++ tests/deprecated_api/test_remove_1-6.py | 24 +++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 84210e9d7b667..c56cbf9507fa3 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -22,6 +22,7 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.distributed import rank_zero_deprecation class LightningDataModule(CheckpointHooks, DataHooks): @@ -160,7 +161,14 @@ def has_prepared_data(self) -> bool: Returns: bool: True if ``datamodule.prepare_data()`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_prepared_data` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_prepared_data @property @@ -169,7 +177,14 @@ def has_setup_fit(self) -> bool: Returns: bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_fit` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_setup_fit @property @@ -178,7 +193,14 @@ def has_setup_validate(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_validate` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_setup_validate @property @@ -187,7 +209,14 @@ def has_setup_test(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_test` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_setup_test @property @@ -196,7 +225,14 @@ def has_setup_predict(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_setup_predict` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_setup_predict @property @@ -205,7 +241,14 @@ def has_teardown_fit(self) -> bool: Returns: bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_fit` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_teardown_fit @property @@ -214,7 +257,14 @@ def has_teardown_validate(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_validate` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_teardown_validate @property @@ -223,7 +273,14 @@ def has_teardown_test(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_test` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_teardown_test @property @@ -232,7 +289,14 @@ def has_teardown_predict(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. + + .. deprecated::v1.4 + Will be removed in v1.6.0. """ + rank_zero_deprecation( + 'DataModule property `has_teardown_predict` was deprecated in v1.4' + ' and will be removed in v1.6.' + ) return self._has_teardown_predict @classmethod diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 7ca0939fd60d2..8061d8c28a30a 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -16,7 +16,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin -from tests.helpers import BoringModel +from tests.helpers import BoringDataModule, BoringModel def test_v1_6_0_trainer_model_hook_mixin(tmpdir): @@ -86,3 +86,25 @@ def training_step(self, *args): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"): trainer.fit(TestModel()) + + +def test_v1_6_0_datamodule_lifecycle_properties(tmpdir): + dm = BoringDataModule() + with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"): + dm.has_prepared_data + with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"): + dm.has_setup_fit + with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"): + dm.has_setup_validate + with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"): + dm.has_setup_test + with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"): + dm.has_setup_predict + with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"): + dm.has_teardown_fit + with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"): + dm.has_teardown_validate + with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"): + dm.has_teardown_test + with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"): + dm.has_teardown_predict From 9415e441be32aa4544b07c00bcea9e8db4a66dc2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sat, 22 May 2021 15:53:04 -0700 Subject: [PATCH 2/7] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b2d79ba0fa4a..83bc3838f6602 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -143,6 +143,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/)) + + - Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422)) From 660ac8fbb7c88e99bf90adb5d3000eaf3c78d49a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 26 May 2021 11:39:04 +0200 Subject: [PATCH 3/7] Apply suggestions from code review --- pytorch_lightning/core/datamodule.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index c56cbf9507fa3..e5a65e2847bb4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -166,8 +166,7 @@ def has_prepared_data(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_prepared_data` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_prepared_data @@ -182,8 +181,7 @@ def has_setup_fit(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_setup_fit` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_setup_fit @@ -198,8 +196,7 @@ def has_setup_validate(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_setup_validate` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_setup_validate @@ -214,8 +211,7 @@ def has_setup_test(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_setup_test` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_setup_test @@ -230,8 +226,7 @@ def has_setup_predict(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_setup_predict` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_setup_predict @@ -246,8 +241,7 @@ def has_teardown_fit(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_teardown_fit` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_teardown_fit @@ -262,8 +256,7 @@ def has_teardown_validate(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_teardown_validate` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_teardown_validate @@ -278,8 +271,7 @@ def has_teardown_test(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_teardown_test` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_teardown_test @@ -294,8 +286,7 @@ def has_teardown_predict(self) -> bool: Will be removed in v1.6.0. """ rank_zero_deprecation( - 'DataModule property `has_teardown_predict` was deprecated in v1.4' - ' and will be removed in v1.6.' + 'DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6.' ) return self._has_teardown_predict From 9f25eb16f521b13edab5b2d258943d3b803bba3c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 26 May 2021 14:20:24 +0200 Subject: [PATCH 4/7] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/core/datamodule.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index e5a65e2847bb4..641bce35453e4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -162,7 +162,7 @@ def has_prepared_data(self) -> bool: Returns: bool: True if ``datamodule.prepare_data()`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -177,7 +177,7 @@ def has_setup_fit(self) -> bool: Returns: bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -192,7 +192,7 @@ def has_setup_validate(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -207,7 +207,7 @@ def has_setup_test(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='test')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -222,7 +222,7 @@ def has_setup_predict(self) -> bool: Returns: bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -237,7 +237,7 @@ def has_teardown_fit(self) -> bool: Returns: bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -252,7 +252,7 @@ def has_teardown_validate(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -267,7 +267,7 @@ def has_teardown_test(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( @@ -282,7 +282,7 @@ def has_teardown_predict(self) -> bool: Returns: bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default. - .. deprecated::v1.4 + .. deprecated:: v1.4 Will be removed in v1.6.0. """ rank_zero_deprecation( From fbb2ecf43e4c779f4356866dc5ea2657fd1e7242 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 May 2021 12:21:04 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/datamodule.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 641bce35453e4..7d0cb49eb35fe 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -180,9 +180,7 @@ def has_setup_fit(self) -> bool: .. deprecated:: v1.4 Will be removed in v1.6.0. """ - rank_zero_deprecation( - 'DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.' - ) + rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.') return self._has_setup_fit @property From 9f3cf8334978696c989c79959bcb054dce281efe Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 26 May 2021 17:18:52 -0700 Subject: [PATCH 6/7] Update datamodule.py --- pytorch_lightning/core/datamodule.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 7d0cb49eb35fe..ae2aa2f442633 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -20,9 +20,8 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only, rank_zero_warn class LightningDataModule(CheckpointHooks, DataHooks): @@ -434,8 +433,13 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: has_run = obj._has_prepared_data obj._has_prepared_data = True - if not has_run: - return fn(*args, **kwargs) + if has_run: + rank_zero_warn( + f"DataModule.{name} has already been called, so it will not be called again. " + f"In v1.6 this behavior will change to always call DataModule.{name}." + ) + else: + fn(*args, **kwargs) return wrapped_fn From 066571b91e93783cd7111865397e0a56f4000d8b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Jun 2021 10:58:44 -0700 Subject: [PATCH 7/7] add-hooks-deprecation-test --- pytorch_lightning/core/datamodule.py | 4 +- tests/core/test_datamodules.py | 43 --------------------- tests/deprecated_api/test_remove_1-6.py | 51 +++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index ae2aa2f442633..afa1238786490 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -21,7 +21,7 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only class LightningDataModule(CheckpointHooks, DataHooks): @@ -434,7 +434,7 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: obj._has_prepared_data = True if has_run: - rank_zero_warn( + rank_zero_deprecation( f"DataModule.{name} has already been called, so it will not be called again. " f"In v1.6 this behavior will change to always call DataModule.{name}." ) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index e6500a15eeed1..d4e1a3ff0e3ae 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -524,46 +524,3 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ]) - - -def test_datamodule_hooks_calls(tmpdir): - """Test that repeated calls to DataHooks' hooks have no effect""" - - class TestDataModule(BoringDataModule): - setup_calls = [] - teardown_calls = [] - prepare_data_calls = 0 - - def setup(self, stage=None): - super().setup(stage=stage) - self.setup_calls.append(stage) - - def teardown(self, stage=None): - super().teardown(stage=stage) - self.teardown_calls.append(stage) - - def prepare_data(self): - super().prepare_data() - self.prepare_data_calls += 1 - - dm = TestDataModule() - dm.prepare_data() - dm.prepare_data() - dm.setup('fit') - dm.setup('fit') - dm.setup() - dm.setup() - dm.teardown('validate') - dm.teardown('validate') - - assert dm.prepare_data_calls == 1 - assert dm.setup_calls == ['fit', None] - assert dm.teardown_calls == ['validate'] - - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) - trainer.test(BoringModel(), datamodule=dm) - - # same number of calls - assert dm.prepare_data_calls == 1 - assert dm.setup_calls == ['fit', None] - assert dm.teardown_calls == ['validate', 'test'] diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 8061d8c28a30a..2851a81e968d2 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -108,3 +108,54 @@ def test_v1_6_0_datamodule_lifecycle_properties(tmpdir): dm.has_teardown_test with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"): dm.has_teardown_predict + + +def test_v1_6_0_datamodule_hooks_calls(tmpdir): + """Test that repeated calls to DataHooks' hooks show a warning about the coming API change.""" + + class TestDataModule(BoringDataModule): + setup_calls = [] + teardown_calls = [] + prepare_data_calls = 0 + + def setup(self, stage=None): + super().setup(stage=stage) + self.setup_calls.append(stage) + + def teardown(self, stage=None): + super().teardown(stage=stage) + self.teardown_calls.append(stage) + + def prepare_data(self): + super().prepare_data() + self.prepare_data_calls += 1 + + dm = TestDataModule() + dm.prepare_data() + dm.prepare_data() + dm.setup('fit') + with pytest.deprecated_call( + match=r"DataModule.setup has already been called, so it will not be called again. " + "In v1.6 this behavior will change to always call DataModule.setup" + ): + dm.setup('fit') + dm.setup() + dm.setup() + dm.teardown('validate') + with pytest.deprecated_call( + match=r"DataModule.teardown has already been called, so it will not be called again. " + "In v1.6 this behavior will change to always call DataModule.teardown" + ): + dm.teardown('validate') + + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate'] + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + trainer.test(BoringModel(), datamodule=dm) + + # same number of calls + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate', 'test']