diff --git a/CHANGELOG.md b/CHANGELOG.md index d29c6d6d5c7d7..8be044ba41fbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) +- Deprecated `on_init_start` and `on_init_end` callback hooks ([#10940](https://github.com/PyTorchLightning/pytorch-lightning/pull/10940)) + + - Deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#10979](https://github.com/PyTorchLightning/pytorch-lightning/pull/10979)) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 7f8e2a8585920..d872dd3210754 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -46,23 +46,15 @@ Example: class MyPrintingCallback(Callback): - def on_init_start(self, trainer): - print("Starting to initialize the trainer!") - - def on_init_end(self, trainer): - print("trainer is initialized now") + def on_train_start(self, trainer, pl_module): + print("Training is starting") def on_train_end(self, trainer, pl_module): - print("do something when training ends") + print("Training is ending") trainer = Trainer(callbacks=[MyPrintingCallback()]) -.. testoutput:: - - Starting to initialize the trainer! - trainer is initialized now - We successfully extended functionality without polluting our super clean :doc:`lightning module <../common/lightning_module>` research code. diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 41eb0b974cc2c..01275530684e2 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -956,14 +956,11 @@ for hooks that you might care about class MyPrintingCallback(Callback): - def on_init_start(self, trainer): - print("Starting to init trainer!") - - def on_init_end(self, trainer): - print("Trainer is init now") + def on_train_start(self, trainer, pl_module): + print("Training is starting") def on_train_end(self, trainer, pl_module): - print("do something when training ends") + print("Training is ending") And pass the callbacks into the trainer @@ -971,12 +968,6 @@ And pass the callbacks into the trainer trainer = Trainer(callbacks=[MyPrintingCallback()]) -.. testoutput:: - :hide: - - Starting to init trainer! - Trainer is init now - .. tip:: See full list of 12+ hooks in the :doc:`callbacks <../extensions/callbacks>`. diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 42d5baf9c036b..df61f90d84d67 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -73,11 +73,21 @@ def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage pass def on_init_start(self, trainer: "pl.Trainer") -> None: - """Called when the trainer initialization begins, model has not yet been set.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization begins, model has not yet been set. + """ pass def on_init_end(self, trainer: "pl.Trainer") -> None: - """Called when the trainer initialization ends, model has not yet been set.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. + + Called when the trainer initialization ends, model has not yet been set. + """ pass def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index e292cd961711a..a8eb484ac4b34 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -125,7 +125,7 @@ def __init__( def state_key(self) -> str: return self._generate_state_key(monitor=self.monitor, mode=self.mode) - def on_init_end(self, trainer: "pl.Trainer") -> None: + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: if self._check_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without # validation, then we run after validation instead of on train epoch end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 33f872f3a9f9b..cd307d18bc03a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -248,14 +248,13 @@ def state_key(self) -> str: save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_init_end(self, trainer: "pl.Trainer") -> None: + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """When pretrain routine starts we build the ckpt dir on the fly.""" if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without # validation, then we run after validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """When pretrain routine starts we build the ckpt dir on the fly.""" self.__resolve_ckpt_dir(trainer) if trainer.is_global_zero: self.__warn_if_dir_not_empty(self.dirpath) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 07cc3136fc7e2..1ec3c2b9a33f1 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Union +from typing import Dict, Optional, Union import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback @@ -152,7 +152,7 @@ def print(self, *args, **kwargs): """You should provide a way to print without breaking the progress bar.""" print(*args, **kwargs) - def on_init_end(self, trainer): + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: self._trainer = trainer def on_train_start(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 8f61a050cfdf7..a917fc2594674 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -43,7 +43,7 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule __verify_dp_batch_transfer_support(trainer, model) _check_add_get_queue(model) - # TODO(@daniellepintz): Delete _check_progress_bar in v1.7 + # TODO: Delete _check_progress_bar in v1.7 _check_progress_bar(model) # TODO: Delete _check_on_post_move_to_device in v1.7 _check_on_post_move_to_device(model) @@ -51,6 +51,8 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule _check_on_keyboard_interrupt(trainer) # TODO: Remove this in v1.7 (deprecation: #9816) _check_dl_idx_in_on_train_batch_hooks(trainer, model) + # TODO: Remove this in v1.8 + _check_on_init_start_end(trainer) def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: @@ -290,3 +292,14 @@ def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer", model: "pl.Ligh f"Base `Callback.{hook}` hook signature has changed in v1.5." " The `dataloader_idx` argument will be removed in v1.7." ) + + +def _check_on_init_start_end(trainer: "pl.Trainer") -> None: + """Checks if on_init_start/end are overridden and sends a deprecation warning.""" + for callback in trainer.callbacks: + if is_overridden(method_name="on_init_start", instance=callback): + rank_zero_deprecation( + "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8." + ) + if is_overridden(method_name="on_init_end", instance=callback): + rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7687b5ecf148c..d6a5910a630db 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1537,6 +1537,7 @@ def _call_callback_hooks( *args: Any, **kwargs: Any, ) -> None: + # TODO: remove if block in v1.8 if hook_name in ("on_init_start", "on_init_end"): # these `Callback` hooks are the only ones that do not take a lightning module. # we also don't profile bc profiler hasn't been set yet @@ -1551,7 +1552,7 @@ def _call_callback_hooks( prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = hook_name - # TODO: remove if statement in v1.7 + # TODO: remove if block in v1.7 if hook_name in ("on_train_batch_start", "on_train_batch_end"): fn = getattr(self, hook_name) if callable(fn): diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index ba66ad169f473..057d4dc7421bb 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -81,9 +81,6 @@ def test_tqdm_progress_bar_totals(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) bar = trainer.progress_bar_callback - assert float("inf") == bar.total_train_batches - assert 0 == bar.total_val_batches - assert 0 == bar.total_test_batches trainer.fit(model) @@ -584,7 +581,7 @@ def test_tqdm_progress_bar_main_bar_resume(): trainer.num_val_batches = [3] trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3 - bar.on_init_end(trainer) + bar.setup(trainer, model) bar.on_train_start(trainer, model) bar.on_train_epoch_start(trainer, model) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index e58e7927641c3..e68cce4dfbaa1 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -16,10 +16,11 @@ import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY +from tests.helpers.boring_model import BoringModel from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator @@ -44,6 +45,34 @@ def test_v1_8_0_deprecated_torchtext_batch(): _ = move_data_to_device(batch=batch, device=torch.device("cpu")) +def test_v1_8_0_on_init_start_end(tmpdir): + class TestCallback(Callback): + def on_init_start(self, trainer): + print("Starting to init trainer!") + + def on_init_end(self, trainer): + print("Trainer is init now") + + model = BoringModel() + + trainer = Trainer( + callbacks=[TestCallback()], + max_epochs=1, + fast_dev_run=True, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + with pytest.deprecated_call( + match="The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.validate(model) + + def test_v1_8_0_deprecated_call_hook(): trainer = Trainer( max_epochs=1,