1414"""LightningDataModule for loading DataLoaders with ease."""
1515
1616import functools
17- from abc import abstractmethod
1817from argparse import ArgumentParser , Namespace
1918from typing import Any , List , Mapping , Optional , Sequence , Tuple , Union
2019
@@ -44,6 +43,8 @@ def __call__(cls, *args, **kwargs):
4443 cls .prepare_data = track_data_hook_calls (rank_zero_only (cls .prepare_data ))
4544 # Track setup calls
4645 cls .setup = track_data_hook_calls (cls .setup )
46+ # Track teardown calls
47+ cls .teardown = track_data_hook_calls (cls .teardown )
4748
4849 # Get instance of LightningDataModule by mocking its __init__ via __call__
4950 obj = type .__call__ (cls , * args , ** kwargs )
@@ -52,12 +53,13 @@ def __call__(cls, *args, **kwargs):
5253
5354
5455def track_data_hook_calls (fn ):
55- """A decorator that checks if prepare_data/setup have been called.
56+ """A decorator that checks if prepare_data/setup/teardown has been called.
5657
5758 - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
5859 - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
5960 - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
6061 Its corresponding `dm_has_setup_{stage}` attribute gets set to True
62+ - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
6163
6264 Args:
6365 fn (function): Function that will be tracked to see if it has been called.
@@ -71,9 +73,10 @@ def wrapped_fn(*args, **kwargs):
7173
7274 # The object instance from which setup or prepare_data was called
7375 obj = args [0 ]
76+ name = fn .__name__
7477
7578 # If calling setup, we check the stage and assign stage-specific bool args
76- if fn . __name__ == "setup" :
79+ if name in ( "setup" , "teardown" ) :
7780
7881 # Get stage either by grabbing from args or checking kwargs.
7982 # If not provided, set call status of 'fit', 'validate', and 'test' to True.
@@ -82,11 +85,11 @@ def wrapped_fn(*args, **kwargs):
8285
8386 if stage is None :
8487 for s in ("fit" , "validate" , "test" ):
85- setattr (obj , f"_has_setup_ { s } " , True )
88+ setattr (obj , f"_has_ { name } _ { s } " , True )
8689 else :
87- setattr (obj , f"_has_setup_ { stage } " , True )
90+ setattr (obj , f"_has_ { name } _ { stage } " , True )
8891
89- if fn . __name__ == "prepare_data" :
92+ elif name == "prepare_data" :
9093 obj ._has_prepared_data = True
9194
9295 return fn (* args , ** kwargs )
@@ -119,14 +122,18 @@ def val_dataloader(self):
119122 def test_dataloader(self):
120123 test_split = Dataset(...)
121124 return DataLoader(test_split)
125+ def teardown(self):
126+ # clean up after fit or test
127+ # called on every process in DDP
122128
123- A DataModule implements 5 key methods:
129+ A DataModule implements 6 key methods:
124130
125131 * **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
126132 * **setup** (things to do on every accelerator in distributed mode).
127133 * **train_dataloader** the training dataloader.
128134 * **val_dataloader** the val dataloader(s).
129135 * **test_dataloader** the test dataloader(s).
136+ * **teardown** (things to do on every accelerator in distributed mode when finished)
130137
131138
132139 This allows you to share a full dataset without explaining how to download,
@@ -154,11 +161,17 @@ def __init__(
154161
155162 # Private attrs to keep track of whether or not data hooks have been called yet
156163 self ._has_prepared_data = False
164+
157165 self ._has_setup_fit = False
158166 self ._has_setup_validate = False
159167 self ._has_setup_test = False
160168 self ._has_setup_predict = False
161169
170+ self ._has_teardown_fit = False
171+ self ._has_teardown_validate = False
172+ self ._has_teardown_test = False
173+ self ._has_teardown_predict = False
174+
162175 @property
163176 def train_transforms (self ):
164177 """
@@ -259,13 +272,41 @@ def has_setup_predict(self) -> bool:
259272 """
260273 return self ._has_setup_predict
261274
262- @abstractmethod
263- def prepare_data (self , * args , ** kwargs ) :
264- pass
275+ @property
276+ def has_teardown_fit (self ) -> bool :
277+ """Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not.
265278
266- @abstractmethod
267- def setup (self , stage : Optional [str ] = None ):
268- pass
279+ Returns:
280+ bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
281+ """
282+ return self ._has_teardown_fit
283+
284+ @property
285+ def has_teardown_validate (self ) -> bool :
286+ """Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not.
287+
288+ Returns:
289+ bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
290+ """
291+ return self ._has_teardown_validate
292+
293+ @property
294+ def has_teardown_test (self ) -> bool :
295+ """Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not.
296+
297+ Returns:
298+ bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
299+ """
300+ return self ._has_teardown_test
301+
302+ @property
303+ def has_teardown_predict (self ) -> bool :
304+ """Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not.
305+
306+ Returns:
307+ bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
308+ """
309+ return self ._has_teardown_predict
269310
270311 @classmethod
271312 def add_argparse_args (cls , parent_parser : ArgumentParser , ** kwargs ) -> ArgumentParser :
0 commit comments