Skip to content

Commit 075de93

Browse files
ananthsubcarmocca
andauthored
Reset current_fx properties on lightning module in teardown (#7247)
* Update trainer.py * cleanup module properties in teardown * Update test_trainer.py * Update lightning.py * Formatting * flake8 * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 40f8023 commit 075de93

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
9292
self._device_type = None
9393

9494
#: True if using amp
95-
self.use_amp = False
95+
self.use_amp: bool = False
9696

9797
#: The precision used
98-
self.precision = 32
98+
self.precision: int = 32
9999

100100
# optionally can be set by user
101101
self._example_input_array = None
102102
self._datamodule = None
103103
self._results: Optional[Result] = None
104-
self._current_fx_name = ''
105-
self._running_manual_backward = False
106-
self._current_hook_fx_name = None
107-
self._current_dataloader_idx = None
104+
self._current_fx_name: str = ''
105+
self._running_manual_backward: bool = False
106+
self._current_hook_fx_name: Optional[str] = None
107+
self._current_dataloader_idx: Optional[int] = None
108108
self._automatic_optimization: bool = True
109109
self._param_requires_grad_state = dict()
110110

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,6 +1147,10 @@ def call_teardown_hook(self, model: LightningModule) -> None:
11471147
self.teardown(stage=state)
11481148
model.teardown(stage=state)
11491149

1150+
model._current_fx_name = ""
1151+
model._current_hook_fx_name = None
1152+
model._current_dataloader_idx = None
1153+
11501154
def _reset_result_and_set_hook_fx_name(self, hook_name: str) -> bool:
11511155
# on_before_zero_grad is called within training_step
11521156
if "batch_start" in hook_name or hook_name in ("on_before_zero_grad", "on_after_backward"):

tests/trainer/test_trainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,3 +2041,33 @@ def test_fit_test_synchronization(tmpdir):
20412041
trainer.fit(model)
20422042
assert os.path.exists(checkpoint.best_model_path), f'Could not find checkpoint at rank {trainer.global_rank}'
20432043
trainer.test()
2044+
2045+
2046+
def test_module_current_fx_attributes_reset(tmpdir):
2047+
""" Ensure that lightning module's attributes related to current hook fx are reset at the end of execution. """
2048+
model = BoringModel()
2049+
model.validation_step = None
2050+
model.training_epoch_end = None
2051+
trainer = Trainer(
2052+
default_root_dir=tmpdir,
2053+
max_epochs=1,
2054+
checkpoint_callback=False,
2055+
logger=False,
2056+
limit_val_batches=0,
2057+
)
2058+
trainer.fit(model)
2059+
assert model._current_fx_name == "", f"_current_fx_name not reset after fit: {model._current_fx_name}"
2060+
assert (
2061+
model._current_hook_fx_name is None
2062+
), f"_current_hook_fx_name not reset after fit: {model._current_hook_fx_name}"
2063+
assert (
2064+
model._current_dataloader_idx is None
2065+
), f"_current_dataloader_idx not reset after fit: {model._current_dataloader_idx}"
2066+
trainer.test(model)
2067+
assert model._current_fx_name == "", f"_current_fx_name not reset after test: {model._current_fx_name}"
2068+
assert (
2069+
model._current_hook_fx_name is None
2070+
), f"_current_hook_fx_name not reset after test: {model._current_hook_fx_name}"
2071+
assert (
2072+
model._current_dataloader_idx is None
2073+
), f"_current_dataloader_idx not reset after test: {model._current_dataloader_idx}"

0 commit comments

Comments
 (0)