Skip to content

Commit b79dc3c

Browse files
author
Sean Naren
committed
[Fix] Ensure we set the eval/train flag correctly on accelerator model (#6877)
* Ensure we move the model to eval mode before running evaluation * Ensure we set the flag appropriately across all stages * Add test, move hooks logic * Apply same fix to the validate loop * Update pytorch_lightning/trainer/trainer.py * Fix function name * Fix order, add predict * Shorten the name * Fix input dm, drop duplicate on predict start hook call, as it's called in the setup function * Use hook, remove double call (cherry picked from commit 742c48e)
1 parent 6f7cf59 commit b79dc3c

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

pytorch_lightning/core/hooks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ def on_validation_model_eval(self) -> None:
150150
"""
151151
Sets the model to eval during the val loop
152152
"""
153-
self.eval()
153+
self.trainer.model.eval()
154154

155155
def on_validation_model_train(self) -> None:
156156
"""
157157
Sets the model to train during the val loop
158158
"""
159-
self.train()
159+
self.trainer.model.train()
160160

161161
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
162162
"""
@@ -208,19 +208,19 @@ def on_test_model_train(self) -> None:
208208
"""
209209
Sets the model to train during the test loop
210210
"""
211-
self.train()
211+
self.trainer.model.train()
212212

213213
def on_test_model_eval(self) -> None:
214214
"""
215215
Sets the model to eval during the test loop
216216
"""
217-
self.eval()
217+
self.trainer.model.eval()
218218

219219
def on_predict_model_eval(self) -> None:
220220
"""
221221
Sets the model to eval during the predict loop
222222
"""
223-
self.eval()
223+
self.trainer.model.eval()
224224

225225
def on_epoch_start(self) -> None:
226226
"""

pytorch_lightning/trainer/predict_loop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def on_predict_model_eval(self, *_, **__):
4444
model_ref.on_predict_model_eval()
4545

4646
def setup(self, model, max_batches, dataloaders):
47+
4748
# copy properties for forward overrides
4849
self.trainer.model_connector.copy_trainer_model_properties(model)
4950

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,11 +612,11 @@ def run_train(self):
612612
self.checkpoint_connector.has_trained = False
613613

614614
# enable train mode
615-
model = self.lightning_module
616-
model.train()
615+
self.model.train()
617616
torch.set_grad_enabled(True)
618617

619618
# reload data when needed
619+
model = self.lightning_module
620620
self.train_loop.reset_train_val_dataloaders(model)
621621

622622
# hook
@@ -814,6 +814,9 @@ def run_predict(self):
814814
model.zero_grad()
815815
torch.set_grad_enabled(False)
816816

817+
# call hook
818+
self.predict_loop.on_predict_start()
819+
817820
# set up the eval loop
818821
self.predict_loop.setup(model, max_batches, dataloaders)
819822

tests/trainer/test_trainer.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,14 @@ def setup(self, model, stage):
13811381
)
13821382
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_metrics")
13831383
def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, log_interval):
1384-
model = EvalModelTemplate()
1384+
1385+
class TestModel(BoringModel):
1386+
1387+
def training_step(self, *args, **kwargs):
1388+
self.log("foo", -1)
1389+
return super().training_step(*args, **kwargs)
1390+
1391+
model = TestModel()
13851392
trainer = Trainer(
13861393
default_root_dir=tmpdir,
13871394
log_every_n_steps=log_interval,
@@ -1932,3 +1939,32 @@ def forward(self, x):
19321939

19331940
with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
19341941
trainer.predict(model)
1942+
1943+
1944+
class TrainerStagesModel(BoringModel):
1945+
1946+
def on_train_start(self) -> None:
1947+
assert self.trainer.model.training
1948+
assert self.training
1949+
1950+
def on_validation_start(self) -> None:
1951+
assert not self.trainer.model.training
1952+
assert not self.training
1953+
1954+
def on_test_start(self) -> None:
1955+
assert not self.trainer.model.training
1956+
assert not self.training
1957+
1958+
def on_predict_start(self) -> None:
1959+
assert not self.trainer.model.training
1960+
assert not self.training
1961+
1962+
1963+
@pytest.mark.parametrize(['accelerator', 'num_processes'],
1964+
[(None, 1), pytest.param('ddp', 2, marks=RunIf(skip_windows=True))])
1965+
def test_model_in_correct_mode_during_stages(tmpdir, accelerator, num_processes):
1966+
model = TrainerStagesModel()
1967+
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
1968+
trainer.fit(model)
1969+
trainer.test(model)
1970+
trainer.predict(model, model.val_dataloader())

0 commit comments

Comments
 (0)