From 6c834e7287ec7df5e48e72bc1ac2f19436f57d13 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Dec 2021 01:26:40 +0100 Subject: [PATCH 1/2] Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 --- pytorch_lightning/core/lightning.py | 2 +- tests/models/test_onnx.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 38010f7acf0a1..fd285f7139953 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1807,7 +1807,7 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non input_sample = self._apply_batch_transfer_handler(input_sample) - if "example_outputs" not in kwargs: + if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs: self.eval() if isinstance(input_sample, Tuple): kwargs["example_outputs"] = self(*input_sample) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 7ab425dd12ea6..d111b266fb115 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -53,6 +53,7 @@ def test_model_saves_on_gpu(tmpdir): assert os.path.getsize(file_path) > 4e2 +@RunIf(max_torch="1.10") def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output.""" model = BoringModel() From af083d90eb964dc0c2969c5779e091a165e8e6da Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 17 Dec 2021 01:29:40 +0100 Subject: [PATCH 2/2] CHANGEOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c6262225b2f4..a1c0551703af2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -278,6 +278,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116)) + + + - Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))