diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 577bdb2eb32..12f761c070e 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -312,7 +312,8 @@ def create_dataset( patch_checks = inject_fake_data special_kwargs, other_kwargs = self._split_kwargs(kwargs) - if "download" in self._HAS_SPECIAL_KWARG: + if "download" in self._HAS_SPECIAL_KWARG and special_kwargs.get("download", False): + # override download param to False param if its default is truthy special_kwargs["download"] = False config.update(other_kwargs) diff --git a/test/fakedata_generation.py b/test/fakedata_generation.py index dac415df110..473c15d19c4 100644 --- a/test/fakedata_generation.py +++ b/test/fakedata_generation.py @@ -143,76 +143,6 @@ def _make_meta_file(file, classes_key): yield root -@contextlib.contextmanager -def imagenet_root(): - import scipy.io as sio - - WNID = 'n01234567' - CLS = 'fakedata' - - def _make_image(file): - PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file) - - def _make_tar(archive, content, arcname=None, compress=False): - mode = 'w:gz' if compress else 'w' - if arcname is None: - arcname = os.path.basename(content) - with tarfile.open(archive, mode) as fh: - fh.add(content, arcname=arcname) - - def _make_train_archive(root): - with get_tmp_dir() as tmp: - wnid_dir = os.path.join(tmp, WNID) - os.mkdir(wnid_dir) - - _make_image(os.path.join(wnid_dir, WNID + '_1.JPEG')) - - wnid_archive = wnid_dir + '.tar' - _make_tar(wnid_archive, wnid_dir) - - train_archive = os.path.join(root, 'ILSVRC2012_img_train.tar') - _make_tar(train_archive, wnid_archive) - - def _make_val_archive(root): - with get_tmp_dir() as tmp: - val_image = os.path.join(tmp, 'ILSVRC2012_val_00000001.JPEG') - _make_image(val_image) - - val_archive = os.path.join(root, 'ILSVRC2012_img_val.tar') - _make_tar(val_archive, val_image) - - def _make_devkit_archive(root): - with get_tmp_dir() as tmp: - data_dir = os.path.join(tmp, 'data') - os.mkdir(data_dir) - - meta_file = os.path.join(data_dir, 'meta.mat') - synsets = np.core.records.fromarrays([ - (0.0, 1.0), - (WNID, ''), - (CLS, ''), - ('fakedata for the torchvision testsuite', ''), - (0.0, 1.0), - ], names=['ILSVRC2012_ID', 'WNID', 'words', 'gloss', 'num_children']) - sio.savemat(meta_file, {'synsets': synsets}) - - groundtruth_file = os.path.join(data_dir, - 'ILSVRC2012_validation_ground_truth.txt') - with open(groundtruth_file, 'w') as fh: - fh.write('0\n') - - devkit_name = 'ILSVRC2012_devkit_t12' - devkit_archive = os.path.join(root, devkit_name + '.tar.gz') - _make_tar(devkit_archive, tmp, arcname=devkit_name, compress=True) - - with get_tmp_dir() as root: - _make_train_archive(root) - _make_val_archive(root) - _make_devkit_archive(root) - - yield root - - @contextlib.contextmanager def widerface_root(): """ diff --git a/test/test_datasets.py b/test/test_datasets.py index 56f0d6707bc..11114ae5b36 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -10,7 +10,7 @@ import torchvision from torchvision.datasets import utils from common_utils import get_tmp_dir -from fakedata_generation import mnist_root, imagenet_root, \ +from fakedata_generation import mnist_root, \ cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root import xml.etree.ElementTree as ET from urllib.request import Request, urlopen @@ -146,16 +146,6 @@ def test_fashionmnist(self, mock_download_extract): img, target = dataset[0] self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - @mock.patch('torchvision.datasets.imagenet._verify_archive') - @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") - def test_imagenet(self, mock_verify): - with imagenet_root() as root: - dataset = torchvision.datasets.ImageNet(root, split='train') - self.generic_classification_dataset_test(dataset) - - dataset = torchvision.datasets.ImageNet(root, split='val') - self.generic_classification_dataset_test(dataset) - @mock.patch('torchvision.datasets.WIDERFace._check_integrity') @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_widerface(self, mock_check_integrity): @@ -490,6 +480,37 @@ def inject_fake_data(self, tmpdir, config): return num_images_per_category * len(categories) +class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.ImageNet + REQUIRED_PACKAGES = ('scipy',) + CONFIGS = datasets_utils.combinations_grid(split=('train', 'val')) + + def inject_fake_data(self, tmpdir, config): + tmpdir = pathlib.Path(tmpdir) + + wnid = 'n01234567' + if config['split'] == 'train': + num_examples = 3 + datasets_utils.create_image_folder( + root=tmpdir, + name=tmpdir / 'train' / wnid / wnid, + file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", + num_examples=num_examples, + ) + else: + num_examples = 1 + datasets_utils.create_image_folder( + root=tmpdir, + name=tmpdir / 'val' / wnid, + file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG", + num_examples=num_examples, + ) + + wnid_to_classes = {wnid: [1]} + torch.save((wnid_to_classes, None), tmpdir / 'meta.bin') + return num_examples + + class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.CIFAR10 CONFIGS = datasets_utils.combinations_grid(train=(True, False))