1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414r"""
15- Abstract base class used to build new callbacks.
15+ Base class used to build new callbacks.
1616
1717"""
1818
19- import abc
2019from typing import Any , Dict , List , Optional , Type
2120
2221import torch
2625from pytorch_lightning .utilities .types import STEP_OUTPUT
2726
2827
29- class Callback ( abc . ABC ) :
28+ class Callback :
3029 r"""
3130 Abstract base class used to build new callbacks.
3231
@@ -62,15 +61,12 @@ def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.Light
6261
6362 def on_before_accelerator_backend_setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
6463 """Called before accelerator is being setup."""
65- pass
6664
6765 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : Optional [str ] = None ) -> None :
6866 """Called when fit, validate, test, predict, or tune begins."""
69- pass
7067
7168 def teardown (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : Optional [str ] = None ) -> None :
7269 """Called when fit, validate, test, predict, or tune ends."""
73- pass
7470
7571 def on_init_start (self , trainer : "pl.Trainer" ) -> None :
7672 r"""
@@ -79,7 +75,6 @@ def on_init_start(self, trainer: "pl.Trainer") -> None:
7975
8076 Called when the trainer initialization begins, model has not yet been set.
8177 """
82- pass
8378
8479 def on_init_end (self , trainer : "pl.Trainer" ) -> None :
8580 r"""
@@ -88,23 +83,18 @@ def on_init_end(self, trainer: "pl.Trainer") -> None:
8883
8984 Called when the trainer initialization ends, model has not yet been set.
9085 """
91- pass
9286
9387 def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
9488 """Called when fit begins."""
95- pass
9689
9790 def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
9891 """Called when fit ends."""
99- pass
10092
10193 def on_sanity_check_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
10294 """Called when the validation sanity check starts."""
103- pass
10495
10596 def on_sanity_check_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
10697 """Called when the validation sanity check ends."""
107- pass
10898
10999 def on_train_batch_start (
110100 self ,
@@ -115,7 +105,6 @@ def on_train_batch_start(
115105 unused : int = 0 ,
116106 ) -> None :
117107 """Called when the train batch begins."""
118- pass
119108
120109 def on_train_batch_end (
121110 self ,
@@ -127,11 +116,9 @@ def on_train_batch_end(
127116 unused : int = 0 ,
128117 ) -> None :
129118 """Called when the train batch ends."""
130- pass
131119
132120 def on_train_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
133121 """Called when the train epoch begins."""
134- pass
135122
136123 def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
137124 """Called when the train epoch ends.
@@ -141,53 +128,41 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
141128 1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
142129 2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
143130 """
144- pass
145131
146132 def on_validation_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
147133 """Called when the val epoch begins."""
148- pass
149134
150135 def on_validation_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
151136 """Called when the val epoch ends."""
152- pass
153137
154138 def on_test_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
155139 """Called when the test epoch begins."""
156- pass
157140
158141 def on_test_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
159142 """Called when the test epoch ends."""
160- pass
161143
162144 def on_predict_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
163145 """Called when the predict epoch begins."""
164- pass
165146
166147 def on_predict_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , outputs : List [Any ]) -> None :
167148 """Called when the predict epoch ends."""
168- pass
169149
170150 def on_epoch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
171151 """Called when either of train/val/test epoch begins."""
172- pass
173152
174153 def on_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
175154 """Called when either of train/val/test epoch ends."""
176- pass
177155
178156 def on_batch_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
179157 """Called when the training batch begins."""
180- pass
181158
182159 def on_batch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
183160 """Called when the training batch ends."""
184- pass
185161
186162 def on_validation_batch_start (
187163 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
188164 ) -> None :
189165 """Called when the validation batch begins."""
190- pass
191166
192167 def on_validation_batch_end (
193168 self ,
@@ -199,13 +174,11 @@ def on_validation_batch_end(
199174 dataloader_idx : int ,
200175 ) -> None :
201176 """Called when the validation batch ends."""
202- pass
203177
204178 def on_test_batch_start (
205179 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
206180 ) -> None :
207181 """Called when the test batch begins."""
208- pass
209182
210183 def on_test_batch_end (
211184 self ,
@@ -217,13 +190,11 @@ def on_test_batch_end(
217190 dataloader_idx : int ,
218191 ) -> None :
219192 """Called when the test batch ends."""
220- pass
221193
222194 def on_predict_batch_start (
223195 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , batch : Any , batch_idx : int , dataloader_idx : int
224196 ) -> None :
225197 """Called when the predict batch begins."""
226- pass
227198
228199 def on_predict_batch_end (
229200 self ,
@@ -235,47 +206,36 @@ def on_predict_batch_end(
235206 dataloader_idx : int ,
236207 ) -> None :
237208 """Called when the predict batch ends."""
238- pass
239209
240210 def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
241211 """Called when the train begins."""
242- pass
243212
244213 def on_train_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
245214 """Called when the train ends."""
246- pass
247215
248216 def on_pretrain_routine_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
249217 """Called when the pretrain routine begins."""
250- pass
251218
252219 def on_pretrain_routine_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
253220 """Called when the pretrain routine ends."""
254- pass
255221
256222 def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
257223 """Called when the validation loop begins."""
258- pass
259224
260225 def on_validation_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
261226 """Called when the validation loop ends."""
262- pass
263227
264228 def on_test_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
265229 """Called when the test begins."""
266- pass
267230
268231 def on_test_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
269232 """Called when the test ends."""
270- pass
271233
272234 def on_predict_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
273235 """Called when the predict begins."""
274- pass
275236
276237 def on_predict_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
277238 """Called when predict ends."""
278- pass
279239
280240 def on_keyboard_interrupt (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
281241 r"""
@@ -284,11 +244,9 @@ def on_keyboard_interrupt(self, trainer: "pl.Trainer", pl_module: "pl.LightningM
284244
285245 Called when any trainer execution is interrupted by KeyboardInterrupt.
286246 """
287- pass
288247
289248 def on_exception (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , exception : BaseException ) -> None :
290249 """Called when any trainer execution is interrupted by an exception."""
291- pass
292250
293251 def on_save_checkpoint (
294252 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : Dict [str , Any ]
@@ -303,7 +261,6 @@ def on_save_checkpoint(
303261 Returns:
304262 The callback state.
305263 """
306- pass
307264
308265 def on_load_checkpoint (
309266 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , callback_state : Dict [str , Any ]
@@ -320,22 +277,17 @@ def on_load_checkpoint(
320277 If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
321278 you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
322279 """
323- pass
324280
325281 def on_before_backward (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , loss : torch .Tensor ) -> None :
326282 """Called before ``loss.backward()``."""
327- pass
328283
329284 def on_after_backward (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
330285 """Called after ``loss.backward()`` and before optimizers are stepped."""
331- pass
332286
333287 def on_before_optimizer_step (
334288 self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , optimizer : Optimizer , opt_idx : int
335289 ) -> None :
336290 """Called before ``optimizer.step()``."""
337- pass
338291
339292 def on_before_zero_grad (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , optimizer : Optimizer ) -> None :
340293 """Called before ``optimizer.zero_grad()``."""
341- pass
0 commit comments