Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0df6bd9
Add LambdaCallback
archsyscall Jan 4, 2021
aa13ddf
docs
archsyscall Jan 4, 2021
01bd0a5
add pr link
Borda Jan 4, 2021
b0953dd
convention
archsyscall Jan 4, 2021
7863e67
Fix Callback Typo
archsyscall Jan 4, 2021
6792408
Update pytorch_lightning/callbacks/lambda_cb.py
archsyscall Jan 4, 2021
d934a23
Update pytorch_lightning/callbacks/lambda_cb.py
archsyscall Jan 4, 2021
9fc981a
Update pytorch_lightning/callbacks/lambda_cb.py
archsyscall Jan 4, 2021
a93e468
use Misconfigureation
archsyscall Jan 5, 2021
2ef199f
update docs
archsyscall Jan 5, 2021
cb294e0
sort export
archsyscall Jan 5, 2021
aadde9e
use inspect
archsyscall Jan 5, 2021
8c10b14
string fill
archsyscall Jan 5, 2021
39b1970
use fast dev run
archsyscall Jan 5, 2021
dc11767
isort
archsyscall Jan 5, 2021
0cfef59
remove unused import
archsyscall Jan 5, 2021
6835771
sort
archsyscall Jan 5, 2021
0263e3a
hilightning
archsyscall Jan 5, 2021
7249a10
highlighting
archsyscall Jan 5, 2021
3038d2f
highlighting
archsyscall Jan 5, 2021
c400b98
remove debug log
archsyscall Jan 5, 2021
8518382
eq
archsyscall Jan 5, 2021
8bfe53e
res
archsyscall Jan 5, 2021
9fd4c6b
results
archsyscall Jan 5, 2021
c4563c7
add misconfig exception test
archsyscall Jan 5, 2021
a329d4a
use pytest raises
archsyscall Jan 5, 2021
571b941
Merge remote-tracking branch 'upstream/release/1.2-dev' into feature/…
archsyscall Jan 5, 2021
d1f8d4a
fix
archsyscall Jan 5, 2021
7293115
Apply suggestions from code review
Borda Jan 6, 2021
4d85f59
Update pytorch_lightning/callbacks/lambda_cb.py
archsyscall Jan 6, 2021
c9ecb8a
hc
archsyscall Jan 6, 2021
2044291
rm pt
archsyscall Jan 6, 2021
5359ce6
Merge branch 'release/1.2-dev' into feature/lambdacallback
tchaton Jan 6, 2021
556ea09
fix
archsyscall Jan 8, 2021
d190e15
try fix
rohitgr7 Jan 9, 2021
d7bfc4a
Merge branch 'release/1.2-dev' into feature/lambdacallback
rohitgr7 Jan 9, 2021
a27dbff
whitespace
rohitgr7 Jan 9, 2021
d1bd19a
new hook
rohitgr7 Jan 9, 2021
afe018a
add raise
archsyscall Jan 10, 2021
709fb5b
fix
archsyscall Jan 10, 2021
9b93a2c
remove unused
archsyscall Jan 10, 2021
72f3f0c
rename
archsyscall Jan 12, 2021
7ed3eea
Merge branch 'release/1.2-dev' into feature/lambdacallback
SkafteNicki Jan 12, 2021
2ce0131
Merge branch 'release/1.2-dev' into feature/lambdacallback
SkafteNicki Jan 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))


- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347))


- Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377))


Expand Down
1 change: 1 addition & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ Lightning has a few built-in callbacks.
EarlyStopping
GPUStatsMonitor
GradientAccumulationScheduler
LambdaCallback
LearningRateMonitor
ModelCheckpoint
ProgressBar
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@
from pytorch_lightning.callbacks.finetuning import BackboneLambdaFinetuningCallback, BaseFinetuningCallback
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase

__all__ = [
'BackboneLambdaFinetuningCallback',
'BaseFinetuningCallback',
'Callback',
'EarlyStopping',
'GPUStatsMonitor',
'GradientAccumulationScheduler',
'BaseFinetuningCallback',
'BackboneLambdaFinetuningCallback',
'LambdaCallback',
'LearningRateMonitor',
'ModelCheckpoint',
'ProgressBar',
Expand Down
158 changes: 158 additions & 0 deletions pytorch_lightning/callbacks/lambda_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
Lambda Callback
^^^^^^^^^^^^^^^

Create a simple callback on the fly using lambda functions.

"""

from typing import Callable, Optional

from pytorch_lightning.callbacks.base import Callback


class LambdaCallback(Callback):
r"""
Create a simple callback on the fly using lambda functions.

Args:
**kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback`

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import LambdaCallback
>>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))])
"""

def __init__(
self,
on_before_accelerator_backend_setup: Optional[Callable] = None,
setup: Optional[Callable] = None,
teardown: Optional[Callable] = None,
on_init_start: Optional[Callable] = None,
on_init_end: Optional[Callable] = None,
on_fit_start: Optional[Callable] = None,
on_fit_end: Optional[Callable] = None,
on_sanity_check_start: Optional[Callable] = None,
on_sanity_check_end: Optional[Callable] = None,
on_train_batch_start: Optional[Callable] = None,
on_train_batch_end: Optional[Callable] = None,
on_train_epoch_start: Optional[Callable] = None,
on_train_epoch_end: Optional[Callable] = None,
on_validation_epoch_start: Optional[Callable] = None,
on_validation_epoch_end: Optional[Callable] = None,
on_test_epoch_start: Optional[Callable] = None,
on_test_epoch_end: Optional[Callable] = None,
on_epoch_start: Optional[Callable] = None,
on_epoch_end: Optional[Callable] = None,
on_batch_start: Optional[Callable] = None,
on_validation_batch_start: Optional[Callable] = None,
on_validation_batch_end: Optional[Callable] = None,
on_test_batch_start: Optional[Callable] = None,
on_test_batch_end: Optional[Callable] = None,
on_batch_end: Optional[Callable] = None,
on_train_start: Optional[Callable] = None,
on_train_end: Optional[Callable] = None,
on_pretrain_routine_start: Optional[Callable] = None,
on_pretrain_routine_end: Optional[Callable] = None,
on_validation_start: Optional[Callable] = None,
on_validation_end: Optional[Callable] = None,
on_test_start: Optional[Callable] = None,
on_test_end: Optional[Callable] = None,
on_keyboard_interrupt: Optional[Callable] = None,
on_save_checkpoint: Optional[Callable] = None,
on_load_checkpoint: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
):
if on_before_accelerator_backend_setup is not None:
self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup
if setup is not None:
self.setup = setup
if teardown is not None:
self.teardown = teardown
if on_init_start is not None:
self.on_init_start = on_init_start
if on_init_end is not None:
self.on_init_end = on_init_end
if on_fit_start is not None:
self.on_fit_start = on_fit_start
if on_fit_end is not None:
self.on_fit_end = on_fit_end
if on_sanity_check_start is not None:
self.on_sanity_check_start = on_sanity_check_start
if on_sanity_check_end is not None:
self.on_sanity_check_end = on_sanity_check_end
if on_train_batch_start is not None:
self.on_train_batch_start = on_train_batch_start
if on_train_batch_end is not None:
self.on_train_batch_end = on_train_batch_end
if on_train_epoch_start is not None:
self.on_train_epoch_start = on_train_epoch_start
if on_train_epoch_end is not None:
self.on_train_epoch_end = on_train_epoch_end
if on_validation_epoch_start is not None:
self.on_validation_epoch_start = on_validation_epoch_start
if on_validation_epoch_end is not None:
self.on_validation_epoch_end = on_validation_epoch_end
if on_test_epoch_start is not None:
self.on_test_epoch_start = on_test_epoch_start
if on_test_epoch_end is not None:
self.on_test_epoch_end = on_test_epoch_end
if on_epoch_start is not None:
self.on_epoch_start = on_epoch_start
if on_epoch_end is not None:
self.on_epoch_end = on_epoch_end
if on_batch_start is not None:
self.on_batch_start = on_batch_start
if on_validation_batch_start is not None:
self.on_validation_batch_start = on_validation_batch_start
if on_validation_batch_end is not None:
self.on_validation_batch_end = on_validation_batch_end
if on_test_batch_start is not None:
self.on_test_batch_start = on_test_batch_start
if on_test_batch_end is not None:
self.on_test_batch_end = on_test_batch_end
if on_batch_end is not None:
self.on_batch_end = on_batch_end
if on_train_start is not None:
self.on_train_start = on_train_start
if on_train_end is not None:
self.on_train_end = on_train_end
if on_pretrain_routine_start is not None:
self.on_pretrain_routine_start = on_pretrain_routine_start
if on_pretrain_routine_end is not None:
self.on_pretrain_routine_end = on_pretrain_routine_end
if on_validation_start is not None:
self.on_validation_start = on_validation_start
if on_validation_end is not None:
self.on_validation_end = on_validation_end
if on_test_start is not None:
self.on_test_start = on_test_start
if on_test_end is not None:
self.on_test_end = on_test_end
if on_keyboard_interrupt is not None:
self.on_keyboard_interrupt = on_keyboard_interrupt
if on_save_checkpoint is not None:
self.on_save_checkpoint = on_save_checkpoint
if on_load_checkpoint is not None:
self.on_load_checkpoint = on_load_checkpoint
if on_after_backward is not None:
self.on_after_backward = on_after_backward
if on_before_zero_grad is not None:
self.on_before_zero_grad = on_before_zero_grad
60 changes: 60 additions & 0 deletions tests/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
from tests.base.boring_model import BoringModel


def test_lambda_call(tmpdir):
seed_everything(42)

class CustomModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 1:
raise KeyboardInterrupt

checker = set()
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks}
hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint")

model = CustomModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
results = trainer.fit(model)
assert results

model = CustomModel()
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
resume_from_checkpoint=ckpt_path,
callbacks=[LambdaCallback(**hooks_args)],
)
results = trainer.fit(model)
trainer.test(model)

assert results
assert checker == set(hooks)