1414
1515from abc import ABC
1616from copy import deepcopy
17- from typing import Callable , List
17+ from typing import List
1818
1919from pytorch_lightning .callbacks import Callback
20+ from pytorch_lightning .core .lightning import LightningModule
2021
2122
2223class TrainerCallbackHookMixin (ABC ):
2324
2425 # this is just a summary on variables used in this abstract class,
2526 # the proper values/initialisation should be done in child class
2627 callbacks : List [Callback ] = []
27- get_model : Callable
28+ lightning_module : LightningModule
2829
2930 def on_before_accelerator_backend_setup (self , model ):
3031 """Called in the beginning of fit and test"""
@@ -39,7 +40,7 @@ def setup(self, model, stage: str):
3940 def teardown (self , stage : str ):
4041 """Called at the end of fit and test"""
4142 for callback in self .callbacks :
42- callback .teardown (self , self .get_model () , stage )
43+ callback .teardown (self , self .lightning_module , stage )
4344
4445 def on_init_start (self ):
4546 """Called when the trainer initialization begins, model has not yet been set."""
@@ -54,72 +55,72 @@ def on_init_end(self):
5455 def on_fit_start (self ):
5556 """Called when the trainer initialization begins, model has not yet been set."""
5657 for callback in self .callbacks :
57- callback .on_fit_start (self , self .get_model () )
58+ callback .on_fit_start (self , self .lightning_module )
5859
5960 def on_fit_end (self ):
6061 """Called when the trainer initialization begins, model has not yet been set."""
6162 for callback in self .callbacks :
62- callback .on_fit_end (self , self .get_model () )
63+ callback .on_fit_end (self , self .lightning_module )
6364
6465 def on_sanity_check_start (self ):
6566 """Called when the validation sanity check starts."""
6667 for callback in self .callbacks :
67- callback .on_sanity_check_start (self , self .get_model () )
68+ callback .on_sanity_check_start (self , self .lightning_module )
6869
6970 def on_sanity_check_end (self ):
7071 """Called when the validation sanity check ends."""
7172 for callback in self .callbacks :
72- callback .on_sanity_check_end (self , self .get_model () )
73+ callback .on_sanity_check_end (self , self .lightning_module )
7374
7475 def on_train_epoch_start (self ):
7576 """Called when the epoch begins."""
7677 for callback in self .callbacks :
77- callback .on_train_epoch_start (self , self .get_model () )
78+ callback .on_train_epoch_start (self , self .lightning_module )
7879
7980 def on_train_epoch_end (self , outputs ):
8081 """Called when the epoch ends."""
8182 for callback in self .callbacks :
82- callback .on_train_epoch_end (self , self .get_model () , outputs )
83+ callback .on_train_epoch_end (self , self .lightning_module , outputs )
8384
8485 def on_validation_epoch_start (self ):
8586 """Called when the epoch begins."""
8687 for callback in self .callbacks :
87- callback .on_validation_epoch_start (self , self .get_model () )
88+ callback .on_validation_epoch_start (self , self .lightning_module )
8889
8990 def on_validation_epoch_end (self ):
9091 """Called when the epoch ends."""
9192 for callback in self .callbacks :
92- callback .on_validation_epoch_end (self , self .get_model () )
93+ callback .on_validation_epoch_end (self , self .lightning_module )
9394
9495 def on_test_epoch_start (self ):
9596 """Called when the epoch begins."""
9697 for callback in self .callbacks :
97- callback .on_test_epoch_start (self , self .get_model () )
98+ callback .on_test_epoch_start (self , self .lightning_module )
9899
99100 def on_test_epoch_end (self ):
100101 """Called when the epoch ends."""
101102 for callback in self .callbacks :
102- callback .on_test_epoch_end (self , self .get_model () )
103+ callback .on_test_epoch_end (self , self .lightning_module )
103104
104105 def on_epoch_start (self ):
105106 """Called when the epoch begins."""
106107 for callback in self .callbacks :
107- callback .on_epoch_start (self , self .get_model () )
108+ callback .on_epoch_start (self , self .lightning_module )
108109
109110 def on_epoch_end (self ):
110111 """Called when the epoch ends."""
111112 for callback in self .callbacks :
112- callback .on_epoch_end (self , self .get_model () )
113+ callback .on_epoch_end (self , self .lightning_module )
113114
114115 def on_train_start (self ):
115116 """Called when the train begins."""
116117 for callback in self .callbacks :
117- callback .on_train_start (self , self .get_model () )
118+ callback .on_train_start (self , self .lightning_module )
118119
119120 def on_train_end (self ):
120121 """Called when the train ends."""
121122 for callback in self .callbacks :
122- callback .on_train_end (self , self .get_model () )
123+ callback .on_train_end (self , self .lightning_module )
123124
124125 def on_pretrain_routine_start (self , model ):
125126 """Called when the train begins."""
@@ -134,74 +135,74 @@ def on_pretrain_routine_end(self, model):
134135 def on_batch_start (self ):
135136 """Called when the training batch begins."""
136137 for callback in self .callbacks :
137- callback .on_batch_start (self , self .get_model () )
138+ callback .on_batch_start (self , self .lightning_module )
138139
139140 def on_batch_end (self ):
140141 """Called when the training batch ends."""
141142 for callback in self .callbacks :
142- callback .on_batch_end (self , self .get_model () )
143+ callback .on_batch_end (self , self .lightning_module )
143144
144145 def on_train_batch_start (self , batch , batch_idx , dataloader_idx ):
145146 """Called when the training batch begins."""
146147 for callback in self .callbacks :
147- callback .on_train_batch_start (self , self .get_model () , batch , batch_idx , dataloader_idx )
148+ callback .on_train_batch_start (self , self .lightning_module , batch , batch_idx , dataloader_idx )
148149
149150 def on_train_batch_end (self , outputs , batch , batch_idx , dataloader_idx ):
150151 """Called when the training batch ends."""
151152 for callback in self .callbacks :
152- callback .on_train_batch_end (self , self .get_model () , outputs , batch , batch_idx , dataloader_idx )
153+ callback .on_train_batch_end (self , self .lightning_module , outputs , batch , batch_idx , dataloader_idx )
153154
154155 def on_validation_batch_start (self , batch , batch_idx , dataloader_idx ):
155156 """Called when the validation batch begins."""
156157 for callback in self .callbacks :
157- callback .on_validation_batch_start (self , self .get_model () , batch , batch_idx , dataloader_idx )
158+ callback .on_validation_batch_start (self , self .lightning_module , batch , batch_idx , dataloader_idx )
158159
159160 def on_validation_batch_end (self , outputs , batch , batch_idx , dataloader_idx ):
160161 """Called when the validation batch ends."""
161162 for callback in self .callbacks :
162- callback .on_validation_batch_end (self , self .get_model () , outputs , batch , batch_idx , dataloader_idx )
163+ callback .on_validation_batch_end (self , self .lightning_module , outputs , batch , batch_idx , dataloader_idx )
163164
164165 def on_test_batch_start (self , batch , batch_idx , dataloader_idx ):
165166 """Called when the test batch begins."""
166167 for callback in self .callbacks :
167- callback .on_test_batch_start (self , self .get_model () , batch , batch_idx , dataloader_idx )
168+ callback .on_test_batch_start (self , self .lightning_module , batch , batch_idx , dataloader_idx )
168169
169170 def on_test_batch_end (self , outputs , batch , batch_idx , dataloader_idx ):
170171 """Called when the test batch ends."""
171172 for callback in self .callbacks :
172- callback .on_test_batch_end (self , self .get_model () , outputs , batch , batch_idx , dataloader_idx )
173+ callback .on_test_batch_end (self , self .lightning_module , outputs , batch , batch_idx , dataloader_idx )
173174
174175 def on_validation_start (self ):
175176 """Called when the validation loop begins."""
176177 for callback in self .callbacks :
177- callback .on_validation_start (self , self .get_model () )
178+ callback .on_validation_start (self , self .lightning_module )
178179
179180 def on_validation_end (self ):
180181 """Called when the validation loop ends."""
181182 for callback in self .callbacks :
182- callback .on_validation_end (self , self .get_model () )
183+ callback .on_validation_end (self , self .lightning_module )
183184
184185 def on_test_start (self ):
185186 """Called when the test begins."""
186187 for callback in self .callbacks :
187- callback .on_test_start (self , self .get_model () )
188+ callback .on_test_start (self , self .lightning_module )
188189
189190 def on_test_end (self ):
190191 """Called when the test ends."""
191192 for callback in self .callbacks :
192- callback .on_test_end (self , self .get_model () )
193+ callback .on_test_end (self , self .lightning_module )
193194
194195 def on_keyboard_interrupt (self ):
195196 """Called when the training is interrupted by KeyboardInterrupt."""
196197 for callback in self .callbacks :
197- callback .on_keyboard_interrupt (self , self .get_model () )
198+ callback .on_keyboard_interrupt (self , self .lightning_module )
198199
199200 def on_save_checkpoint (self ):
200201 """Called when saving a model checkpoint."""
201202 callback_states = {}
202203 for callback in self .callbacks :
203204 callback_class = type (callback )
204- state = callback .on_save_checkpoint (self , self .get_model () )
205+ state = callback .on_save_checkpoint (self , self .lightning_module )
205206 if state :
206207 callback_states [callback_class ] = state
207208 return callback_states
@@ -224,11 +225,11 @@ def on_after_backward(self):
224225 Called after loss.backward() and before optimizers do anything.
225226 """
226227 for callback in self .callbacks :
227- callback .on_after_backward (self , self .get_model () )
228+ callback .on_after_backward (self , self .lightning_module )
228229
229230 def on_before_zero_grad (self , optimizer ):
230231 """
231232 Called after optimizer.step() and before optimizer.zero_grad().
232233 """
233234 for callback in self .callbacks :
234- callback .on_before_zero_grad (self , self .get_model () , optimizer )
235+ callback .on_before_zero_grad (self , self .lightning_module , optimizer )
0 commit comments