Skip to content

Commit 40976e4

Browse files
ananthsubcarmoccatchaton
authored
Support teardown hook on DataModule (#4673)
Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: chaton <[email protected]>
1 parent 92a1671 commit 40976e4

File tree

8 files changed

+352
-131
lines changed

8 files changed

+352
-131
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
3232

3333

34+
- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673))
35+
36+
3437
- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))
3538

3639

docs/source/common/lightning_module.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,7 @@ prepare_data
12561256
setup
12571257
~~~~~
12581258

1259-
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup
1259+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup
12601260
:noindex:
12611261

12621262
tbptt_split_batch
@@ -1268,7 +1268,7 @@ tbptt_split_batch
12681268
teardown
12691269
~~~~~~~~
12701270

1271-
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown
1271+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown
12721272
:noindex:
12731273

12741274
train_dataloader

docs/source/extensions/datamodules.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
9494
def test_dataloader(self):
9595
return DataLoader(self.mnist_test, batch_size=self.batch_size)
9696
97+
def teardown(self, stage: Optional[str] = None):
98+
# Used to clean-up when the run is finished
99+
...
100+
97101
But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can
98102
let Lightning handle those details for you while making this dataset reusable so you can share with
99103
colleagues or use in different projects.
@@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup
243247
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)
244248
245249
246-
.. warning:: `setup` is called from every process. Setting state here is okay.
250+
.. warning:: ``setup`` is called from every process. Setting state here is okay.
251+
252+
253+
.. note:: ``teardown`` can be used to clean up the state. It is also called from every process
247254

248255

249256
train_dataloader
@@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well.
411418
for batch in dm.val_dataloader():
412419
...
413420
421+
dm.teardown(stage='fit')
422+
414423
# lazy load test data
415424
dm.setup(stage='test')
416425
for batch in dm.test_dataloader():
417426
...
418427
428+
dm.teardown(stage='test')
429+
419430
But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified
420431
structure.

pytorch_lightning/core/datamodule.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""LightningDataModule for loading DataLoaders with ease."""
1515

1616
import functools
17-
from abc import abstractmethod
1817
from argparse import ArgumentParser, Namespace
1918
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
2019

@@ -44,6 +43,8 @@ def __call__(cls, *args, **kwargs):
4443
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
4544
# Track setup calls
4645
cls.setup = track_data_hook_calls(cls.setup)
46+
# Track teardown calls
47+
cls.teardown = track_data_hook_calls(cls.teardown)
4748

4849
# Get instance of LightningDataModule by mocking its __init__ via __call__
4950
obj = type.__call__(cls, *args, **kwargs)
@@ -52,12 +53,13 @@ def __call__(cls, *args, **kwargs):
5253

5354

5455
def track_data_hook_calls(fn):
55-
"""A decorator that checks if prepare_data/setup have been called.
56+
"""A decorator that checks if prepare_data/setup/teardown has been called.
5657
5758
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
5859
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
5960
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
6061
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
62+
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
6163
6264
Args:
6365
fn (function): Function that will be tracked to see if it has been called.
@@ -71,9 +73,10 @@ def wrapped_fn(*args, **kwargs):
7173

7274
# The object instance from which setup or prepare_data was called
7375
obj = args[0]
76+
name = fn.__name__
7477

7578
# If calling setup, we check the stage and assign stage-specific bool args
76-
if fn.__name__ == "setup":
79+
if name in ("setup", "teardown"):
7780

7881
# Get stage either by grabbing from args or checking kwargs.
7982
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
@@ -82,11 +85,11 @@ def wrapped_fn(*args, **kwargs):
8285

8386
if stage is None:
8487
for s in ("fit", "validate", "test"):
85-
setattr(obj, f"_has_setup_{s}", True)
88+
setattr(obj, f"_has_{name}_{s}", True)
8689
else:
87-
setattr(obj, f"_has_setup_{stage}", True)
90+
setattr(obj, f"_has_{name}_{stage}", True)
8891

89-
if fn.__name__ == "prepare_data":
92+
elif name == "prepare_data":
9093
obj._has_prepared_data = True
9194

9295
return fn(*args, **kwargs)
@@ -119,14 +122,18 @@ def val_dataloader(self):
119122
def test_dataloader(self):
120123
test_split = Dataset(...)
121124
return DataLoader(test_split)
125+
def teardown(self):
126+
# clean up after fit or test
127+
# called on every process in DDP
122128
123-
A DataModule implements 5 key methods:
129+
A DataModule implements 6 key methods:
124130
125131
* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
126132
* **setup** (things to do on every accelerator in distributed mode).
127133
* **train_dataloader** the training dataloader.
128134
* **val_dataloader** the val dataloader(s).
129135
* **test_dataloader** the test dataloader(s).
136+
* **teardown** (things to do on every accelerator in distributed mode when finished)
130137
131138
132139
This allows you to share a full dataset without explaining how to download,
@@ -154,11 +161,17 @@ def __init__(
154161

155162
# Private attrs to keep track of whether or not data hooks have been called yet
156163
self._has_prepared_data = False
164+
157165
self._has_setup_fit = False
158166
self._has_setup_validate = False
159167
self._has_setup_test = False
160168
self._has_setup_predict = False
161169

170+
self._has_teardown_fit = False
171+
self._has_teardown_validate = False
172+
self._has_teardown_test = False
173+
self._has_teardown_predict = False
174+
162175
@property
163176
def train_transforms(self):
164177
"""
@@ -259,13 +272,41 @@ def has_setup_predict(self) -> bool:
259272
"""
260273
return self._has_setup_predict
261274

262-
@abstractmethod
263-
def prepare_data(self, *args, **kwargs):
264-
pass
275+
@property
276+
def has_teardown_fit(self) -> bool:
277+
"""Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not.
265278
266-
@abstractmethod
267-
def setup(self, stage: Optional[str] = None):
268-
pass
279+
Returns:
280+
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
281+
"""
282+
return self._has_teardown_fit
283+
284+
@property
285+
def has_teardown_validate(self) -> bool:
286+
"""Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not.
287+
288+
Returns:
289+
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
290+
"""
291+
return self._has_teardown_validate
292+
293+
@property
294+
def has_teardown_test(self) -> bool:
295+
"""Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not.
296+
297+
Returns:
298+
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
299+
"""
300+
return self._has_teardown_test
301+
302+
@property
303+
def has_teardown_predict(self) -> bool:
304+
"""Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not.
305+
306+
Returns:
307+
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
308+
"""
309+
return self._has_teardown_predict
269310

270311
@classmethod
271312
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:

pytorch_lightning/core/hooks.py

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,6 @@
2525
class ModelHooks:
2626
"""Hooks to be used in LightningModule."""
2727

28-
def setup(self, stage: Optional[str] = None) -> None:
29-
"""
30-
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
31-
This is a good hook when you need to build models dynamically or adjust something about them.
32-
This hook is called on every process when using DDP.
33-
34-
Args:
35-
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
36-
37-
Example::
38-
39-
class LitModel(...):
40-
def __init__(self):
41-
self.l1 = None
42-
43-
def prepare_data(self):
44-
download_data()
45-
tokenize()
46-
47-
# don't do this
48-
self.something = else
49-
50-
def setup(stage):
51-
data = Load_data(...)
52-
self.l1 = nn.Linear(28, data.num_classes)
53-
54-
"""
55-
56-
def teardown(self, stage: Optional[str] = None) -> None:
57-
"""
58-
Called at the end of fit (train + validate), validate, test, predict, or tune.
59-
60-
Args:
61-
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
62-
"""
63-
6428
def on_fit_start(self) -> None:
6529
"""
6630
Called at the very beginning of fit.
@@ -395,6 +359,42 @@ def prepare_data(self):
395359
model.test_dataloader()
396360
"""
397361

362+
def setup(self, stage: Optional[str] = None) -> None:
363+
"""
364+
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
365+
This is a good hook when you need to build models dynamically or adjust something about them.
366+
This hook is called on every process when using DDP.
367+
368+
Args:
369+
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
370+
371+
Example::
372+
373+
class LitModel(...):
374+
def __init__(self):
375+
self.l1 = None
376+
377+
def prepare_data(self):
378+
download_data()
379+
tokenize()
380+
381+
# don't do this
382+
self.something = else
383+
384+
def setup(stage):
385+
data = Load_data(...)
386+
self.l1 = nn.Linear(28, data.num_classes)
387+
388+
"""
389+
390+
def teardown(self, stage: Optional[str] = None) -> None:
391+
"""
392+
Called at the end of fit (train + validate), validate, test, predict, or tune.
393+
394+
Args:
395+
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
396+
"""
397+
398398
def train_dataloader(self) -> Any:
399399
"""
400400
Implement one or more PyTorch DataLoaders for training.

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,12 @@ def call_setup_hook(self, model: LightningModule) -> None:
10861086

10871087
def call_teardown_hook(self, model: LightningModule) -> None:
10881088
state = self._teardown_state
1089+
1090+
if self.datamodule is not None:
1091+
called = getattr(self.datamodule, f'has_teardown_{state}')
1092+
if not called:
1093+
self.datamodule.teardown(stage=state)
1094+
10891095
self.profiler.teardown(stage=state)
10901096
self.teardown(stage=state)
10911097
model.teardown(stage=state)

0 commit comments

Comments
 (0)