Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 60 additions & 74 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,80 +24,7 @@
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types


class _DataModuleWrapper(type):

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.__has_added_checks = False

def __call__(cls, *args, **kwargs):
"""A wrapper for LightningDataModule that:

1. Runs user defined subclass's __init__
2. Assures prepare_data() runs on rank 0
3. Lets you check prepare_data and setup to see if they've been called
"""
if not cls.__has_added_checks:
cls.__has_added_checks = True
# Track prepare_data calls and make sure it runs on rank zero
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
# Track setup calls
cls.setup = track_data_hook_calls(cls.setup)
# Track teardown calls
cls.teardown = track_data_hook_calls(cls.teardown)

# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)

return obj


def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup/teardown has been called.

- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``

Args:
fn (function): Function that will be tracked to see if it has been called.

Returns:
function: Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""

@functools.wraps(fn)
def wrapped_fn(*args, **kwargs):

# The object instance from which setup or prepare_data was called
obj = args[0]
name = fn.__name__

# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)

if stage is None:
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
else:
setattr(obj, f"_has_{name}_{stage}", True)

elif name == "prepare_data":
obj._has_prepared_data = True

return fn(*args, **kwargs)

return wrapped_fn


class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper):
class LightningDataModule(CheckpointHooks, DataHooks):
"""
A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.
Expand Down Expand Up @@ -398,3 +325,62 @@ def test_dataloader():
if test_dataset is not None:
datamodule.test_dataloader = test_dataloader
return datamodule

def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule':
obj = super().__new__(cls)
# track `DataHooks` calls and run `prepare_data` only on rank zero
obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data))
obj.setup = cls._track_data_hook_calls(obj, obj.setup)
obj.teardown = cls._track_data_hook_calls(obj, obj.teardown)
return obj

@staticmethod
def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable:
"""A decorator that checks if prepare_data/setup/teardown has been called.

- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``

Args:
obj: Object whose function will be tracked
fn: Function that will be tracked to see if it has been called.

Returns:
Decorated function that tracks its call status and saves it to private attrs in its obj instance.
"""

@functools.wraps(fn)
def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any:
name = fn.__name__

# If calling setup, we check the stage and assign stage-specific bool args
if name in ("setup", "teardown"):

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call
# setup('test') on trainer.test()
stage = args[0] if len(args) else kwargs.get("stage", None)

if stage is None:
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_{name}_{s}", True)
else:
setattr(obj, f"_has_{name}_{stage}", True)

elif name == "prepare_data":
obj._has_prepared_data = True

return fn(*args, **kwargs)

return wrapped_fn

def __getstate__(self) -> dict:
# avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...>
d = self.__dict__.copy()
for fn in ("prepare_data", "setup", "teardown"):
del d[fn]
return d
12 changes: 6 additions & 6 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank):
assert trainer.data_connector.can_prepare_data()


def test_hooks_no_recursion_error(tmpdir):
def test_hooks_no_recursion_error():
# hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652
class DummyDM(LightningDataModule):
Expand All @@ -108,20 +108,20 @@ def prepare_data(self, *args, **kwargs):
dm.prepare_data()


def test_helper_boringdatamodule(tmpdir):
def test_helper_boringdatamodule():
dm = BoringDataModule()
dm.prepare_data()
dm.setup()


def test_helper_boringdatamodule_with_verbose_setup(tmpdir):
def test_helper_boringdatamodule_with_verbose_setup():
dm = BoringDataModule()
dm.prepare_data()
dm.setup('fit')
dm.setup('test')


def test_data_hooks_called(tmpdir):
def test_data_hooks_called():
dm = BoringDataModule()
assert not dm.has_prepared_data
assert not dm.has_setup_fit
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir):


@pytest.mark.parametrize("use_kwarg", (False, True))
def test_data_hooks_called_verbose(tmpdir, use_kwarg):
def test_data_hooks_called_verbose(use_kwarg):
dm = BoringDataModule()
dm.prepare_data()
assert not dm.has_setup_fit
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir):
assert dm.data_dir == args.data_dir == str(tmpdir)


def test_dm_pickle_after_init(tmpdir):
def test_dm_pickle_after_init():
dm = BoringDataModule()
pickle.dumps(dm)

Expand Down