Skip to content

Commit e3503a1

Browse files
carmoccarohitgr7
authored andcommitted
Avoid the deprecated onnx.export(example_outputs=...) in torch 1.10 (#11116)
1 parent 5f4d639 commit e3503a1

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Fixed `NeptuneLogger` when using DDP ([#11030](https://github.com/PyTorchLightning/pytorch-lightning/pull/11030))
1212
- Fixed a bug to disable logging hyperparameters in logger if there are no hparams ([#11105](https://github.com/PyTorchLightning/pytorch-lightning/issues/11105))
13+
- Avoid the deprecated `onnx.export(example_outputs=...)` in torch 1.10 ([#11116](https://github.com/PyTorchLightning/pytorch-lightning/pull/11116))
1314
- Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078))
1415
- Fixed an `AttributeError` occuring when using a `CombinedLoader` (multiple dataloaders) for prediction ([#11111](https://github.com/PyTorchLightning/pytorch-lightning/pull/11111))
1516

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,7 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non
18891889

18901890
input_sample = self._apply_batch_transfer_handler(input_sample)
18911891

1892-
if "example_outputs" not in kwargs:
1892+
if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs:
18931893
self.eval()
18941894
if isinstance(input_sample, Tuple):
18951895
kwargs["example_outputs"] = self(*input_sample)

tests/models/test_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def test_model_saves_on_gpu(tmpdir):
5353
assert os.path.getsize(file_path) > 4e2
5454

5555

56+
@RunIf(max_torch="1.10")
5657
def test_model_saves_with_example_output(tmpdir):
5758
"""Test that ONNX model saves when provided with example output."""
5859
model = BoringModel()

0 commit comments

Comments
 (0)