diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6d67d2d58643a..409d3f51bd46f 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added prefix to log message in `seed_everything` with rank info ([#13290](https://github.com/Lightning-AI/lightning/issues/13290)) +- Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069)) + + - diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index f58503edd88cb..612bcc72d2806 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -37,7 +37,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.loggers import Logger, LoggerCollection -from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors @@ -291,16 +290,24 @@ def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 ) -> Any: device = device or self.device - datahook_selector = ( - _DataHookSelector(self, None) if self._trainer is None else self.trainer._data_connector._datahook_selector - ) - hook = datahook_selector.get_hook("on_before_batch_transfer") - batch = hook(batch, dataloader_idx) - hook = datahook_selector.get_hook("transfer_batch_to_device") - batch = hook(batch, device, dataloader_idx) - hook = datahook_selector.get_hook("on_after_batch_transfer") - batch = hook(batch, dataloader_idx) + def call_hook(hook_name, *args): + if self._trainer: + datahook_selector = self._trainer._data_connector._datahook_selector + obj = datahook_selector.get_instance(hook_name) + trainer_method = ( + self._trainer._call_lightning_module_hook + if isinstance(obj, self.__class__) + else self._trainer._call_lightning_datamodule_hook + ) + return trainer_method(hook_name, *args) + else: + hook = getattr(self, hook_name) + return hook(*args) + + batch = call_hook("on_before_batch_transfer", batch, dataloader_idx) + batch = call_hook("transfer_batch_to_device", batch, device, dataloader_idx) + batch = call_hook("on_after_batch_transfer", batch, dataloader_idx) return batch def print(self, *args, **kwargs) -> None: diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index 60dfb1ab6c92f..285a0f31e3955 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -182,7 +182,9 @@ def _clip_gradients( if not isinstance(model, pl.LightningModule) or not model.automatic_optimization: # the configuration validator disallows clipping on manual return - model.configure_gradient_clipping( + + model.trainer._call_lightning_module_hook( + "configure_gradient_clipping", optimizer, optimizer_idx, gradient_clip_val=clip_val, diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index e1aca404722db..1de8bee90d18f 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Callable, Collection, List, Optional, Tuple, Union +from typing import Any, Collection, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler @@ -527,16 +527,16 @@ def is_module(self) -> bool: @dataclass class _DataHookSelector: - """Stores the info about the shared DataHooks within LightningModule and LightningDataModule. + """Stores the info about the shared DataHooks within ``LightningModule`` and ``LightningDataModule``. - The hook source can be + The hook source can be: - 1. a method from the :class:`~pytorch_lightning.core.module.LightningModule`, - 2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`, + 1. the :class:`~pytorch_lightning.core.module.LightningModule`, + 2. the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`, Arguments: - model: A LightningModule - datamodule: A LightningDataModule + model: A ``LightningModule`` + datamodule: A ``LightningDataModule`` """ model: "pl.LightningModule" @@ -545,7 +545,7 @@ class _DataHookSelector: default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) - def get_hook(self, hook_name: str) -> Callable: + def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if hook_name not in self._valid_hooks: raise ValueError( f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`." @@ -553,7 +553,7 @@ def get_hook(self, hook_name: str) -> Callable: ) if self.datamodule is None: - return getattr(self.model, hook_name) + return self.model if is_overridden(hook_name, self.datamodule): if is_overridden(hook_name, self.model): @@ -561,11 +561,11 @@ def get_hook(self, hook_name: str) -> Callable: f"You have overridden `{hook_name}` in both `LightningModule` and `LightningDataModule`." " It will use the implementation from `LightningDataModule` instance." ) - return getattr(self.datamodule, hook_name) + return self.datamodule if is_overridden(hook_name, self.model): warning_cache.warn( f"You have overridden `{hook_name}` in `LightningModule` but have passed in a" " `LightningDataModule`. It will use the implementation from `LightningModule` instance." ) - return getattr(self.model, hook_name) + return self.model diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 6f60ba6f1aa2f..56ad53ef4ba04 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -44,6 +44,8 @@ class _LogOptions(TypedDict): allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), "lr_scheduler_step": None, + "configure_gradient_clipping": None, + "clip_gradients": None, "on_before_zero_grad": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), @@ -98,6 +100,9 @@ class _LogOptions(TypedDict): "on_epoch_end": _LogOptions( allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True ), + "on_before_batch_transfer": None, + "transfer_batch_to_device": None, + "on_after_batch_transfer": None, "on_batch_start": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 2650e46b7fa60..7273d7719834e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -471,34 +471,34 @@ def test_no_datamodule_no_overridden(self, hook_name): model, _, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=None) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_with_datamodule_no_overridden(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_override_model_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_override_datamodule_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) setattr(dm, hook_name, self.overridden_func) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(dm, hook_name) + assert instance is dm def test_override_both_model_and_datamodule(self, hook_name): model, dm, trainer = self.reset_instances() @@ -506,24 +506,24 @@ def test_override_both_model_and_datamodule(self, hook_name): setattr(model, hook_name, self.overridden_func) setattr(dm, hook_name, self.overridden_func) with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(dm, hook_name) + assert instance is dm def test_with_datamodule_override_model(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) setattr(model, hook_name, self.overridden_func) with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_invalid_hook_passed_in_datahook_selector(): dh_selector = _DataHookSelector(BoringModel(), None) with pytest.raises(ValueError, match="is not a shared hook"): - dh_selector.get_hook("setup") + dh_selector.get_instance("setup") def test_eval_distributed_sampler_warning(tmpdir): diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 760e8eea2a85c..c2be22c61244b 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -187,11 +187,6 @@ def __init__(self, not_supported): { "log", "log_dict", - # the following are problematic as they do have `self._current_fx_name` defined some times but - # not others depending on where they were called. So we cannot reliably `self.log` in them - "on_before_batch_transfer", - "transfer_batch_to_device", - "on_after_batch_transfer", } ) # remove `nn.Module` hooks @@ -227,6 +222,9 @@ def test_fx_validator_integration(tmpdir): "on_pretrain_routine_end": "You can't", "train_dataloader": "You can't", "val_dataloader": "You can't", + "on_before_batch_transfer": "You can't", + "transfer_batch_to_device": "You can't", + "on_after_batch_transfer": "You can't", "on_validation_end": "You can't", "on_train_end": "You can't", "on_fit_end": "You can't", @@ -238,6 +236,8 @@ def test_fx_validator_integration(tmpdir): "on_validation_model_eval": "You can't", "on_validation_model_train": "You can't", "lr_scheduler_step": "You can't", + "configure_gradient_clipping": "You can't", + "clip_gradients": "You can't", "on_save_checkpoint": "You can't", "on_load_checkpoint": "You can't", "on_exception": "You can't",