From df96be7591c040b06070db9d73e6ad0284c9e204 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 16:36:54 +0200 Subject: [PATCH 01/11] Automatically check `DataModule.has_{setup,teardown,prepare_data}` --- pytorch_lightning/core/datamodule.py | 10 +++++++++- pytorch_lightning/trainer/trainer.py | 8 ++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 79aa6dc40b5d8..df77db14a96d2 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -84,15 +84,23 @@ def wrapped_fn(*args, **kwargs): stage = args[1] if len(args) > 1 else kwargs.get("stage", None) if stage is None: + has_run = True for s in ("fit", "validate", "test"): + has_run &= getattr(obj, f"_has_{name}_{s}") setattr(obj, f"_has_{name}_{s}", True) else: + has_run = getattr(obj, f"_has_{name}_{stage}") setattr(obj, f"_has_{name}_{stage}", True) elif name == "prepare_data": + has_run = obj._has_prepared_data obj._has_prepared_data = True - return fn(*args, **kwargs) + else: + raise ValueError(name) + + if not has_run: + return fn(*args, **kwargs) return wrapped_fn diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3c1cdc0a42748..f32075aff6c8f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1115,9 +1115,7 @@ def call_setup_hook(self, model: LightningModule) -> None: state = self._setup_state if self.datamodule is not None: - called = getattr(self.datamodule, f'has_setup_{state}') - if not called: - self.datamodule.setup(stage=state) + self.datamodule.setup(stage=state) self.setup(model, stage=state) model.setup(stage=state) @@ -1139,9 +1137,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state if self.datamodule is not None: - called = getattr(self.datamodule, f'has_teardown_{state}') - if not called: - self.datamodule.teardown(stage=state) + self.datamodule.teardown(stage=state) self.profiler.teardown(stage=state) self.teardown(stage=state) From 204741bc23bc221ac1a9b17689240c385ead221f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 16:39:26 +0200 Subject: [PATCH 02/11] Use variable --- pytorch_lightning/core/datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index df77db14a96d2..2b7c37136a9ca 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -86,11 +86,13 @@ def wrapped_fn(*args, **kwargs): if stage is None: has_run = True for s in ("fit", "validate", "test"): - has_run &= getattr(obj, f"_has_{name}_{s}") - setattr(obj, f"_has_{name}_{s}", True) + attr = f"_has_{name}_{s}" + has_run &= getattr(obj, attr) + setattr(obj, attr, True) else: - has_run = getattr(obj, f"_has_{name}_{stage}") - setattr(obj, f"_has_{name}_{stage}", True) + attr = f"_has_{name}_{stage}" + has_run = getattr(obj, attr) + setattr(obj, attr, True) elif name == "prepare_data": has_run = obj._has_prepared_data From a886e7e99f189e2f46b831ee9abe26e4984be2ac Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 16:57:53 +0200 Subject: [PATCH 03/11] Spacing --- pytorch_lightning/trainer/trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f32075aff6c8f..caba3bf406ed2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1116,7 +1116,6 @@ def call_setup_hook(self, model: LightningModule) -> None: if self.datamodule is not None: self.datamodule.setup(stage=state) - self.setup(model, stage=state) model.setup(stage=state) @@ -1138,7 +1137,6 @@ def call_teardown_hook(self, model: LightningModule) -> None: if self.datamodule is not None: self.datamodule.teardown(stage=state) - self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) From e4c7a747051eb5807d1c1c4085a4ee2b18a70ea2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 18:07:38 +0200 Subject: [PATCH 04/11] Docs --- docs/source/extensions/datamodules.rst | 28 ++++++++++++---------- docs/source/starter/introduction_guide.rst | 2 -- docs/source/starter/new-project.rst | 4 ++-- pytorch_lightning/core/hooks.py | 2 +- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index a602a75b0f877..978d9ffb07b7e 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -168,10 +168,6 @@ Here's a more realistic, complex DataModule that shows how much more reusable th def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32) - -.. note:: ``setup`` expects a string arg ``stage``. It is used to separate setup logic for ``trainer.fit`` and ``trainer.test``. - - --------------- LightningDataModule API @@ -228,7 +224,7 @@ There are also data operations you might want to perform on every GPU. Use setup def setup(self, stage: Optional[str] = None): # Assign Train/val split(s) for use in Dataloaders - if stage == 'fit' or stage is None: + if stage in (None, 'fit'): mnist_full = MNIST( self.data_dir, train=True, @@ -239,7 +235,7 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = self.mnist_train[0][0].shape # Assign Test split(s) for use in Dataloaders - if stage == 'test' or stage is None: + if stage in (None, 'test'): self.mnist_test = MNIST( self.data_dir, train=False, @@ -249,10 +245,17 @@ There are also data operations you might want to perform on every GPU. Use setup self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) -.. warning:: ``setup`` is called from every process. Setting state here is okay. - +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` expects an ``stage: Optional[str]`` argument. +It is used to separate setup logic for ``trainer.{fit,validate,test}``. If ``setup`` is called with ``stage = None``, +we assume all stages have been set-up. +.. note:: ``setup`` is called from every process. Setting state here is okay. .. note:: ``teardown`` can be used to clean up the state. It is also called from every process +.. note:: + ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. + If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that + any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite + ``dm.has_setup_fit = False`` train_dataloader @@ -396,11 +399,12 @@ The recommended way to use a DataModule is simply: dm = MNISTDataModule() model = Model() trainer.fit(model, dm) - trainer.test(datamodule=dm) -If you need information from the dataset to build your model, then run `prepare_data` and `setup` manually (Lightning -still ensures the method runs on the correct devices) +If you need information from the dataset to build your model, then run +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.prepare_data` and +:meth:`~pytorch_lightning.core.datamodule.LightningDataModule.setup` manually (Lightning ensures +the method runs on the correct devices). .. code-block:: python @@ -416,7 +420,7 @@ still ensures the method runs on the correct devices) ---------------- -Datamodules without Lightning +DataModules without Lightning ----------------------------- You can of course use DataModules in plain PyTorch code as well. diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 8d35e27185649..680a388ee118b 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -295,8 +295,6 @@ When your models need to know about the data, it's best to process the data befo 1. use ``prepare_data()`` to download and process the dataset. 2. use ``setup()`` to do splits, and build your model internals -| - An alternative to using a DataModule is to defer initialization of the models modules to the ``setup`` method of your LightningModule as follows: .. testcode:: diff --git a/docs/source/starter/new-project.rst b/docs/source/starter/new-project.rst index ff2c91580203d..46b1cc4d5d453 100644 --- a/docs/source/starter/new-project.rst +++ b/docs/source/starter/new-project.rst @@ -658,10 +658,10 @@ Make your data code reusable by organizing it into a :class:`~pytorch_lightning. transforms.Normalize((0.1307,), (0.3081,)) ]) # split dataset - if stage == 'fit': + if stage in (None, 'fit'): mnist_train = MNIST(os.getcwd(), train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000]) - if stage == 'test': + if stage == (None, 'test'): self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform) # return the dataloader for each split diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b55a8258e03fa..3759929e33e7f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -389,7 +389,7 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None) -> None: """ - Called at the beginning of fit (train + validate), validate, test, predict, or tune. + Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. From 3c7c43f4f76e071202c5289fb71434f34d29a8c7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 18:11:06 +0200 Subject: [PATCH 05/11] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c7d7297975f2..6e2427e20bbf5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -152,6 +152,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `pl.seed_everything` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024)) +- `DataModules` now avoid duplicate {setup,teardown,prepare_data} calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238)) + + - Changed default setting for communication of multi-node training using `DDPShardedPlugin` ([#6937](https://github.com/PyTorchLightning/pytorch-lightning/pull/6937)) From 224e831b6336372a93caa6dcdc31eba6d22fa5ea Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 02:41:55 +0200 Subject: [PATCH 06/11] Remove `_DataModuleWrapper` --- pytorch_lightning/core/datamodule.py | 134 ++++++++++++--------------- tests/core/test_datamodules.py | 12 +-- 2 files changed, 66 insertions(+), 80 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index a1f1c02ef498d..b994d08ebd02b 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -24,80 +24,7 @@ from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -class _DataModuleWrapper(type): - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.__has_added_checks = False - - def __call__(cls, *args, **kwargs): - """A wrapper for LightningDataModule that: - - 1. Runs user defined subclass's __init__ - 2. Assures prepare_data() runs on rank 0 - 3. Lets you check prepare_data and setup to see if they've been called - """ - if not cls.__has_added_checks: - cls.__has_added_checks = True - # Track prepare_data calls and make sure it runs on rank zero - cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data)) - # Track setup calls - cls.setup = track_data_hook_calls(cls.setup) - # Track teardown calls - cls.teardown = track_data_hook_calls(cls.teardown) - - # Get instance of LightningDataModule by mocking its __init__ via __call__ - obj = type.__call__(cls, *args, **kwargs) - - return obj - - -def track_data_hook_calls(fn): - """A decorator that checks if prepare_data/setup/teardown has been called. - - - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True - - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True - - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. - Its corresponding `dm_has_setup_{stage}` attribute gets set to True - - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` - - Args: - fn (function): Function that will be tracked to see if it has been called. - - Returns: - function: Decorated function that tracks its call status and saves it to private attrs in its obj instance. - """ - - @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): - - # The object instance from which setup or prepare_data was called - obj = args[0] - name = fn.__name__ - - # If calling setup, we check the stage and assign stage-specific bool args - if name in ("setup", "teardown"): - - # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit', 'validate', and 'test' to True. - # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() - stage = args[1] if len(args) > 1 else kwargs.get("stage", None) - - if stage is None: - for s in ("fit", "validate", "test"): - setattr(obj, f"_has_{name}_{s}", True) - else: - setattr(obj, f"_has_{name}_{stage}", True) - - elif name == "prepare_data": - obj._has_prepared_data = True - - return fn(*args, **kwargs) - - return wrapped_fn - - -class LightningDataModule(CheckpointHooks, DataHooks, metaclass=_DataModuleWrapper): +class LightningDataModule(CheckpointHooks, DataHooks): """ A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models. @@ -398,3 +325,62 @@ def test_dataloader(): if test_dataset is not None: datamodule.test_dataloader = test_dataloader return datamodule + + def __new__(cls, *args: Any, **kwargs: Any) -> 'LightningDataModule': + obj = super(LightningDataModule, cls).__new__(cls) + # save `args` and `kwargs` for `__reduce__` + obj.__args = args + obj.__kwargs = kwargs + # track `DataHooks` calls and run `prepare_data` only on rank zero + obj.prepare_data = cls._track_data_hook_calls(obj, rank_zero_only(obj.prepare_data)) + obj.setup = cls._track_data_hook_calls(obj, obj.setup) + obj.teardown = cls._track_data_hook_calls(obj, obj.teardown) + return obj + + @staticmethod + def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable: + """A decorator that checks if prepare_data/setup/teardown has been called. + + - When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True + - When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True + - When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``. + Its corresponding `dm_has_setup_{stage}` attribute gets set to True + - ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup`` + + Args: + obj: Object whose function will be tracked + fn: Function that will be tracked to see if it has been called. + + Returns: + Decorated function that tracks its call status and saves it to private attrs in its obj instance. + """ + + @functools.wraps(fn) + def wrapped_fn(*args: str, **kwargs: Optional[str]) -> callable: + name = fn.__name__ + + # If calling setup, we check the stage and assign stage-specific bool args + if name in ("setup", "teardown"): + + # Get stage either by grabbing from args or checking kwargs. + # If not provided, set call status of 'fit', 'validate', and 'test' to True. + # We do this so __attach_datamodule in trainer.py doesn't mistakenly call + # setup('test') on trainer.test() + stage = args[0] if len(args) else kwargs.get("stage", None) + + if stage is None: + for s in ("fit", "validate", "test"): + setattr(obj, f"_has_{name}_{s}", True) + else: + setattr(obj, f"_has_{name}_{stage}", True) + + elif name == "prepare_data": + obj._has_prepared_data = True + + return fn(*args, **kwargs) + + return wrapped_fn + + def __reduce__(self) -> Tuple[type, tuple, dict]: + # avoids _pickle.PicklingError: Can't pickle <...>: it's not the same object as <...> + return self.__class__, self.__args, self.__kwargs diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 40c38b9d3af3c..c4eb076e04773 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -91,7 +91,7 @@ def test_can_prepare_data(local_rank, node_rank): assert trainer.data_connector.can_prepare_data() -def test_hooks_no_recursion_error(tmpdir): +def test_hooks_no_recursion_error(): # hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error. # See https://github.com/PyTorchLightning/pytorch-lightning/issues/3652 class DummyDM(LightningDataModule): @@ -108,20 +108,20 @@ def prepare_data(self, *args, **kwargs): dm.prepare_data() -def test_helper_boringdatamodule(tmpdir): +def test_helper_boringdatamodule(): dm = BoringDataModule() dm.prepare_data() dm.setup() -def test_helper_boringdatamodule_with_verbose_setup(tmpdir): +def test_helper_boringdatamodule_with_verbose_setup(): dm = BoringDataModule() dm.prepare_data() dm.setup('fit') dm.setup('test') -def test_data_hooks_called(tmpdir): +def test_data_hooks_called(): dm = BoringDataModule() assert not dm.has_prepared_data assert not dm.has_setup_fit @@ -168,7 +168,7 @@ def test_data_hooks_called(tmpdir): @pytest.mark.parametrize("use_kwarg", (False, True)) -def test_data_hooks_called_verbose(tmpdir, use_kwarg): +def test_data_hooks_called_verbose(use_kwarg): dm = BoringDataModule() dm.prepare_data() assert not dm.has_setup_fit @@ -246,7 +246,7 @@ def test_dm_init_from_argparse_args(tmpdir): assert dm.data_dir == args.data_dir == str(tmpdir) -def test_dm_pickle_after_init(tmpdir): +def test_dm_pickle_after_init(): dm = BoringDataModule() pickle.dumps(dm) From de9dcc1595be5149cf606b119559b61b04bde88f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 30 Apr 2021 03:38:04 +0200 Subject: [PATCH 07/11] Add test --- tests/core/test_datamodules.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c4eb076e04773..37b908bc5acec 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -522,3 +522,46 @@ def test_dm_init_from_datasets_dataloaders(iterable): call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True), call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True) ]) + + +def test_datamodule_hooks_calls(tmpdir): + """Test that repeated calls to DataHooks' hooks have no effect""" + + class TestDataModule(BoringDataModule): + setup_calls = [] + teardown_calls = [] + prepare_data_calls = 0 + + def setup(self, stage=None): + super().setup(stage=stage) + self.setup_calls.append(stage) + + def teardown(self, stage=None): + super().teardown(stage=stage) + self.teardown_calls.append(stage) + + def prepare_data(self): + super().prepare_data() + self.prepare_data_calls += 1 + + dm = TestDataModule() + dm.prepare_data() + dm.prepare_data() + dm.setup('fit') + dm.setup('fit') + dm.setup() + dm.setup() + dm.teardown('validate') + dm.teardown('validate') + + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate'] + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + trainer.test(BoringModel(), datamodule=dm) + + # same number of calls + assert dm.prepare_data_calls == 1 + assert dm.setup_calls == ['fit', None] + assert dm.teardown_calls == ['validate', 'test'] From 074308101a8e40c93d10d0db1584bf462144771d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 4 May 2021 12:52:09 +0200 Subject: [PATCH 08/11] Update docs/source/extensions/datamodules.rst --- docs/source/extensions/datamodules.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 978d9ffb07b7e..fbb19e10a8e1e 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -255,7 +255,7 @@ we assume all stages have been set-up. ``{setup,teardown,prepare_data}`` call will be only called once for a specific stage. If the stage was ``None`` then we assume ``{fit,validate,test}`` have been called. For example, this means that any duplicate ``dm.setup('fit')`` calls will be a no-op. To avoid this, you can overwrite - ``dm.has_setup_fit = False`` + ``dm._has_setup_fit = False`` train_dataloader From acdc61b1347efffd8707407203ebc748b844790e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 4 May 2021 12:57:08 +0200 Subject: [PATCH 09/11] Bad merge --- pytorch_lightning/trainer/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4e4c448a3911d..6793ce08aed0a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1156,8 +1156,8 @@ def call_setup_hook(self, model: LightningModule) -> None: if self.datamodule is not None: self.datamodule.setup(stage=fn) - self.setup(model, stage=state) - model.setup(stage=state) + self.setup(model, stage=fn) + model.setup(stage=fn) self.accelerator.barrier("post_setup") @@ -1179,9 +1179,9 @@ def call_teardown_hook(self, model: LightningModule) -> None: if self.datamodule is not None: self.datamodule.teardown(stage=fn) - self.profiler.teardown(stage=state) - self.teardown(stage=state) - model.teardown(stage=state) + self.profiler.teardown(stage=fn) + self.teardown(stage=fn) + model.teardown(stage=fn) model._current_fx_name = "" model._current_hook_fx_name = None From 23ab37fb33d9b19f991453a571836f113cb633cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 8 May 2021 01:32:16 +0200 Subject: [PATCH 10/11] add test for invalid name --- tests/core/test_datamodules.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 7cfa569115550..8815549a32239 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -564,3 +564,9 @@ def prepare_data(self): assert dm.prepare_data_calls == 1 assert dm.setup_calls == ['fit', None] assert dm.teardown_calls == ['validate', 'test'] + + with pytest.raises(ValueError, match="sunflower"): + dm.setup("sunflower") + + with pytest.raises(ValueError, match="sunflower"): + dm.teardown("sunflower") From ad8221b09b5e2d7d3df3537aa26299cd399fb548 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 10 May 2021 13:38:48 +0200 Subject: [PATCH 11/11] Remove ValueError --- pytorch_lightning/core/datamodule.py | 4 +--- tests/core/test_datamodules.py | 6 ------ 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 1e85f81310aff..23626ed9cbeae 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -355,6 +355,7 @@ def _track_data_hook_calls(obj: 'LightningDataModule', fn: callable) -> callable @functools.wraps(fn) def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: name = fn.__name__ + has_run = False # If calling setup, we check the stage and assign stage-specific bool args if name in ("setup", "teardown"): @@ -380,9 +381,6 @@ def wrapped_fn(*args: str, **kwargs: Optional[str]) -> Any: has_run = obj._has_prepared_data obj._has_prepared_data = True - else: - raise ValueError(name) - if not has_run: return fn(*args, **kwargs) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 8815549a32239..7cfa569115550 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -564,9 +564,3 @@ def prepare_data(self): assert dm.prepare_data_calls == 1 assert dm.setup_calls == ['fit', None] assert dm.teardown_calls == ['validate', 'test'] - - with pytest.raises(ValueError, match="sunflower"): - dm.setup("sunflower") - - with pytest.raises(ValueError, match="sunflower"): - dm.teardown("sunflower")