diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 57979b73f2cb6..f24a4ce8beb8a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -14,7 +14,7 @@ """Various hooks to be used in the Lightning code.""" -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import torch from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn @@ -501,7 +501,7 @@ def val_dataloader(self): will have an argument ``dataloader_idx`` which matches the order here. """ - def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: """ Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom data structure. @@ -549,6 +549,7 @@ def transfer_batch_to_device(self, batch, device) - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` """ + device = device or self.device return move_data_to_device(batch, device) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ef05ce69c1828..358b24fe1f40c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -22,6 +22,7 @@ import tempfile from abc import ABC from argparse import Namespace +from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch @@ -1530,12 +1531,19 @@ def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None: else: self._hparams = hp - def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwargs): - """Saves the model in ONNX format + @torch.no_grad() + def to_onnx( + self, + file_path: Union[str, Path], + input_sample: Optional[Any] = None, + **kwargs, + ): + """ + Saves the model in ONNX format Args: - file_path: The path of the file the model should be saved to. - input_sample: A sample of an input tensor for tracing. + file_path: The path of the file the onnx model should be saved to. + input_sample: An input for tracing. Default: None (Use self.example_input_array) **kwargs: Will be passed to torch.onnx.export function. Example: @@ -1554,31 +1562,32 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg ... os.path.isfile(tmpfile.name) True """ + mode = self.training - if isinstance(input_sample, Tensor): - input_data = input_sample - elif self.example_input_array is not None: - input_data = self.example_input_array - else: - if input_sample is not None: + if input_sample is None: + if self.example_input_array is None: raise ValueError( - f"Received `input_sample` of type {type(input_sample)}. Expected type is `Tensor`" + "Could not export to ONNX since neither `input_sample` nor" + " `model.example_input_array` attribute is set." ) - raise ValueError( - "Could not export to ONNX since neither `input_sample` nor" - " `model.example_input_array` attribute is set." - ) - input_data = input_data.to(self.device) + input_sample = self.example_input_array + + input_sample = self.transfer_batch_to_device(input_sample) + if "example_outputs" not in kwargs: self.eval() - with torch.no_grad(): - kwargs["example_outputs"] = self(input_data) + kwargs["example_outputs"] = self(input_sample) - torch.onnx.export(self, input_data, file_path, **kwargs) + torch.onnx.export(self, input_sample, file_path, **kwargs) + self.train(mode) + @torch.no_grad() def to_torchscript( - self, file_path: Optional[str] = None, method: Optional[str] = 'script', - example_inputs: Optional[Union[torch.Tensor, Tuple[torch.Tensor]]] = None, **kwargs + self, + file_path: Optional[Union[str, Path]] = None, + method: Optional[str] = 'script', + example_inputs: Optional[Any] = None, + **kwargs, ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """ By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. @@ -1590,7 +1599,7 @@ def to_torchscript( Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' - example_inputs: Tensor to be used to do tracing when method is set to 'trace'. + example_inputs: An input to be used to do tracing when method is set to 'trace'. Default: None (Use self.example_input_array) **kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or :func:`torch.jit.trace` function. @@ -1624,21 +1633,27 @@ def to_torchscript( This LightningModule as a torchscript, regardless of whether file_path is defined or not. """ - mode = self.training - with torch.no_grad(): - if method == 'script': - torchscript_module = torch.jit.script(self.eval(), **kwargs) - elif method == 'trace': - # if no example inputs are provided, try to see if model has example_input_array set - if example_inputs is None: - example_inputs = self.example_input_array - # automatically send example inputs to the right device and use trace - example_inputs = self.transfer_batch_to_device(example_inputs, device=self.device) - torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) - else: - raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was:" - f"{method}") + + if method == 'script': + torchscript_module = torch.jit.script(self.eval(), **kwargs) + elif method == 'trace': + # if no example inputs are provided, try to see if model has example_input_array set + if example_inputs is None: + if self.example_input_array is None: + raise ValueError( + 'Choosing method=`trace` requires either `example_inputs`' + ' or `model.example_input_array` to be defined' + ) + example_inputs = self.example_input_array + + # automatically send example inputs to the right device and use trace + example_inputs = self.transfer_batch_to_device(example_inputs) + torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs) + else: + raise ValueError("The 'method' parameter only supports 'script' or 'trace'," + f" but value given was: {method}") + self.train(mode) if file_path is not None: diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index a3919a6a8a7dd..82727d37479b6 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -21,44 +21,44 @@ import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate +from tests.base import BoringModel, EvalModelTemplate def test_model_saves_with_input_sample(tmpdir): """Test that ONNX model saves with input sample and size is greater than 3 MB""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_model_saves_on_gpu(tmpdir): """Test that model saves on gpu""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(gpus=1, max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.to_onnx(file_path, input_sample) assert os.path.isfile(file_path) - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 def test_model_saves_with_example_output(tmpdir): """Test that ONNX model saves when provided with example output""" - model = EvalModelTemplate() + model = BoringModel() trainer = Trainer(max_epochs=1) trainer.fit(model) file_path = os.path.join(tmpdir, "model.onnx") - input_sample = torch.randn((1, 28 * 28)) + input_sample = torch.randn((1, 32)) model.eval() example_outputs = model.forward(input_sample) model.to_onnx(file_path, input_sample, example_outputs=example_outputs) @@ -67,11 +67,13 @@ def test_model_saves_with_example_output(tmpdir): def test_model_saves_with_example_input_array(tmpdir): """Test that ONNX model saves with_example_input_array and size is greater than 3 MB""" - model = EvalModelTemplate() + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path) assert os.path.exists(file_path) is True - assert os.path.getsize(file_path) > 3e+06 + assert os.path.getsize(file_path) > 4e2 @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -100,7 +102,9 @@ def test_model_saves_on_multi_gpu(tmpdir): def test_verbose_param(tmpdir, capsys): """Test that output is present when verbose parameter is set""" - model = EvalModelTemplate() + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + file_path = os.path.join(tmpdir, "model.onnx") model.to_onnx(file_path, verbose=True) captured = capsys.readouterr() @@ -108,8 +112,8 @@ def test_verbose_param(tmpdir, capsys): def test_error_if_no_input(tmpdir): - """Test that an exception is thrown when there is no input tensor""" - model = EvalModelTemplate() + """Test that an error is thrown when there is no input tensor""" + model = BoringModel() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onnx") with pytest.raises(ValueError, match=r'Could not export to ONNX since neither `input_sample` nor' @@ -117,21 +121,12 @@ def test_error_if_no_input(tmpdir): model.to_onnx(file_path) -def test_error_if_input_sample_is_not_tensor(tmpdir): - """Test that an exception is thrown when there is no input tensor""" - model = EvalModelTemplate() - model.example_input_array = None - file_path = os.path.join(tmpdir, "model.onnx") - input_sample = np.random.randn(1, 28 * 28) - with pytest.raises(ValueError, match=f'Received `input_sample` of type {type(input_sample)}. Expected type is ' - f'`Tensor`'): - model.to_onnx(file_path, input_sample) - - def test_if_inference_output_is_valid(tmpdir): """Test that the output inferred from ONNX model is same as from PyTorch""" - model = EvalModelTemplate() - trainer = Trainer(max_epochs=5) + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + trainer = Trainer(max_epochs=2) trainer.fit(model) model.eval() diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index bf2c34b8bfef5..3c43b201f52e4 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -16,43 +16,72 @@ import pytest import torch -from tests.base import EvalModelTemplate +from tests.base import BoringModel from tests.base.datamodules import TrialMNISTDataModule from tests.base.models import ParityModuleRNN, BasicGAN @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) def test_torchscript_input_output(modelclass): """ Test that scripted LightningModule forward works. """ model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) + model.eval() - model_output = model(model.example_input_array) + with torch.no_grad(): + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) -def test_torchscript_input_output_trace(modelclass): - """ Test that traced LightningModule forward works. """ +def test_torchscript_example_input_output_trace(modelclass): + """ Test that traced LightningModule forward works with example_input_array """ model = modelclass() + + if isinstance(model, BoringModel): + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript(method='trace') assert isinstance(script, torch.jit.ScriptModule) + model.eval() - model_output = model(model.example_input_array) + with torch.no_grad(): + model_output = model(model.example_input_array) + script_output = script(model.example_input_array) assert torch.allclose(script_output, model_output) +def test_torchscript_input_output_trace(): + """ Test that traced LightningModule forward works with example_inputs """ + model = BoringModel() + example_inputs = torch.randn(1, 32) + script = model.to_torchscript(example_inputs=example_inputs, method='trace') + assert isinstance(script, torch.jit.ScriptModule) + + model.eval() + with torch.no_grad(): + model_output = model(example_inputs) + + script_output = script(example_inputs) + assert torch.allclose(script_output, model_output) + + @pytest.mark.parametrize("device", [ torch.device("cpu"), torch.device("cuda", 0) @@ -60,7 +89,9 @@ def test_torchscript_input_output_trace(modelclass): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_torchscript_device(device): """ Test that scripted module is on the correct device. """ - model = EvalModelTemplate().to(device) + model = BoringModel().to(device) + model.example_input_array = torch.randn(5, 32) + script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) @@ -69,7 +100,7 @@ def test_torchscript_device(device): def test_torchscript_retain_training_state(): """ Test that torchscript export does not alter the training mode of original model. """ - model = EvalModelTemplate() + model = BoringModel() model.train(True) script = model.to_torchscript() assert model.training @@ -81,7 +112,7 @@ def test_torchscript_retain_training_state(): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -100,7 +131,7 @@ def test_torchscript_properties(modelclass): @pytest.mark.parametrize("modelclass", [ - EvalModelTemplate, + BoringModel, ParityModuleRNN, BasicGAN, ]) @@ -109,9 +140,27 @@ def test_torchscript_properties(modelclass): reason="torch.save/load has bug loading script modules on torch <= 1.4", ) def test_torchscript_save_load(tmpdir, modelclass): - """ Test that scripted LightningModules is correctly saved and can be loaded. """ + """ Test that scripted LightningModule is correctly saved and can be loaded. """ model = modelclass() output_file = str(tmpdir / "model.pt") script = model.to_torchscript(file_path=output_file) loaded_script = torch.jit.load(output_file) assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) + + +def test_torchcript_invalid_method(tmpdir): + """Test that an error is thrown with invalid torchscript method""" + model = BoringModel() + model.train(True) + + with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): + model.to_torchscript(method='temp') + + +def test_torchscript_with_no_input(tmpdir): + """Test that an error is thrown when there is no input tensor""" + model = BoringModel() + model.example_input_array = None + + with pytest.raises(ValueError, match='requires either `example_inputs` or `model.example_input_array`'): + model.to_torchscript(method='trace') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 9b29d6ec2b1dd..c24f1f5421e5c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -958,6 +958,7 @@ def test_gradient_clipping(tmpdir): """ Test gradient clipping """ + tutils.reset_seed() model = EvalModelTemplate() @@ -995,6 +996,7 @@ def test_gradient_clipping_fp16(tmpdir): """ Test gradient clipping with fp16 """ + tutils.reset_seed() model = EvalModelTemplate()