Skip to content

Commit 553f076

Browse files
committed
add-hooks-deprecation-test
1 parent 677d8bd commit 553f076

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

pytorch_lightning/core/datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
24-
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only, rank_zero_warn
24+
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_only
2525

2626

2727
class LightningDataModule(CheckpointHooks, DataHooks):
@@ -434,7 +434,7 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
434434
obj._has_prepared_data = True
435435

436436
if has_run:
437-
rank_zero_warn(
437+
rank_zero_deprecation(
438438
f"DataModule.{name} has already been called, so it will not be called again. "
439439
f"In v1.6 this behavior will change to always call DataModule.{name}."
440440
)

tests/core/test_datamodules.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -524,46 +524,3 @@ def test_dm_init_from_datasets_dataloaders(iterable):
524524
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
525525
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
526526
])
527-
528-
529-
def test_datamodule_hooks_calls(tmpdir):
530-
"""Test that repeated calls to DataHooks' hooks have no effect"""
531-
532-
class TestDataModule(BoringDataModule):
533-
setup_calls = []
534-
teardown_calls = []
535-
prepare_data_calls = 0
536-
537-
def setup(self, stage=None):
538-
super().setup(stage=stage)
539-
self.setup_calls.append(stage)
540-
541-
def teardown(self, stage=None):
542-
super().teardown(stage=stage)
543-
self.teardown_calls.append(stage)
544-
545-
def prepare_data(self):
546-
super().prepare_data()
547-
self.prepare_data_calls += 1
548-
549-
dm = TestDataModule()
550-
dm.prepare_data()
551-
dm.prepare_data()
552-
dm.setup('fit')
553-
dm.setup('fit')
554-
dm.setup()
555-
dm.setup()
556-
dm.teardown('validate')
557-
dm.teardown('validate')
558-
559-
assert dm.prepare_data_calls == 1
560-
assert dm.setup_calls == ['fit', None]
561-
assert dm.teardown_calls == ['validate']
562-
563-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
564-
trainer.test(BoringModel(), datamodule=dm)
565-
566-
# same number of calls
567-
assert dm.prepare_data_calls == 1
568-
assert dm.setup_calls == ['fit', None]
569-
assert dm.teardown_calls == ['validate', 'test']

tests/deprecated_api/test_remove_1-6.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,54 @@ def test_v1_6_0_datamodule_lifecycle_properties(tmpdir):
108108
dm.has_teardown_test
109109
with pytest.deprecated_call(match=r"DataModule property `has_teardown_predict` was deprecated in v1.4"):
110110
dm.has_teardown_predict
111+
112+
113+
def test_v1_6_0_datamodule_hooks_calls(tmpdir):
114+
"""Test that repeated calls to DataHooks' hooks show a warning about the coming API change."""
115+
116+
class TestDataModule(BoringDataModule):
117+
setup_calls = []
118+
teardown_calls = []
119+
prepare_data_calls = 0
120+
121+
def setup(self, stage=None):
122+
super().setup(stage=stage)
123+
self.setup_calls.append(stage)
124+
125+
def teardown(self, stage=None):
126+
super().teardown(stage=stage)
127+
self.teardown_calls.append(stage)
128+
129+
def prepare_data(self):
130+
super().prepare_data()
131+
self.prepare_data_calls += 1
132+
133+
dm = TestDataModule()
134+
dm.prepare_data()
135+
dm.prepare_data()
136+
dm.setup('fit')
137+
with pytest.deprecated_call(
138+
match=r"DataModule.setup has already been called, so it will not be called again. "
139+
"In v1.6 this behavior will change to always call DataModule.setup"
140+
):
141+
dm.setup('fit')
142+
dm.setup()
143+
dm.setup()
144+
dm.teardown('validate')
145+
with pytest.deprecated_call(
146+
match=r"DataModule.teardown has already been called, so it will not be called again. "
147+
"In v1.6 this behavior will change to always call DataModule.teardown"
148+
):
149+
dm.teardown('validate')
150+
151+
assert dm.prepare_data_calls == 1
152+
assert dm.setup_calls == ['fit', None]
153+
assert dm.teardown_calls == ['validate']
154+
155+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
156+
trainer.test(BoringModel(), datamodule=dm)
157+
158+
# same number of calls
159+
assert dm.prepare_data_calls == 1
160+
assert dm.setup_calls == ['fit', None]
161+
assert dm.teardown_calls == ['validate', 'test']

0 commit comments

Comments
 (0)