diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e13686f9018d..d795ff3970a88 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -314,6 +314,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448)) + +- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576)) + + ### Deprecated - Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) @@ -672,6 +676,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed passing `_ddp_params_and_buffers_to_ignore` ([#11949](https://github.com/PyTorchLightning/pytorch-lightning/pull/11949)) +- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827)) + + +- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406)) + + +- Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131)) + + ## [1.5.10] - 2022-02-08 ### Fixed @@ -688,12 +701,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481)) -- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827)) - - -- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406)) - - ## [1.5.9] - 2022-01-20 ### Fixed @@ -708,9 +715,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) -- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576)) - - ## [1.5.8] - 2022-01-05 ### Fixed diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index b79b095feca94..b0cf6a95fac35 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -599,8 +599,9 @@ def get_hook(self, hook_name: str) -> Callable: ) return getattr(self.datamodule, hook_name) - warning_cache.warn( - f"You have overridden `{hook_name}` in `LightningModule` but have passed in a" - " `LightningDataModule`. It will use the implementation from `LightningModule` instance." - ) + if is_overridden(hook_name, self.model): + warning_cache.warn( + f"You have overridden `{hook_name}` in `LightningModule` but have passed in a" + " `LightningDataModule`. It will use the implementation from `LightningModule` instance." + ) return getattr(self.model, hook_name) diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index bb618dfa091e9..e22e846600122 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -20,9 +20,9 @@ from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.warnings import PossibleUserWarning -from tests.deprecated_api import no_warning_call from tests.helpers import BoringDataModule, BoringModel from tests.helpers.boring_model import RandomDataset +from tests.helpers.utils import no_warning_call class NoDataLoaderModel(BoringModel): @@ -80,12 +80,13 @@ def overridden_func(self, batch, *args, **kwargs): return batch def reset_instances(self): + warning_cache.clear() return BoringDataModule(), BoringModel(), Trainer() def test_no_datamodule_no_overridden(self, hook_name): model, _, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=None) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name) @@ -93,7 +94,7 @@ def test_no_datamodule_no_overridden(self, hook_name): def test_with_datamodule_no_overridden(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name) @@ -101,7 +102,7 @@ def test_with_datamodule_no_overridden(self, hook_name): def test_override_model_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(model, hook_name) @@ -110,7 +111,7 @@ def test_override_datamodule_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) setattr(dm, hook_name, self.overridden_func) - with no_warning_call(match="have overridden `{hook_name}` in both"): + with no_warning_call(match=f"have overridden `{hook_name}` in"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) assert hook == getattr(dm, hook_name) @@ -123,7 +124,6 @@ def test_override_both_model_and_datamodule(self, hook_name): with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) - warning_cache.clear() assert hook == getattr(dm, hook_name) def test_with_datamodule_override_model(self, hook_name): @@ -133,7 +133,6 @@ def test_with_datamodule_override_model(self, hook_name): with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"): hook = trainer._data_connector._datahook_selector.get_hook(hook_name) - warning_cache.clear() assert hook == getattr(model, hook_name)