diff --git a/CHANGELOG.md b/CHANGELOG.md index ab74be455ad54..1c589733fb367 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) + +- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347)) + + - Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377)) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index dbc7651687f20..e955ad89fa998 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -98,6 +98,7 @@ Lightning has a few built-in callbacks. EarlyStopping GPUStatsMonitor GradientAccumulationScheduler + LambdaCallback LearningRateMonitor ModelCheckpoint ProgressBar diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 47fc865cadb66..a03dbcca85f7f 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -16,17 +16,19 @@ from pytorch_lightning.callbacks.finetuning import BackboneLambdaFinetuningCallback, BaseFinetuningCallback from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler +from pytorch_lightning.callbacks.lambda_function import LambdaCallback from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase __all__ = [ + 'BackboneLambdaFinetuningCallback', + 'BaseFinetuningCallback', 'Callback', 'EarlyStopping', 'GPUStatsMonitor', 'GradientAccumulationScheduler', - 'BaseFinetuningCallback', - 'BackboneLambdaFinetuningCallback', + 'LambdaCallback', 'LearningRateMonitor', 'ModelCheckpoint', 'ProgressBar', diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py new file mode 100644 index 0000000000000..2d111e7da7acd --- /dev/null +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -0,0 +1,158 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Lambda Callback +^^^^^^^^^^^^^^^ + +Create a simple callback on the fly using lambda functions. + +""" + +from typing import Callable, Optional + +from pytorch_lightning.callbacks.base import Callback + + +class LambdaCallback(Callback): + r""" + Create a simple callback on the fly using lambda functions. + + Args: + **kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback` + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LambdaCallback + >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))]) + """ + + def __init__( + self, + on_before_accelerator_backend_setup: Optional[Callable] = None, + setup: Optional[Callable] = None, + teardown: Optional[Callable] = None, + on_init_start: Optional[Callable] = None, + on_init_end: Optional[Callable] = None, + on_fit_start: Optional[Callable] = None, + on_fit_end: Optional[Callable] = None, + on_sanity_check_start: Optional[Callable] = None, + on_sanity_check_end: Optional[Callable] = None, + on_train_batch_start: Optional[Callable] = None, + on_train_batch_end: Optional[Callable] = None, + on_train_epoch_start: Optional[Callable] = None, + on_train_epoch_end: Optional[Callable] = None, + on_validation_epoch_start: Optional[Callable] = None, + on_validation_epoch_end: Optional[Callable] = None, + on_test_epoch_start: Optional[Callable] = None, + on_test_epoch_end: Optional[Callable] = None, + on_epoch_start: Optional[Callable] = None, + on_epoch_end: Optional[Callable] = None, + on_batch_start: Optional[Callable] = None, + on_validation_batch_start: Optional[Callable] = None, + on_validation_batch_end: Optional[Callable] = None, + on_test_batch_start: Optional[Callable] = None, + on_test_batch_end: Optional[Callable] = None, + on_batch_end: Optional[Callable] = None, + on_train_start: Optional[Callable] = None, + on_train_end: Optional[Callable] = None, + on_pretrain_routine_start: Optional[Callable] = None, + on_pretrain_routine_end: Optional[Callable] = None, + on_validation_start: Optional[Callable] = None, + on_validation_end: Optional[Callable] = None, + on_test_start: Optional[Callable] = None, + on_test_end: Optional[Callable] = None, + on_keyboard_interrupt: Optional[Callable] = None, + on_save_checkpoint: Optional[Callable] = None, + on_load_checkpoint: Optional[Callable] = None, + on_after_backward: Optional[Callable] = None, + on_before_zero_grad: Optional[Callable] = None, + ): + if on_before_accelerator_backend_setup is not None: + self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup + if setup is not None: + self.setup = setup + if teardown is not None: + self.teardown = teardown + if on_init_start is not None: + self.on_init_start = on_init_start + if on_init_end is not None: + self.on_init_end = on_init_end + if on_fit_start is not None: + self.on_fit_start = on_fit_start + if on_fit_end is not None: + self.on_fit_end = on_fit_end + if on_sanity_check_start is not None: + self.on_sanity_check_start = on_sanity_check_start + if on_sanity_check_end is not None: + self.on_sanity_check_end = on_sanity_check_end + if on_train_batch_start is not None: + self.on_train_batch_start = on_train_batch_start + if on_train_batch_end is not None: + self.on_train_batch_end = on_train_batch_end + if on_train_epoch_start is not None: + self.on_train_epoch_start = on_train_epoch_start + if on_train_epoch_end is not None: + self.on_train_epoch_end = on_train_epoch_end + if on_validation_epoch_start is not None: + self.on_validation_epoch_start = on_validation_epoch_start + if on_validation_epoch_end is not None: + self.on_validation_epoch_end = on_validation_epoch_end + if on_test_epoch_start is not None: + self.on_test_epoch_start = on_test_epoch_start + if on_test_epoch_end is not None: + self.on_test_epoch_end = on_test_epoch_end + if on_epoch_start is not None: + self.on_epoch_start = on_epoch_start + if on_epoch_end is not None: + self.on_epoch_end = on_epoch_end + if on_batch_start is not None: + self.on_batch_start = on_batch_start + if on_validation_batch_start is not None: + self.on_validation_batch_start = on_validation_batch_start + if on_validation_batch_end is not None: + self.on_validation_batch_end = on_validation_batch_end + if on_test_batch_start is not None: + self.on_test_batch_start = on_test_batch_start + if on_test_batch_end is not None: + self.on_test_batch_end = on_test_batch_end + if on_batch_end is not None: + self.on_batch_end = on_batch_end + if on_train_start is not None: + self.on_train_start = on_train_start + if on_train_end is not None: + self.on_train_end = on_train_end + if on_pretrain_routine_start is not None: + self.on_pretrain_routine_start = on_pretrain_routine_start + if on_pretrain_routine_end is not None: + self.on_pretrain_routine_end = on_pretrain_routine_end + if on_validation_start is not None: + self.on_validation_start = on_validation_start + if on_validation_end is not None: + self.on_validation_end = on_validation_end + if on_test_start is not None: + self.on_test_start = on_test_start + if on_test_end is not None: + self.on_test_end = on_test_end + if on_keyboard_interrupt is not None: + self.on_keyboard_interrupt = on_keyboard_interrupt + if on_save_checkpoint is not None: + self.on_save_checkpoint = on_save_checkpoint + if on_load_checkpoint is not None: + self.on_load_checkpoint = on_load_checkpoint + if on_after_backward is not None: + self.on_after_backward = on_after_backward + if on_before_zero_grad is not None: + self.on_before_zero_grad = on_before_zero_grad diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py new file mode 100644 index 0000000000000..a22a03fa369ff --- /dev/null +++ b/tests/callbacks/test_lambda_function.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import Callback, LambdaCallback +from tests.base.boring_model import BoringModel + + +def test_lambda_call(tmpdir): + seed_everything(42) + + class CustomModel(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch > 1: + raise KeyboardInterrupt + + checker = set() + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] + hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} + hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint") + + model = CustomModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=1, + limit_val_batches=1, + callbacks=[LambdaCallback(**hooks_args)], + ) + results = trainer.fit(model) + assert results + + model = CustomModel() + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + resume_from_checkpoint=ckpt_path, + callbacks=[LambdaCallback(**hooks_args)], + ) + results = trainer.fit(model) + trainer.test(model) + + assert results + assert checker == set(hooks)