From 224e831b6336372a93caa6dcdc31eba6d22fa5ea Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 02:41:55 +0200 Subject: [PATCH 1/4] Remove `_DataModuleWrapper` --- pytorch_lightning/core/datamodule.py | 134 ++++++++++++--------------- tests/core/test_datamodules.py | 12 +-- 2 files changed, 66 insertions(+), 80 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index a1f1c02ef498d..b994d08ebd02b 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -24,80 +24,7 @@ from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -class _DataModuleWrapper(type): - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.__has_added_checks = False - - def __call__(cls, *args, **kwargs): - """A wrapper for LightningDataModule that: - - 1. Runs user defined subclass's __init__ - 2. Assures prepare_data() runs on rank 0 - 3. Lets you check prepare_data and setup to see if they've been called - """ - if not cls.__has_added_checks: - cls.__has_added_checks = True - # Track prepare_data calls and make sure it runs on rank zero - cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) - # Track setup calls - cls.setup = track_data_hook_calls(cls.setup) - # Track teardown calls - cls.teardown = track_data_hook_calls(cls.teardown) - - # Get instance of LightningDataModule by mocking its __init__ via __call__ - obj = type.__call__(cls, *args, **kwargs) - - return obj - - -def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup/teardown has been called. - - - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. - Its corresponding `dm_has_setup_{stage}` attribute gets set to True - - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` - - Args: - fn (function): Function that will be tracked to see if it has been called. - - Returns: - function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. - """ - - @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): - - # The object instance from which setup or prepare_data was called - obj = args[0] - name = fn.__name__ - - # If calling setup, we check the stage and assign stage-specific bool args - if name in ("setup", "teardown"): - - # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit', 'validate', and 'test' to True. - # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() - stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - - if stage is None: - for s in ("fit", "validate", "test"): - setattr(obj, f"_has_{name}_{s}", True) - else: - setattr(obj, f"_has_{name}_{stage}", True) - - elif name == "prepare_data": - obj._has_prepared_data = True - - return fn(*args, **kwargs) - - return wrapped_fn - - -class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): +class LightningDataModule(CheckpointHooks, DataHooks): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. @@ -398,3 +325,62 @@ def test_dataloader(): if test_dataset is not None: datamodule.test_dataloader = test_dataloader return datamodule + + def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': + obj = super(LightningDataModule, cls).__new__(cls) + # save `args` and `kwargs` for `__reduce__` + obj.__args = args + obj.__kwargs = kwargs + # track `DataHooks` calls and run `prepare_data` only on rank zero + obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) + obj.setup = cls._track_data_hook_calls(obj, obj.setup) + obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) + return obj + + @staticmethod + def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable: + """A decorator that checks if prepare_data/setup/teardown has been called. + + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` + + Args: + obj: Object whose function will be tracked + fn: Function that will be tracked to see if it has been called. + + Returns: + Decorated function that tracks its call status and saves it to private attrs in its obj instance. + """ + + @functools.wraps(fn) + def wrapped_fn(*args: str, **kwargs: Optional[str]) -> callable: + name = fn.__name__ + + # If calling setup, we check the stage and assign stage-specific bool args + if name in ("setup", "teardown"): + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call + # setup('test') on trainer.test() + stage = args[0] if len(args) else kwargs.get("stage", None) + + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_{name}_{s}", True) + else: + setattr(obj, f"_has_{name}_{stage}", True) + + elif name == "prepare_data": + obj._has_prepared_data = True + + return fn(*args, **kwargs) + + return wrapped_fn + + def __reduce__(self) -> Tuple[type, tuple, dict]: + # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...> + return self.__class__, self.__args, self.__kwargs diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 40c38b9d3af3c..c4eb076e04773 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.data_connector.can_prepare_data() -def test_hooks_no_recursion_error(tmpdir): +def test_hooks_no_recursion_error(): # hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652 class DummyDM(LightningDataModule): @@ -108,20 +108,20 @@ def prepare_data(self, *args, **kwargs): dm.prepare_data() -def test_helper_boringdatamodule(tmpdir): +def test_helper_boringdatamodule(): dm = BoringDataModule() dm.prepare_data() dm.setup() -def test_helper_boringdatamodule_with_verbose_setup(tmpdir): +def test_helper_boringdatamodule_with_verbose_setup(): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') dm.setup('test') -def test_data_hooks_called(tmpdir): +def test_data_hooks_called(): dm = BoringDataModule() assert not dm.has_prepared_data assert not dm.has_setup_fit @@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir): @pytest.mark.parametrize("use_kwarg", (False, True)) -def test_data_hooks_called_verbose(tmpdir, use_kwarg): +def test_data_hooks_called_verbose(use_kwarg): dm = BoringDataModule() dm.prepare_data() assert not dm.has_setup_fit @@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir): assert dm.data_dir == args.data_dir == str(tmpdir) -def test_dm_pickle_after_init(tmpdir): +def test_dm_pickle_after_init(): dm = BoringDataModule() pickle.dumps(dm) From 617a1d92a3718aa42abd8e8fb369a98ea2553ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 30 Apr 2021 03:51:15 +0200 Subject: [PATCH 2/4] Update pytorch_lightning/core/datamodule.py --- pytorch_lightning/core/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index b994d08ebd02b..dae00d5d7f18a 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -356,7 +356,7 @@ def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable """ @functools.wraps(fn) - def wrapped_fn(*args: str, **kwargs: Optional[str]) -> callable: + def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: name = fn.__name__ # If calling setup, we check the stage and assign stage-specific bool args From 9c3103adea341dee9dd1ba41d5785a06f8f821ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 30 Apr 2021 14:52:37 +0200 Subject: [PATCH 3/4] Update pytorch_lightning/core/datamodule.py --- pytorch_lightning/core/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index dae00d5d7f18a..54af4090469ed 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -327,7 +327,7 @@ def test_dataloader(): return datamodule def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': - obj = super(LightningDataModule, cls).__new__(cls) + obj = super().__new__(cls) # save `args` and `kwargs` for `__reduce__` obj.__args = args obj.__kwargs = kwargs From 905a0bf7e15b5363fe1db5bd57f8d74c8c80baab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 15:26:06 +0200 Subject: [PATCH 4/4] Replace `__reduce__` with `__getstate__` --- pytorch_lightning/core/datamodule.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 54af4090469ed..a9add9fa47a71 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -328,9 +328,6 @@ def test_dataloader(): def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': obj = super().__new__(cls) - # save `args` and `kwargs` for `__reduce__` - obj.__args = args - obj.__kwargs = kwargs # track `DataHooks` calls and run `prepare_data` only on rank zero obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) obj.setup = cls._track_data_hook_calls(obj, obj.setup) @@ -381,6 +378,9 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: return wrapped_fn - def __reduce__(self) -> Tuple[type, tuple, dict]: + def __getstate__(self) -> dict: # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...> - return self.__class__, self.__args, self.__kwargs + d = self.__dict__.copy() + for fn in ("prepare_data", "setup", "teardown"): + del d[fn] + return d