We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 012b95b commit 61a3e89Copy full SHA for 61a3e89
tests/plugins/test_double_plugin.py
@@ -125,5 +125,6 @@ def test_double_precision(tmpdir, boring_model):
125
trainer.fit(model)
126
trainer.test(model)
127
trainer.predict(model)
128
+ torch.set_grad_enabled(True) # trainer.predict kills gradient
129
130
assert model.training_step == original_training_step
0 commit comments