Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))


- Deprecated `on_init_start` and `on_init_end` callback hooks ([#10940](https://github.com/PyTorchLightning/pytorch-lightning/pull/10940))


- Deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#10979](https://github.com/PyTorchLightning/pytorch-lightning/pull/10979))


Expand Down
14 changes: 3 additions & 11 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,15 @@ Example:


class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print("Starting to initialize the trainer!")

def on_init_end(self, trainer):
print("trainer is initialized now")
def on_train_start(self, trainer, pl_module):
print("Training is starting")

def on_train_end(self, trainer, pl_module):
print("do something when training ends")
print("Training is ending")


trainer = Trainer(callbacks=[MyPrintingCallback()])

.. testoutput::

Starting to initialize the trainer!
trainer is initialized now

We successfully extended functionality without polluting our super clean
:doc:`lightning module <../common/lightning_module>` research code.

Expand Down
15 changes: 3 additions & 12 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -956,27 +956,18 @@ for hooks that you might care about


class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print("Starting to init trainer!")

def on_init_end(self, trainer):
print("Trainer is init now")
def on_train_start(self, trainer, pl_module):
print("Training is starting")

def on_train_end(self, trainer, pl_module):
print("do something when training ends")
print("Training is ending")

And pass the callbacks into the trainer

.. testcode::

trainer = Trainer(callbacks=[MyPrintingCallback()])

.. testoutput::
:hide:

Starting to init trainer!
Trainer is init now

.. tip::
See full list of 12+ hooks in the :doc:`callbacks <../extensions/callbacks>`.

Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,21 @@ def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage
pass

def on_init_start(self, trainer: "pl.Trainer") -> None:
"""Called when the trainer initialization begins, model has not yet been set."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8.

Called when the trainer initialization begins, model has not yet been set.
"""
pass

def on_init_end(self, trainer: "pl.Trainer") -> None:
"""Called when the trainer initialization ends, model has not yet been set."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8.

Called when the trainer initialization ends, model has not yet been set.
"""
pass

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
def state_key(self) -> str:
return self._generate_state_key(monitor=self.monitor, mode=self.mode)

def on_init_end(self, trainer: "pl.Trainer") -> None:
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if self._check_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,14 +248,13 @@ def state_key(self) -> str:
save_on_train_epoch_end=self._save_on_train_epoch_end,
)

def on_init_end(self, trainer: "pl.Trainer") -> None:
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero:
self.__warn_if_dir_not_empty(self.dirpath)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
from typing import Dict, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -152,7 +152,7 @@ def print(self, *args, **kwargs):
"""You should provide a way to print without breaking the progress bar."""
print(*args, **kwargs)

def on_init_end(self, trainer):
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self._trainer = trainer

def on_train_start(self, trainer, pl_module):
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule

__verify_dp_batch_transfer_support(trainer, model)
_check_add_get_queue(model)
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
# TODO: Delete _check_progress_bar in v1.7
_check_progress_bar(model)
# TODO: Delete _check_on_post_move_to_device in v1.7
_check_on_post_move_to_device(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
_check_on_keyboard_interrupt(trainer)
# TODO: Remove this in v1.7 (deprecation: #9816)
_check_dl_idx_in_on_train_batch_hooks(trainer, model)
# TODO: Remove this in v1.8
_check_on_init_start_end(trainer)


def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -290,3 +292,14 @@ def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer", model: "pl.Ligh
f"Base `Callback.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)


def _check_on_init_start_end(trainer: "pl.Trainer") -> None:
"""Checks if on_init_start/end are overridden and sends a deprecation warning."""
for callback in trainer.callbacks:
if is_overridden(method_name="on_init_start", instance=callback):
rank_zero_deprecation(
"The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
)
if is_overridden(method_name="on_init_end", instance=callback):
rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.")
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,7 @@ def _call_callback_hooks(
*args: Any,
**kwargs: Any,
) -> None:
# TODO: remove if block in v1.8
if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module.
# we also don't profile bc profiler hasn't been set yet
Expand All @@ -1551,7 +1552,7 @@ def _call_callback_hooks(
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

# TODO: remove if statement in v1.7
# TODO: remove if block in v1.7
if hook_name in ("on_train_batch_start", "on_train_batch_end"):
fn = getattr(self, hook_name)
if callable(fn):
Expand Down
5 changes: 1 addition & 4 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ def test_tqdm_progress_bar_totals(tmpdir):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
bar = trainer.progress_bar_callback
assert float("inf") == bar.total_train_batches
assert 0 == bar.total_val_batches
assert 0 == bar.total_test_batches

trainer.fit(model)

Expand Down Expand Up @@ -584,7 +581,7 @@ def test_tqdm_progress_bar_main_bar_resume():
trainer.num_val_batches = [3]
trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3

bar.on_init_end(trainer)
bar.setup(trainer, model)
bar.on_train_start(trainer, model)
bar.on_train_epoch_start(trainer, model)

Expand Down
31 changes: 30 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from tests.helpers.boring_model import BoringModel
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator


Expand All @@ -44,6 +45,34 @@ def test_v1_8_0_deprecated_torchtext_batch():
_ = move_data_to_device(batch=batch, device=torch.device("cpu"))


def test_v1_8_0_on_init_start_end(tmpdir):
class TestCallback(Callback):
def on_init_start(self, trainer):
print("Starting to init trainer!")

def on_init_end(self, trainer):
print("Trainer is init now")

model = BoringModel()

trainer = Trainer(
callbacks=[TestCallback()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)
with pytest.deprecated_call(
match="The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.validate(model)


def test_v1_8_0_deprecated_call_hook():
trainer = Trainer(
max_epochs=1,
Expand Down