Skip to content

Commit 4eadb89

Browse files
archsyscallBordaSkafteNickirohitgr7tchaton
committed
Add LambdaCallback (#5347)
* Add LambdaCallback * docs * add pr link # Conflicts: # CHANGELOG.md * convention * Fix Callback Typo * Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Nicki Skafte <[email protected]> * Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Nicki Skafte <[email protected]> * Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Nicki Skafte <[email protected]> * use Misconfigureation * update docs * sort export * use inspect * string fill * use fast dev run * isort * remove unused import * sort * hilightning * highlighting * highlighting * remove debug log * eq * res * results * add misconfig exception test * use pytest raises * fix * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Rohit Gupta <[email protected]> * hc * rm pt * fix * try fix * whitespace * new hook * add raise * fix * remove unused * rename Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: chaton <[email protected]>
1 parent 4e827eb commit 4eadb89

File tree

5 files changed

+227
-2
lines changed

5 files changed

+227
-2
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))
3131

32+
33+
- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347))
34+
35+
3236
- Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377))
3337

3438

docs/source/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ Lightning has a few built-in callbacks.
9898
EarlyStopping
9999
GPUStatsMonitor
100100
GradientAccumulationScheduler
101+
LambdaCallback
101102
LearningRateMonitor
102103
ModelCheckpoint
103104
ProgressBar

pytorch_lightning/callbacks/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@
1616
from pytorch_lightning.callbacks.finetuning import BackboneLambdaFinetuningCallback, BaseFinetuningCallback
1717
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
1818
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
19+
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
1920
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
2021
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2122
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
2223

2324
__all__ = [
25+
'BackboneLambdaFinetuningCallback',
26+
'BaseFinetuningCallback',
2427
'Callback',
2528
'EarlyStopping',
2629
'GPUStatsMonitor',
2730
'GradientAccumulationScheduler',
28-
'BaseFinetuningCallback',
29-
'BackboneLambdaFinetuningCallback',
31+
'LambdaCallback',
3032
'LearningRateMonitor',
3133
'ModelCheckpoint',
3234
'ProgressBar',
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
r"""
16+
Lambda Callback
17+
^^^^^^^^^^^^^^^
18+
19+
Create a simple callback on the fly using lambda functions.
20+
21+
"""
22+
23+
from typing import Callable, Optional
24+
25+
from pytorch_lightning.callbacks.base import Callback
26+
27+
28+
class LambdaCallback(Callback):
29+
r"""
30+
Create a simple callback on the fly using lambda functions.
31+
32+
Args:
33+
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback`
34+
35+
Example::
36+
37+
>>> from pytorch_lightning import Trainer
38+
>>> from pytorch_lightning.callbacks import LambdaCallback
39+
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
40+
"""
41+
42+
def __init__(
43+
self,
44+
on_before_accelerator_backend_setup: Optional[Callable] = None,
45+
setup: Optional[Callable] = None,
46+
teardown: Optional[Callable] = None,
47+
on_init_start: Optional[Callable] = None,
48+
on_init_end: Optional[Callable] = None,
49+
on_fit_start: Optional[Callable] = None,
50+
on_fit_end: Optional[Callable] = None,
51+
on_sanity_check_start: Optional[Callable] = None,
52+
on_sanity_check_end: Optional[Callable] = None,
53+
on_train_batch_start: Optional[Callable] = None,
54+
on_train_batch_end: Optional[Callable] = None,
55+
on_train_epoch_start: Optional[Callable] = None,
56+
on_train_epoch_end: Optional[Callable] = None,
57+
on_validation_epoch_start: Optional[Callable] = None,
58+
on_validation_epoch_end: Optional[Callable] = None,
59+
on_test_epoch_start: Optional[Callable] = None,
60+
on_test_epoch_end: Optional[Callable] = None,
61+
on_epoch_start: Optional[Callable] = None,
62+
on_epoch_end: Optional[Callable] = None,
63+
on_batch_start: Optional[Callable] = None,
64+
on_validation_batch_start: Optional[Callable] = None,
65+
on_validation_batch_end: Optional[Callable] = None,
66+
on_test_batch_start: Optional[Callable] = None,
67+
on_test_batch_end: Optional[Callable] = None,
68+
on_batch_end: Optional[Callable] = None,
69+
on_train_start: Optional[Callable] = None,
70+
on_train_end: Optional[Callable] = None,
71+
on_pretrain_routine_start: Optional[Callable] = None,
72+
on_pretrain_routine_end: Optional[Callable] = None,
73+
on_validation_start: Optional[Callable] = None,
74+
on_validation_end: Optional[Callable] = None,
75+
on_test_start: Optional[Callable] = None,
76+
on_test_end: Optional[Callable] = None,
77+
on_keyboard_interrupt: Optional[Callable] = None,
78+
on_save_checkpoint: Optional[Callable] = None,
79+
on_load_checkpoint: Optional[Callable] = None,
80+
on_after_backward: Optional[Callable] = None,
81+
on_before_zero_grad: Optional[Callable] = None,
82+
):
83+
if on_before_accelerator_backend_setup is not None:
84+
self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup
85+
if setup is not None:
86+
self.setup = setup
87+
if teardown is not None:
88+
self.teardown = teardown
89+
if on_init_start is not None:
90+
self.on_init_start = on_init_start
91+
if on_init_end is not None:
92+
self.on_init_end = on_init_end
93+
if on_fit_start is not None:
94+
self.on_fit_start = on_fit_start
95+
if on_fit_end is not None:
96+
self.on_fit_end = on_fit_end
97+
if on_sanity_check_start is not None:
98+
self.on_sanity_check_start = on_sanity_check_start
99+
if on_sanity_check_end is not None:
100+
self.on_sanity_check_end = on_sanity_check_end
101+
if on_train_batch_start is not None:
102+
self.on_train_batch_start = on_train_batch_start
103+
if on_train_batch_end is not None:
104+
self.on_train_batch_end = on_train_batch_end
105+
if on_train_epoch_start is not None:
106+
self.on_train_epoch_start = on_train_epoch_start
107+
if on_train_epoch_end is not None:
108+
self.on_train_epoch_end = on_train_epoch_end
109+
if on_validation_epoch_start is not None:
110+
self.on_validation_epoch_start = on_validation_epoch_start
111+
if on_validation_epoch_end is not None:
112+
self.on_validation_epoch_end = on_validation_epoch_end
113+
if on_test_epoch_start is not None:
114+
self.on_test_epoch_start = on_test_epoch_start
115+
if on_test_epoch_end is not None:
116+
self.on_test_epoch_end = on_test_epoch_end
117+
if on_epoch_start is not None:
118+
self.on_epoch_start = on_epoch_start
119+
if on_epoch_end is not None:
120+
self.on_epoch_end = on_epoch_end
121+
if on_batch_start is not None:
122+
self.on_batch_start = on_batch_start
123+
if on_validation_batch_start is not None:
124+
self.on_validation_batch_start = on_validation_batch_start
125+
if on_validation_batch_end is not None:
126+
self.on_validation_batch_end = on_validation_batch_end
127+
if on_test_batch_start is not None:
128+
self.on_test_batch_start = on_test_batch_start
129+
if on_test_batch_end is not None:
130+
self.on_test_batch_end = on_test_batch_end
131+
if on_batch_end is not None:
132+
self.on_batch_end = on_batch_end
133+
if on_train_start is not None:
134+
self.on_train_start = on_train_start
135+
if on_train_end is not None:
136+
self.on_train_end = on_train_end
137+
if on_pretrain_routine_start is not None:
138+
self.on_pretrain_routine_start = on_pretrain_routine_start
139+
if on_pretrain_routine_end is not None:
140+
self.on_pretrain_routine_end = on_pretrain_routine_end
141+
if on_validation_start is not None:
142+
self.on_validation_start = on_validation_start
143+
if on_validation_end is not None:
144+
self.on_validation_end = on_validation_end
145+
if on_test_start is not None:
146+
self.on_test_start = on_test_start
147+
if on_test_end is not None:
148+
self.on_test_end = on_test_end
149+
if on_keyboard_interrupt is not None:
150+
self.on_keyboard_interrupt = on_keyboard_interrupt
151+
if on_save_checkpoint is not None:
152+
self.on_save_checkpoint = on_save_checkpoint
153+
if on_load_checkpoint is not None:
154+
self.on_load_checkpoint = on_load_checkpoint
155+
if on_after_backward is not None:
156+
self.on_after_backward = on_after_backward
157+
if on_before_zero_grad is not None:
158+
self.on_before_zero_grad = on_before_zero_grad
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import inspect
15+
16+
from pytorch_lightning import seed_everything, Trainer
17+
from pytorch_lightning.callbacks import Callback, LambdaCallback
18+
from tests.base.boring_model import BoringModel
19+
20+
21+
def test_lambda_call(tmpdir):
22+
seed_everything(42)
23+
24+
class CustomModel(BoringModel):
25+
def on_train_epoch_start(self):
26+
if self.current_epoch > 1:
27+
raise KeyboardInterrupt
28+
29+
checker = set()
30+
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
31+
hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks}
32+
hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint")
33+
34+
model = CustomModel()
35+
trainer = Trainer(
36+
default_root_dir=tmpdir,
37+
max_epochs=1,
38+
limit_train_batches=1,
39+
limit_val_batches=1,
40+
callbacks=[LambdaCallback(**hooks_args)],
41+
)
42+
results = trainer.fit(model)
43+
assert results
44+
45+
model = CustomModel()
46+
ckpt_path = trainer.checkpoint_callback.best_model_path
47+
trainer = Trainer(
48+
default_root_dir=tmpdir,
49+
max_epochs=3,
50+
limit_train_batches=1,
51+
limit_val_batches=1,
52+
limit_test_batches=1,
53+
resume_from_checkpoint=ckpt_path,
54+
callbacks=[LambdaCallback(**hooks_args)],
55+
)
56+
results = trainer.fit(model)
57+
trainer.test(model)
58+
59+
assert results
60+
assert checker == set(hooks)

0 commit comments

Comments
 (0)