From 0df6bd978ff5e8c7e18cd708a703fa6d56690e92 Mon Sep 17 00:00:00 2001 From: marload Date: Mon, 4 Jan 2021 21:12:49 +0900 Subject: [PATCH 01/39] Add LambdaCallback --- pytorch_lightning/callbacks/__init__.py | 19 +++++----- pytorch_lightning/callbacks/lambda_cb.py | 45 ++++++++++++++++++++++++ tests/callbacks/test_lambda_cb.py | 42 ++++++++++++++++++++++ 3 files changed, 97 insertions(+), 9 deletions(-) create mode 100644 pytorch_lightning/callbacks/lambda_cb.py create mode 100644 tests/callbacks/test_lambda_cb.py diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index d91ab92bdfc08..4595fef7a08a1 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -15,18 +15,19 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler +from pytorch_lightning.callbacks.lambda_cb import LambdaCallback from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase - __all__ = [ - 'Callback', - 'EarlyStopping', - 'GPUStatsMonitor', - 'GradientAccumulationScheduler', - 'LearningRateMonitor', - 'ModelCheckpoint', - 'ProgressBar', - 'ProgressBarBase', + "Callback", + "EarlyStopping", + "GPUStatsMonitor", + "GradientAccumulationScheduler", + "LearningRateMonitor", + "ModelCheckpoint", + "ProgressBar", + "ProgressBarBase", + "LambdaCallback", ] diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py new file mode 100644 index 0000000000000..c7f3bc6792372 --- /dev/null +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r""" +Lambda Callback +^^^^^^^^^^^^^^^ + +Create a simple callback on the fly. + +""" + +from pytorch_lightning.callbacks.base import Callback + + +class LambdaCallback(Callback): + r""" + Create a simple callback on the fly. + + Args: + **kwargs: event listener supported by ``Callback`` + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import LambdaCallback + >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))]) + """ + + def __init__(self, **kwargs): + listeners = [m for m in dir(Callback) if not m.startswith("_")] + for k, v in kwargs.items(): + if k not in listeners: + raise ValueError(f"Invalid argument `{k}`") + setattr(self, k, v) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py new file mode 100644 index 0000000000000..4a99e023c7bd0 --- /dev/null +++ b/tests/callbacks/test_lambda_cb.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import LambdaCallback +from tests.base.boring_model import BoringModel + + +def test_lambda_call(tmpdir): + seed_everything(42) + + model = BoringModel() + checker = set() + + callback_dicts = {"setup": lambda *args: checker.add("setup")} + test_callback = LambdaCallback(**callback_dicts) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=1, + limit_train_batches=3, + limit_test_batches=2, + progress_bar_refresh_rate=0, + enable_pl_optimizer=True, + callbacks=[test_callback], + ) + + trainer.fit(model) + + for name in ("setup",): + assert name in checker From aa13ddf3f007e6ffb3d4cb0dcf5ea178ddc0f449 Mon Sep 17 00:00:00 2001 From: marload Date: Mon, 4 Jan 2021 13:18:43 +0100 Subject: [PATCH 02/39] docs --- CHANGELOG.md | 2 ++ docs/source/callbacks.rst | 1 + 2 files changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 83f32bbba29ba..09fcd6afcaa18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) +- Add LambdaCallback ([#]) + ### Changed diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index dbc7651687f20..a5caeea9943fe 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -102,6 +102,7 @@ Lightning has a few built-in callbacks. ModelCheckpoint ProgressBar ProgressBarBase + LambdaMonitor ---------- From 01bd0a5b2c3b60a84781c208ceeb11abd9fd3a1a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 4 Jan 2021 14:13:37 +0100 Subject: [PATCH 03/39] add pr link # Conflicts: # CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09fcd6afcaa18..4dc39dd95eeef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) -- Add LambdaCallback ([#]) +- Added `LambdaCallback` ([#5347](https://github.com/PyTorchLightning/pytorch-lightning/pull/5347)) ### Changed From b0953dd2180d6791a0482a8ed12ff7575f085a55 Mon Sep 17 00:00:00 2001 From: marload Date: Mon, 4 Jan 2021 21:21:29 +0900 Subject: [PATCH 04/39] convention --- pytorch_lightning/callbacks/__init__.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 4595fef7a08a1..2ee8427019fb2 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -21,13 +21,13 @@ from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase __all__ = [ - "Callback", - "EarlyStopping", - "GPUStatsMonitor", - "GradientAccumulationScheduler", - "LearningRateMonitor", - "ModelCheckpoint", - "ProgressBar", - "ProgressBarBase", - "LambdaCallback", + 'Callback', + 'EarlyStopping', + 'GPUStatsMonitor', + 'GradientAccumulationScheduler', + 'LearningRateMonitor', + 'ModelCheckpoint', + 'ProgressBar', + 'ProgressBarBase', + 'LambdaCallback', ] From 7863e67fe564cab243d397aa2a401a9c9c2882b6 Mon Sep 17 00:00:00 2001 From: marload Date: Mon, 4 Jan 2021 21:39:19 +0900 Subject: [PATCH 05/39] Fix Callback Typo --- docs/source/callbacks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index a5caeea9943fe..54957caab16db 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -102,7 +102,7 @@ Lightning has a few built-in callbacks. ModelCheckpoint ProgressBar ProgressBarBase - LambdaMonitor + LambdaCallback ---------- From 6792408d060a23f3f13d517b3eb8a28ae6c7d491 Mon Sep 17 00:00:00 2001 From: Wansoo Kim Date: Mon, 4 Jan 2021 23:42:51 +0900 Subject: [PATCH 06/39] Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Nicki Skafte --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index c7f3bc6792372..70da1a9c0e526 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -28,7 +28,7 @@ class LambdaCallback(Callback): Create a simple callback on the fly. Args: - **kwargs: event listener supported by ``Callback`` + **kwargs: hooks supported by ``Callback`` Example:: From d934a234fb1d497b1e3acfad8aa3b702e1b417d8 Mon Sep 17 00:00:00 2001 From: Wansoo Kim Date: Mon, 4 Jan 2021 23:42:55 +0900 Subject: [PATCH 07/39] Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Nicki Skafte --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 70da1a9c0e526..b2c89d33ea112 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -38,7 +38,7 @@ class LambdaCallback(Callback): """ def __init__(self, **kwargs): - listeners = [m for m in dir(Callback) if not m.startswith("_")] + hooks = [m for m in dir(Callback) if not m.startswith("_")] for k, v in kwargs.items(): if k not in listeners: raise ValueError(f"Invalid argument `{k}`") From 9fc981a636ebdb776c296eba37c0e961ad46e80a Mon Sep 17 00:00:00 2001 From: Wansoo Kim Date: Mon, 4 Jan 2021 23:43:03 +0900 Subject: [PATCH 08/39] Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Nicki Skafte --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index b2c89d33ea112..14f44c31c957a 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -40,6 +40,6 @@ class LambdaCallback(Callback): def __init__(self, **kwargs): hooks = [m for m in dir(Callback) if not m.startswith("_")] for k, v in kwargs.items(): - if k not in listeners: + if k not in hooks: raise ValueError(f"Invalid argument `{k}`") setattr(self, k, v) From a93e468a42d78996dc273cff6237c269e0c3a008 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:22:42 +0900 Subject: [PATCH 09/39] use Misconfigureation --- pytorch_lightning/callbacks/lambda_cb.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 14f44c31c957a..6b0616105af54 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -21,6 +21,7 @@ """ from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.exceptions import MisconfigurationException class LambdaCallback(Callback): @@ -41,5 +42,7 @@ def __init__(self, **kwargs): hooks = [m for m in dir(Callback) if not m.startswith("_")] for k, v in kwargs.items(): if k not in hooks: - raise ValueError(f"Invalid argument `{k}`") + raise MisconfigurationException( + f"The event function: `{k}` doesn't exist in supported callbacks function. Currently, Callback implements the following functions {dir(Callback)}" + ) setattr(self, k, v) From 2ef199f149b6c0b5c8b3891ec5f4eb88aaf54387 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:23:29 +0900 Subject: [PATCH 10/39] update docs --- pytorch_lightning/callbacks/lambda_cb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 6b0616105af54..cb488dfce0da3 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -16,7 +16,7 @@ Lambda Callback ^^^^^^^^^^^^^^^ -Create a simple callback on the fly. +Create a simple callback on the fly using lambda functions. """ @@ -26,7 +26,7 @@ class LambdaCallback(Callback): r""" - Create a simple callback on the fly. + Create a simple callback on the fly using lambda functions. Args: **kwargs: hooks supported by ``Callback`` From cb294e0a7eb08ab58caa5482efe7c022f287e454 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:23:52 +0900 Subject: [PATCH 11/39] sort export --- pytorch_lightning/callbacks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 2ee8427019fb2..48f3888cabd2a 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -25,9 +25,9 @@ 'EarlyStopping', 'GPUStatsMonitor', 'GradientAccumulationScheduler', + 'LambdaCallback', 'LearningRateMonitor', 'ModelCheckpoint', 'ProgressBar', 'ProgressBarBase', - 'LambdaCallback', ] From aadde9e3ffcc498d885c775a749ba4d3f32225fb Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:24:43 +0900 Subject: [PATCH 12/39] use inspect --- pytorch_lightning/callbacks/lambda_cb.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index cb488dfce0da3..eddeca01c1fd2 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -20,6 +20,8 @@ """ +import inspect + from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -39,7 +41,7 @@ class LambdaCallback(Callback): """ def __init__(self, **kwargs): - hooks = [m for m in dir(Callback) if not m.startswith("_")] + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] for k, v in kwargs.items(): if k not in hooks: raise MisconfigurationException( From 8c10b1471d488858ad5eba4ac94da86b1d5d7eb7 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:25:54 +0900 Subject: [PATCH 13/39] string fill --- pytorch_lightning/callbacks/lambda_cb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index eddeca01c1fd2..aed9be771a4a2 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -45,6 +45,7 @@ def __init__(self, **kwargs): for k, v in kwargs.items(): if k not in hooks: raise MisconfigurationException( - f"The event function: `{k}` doesn't exist in supported callbacks function. Currently, Callback implements the following functions {dir(Callback)}" + f"The event function: `{k}` doesn't exist in supported callbacks function." + f"Currently, Callback implements the following functions {hooks}" ) setattr(self, k, v) From 39b197039e74cd4b0795037e50db2e8aa761a484 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:26:28 +0900 Subject: [PATCH 14/39] use fast dev run --- tests/callbacks/test_lambda_cb.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index 4a99e023c7bd0..c16e1658e9ca5 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -26,13 +26,7 @@ def test_lambda_call(tmpdir): test_callback = LambdaCallback(**callback_dicts) trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=1, - limit_train_batches=3, - limit_test_batches=2, - progress_bar_refresh_rate=0, - enable_pl_optimizer=True, + fast_dev_run=True, callbacks=[test_callback], ) From dc117670ce6d5aedbbdf4f0f99bdbc5c7b70f7f2 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 10:51:42 +0900 Subject: [PATCH 15/39] isort --- tests/callbacks/test_lambda_cb.py | 54 ++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index c16e1658e9ca5..c065291c752e7 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -11,26 +11,58 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import LambdaCallback +import inspect + +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.base.boring_model import BoringModel def test_lambda_call(tmpdir): seed_everything(42) - model = BoringModel() checker = set() - callback_dicts = {"setup": lambda *args: checker.add("setup")} - test_callback = LambdaCallback(**callback_dicts) + hooks = [ + "setup", + "teardown", + "on_init_start", + "on_init_end", + "on_fit_start", + "on_fit_end", + "on_train_batch_start", + "on_train_batch_end", + "on_train_epoch_start", + "on_train_epoch_end", + "on_validation_epoch_start", + "on_validation_epoch_end", + "on_test_epoch_start", + "on_test_epoch_end", + "on_epoch_start", + "on_epoch_end", + "on_batch_start", + "on_batch_end", + "on_validation_batch_start", + "on_validation_batch_end", + "on_test_batch_start", + "on_test_batch_end", + "on_train_start", + "on_train_end", + "on_test_start", + "on_test_end", + ] + model = BoringModel() - trainer = Trainer( - fast_dev_run=True, - callbacks=[test_callback], - ) + hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} + test_callback = LambdaCallback(**hooks_args) + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[test_callback]) trainer.fit(model) + trainer.test(model) + + print(len(hooks)) + print(len(checker)) - for name in ("setup",): - assert name in checker + for h in hooks: + assert h in checker From 0cfef5990e833198edb612e62b9e897aee98ed40 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 11:03:47 +0900 Subject: [PATCH 16/39] remove unused import --- tests/callbacks/test_lambda_cb.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index c065291c752e7..209bfc9c7a46d 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect - from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.callbacks import Callback, LambdaCallback +from pytorch_lightning.callbacks import LambdaCallback from tests.base.boring_model import BoringModel From 68357716ba2b1a0089767072438c39d01a7db863 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:52:30 +0900 Subject: [PATCH 17/39] sort --- docs/source/callbacks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 54957caab16db..e955ad89fa998 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -98,11 +98,11 @@ Lightning has a few built-in callbacks. EarlyStopping GPUStatsMonitor GradientAccumulationScheduler + LambdaCallback LearningRateMonitor ModelCheckpoint ProgressBar ProgressBarBase - LambdaCallback ---------- From 0263e3ab87d781a82b425c20541804e22f96299a Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:53:10 +0900 Subject: [PATCH 18/39] hilightning --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index aed9be771a4a2..21eb9cc8e98c2 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -46,6 +46,6 @@ def __init__(self, **kwargs): if k not in hooks: raise MisconfigurationException( f"The event function: `{k}` doesn't exist in supported callbacks function." - f"Currently, Callback implements the following functions {hooks}" + f"Currently, `Callback` implements the following functions {hooks}" ) setattr(self, k, v) From 7249a1055a08082e8c2a60536c3184b8babe3d89 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:53:38 +0900 Subject: [PATCH 19/39] highlighting --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 21eb9cc8e98c2..5c0ac2bb2fa0c 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -46,6 +46,6 @@ def __init__(self, **kwargs): if k not in hooks: raise MisconfigurationException( f"The event function: `{k}` doesn't exist in supported callbacks function." - f"Currently, `Callback` implements the following functions {hooks}" + f"Currently, Callback` implements the following functions {hooks}" ) setattr(self, k, v) From 3038d2f5db3aa41c0260fed7d71f203aa62a82f3 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:53:44 +0900 Subject: [PATCH 20/39] highlighting --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 5c0ac2bb2fa0c..21eb9cc8e98c2 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -46,6 +46,6 @@ def __init__(self, **kwargs): if k not in hooks: raise MisconfigurationException( f"The event function: `{k}` doesn't exist in supported callbacks function." - f"Currently, Callback` implements the following functions {hooks}" + f"Currently, `Callback` implements the following functions {hooks}" ) setattr(self, k, v) From c400b9893188d0884b83e75e8f3350ef9bbf138a Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:54:15 +0900 Subject: [PATCH 21/39] remove debug log --- tests/callbacks/test_lambda_cb.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index 209bfc9c7a46d..5e3205bd81fde 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -59,8 +59,5 @@ def test_lambda_call(tmpdir): trainer.fit(model) trainer.test(model) - print(len(hooks)) - print(len(checker)) - for h in hooks: assert h in checker From 8518382d0c9bce4a34ee995dde4e240181f2557f Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:55:39 +0900 Subject: [PATCH 22/39] eq --- tests/callbacks/test_lambda_cb.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index 5e3205bd81fde..93ed24a1e1b3c 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -59,5 +59,4 @@ def test_lambda_call(tmpdir): trainer.fit(model) trainer.test(model) - for h in hooks: - assert h in checker + assert checker == set(hooks) From 8bfe53ed6f5cf74abfcc7a084baab942951627d8 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:58:04 +0900 Subject: [PATCH 23/39] res --- tests/callbacks/test_lambda_cb.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index 93ed24a1e1b3c..a72c670f71d63 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -56,7 +56,8 @@ def test_lambda_call(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[test_callback]) - trainer.fit(model) + result = trainer.fit(model) trainer.test(model) + assert result assert checker == set(hooks) From 9fd4c6b41424361a89b2b3cade8c3d25274dbc2e Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 14:58:37 +0900 Subject: [PATCH 24/39] results --- tests/callbacks/test_lambda_cb.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index a72c670f71d63..b2b1dcb91e61a 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -56,8 +56,8 @@ def test_lambda_call(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[test_callback]) - result = trainer.fit(model) + results = trainer.fit(model) trainer.test(model) - assert result + assert results assert checker == set(hooks) From c4563c721cc67916c9c74b9e47a071f984d10575 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 15:05:44 +0900 Subject: [PATCH 25/39] add misconfig exception test --- tests/callbacks/test_lambda_cb.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index b2b1dcb91e61a..ea7b419885d2d 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -13,9 +13,18 @@ # limitations under the License. from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import LambdaCallback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel +def test_lambda_raise_misconfiguration(): + try: + LambdaCallback(invalid=lambda *args: args) + assert False + except MisconfigurationException: + assert True + + def test_lambda_call(tmpdir): seed_everything(42) From a329d4a00f94af4e1a2079e94a7b0dd362aa8391 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 15:24:45 +0900 Subject: [PATCH 26/39] use pytest raises --- tests/callbacks/test_lambda_cb.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index ea7b419885d2d..718efeff0b230 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest + from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import LambdaCallback from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -18,11 +20,8 @@ def test_lambda_raise_misconfiguration(): - try: - LambdaCallback(invalid=lambda *args: args) - assert False - except MisconfigurationException: - assert True + with pytest.raises(MisconfigurationException): + LambdaCallback(setup=lambda *args: args) def test_lambda_call(tmpdir): From d1f8d4a94122c3e6cb1af258148bc58e3af8d7aa Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 5 Jan 2021 18:36:01 +0900 Subject: [PATCH 27/39] fix --- tests/callbacks/test_lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index 718efeff0b230..88037352f52c5 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -21,7 +21,7 @@ def test_lambda_raise_misconfiguration(): with pytest.raises(MisconfigurationException): - LambdaCallback(setup=lambda *args: args) + LambdaCallback(invalid=lambda *args: args) def test_lambda_call(tmpdir): From 72931158d9d21ec00b2f8b83ac095b8167e93992 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 6 Jan 2021 01:25:33 +0100 Subject: [PATCH 28/39] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/callbacks/lambda_cb.py | 4 ++-- tests/callbacks/test_lambda_cb.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 21eb9cc8e98c2..b0f6e229891c1 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -45,7 +45,7 @@ def __init__(self, **kwargs): for k, v in kwargs.items(): if k not in hooks: raise MisconfigurationException( - f"The event function: `{k}` doesn't exist in supported callbacks function." - f"Currently, `Callback` implements the following functions {hooks}" + f"The event function: `{k}` does not exist in supported callbacks function." + f" Currently, `Callback` implements the following functions {hooks}" ) setattr(self, k, v) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index 88037352f52c5..fedaa177b0e7d 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -20,7 +20,7 @@ def test_lambda_raise_misconfiguration(): - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match='does not exist in supported callbacks function'): LambdaCallback(invalid=lambda *args: args) From 4d85f59ad72bbbf2ef01dd1505a6d1667af99ee5 Mon Sep 17 00:00:00 2001 From: Wansoo Kim Date: Wed, 6 Jan 2021 09:46:58 +0900 Subject: [PATCH 29/39] Update pytorch_lightning/callbacks/lambda_cb.py Co-authored-by: Rohit Gupta --- pytorch_lightning/callbacks/lambda_cb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index b0f6e229891c1..3020b9b1d0938 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -31,7 +31,7 @@ class LambdaCallback(Callback): Create a simple callback on the fly using lambda functions. Args: - **kwargs: hooks supported by ``Callback`` + **kwargs: hooks supported by :class:`~pytorch_lightning.callbacks.base.Callback` Example:: From c9ecb8aaab82578aa43be36fc6f3d05542e9e32e Mon Sep 17 00:00:00 2001 From: marload Date: Wed, 6 Jan 2021 12:09:19 +0900 Subject: [PATCH 30/39] hc --- pytorch_lightning/callbacks/lambda_cb.py | 126 +++++++++++++++++++++-- tests/callbacks/test_lambda_cb.py | 6 -- 2 files changed, 115 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index 3020b9b1d0938..ee5d87e830f33 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -20,10 +20,9 @@ """ -import inspect +from typing import Callable, Optional from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException class LambdaCallback(Callback): @@ -40,12 +39,117 @@ class LambdaCallback(Callback): >>> trainer = Trainer(callbacks=[LambdaCallback(setup=lambda *args: print('setup'))]) """ - def __init__(self, **kwargs): - hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] - for k, v in kwargs.items(): - if k not in hooks: - raise MisconfigurationException( - f"The event function: `{k}` does not exist in supported callbacks function." - f" Currently, `Callback` implements the following functions {hooks}" - ) - setattr(self, k, v) + def __init__( + self, + setup: Optional[Callable] = None, + teardown: Optional[Callable] = None, + on_init_start: Optional[Callable] = None, + on_init_end: Optional[Callable] = None, + on_fit_start: Optional[Callable] = None, + on_fit_end: Optional[Callable] = None, + on_sanity_check_start: Optional[Callable] = None, + on_sanity_check_end: Optional[Callable] = None, + on_train_batch_start: Optional[Callable] = None, + on_train_batch_end: Optional[Callable] = None, + on_train_epoch_start: Optional[Callable] = None, + on_train_epoch_end: Optional[Callable] = None, + on_validation_epoch_start: Optional[Callable] = None, + on_validation_epoch_end: Optional[Callable] = None, + on_test_epoch_start: Optional[Callable] = None, + on_test_epoch_end: Optional[Callable] = None, + on_epoch_start: Optional[Callable] = None, + on_epoch_end: Optional[Callable] = None, + on_batch_start: Optional[Callable] = None, + on_validation_batch_start: Optional[Callable] = None, + on_validation_batch_end: Optional[Callable] = None, + on_test_batch_start: Optional[Callable] = None, + on_test_batch_end: Optional[Callable] = None, + on_batch_end: Optional[Callable] = None, + on_train_start: Optional[Callable] = None, + on_train_end: Optional[Callable] = None, + on_pretrain_routine_start: Optional[Callable] = None, + on_pretrain_routine_end: Optional[Callable] = None, + on_validation_start: Optional[Callable] = None, + on_validation_end: Optional[Callable] = None, + on_test_start: Optional[Callable] = None, + on_test_end: Optional[Callable] = None, + on_keyboard_interrupt: Optional[Callable] = None, + on_save_checkpoint: Optional[Callable] = None, + on_load_checkpoint: Optional[Callable] = None, + on_after_backward: Optional[Callable] = None, + on_before_zero_grad: Optional[Callable] = None, + ): + if setup is not None: + self.setup = setup + if teardown is not None: + self.teardown = teardown + if on_init_start is not None: + self.on_init_start = on_init_start + if on_init_end is not None: + self.on_init_end = on_init_end + if on_fit_start is not None: + self.on_fit_start = on_fit_start + if on_fit_end is not None: + self.on_fit_end = on_fit_end + if on_sanity_check_start is not None: + self.on_sanity_check_start = on_sanity_check_start + if on_sanity_check_end is not None: + self.on_sanity_check_end = on_sanity_check_end + if on_train_batch_start is not None: + self.on_train_batch_start = on_train_batch_start + if on_train_batch_end is not None: + self.on_train_batch_end = on_train_batch_end + if on_train_epoch_start is not None: + self.on_train_epoch_start = on_train_epoch_start + if on_train_epoch_end is not None: + self.on_train_epoch_end = on_train_epoch_end + if on_validation_epoch_start is not None: + self.on_validation_epoch_start = on_validation_epoch_start + if on_validation_epoch_end is not None: + self.on_validation_epoch_end = on_validation_epoch_end + if on_test_epoch_start is not None: + self.on_test_epoch_start = on_test_epoch_start + if on_test_epoch_end is not None: + self.on_test_epoch_end = on_test_epoch_end + if on_epoch_start is not None: + self.on_epoch_start = on_epoch_start + if on_epoch_end is not None: + self.on_epoch_end = on_epoch_end + if on_batch_start is not None: + self.on_batch_start = on_batch_start + if on_validation_batch_start is not None: + self.on_validation_batch_start = on_validation_batch_start + if on_validation_batch_end is not None: + self.on_validation_batch_end = on_validation_batch_end + if on_test_batch_start is not None: + self.on_test_batch_start = on_test_batch_start + if on_test_batch_end is not None: + self.on_test_batch_end = on_test_batch_end + if on_batch_end is not None: + self.on_batch_end = on_batch_end + if on_train_start is not None: + self.on_train_start = on_train_start + if on_train_end is not None: + self.on_train_end = on_train_end + if on_pretrain_routine_start is not None: + self.on_pretrain_routine_start = on_pretrain_routine_start + if on_pretrain_routine_end is not None: + self.on_pretrain_routine_end = on_pretrain_routine_end + if on_validation_start is not None: + self.on_validation_start = on_validation_start + if on_validation_end is not None: + self.on_validation_end = on_validation_end + if on_test_start is not None: + self.on_test_start = on_test_start + if on_test_end is not None: + self.on_test_end = on_test_end + if on_keyboard_interrupt is not None: + self.on_keyboard_interrupt = on_keyboard_interrupt + if on_save_checkpoint is not None: + self.on_save_checkpoint = on_save_checkpoint + if on_load_checkpoint is not None: + self.on_load_checkpoint = on_load_checkpoint + if on_after_backward is not None: + self.on_after_backward = on_after_backward + if on_before_zero_grad is not None: + self.on_before_zero_grad = on_before_zero_grad diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index fedaa177b0e7d..ce5ef7edbd04c 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -15,15 +15,9 @@ from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import LambdaCallback -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel -def test_lambda_raise_misconfiguration(): - with pytest.raises(MisconfigurationException, match='does not exist in supported callbacks function'): - LambdaCallback(invalid=lambda *args: args) - - def test_lambda_call(tmpdir): seed_everything(42) From 204429149d57700c274238e809b841f5cf5eda43 Mon Sep 17 00:00:00 2001 From: marload Date: Wed, 6 Jan 2021 12:14:03 +0900 Subject: [PATCH 31/39] rm pt --- tests/callbacks/test_lambda_cb.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index ce5ef7edbd04c..b2b1dcb91e61a 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest - from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import LambdaCallback from tests.base.boring_model import BoringModel From 556ea0928549a7573b8bd4daa4cac9da7ec74e3d Mon Sep 17 00:00:00 2001 From: marload Date: Fri, 8 Jan 2021 19:22:06 +0900 Subject: [PATCH 32/39] fix --- tests/callbacks/test_lambda_cb.py | 43 ++++++++----------------------- 1 file changed, 11 insertions(+), 32 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index b2b1dcb91e61a..f6f2110048e7e 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Trainer, seed_everything -from pytorch_lightning.callbacks import LambdaCallback +import inspect + +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.base.boring_model import BoringModel @@ -21,41 +23,18 @@ def test_lambda_call(tmpdir): checker = set() - hooks = [ - "setup", - "teardown", - "on_init_start", - "on_init_end", - "on_fit_start", - "on_fit_end", - "on_train_batch_start", - "on_train_batch_end", - "on_train_epoch_start", - "on_train_epoch_end", - "on_validation_epoch_start", - "on_validation_epoch_end", - "on_test_epoch_start", - "on_test_epoch_end", - "on_epoch_start", - "on_epoch_end", - "on_batch_start", - "on_batch_end", - "on_validation_batch_start", - "on_validation_batch_end", - "on_test_batch_start", - "on_test_batch_end", - "on_train_start", - "on_train_end", - "on_test_start", - "on_test_end", - ] + hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] model = BoringModel() hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} test_callback = LambdaCallback(**hooks_args) - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=[test_callback]) - + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=1, + max_epochs=1, + callbacks=[test_callback] + ) results = trainer.fit(model) trainer.test(model) From d190e15fa68ab4dbe0d39121bd24ef0d73a2bfa3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 10 Jan 2021 01:07:22 +0530 Subject: [PATCH 33/39] try fix --- tests/callbacks/test_lambda_cb.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index f6f2110048e7e..fc76f3258f6ee 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -21,19 +21,37 @@ def test_lambda_call(tmpdir): seed_everything(42) - checker = set() + class CustomModel(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch > 1: + raise KeyboardInterrupt + checker = set() hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] - model = BoringModel() - hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} - test_callback = LambdaCallback(**hooks_args) + hooks_args['on_save_checkpoint'] = (lambda x: lambda *args: [checker.add(x)])('on_save_checkpoint') + model = CustomModel() trainer = Trainer( default_root_dir=tmpdir, - num_sanity_val_steps=1, max_epochs=1, - callbacks=[test_callback] + limit_train_batches=1, + limit_val_batches=1, + callbacks=[LambdaCallback(**hooks_args)], + ) + results = trainer.fit(model) + assert results + + model = CustomModel() + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + resume_from_checkpoint=ckpt_path, + callbacks=[LambdaCallback(**hooks_args)], ) results = trainer.fit(model) trainer.test(model) From a27dbff4355e3372eb1da842f404fac223dd1ac5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 10 Jan 2021 01:13:03 +0530 Subject: [PATCH 34/39] whitespace --- pytorch_lightning/callbacks/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index e1e0c2cd6fad4..701e45ab17323 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -22,8 +22,8 @@ from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase __all__ = [ - 'BackboneLambdaFinetuningCallback', - 'BaseFinetuningCallback', + 'BackboneLambdaFinetuningCallback', + 'BaseFinetuningCallback', 'Callback', 'EarlyStopping', 'GPUStatsMonitor', From d1bd19a1430d057ac4c6f348e5c5d260077675a2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 10 Jan 2021 01:29:23 +0530 Subject: [PATCH 35/39] new hook --- pytorch_lightning/callbacks/lambda_cb.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_cb.py index ee5d87e830f33..2d111e7da7acd 100644 --- a/pytorch_lightning/callbacks/lambda_cb.py +++ b/pytorch_lightning/callbacks/lambda_cb.py @@ -41,6 +41,7 @@ class LambdaCallback(Callback): def __init__( self, + on_before_accelerator_backend_setup: Optional[Callable] = None, setup: Optional[Callable] = None, teardown: Optional[Callable] = None, on_init_start: Optional[Callable] = None, @@ -79,6 +80,8 @@ def __init__( on_after_backward: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, ): + if on_before_accelerator_backend_setup is not None: + self.on_before_accelerator_backend_setup = on_before_accelerator_backend_setup if setup is not None: self.setup = setup if teardown is not None: From afe018afd844ae8af5f09f822638b898c37488a1 Mon Sep 17 00:00:00 2001 From: marload Date: Sun, 10 Jan 2021 14:44:39 +0900 Subject: [PATCH 36/39] add raise --- tests/callbacks/test_lambda_cb.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index fc76f3258f6ee..a8d5f49805461 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -13,11 +13,19 @@ # limitations under the License. import inspect +import pytest + from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel +def test_lambda_raise_misconfiguration(): + with pytest.raises(MisconfigurationException, match="does not exist in supported callbacks function"): + LambdaCallback(invalid=lambda *args: args) + + def test_lambda_call(tmpdir): seed_everything(42) @@ -29,7 +37,7 @@ def on_train_epoch_start(self): checker = set() hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)] hooks_args = {h: (lambda x: lambda *args: checker.add(x))(h) for h in hooks} - hooks_args['on_save_checkpoint'] = (lambda x: lambda *args: [checker.add(x)])('on_save_checkpoint') + hooks_args["on_save_checkpoint"] = (lambda x: lambda *args: [checker.add(x)])("on_save_checkpoint") model = CustomModel() trainer = Trainer( From 709fb5b7a09dd8d266964faf08d117953b54d02d Mon Sep 17 00:00:00 2001 From: marload Date: Sun, 10 Jan 2021 15:04:30 +0900 Subject: [PATCH 37/39] fix --- tests/callbacks/test_lambda_cb.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index a8d5f49805461..ef03af16224f8 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -21,11 +21,6 @@ from tests.base.boring_model import BoringModel -def test_lambda_raise_misconfiguration(): - with pytest.raises(MisconfigurationException, match="does not exist in supported callbacks function"): - LambdaCallback(invalid=lambda *args: args) - - def test_lambda_call(tmpdir): seed_everything(42) From 9b93a2c1c674785c957c2bba2339cdd664b9c3d8 Mon Sep 17 00:00:00 2001 From: marload Date: Sun, 10 Jan 2021 15:07:29 +0900 Subject: [PATCH 38/39] remove unused --- tests/callbacks/test_lambda_cb.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_cb.py index ef03af16224f8..a22a03fa369ff 100644 --- a/tests/callbacks/test_lambda_cb.py +++ b/tests/callbacks/test_lambda_cb.py @@ -13,11 +13,8 @@ # limitations under the License. import inspect -import pytest - from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel From 72f3f0c86179613ef65e806b92ef779cc80e4835 Mon Sep 17 00:00:00 2001 From: marload Date: Tue, 12 Jan 2021 19:09:01 +0900 Subject: [PATCH 39/39] rename --- pytorch_lightning/callbacks/__init__.py | 2 +- .../callbacks/{lambda_cb.py => lambda_function.py} | 0 tests/callbacks/{test_lambda_cb.py => test_lambda_function.py} | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename pytorch_lightning/callbacks/{lambda_cb.py => lambda_function.py} (100%) rename tests/callbacks/{test_lambda_cb.py => test_lambda_function.py} (100%) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 701e45ab17323..a03dbcca85f7f 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -16,7 +16,7 @@ from pytorch_lightning.callbacks.finetuning import BackboneLambdaFinetuningCallback, BaseFinetuningCallback from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler -from pytorch_lightning.callbacks.lambda_cb import LambdaCallback +from pytorch_lightning.callbacks.lambda_function import LambdaCallback from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase diff --git a/pytorch_lightning/callbacks/lambda_cb.py b/pytorch_lightning/callbacks/lambda_function.py similarity index 100% rename from pytorch_lightning/callbacks/lambda_cb.py rename to pytorch_lightning/callbacks/lambda_function.py diff --git a/tests/callbacks/test_lambda_cb.py b/tests/callbacks/test_lambda_function.py similarity index 100% rename from tests/callbacks/test_lambda_cb.py rename to tests/callbacks/test_lambda_function.py