Skip to content

Commit 2e226d9

Browse files
carmoccaawaelchli
authored andcommitted
Automatically check DataModule.has_{setup,teardown,prepare_data} [2/2] (#7238)
* Automatically check `DataModule.has_{setup,teardown,prepare_data}` * Use variable * Spacing * Docs * Update CHANGELOG * Remove `_DataModuleWrapper` * Add test * Update docs/source/extensions/datamodules.rst * Bad merge * add test for invalid name * Remove ValueError Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b9b3ec5 commit 2e226d9

File tree

8 files changed

+112
-28
lines changed

8 files changed

+112
-28
lines changed

CHANGELOG.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,43 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [1.4.0] - 2021-MM-DD
9+
10+
### Added
11+
12+
13+
- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
14+
15+
16+
### Changed
17+
18+
19+
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
20+
21+
22+
- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))
23+
24+
25+
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))
26+
27+
28+
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))
29+
30+
31+
- Changed `resolve_training_type_plugins` to allow setting `num_nodes` and `sync_batchnorm` from `Trainer` setting ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
32+
33+
34+
### Deprecated
35+
36+
37+
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
38+
39+
40+
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
41+
42+
43+
### Removed
44+
845
## [1.3.1] - 2021-05-11
946

1047
### Fixed

docs/source/extensions/datamodules.rst

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th
168168
def test_dataloader(self):
169169
return DataLoader(self.mnist_test, batch_size=32)
170170
171-
172-
.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``.
173-
174-
175171
---------------
176172

177173
LightningDataModule API
@@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup
228224
def setup(self, stage: Optional[str] = None):
229225
230226
# Assign Train/val split(s) for use in Dataloaders
231-
if stage == 'fit' or stage is None:
227+
if stage in (None, 'fit'):
232228
mnist_full = MNIST(
233229
self.data_dir,
234230
train=True,
@@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup
239235
self.dims = self.mnist_train[0][0].shape
240236
241237
# Assign Test split(s) for use in Dataloaders
242-
if stage == 'test' or stage is None:
238+
if stage in (None, 'test'):
243239
self.mnist_test = MNIST(
244240
self.data_dir,
245241
train=False,
@@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup
249245
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
250246
251247
252-
.. warning:: ``setup`` is called from every process. Setting state here is okay.
253-
248+
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument.
249+
It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``,
250+
we assume all stages have been set-up.
254251

252+
.. note:: ``setup`` is called from every process. Setting state here is okay.
255253
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
254+
.. note::
255+
``{setup,teardown,prepare_data}`` call will be only called once for a specific stage.
256+
If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that
257+
any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite
258+
``dm._has_setup_fit = False``
256259

257260

258261
train_dataloader
@@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply:
396399
dm = MNISTDataModule()
397400
model = Model()
398401
trainer.fit(model, dm)
399-
400402
trainer.test(datamodule=dm)
401403
402-
If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning
403-
still ensures the method runs on the correct devices)
404+
If you need information from the dataset to build your model, then run
405+
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and
406+
:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures
407+
the method runs on the correct devices).
404408

405409
.. code-block:: python
406410
@@ -416,7 +420,7 @@ still ensures the method runs on the correct devices)
416420
417421
----------------
418422

419-
Datamodules without Lightning
423+
DataModules without Lightning
420424
-----------------------------
421425
You can of course use DataModules in plain PyTorch code as well.
422426

docs/source/starter/introduction_guide.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo
295295
1. use ``prepare_data()`` to download and process the dataset.
296296
2. use ``setup()`` to do splits, and build your model internals
297297

298-
|
299-
300298
An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows:
301299

302300
.. testcode::

docs/source/starter/new-project.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning.
658658
transforms.Normalize((0.1307,), (0.3081,))
659659
])
660660
# split dataset
661-
if stage == 'fit':
661+
if stage in (None, 'fit'):
662662
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
663663
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
664-
if stage == 'test':
664+
if stage == (None, 'test'):
665665
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
666666

667667
# return the dataloader for each split

pytorch_lightning/core/datamodule.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable
355355
@functools.wraps(fn)
356356
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
357357
name = fn.__name__
358+
has_run = False
358359

359360
# If calling setup, we check the stage and assign stage-specific bool args
360361
if name in ("setup", "teardown"):
@@ -366,15 +367,22 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
366367
stage = args[0] if len(args) else kwargs.get("stage", None)
367368

368369
if stage is None:
370+
has_run = True
369371
for s in ("fit", "validate", "test"):
370-
setattr(obj, f"_has_{name}_{s}", True)
372+
attr = f"_has_{name}_{s}"
373+
has_run &= getattr(obj, attr)
374+
setattr(obj, attr, True)
371375
else:
372-
setattr(obj, f"_has_{name}_{stage}", True)
376+
attr = f"_has_{name}_{stage}"
377+
has_run = getattr(obj, attr)
378+
setattr(obj, attr, True)
373379

374380
elif name == "prepare_data":
381+
has_run = obj._has_prepared_data
375382
obj._has_prepared_data = True
376383

377-
return fn(*args, **kwargs)
384+
if not has_run:
385+
return fn(*args, **kwargs)
378386

379387
return wrapped_fn
380388

pytorch_lightning/core/hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def prepare_data(self):
394394

395395
def setup(self, stage: Optional[str] = None) -> None:
396396
"""
397-
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
397+
Called at the beginning of fit (train + validate), validate, test, and predict.
398398
This is a good hook when you need to build models dynamically or adjust something about them.
399399
This hook is called on every process when using DDP.
400400

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,10 +1156,7 @@ def call_setup_hook(self, model: LightningModule) -> None:
11561156
self.accelerator.barrier("pre_setup")
11571157

11581158
if self.datamodule is not None:
1159-
called = getattr(self.datamodule, f'has_setup_{fn}')
1160-
if not called:
1161-
self.datamodule.setup(stage=fn)
1162-
1159+
self.datamodule.setup(stage=fn)
11631160
self.setup(model, stage=fn)
11641161
model.setup(stage=fn)
11651162

@@ -1182,10 +1179,7 @@ def call_teardown_hook(self, model: LightningModule) -> None:
11821179
fn = self.state.fn._setup_fn
11831180

11841181
if self.datamodule is not None:
1185-
called = getattr(self.datamodule, f'has_teardown_{fn}')
1186-
if not called:
1187-
self.datamodule.teardown(stage=fn)
1188-
1182+
self.datamodule.teardown(stage=fn)
11891183
self.profiler.teardown(stage=fn)
11901184
self.teardown(stage=fn)
11911185
model.teardown(stage=fn)

tests/core/test_datamodules.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,46 @@ def test_dm_init_from_datasets_dataloaders(iterable):
521521
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
522522
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True)
523523
])
524+
525+
526+
def test_datamodule_hooks_calls(tmpdir):
527+
"""Test that repeated calls to DataHooks' hooks have no effect"""
528+
529+
class TestDataModule(BoringDataModule):
530+
setup_calls = []
531+
teardown_calls = []
532+
prepare_data_calls = 0
533+
534+
def setup(self, stage=None):
535+
super().setup(stage=stage)
536+
self.setup_calls.append(stage)
537+
538+
def teardown(self, stage=None):
539+
super().teardown(stage=stage)
540+
self.teardown_calls.append(stage)
541+
542+
def prepare_data(self):
543+
super().prepare_data()
544+
self.prepare_data_calls += 1
545+
546+
dm = TestDataModule()
547+
dm.prepare_data()
548+
dm.prepare_data()
549+
dm.setup('fit')
550+
dm.setup('fit')
551+
dm.setup()
552+
dm.setup()
553+
dm.teardown('validate')
554+
dm.teardown('validate')
555+
556+
assert dm.prepare_data_calls == 1
557+
assert dm.setup_calls == ['fit', None]
558+
assert dm.teardown_calls == ['validate']
559+
560+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
561+
trainer.test(BoringModel(), datamodule=dm)
562+
563+
# same number of calls
564+
assert dm.prepare_data_calls == 1
565+
assert dm.setup_calls == ['fit', None]
566+
assert dm.teardown_calls == ['validate', 'test']

0 commit comments

Comments
 (0)