From a15a29cc85ef7ec74851c1cc4284993cc71ca025 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 27 Apr 2021 23:13:34 -0700 Subject: [PATCH 1/7] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fa42a75c24829..76948428b3094 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1150,6 +1150,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) + model._current_fx_name = "" def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool: # on_before_zero_grad is called within training_step From e8b21730d15942e65d3eaa99d9b3a08ba2fade49 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 27 Apr 2021 23:44:22 -0700 Subject: [PATCH 2/7] cleanup module properties in teardown --- pytorch_lightning/trainer/trainer.py | 3 +++ tests/trainer/test_trainer.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 76948428b3094..9bcbbc0e7e9a5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1151,6 +1151,9 @@ def call_teardown_hook(self, model: LightningModule) -> None: self.teardown(stage=state) model.teardown(stage=state) model._current_fx_name = "" + model._current_hook_fx_name = None + model._current_dataloader_idx = None + def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool: # on_before_zero_grad is called within training_step diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3a35912fa7936..5bfda485f7e4b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2047,3 +2047,25 @@ def test_fit_test_synchronization(tmpdir): trainer.fit(model) assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}' trainer.test() + + +def test_module_current_fx_attributes_reset(tmpdir): + """ Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """ + model = BoringModel() + model.validation_step = None + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + checkpoint_callback=False, + logger=False, + limit_val_batches=0, + ) + trainer.fit(model) + assert model._current_fx_name == "", f"module._current_fx_name not cleaned up after fit: {model._current_fx_name}" + assert model._current_hook_fx_name is None, f"{model._current_hook_fx_name}" + assert model._current_dataloader_idx is None, f"{model._current_dataloader_idx}" + trainer.test(model) + assert model._current_fx_name == "", f"module._current_fx_name not cleaned up after test: {model._current_fx_name}" + assert model._current_hook_fx_name is None, f"module._current_hook_fx_name not cleaned up after test: {model._current_hook_fx_name}" + assert model._current_dataloader_idx is None, f"module._current_dataloader_idx not cleaned up after test: {model._current_dataloader_idx}" From c327a35b651a078522e6fb2e6d032f9c7756be61 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 28 Apr 2021 00:17:36 -0700 Subject: [PATCH 3/7] Update test_trainer.py --- tests/trainer/test_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5bfda485f7e4b..676a264885ff1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2062,10 +2062,10 @@ def test_module_current_fx_attributes_reset(tmpdir): limit_val_batches=0, ) trainer.fit(model) - assert model._current_fx_name == "", f"module._current_fx_name not cleaned up after fit: {model._current_fx_name}" - assert model._current_hook_fx_name is None, f"{model._current_hook_fx_name}" - assert model._current_dataloader_idx is None, f"{model._current_dataloader_idx}" + assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}" + assert model._current_hook_fx_name is None, f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}" + assert model._current_dataloader_idx is None, f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}" trainer.test(model) - assert model._current_fx_name == "", f"module._current_fx_name not cleaned up after test: {model._current_fx_name}" - assert model._current_hook_fx_name is None, f"module._current_hook_fx_name not cleaned up after test: {model._current_hook_fx_name}" - assert model._current_dataloader_idx is None, f"module._current_dataloader_idx not cleaned up after test: {model._current_dataloader_idx}" + assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}" + assert model._current_hook_fx_name is None, f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}" + assert model._current_dataloader_idx is None, f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}" From fb8cc29895cb184bc6260e7dde88ac5e826386a7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 28 Apr 2021 00:25:00 -0700 Subject: [PATCH 4/7] Update lightning.py --- pytorch_lightning/core/lightning.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 80af2fbc1a9bb..996b08522d166 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -92,19 +92,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._device_type = None #: True if using amp - self.use_amp = False + self.use_amp: bool = False #: The precision used - self.precision = 32 + self.precision: int = 32 # optionally can be set by user self._example_input_array = None self._datamodule = None self._results: Optional[Result] = None - self._current_fx_name = '' - self._running_manual_backward = False - self._current_hook_fx_name = None - self._current_dataloader_idx = None + self._current_fx_name: str = '' + self._running_manual_backward: bool = False + self._current_hook_fx_name: Optional[str] = None + self._current_dataloader_idx: Optional[int] = None self._automatic_optimization: bool = True self._param_requires_grad_state = dict() From dd6cf9ddf4883d76b9ddb35e73676e738082c13c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 13:43:35 +0200 Subject: [PATCH 5/7] Formatting --- tests/trainer/test_trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 676a264885ff1..0fb060ef31903 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2063,9 +2063,17 @@ def test_module_current_fx_attributes_reset(tmpdir): ) trainer.fit(model) assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}" - assert model._current_hook_fx_name is None, f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}" - assert model._current_dataloader_idx is None, f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}" + assert ( + model._current_hook_fx_name is None + ), f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}" + assert ( + model._current_dataloader_idx is None + ), f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}" trainer.test(model) assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}" - assert model._current_hook_fx_name is None, f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}" - assert model._current_dataloader_idx is None, f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}" + assert ( + model._current_hook_fx_name is None + ), f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}" + assert ( + model._current_dataloader_idx is None + ), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}" From 1a2080cb426e494f5e66be8a12a8e401b736ab16 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 28 Apr 2021 13:45:05 +0200 Subject: [PATCH 6/7] flake8 --- pytorch_lightning/trainer/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9bcbbc0e7e9a5..cd39e19eb7e53 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1154,7 +1154,6 @@ def call_teardown_hook(self, model: LightningModule) -> None: model._current_hook_fx_name = None model._current_dataloader_idx = None - def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool: # on_before_zero_grad is called within training_step if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"): From afb50bd69761959b4f6323cb402d1aa45a5b30d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 28 Apr 2021 20:27:23 +0200 Subject: [PATCH 7/7] Update pytorch_lightning/trainer/trainer.py --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cd39e19eb7e53..3731a6d0bd8cb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1150,6 +1150,7 @@ def call_teardown_hook(self, model: LightningModule) -> None: self.profiler.teardown(stage=state) self.teardown(stage=state) model.teardown(stage=state) + model._current_fx_name = "" model._current_hook_fx_name = None model._current_dataloader_idx = None