@@ -55,6 +55,8 @@ def __init__(self):
5555 self .on_validation_end_called = False
5656 self .on_test_start_called = False
5757 self .on_test_end_called = False
58+ self .on_after_backward_called = False
59+ self .on_before_zero_grad_called = False
5860
5961 def setup (self , trainer , pl_module , stage : str ):
6062 assert isinstance (trainer , Trainer )
@@ -160,6 +162,14 @@ def on_test_end(self, trainer, pl_module):
160162 _check_args (trainer , pl_module )
161163 self .on_test_end_called = True
162164
165+ def on_after_backward (self , trainer , pl_module ):
166+ _check_args (trainer , pl_module )
167+ self .on_after_backward_called = True
168+
169+ def on_before_zero_grad (self , trainer , pl_module , optimizer ):
170+ _check_args (trainer , pl_module )
171+ self .on_before_zero_grad_called = True
172+
163173 test_callback = TestCallback ()
164174
165175 trainer_options = dict (
@@ -197,6 +207,8 @@ def on_test_end(self, trainer, pl_module):
197207 assert not test_callback .on_validation_end_called
198208 assert not test_callback .on_test_start_called
199209 assert not test_callback .on_test_end_called
210+ assert not test_callback .on_after_backward_called
211+ assert not test_callback .on_before_zero_grad_called
200212
201213 # fit model
202214 trainer = Trainer (** trainer_options )
@@ -228,6 +240,8 @@ def on_test_end(self, trainer, pl_module):
228240 assert not test_callback .on_validation_end_called
229241 assert not test_callback .on_test_start_called
230242 assert not test_callback .on_test_end_called
243+ assert not test_callback .on_after_backward_called
244+ assert not test_callback .on_before_zero_grad_called
231245
232246 trainer .fit (model )
233247
@@ -257,6 +271,8 @@ def on_test_end(self, trainer, pl_module):
257271 assert not test_callback .on_test_batch_end_called
258272 assert not test_callback .on_test_start_called
259273 assert not test_callback .on_test_end_called
274+ assert test_callback .on_after_backward_called
275+ assert test_callback .on_before_zero_grad_called
260276
261277 # reset setup teardown callback
262278 test_callback .teardown_called = False
@@ -277,3 +293,5 @@ def on_test_end(self, trainer, pl_module):
277293 assert not test_callback .on_validation_end_called
278294 assert not test_callback .on_validation_batch_end_called
279295 assert not test_callback .on_validation_batch_start_called
296+ assert not test_callback .on_after_backward_called
297+ assert not test_callback .on_before_zero_grad_called
0 commit comments