Skip to content

Commit 4bf0675

Browse files
kaushikb11lexierule
authored andcommitted
Hotfix for torchvision (#6476)
(cherry picked from commit 079fe9b)
1 parent b546431 commit 4bf0675

File tree

6 files changed

+15
-9
lines changed

6 files changed

+15
-9
lines changed

pl_examples/basic_examples/autoencoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
import pytorch_lightning as pl
2323
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
2424

25-
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
25+
if _TORCHVISION_AVAILABLE:
2626
from torchvision import transforms
27-
from torchvision.datasets.mnist import MNIST
27+
if _TORCHVISION_MNIST_AVAILABLE:
28+
from torchvision.datasets import MNIST
2829
else:
2930
from tests.helpers.datasets import MNIST
3031

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import pytorch_lightning as pl
2222
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
2323

24-
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
24+
if _TORCHVISION_AVAILABLE:
2525
from torchvision import transforms
26-
from torchvision.datasets.mnist import MNIST
26+
if _TORCHVISION_MNIST_AVAILABLE:
27+
from torchvision.datasets import MNIST
2728
else:
2829
from tests.helpers.datasets import MNIST
2930

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@
3131
cli_lightning_logo,
3232
)
3333

34-
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
34+
if _TORCHVISION_AVAILABLE:
3535
from torchvision import transforms
36-
from torchvision.datasets.mnist import MNIST
36+
if _TORCHVISION_MNIST_AVAILABLE:
37+
from torchvision.datasets import MNIST
3738
else:
3839
from tests.helpers.datasets import MNIST
3940

pl_examples/basic_examples/mnist_datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from pl_examples import _DATASETS_PATH, _TORCHVISION_AVAILABLE, _TORCHVISION_MNIST_AVAILABLE
2121
from pytorch_lightning import LightningDataModule
2222

23-
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
23+
if _TORCHVISION_AVAILABLE:
2424
from torchvision import transforms as transform_lib
25+
if _TORCHVISION_MNIST_AVAILABLE:
2526
from torchvision.datasets import MNIST
2627
else:
2728
from tests.helpers.datasets import MNIST

pl_examples/domain_templates/generative_adversarial_net.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
from pytorch_lightning.core import LightningDataModule, LightningModule
3333
from pytorch_lightning.trainer import Trainer
3434

35-
if _TORCHVISION_AVAILABLE and _TORCHVISION_MNIST_AVAILABLE:
35+
if _TORCHVISION_AVAILABLE:
3636
import torchvision
37-
import torchvision.transforms as transforms
37+
from torchvision import transforms
38+
if _TORCHVISION_MNIST_AVAILABLE:
3839
from torchvision.datasets import MNIST
3940
else:
4041
from tests.helpers.datasets import MNIST

tests/helpers/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
train: bool = True,
7070
normalize: tuple = (0.1307, 0.3081),
7171
download: bool = True,
72+
**kwargs,
7273
):
7374
super().__init__()
7475
self.root = root

0 commit comments

Comments
 (0)