Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `on_pretrain_routine_start` and `on_pretrain_routine_end` callback hooks in favor of `on_fit_start` ([#11794](https://github.com/PyTorchLightning/pytorch-lightning/pull/11794))


- Deprecated `LightningModule.on_pretrain_routine_start` and `LightningModule.on_pretrain_routine_end` hooks in favor of `on_fit_start` ([#12122](https://github.com/PyTorchLightning/pytorch-lightning/pull/12122))


- Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))


Expand Down
15 changes: 0 additions & 15 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1225,9 +1225,6 @@ for more information.
setup("fit")
configure_optimizers()

on_pretrain_routine_start()
on_pretrain_routine_end()

# the sanity check runs here

on_train_start()
Expand Down Expand Up @@ -1391,18 +1388,6 @@ on_validation_end
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_validation_end
:noindex:

on_pretrain_routine_start
~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_start
:noindex:

on_pretrain_routine_end
~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_pretrain_routine_end
:noindex:

on_test_batch_start
~~~~~~~~~~~~~~~~~~~

Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def on_pretrain_routine_start(self) -> None:
- pretrain_routine start
- pretrain_routine end
- training_start

.. deprecated:: v1.6
:meth:`on_pretrain_routine_start` has been deprecated in v1.6 and will be removed in v1.8.
Use ``on_fit_start`` instead.
"""

def on_pretrain_routine_end(self) -> None:
Expand All @@ -77,6 +81,10 @@ def on_pretrain_routine_end(self) -> None:
- pretrain_routine start
- pretrain_routine end
- training_start

.. deprecated:: v1.6
:meth:`on_pretrain_routine_end` has been deprecated in v1.6 and will be removed in v1.8.
Use ``on_fit_start`` instead.
"""

def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]:
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_on_epoch_start_end(model)
# TODO: Delete CheckpointHooks off PrecisionPlugin in v1.8
_check_precision_plugin_checkpoint_hooks(trainer)
# TODO: Delete on_pretrain_routine_start/end hooks in v1.8
_check_on_pretrain_routine(model)


def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -306,6 +308,16 @@ def _check_on_epoch_start_end(model: "pl.LightningModule") -> None:
)


def _check_on_pretrain_routine(model: "pl.LightningModule") -> None:
hooks = (("on_pretrain_routine_start", "on_fit_start"), ("on_pretrain_routine_end", "on_fit_start"))
for hook, alternative_hook in hooks:
if is_overridden(hook, model):
rank_zero_deprecation(
f"The `LightningModule.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead."
)


def _check_dl_idx_in_on_train_batch_hooks(model: "pl.LightningModule") -> None:
for hook in ("on_train_batch_start", "on_train_batch_end"):
if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True):
Expand Down
31 changes: 31 additions & 0 deletions tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,37 @@ def on_epoch_end(self, *args, **kwargs):
trainer.fit(model)


def test_v1_8_0_remove_on_pretrain_routine_start_end_lightning_module(tmpdir):
class CustomModel(BoringModel):
def on_pretrain_routine_start(self, *args, **kwargs):
print("foo")

model = CustomModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `LightningModule.on_pretrain_routine_start` hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)

class CustomModel(BoringModel):
def on_pretrain_routine_end(self, *args, **kwargs):
print("foo")

trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
)

model = CustomModel()
with pytest.deprecated_call(
match="The `LightningModule.on_pretrain_routine_end` hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)


def test_v1_8_0_rank_zero_imports():

import warnings
Expand Down
22 changes: 7 additions & 15 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers import BoringModel
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_correct_step_and_epoch(tmpdir):
assert trainer.global_step == 0

class TestModel(BoringModel):
def on_pretrain_routine_end(self) -> None:
def on_train_start(self) -> None:
assert self.trainer.current_epoch == first_max_epochs
# TODO(@carmocca): should not need `+1`
assert self.trainer.global_step == first_max_epochs * train_batches + 1
Expand Down Expand Up @@ -610,26 +610,18 @@ def test_dp_resume(tmpdir):
class CustomModel(CustomClassificationModelDP):
def __init__(self):
super().__init__()
self.on_pretrain_routine_end_called = False
self.on_train_start_called = False

# set the epoch start hook so we can predict before the model does the full training
def on_pretrain_routine_end(self):
def on_validation_start(self):
assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

# if model and state loaded correctly, predictions will be good even though we
# haven't trained with the new loaded model
new_trainer.state.stage = RunningStage.VALIDATING

dataloader = dm.train_dataloader()
dataloader = dm.val_dataloader()
tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)
self.on_pretrain_routine_end_called = True

# new model
model = CustomModel()

# fit new model which should load hpc weights
new_trainer.fit(model, datamodule=dm)
assert model.on_pretrain_routine_end_called
# validate new model which should load hpc weights
new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path)

# test freeze on gpu
model.freeze()
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, num_workers):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), num_workers=self.num_workers)

def on_pretrain_routine_start(self):
def on_fit_start(self):
self._resout = StringIO()
self.ctx = redirect_stderr(self._resout)
self.ctx.__enter__()
Expand Down