From decf7eb100bb8351a7079077042627d1ba7607d8 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 2 Jun 2021 16:32:39 +0200 Subject: [PATCH 1/5] Fixed support for torch Module type hints in LightningCLI --- CHANGELOG.md | 3 +++ requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 49 +++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cbe3efd12c06e..8c3585c57dac4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,6 +182,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924)) +- Fixed support for torch Module type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807)) + + ## [1.3.2] - 2021-05-18 ### Changed diff --git a/requirements/extra.txt b/requirements/extra.txt index c41f464ef383b..e75b20e2ca401 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.12.0 +jsonargparse[signatures]>=3.13.0 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 5780a83e75db8..a0e57465116f3 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -20,9 +20,11 @@ from argparse import Namespace from contextlib import redirect_stdout from io import StringIO +from typing import List, Optional from unittest import mock import pytest +import torch import yaml from pytorch_lightning import LightningDataModule, LightningModule, Trainer @@ -30,6 +32,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel @@ -443,3 +446,49 @@ def __init__( assert cli.model.submodule2 == cli.config_init['model']['submodule2'] assert isinstance(cli.config_init['model']['submodule1'], BoringModel) assert isinstance(cli.config_init['model']['submodule2'], BoringModel) + + +@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason='torchvision is required') +def test_lightning_cli_torch_modules(tmpdir): + + class MainModule(BoringModel): + + def __init__( + self, + activation: torch.nn.Module = None, + transform: Optional[List[torch.nn.Module]] = None, + ): + super().__init__() + self.activation = activation + self.transform = transform + + config = """model: + activation: + class_path: torch.nn.LeakyReLU + init_args: + negative_slope: 0.2 + transform: + - class_path: torchvision.transforms.Resize + init_args: + size: 64 + - class_path: torchvision.transforms.CenterCrop + init_args: + size: 64 + """ + config_path = tmpdir / 'config.yaml' + with open(config_path, 'w') as f: + f.write(config) + + cli_args = [ + f'--trainer.default_root_dir={tmpdir}', + '--trainer.max_epochs=1', + f'--config={str(config_path)}', + ] + + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = LightningCLI(MainModule) + + assert isinstance(cli.model.activation, torch.nn.LeakyReLU) + assert cli.model.activation.negative_slope == 0.2 + assert len(cli.model.transform) == 2 + assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) From 4bf801748df711a20584f835c7cd0877e7ae3a09 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 3 Jun 2021 07:48:32 +0200 Subject: [PATCH 2/5] - Fix issue with serializing values when type hint is Any. - Run unit test only on newer torchvision versions in which the base class is Module. --- requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/requirements/extra.txt b/requirements/extra.txt index e75b20e2ca401..cb9515beefb9a 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.13.0 +jsonargparse[signatures]>=3.13.1 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index a0e57465116f3..db2daa2c73e8c 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -20,6 +20,7 @@ from argparse import Namespace from contextlib import redirect_stdout from io import StringIO +from packaging import version from typing import List, Optional from unittest import mock @@ -36,6 +37,11 @@ from tests.helpers import BoringDataModule, BoringModel +torchvision_version = version.parse('0') +if _TORCHVISION_AVAILABLE: + torchvision_version = version.parse(__import__('torchvision').__version__) + + @mock.patch('argparse.ArgumentParser.parse_args') def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" @@ -448,7 +454,7 @@ def __init__( assert isinstance(cli.config_init['model']['submodule2'], BoringModel) -@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason='torchvision is required') +@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required') def test_lightning_cli_torch_modules(tmpdir): class MainModule(BoringModel): From 5e2f91949883bb075a66684da0d759569f7a73a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jun 2021 05:49:44 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index db2daa2c73e8c..a4a1854ccd298 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -20,13 +20,13 @@ from argparse import Namespace from contextlib import redirect_stdout from io import StringIO -from packaging import version from typing import List, Optional from unittest import mock import pytest import torch import yaml +from packaging import version from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint @@ -36,7 +36,6 @@ from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel - torchvision_version = version.parse('0') if _TORCHVISION_AVAILABLE: torchvision_version = version.parse(__import__('torchvision').__version__) From fd844f1798e4116f0c9ede37033291314083c762 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 3 Jun 2021 07:56:24 +0200 Subject: [PATCH 4/5] Minor change --- tests/utilities/test_cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index a4a1854ccd298..c1eabca5d663d 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -456,7 +456,7 @@ def __init__( @pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required') def test_lightning_cli_torch_modules(tmpdir): - class MainModule(BoringModel): + class TestModule(BoringModel): def __init__( self, @@ -491,7 +491,7 @@ def __init__( ] with mock.patch('sys.argv', ['any.py'] + cli_args): - cli = LightningCLI(MainModule) + cli = LightningCLI(TestModule) assert isinstance(cli.model.activation, torch.nn.LeakyReLU) assert cli.model.activation.negative_slope == 0.2 From c3db752edf811b8bb1c97beacfc75350c7d7a631 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 3 Jun 2021 13:15:10 +0200 Subject: [PATCH 5/5] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c3585c57dac4..48da9eba1acf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,7 +182,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924)) -- Fixed support for torch Module type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807)) +- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807)) ## [1.3.2] - 2021-05-18