|
10 | 10 | import torchvision |
11 | 11 | from torchvision.datasets import utils |
12 | 12 | from common_utils import get_tmp_dir |
13 | | -from fakedata_generation import mnist_root, imagenet_root, \ |
| 13 | +from fakedata_generation import mnist_root, \ |
14 | 14 | cityscapes_root, svhn_root, places365_root, widerface_root, stl10_root |
15 | 15 | import xml.etree.ElementTree as ET |
16 | 16 | from urllib.request import Request, urlopen |
@@ -146,16 +146,6 @@ def test_fashionmnist(self, mock_download_extract): |
146 | 146 | img, target = dataset[0] |
147 | 147 | self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) |
148 | 148 |
|
149 | | - @mock.patch('torchvision.datasets.imagenet._verify_archive') |
150 | | - @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") |
151 | | - def test_imagenet(self, mock_verify): |
152 | | - with imagenet_root() as root: |
153 | | - dataset = torchvision.datasets.ImageNet(root, split='train') |
154 | | - self.generic_classification_dataset_test(dataset) |
155 | | - |
156 | | - dataset = torchvision.datasets.ImageNet(root, split='val') |
157 | | - self.generic_classification_dataset_test(dataset) |
158 | | - |
159 | 149 | @mock.patch('torchvision.datasets.WIDERFace._check_integrity') |
160 | 150 | @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') |
161 | 151 | def test_widerface(self, mock_check_integrity): |
@@ -490,6 +480,37 @@ def inject_fake_data(self, tmpdir, config): |
490 | 480 | return num_images_per_category * len(categories) |
491 | 481 |
|
492 | 482 |
|
| 483 | +class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): |
| 484 | + DATASET_CLASS = datasets.ImageNet |
| 485 | + REQUIRED_PACKAGES = ('scipy',) |
| 486 | + CONFIGS = datasets_utils.combinations_grid(split=('train', 'val')) |
| 487 | + |
| 488 | + def inject_fake_data(self, tmpdir, config): |
| 489 | + tmpdir = pathlib.Path(tmpdir) |
| 490 | + |
| 491 | + wnid = 'n01234567' |
| 492 | + if config['split'] == 'train': |
| 493 | + num_examples = 3 |
| 494 | + datasets_utils.create_image_folder( |
| 495 | + root=tmpdir, |
| 496 | + name=tmpdir / 'train' / wnid / wnid, |
| 497 | + file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", |
| 498 | + num_examples=num_examples, |
| 499 | + ) |
| 500 | + else: |
| 501 | + num_examples = 1 |
| 502 | + datasets_utils.create_image_folder( |
| 503 | + root=tmpdir, |
| 504 | + name=tmpdir / 'val' / wnid, |
| 505 | + file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG", |
| 506 | + num_examples=num_examples, |
| 507 | + ) |
| 508 | + |
| 509 | + wnid_to_classes = {wnid: [1]} |
| 510 | + torch.save((wnid_to_classes, None), tmpdir / 'meta.bin') |
| 511 | + return num_examples |
| 512 | + |
| 513 | + |
493 | 514 | class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): |
494 | 515 | DATASET_CLASS = datasets.CIFAR10 |
495 | 516 | CONFIGS = datasets_utils.combinations_grid(train=(True, False)) |
|
0 commit comments