Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,6 @@ teardown
.. automethod:: pytorch_lightning.callbacks.Callback.teardown
:noindex:

on_init_start
^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_init_start
:noindex:

on_init_end
^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_init_end
:noindex:

on_fit_start
^^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the deprecated `Trainer.use_amp` and `LightningModule.use_amp` attributes ([#14832](https://github.com/Lightning-AI/lightning/pull/14832))

- Removed the deprecated callback hooks `Callback.on_init_start` and `Callback.on_init_end` ([#14867](https://github.com/Lightning-AI/lightning/pull/14867))


- Removed the deprecated `Trainer.run_stage` in favor of `Trainer.{fit,validate,test,predict}` ([#14870](https://github.com/Lightning-AI/lightning/pull/14870))

Expand Down
16 changes: 0 additions & 16 deletions src/pytorch_lightning/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,6 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Called when fit, validate, test, predict, or tune ends."""

def on_init_start(self, trainer: "pl.Trainer") -> None:
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8.

Called when the trainer initialization begins, model has not yet been set.
"""

def on_init_end(self, trainer: "pl.Trainer") -> None:
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8.

Called when the trainer initialization ends, model has not yet been set.
"""

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when fit begins."""

Expand Down
2 changes: 0 additions & 2 deletions src/pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def __init__(
setup: Optional[Callable] = None,
on_configure_sharded_model: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
on_fit_start: Optional[Callable] = None,
on_fit_end: Optional[Callable] = None,
on_sanity_check_start: Optional[Callable] = None,
Expand Down
4 changes: 0 additions & 4 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,10 +1180,6 @@ def configure_callbacks(self):
early_stop = EarlyStopping(monitor="val_acc", mode="max")
checkpoint = ModelCheckpoint(monitor="val_loss")
return [early_stop, checkpoint]

Note:
Certain callback methods like :meth:`~pytorch_lightning.callbacks.base.Callback.on_init_start`
will never be invoked on the new callbacks returned here.
"""
return []

Expand Down
7 changes: 0 additions & 7 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,6 @@ def _check_on_pretrain_routine(model: "pl.LightningModule") -> None:

def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
for callback in trainer.callbacks:
if is_overridden(method_name="on_init_start", instance=callback):
rank_zero_deprecation(
"The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
)
if is_overridden(method_name="on_init_end", instance=callback):
rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.")

if is_overridden(method_name="on_configure_sharded_model", instance=callback):
rank_zero_deprecation(
"The `on_configure_sharded_model` callback hook was deprecated in"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class _LogOptions(TypedDict):
"optimizer_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_init_start": None,
"on_init_end": None,
"on_fit_start": None,
"on_fit_end": None,
"on_sanity_check_start": None,
Expand Down
15 changes: 0 additions & 15 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,6 @@ def __init__(
accumulate_grad_batches,
)

# hook
self._call_callback_hooks("on_init_start")

# init data flags
self.check_val_every_n_epoch: Optional[int]
self._data_connector.on_trainer_init(
Expand Down Expand Up @@ -523,9 +520,6 @@ def __init__(
num_sanity_val_steps,
)

# Callback system
self._call_callback_hooks("on_init_end")

def _setup_on_init(self) -> None:
setup._log_device_info(self)

Expand Down Expand Up @@ -1333,15 +1327,6 @@ def _call_callback_hooks(
**kwargs: Any,
) -> None:
log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}")
# TODO: remove if block in v1.8
if hook_name in ("on_init_start", "on_init_end"):
# these `Callback` hooks are the only ones that do not take a lightning module.
# we also don't profile bc profiler hasn't been set yet
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
fn(self, *args, **kwargs)
return

pl_module = self.lightning_module
if pl_module:
Expand Down
9 changes: 2 additions & 7 deletions tests/tests_pytorch/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from pathlib import Path
from re import escape
from unittest.mock import call, Mock
from unittest.mock import Mock

import pytest

Expand All @@ -39,13 +39,8 @@ def configure_callbacks(self):
)

def assert_expected_calls(_trainer, model_callback, trainer_callback):
# some methods in callbacks configured through model won't get called
uncalled_methods = [call.on_init_start(_trainer), call.on_init_end(_trainer)]
for uncalled in uncalled_methods:
assert uncalled not in model_callback.method_calls

# assert that the rest of calls are the same as for trainer callbacks
expected_calls = [m for m in trainer_callback.method_calls if m not in uncalled_methods]
expected_calls = [m for m in trainer_callback.method_calls if m]
assert expected_calls
assert model_callback.method_calls == expected_calls

Expand Down
8 changes: 4 additions & 4 deletions tests/tests_pytorch/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def call(hook, *_, **__):
callbacks=[LambdaCallback(**hooks_args)],
)
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
match="`on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
Expand All @@ -65,15 +65,15 @@ def call(hook, *_, **__):
callbacks=[LambdaCallback(**hooks_args)],
)
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
match="`on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
match="`on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.test(model)
with pytest.deprecated_call(
match="`on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
match="`on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.predict(model)

Expand Down
28 changes: 0 additions & 28 deletions tests/tests_pytorch/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,6 @@
from pytorch_lightning.demos.boring_classes import BoringModel


def test_v1_8_0_on_init_start_end(tmpdir):
class TestCallback(Callback):
def on_init_start(self, trainer):
print("Starting to init trainer!")

def on_init_end(self, trainer):
print("Trainer is init now")

model = BoringModel()

trainer = Trainer(
callbacks=[TestCallback()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)
with pytest.deprecated_call(
match="The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.validate(model)


def test_v_1_8_0_deprecated_device_stats_monitor_prefix_metric_keys():
from pytorch_lightning.callbacks.device_stats_monitor import prefix_metric_keys

Expand Down
30 changes: 0 additions & 30 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,6 @@ def training_step(self, batch, batch_idx):
track_grad_norm=1,
**kwargs,
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
trainer.fit(model)
saved_ckpt = {
"callbacks": ANY,
Expand All @@ -523,8 +519,6 @@ def training_step(self, batch, batch_idx):
saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY
device = torch.device("cuda:0" if "accelerator" in kwargs and kwargs["accelerator"] == "gpu" else "cpu")
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
Expand Down Expand Up @@ -621,10 +615,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
callbacks=[callback],
track_grad_norm=1,
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]

# resume from checkpoint with HookedModel
model = HookedModel(called)
Expand All @@ -641,8 +631,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
}
saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1}
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
Expand Down Expand Up @@ -718,10 +706,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
callbacks=[callback],
track_grad_norm=1,
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]

trainer.fit(model, ckpt_path=best_model_path)
loaded_ckpt = {
Expand All @@ -736,8 +720,6 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
}
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload}
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
Expand Down Expand Up @@ -799,10 +781,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
enable_model_summary=False,
callbacks=[callback],
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
fn = getattr(trainer, verb)
fn(model, verbose=False)
hooks = [
Expand All @@ -819,8 +797,6 @@ def test_trainer_model_hook_system_eval(tmpdir, batches, verb, noun, dataloader,
dict(name=f"on_{noun}_model_train"),
]
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
Expand All @@ -843,14 +819,8 @@ def test_trainer_model_hook_system_predict(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir, limit_predict_batches=batches, enable_progress_bar=False, callbacks=[callback]
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
]
trainer.predict(model)
expected = [
dict(name="Callback.on_init_start", args=(trainer,)),
dict(name="Callback.on_init_end", args=(trainer,)),
dict(name="configure_callbacks"),
dict(name="prepare_data"),
dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ def test_fx_validator():
"on_fit_end",
"on_configure_sharded_model",
"on_fit_start",
"on_init_end",
"on_init_start",
"on_exception",
"on_load_checkpoint",
"load_state_dict",
Expand Down Expand Up @@ -90,8 +88,6 @@ def test_fx_validator():
"on_fit_end",
"on_fit_start",
"on_configure_sharded_model",
"on_init_end",
"on_init_start",
"on_exception",
"on_load_checkpoint",
"load_state_dict",
Expand Down Expand Up @@ -164,10 +160,6 @@ def call(hook, trainer=None, model=None, *_, **__):
return

lightning_module = trainer.lightning_module or model
if lightning_module is None:
# `on_init_{start,end}` do not have the `LightningModule` available
assert hook in ("on_init_start", "on_init_end")
return

if hook in not_supported:
with pytest.raises(MisconfigurationException, match=not_supported[hook]):
Expand Down