Skip to content

Commit e84ae3e

Browse files
rohitgr7awaelchli
authored andcommitted
Add on_epoch_start to run at the beginning of every loop irrespective of train/val/test (#6498)
* update docs * add hook and update docs * update tests * chlog * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <[email protected]> * chlog Co-authored-by: Adrian Wälchli <[email protected]>
1 parent ad6b20f commit e84ae3e

File tree

15 files changed

+135
-32
lines changed

15 files changed

+135
-32
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
### Changed
1010

11-
-
11+
- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498))
1212

1313
### Fixed
1414

docs/source/common/lightning_module.rst

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
10391039
teardown()
10401040
10411041
def train_loop():
1042+
on_epoch_start()
10421043
on_train_epoch_start()
10431044
train_outs = []
10441045
for train_batch in train_dataloader():
@@ -1062,12 +1063,15 @@ This is the pseudocode to describe how all the hooks are called during a call to
10621063
val_loop()
10631064
10641065
# end training epoch
1065-
logs = training_epoch_end(outs)
1066+
outs = training_epoch_end(outs)
1067+
on_train_epoch_end(outs)
1068+
on_epoch_end()
10661069
10671070
def val_loop():
10681071
model.eval()
10691072
torch.set_grad_enabled(False)
10701073
1074+
on_epoch_start()
10711075
on_validation_epoch_start()
10721076
val_outs = []
10731077
for val_batch in val_dataloader():
@@ -1081,6 +1085,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
10811085
10821086
validation_epoch_end(val_outs)
10831087
on_validation_epoch_end()
1088+
on_epoch_end()
10841089
10851090
# set up for train
10861091
model.train()
@@ -1108,12 +1113,12 @@ manual_backward
11081113
on_after_backward
11091114
~~~~~~~~~~~~~~~~~
11101115

1111-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward
1116+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward
11121117
:noindex:
11131118

11141119
on_before_zero_grad
11151120
~~~~~~~~~~~~~~~~~~~
1116-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
1121+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad
11171122
:noindex:
11181123

11191124
on_fit_start
@@ -1132,15 +1137,38 @@ on_fit_end
11321137
on_load_checkpoint
11331138
~~~~~~~~~~~~~~~~~~
11341139

1135-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
1140+
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint
11361141
:noindex:
11371142

11381143
on_save_checkpoint
11391144
~~~~~~~~~~~~~~~~~~
11401145

1141-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
1146+
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint
11421147
:noindex:
11431148

1149+
on_train_start
1150+
~~~~~~~~~~~~~~
1151+
1152+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start
1153+
:noindex:
1154+
1155+
on_train_end
1156+
~~~~~~~~~~~~
1157+
1158+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end
1159+
:noindex:
1160+
1161+
on_validation_start
1162+
~~~~~~~~~~~~~~~~~~~
1163+
1164+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start
1165+
:noindex:
1166+
1167+
on_validation_end
1168+
~~~~~~~~~~~~~~~~~
1169+
1170+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end
1171+
:noindex:
11441172

11451173
on_pretrain_routine_start
11461174
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1178,6 +1206,11 @@ on_test_epoch_end
11781206
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
11791207
:noindex:
11801208

1209+
on_test_end
1210+
~~~~~~~~~~~
1211+
1212+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end
1213+
:noindex:
11811214

11821215
on_train_batch_start
11831216
~~~~~~~~~~~~~~~~~~~~
@@ -1191,6 +1224,18 @@ on_train_batch_end
11911224
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
11921225
:noindex:
11931226

1227+
on_epoch_start
1228+
~~~~~~~~~~~~~~
1229+
1230+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start
1231+
:noindex:
1232+
1233+
on_epoch_end
1234+
~~~~~~~~~~~~
1235+
1236+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end
1237+
:noindex:
1238+
11941239
on_train_epoch_start
11951240
~~~~~~~~~~~~~~~~~~~~
11961241

@@ -1227,6 +1272,36 @@ on_validation_epoch_end
12271272
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
12281273
:noindex:
12291274

1275+
on_post_move_to_device
1276+
~~~~~~~~~~~~~~~~~~~~~~
1277+
1278+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device
1279+
:noindex:
1280+
1281+
on_validation_model_eval
1282+
~~~~~~~~~~~~~~~~~~~~~~~~
1283+
1284+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval
1285+
:noindex:
1286+
1287+
on_validation_model_train
1288+
~~~~~~~~~~~~~~~~~~~~~~~~~
1289+
1290+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train
1291+
:noindex:
1292+
1293+
on_test_model_eval
1294+
~~~~~~~~~~~~~~~~~~
1295+
1296+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval
1297+
:noindex:
1298+
1299+
on_test_model_train
1300+
~~~~~~~~~~~~~~~~~~~
1301+
1302+
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
1303+
:noindex:
1304+
12301305
optimizer_step
12311306
~~~~~~~~~~~~~~
12321307

@@ -1266,19 +1341,19 @@ teardown
12661341
train_dataloader
12671342
~~~~~~~~~~~~~~~~
12681343

1269-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader
1344+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader
12701345
:noindex:
12711346

12721347
val_dataloader
12731348
~~~~~~~~~~~~~~
12741349

1275-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader
1350+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader
12761351
:noindex:
12771352

12781353
test_dataloader
12791354
~~~~~~~~~~~~~~~
12801355

1281-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader
1356+
.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader
12821357
:noindex:
12831358

12841359
transfer_batch_to_device

docs/source/extensions/callbacks.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,15 @@ on_load_checkpoint
349349

350350
.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
351351
:noindex:
352+
353+
on_after_backward
354+
^^^^^^^^^^^^^^^^^
355+
356+
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
357+
:noindex:
358+
359+
on_before_zero_grad
360+
^^^^^^^^^^^^^^^^^^^
361+
362+
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad
363+
:noindex:

docs/source/extensions/logging.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
9090
.. note::
9191

9292
- Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a
93-
reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.
93+
reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.
9494

9595
- Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with
9696
suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor`

pytorch_lightning/callbacks/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
102102
pass
103103

104104
def on_epoch_start(self, trainer, pl_module: LightningModule) -> None:
105-
"""Called when the epoch begins."""
105+
"""Called when either of train/val/test epoch begins."""
106106
pass
107107

108108
def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
109-
"""Called when the epoch ends."""
109+
"""Called when either of train/val/test epoch ends."""
110110
pass
111111

112112
def on_batch_start(self, trainer, pl_module: LightningModule) -> None:

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]):
7474
def going_to_accumulate_grad_batches(self):
7575
return any([v > 1 for v in self.scheduling.values()])
7676

77-
def on_epoch_start(self, trainer, pl_module):
77+
def on_train_epoch_start(self, trainer, pl_module):
7878
epoch = trainer.current_epoch
7979
for i in reversed(range(len(self.epochs))):
8080
if epoch >= self.epochs[i]:

pytorch_lightning/callbacks/progress.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def on_init_end(self, trainer):
192192
def on_train_start(self, trainer, pl_module):
193193
self._train_batch_idx = trainer.batch_idx
194194

195-
def on_epoch_start(self, trainer, pl_module):
195+
def on_train_epoch_start(self, trainer, pl_module):
196196
self._train_batch_idx = 0
197197

198198
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@@ -383,8 +383,8 @@ def on_train_start(self, trainer, pl_module):
383383
super().on_train_start(trainer, pl_module)
384384
self.main_progress_bar = self.init_train_tqdm()
385385

386-
def on_epoch_start(self, trainer, pl_module):
387-
super().on_epoch_start(trainer, pl_module)
386+
def on_train_epoch_start(self, trainer, pl_module):
387+
super().on_train_epoch_start(trainer, pl_module)
388388
total_train_batches = self.total_train_batches
389389
total_val_batches = self.total_val_batches
390390
if total_train_batches != float('inf'):

pytorch_lightning/core/hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,13 @@ def on_predict_model_eval(self) -> None:
224224

225225
def on_epoch_start(self) -> None:
226226
"""
227-
Called in the training loop at the very beginning of the epoch.
227+
Called when either of train/val/test epoch begins.
228228
"""
229229
# do something when the epoch starts
230230

231231
def on_epoch_end(self) -> None:
232232
"""
233-
Called in the training loop at the very end of the epoch.
233+
Called when either of train/val/test epoch ends.
234234
"""
235235
# do something when the epoch ends
236236

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,10 +706,13 @@ def validation_step(self, *args, **kwargs):
706706
.. code-block:: python
707707
708708
# pseudocode of order
709-
out = validation_step()
710-
if defined('validation_step_end'):
711-
out = validation_step_end(out)
712-
out = validation_epoch_end(out)
709+
val_outs = []
710+
for val_batch in val_data:
711+
out = validation_step(val_batch)
712+
if defined('validation_step_end'):
713+
out = validation_step_end(out)
714+
val_outs.append(out)
715+
val_outs = validation_epoch_end(val_outs)
713716
714717
715718
.. code-block:: python

pytorch_lightning/trainer/callback_hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,12 @@ def on_test_epoch_end(self):
105105
callback.on_test_epoch_end(self, self.lightning_module)
106106

107107
def on_epoch_start(self):
108-
"""Called when the epoch begins."""
108+
"""Called when either of train/val/test epoch begins."""
109109
for callback in self.callbacks:
110110
callback.on_epoch_start(self, self.lightning_module)
111111

112112
def on_epoch_end(self):
113-
"""Called when the epoch ends."""
113+
"""Called when either of train/val/test epoch ends."""
114114
for callback in self.callbacks:
115115
callback.on_epoch_end(self, self.lightning_module)
116116

0 commit comments

Comments
 (0)