Skip to content

Commit f6d5782

Browse files
ethanwharrislexierule
authored andcommitted
Fix disabled grads after call to predict (#6657)
1 parent 5bb1838 commit f6d5782

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))
1616
- Fixed error on TPUs when there was no `ModelCheckpoint` ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
1717
- Fixed `trainer.test` freeze on TPUs ([#6654](https://github.com/PyTorchLightning/pytorch-lightning/pull/6654))
18+
- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657))
1819

1920

2021
## [1.2.5] - 2021-03-23

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,10 @@ def run_predict(self):
834834
self.predict_loop.predict(batch, batch_idx, dataloader_idx)
835835

836836
results = self.predict_loop.on_predict_epoch_end()
837+
838+
# re-enable grads
839+
torch.set_grad_enabled(True)
840+
837841
return results
838842

839843
def run_sanity_check(self, ref_model):

tests/trainer/test_trainer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,12 +1410,12 @@ def predict_dataloader(self):
14101410
return self._dataloaders
14111411

14121412

1413-
def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True):
1413+
def predict(tmpdir, accelerator, gpus, num_processes, model=None, plugins=None, datamodule=True):
14141414

14151415
dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))]
14161416

1417-
model = BoringModel()
1418-
datamodule = TestLightningDataModule(dataloaders)
1417+
model = model or BoringModel()
1418+
dm = TestLightningDataModule(dataloaders)
14191419

14201420
trainer = Trainer(
14211421
default_root_dir=tmpdir,
@@ -1428,7 +1428,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
14281428
plugins=plugins,
14291429
)
14301430
if datamodule:
1431-
results = trainer.predict(model, datamodule=datamodule)
1431+
results = trainer.predict(model, datamodule=dm)
14321432
else:
14331433
results = trainer.predict(model, dataloaders=dataloaders)
14341434

@@ -1439,6 +1439,19 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T
14391439
assert results[0][0].shape == torch.Size([1, 2])
14401440

14411441

1442+
def test_trainer_predict_grad(tmpdir):
1443+
class CustomBoringModel(BoringModel):
1444+
1445+
def predict_step(self, batch, batch_idx, dataloader_idx=None):
1446+
assert batch.expand_as(batch).grad_fn is None
1447+
return super().predict_step(batch, batch_idx, dataloader_idx)
1448+
1449+
predict(tmpdir, None, None, 1, model=CustomBoringModel())
1450+
1451+
x = torch.zeros(1, requires_grad=True)
1452+
assert x.expand_as(x).grad_fn is not None
1453+
1454+
14421455
@pytest.mark.parametrize('datamodule', [False, True])
14431456
def test_trainer_predict_cpu(tmpdir, datamodule):
14441457
predict(tmpdir, None, None, 1, datamodule=datamodule)

0 commit comments

Comments
 (0)