diff --git a/CHANGELOG.md b/CHANGELOG.md index e1106189e0c17..71ab8b457b0b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434)) +- Fixed a bug where gradients were disabled after calling `Trainer.predict` ([#6657](https://github.com/PyTorchLightning/pytorch-lightning/pull/6657)) + + ## [1.2.4] - 2021-03-16 ### Changed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bb5d6919964e5..dbc493aa76e04 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -800,6 +800,10 @@ def run_predict(self): results = self.predict_loop.on_predict_epoch_end() self.predict_loop.on_predict_end() + + # re-enable grads + torch.set_grad_enabled(True) + return results def run_sanity_check(self, ref_model): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index d461d9d152e74..490f205a7bbec 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1450,6 +1450,19 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): predict(tmpdir, None, None, 1, model=CustomBoringModel()) +def test_trainer_predict_grad(tmpdir): + class CustomBoringModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx=None): + assert batch.expand_as(batch).grad_fn is None + return super().predict_step(batch, batch_idx, dataloader_idx) + + predict(tmpdir, None, None, 1, model=CustomBoringModel()) + + x = torch.zeros(1, requires_grad=True) + assert x.expand_as(x).grad_fn is not None + + @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule)