From d8d060812b6d8c6cc374daabd568d0fb1f27d2d6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 26 Oct 2020 23:06:59 +0530 Subject: [PATCH 01/18] branch merge --- pytorch_lightning/core/lightning.py | 65 ++++++++++++++-------------- tests/callbacks/test_progress_bar.py | 2 +- tests/models/test_onnx.py | 11 ----- 3 files changed, 33 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 22d63d0a03a74..3b236209aa30f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,8 @@ import tempfile from abc import ABC from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from pytorch_lightning import _logger as log @@ -1490,12 +1491,13 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: else: self._hparams = hp - def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): + @torch.no_grad() + def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs): """Saves the model in ONNX format Args: - file_path: The path of the file the model should be saved to. - input_sample: A sample of an input tensor for tracing. + file_path: The path of the file the onnx model should be saved to. + input_sample: An input for tracing. **kwargs: Will be passed to torch.onnx.export function. Example: @@ -1514,29 +1516,25 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg ... os.path.isfile(tmpfile.name) True """ - - if isinstance(input_sample, Tensor): + if input_sample is not None: input_data = input_sample elif self.example_input_array is not None: input_data = self.example_input_array else: - if input_sample is not None: - raise ValueError( - f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`" - ) - else: - raise ValueError( - "Could not export to ONNX since neither `input_sample` nor" - " `model.example_input_array` attribute is set." - ) - input_data = input_data.to(self.device) + raise ValueError( + "Could not export to ONNX since neither `input_sample` nor" + " `model.example_input_array` attribute is set." + ) + + input_data = self.transfer_batch_to_device(input_data, self.device) + if "example_outputs" not in kwargs: self.eval() - with torch.no_grad(): - kwargs["example_outputs"] = self(input_data) + kwargs["example_outputs"] = self(input_data) torch.onnx.export(self, input_data, file_path, **kwargs) + @torch.no_grad() def to_torchscript( self, file_path: Optional[str] = None, method: Optional[str] = 'script', example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs @@ -1551,7 +1549,7 @@ def to_torchscript( Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' - example_inputs: Tensor to be used to do tracing when method is set to 'trace'. + example_inputs: Input sample to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array) **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or :func:`torch.jit.trace` function. @@ -1585,21 +1583,22 @@ def to_torchscript( This LightningModule as a torchscript, regardless of whether file_path is defined or not. """ - mode = self.training - with torch.no_grad(): - if method == 'script': - torchscript_module = torch.jit.script(self.eval(), **kwargs) - elif method == 'trace': - # if no example inputs are provided, try to see if model has example_input_array set - if example_inputs is None: - example_inputs = self.example_input_array - # automatically send example inputs to the right device and use trace - example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) - else: - raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" - f"{method}") + + if method == 'script': + torchscript_module = torch.jit.script(self.eval(), **kwargs) + elif method == 'trace': + # if no example inputs are provided, try to see if model has example_input_array set + if example_inputs is None: + example_inputs = self.example_input_array + + # automatically send example inputs to the right device and use trace + example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + else: + raise ValueError("The 'method' parameter only supports 'script' or 'trace'," + f" but value given was: {method}") + self.train(mode) if file_path is not None: diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index d354b59682240..221844244ad75 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -231,7 +231,7 @@ def on_validation_epoch_end(self, trainer, pl_module): default_root_dir=tmpdir, max_epochs=1, num_sanity_val_steps=2, - limit_train_batches=0, + limit_train_batches=1, limit_val_batches=limit_val_batches, callbacks=[progress_bar], logger=False, diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 5d3cf7d6bdffc..3cedbdf8b4801 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -117,17 +117,6 @@ def test_error_if_no_input(tmpdir): model.to_onnx(file_path) -def test_error_if_input_sample_is_not_tensor(tmpdir): - """Test that an exception is thrown when there is no input tensor""" - model = EvalModelTemplate() - model.example_input_array = None - file_path = os.path.join(tmpdir, "model.onnx") - input_sample = np.random.randn(1, 28 * 28) - with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is ' - f'`Tensor`'): - model.to_onnx(file_path, input_sample) - - def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch""" model = EvalModelTemplate() From 0037f284cc554f4c30354637650d19d5145c3d3a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 26 Oct 2020 23:11:12 +0530 Subject: [PATCH 02/18] sample --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3b236209aa30f..9baa99300b397 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1497,7 +1497,7 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non Args: file_path: The path of the file the onnx model should be saved to. - input_sample: An input for tracing. + input_sample: An input sample for tracing. **kwargs: Will be passed to torch.onnx.export function. Example: From 2da23cde83a5523cc7888197fd483966c3d68b1b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 31 Oct 2020 22:10:44 +0530 Subject: [PATCH 03/18] update with valid input tensors --- pytorch_lightning/core/lightning.py | 45 ++++++++++----- tests/models/test_onnx.py | 42 +++++++++++++- tests/models/test_torchscript.py | 89 ++++++++++++++++++++++++++--- 3 files changed, 154 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9baa99300b397..66f8e0374d056 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1492,12 +1492,12 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: self._hparams = hp @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs): + def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs): """Saves the model in ONNX format Args: file_path: The path of the file the onnx model should be saved to. - input_sample: An input sample for tracing. + input_sample: An input tensor or tuple of tensors for tracing. Default: None (Use self.example_input_array) **kwargs: Will be passed to torch.onnx.export function. Example: @@ -1516,28 +1516,33 @@ def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = Non ... os.path.isfile(tmpfile.name) True """ - if input_sample is not None: - input_data = input_sample - elif self.example_input_array is not None: - input_data = self.example_input_array - else: + if input_sample is None: + if self.example_input_array is None: + raise ValueError( + "Could not export to ONNX since neither `input_sample` nor" + " `model.example_input_array` attribute is set." + ) + input_sample = self.example_input_array + + if isinstance(input_sample, Tensor): + input_sample = (input_sample,) + elif not (isinstance(input_sample, tuple) and all(isinstance(inp, Tensor) for inp in input_sample)): raise ValueError( - "Could not export to ONNX since neither `input_sample` nor" - " `model.example_input_array` attribute is set." + "Could not export to ONNX since input_sample is neither a Tensor nor tuple of Tensors" ) - input_data = self.transfer_batch_to_device(input_data, self.device) + input_sample = self.transfer_batch_to_device(input_sample, self.device) if "example_outputs" not in kwargs: self.eval() - kwargs["example_outputs"] = self(input_data) + kwargs["example_outputs"] = self(*input_sample) - torch.onnx.export(self, input_data, file_path, **kwargs) + torch.onnx.export(self, input_sample, file_path, **kwargs) @torch.no_grad() def to_torchscript( self, file_path: Optional[str] = None, method: Optional[str] = 'script', - example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs + example_inputs: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1549,7 +1554,7 @@ def to_torchscript( Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' - example_inputs: Input sample to be used to do tracing when method is set to 'trace'. + example_inputs: An input tensor or tuple of tensors to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array) **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or :func:`torch.jit.trace` function. @@ -1590,8 +1595,20 @@ def to_torchscript( elif method == 'trace': # if no example inputs are provided, try to see if model has example_input_array set if example_inputs is None: + if self.example_input_array is None: + raise ValueError( + 'Choosing method=`trace` requires either `example_inputs`' + ' or `model.example_input_array` to be defined' + ) example_inputs = self.example_input_array + if isinstance(example_inputs, Tensor): + example_inputs = (example_inputs,) + elif not (isinstance(example_inputs, tuple) and all(isinstance(inp, Tensor) for inp in example_inputs)): + raise ValueError( + "Could not export to torchscript since example_inputs is neither a Tensor nor tuple of Tensors" + ) + # automatically send example inputs to the right device and use trace example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 3cedbdf8b4801..aee74a9846109 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -108,7 +108,7 @@ def test_verbose_param(tmpdir, capsys): def test_error_if_no_input(tmpdir): - """Test that an exception is thrown when there is no input tensor""" + """Test that an error is thrown when there is no input tensor""" model = EvalModelTemplate() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onnx") @@ -141,3 +141,43 @@ def to_numpy(tensor): # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + +def test_model_saves_with_tuple_input(tmpdir): + """Test that ONNX model saves when input is tuple of tensors""" + class CustomModel(EvalModelTemplate): + def forward(self, x, y=None): + return super().forward(x) + + model = CustomModel() + trainer = Trainer(max_epochs=1) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onnx") + input_sample = (torch.randn(1, 28 * 28), torch.randn(1, 28 * 28)) + model.to_onnx(file_path, input_sample) + assert os.path.exists(file_path) is True + + input_sample = (torch.randn(1, 28 * 28), np.random.randn(1, 28 * 28)) + with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): + model.to_onnx(file_path, input_sample) + + +def test_error_with_invalid_input(tmpdir): + """Test that an error is thrown with invalid input""" + class CustomModel(EvalModelTemplate): + def forward(self, x): + if isinstance(x, dict): + x = x['x'] + + return super().forward(x) + + model = CustomModel() + trainer = Trainer(max_epochs=1) + trainer.fit(model) + + file_path = os.path.join(tmpdir, "model.onnx") + input_sample = {'x': torch.randn((1, 28 * 28))} + + with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): + model.to_onnx(file_path, input_sample) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index bf2c34b8bfef5..d8eb9bfd3529c 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from distutils.version import LooseVersion +import numpy as np import pytest import torch @@ -31,9 +32,11 @@ def test_torchscript_input_output(modelclass): model = modelclass() script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) + model.eval() - model_output = model(model.example_input_array) - script_output = script(model.example_input_array) + with torch.no_grad(): + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) @@ -42,14 +45,30 @@ def test_torchscript_input_output(modelclass): ParityModuleRNN, BasicGAN, ]) -def test_torchscript_input_output_trace(modelclass): - """ Test that traced LightningModule forward works. """ +def test_torchscript_example_input_output_trace(modelclass): + """ Test that traced LightningModule forward works with example_input_array """ model = modelclass() script = model.to_torchscript(method='trace') assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) + assert torch.allclose(script_output, model_output) + + +def test_torchscript_input_output_trace(): + """ Test that traced LightningModule forward works with example_inputs """ + model = EvalModelTemplate() + example_inputs = torch.randn(1, 28 * 28) + script = model.to_torchscript(example_inputs=example_inputs, method='trace') + assert isinstance(script, torch.jit.ScriptModule) + model.eval() - model_output = model(model.example_input_array) - script_output = script(model.example_input_array) + with torch.no_grad(): + model_output = model(example_inputs) + script_output = script(example_inputs) assert torch.allclose(script_output, model_output) @@ -109,9 +128,65 @@ def test_torchscript_properties(modelclass): reason="torch.save/load has bug loading script modules on torch <= 1.4", ) def test_torchscript_save_load(tmpdir, modelclass): - """ Test that scripted LightningModules is correctly saved and can be loaded. """ + """ Test that scripted LightningModule is correctly saved and can be loaded. """ model = modelclass() output_file = str(tmpdir / "model.pt") script = model.to_torchscript(file_path=output_file) loaded_script = torch.jit.load(output_file) assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) + + +def test_torchcript_invalid_method(tmpdir): + """Test that an error is thrown with invalid torchscript method""" + model = EvalModelTemplate() + model.train(True) + + with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): + model.to_torchscript(method='temp') + + +def test_torchscript_with_no_input(tmpdir): + """Test that an error is thrown when there is no input tensor""" + model = EvalModelTemplate() + model.example_input_array = None + + with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'): + model.to_torchscript(method='trace') + + +def test_torchscript_with_tuple_input(tmpdir): + """Test that traced LightningModule is created when input is tuple of tensors""" + class CustomModel(EvalModelTemplate): + def forward(self, x, y=None): + return super().forward(x) + + example_inputs = (torch.randn(1, 28 * 28), torch.randn(1, 28 * 28)) + model = CustomModel() + script = model.to_torchscript(example_inputs=example_inputs, method='trace') + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(*example_inputs) + script_output = script(*example_inputs) + assert torch.allclose(script_output, model_output) + + example_inputs = (torch.randn(1, 28 * 28), np.random.randn(1, 28 * 28)) + with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): + model.to_torchscript(example_inputs=example_inputs, method='trace') + + +def test_error_with_invalid_input(tmpdir): + """Test that an error is thrown with invalid input""" + class CustomModel(EvalModelTemplate): + def forward(self, x): + if isinstance(x, dict): + x = x['x'] + + return super().forward(x) + + model = CustomModel() + example_inputs = {'x': torch.randn(1, 28 * 28)} + + with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): + model.to_torchscript(example_inputs=example_inputs, method='trace') From b45792c848332f681efa16e558f66b5bf90177b7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 31 Oct 2020 22:13:44 +0530 Subject: [PATCH 04/18] pep --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 66f8e0374d056..492a789026f9b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1492,7 +1492,8 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: self._hparams = hp @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs): + def to_onnx(self, file_path: Union[str, Path], + input_sample: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs): """Saves the model in ONNX format Args: From e98cc33835d3aa59ef08ac5482d0318cc5b563c5 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 1 Nov 2020 00:14:40 +0530 Subject: [PATCH 05/18] pathlib --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 492a789026f9b..ce9c00f167e1e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1542,7 +1542,7 @@ def to_onnx(self, file_path: Union[str, Path], @torch.no_grad() def to_torchscript( - self, file_path: Optional[str] = None, method: Optional[str] = 'script', + self, file_path: Optional[Union[str, Path]] = None, method: Optional[str] = 'script', example_inputs: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ From 3096054897797830d8121f0b7e54e1851a1117ca Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 2 Nov 2020 23:46:45 +0530 Subject: [PATCH 06/18] Updated with BoringModel and added more input types --- pytorch_lightning/core/lightning.py | 4 +- tests/base/boring_model.py | 1 + tests/models/test_onnx.py | 62 +++++++++++++++------------ tests/models/test_torchscript.py | 65 ++++++++++++++++++----------- 4 files changed, 79 insertions(+), 53 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ce9c00f167e1e..4436e06584227 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1543,7 +1543,7 @@ def to_onnx(self, file_path: Union[str, Path], @torch.no_grad() def to_torchscript( self, file_path: Optional[Union[str, Path]] = None, method: Optional[str] = 'script', - example_inputs: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs + example_inputs: Optional[Union[Tensor, Sequence[Tensor]]] = None, **kwargs ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1605,7 +1605,7 @@ def to_torchscript( if isinstance(example_inputs, Tensor): example_inputs = (example_inputs,) - elif not (isinstance(example_inputs, tuple) and all(isinstance(inp, Tensor) for inp in example_inputs)): + elif not (isinstance(example_inputs, collections.Sequence) and all(isinstance(inp, Tensor) for inp in example_inputs)): raise ValueError( "Could not export to torchscript since example_inputs is neither a Tensor nor tuple of Tensors" ) diff --git a/tests/base/boring_model.py b/tests/base/boring_model.py index 6ceffe8562372..c1d2b6b3fad49 100644 --- a/tests/base/boring_model.py +++ b/tests/base/boring_model.py @@ -76,6 +76,7 @@ def training_step(...): """ super().__init__() self.layer = torch.nn.Linear(32, 2) + self.example_input_array = torch.rand(5, 32) def forward(self, x): return self.layer(x) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index aee74a9846109..38654e02fb811 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections import namedtuple import numpy as np import onnxruntime @@ -21,44 +22,44 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate +from tests.base import BoringModel def test_model_saves_with_input_sample(tmpdir): """Test that ONNX model saves with input sample and size is greater than 3 MB""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_model_saves_on_gpu(tmpdir): """Test that model saves on gpu""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(gpus=1, max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.eval() example_outputs = model.forward(input_sample) model.to_onnx(file_path, input_sample, example_outputs=example_outputs) @@ -67,11 +68,11 @@ def test_model_saves_with_example_output(tmpdir): def test_model_saves_with_example_input_array(tmpdir): """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" - model = EvalModelTemplate() + model = BoringModel() file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path) assert os.path.exists(file_path) is True - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -89,7 +90,7 @@ def test_model_saves_on_multi_gpu(tmpdir): progress_bar_refresh_rate=0 ) - model = EvalModelTemplate() + model = BoringModel() tpipes.run_model_test(trainer_options, model) @@ -100,7 +101,7 @@ def test_model_saves_on_multi_gpu(tmpdir): def test_verbose_param(tmpdir, capsys): """Test that output is present when verbose parameter is set""" - model = EvalModelTemplate() + model = BoringModel() file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, verbose=True) captured = capsys.readouterr() @@ -109,7 +110,7 @@ def test_verbose_param(tmpdir, capsys): def test_error_if_no_input(tmpdir): """Test that an error is thrown when there is no input tensor""" - model = EvalModelTemplate() + model = BoringModel() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onnx") with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor' @@ -119,8 +120,8 @@ def test_error_if_no_input(tmpdir): def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch""" - model = EvalModelTemplate() - trainer = Trainer(max_epochs=5) + model = BoringModel() + trainer = Trainer(max_epochs=2) trainer.fit(model) model.eval() @@ -143,29 +144,38 @@ def to_numpy(tensor): assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) -def test_model_saves_with_tuple_input(tmpdir): +def test_model_saves_with_sequence_input(tmpdir): """Test that ONNX model saves when input is tuple of tensors""" - class CustomModel(EvalModelTemplate): + class CustomModel(BoringModel): def forward(self, x, y=None): return super().forward(x) + def test_onnx_export(model, input_sample): + file_path = os.path.join(tmpdir, "model.onnx") + model.to_onnx(file_path, input_sample) + assert os.path.exists(file_path) is True + model = CustomModel() trainer = Trainer(max_epochs=1) trainer.fit(model) - file_path = os.path.join(tmpdir, "model.onnx") - input_sample = (torch.randn(1, 28 * 28), torch.randn(1, 28 * 28)) - model.to_onnx(file_path, input_sample) - assert os.path.exists(file_path) is True + # tuple input + input_sample = (torch.randn(1, 32), torch.randn(1, 32)) + test_onnx_export(model, input_sample) + + # NamedTuple input + input_sample = namedtuple('sample', ['x', 'y']) + input_sample = input_sample(torch.randn(1, 32), torch.randn(1, 32)) + test_onnx_export(model, input_sample) - input_sample = (torch.randn(1, 28 * 28), np.random.randn(1, 28 * 28)) with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): - model.to_onnx(file_path, input_sample) + input_sample = (torch.randn(1, 32), np.random.randn(1, 32)) + test_onnx_export(model, input_sample) def test_error_with_invalid_input(tmpdir): """Test that an error is thrown with invalid input""" - class CustomModel(EvalModelTemplate): + class CustomModel(BoringModel): def forward(self, x): if isinstance(x, dict): x = x['x'] @@ -177,7 +187,7 @@ def forward(self, x): trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = {'x': torch.randn((1, 28 * 28))} + input_sample = {'x': torch.randn((2, 32))} with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): model.to_onnx(file_path, input_sample) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index d8eb9bfd3529c..e180a1a9a9699 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -13,17 +13,18 @@ # limitations under the License. from distutils.version import LooseVersion import numpy as np +from collections import namedtuple import pytest import torch -from tests.base import EvalModelTemplate +from tests.base import BoringModel from tests.base.datamodules import TrialMNISTDataModule from tests.base.models import ParityModuleRNN, BasicGAN @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -41,7 +42,7 @@ def test_torchscript_input_output(modelclass): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -60,8 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass): def test_torchscript_input_output_trace(): """ Test that traced LightningModule forward works with example_inputs """ - model = EvalModelTemplate() - example_inputs = torch.randn(1, 28 * 28) + model = BoringModel() + example_inputs = torch.randn(1, 32) script = model.to_torchscript(example_inputs=example_inputs, method='trace') assert isinstance(script, torch.jit.ScriptModule) @@ -79,7 +80,7 @@ def test_torchscript_input_output_trace(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ - model = EvalModelTemplate().to(device) + model = BoringModel().to(device) script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) @@ -88,7 +89,7 @@ def test_torchscript_device(device): def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ - model = EvalModelTemplate() + model = BoringModel() model.train(True) script = model.to_torchscript() assert model.training @@ -100,7 +101,7 @@ def test_torchscript_retain_training_state(): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -119,7 +120,7 @@ def test_torchscript_properties(modelclass): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -138,7 +139,7 @@ def test_torchscript_save_load(tmpdir, modelclass): def test_torchcript_invalid_method(tmpdir): """Test that an error is thrown with invalid torchscript method""" - model = EvalModelTemplate() + model = BoringModel() model.train(True) with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): @@ -147,38 +148,52 @@ def test_torchcript_invalid_method(tmpdir): def test_torchscript_with_no_input(tmpdir): """Test that an error is thrown when there is no input tensor""" - model = EvalModelTemplate() + model = BoringModel() model.example_input_array = None with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'): model.to_torchscript(method='trace') -def test_torchscript_with_tuple_input(tmpdir): +def test_torchscript_with_sequence_input(tmpdir): """Test that traced LightningModule is created when input is tuple of tensors""" - class CustomModel(EvalModelTemplate): + class CustomModel(BoringModel): def forward(self, x, y=None): return super().forward(x) - example_inputs = (torch.randn(1, 28 * 28), torch.randn(1, 28 * 28)) + def test_torchscript_export(model, example_inputs): + script = model.to_torchscript(example_inputs=example_inputs, method='trace') + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(*example_inputs) + script_output = script(*example_inputs) + assert torch.allclose(script_output, model_output) + model = CustomModel() - script = model.to_torchscript(example_inputs=example_inputs, method='trace') - assert isinstance(script, torch.jit.ScriptModule) - model.eval() - with torch.no_grad(): - model_output = model(*example_inputs) - script_output = script(*example_inputs) - assert torch.allclose(script_output, model_output) + # tuple input + example_inputs = (torch.randn(1, 32), torch.randn(1, 32)) + test_torchscript_export(model, example_inputs) + + # list input + example_inputs = [torch.randn(1, 32), torch.randn(1, 32)] + test_torchscript_export(model, example_inputs) + + # NamedTuple input + example_inputs = namedtuple('sample', ['x', 'y']) + example_inputs = example_inputs(torch.randn(1, 32), torch.randn(1, 32)) + test_torchscript_export(model, example_inputs) - example_inputs = (torch.randn(1, 28 * 28), np.random.randn(1, 28 * 28)) with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): - model.to_torchscript(example_inputs=example_inputs, method='trace') + example_inputs = (torch.randn(1, 32), np.random.randn(1, 32)) + test_torchscript_export(model, example_inputs) def test_error_with_invalid_input(tmpdir): """Test that an error is thrown with invalid input""" - class CustomModel(EvalModelTemplate): + class CustomModel(BoringModel): def forward(self, x): if isinstance(x, dict): x = x['x'] @@ -186,7 +201,7 @@ def forward(self, x): return super().forward(x) model = CustomModel() - example_inputs = {'x': torch.randn(1, 28 * 28)} + example_inputs = {'x': torch.randn(1, 32)} with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): model.to_torchscript(example_inputs=example_inputs, method='trace') From 62d5b848e153169bf1c2180857eb57d5776f3a81 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 3 Nov 2020 00:06:37 +0530 Subject: [PATCH 07/18] try fix --- tests/models/test_onnx.py | 2 +- tests/models/test_torchscript.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 38654e02fb811..a6828f70a7848 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -165,7 +165,7 @@ def test_onnx_export(model, input_sample): # NamedTuple input input_sample = namedtuple('sample', ['x', 'y']) - input_sample = input_sample(torch.randn(1, 32), torch.randn(1, 32)) + input_sample = input_sample(x=torch.randn(1, 32), y=torch.randn(1, 32)) test_onnx_export(model, input_sample) with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index e180a1a9a9699..280ac2b81a442 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -183,7 +183,7 @@ def test_torchscript_export(model, example_inputs): # NamedTuple input example_inputs = namedtuple('sample', ['x', 'y']) - example_inputs = example_inputs(torch.randn(1, 32), torch.randn(1, 32)) + example_inputs = example_inputs(x=torch.randn(1, 32), y=torch.randn(1, 32)) test_torchscript_export(model, example_inputs) with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): From d05f145ecb201b1872766311045e03a8a0cfef36 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 3 Nov 2020 00:13:55 +0530 Subject: [PATCH 08/18] pep --- pytorch_lightning/core/lightning.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4436e06584227..9b07e0d97ceeb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1605,7 +1605,10 @@ def to_torchscript( if isinstance(example_inputs, Tensor): example_inputs = (example_inputs,) - elif not (isinstance(example_inputs, collections.Sequence) and all(isinstance(inp, Tensor) for inp in example_inputs)): + elif not ( + isinstance(example_inputs, collections.Sequence) + and all(isinstance(inp, Tensor) for inp in example_inputs) + ): raise ValueError( "Could not export to torchscript since example_inputs is neither a Tensor nor tuple of Tensors" ) From 9aab119ace8a7ffd7d5eb8f6e05ab38c3dad2589 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 3 Nov 2020 16:14:34 +0530 Subject: [PATCH 09/18] skip test with torch < 1.4 --- tests/models/test_onnx.py | 4 ++-- tests/models/test_torchscript.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index a6828f70a7848..a2e13db2767c8 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -22,7 +22,7 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from tests.base import BoringModel +from tests.base import BoringModel, EvalModelTemplate def test_model_saves_with_input_sample(tmpdir): @@ -90,7 +90,7 @@ def test_model_saves_on_multi_gpu(tmpdir): progress_bar_refresh_rate=0 ) - model = BoringModel() + model = EvalModelTemplate() tpipes.run_model_test(trainer_options, model) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 280ac2b81a442..8d736c555855a 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -155,6 +155,10 @@ def test_torchscript_with_no_input(tmpdir): model.to_torchscript(method='trace') +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("1.4.0"), + reason="torch has bug parsing namedtuples on torch < 1.4", +) def test_torchscript_with_sequence_input(tmpdir): """Test that traced LightningModule is created when input is tuple of tensors""" class CustomModel(BoringModel): From fb47563c57fef5adff7bf49fad0d880462977d84 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 3 Nov 2020 02:33:08 +0530 Subject: [PATCH 10/18] fix test --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d8c4b17c900cc..45651ed6ea76f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1575,7 +1575,7 @@ def to_torchscript( Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' - example_inputs: An input tensor or tuple of tensors to be used to do tracing when method is set to 'trace'. + example_inputs: An input tensor or tuple/list of tensors to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array) **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or :func:`torch.jit.trace` function. From 9d799f17b3ace4287360e4fb20c078be076c9817 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 16 Nov 2020 10:56:54 +0100 Subject: [PATCH 11/18] Apply suggestions from code review --- pytorch_lightning/core/lightning.py | 15 +++++++++++---- tests/models/test_onnx.py | 12 ++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e10c04e7513f7..a015f8ac7da35 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1684,8 +1684,12 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: self._hparams = hp @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], - input_sample: Optional[Union[Tensor, Tuple[Tensor]]] = None, **kwargs): + def to_onnx( + self, + file_path: Union[str, Path], + input_sample: Optional[Union[Tensor, Tuple[Tensor]]] = None, + **kwargs, + ): """Saves the model in ONNX format Args: @@ -1734,8 +1738,11 @@ def to_onnx(self, file_path: Union[str, Path], @torch.no_grad() def to_torchscript( - self, file_path: Optional[Union[str, Path]] = None, method: Optional[str] = 'script', - example_inputs: Optional[Union[Tensor, Sequence[Tensor]]] = None, **kwargs + self, + file_path: Optional[Union[str, Path]] = None, + method: Optional[str] = 'script', + example_inputs: Optional[Union[Tensor, Sequence[Tensor]]] = None, + **kwargs, ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index a2e13db2767c8..fa83922f06415 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -150,27 +150,27 @@ class CustomModel(BoringModel): def forward(self, x, y=None): return super().forward(x) - def test_onnx_export(model, input_sample): + def _assert_onnx_export(model, input_sample): file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, input_sample) assert os.path.exists(file_path) is True model = CustomModel() - trainer = Trainer(max_epochs=1) + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) # tuple input input_sample = (torch.randn(1, 32), torch.randn(1, 32)) - test_onnx_export(model, input_sample) + _assert_onnx_export(model, input_sample) # NamedTuple input input_sample = namedtuple('sample', ['x', 'y']) input_sample = input_sample(x=torch.randn(1, 32), y=torch.randn(1, 32)) - test_onnx_export(model, input_sample) + _assert_onnx_export(model, input_sample) with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): input_sample = (torch.randn(1, 32), np.random.randn(1, 32)) - test_onnx_export(model, input_sample) + _assert_onnx_export(model, input_sample) def test_error_with_invalid_input(tmpdir): @@ -183,7 +183,7 @@ def forward(self, x): return super().forward(x) model = CustomModel() - trainer = Trainer(max_epochs=1) + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") From 5770a75e1fb1ff92d489b4535c00feb77b0847de Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 16 Nov 2020 18:19:25 +0530 Subject: [PATCH 12/18] update tests --- pytorch_lightning/core/lightning.py | 6 +++--- tests/models/test_onnx.py | 10 +++++----- tests/models/test_torchscript.py | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a015f8ac7da35..08ec256c3c6eb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -33,7 +33,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_warn, AMPType +from pytorch_lightning.utilities import rank_zero_warn, AMPType, move_data_to_device from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -1728,7 +1728,7 @@ def to_onnx( "Could not export to ONNX since input_sample is neither a Tensor nor tuple of Tensors" ) - input_sample = self.transfer_batch_to_device(input_sample, self.device) + input_sample = move_data_to_device(input_sample, self.device) if "example_outputs" not in kwargs: self.eval() @@ -1813,7 +1813,7 @@ def to_torchscript( ) # automatically send example inputs to the right device and use trace - example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) + example_inputs = move_data_to_device(example_inputs, device=self.device) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError("The 'method' parameter only supports 'script' or 'trace'," diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index fa83922f06415..b96f414cd7a65 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -150,8 +150,8 @@ class CustomModel(BoringModel): def forward(self, x, y=None): return super().forward(x) - def _assert_onnx_export(model, input_sample): - file_path = os.path.join(tmpdir, "model.onnx") + def _assert_onnx_export(model, input_sample, filename): + file_path = os.path.join(tmpdir, filename) model.to_onnx(file_path, input_sample) assert os.path.exists(file_path) is True @@ -161,16 +161,16 @@ def _assert_onnx_export(model, input_sample): # tuple input input_sample = (torch.randn(1, 32), torch.randn(1, 32)) - _assert_onnx_export(model, input_sample) + _assert_onnx_export(model, input_sample, 'model_tuple.onnx') # NamedTuple input input_sample = namedtuple('sample', ['x', 'y']) input_sample = input_sample(x=torch.randn(1, 32), y=torch.randn(1, 32)) - _assert_onnx_export(model, input_sample) + _assert_onnx_export(model, input_sample, 'model_ntuple.onnx') with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): input_sample = (torch.randn(1, 32), np.random.randn(1, 32)) - _assert_onnx_export(model, input_sample) + _assert_onnx_export(model, input_sample, 'model_error.onnx') def test_error_with_invalid_input(tmpdir): diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 8d736c555855a..ceb7283b9ff81 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -165,7 +165,7 @@ class CustomModel(BoringModel): def forward(self, x, y=None): return super().forward(x) - def test_torchscript_export(model, example_inputs): + def _assert_torchscript_export(model, example_inputs): script = model.to_torchscript(example_inputs=example_inputs, method='trace') assert isinstance(script, torch.jit.ScriptModule) @@ -179,20 +179,20 @@ def test_torchscript_export(model, example_inputs): # tuple input example_inputs = (torch.randn(1, 32), torch.randn(1, 32)) - test_torchscript_export(model, example_inputs) + _assert_torchscript_export(model, example_inputs) # list input example_inputs = [torch.randn(1, 32), torch.randn(1, 32)] - test_torchscript_export(model, example_inputs) + _assert_torchscript_export(model, example_inputs) # NamedTuple input example_inputs = namedtuple('sample', ['x', 'y']) example_inputs = example_inputs(x=torch.randn(1, 32), y=torch.randn(1, 32)) - test_torchscript_export(model, example_inputs) + _assert_torchscript_export(model, example_inputs) with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): example_inputs = (torch.randn(1, 32), np.random.randn(1, 32)) - test_torchscript_export(model, example_inputs) + _assert_torchscript_export(model, example_inputs) def test_error_with_invalid_input(tmpdir): From 48ed87d2a320d9744a61d4f806956728c133c6e7 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 21 Nov 2020 00:37:56 +0530 Subject: [PATCH 13/18] Allow any input in to_onnx and to_torchscript --- pytorch_lightning/core/hooks.py | 5 ++- pytorch_lightning/core/lightning.py | 45 ++++++++-------------- tests/models/test_onnx.py | 50 ------------------------- tests/models/test_torchscript.py | 58 ----------------------------- 4 files changed, 19 insertions(+), 139 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 4272c0823bb19..2b555141c201f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -14,7 +14,7 @@ """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn @@ -508,7 +508,7 @@ def val_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data structure. @@ -556,6 +556,7 @@ def transfer_batch_to_device(self, batch, device) - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ + device = device or self.device return move_data_to_device(batch, device) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 08ec256c3c6eb..5fcff45683562 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -33,7 +33,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities import rank_zero_warn, AMPType, move_data_to_device +from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -1685,16 +1685,17 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: @torch.no_grad() def to_onnx( - self, - file_path: Union[str, Path], - input_sample: Optional[Union[Tensor, Tuple[Tensor]]] = None, - **kwargs, + self, + file_path: Union[str, Path], + input_sample: Optional[Any] = None, + **kwargs, ): - """Saves the model in ONNX format + """ + Saves the model in ONNX format Args: file_path: The path of the file the onnx model should be saved to. - input_sample: An input tensor or tuple of tensors for tracing. Default: None (Use self.example_input_array) + input_sample: An input for tracing. Default: None (Use self.example_input_array) **kwargs: Will be passed to torch.onnx.export function. Example: @@ -1713,6 +1714,8 @@ def to_onnx( ... os.path.isfile(tmpfile.name) True """ + mode = self.training + if input_sample is None: if self.example_input_array is None: raise ValueError( @@ -1721,27 +1724,21 @@ def to_onnx( ) input_sample = self.example_input_array - if isinstance(input_sample, Tensor): - input_sample = (input_sample,) - elif not (isinstance(input_sample, tuple) and all(isinstance(inp, Tensor) for inp in input_sample)): - raise ValueError( - "Could not export to ONNX since input_sample is neither a Tensor nor tuple of Tensors" - ) - - input_sample = move_data_to_device(input_sample, self.device) + input_sample = self.transfer_batch_to_device(input_sample) if "example_outputs" not in kwargs: self.eval() - kwargs["example_outputs"] = self(*input_sample) + kwargs["example_outputs"] = self(input_sample) torch.onnx.export(self, input_sample, file_path, **kwargs) + self.train(mode) @torch.no_grad() def to_torchscript( self, file_path: Optional[Union[str, Path]] = None, method: Optional[str] = 'script', - example_inputs: Optional[Union[Tensor, Sequence[Tensor]]] = None, + example_inputs: Optional[Any] = None, **kwargs, ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ @@ -1754,7 +1751,7 @@ def to_torchscript( Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' - example_inputs: An input tensor or tuple/list of tensors to be used to do tracing when method is set to 'trace'. + example_inputs: An input to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array) **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or :func:`torch.jit.trace` function. @@ -1802,18 +1799,8 @@ def to_torchscript( ) example_inputs = self.example_input_array - if isinstance(example_inputs, Tensor): - example_inputs = (example_inputs,) - elif not ( - isinstance(example_inputs, collections.Sequence) - and all(isinstance(inp, Tensor) for inp in example_inputs) - ): - raise ValueError( - "Could not export to torchscript since example_inputs is neither a Tensor nor tuple of Tensors" - ) - # automatically send example inputs to the right device and use trace - example_inputs = move_data_to_device(example_inputs, device=self.device) + example_inputs = self.transfer_batch_to_device(example_inputs) torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) else: raise ValueError("The 'method' parameter only supports 'script' or 'trace'," diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index b96f414cd7a65..7e45c75b0716a 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import namedtuple import numpy as np import onnxruntime @@ -142,52 +141,3 @@ def to_numpy(tensor): # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) - - -def test_model_saves_with_sequence_input(tmpdir): - """Test that ONNX model saves when input is tuple of tensors""" - class CustomModel(BoringModel): - def forward(self, x, y=None): - return super().forward(x) - - def _assert_onnx_export(model, input_sample, filename): - file_path = os.path.join(tmpdir, filename) - model.to_onnx(file_path, input_sample) - assert os.path.exists(file_path) is True - - model = CustomModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - trainer.fit(model) - - # tuple input - input_sample = (torch.randn(1, 32), torch.randn(1, 32)) - _assert_onnx_export(model, input_sample, 'model_tuple.onnx') - - # NamedTuple input - input_sample = namedtuple('sample', ['x', 'y']) - input_sample = input_sample(x=torch.randn(1, 32), y=torch.randn(1, 32)) - _assert_onnx_export(model, input_sample, 'model_ntuple.onnx') - - with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): - input_sample = (torch.randn(1, 32), np.random.randn(1, 32)) - _assert_onnx_export(model, input_sample, 'model_error.onnx') - - -def test_error_with_invalid_input(tmpdir): - """Test that an error is thrown with invalid input""" - class CustomModel(BoringModel): - def forward(self, x): - if isinstance(x, dict): - x = x['x'] - - return super().forward(x) - - model = CustomModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - trainer.fit(model) - - file_path = os.path.join(tmpdir, "model.onnx") - input_sample = {'x': torch.randn((2, 32))} - - with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): - model.to_onnx(file_path, input_sample) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index ceb7283b9ff81..6bcedf7f7b436 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from distutils.version import LooseVersion -import numpy as np -from collections import namedtuple import pytest import torch @@ -153,59 +151,3 @@ def test_torchscript_with_no_input(tmpdir): with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'): model.to_torchscript(method='trace') - - -@pytest.mark.skipif( - LooseVersion(torch.__version__) < LooseVersion("1.4.0"), - reason="torch has bug parsing namedtuples on torch < 1.4", -) -def test_torchscript_with_sequence_input(tmpdir): - """Test that traced LightningModule is created when input is tuple of tensors""" - class CustomModel(BoringModel): - def forward(self, x, y=None): - return super().forward(x) - - def _assert_torchscript_export(model, example_inputs): - script = model.to_torchscript(example_inputs=example_inputs, method='trace') - assert isinstance(script, torch.jit.ScriptModule) - - model.eval() - with torch.no_grad(): - model_output = model(*example_inputs) - script_output = script(*example_inputs) - assert torch.allclose(script_output, model_output) - - model = CustomModel() - - # tuple input - example_inputs = (torch.randn(1, 32), torch.randn(1, 32)) - _assert_torchscript_export(model, example_inputs) - - # list input - example_inputs = [torch.randn(1, 32), torch.randn(1, 32)] - _assert_torchscript_export(model, example_inputs) - - # NamedTuple input - example_inputs = namedtuple('sample', ['x', 'y']) - example_inputs = example_inputs(x=torch.randn(1, 32), y=torch.randn(1, 32)) - _assert_torchscript_export(model, example_inputs) - - with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): - example_inputs = (torch.randn(1, 32), np.random.randn(1, 32)) - _assert_torchscript_export(model, example_inputs) - - -def test_error_with_invalid_input(tmpdir): - """Test that an error is thrown with invalid input""" - class CustomModel(BoringModel): - def forward(self, x): - if isinstance(x, dict): - x = x['x'] - - return super().forward(x) - - model = CustomModel() - example_inputs = {'x': torch.randn(1, 32)} - - with pytest.raises(ValueError, match='neither a Tensor nor tuple of Tensors'): - model.to_torchscript(example_inputs=example_inputs, method='trace') From 689b63075d170634ee62bd193a419f0cf5abd3e6 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sun, 22 Nov 2020 18:15:37 +0530 Subject: [PATCH 14/18] Update tests/models/test_torchscript.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- tests/models/test_torchscript.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 6bcedf7f7b436..1962fe0f50adf 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -35,7 +35,7 @@ def test_torchscript_input_output(modelclass): model.eval() with torch.no_grad(): model_output = model(model.example_input_array) - script_output = script(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) From e85dc074c3978a4a5e5d744dba885130dd855265 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 22 Nov 2020 18:21:10 +0530 Subject: [PATCH 15/18] no_grad --- tests/models/test_torchscript.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 1962fe0f50adf..bbe007c067c92 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -35,6 +35,7 @@ def test_torchscript_input_output(modelclass): model.eval() with torch.no_grad(): model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) @@ -53,7 +54,8 @@ def test_torchscript_example_input_output_trace(modelclass): model.eval() with torch.no_grad(): model_output = model(model.example_input_array) - script_output = script(model.example_input_array) + + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) @@ -67,7 +69,8 @@ def test_torchscript_input_output_trace(): model.eval() with torch.no_grad(): model_output = model(example_inputs) - script_output = script(example_inputs) + + script_output = script(example_inputs) assert torch.allclose(script_output, model_output) From 3c9ed6d4550279b834213076a451e4329ed35acf Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 1 Dec 2020 13:17:53 +0530 Subject: [PATCH 16/18] try fix random failing test --- tests/trainer/test_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 085d361952844..f98b2c0b72cdd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -958,6 +958,7 @@ def test_gradient_clipping(tmpdir): """ Test gradient clipping """ + tutils.reset_seed() model = EvalModelTemplate() @@ -995,6 +996,7 @@ def test_gradient_clipping_fp16(tmpdir): """ Test gradient clipping with fp16 """ + tutils.reset_seed() model = EvalModelTemplate() From e7b1a0a3d56715ad720b90f6bed9485f10fa412e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Dec 2020 01:35:54 +0530 Subject: [PATCH 17/18] rm example_input_array --- tests/base/boring_model.py | 1 - tests/models/test_onnx.py | 6 ++++++ tests/models/test_torchscript.py | 8 ++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/base/boring_model.py b/tests/base/boring_model.py index c1d2b6b3fad49..6ceffe8562372 100644 --- a/tests/base/boring_model.py +++ b/tests/base/boring_model.py @@ -76,7 +76,6 @@ def training_step(...): """ super().__init__() self.layer = torch.nn.Linear(32, 2) - self.example_input_array = torch.rand(5, 32) def forward(self, x): return self.layer(x) diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 7e45c75b0716a..bbd142b652787 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -68,6 +68,8 @@ def test_model_saves_with_example_output(tmpdir): def test_model_saves_with_example_input_array(tmpdir): """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" model = BoringModel() + model.example_input_array = torch.randn(5, 32) + file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path) assert os.path.exists(file_path) is True @@ -101,6 +103,8 @@ def test_model_saves_on_multi_gpu(tmpdir): def test_verbose_param(tmpdir, capsys): """Test that output is present when verbose parameter is set""" model = BoringModel() + model.example_input_array = torch.randn(5, 32) + file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, verbose=True) captured = capsys.readouterr() @@ -120,6 +124,8 @@ def test_error_if_no_input(tmpdir): def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch""" model = BoringModel() + model.example_input_array = torch.randn(5, 32) + trainer = Trainer(max_epochs=2) trainer.fit(model) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index bbe007c067c92..5526fd31f4ff9 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -29,6 +29,10 @@ def test_torchscript_input_output(modelclass): """ Test that scripted LightningModule forward works. """ model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) @@ -48,6 +52,10 @@ def test_torchscript_input_output(modelclass): def test_torchscript_example_input_output_trace(modelclass): """ Test that traced LightningModule forward works with example_input_array """ model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript(method='trace') assert isinstance(script, torch.jit.ScriptModule) From 341a8cd5c74365e1cf397dd5b90fee9a8ef51724 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 2 Dec 2020 02:30:43 +0530 Subject: [PATCH 18/18] rm example_input_array --- tests/models/test_torchscript.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index 5526fd31f4ff9..3c43b201f52e4 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -90,6 +90,8 @@ def test_torchscript_input_output_trace(): def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ model = BoringModel().to(device) + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device))