Skip to content

Commit 742c48e

Browse files
author
Sean Naren
authored
[Fix] Ensure we set the eval/train flag correctly on accelerator model (#6877)
* Ensure we move the model to eval mode before running evaluation * Ensure we set the flag appropriately across all stages * Add test, move hooks logic * Apply same fix to the validate loop * Update pytorch_lightning/trainer/trainer.py * Fix function name * Fix order, add predict * Shorten the name * Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function * Use hook, remove double call
1 parent 851fd7f commit 742c48e

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

pytorch_lightning/core/hooks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,13 @@ def on_validation_model_eval(self) -> None:
114114
"""
115115
Sets the model to eval during the val loop
116116
"""
117-
self.eval()
117+
self.trainer.model.eval()
118118

119119
def on_validation_model_train(self) -> None:
120120
"""
121121
Sets the model to train during the val loop
122122
"""
123-
self.train()
123+
self.trainer.model.train()
124124

125125
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
126126
"""
@@ -172,19 +172,19 @@ def on_test_model_train(self) -> None:
172172
"""
173173
Sets the model to train during the test loop
174174
"""
175-
self.train()
175+
self.trainer.model.train()
176176

177177
def on_test_model_eval(self) -> None:
178178
"""
179179
Sets the model to eval during the test loop
180180
"""
181-
self.eval()
181+
self.trainer.model.eval()
182182

183183
def on_predict_model_eval(self) -> None:
184184
"""
185185
Sets the model to eval during the predict loop
186186
"""
187-
self.eval()
187+
self.trainer.model.eval()
188188

189189
def on_epoch_start(self) -> None:
190190
"""

pytorch_lightning/trainer/predict_loop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def on_predict_model_eval(self, *_, **__):
4444
model_ref.on_predict_model_eval()
4545

4646
def setup(self, model, max_batches, dataloaders):
47-
self.trainer.call_hook("on_predict_start")
4847

4948
# copy properties for forward overrides
5049
self.trainer.model_connector.copy_trainer_model_properties(model)

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,11 +582,11 @@ def run_train(self) -> None:
582582
self.checkpoint_connector.has_trained = False
583583

584584
# enable train mode
585-
model = self.lightning_module
586-
model.train()
585+
self.model.train()
587586
torch.set_grad_enabled(True)
588587

589588
# reload data when needed
589+
model = self.lightning_module
590590
self.train_loop.reset_train_val_dataloaders(model)
591591

592592
# hook
@@ -772,8 +772,6 @@ def run_evaluate(self):
772772
return eval_loop_results
773773

774774
def run_predict(self):
775-
self.predict_loop.on_predict_start()
776-
777775
# prepare dataloaders
778776
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()
779777

@@ -789,6 +787,9 @@ def run_predict(self):
789787
model.zero_grad()
790788
torch.set_grad_enabled(False)
791789

790+
# call hook
791+
self.predict_loop.on_predict_start()
792+
792793
# set up the eval loop
793794
self.predict_loop.setup(model, max_batches, dataloaders)
794795

tests/trainer/test_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,9 @@ def setup(self, model, stage):
14381438
)
14391439
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
14401440
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
1441+
14411442
class TestModel(BoringModel):
1443+
14421444
def training_step(self, *args, **kwargs):
14431445
self.log("foo", -1)
14441446
return super().training_step(*args, **kwargs)
@@ -1888,3 +1890,33 @@ def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
18881890
trainer.validate()
18891891
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
18901892
trainer.test()
1893+
1894+
1895+
class TrainerStagesModel(BoringModel):
1896+
1897+
def on_train_start(self) -> None:
1898+
assert self.trainer.model.training
1899+
assert self.training
1900+
1901+
def on_validation_start(self) -> None:
1902+
assert not self.trainer.model.training
1903+
assert not self.training
1904+
1905+
def on_test_start(self) -> None:
1906+
assert not self.trainer.model.training
1907+
assert not self.training
1908+
1909+
def on_predict_start(self) -> None:
1910+
assert not self.trainer.model.training
1911+
assert not self.training
1912+
1913+
1914+
@pytest.mark.parametrize(['accelerator', 'num_processes'],
1915+
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
1916+
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
1917+
model = TrainerStagesModel()
1918+
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
1919+
trainer.fit(model)
1920+
trainer.validate(model)
1921+
trainer.test(model)
1922+
trainer.predict(model, model.val_dataloader())

0 commit comments

Comments
 (0)