Skip to content

Commit 409d6e1

Browse files
rohitgr7SkafteNicki
authored andcommitted
Add optimizer hooks in callbacks (#4379)
* Add optimizer hooks in callbacks * optimizer param * update test Co-authored-by: Nicki Skafte <[email protected]> (cherry picked from commit b26c71e)
1 parent e9c61cc commit 409d6e1

File tree

4 files changed

+45
-2
lines changed

4 files changed

+45
-2
lines changed

pytorch_lightning/callbacks/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,15 @@ def on_save_checkpoint(self, trainer, pl_module):
166166
def on_load_checkpoint(self, checkpointed_state):
167167
"""Called when loading a model checkpoint, use to reload state."""
168168
pass
169+
170+
def on_after_backward(self, trainer, pl_module):
171+
"""
172+
Called after loss.backward() and before optimizers do anything.
173+
"""
174+
pass
175+
176+
def on_before_zero_grad(self, trainer, pl_module, optimizer):
177+
"""
178+
Called after optimizer.step() and before optimizer.zero_grad().
179+
"""
180+
pass

pytorch_lightning/trainer/callback_hook.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,17 @@ def on_load_checkpoint(self, checkpoint):
209209
if state:
210210
state = deepcopy(state)
211211
callback.on_load_checkpoint(state)
212+
213+
def on_after_backward(self):
214+
"""
215+
Called after loss.backward() and before optimizers do anything.
216+
"""
217+
for callback in self.callbacks:
218+
callback.on_after_backward(self, self.get_model())
219+
220+
def on_before_zero_grad(self, optimizer):
221+
"""
222+
Called after optimizer.step() and before optimizer.zero_grad().
223+
"""
224+
for callback in self.callbacks:
225+
callback.on_before_zero_grad(self, self.get_model(), optimizer)

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,7 @@ def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_
463463
)
464464

465465
def on_before_zero_grad(self, optimizer):
466-
model = self.trainer.get_model()
467-
model.on_before_zero_grad(optimizer)
466+
self.trainer.call_hook('on_before_zero_grad', optimizer)
468467

469468
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
470469
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

tests/callbacks/test_callbacks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(self):
5555
self.on_validation_end_called = False
5656
self.on_test_start_called = False
5757
self.on_test_end_called = False
58+
self.on_after_backward_called = False
59+
self.on_before_zero_grad_called = False
5860

5961
def setup(self, trainer, pl_module, stage: str):
6062
assert isinstance(trainer, Trainer)
@@ -160,6 +162,14 @@ def on_test_end(self, trainer, pl_module):
160162
_check_args(trainer, pl_module)
161163
self.on_test_end_called = True
162164

165+
def on_after_backward(self, trainer, pl_module):
166+
_check_args(trainer, pl_module)
167+
self.on_after_backward_called = True
168+
169+
def on_before_zero_grad(self, trainer, pl_module, optimizer):
170+
_check_args(trainer, pl_module)
171+
self.on_before_zero_grad_called = True
172+
163173
test_callback = TestCallback()
164174

165175
trainer_options = dict(
@@ -197,6 +207,8 @@ def on_test_end(self, trainer, pl_module):
197207
assert not test_callback.on_validation_end_called
198208
assert not test_callback.on_test_start_called
199209
assert not test_callback.on_test_end_called
210+
assert not test_callback.on_after_backward_called
211+
assert not test_callback.on_before_zero_grad_called
200212

201213
# fit model
202214
trainer = Trainer(**trainer_options)
@@ -228,6 +240,8 @@ def on_test_end(self, trainer, pl_module):
228240
assert not test_callback.on_validation_end_called
229241
assert not test_callback.on_test_start_called
230242
assert not test_callback.on_test_end_called
243+
assert not test_callback.on_after_backward_called
244+
assert not test_callback.on_before_zero_grad_called
231245

232246
trainer.fit(model)
233247

@@ -257,6 +271,8 @@ def on_test_end(self, trainer, pl_module):
257271
assert not test_callback.on_test_batch_end_called
258272
assert not test_callback.on_test_start_called
259273
assert not test_callback.on_test_end_called
274+
assert test_callback.on_after_backward_called
275+
assert test_callback.on_before_zero_grad_called
260276

261277
# reset setup teardown callback
262278
test_callback.teardown_called = False
@@ -277,3 +293,5 @@ def on_test_end(self, trainer, pl_module):
277293
assert not test_callback.on_validation_end_called
278294
assert not test_callback.on_validation_batch_end_called
279295
assert not test_callback.on_validation_batch_start_called
296+
assert not test_callback.on_after_backward_called
297+
assert not test_callback.on_before_zero_grad_called

0 commit comments

Comments
 (0)