Skip to content

Commit 2b81ad8

Browse files
Philip Meierfmassa
authored andcommitted
Always pass transform and target_transform to abstract dataset (#1126)
* fixed call to the VisionDataset constructor * change call from keyword arguments to positional * changed order of arguments * removed transforms argument once again * Fixed call to constructor of parent class * fixed LSUN * fixed Caltech256
1 parent 2cae950 commit 2b81ad8

File tree

14 files changed

+55
-79
lines changed

14 files changed

+55
-79
lines changed

torchvision/datasets/caltech.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@ class Caltech101(VisionDataset):
2626
downloaded again.
2727
"""
2828

29-
def __init__(self, root, target_type="category",
30-
transform=None, target_transform=None,
31-
download=False):
32-
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
29+
def __init__(self, root, target_type="category", transform=None,
30+
target_transform=None, download=False):
31+
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
32+
transform=transform,
33+
target_transform=target_transform)
3334
makedir_exist_ok(self.root)
3435
if isinstance(target_type, list):
3536
self.target_type = target_type
3637
else:
3738
self.target_type = [target_type]
38-
self.transform = transform
39-
self.target_transform = target_transform
4039

4140
if download:
4241
self.download()
@@ -143,13 +142,11 @@ class Caltech256(VisionDataset):
143142
downloaded again.
144143
"""
145144

146-
def __init__(self, root,
147-
transform=None, target_transform=None,
148-
download=False):
149-
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
145+
def __init__(self, root, transform=None, target_transform=None, download=False):
146+
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
147+
transform=transform,
148+
target_transform=target_transform)
150149
makedir_exist_ok(self.root)
151-
self.transform = transform
152-
self.target_transform = target_transform
153150

154151
if download:
155152
self.download()

torchvision/datasets/celeba.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,16 @@ class CelebA(VisionDataset):
4848
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
4949
]
5050

51-
def __init__(self, root,
52-
split="train",
53-
target_type="attr",
54-
transform=None, target_transform=None,
55-
download=False):
51+
def __init__(self, root, split="train", target_type="attr", transform=None,
52+
target_transform=None, download=False):
5653
import pandas
57-
super(CelebA, self).__init__(root)
54+
super(CelebA, self).__init__(root, transform=transform,
55+
target_transform=target_transform)
5856
self.split = split
5957
if isinstance(target_type, list):
6058
self.target_type = target_type
6159
else:
6260
self.target_type = [target_type]
63-
self.transform = transform
64-
self.target_transform = target_transform
6561

6662
if download:
6763
self.download()
@@ -70,9 +66,6 @@ def __init__(self, root,
7066
raise RuntimeError('Dataset not found or corrupted.' +
7167
' You can use download=True to download it')
7268

73-
self.transform = transform
74-
self.target_transform = target_transform
75-
7669
if split.lower() == "train":
7770
split = 0
7871
elif split.lower() == "valid":

torchvision/datasets/cifar.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,11 @@ class CIFAR10(VisionDataset):
5252
'md5': '5ff9c542aee3614f3951f8cda6e48888',
5353
}
5454

55-
def __init__(self, root, train=True,
56-
transform=None, target_transform=None,
55+
def __init__(self, root, train=True, transform=None, target_transform=None,
5756
download=False):
5857

59-
super(CIFAR10, self).__init__(root)
60-
self.transform = transform
61-
self.target_transform = target_transform
58+
super(CIFAR10, self).__init__(root, transform=transform,
59+
target_transform=target_transform)
6260

6361
self.train = train # training set or test set
6462

torchvision/datasets/flickr.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ class Flickr8k(VisionDataset):
6363
"""
6464

6565
def __init__(self, root, ann_file, transform=None, target_transform=None):
66-
super(Flickr8k, self).__init__(root)
67-
self.transform = transform
68-
self.target_transform = target_transform
66+
super(Flickr8k, self).__init__(root, transform=transform,
67+
target_transform=target_transform)
6968
self.ann_file = os.path.expanduser(ann_file)
7069

7170
# Read annotations and store in a dict
@@ -115,9 +114,8 @@ class Flickr30k(VisionDataset):
115114
"""
116115

117116
def __init__(self, root, ann_file, transform=None, target_transform=None):
118-
super(Flickr30k, self).__init__(root)
119-
self.transform = transform
120-
self.target_transform = target_transform
117+
super(Flickr30k, self).__init__(root, transform=transform,
118+
target_transform=target_transform)
121119
self.ann_file = os.path.expanduser(ann_file)
122120

123121
# Read annotations and store in a dict

torchvision/datasets/folder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ class DatasetFolder(VisionDataset):
8686
targets (list): The class_index value for each image in the dataset
8787
"""
8888

89-
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
90-
super(DatasetFolder, self).__init__(root)
91-
self.transform = transform
92-
self.target_transform = target_transform
89+
def __init__(self, root, loader, extensions=None, transform=None,
90+
target_transform=None, is_valid_file=None):
91+
super(DatasetFolder, self).__init__(root, transform=transform,
92+
target_transform=target_transform)
9393
classes, class_to_idx = self._find_classes(self.root)
9494
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
9595
if len(samples) == 0:

torchvision/datasets/lsun.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
class LSUNClass(VisionDataset):
1616
def __init__(self, root, transform=None, target_transform=None):
1717
import lmdb
18-
super(LSUNClass, self).__init__(root)
19-
self.transform = transform
20-
self.target_transform = target_transform
18+
super(LSUNClass, self).__init__(root, transform=transform,
19+
target_transform=target_transform)
2120

2221
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
2322
readahead=False, meminit=False)
@@ -68,11 +67,9 @@ class LSUN(VisionDataset):
6867
target and transforms it.
6968
"""
7069

71-
def __init__(self, root, classes='train',
72-
transform=None, target_transform=None):
73-
super(LSUN, self).__init__(root)
74-
self.transform = transform
75-
self.target_transform = target_transform
70+
def __init__(self, root, classes='train', transform=None, target_transform=None):
71+
super(LSUN, self).__init__(root, transform=transform,
72+
target_transform=target_transform)
7673
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
7774
'conference_room', 'dining_room', 'kitchen',
7875
'living_room', 'restaurant', 'tower']

torchvision/datasets/mnist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def test_data(self):
5757
warnings.warn("test_data has been renamed data")
5858
return self.data
5959

60-
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
61-
super(MNIST, self).__init__(root)
62-
self.transform = transform
63-
self.target_transform = target_transform
60+
def __init__(self, root, train=True, transform=None, target_transform=None,
61+
download=False):
62+
super(MNIST, self).__init__(root, transform=transform,
63+
target_transform=target_transform)
6464
self.train = train # training set or test set
6565

6666
if download:

torchvision/datasets/omniglot.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ class Omniglot(VisionDataset):
2828
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
2929
}
3030

31-
def __init__(self, root, background=True,
32-
transform=None, target_transform=None,
31+
def __init__(self, root, background=True, transform=None, target_transform=None,
3332
download=False):
34-
super(Omniglot, self).__init__(join(root, self.folder))
35-
self.transform = transform
36-
self.target_transform = target_transform
33+
super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
34+
target_transform=target_transform)
3735
self.background = background
3836

3937
if download:

torchvision/datasets/phototour.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ class PhotoTour(VisionDataset):
6565
matches_files = 'm50_100000_100000_0.txt'
6666

6767
def __init__(self, root, name, train=True, transform=None, download=False):
68-
super(PhotoTour, self).__init__(root)
69-
self.transform = transform
68+
super(PhotoTour, self).__init__(root, transform=transform)
7069
self.name = name
7170
self.data_dir = os.path.join(self.root, name)
7271
self.data_down = os.path.join(self.root, '{}.zip'.format(name))

torchvision/datasets/sbu.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ class SBU(VisionDataset):
2424
filename = "SBUCaptionedPhotoDataset.tar.gz"
2525
md5_checksum = '9aec147b3488753cf758b4d493422285'
2626

27-
def __init__(self, root, transform=None, target_transform=None,
28-
download=True):
29-
super(SBU, self).__init__(root)
30-
self.transform = transform
31-
self.target_transform = target_transform
27+
def __init__(self, root, transform=None, target_transform=None, download=True):
28+
super(SBU, self).__init__(root, transform=transform,
29+
target_transform=target_transform)
3230

3331
if download:
3432
self.download()

0 commit comments

Comments
 (0)