Skip to content

Commit 3fdb61a

Browse files
authored
Replace _DataModuleWrapper with __new__ [1/2] (#7289)
* Remove `_DataModuleWrapper` * Update pytorch_lightning/core/datamodule.py * Update pytorch_lightning/core/datamodule.py * Replace `__reduce__` with `__getstate__`
1 parent 597b309 commit 3fdb61a

File tree

2 files changed

+66
-80
lines changed

2 files changed

+66
-80
lines changed

pytorch_lightning/core/datamodule.py

Lines changed: 60 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -24,80 +24,7 @@
2424
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
2525

2626

27-
class _DataModuleWrapper(type):
28-
29-
def __init__(self, *args: Any, **kwargs: Any) -> None:
30-
super().__init__(*args, **kwargs)
31-
self.__has_added_checks = False
32-
33-
def __call__(cls, *args, **kwargs):
34-
"""A wrapper for LightningDataModule that:
35-
36-
1. Runs user defined subclass's __init__
37-
2. Assures prepare_data() runs on rank 0
38-
3. Lets you check prepare_data and setup to see if they've been called
39-
"""
40-
if not cls.__has_added_checks:
41-
cls.__has_added_checks = True
42-
# Track prepare_data calls and make sure it runs on rank zero
43-
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
44-
# Track setup calls
45-
cls.setup = track_data_hook_calls(cls.setup)
46-
# Track teardown calls
47-
cls.teardown = track_data_hook_calls(cls.teardown)
48-
49-
# Get instance of LightningDataModule by mocking its __init__ via __call__
50-
obj = type.__call__(cls, *args, **kwargs)
51-
52-
return obj
53-
54-
55-
def track_data_hook_calls(fn):
56-
"""A decorator that checks if prepare_data/setup/teardown has been called.
57-
58-
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
59-
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
60-
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
61-
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
62-
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
63-
64-
Args:
65-
fn (function): Function that will be tracked to see if it has been called.
66-
67-
Returns:
68-
function: Decorated function that tracks its call status and saves it to private attrs in its obj instance.
69-
"""
70-
71-
@functools.wraps(fn)
72-
def wrapped_fn(*args, **kwargs):
73-
74-
# The object instance from which setup or prepare_data was called
75-
obj = args[0]
76-
name = fn.__name__
77-
78-
# If calling setup, we check the stage and assign stage-specific bool args
79-
if name in ("setup", "teardown"):
80-
81-
# Get stage either by grabbing from args or checking kwargs.
82-
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
83-
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
84-
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)
85-
86-
if stage is None:
87-
for s in ("fit", "validate", "test"):
88-
setattr(obj, f"_has_{name}_{s}", True)
89-
else:
90-
setattr(obj, f"_has_{name}_{stage}", True)
91-
92-
elif name == "prepare_data":
93-
obj._has_prepared_data = True
94-
95-
return fn(*args, **kwargs)
96-
97-
return wrapped_fn
98-
99-
100-
class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper):
27+
class LightningDataModule(CheckpointHooks, DataHooks):
10128
"""
10229
A DataModule standardizes the training, val, test splits, data preparation and transforms.
10330
The main advantage is consistent data splits, data preparation and transforms across models.
@@ -398,3 +325,62 @@ def test_dataloader():
398325
if test_dataset is not None:
399326
datamodule.test_dataloader = test_dataloader
400327
return datamodule
328+
329+
def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
330+
obj = super().__new__(cls)
331+
# track `DataHooks` calls and run `prepare_data` only on rank zero
332+
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
333+
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
334+
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
335+
return obj
336+
337+
@staticmethod
338+
def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable:
339+
"""A decorator that checks if prepare_data/setup/teardown has been called.
340+
341+
- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
342+
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
343+
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
344+
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
345+
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``
346+
347+
Args:
348+
obj: Object whose function will be tracked
349+
fn: Function that will be tracked to see if it has been called.
350+
351+
Returns:
352+
Decorated function that tracks its call status and saves it to private attrs in its obj instance.
353+
"""
354+
355+
@functools.wraps(fn)
356+
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
357+
name = fn.__name__
358+
359+
# If calling setup, we check the stage and assign stage-specific bool args
360+
if name in ("setup", "teardown"):
361+
362+
# Get stage either by grabbing from args or checking kwargs.
363+
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
364+
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call
365+
# setup('test') on trainer.test()
366+
stage = args[0] if len(args) else kwargs.get("stage", None)
367+
368+
if stage is None:
369+
for s in ("fit", "validate", "test"):
370+
setattr(obj, f"_has_{name}_{s}", True)
371+
else:
372+
setattr(obj, f"_has_{name}_{stage}", True)
373+
374+
elif name == "prepare_data":
375+
obj._has_prepared_data = True
376+
377+
return fn(*args, **kwargs)
378+
379+
return wrapped_fn
380+
381+
def __getstate__(self) -> dict:
382+
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...>
383+
d = self.__dict__.copy()
384+
for fn in ("prepare_data", "setup", "teardown"):
385+
del d[fn]
386+
return d

tests/core/test_datamodules.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank):
9191
assert trainer.data_connector.can_prepare_data()
9292

9393

94-
def test_hooks_no_recursion_error(tmpdir):
94+
def test_hooks_no_recursion_error():
9595
# hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error.
9696
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652
9797
class DummyDM(LightningDataModule):
@@ -108,20 +108,20 @@ def prepare_data(self, *args, **kwargs):
108108
dm.prepare_data()
109109

110110

111-
def test_helper_boringdatamodule(tmpdir):
111+
def test_helper_boringdatamodule():
112112
dm = BoringDataModule()
113113
dm.prepare_data()
114114
dm.setup()
115115

116116

117-
def test_helper_boringdatamodule_with_verbose_setup(tmpdir):
117+
def test_helper_boringdatamodule_with_verbose_setup():
118118
dm = BoringDataModule()
119119
dm.prepare_data()
120120
dm.setup('fit')
121121
dm.setup('test')
122122

123123

124-
def test_data_hooks_called(tmpdir):
124+
def test_data_hooks_called():
125125
dm = BoringDataModule()
126126
assert not dm.has_prepared_data
127127
assert not dm.has_setup_fit
@@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir):
168168

169169

170170
@pytest.mark.parametrize("use_kwarg", (False, True))
171-
def test_data_hooks_called_verbose(tmpdir, use_kwarg):
171+
def test_data_hooks_called_verbose(use_kwarg):
172172
dm = BoringDataModule()
173173
dm.prepare_data()
174174
assert not dm.has_setup_fit
@@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir):
246246
assert dm.data_dir == args.data_dir == str(tmpdir)
247247

248248

249-
def test_dm_pickle_after_init(tmpdir):
249+
def test_dm_pickle_after_init():
250250
dm = BoringDataModule()
251251
pickle.dumps(dm)
252252

0 commit comments

Comments
 (0)