Skip to content

Commit 01f5f99

Browse files
Deprecate callback hooks on_init_start and on_init_end (#10940)
1 parent aeb0b55 commit 01f5f99

File tree

11 files changed

+73
-38
lines changed

11 files changed

+73
-38
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
116116
- Deprecated the access to the attribute `IndexBatchSamplerWrapper.batch_indices` in favor of `IndexBatchSamplerWrapper.seen_batch_indices` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))
117117

118118

119+
- Deprecated `on_init_start` and `on_init_end` callback hooks ([#10940](https://github.com/PyTorchLightning/pytorch-lightning/pull/10940))
120+
121+
119122
- Deprecated `Trainer.call_hook` in favor of `Trainer._call_callback_hooks`, `Trainer._call_lightning_module_hook`, `Trainer._call_ttp_hook`, and `Trainer._call_accelerator_hook` ([#10979](https://github.com/PyTorchLightning/pytorch-lightning/pull/10979))
120123

121124

docs/source/extensions/callbacks.rst

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,15 @@ Example:
4646

4747

4848
class MyPrintingCallback(Callback):
49-
def on_init_start(self, trainer):
50-
print("Starting to initialize the trainer!")
51-
52-
def on_init_end(self, trainer):
53-
print("trainer is initialized now")
49+
def on_train_start(self, trainer, pl_module):
50+
print("Training is starting")
5451

5552
def on_train_end(self, trainer, pl_module):
56-
print("do something when training ends")
53+
print("Training is ending")
5754

5855

5956
trainer = Trainer(callbacks=[MyPrintingCallback()])
6057

61-
.. testoutput::
62-
63-
Starting to initialize the trainer!
64-
trainer is initialized now
65-
6658
We successfully extended functionality without polluting our super clean
6759
:doc:`lightning module <../common/lightning_module>` research code.
6860

docs/source/starter/introduction_guide.rst

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -956,27 +956,18 @@ for hooks that you might care about
956956

957957

958958
class MyPrintingCallback(Callback):
959-
def on_init_start(self, trainer):
960-
print("Starting to init trainer!")
961-
962-
def on_init_end(self, trainer):
963-
print("Trainer is init now")
959+
def on_train_start(self, trainer, pl_module):
960+
print("Training is starting")
964961

965962
def on_train_end(self, trainer, pl_module):
966-
print("do something when training ends")
963+
print("Training is ending")
967964

968965
And pass the callbacks into the trainer
969966

970967
.. testcode::
971968

972969
trainer = Trainer(callbacks=[MyPrintingCallback()])
973970

974-
.. testoutput::
975-
:hide:
976-
977-
Starting to init trainer!
978-
Trainer is init now
979-
980971
.. tip::
981972
See full list of 12+ hooks in the :doc:`callbacks <../extensions/callbacks>`.
982973

pytorch_lightning/callbacks/base.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,21 @@ def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage
7373
pass
7474

7575
def on_init_start(self, trainer: "pl.Trainer") -> None:
76-
"""Called when the trainer initialization begins, model has not yet been set."""
76+
r"""
77+
.. deprecated:: v1.6
78+
This callback hook was deprecated in v1.6 and will be removed in v1.8.
79+
80+
Called when the trainer initialization begins, model has not yet been set.
81+
"""
7782
pass
7883

7984
def on_init_end(self, trainer: "pl.Trainer") -> None:
80-
"""Called when the trainer initialization ends, model has not yet been set."""
85+
r"""
86+
.. deprecated:: v1.6
87+
This callback hook was deprecated in v1.6 and will be removed in v1.8.
88+
89+
Called when the trainer initialization ends, model has not yet been set.
90+
"""
8191
pass
8292

8393
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def __init__(
125125
def state_key(self) -> str:
126126
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
127127

128-
def on_init_end(self, trainer: "pl.Trainer") -> None:
128+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
129129
if self._check_on_train_epoch_end is None:
130130
# if the user runs validation multiple times per training epoch or multiple training epochs without
131131
# validation, then we run after validation instead of on train epoch end

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,13 @@ def state_key(self) -> str:
248248
save_on_train_epoch_end=self._save_on_train_epoch_end,
249249
)
250250

251-
def on_init_end(self, trainer: "pl.Trainer") -> None:
251+
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
252+
"""When pretrain routine starts we build the ckpt dir on the fly."""
252253
if self._save_on_train_epoch_end is None:
253254
# if the user runs validation multiple times per training epoch or multiple training epochs without
254255
# validation, then we run after validation instead of on train epoch end
255256
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
256257

257-
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
258-
"""When pretrain routine starts we build the ckpt dir on the fly."""
259258
self.__resolve_ckpt_dir(trainer)
260259
if trainer.is_global_zero:
261260
self.__warn_if_dir_not_empty(self.dirpath)

pytorch_lightning/callbacks/progress/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Dict, Union
14+
from typing import Dict, Optional, Union
1515

1616
import pytorch_lightning as pl
1717
from pytorch_lightning.callbacks import Callback
@@ -152,7 +152,7 @@ def print(self, *args, **kwargs):
152152
"""You should provide a way to print without breaking the progress bar."""
153153
print(*args, **kwargs)
154154

155-
def on_init_end(self, trainer):
155+
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
156156
self._trainer = trainer
157157

158158
def on_train_start(self, trainer, pl_module):

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ def verify_loop_configurations(trainer: "pl.Trainer", model: "pl.LightningModule
4343

4444
__verify_dp_batch_transfer_support(trainer, model)
4545
_check_add_get_queue(model)
46-
# TODO(@daniellepintz): Delete _check_progress_bar in v1.7
46+
# TODO: Delete _check_progress_bar in v1.7
4747
_check_progress_bar(model)
4848
# TODO: Delete _check_on_post_move_to_device in v1.7
4949
_check_on_post_move_to_device(model)
5050
# TODO: Delete _check_on_keyboard_interrupt in v1.7
5151
_check_on_keyboard_interrupt(trainer)
5252
# TODO: Remove this in v1.7 (deprecation: #9816)
5353
_check_dl_idx_in_on_train_batch_hooks(trainer, model)
54+
# TODO: Remove this in v1.8
55+
_check_on_init_start_end(trainer)
5456

5557

5658
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@@ -290,3 +292,14 @@ def _check_dl_idx_in_on_train_batch_hooks(trainer: "pl.Trainer", model: "pl.Ligh
290292
f"Base `Callback.{hook}` hook signature has changed in v1.5."
291293
" The `dataloader_idx` argument will be removed in v1.7."
292294
)
295+
296+
297+
def _check_on_init_start_end(trainer: "pl.Trainer") -> None:
298+
"""Checks if on_init_start/end are overridden and sends a deprecation warning."""
299+
for callback in trainer.callbacks:
300+
if is_overridden(method_name="on_init_start", instance=callback):
301+
rank_zero_deprecation(
302+
"The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
303+
)
304+
if is_overridden(method_name="on_init_end", instance=callback):
305+
rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.")

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1537,6 +1537,7 @@ def _call_callback_hooks(
15371537
*args: Any,
15381538
**kwargs: Any,
15391539
) -> None:
1540+
# TODO: remove if block in v1.8
15401541
if hook_name in ("on_init_start", "on_init_end"):
15411542
# these `Callback` hooks are the only ones that do not take a lightning module.
15421543
# we also don't profile bc profiler hasn't been set yet
@@ -1551,7 +1552,7 @@ def _call_callback_hooks(
15511552
prev_fx_name = pl_module._current_fx_name
15521553
pl_module._current_fx_name = hook_name
15531554

1554-
# TODO: remove if statement in v1.7
1555+
# TODO: remove if block in v1.7
15551556
if hook_name in ("on_train_batch_start", "on_train_batch_end"):
15561557
fn = getattr(self, hook_name)
15571558
if callable(fn):

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,6 @@ def test_tqdm_progress_bar_totals(tmpdir):
8181

8282
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
8383
bar = trainer.progress_bar_callback
84-
assert float("inf") == bar.total_train_batches
85-
assert 0 == bar.total_val_batches
86-
assert 0 == bar.total_test_batches
8784

8885
trainer.fit(model)
8986

@@ -584,7 +581,7 @@ def test_tqdm_progress_bar_main_bar_resume():
584581
trainer.num_val_batches = [3]
585582
trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3
586583

587-
bar.on_init_end(trainer)
584+
bar.setup(trainer, model)
588585
bar.on_train_start(trainer, model)
589586
bar.on_train_epoch_start(trainer, model)
590587

0 commit comments

Comments
 (0)