diff --git a/test/test_datasets.py b/test/test_datasets.py index f4ef4721370..2410f18de09 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -108,14 +108,14 @@ def test_fashionmnist(self, mock_download_extract): img, target = dataset[0] self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - @mock.patch('torchvision.datasets.utils.download_url') + @mock.patch('torchvision.datasets.imagenet._verify_archive') @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") - def test_imagenet(self, mock_download): + def test_imagenet(self, mock_verify): with imagenet_root() as root: - dataset = torchvision.datasets.ImageNet(root, split='train', download=True) + dataset = torchvision.datasets.ImageNet(root, split='train') self.generic_classification_dataset_test(dataset) - dataset = torchvision.datasets.ImageNet(root, split='val', download=True) + dataset = torchvision.datasets.ImageNet(root, split='val') self.generic_classification_dataset_test(dataset) @mock.patch('torchvision.datasets.cifar.check_integrity') diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 14a256c66ab..a45ff3cd44b 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,27 +1,20 @@ -from __future__ import print_function +import warnings +from contextlib import contextmanager import os import shutil import tempfile import torch from .folder import ImageFolder -from .utils import check_integrity, download_and_extract_archive, extract_archive, \ - verify_str_arg - -ARCHIVE_DICT = { - 'train': { - 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', - 'md5': '1d675b47d978889d74fa0da5fadfb00e', - }, - 'val': { - 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', - 'md5': '29b22e2961454d5413ddabcf34fc5622', - }, - 'devkit': { - 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', - 'md5': 'fa75699e90414af021442c21a62c3abf', - } +from .utils import check_integrity, extract_archive, verify_str_arg + +ARCHIVE_META = { + 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), + 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), + 'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf') } +META_FILE = "meta.bin" + class ImageNet(ImageFolder): """`ImageNet `_ 2012 Classification Dataset. @@ -29,9 +22,6 @@ class ImageNet(ImageFolder): Args: root (string): Root directory of the ImageNet Dataset. split (string, optional): The dataset split, supports ``train``, or ``val``. - download (bool, optional): If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the @@ -47,13 +37,22 @@ class ImageNet(ImageFolder): targets (list): The class_index value for each image in the dataset """ - def __init__(self, root, split='train', download=False, **kwargs): + def __init__(self, root, split='train', download=None, **kwargs): + if download is True: + msg = ("The dataset is no longer publicly accessible. You need to " + "download the archives externally and place them in the root " + "directory.") + raise RuntimeError(msg) + elif download is False: + msg = ("The use of the download flag is deprecated, since the dataset " + "is no longer publicly accessible.") + warnings.warn(msg, RuntimeWarning) + root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) - if download: - self.download() - wnid_to_classes = self._load_meta_file()[0] + self.parse_archives() + wnid_to_classes = load_meta_file(self.root)[0] super(ImageNet, self).__init__(self.split_folder, **kwargs) self.root = root @@ -65,50 +64,15 @@ def __init__(self, root, split='train', download=False, **kwargs): for idx, clss in enumerate(self.classes) for cls in clss} - def download(self): - if not check_integrity(self.meta_file): - tmp_dir = tempfile.mkdtemp() - - archive_dict = ARCHIVE_DICT['devkit'] - download_and_extract_archive(archive_dict['url'], self.root, - extract_root=tmp_dir, - md5=archive_dict['md5']) - devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] - meta = parse_devkit(os.path.join(tmp_dir, devkit_folder)) - self._save_meta_file(*meta) - - shutil.rmtree(tmp_dir) + def parse_archives(self): + if not check_integrity(os.path.join(self.root, META_FILE)): + parse_devkit_archive(self.root) if not os.path.isdir(self.split_folder): - archive_dict = ARCHIVE_DICT[self.split] - download_and_extract_archive(archive_dict['url'], self.root, - extract_root=self.split_folder, - md5=archive_dict['md5']) - if self.split == 'train': - prepare_train_folder(self.split_folder) + parse_train_archive(self.root) elif self.split == 'val': - val_wnids = self._load_meta_file()[1] - prepare_val_folder(self.split_folder, val_wnids) - else: - msg = ("You set download=True, but a folder '{}' already exist in " - "the root directory. If you want to re-download or re-extract the " - "archive, delete the folder.") - print(msg.format(self.split)) - - @property - def meta_file(self): - return os.path.join(self.root, 'meta.bin') - - def _load_meta_file(self): - if check_integrity(self.meta_file): - return torch.load(self.meta_file) - else: - raise RuntimeError("Meta file not found or corrupted.", - "You can use download=True to create it.") - - def _save_meta_file(self, wnid_to_class, val_wnids): - torch.save((wnid_to_class, val_wnids), self.meta_file) + parse_val_archive(self.root) @property def split_folder(self): @@ -118,54 +82,137 @@ def extra_repr(self): return "Split: {split}".format(**self.__dict__) -def parse_devkit(root): - idx_to_wnid, wnid_to_classes = parse_meta(root) - val_idcs = parse_val_groundtruth(root) - val_wnids = [idx_to_wnid[idx] for idx in val_idcs] - return wnid_to_classes, val_wnids +def load_meta_file(root, file=None): + if file is None: + file = META_FILE + file = os.path.join(root, file) + + if check_integrity(file): + return torch.load(file) + else: + msg = ("The meta file {} is not present in the root directory or is corrupted. " + "This file is automatically created by the ImageNet dataset.") + raise RuntimeError(msg.format(file, root)) + +def _verify_archive(root, file, md5): + if not check_integrity(os.path.join(root, file), md5): + msg = ("The archive {} is not present in the root directory or is corrupted. " + "You need to download it externally and place it in {}.") + raise RuntimeError(msg.format(file, root)) -def parse_meta(devkit_root, path='data', filename='meta.mat'): + +def parse_devkit_archive(root, file=None): + """Parse the devkit archive of the ImageNet2012 classification dataset and save + the meta information in a binary file. + + Args: + root (str): Root directory containing the devkit archive + file (str, optional): Name of devkit archive. Defaults to + 'ILSVRC2012_devkit_t12.tar.gz' + """ import scipy.io as sio - metafile = os.path.join(devkit_root, path, filename) - meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] - nums_children = list(zip(*meta))[4] - meta = [meta[idx] for idx, num_children in enumerate(nums_children) - if num_children == 0] - idcs, wnids, classes = list(zip(*meta))[:3] - classes = [tuple(clss.split(', ')) for clss in classes] - idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} - wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} - return idx_to_wnid, wnid_to_classes + def parse_meta_mat(devkit_root): + metafile = os.path.join(devkit_root, "data", "meta.mat") + meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] + nums_children = list(zip(*meta))[4] + meta = [meta[idx] for idx, num_children in enumerate(nums_children) + if num_children == 0] + idcs, wnids, classes = list(zip(*meta))[:3] + classes = [tuple(clss.split(', ')) for clss in classes] + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} + return idx_to_wnid, wnid_to_classes + + def parse_val_groundtruth_txt(devkit_root): + file = os.path.join(devkit_root, "data", + "ILSVRC2012_validation_ground_truth.txt") + with open(file, 'r') as txtfh: + val_idcs = txtfh.readlines() + return [int(val_idx) for val_idx in val_idcs] + + @contextmanager + def get_tmp_dir(): + tmp_dir = tempfile.mkdtemp() + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) + + archive_meta = ARCHIVE_META["devkit"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] + + _verify_archive(root, file, md5) + + with get_tmp_dir() as tmp_dir: + extract_archive(os.path.join(root, file), tmp_dir) + + devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") + idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root) + val_idcs = parse_val_groundtruth_txt(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + + torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) + + +def parse_train_archive(root, file=None, folder="train"): + """Parse the train images archive of the ImageNet2012 classification dataset and + prepare it for usage with the ImageNet dataset. + Args: + root (str): Root directory containing the train images archive + file (str, optional): Name of train images archive. Defaults to + 'ILSVRC2012_img_train.tar' + folder (str, optional): Optional name for train images folder. Defaults to + 'train' + """ + archive_meta = ARCHIVE_META["train"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] -def parse_val_groundtruth(devkit_root, path='data', - filename='ILSVRC2012_validation_ground_truth.txt'): - with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: - val_idcs = txtfh.readlines() - return [int(val_idx) for val_idx in val_idcs] + _verify_archive(root, file, md5) + train_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), train_root) -def prepare_train_folder(folder): - for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: + archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)] + for archive in archives: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) -def prepare_val_folder(folder, wnids): - img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) +def parse_val_archive(root, file=None, wnids=None, folder="val"): + """Parse the validation images archive of the ImageNet2012 classification dataset + and prepare it for usage with the ImageNet dataset. - for wnid in set(wnids): - os.mkdir(os.path.join(folder, wnid)) + Args: + root (str): Root directory containing the validation images archive + file (str, optional): Name of validation images archive. Defaults to + 'ILSVRC2012_img_val.tar' + wnids (list, optional): List of WordNet IDs of the validation images. If None + is given, the IDs are loaded from the meta file in the root directory + folder (str, optional): Optional name for validation images folder. Defaults to + 'val' + """ + archive_meta = ARCHIVE_META["val"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] + if wnids is None: + wnids = load_meta_file(root)[1] + + _verify_archive(root, file, md5) - for wnid, img_file in zip(wnids, img_files): - shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) + val_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), val_root) + images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)]) + + for wnid in set(wnids): + os.mkdir(os.path.join(val_root, wnid)) -def _splitexts(root): - exts = [] - ext = '.' - while ext: - root, ext = os.path.splitext(root) - exts.append(ext) - return root, ''.join(reversed(exts)) + for wnid, img_file in zip(wnids, images): + shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))