Skip to content

Commit be8978b

Browse files
committed
remove old CIFAR tests and fake data generation
1 parent fc33c46 commit be8978b

File tree

2 files changed

+1
-88
lines changed

2 files changed

+1
-88
lines changed

test/fakedata_generation.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -88,61 +88,6 @@ def _make_label_file(filename, num_images):
8888
yield tmp_dir
8989

9090

91-
@contextlib.contextmanager
92-
def cifar_root(version):
93-
def _get_version_params(version):
94-
if version == 'CIFAR10':
95-
return {
96-
'base_folder': 'cifar-10-batches-py',
97-
'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)],
98-
'test_file': 'test_batch',
99-
'target_key': 'labels',
100-
'meta_file': 'batches.meta',
101-
'classes_key': 'label_names',
102-
}
103-
elif version == 'CIFAR100':
104-
return {
105-
'base_folder': 'cifar-100-python',
106-
'train_files': ['train'],
107-
'test_file': 'test',
108-
'target_key': 'fine_labels',
109-
'meta_file': 'meta',
110-
'classes_key': 'fine_label_names',
111-
}
112-
else:
113-
raise ValueError
114-
115-
def _make_pickled_file(obj, file):
116-
with open(file, 'wb') as fh:
117-
pickle.dump(obj, fh, 2)
118-
119-
def _make_data_file(file, target_key):
120-
obj = {
121-
'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8),
122-
target_key: [0]
123-
}
124-
_make_pickled_file(obj, file)
125-
126-
def _make_meta_file(file, classes_key):
127-
obj = {
128-
classes_key: ['fakedata'],
129-
}
130-
_make_pickled_file(obj, file)
131-
132-
params = _get_version_params(version)
133-
with get_tmp_dir() as root:
134-
base_folder = os.path.join(root, params['base_folder'])
135-
os.mkdir(base_folder)
136-
137-
for file in list(params['train_files']) + [params['test_file']]:
138-
_make_data_file(os.path.join(base_folder, file), params['target_key'])
139-
140-
_make_meta_file(os.path.join(base_folder, params['meta_file']),
141-
params['classes_key'])
142-
143-
yield root
144-
145-
14691
@contextlib.contextmanager
14792
def imagenet_root():
14893
import scipy.io as sio

test/test_datasets.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torchvision
1111
from torchvision.datasets import utils
1212
from common_utils import get_tmp_dir
13-
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
13+
from fakedata_generation import mnist_root, imagenet_root, \
1414
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
1515
import xml.etree.ElementTree as ET
1616
from urllib.request import Request, urlopen
@@ -171,38 +171,6 @@ def test_widerface(self, mock_check_integrity):
171171
img, target = dataset[0]
172172
self.assertTrue(isinstance(img, PIL.Image.Image))
173173

174-
@mock.patch('torchvision.datasets.cifar.check_integrity')
175-
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
176-
def test_cifar10(self, mock_ext_check, mock_int_check):
177-
mock_ext_check.return_value = True
178-
mock_int_check.return_value = True
179-
with cifar_root('CIFAR10') as root:
180-
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
181-
self.generic_classification_dataset_test(dataset, num_images=5)
182-
img, target = dataset[0]
183-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
184-
185-
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
186-
self.generic_classification_dataset_test(dataset)
187-
img, target = dataset[0]
188-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
189-
190-
@mock.patch('torchvision.datasets.cifar.check_integrity')
191-
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
192-
def test_cifar100(self, mock_ext_check, mock_int_check):
193-
mock_ext_check.return_value = True
194-
mock_int_check.return_value = True
195-
with cifar_root('CIFAR100') as root:
196-
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
197-
self.generic_classification_dataset_test(dataset)
198-
img, target = dataset[0]
199-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
200-
201-
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
202-
self.generic_classification_dataset_test(dataset)
203-
img, target = dataset[0]
204-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
205-
206174
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
207175
def test_cityscapes(self):
208176
with cityscapes_root() as root:

0 commit comments

Comments
 (0)