Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated


- Deprecated `DataModule` properties: `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test`, `has_teardown_predict` ([#7657](https://github.com/PyTorchLightning/pytorch-lightning/pull/7657/))


- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))


Expand Down
63 changes: 60 additions & 3 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset

from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only


class LightningDataModule(CheckpointHooks, DataHooks):
Expand Down Expand Up @@ -160,7 +160,13 @@ def has_prepared_data(self) -> bool:

Returns:
bool: True if ``datamodule.prepare_data()`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_prepared_data` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_prepared_data

@property
Expand All @@ -169,7 +175,11 @@ def has_setup_fit(self) -> bool:

Returns:
bool: True ``if datamodule.setup(stage='fit')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation('DataModule property `has_setup_fit` was deprecated in v1.4 and will be removed in v1.6.')
return self._has_setup_fit

@property
Expand All @@ -178,7 +188,13 @@ def has_setup_validate(self) -> bool:

Returns:
bool: True if ``datamodule.setup(stage='validate')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_setup_validate` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_setup_validate

@property
Expand All @@ -187,7 +203,13 @@ def has_setup_test(self) -> bool:

Returns:
bool: True if ``datamodule.setup(stage='test')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_setup_test` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_setup_test

@property
Expand All @@ -196,7 +218,13 @@ def has_setup_predict(self) -> bool:

Returns:
bool: True if ``datamodule.setup(stage='predict')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_setup_predict` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_setup_predict

@property
Expand All @@ -205,7 +233,13 @@ def has_teardown_fit(self) -> bool:

Returns:
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_fit` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_fit

@property
Expand All @@ -214,7 +248,13 @@ def has_teardown_validate(self) -> bool:

Returns:
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_validate` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_validate

@property
Expand All @@ -223,7 +263,13 @@ def has_teardown_test(self) -> bool:

Returns:
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_test` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_test

@property
Expand All @@ -232,7 +278,13 @@ def has_teardown_predict(self) -> bool:

Returns:
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.

.. deprecated:: v1.4
Will be removed in v1.6.0.
"""
rank_zero_deprecation(
'DataModule property `has_teardown_predict` was deprecated in v1.4 and will be removed in v1.6.'
)
return self._has_teardown_predict

@classmethod
Expand Down Expand Up @@ -381,8 +433,13 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
has_run = obj._has_prepared_data
obj._has_prepared_data = True

if not has_run:
return fn(*args, **kwargs)
if has_run:
rank_zero_deprecation(
f"DataModule.{name} has already been called, so it will not be called again. "
f"In v1.6 this behavior will change to always call DataModule.{name}."
)
else:
fn(*args, **kwargs)

return wrapped_fn

Expand Down
43 changes: 0 additions & 43 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,46 +524,3 @@ def test_dm_init_from_datasets_dataloaders(iterable):
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
])


def test_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks have no effect"""

class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0

def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)

def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)

def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1

dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
dm.teardown('validate')

assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)

# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']
75 changes: 74 additions & 1 deletion tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from tests.helpers import BoringModel
from tests.helpers import BoringDataModule, BoringModel


def test_v1_6_0_trainer_model_hook_mixin(tmpdir):
Expand Down Expand Up @@ -86,3 +86,76 @@ def training_step(self, *args):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"):
trainer.fit(TestModel())


def test_v1_6_0_datamodule_lifecycle_properties(tmpdir):
dm = BoringDataModule()
with pytest.deprecated_call(match=r"DataModule property `has_prepared_data` was deprecated in v1.4"):
dm.has_prepared_data
with pytest.deprecated_call(match=r"DataModule property `has_setup_fit` was deprecated in v1.4"):
dm.has_setup_fit
with pytest.deprecated_call(match=r"DataModule property `has_setup_validate` was deprecated in v1.4"):
dm.has_setup_validate
with pytest.deprecated_call(match=r"DataModule property `has_setup_test` was deprecated in v1.4"):
dm.has_setup_test
with pytest.deprecated_call(match=r"DataModule property `has_setup_predict` was deprecated in v1.4"):
dm.has_setup_predict
with pytest.deprecated_call(match=r"DataModule property `has_teardown_fit` was deprecated in v1.4"):
dm.has_teardown_fit
with pytest.deprecated_call(match=r"DataModule property `has_teardown_validate` was deprecated in v1.4"):
dm.has_teardown_validate
with pytest.deprecated_call(match=r"DataModule property `has_teardown_test` was deprecated in v1.4"):
dm.has_teardown_test
with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"):
dm.has_teardown_predict


def test_v1_6_0_datamodule_hooks_calls(tmpdir):
"""Test that repeated calls to DataHooks' hooks show a warning about the coming API change."""

class TestDataModule(BoringDataModule):
setup_calls = []
teardown_calls = []
prepare_data_calls = 0

def setup(self, stage=None):
super().setup(stage=stage)
self.setup_calls.append(stage)

def teardown(self, stage=None):
super().teardown(stage=stage)
self.teardown_calls.append(stage)

def prepare_data(self):
super().prepare_data()
self.prepare_data_calls += 1

dm = TestDataModule()
dm.prepare_data()
dm.prepare_data()
dm.setup('fit')
with pytest.deprecated_call(
match=r"DataModule.setup has already been called, so it will not be called again. "
"In v1.6 this behavior will change to always call DataModule.setup"
):
dm.setup('fit')
dm.setup()
dm.setup()
dm.teardown('validate')
with pytest.deprecated_call(
match=r"DataModule.teardown has already been called, so it will not be called again. "
"In v1.6 this behavior will change to always call DataModule.teardown"
):
dm.teardown('validate')

assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate']

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.test(BoringModel(), datamodule=dm)

# same number of calls
assert dm.prepare_data_calls == 1
assert dm.setup_calls == ['fit', None]
assert dm.teardown_calls == ['validate', 'test']