From dbac70a1d8fc243e0261591fc63bd6b4f1e6d2aa Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 15:00:16 +0530 Subject: [PATCH 1/2] hotfix for torchvision --- pl_examples/basic_examples/autoencoder.py | 5 +++-- pl_examples/basic_examples/backbone_image_classifier.py | 5 +++-- pl_examples/basic_examples/dali_image_classifier.py | 5 +++-- pl_examples/basic_examples/mnist_datamodule.py | 3 ++- pl_examples/domain_templates/generative_adversarial_net.py | 5 +++-- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index b3188a21b7f04..a2010a89f4461 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -22,9 +22,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 01a5dca0de3c7..3546bee9ad129 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -21,9 +21,10 @@ import pytorch_lightning as pl from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index b4bf1407a9b26..da5b1e4fd9e9c 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -31,9 +31,10 @@ cli_lightning_logo, ) -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms - from torchvision.datasets.mnist import MNIST +if _TORCHVISION_MNIST_AVAILABLE: + from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a50f67cdab301..a6d59c64d9aa0 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -20,8 +20,9 @@ from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE from pytorch_lightning import LightningDataModule -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 285fba8b93f1b..e65ede17dac7a 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -32,9 +32,10 @@ from pytorch_lightning.core import LightningDataModule, LightningModule from pytorch_lightning.trainer import Trainer -if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE: +if _TORCHVISION_AVAILABLE: import torchvision - import torchvision.transforms as transforms + from torchvision import transforms +if _TORCHVISION_MNIST_AVAILABLE: from torchvision.datasets import MNIST else: from tests.helpers.datasets import MNIST From 97e59b220e25e233f1e497263d1004fe55b3c172 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 11 Mar 2021 15:48:42 +0530 Subject: [PATCH 2/2] fix tests --- tests/helpers/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 5af3fbfbc4a11..e7bdad0f1538c 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -69,6 +69,7 @@ def __init__( train: bool = True, normalize: tuple = (0.1307, 0.3081), download: bool = True, + **kwargs, ): super().__init__() self.root = root