Skip to content

Commit a7a2ee7

Browse files
authored
remove old CIFAR tests and fake data generation (#3447)
1 parent 5266a72 commit a7a2ee7

File tree

2 files changed

+1
-88
lines changed

2 files changed

+1
-88
lines changed

test/fakedata_generation.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -88,61 +88,6 @@ def _make_label_file(filename, num_images):
8888
yield tmp_dir
8989

9090

91-
@contextlib.contextmanager
92-
def cifar_root(version):
93-
def _get_version_params(version):
94-
if version == 'CIFAR10':
95-
return {
96-
'base_folder': 'cifar-10-batches-py',
97-
'train_files': ['data_batch_{}'.format(batch) for batch in range(1, 6)],
98-
'test_file': 'test_batch',
99-
'target_key': 'labels',
100-
'meta_file': 'batches.meta',
101-
'classes_key': 'label_names',
102-
}
103-
elif version == 'CIFAR100':
104-
return {
105-
'base_folder': 'cifar-100-python',
106-
'train_files': ['train'],
107-
'test_file': 'test',
108-
'target_key': 'fine_labels',
109-
'meta_file': 'meta',
110-
'classes_key': 'fine_label_names',
111-
}
112-
else:
113-
raise ValueError
114-
115-
def _make_pickled_file(obj, file):
116-
with open(file, 'wb') as fh:
117-
pickle.dump(obj, fh, 2)
118-
119-
def _make_data_file(file, target_key):
120-
obj = {
121-
'data': np.zeros((1, 32 * 32 * 3), dtype=np.uint8),
122-
target_key: [0]
123-
}
124-
_make_pickled_file(obj, file)
125-
126-
def _make_meta_file(file, classes_key):
127-
obj = {
128-
classes_key: ['fakedata'],
129-
}
130-
_make_pickled_file(obj, file)
131-
132-
params = _get_version_params(version)
133-
with get_tmp_dir() as root:
134-
base_folder = os.path.join(root, params['base_folder'])
135-
os.mkdir(base_folder)
136-
137-
for file in list(params['train_files']) + [params['test_file']]:
138-
_make_data_file(os.path.join(base_folder, file), params['target_key'])
139-
140-
_make_meta_file(os.path.join(base_folder, params['meta_file']),
141-
params['classes_key'])
142-
143-
yield root
144-
145-
14691
@contextlib.contextmanager
14792
def imagenet_root():
14893
import scipy.io as sio

test/test_datasets.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torchvision
1111
from torchvision.datasets import utils
1212
from common_utils import get_tmp_dir
13-
from fakedata_generation import mnist_root, cifar_root, imagenet_root, \
13+
from fakedata_generation import mnist_root, imagenet_root, \
1414
cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root
1515
import xml.etree.ElementTree as ET
1616
from urllib.request import Request, urlopen
@@ -173,38 +173,6 @@ def test_widerface(self, mock_check_integrity):
173173
img, target = dataset[0]
174174
self.assertTrue(isinstance(img, PIL.Image.Image))
175175

176-
@mock.patch('torchvision.datasets.cifar.check_integrity')
177-
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
178-
def test_cifar10(self, mock_ext_check, mock_int_check):
179-
mock_ext_check.return_value = True
180-
mock_int_check.return_value = True
181-
with cifar_root('CIFAR10') as root:
182-
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True)
183-
self.generic_classification_dataset_test(dataset, num_images=5)
184-
img, target = dataset[0]
185-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
186-
187-
dataset = torchvision.datasets.CIFAR10(root, train=False, download=True)
188-
self.generic_classification_dataset_test(dataset)
189-
img, target = dataset[0]
190-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
191-
192-
@mock.patch('torchvision.datasets.cifar.check_integrity')
193-
@mock.patch('torchvision.datasets.cifar.CIFAR10._check_integrity')
194-
def test_cifar100(self, mock_ext_check, mock_int_check):
195-
mock_ext_check.return_value = True
196-
mock_int_check.return_value = True
197-
with cifar_root('CIFAR100') as root:
198-
dataset = torchvision.datasets.CIFAR100(root, train=True, download=True)
199-
self.generic_classification_dataset_test(dataset)
200-
img, target = dataset[0]
201-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
202-
203-
dataset = torchvision.datasets.CIFAR100(root, train=False, download=True)
204-
self.generic_classification_dataset_test(dataset)
205-
img, target = dataset[0]
206-
self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target)
207-
208176
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
209177
def test_cityscapes(self):
210178
with cityscapes_root() as root:

0 commit comments

Comments
 (0)