Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Rewrote `accelerator_connector` ([#11448](https://github.com/PyTorchLightning/pytorch-lightning/pull/11448))


- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))


### Deprecated

- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
Expand Down Expand Up @@ -672,6 +676,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed passing `_ddp_params_and_buffers_to_ignore` ([#11949](https://github.com/PyTorchLightning/pytorch-lightning/pull/11949))


- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))


- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))


- Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131))


## [1.5.10] - 2022-02-08

### Fixed
Expand All @@ -688,12 +701,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))


- Fixed an `AttributeError` when calling `save_hyperparameters` and no parameters need saving ([#11827](https://github.com/PyTorchLightning/pytorch-lightning/pull/11827))


- Fixed environment variable priority for global rank determination ([#11406](https://github.com/PyTorchLightning/pytorch-lightning/pull/11406))


## [1.5.9] - 2022-01-20

### Fixed
Expand All @@ -708,9 +715,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))


- Disable loading dataloades if corresponding `limit_batches=0` ([#11576](https://github.com/PyTorchLightning/pytorch-lightning/pull/11576))


## [1.5.8] - 2022-01-05

### Fixed
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,9 @@ def get_hook(self, hook_name: str) -> Callable:
)
return getattr(self.datamodule, hook_name)

warning_cache.warn(
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
)
if is_overridden(hook_name, self.model):
warning_cache.warn(
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
)
return getattr(self.model, hook_name)
13 changes: 6 additions & 7 deletions tests/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.deprecated_api import no_warning_call
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.boring_model import RandomDataset
from tests.helpers.utils import no_warning_call


class NoDataLoaderModel(BoringModel):
Expand Down Expand Up @@ -80,28 +80,29 @@ def overridden_func(self, batch, *args, **kwargs):
return batch

def reset_instances(self):
warning_cache.clear()
return BoringDataModule(), BoringModel(), Trainer()

def test_no_datamodule_no_overridden(self, hook_name):
model, _, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=None)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

assert hook == getattr(model, hook_name)

def test_with_datamodule_no_overridden(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

assert hook == getattr(model, hook_name)

def test_override_model_hook(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

assert hook == getattr(model, hook_name)
Expand All @@ -110,7 +111,7 @@ def test_override_datamodule_hook(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
setattr(dm, hook_name, self.overridden_func)
with no_warning_call(match="have overridden `{hook_name}` in both"):
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

assert hook == getattr(dm, hook_name)
Expand All @@ -123,7 +124,6 @@ def test_override_both_model_and_datamodule(self, hook_name):
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

warning_cache.clear()
assert hook == getattr(dm, hook_name)

def test_with_datamodule_override_model(self, hook_name):
Expand All @@ -133,7 +133,6 @@ def test_with_datamodule_override_model(self, hook_name):
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)

warning_cache.clear()
assert hook == getattr(model, hook_name)


Expand Down