diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c6262225b2f4..a1c0551703af2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -278,6 +278,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116)) + + + - Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 38010f7acf0a1..fd285f7139953 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1807,7 +1807,7 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non input_sample = self._apply_batch_transfer_handler(input_sample) - if "example_outputs" not in kwargs: + if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs: self.eval() if isinstance(input_sample, Tuple): kwargs["example_outputs"] = self(*input_sample) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 7ab425dd12ea6..d111b266fb115 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -53,6 +53,7 @@ def test_model_saves_on_gpu(tmpdir): assert os.path.getsize(file_path) > 4e2 +@RunIf(max_torch="1.10") def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output.""" model = BoringModel()