From 40f28d48cb09169b06c3a4ea611189c99d27c543 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Oct 2019 13:23:31 +0200 Subject: [PATCH 01/15] remove download process --- torchvision/datasets/imagenet.py | 184 ++++++++++++++++++++----------- 1 file changed, 117 insertions(+), 67 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 14a256c66ab..17303c9c400 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,23 +1,23 @@ from __future__ import print_function +from contextlib import contextmanager import os import shutil import tempfile import torch from .folder import ImageFolder -from .utils import check_integrity, download_and_extract_archive, extract_archive, \ - verify_str_arg +from .utils import check_integrity, extract_archive, verify_str_arg ARCHIVE_DICT = { 'train': { - 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', + 'file': 'ILSVRC2012_img_train.tar', 'md5': '1d675b47d978889d74fa0da5fadfb00e', }, 'val': { - 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', + 'file': 'ILSVRC2012_img_val.tar', 'md5': '29b22e2961454d5413ddabcf34fc5622', }, 'devkit': { - 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', + 'file': 'ILSVRC2012_devkit_t12.tar.gz', 'md5': 'fa75699e90414af021442c21a62c3abf', } } @@ -53,7 +53,7 @@ def __init__(self, root, split='train', download=False, **kwargs): if download: self.download() - wnid_to_classes = self._load_meta_file()[0] + wnid_to_classes = load_meta_file(self.root)[0] super(ImageNet, self).__init__(self.split_folder, **kwargs) self.root = root @@ -66,49 +66,38 @@ def __init__(self, root, split='train', download=False, **kwargs): for cls in clss} def download(self): - if not check_integrity(self.meta_file): - tmp_dir = tempfile.mkdtemp() + def check_archive(archive_dict): + archive = os.path.join(self.root, archive_dict["file"]) + md5 = archive_dict["md5"] + if not check_integrity(archive, md5): + self._raise_download_error() - archive_dict = ARCHIVE_DICT['devkit'] - download_and_extract_archive(archive_dict['url'], self.root, - extract_root=tmp_dir, - md5=archive_dict['md5']) - devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] - meta = parse_devkit(os.path.join(tmp_dir, devkit_folder)) - self._save_meta_file(*meta) + return archive - shutil.rmtree(tmp_dir) + if not check_integrity(self.meta_file): + archive = check_archive(ARCHIVE_DICT['devkit']) + parse_devkit(archive) if not os.path.isdir(self.split_folder): - archive_dict = ARCHIVE_DICT[self.split] - download_and_extract_archive(archive_dict['url'], self.root, - extract_root=self.split_folder, - md5=archive_dict['md5']) + archive = check_archive(ARCHIVE_DICT[self.split]) if self.split == 'train': - prepare_train_folder(self.split_folder) + parse_train_archive(archive) elif self.split == 'val': - val_wnids = self._load_meta_file()[1] - prepare_val_folder(self.split_folder, val_wnids) + parse_val_archive(archive) else: msg = ("You set download=True, but a folder '{}' already exist in " "the root directory. If you want to re-download or re-extract the " "archive, delete the folder.") print(msg.format(self.split)) + def _raise_download_error(self): + # FIXME + raise RuntimeError + @property def meta_file(self): - return os.path.join(self.root, 'meta.bin') - - def _load_meta_file(self): - if check_integrity(self.meta_file): - return torch.load(self.meta_file) - else: - raise RuntimeError("Meta file not found or corrupted.", - "You can use download=True to create it.") - - def _save_meta_file(self, wnid_to_class, val_wnids): - torch.save((wnid_to_class, val_wnids), self.meta_file) + return os.path.join(self.root, "meta.bin") @property def split_folder(self): @@ -118,41 +107,111 @@ def extra_repr(self): return "Split: {split}".format(**self.__dict__) -def parse_devkit(root): - idx_to_wnid, wnid_to_classes = parse_meta(root) - val_idcs = parse_val_groundtruth(root) - val_wnids = [idx_to_wnid[idx] for idx in val_idcs] - return wnid_to_classes, val_wnids +@contextmanager +def tmpdir(): + tmpdir = tempfile.mkdtemp() + try: + yield tmpdir + except: + shutil.rmtree(tmpdir) -def parse_meta(devkit_root, path='data', filename='meta.mat'): - import scipy.io as sio +def _splitexts(root): + exts = [] + ext = '.' + while ext: + root, ext = os.path.splitext(root) + exts.append(ext) + return root, ''.join(reversed(exts)) - metafile = os.path.join(devkit_root, path, filename) - meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] - nums_children = list(zip(*meta))[4] - meta = [meta[idx] for idx, num_children in enumerate(nums_children) - if num_children == 0] - idcs, wnids, classes = list(zip(*meta))[:3] - classes = [tuple(clss.split(', ')) for clss in classes] - idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} - wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} - return idx_to_wnid, wnid_to_classes +def parse_devkit(archive, meta_file=None): + """ -def parse_val_groundtruth(devkit_root, path='data', - filename='ILSVRC2012_validation_ground_truth.txt'): - with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: - val_idcs = txtfh.readlines() - return [int(val_idx) for val_idx in val_idcs] + Args: + archive: + meta_file: + """ + # FIXME + def parse_meta(devkit_root, path='data', filename='meta.mat'): + import scipy.io as sio + + metafile = os.path.join(devkit_root, path, filename) + meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] + nums_children = list(zip(*meta))[4] + meta = [meta[idx] for idx, num_children in enumerate(nums_children) + if num_children == 0] + idcs, wnids, classes = list(zip(*meta))[:3] + classes = [tuple(clss.split(', ')) for clss in classes] + idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} + wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} + return idx_to_wnid, wnid_to_classes + + def parse_val_groundtruth(devkit_root, path='data', + filename='ILSVRC2012_validation_ground_truth.txt'): + with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: + val_idcs = txtfh.readlines() + return [int(val_idx) for val_idx in val_idcs] + + if meta_file is None: + meta_file = os.path.join(os.path.basename(archive), "meta.bin") + + with tmpdir() as devkit_root: + extract_archive(archive, devkit_root) + + idx_to_wnid, wnid_to_classes = parse_meta(devkit_root) + val_idcs = parse_val_groundtruth(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + + torch.save((wnid_to_classes, val_wnids), meta_file) + + +def load_meta_file(root, filename="meta.bin"): + file = os.path.join(root, filename) + if check_integrity(file): + return torch.load(file) + else: + # FIXME + raise RuntimeError("Meta file not found.") + + +def parse_train_archive(archive, folder=None): + # FIXME + """ + + Args: + archive: + folder: + Returns: + + """ + if folder is None: + folder = os.path.join(os.path.basename(archive), "train") + + extract_archive(archive, folder) -def prepare_train_folder(folder): for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) -def prepare_val_folder(folder, wnids): +def parse_val_archive(archive, wnids=None, folder=None): + # FIXME + """ + + Args: + archive: + wnids: + folder: + """ + root = os.path.basename(archive) + if wnids is None: + wnids = load_meta_file(root)[1] + if folder is None: + folder = os.path.join(root, "val") + + extract_archive(archive, folder) + img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) for wnid in set(wnids): @@ -160,12 +219,3 @@ def prepare_val_folder(folder, wnids): for wnid, img_file in zip(wnids, img_files): shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) - - -def _splitexts(root): - exts = [] - ext = '.' - while ext: - root, ext = os.path.splitext(root) - exts.append(ext) - return root, ''.join(reversed(exts)) From 85fffbeaa16717d199a15bfb18eade8a867cfdde Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Oct 2019 14:26:38 +0200 Subject: [PATCH 02/15] address comments --- torchvision/datasets/imagenet.py | 100 +++++++++++++++++-------------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 17303c9c400..99bf559efeb 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,4 +1,4 @@ -from __future__ import print_function +import warnings from contextlib import contextmanager import os import shutil @@ -22,6 +22,8 @@ } } +META_FILE_NAME = "meta.bin" + class ImageNet(ImageFolder): """`ImageNet `_ 2012 Classification Dataset. @@ -29,9 +31,6 @@ class ImageNet(ImageFolder): Args: root (string): Root directory of the ImageNet Dataset. split (string, optional): The dataset split, supports ``train``, or ``val``. - download (bool, optional): If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the @@ -47,7 +46,16 @@ class ImageNet(ImageFolder): targets (list): The class_index value for each image in the dataset """ - def __init__(self, root, split='train', download=False, **kwargs): + def __init__(self, root, split='train', download=None, **kwargs): + if download is None: + msg = ("The use of the download flag is deprecated, since the public " + "download links were removed by the dataset authors. To use this " + "dataset, you need to download the archives externally. Afterwards " + "you can use the parse_{devkit|train|val}_archive() functions to " + "prepare them for usage.") + warnings.warn(msg, DeprecationWarning) + download = False + root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) @@ -70,13 +78,13 @@ def check_archive(archive_dict): archive = os.path.join(self.root, archive_dict["file"]) md5 = archive_dict["md5"] if not check_integrity(archive, md5): - self._raise_download_error() + self._raise_download_error(archive) return archive if not check_integrity(self.meta_file): archive = check_archive(ARCHIVE_DICT['devkit']) - parse_devkit(archive) + parse_devkit_archive(archive) if not os.path.isdir(self.split_folder): archive = check_archive(ARCHIVE_DICT[self.split]) @@ -86,18 +94,18 @@ def check_archive(archive_dict): elif self.split == 'val': parse_val_archive(archive) else: - msg = ("You set download=True, but a folder '{}' already exist in " - "the root directory. If you want to re-download or re-extract the " - "archive, delete the folder.") - print(msg.format(self.split)) + msg = ("A folder '{}' already exist in the root directory. If you want to " + "re-extract the archive, delete the folder.") + warnings.warn(msg.format(self.split), RuntimeWarning) - def _raise_download_error(self): - # FIXME - raise RuntimeError + def _raise_download_error(self, file): + msg = ("The file {} is not present in the root directory and cannot be " + "downloaded anymore.") + raise RuntimeError(msg.format(file)) @property def meta_file(self): - return os.path.join(self.root, "meta.bin") + return os.path.join(self.root, META_FILE_NAME) @property def split_folder(self): @@ -107,25 +115,7 @@ def extra_repr(self): return "Split: {split}".format(**self.__dict__) -@contextmanager -def tmpdir(): - tmpdir = tempfile.mkdtemp() - try: - yield tmpdir - except: - shutil.rmtree(tmpdir) - - -def _splitexts(root): - exts = [] - ext = '.' - while ext: - root, ext = os.path.splitext(root) - exts.append(ext) - return root, ''.join(reversed(exts)) - - -def parse_devkit(archive, meta_file=None): +def parse_devkit_archive(archive, meta_file=None): """ Args: @@ -133,10 +123,10 @@ def parse_devkit(archive, meta_file=None): meta_file: """ # FIXME - def parse_meta(devkit_root, path='data', filename='meta.mat'): - import scipy.io as sio + import scipy.io as sio - metafile = os.path.join(devkit_root, path, filename) + def parse_meta(devkit_root): + metafile = os.path.join(devkit_root, "data", "meta.mat") meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] nums_children = list(zip(*meta))[4] meta = [meta[idx] for idx, num_children in enumerate(nums_children) @@ -147,16 +137,17 @@ def parse_meta(devkit_root, path='data', filename='meta.mat'): wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} return idx_to_wnid, wnid_to_classes - def parse_val_groundtruth(devkit_root, path='data', - filename='ILSVRC2012_validation_ground_truth.txt'): - with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: + def parse_val_groundtruth(devkit_root): + file = os.path.join(devkit_root, "data", + "ILSVRC2012_validation_ground_truth.txt") + with open(file, 'r') as txtfh: val_idcs = txtfh.readlines() return [int(val_idx) for val_idx in val_idcs] if meta_file is None: - meta_file = os.path.join(os.path.basename(archive), "meta.bin") + meta_file = os.path.join(os.path.basename(archive), META_FILE_NAME) - with tmpdir() as devkit_root: + with _tmpdir() as devkit_root: extract_archive(archive, devkit_root) idx_to_wnid, wnid_to_classes = parse_meta(devkit_root) @@ -166,13 +157,14 @@ def parse_val_groundtruth(devkit_root, path='data', torch.save((wnid_to_classes, val_wnids), meta_file) -def load_meta_file(root, filename="meta.bin"): +def load_meta_file(root, filename=META_FILE_NAME): file = os.path.join(root, filename) if check_integrity(file): return torch.load(file) else: - # FIXME - raise RuntimeError("Meta file not found.") + msg = ("Meta file not found at {}. You can create it with the " + "parse_devkit_archive() function") + raise RuntimeError(msg.format(file)) def parse_train_archive(archive, folder=None): @@ -219,3 +211,21 @@ def parse_val_archive(archive, wnids=None, folder=None): for wnid, img_file in zip(wnids, img_files): shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) + + +def _splitexts(root): + exts = [] + ext = '.' + while ext: + root, ext = os.path.splitext(root) + exts.append(ext) + return root, ''.join(reversed(exts)) + + +@contextmanager +def _tmpdir(): + tmpdir = tempfile.mkdtemp() + try: + yield tmpdir + except: + shutil.rmtree(tmpdir) From daa9f99dd85ae08d9f2a75f38b7a26433221aac0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Oct 2019 14:32:19 +0200 Subject: [PATCH 03/15] fix logic error --- torchvision/datasets/imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 99bf559efeb..14c8b028cd6 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -47,7 +47,7 @@ class ImageNet(ImageFolder): """ def __init__(self, root, split='train', download=None, **kwargs): - if download is None: + if download is not None: msg = ("The use of the download flag is deprecated, since the public " "download links were removed by the dataset authors. To use this " "dataset, you need to download the archives externally. Afterwards " From 0e62ef472769d73e61a96493c279bf61ee80dfd5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Oct 2019 14:45:20 +0200 Subject: [PATCH 04/15] bug fixes --- torchvision/datasets/imagenet.py | 35 ++++++++++++++------------------ 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 14c8b028cd6..67b627416b2 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -47,14 +47,15 @@ class ImageNet(ImageFolder): """ def __init__(self, root, split='train', download=None, **kwargs): - if download is not None: + if download is None: + download = False + else: msg = ("The use of the download flag is deprecated, since the public " "download links were removed by the dataset authors. To use this " "dataset, you need to download the archives externally. Afterwards " "you can use the parse_{devkit|train|val}_archive() functions to " "prepare them for usage.") - warnings.warn(msg, DeprecationWarning) - download = False + warnings.warn(msg) root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) @@ -145,17 +146,20 @@ def parse_val_groundtruth(devkit_root): return [int(val_idx) for val_idx in val_idcs] if meta_file is None: - meta_file = os.path.join(os.path.basename(archive), META_FILE_NAME) + meta_file = os.path.join(os.path.dirname(archive), META_FILE_NAME) - with _tmpdir() as devkit_root: - extract_archive(archive, devkit_root) + tmpdir = tempfile.mkdtemp() + extract_archive(archive, tmpdir) - idx_to_wnid, wnid_to_classes = parse_meta(devkit_root) - val_idcs = parse_val_groundtruth(devkit_root) - val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + devkit_root = os.path.join(tmpdir, "ILSVRC2012_devkit_t12") + idx_to_wnid, wnid_to_classes = parse_meta(devkit_root) + val_idcs = parse_val_groundtruth(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] torch.save((wnid_to_classes, val_wnids), meta_file) + shutil.rmtree(tmpdir) + def load_meta_file(root, filename=META_FILE_NAME): file = os.path.join(root, filename) @@ -179,7 +183,7 @@ def parse_train_archive(archive, folder=None): """ if folder is None: - folder = os.path.join(os.path.basename(archive), "train") + folder = os.path.join(os.path.dirname(archive), "train") extract_archive(archive, folder) @@ -196,7 +200,7 @@ def parse_val_archive(archive, wnids=None, folder=None): wnids: folder: """ - root = os.path.basename(archive) + root = os.path.dirname(archive) if wnids is None: wnids = load_meta_file(root)[1] if folder is None: @@ -220,12 +224,3 @@ def _splitexts(root): root, ext = os.path.splitext(root) exts.append(ext) return root, ''.join(reversed(exts)) - - -@contextmanager -def _tmpdir(): - tmpdir = tempfile.mkdtemp() - try: - yield tmpdir - except: - shutil.rmtree(tmpdir) From 18d8b9fce8e59bc12fda6ab55721d79f5acdefe8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Oct 2019 14:46:31 +0200 Subject: [PATCH 05/15] removed unused import --- torchvision/datasets/imagenet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 67b627416b2..d3f904eb30b 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,5 +1,4 @@ import warnings -from contextlib import contextmanager import os import shutil import tempfile From f0cc0251198616d23ff0f375fbc7278782e034b0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Oct 2019 08:54:49 +0200 Subject: [PATCH 06/15] add docstrings --- torchvision/datasets/imagenet.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index d3f904eb30b..51a5ee80c97 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -44,7 +44,6 @@ class ImageNet(ImageFolder): imgs (list): List of (image path, class_index) tuples targets (list): The class_index value for each image in the dataset """ - def __init__(self, root, split='train', download=None, **kwargs): if download is None: download = False @@ -116,13 +115,13 @@ def extra_repr(self): def parse_devkit_archive(archive, meta_file=None): - """ + """Parse the devkit archive of the ImageNet2012 classification dataset and save + the meta information in a binary file. Args: - archive: - meta_file: + archive (str): Path to the devkit archive + meta_file (str, optional): Optional name for the meta information file """ - # FIXME import scipy.io as sio def parse_meta(devkit_root): @@ -171,15 +170,12 @@ def load_meta_file(root, filename=META_FILE_NAME): def parse_train_archive(archive, folder=None): - # FIXME - """ + """Parse the train images archive of the ImageNet2012 classification dataset and + prepare it for usage with the ImageNet dataset. Args: - archive: - folder: - - Returns: - + archive (str): Path to the train images archive + folder (str, optional): Optional name for train images folder """ if folder is None: folder = os.path.join(os.path.dirname(archive), "train") @@ -191,13 +187,15 @@ def parse_train_archive(archive, folder=None): def parse_val_archive(archive, wnids=None, folder=None): - # FIXME - """ + """Parse the validation images archive of the ImageNet2012 classification dataset + and prepare it for usage with the ImageNet dataset. Args: - archive: - wnids: - folder: + archive (str): Path to the validation images archive + wnids (list, optional): List of WordNet IDs of the validation images. If None + is given, the IDs are tried to be loaded from the meta information binary + file in the same directory as the archive. + folder (str, optional): Optional name for validation images folder """ root = os.path.dirname(archive) if wnids is None: From bfacf307c399655c09eb368e574c638cd7bffdcc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Oct 2019 09:05:48 +0200 Subject: [PATCH 07/15] flake8 --- torchvision/datasets/imagenet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 51a5ee80c97..0d57f9c523d 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -182,7 +182,7 @@ def parse_train_archive(archive, folder=None): extract_archive(archive, folder) - for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: + for archive in [os.path.join(folder, file) for file in os.listdir(folder)]: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) From c6b6be726c8d8fe116ca51fb9754e225f80119d1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Oct 2019 08:48:15 +0200 Subject: [PATCH 08/15] remove download BC --- torchvision/datasets/imagenet.py | 40 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 0d57f9c523d..f0b349d87ad 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -44,22 +44,22 @@ class ImageNet(ImageFolder): imgs (list): List of (image path, class_index) tuples targets (list): The class_index value for each image in the dataset """ + def __init__(self, root, split='train', download=None, **kwargs): - if download is None: - download = False - else: - msg = ("The use of the download flag is deprecated, since the public " - "download links were removed by the dataset authors. To use this " - "dataset, you need to download the archives externally. Afterwards " - "you can use the parse_{devkit|train|val}_archive() functions to " - "prepare them for usage.") - warnings.warn(msg) + if download is True: + msg = ("The dataset is no longer publicly accessible. You need to " + "download the archives externally and place them in the root " + "directory.") + raise RuntimeError(msg) + elif download is False: + msg = ("The use of the download flag is deprecated, since the dataset " + "is no longer publicly accessible.") + warnings.warn(msg, RuntimeWarning) root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) - if download: - self.download() + self.extract_archives() wnid_to_classes = load_meta_file(self.root)[0] super(ImageNet, self).__init__(self.split_folder, **kwargs) @@ -72,12 +72,15 @@ def __init__(self, root, split='train', download=None, **kwargs): for idx, clss in enumerate(self.classes) for cls in clss} - def download(self): + def extract_archives(self): def check_archive(archive_dict): - archive = os.path.join(self.root, archive_dict["file"]) + file = archive_dict["file"] md5 = archive_dict["md5"] + archive = os.path.join(self.root, file) if not check_integrity(archive, md5): - self._raise_download_error(archive) + msg = ("The file {} is not present in the root directory. You need to " + "download it externally and place it in {}.") + raise RuntimeError(msg.format(file, self.root)) return archive @@ -92,15 +95,6 @@ def check_archive(archive_dict): parse_train_archive(archive) elif self.split == 'val': parse_val_archive(archive) - else: - msg = ("A folder '{}' already exist in the root directory. If you want to " - "re-extract the archive, delete the folder.") - warnings.warn(msg.format(self.split), RuntimeWarning) - - def _raise_download_error(self, file): - msg = ("The file {} is not present in the root directory and cannot be " - "downloaded anymore.") - raise RuntimeError(msg.format(file)) @property def meta_file(self): From caaafbf52c8f5d57a302844cc6357e086db0c4f4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Oct 2019 09:06:44 +0200 Subject: [PATCH 09/15] fix test --- test/test_datasets.py | 8 ++++---- torchvision/datasets/imagenet.py | 26 +++++++++++++------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index f4ef4721370..3076945765d 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -108,14 +108,14 @@ def test_fashionmnist(self, mock_download_extract): img, target = dataset[0] self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) - @mock.patch('torchvision.datasets.utils.download_url') + @mock.patch('torchvision.datasets.imagenet.ImageNet._verify_archive') @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") - def test_imagenet(self, mock_download): + def test_imagenet(self, mock_check): with imagenet_root() as root: - dataset = torchvision.datasets.ImageNet(root, split='train', download=True) + dataset = torchvision.datasets.ImageNet(root, split='train') self.generic_classification_dataset_test(dataset) - dataset = torchvision.datasets.ImageNet(root, split='val', download=True) + dataset = torchvision.datasets.ImageNet(root, split='val') self.generic_classification_dataset_test(dataset) @mock.patch('torchvision.datasets.cifar.check_integrity') diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index f0b349d87ad..4b529d5982f 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -73,29 +73,29 @@ def __init__(self, root, split='train', download=None, **kwargs): for cls in clss} def extract_archives(self): - def check_archive(archive_dict): - file = archive_dict["file"] - md5 = archive_dict["md5"] - archive = os.path.join(self.root, file) - if not check_integrity(archive, md5): - msg = ("The file {} is not present in the root directory. You need to " - "download it externally and place it in {}.") - raise RuntimeError(msg.format(file, self.root)) - - return archive - if not check_integrity(self.meta_file): - archive = check_archive(ARCHIVE_DICT['devkit']) + archive_dict = ARCHIVE_DICT['devkit'] + archive = os.path.join(self.root, archive_dict["file"]) + self._verify_archive(archive, archive_dict["md5"]) + parse_devkit_archive(archive) if not os.path.isdir(self.split_folder): - archive = check_archive(ARCHIVE_DICT[self.split]) + archive_dict = ARCHIVE_DICT[self.split] + archive = os.path.join(self.root, archive_dict["file"]) + self._verify_archive(archive, archive_dict["md5"]) if self.split == 'train': parse_train_archive(archive) elif self.split == 'val': parse_val_archive(archive) + def _verify_archive(self, archive, md5): + if not check_integrity(archive, md5): + msg = ("The file {} is not present in the root directory or corrupted. " + "You need to download it externally and place it in {}.") + raise RuntimeError(msg.format(os.path.basename(archive), self.root)) + @property def meta_file(self): return os.path.join(self.root, META_FILE_NAME) From 2548910911469bde92c308863eb45cd31701b415 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Oct 2019 09:11:48 +0200 Subject: [PATCH 10/15] removed unused code --- torchvision/datasets/imagenet.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 4b529d5982f..7cd6ea16b6a 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -207,11 +207,3 @@ def parse_val_archive(archive, wnids=None, folder=None): for wnid, img_file in zip(wnids, img_files): shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) - -def _splitexts(root): - exts = [] - ext = '.' - while ext: - root, ext = os.path.splitext(root) - exts.append(ext) - return root, ''.join(reversed(exts)) From ba4bbc6a50bf62722fc3ed105224e80006757be7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Oct 2019 15:41:30 +0200 Subject: [PATCH 11/15] flake 8 --- torchvision/datasets/imagenet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 7cd6ea16b6a..5bb2821b1b1 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -206,4 +206,3 @@ def parse_val_archive(archive, wnids=None, folder=None): for wnid, img_file in zip(wnids, img_files): shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) - From 6fc496f965167e819260072446dd2ba1a8e77816 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 17 Oct 2019 16:56:04 +0200 Subject: [PATCH 12/15] add MD5 verification before extraction --- torchvision/datasets/imagenet.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 5bb2821b1b1..47677ed0cac 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -4,7 +4,7 @@ import tempfile import torch from .folder import ImageFolder -from .utils import check_integrity, extract_archive, verify_str_arg +from .utils import check_integrity, extract_archive, verify_str_arg, check_md5 ARCHIVE_DICT = { 'train': { @@ -108,13 +108,24 @@ def extra_repr(self): return "Split: {split}".format(**self.__dict__) -def parse_devkit_archive(archive, meta_file=None): +def _verify_archive(archive, md5, force): + if not check_integrity(archive): + raise RuntimeError("The file {} doesn't exist.".format(archive)) + if not check_md5(archive, md5) and not force: + msg = ("The MD5 checksum of the file {} and the original archive do not match. " + "Use force=True to force an extraction") + raise RuntimeError(msg.format(archive)) + + +def parse_devkit_archive(archive, meta_file=None, force=False): """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. Args: archive (str): Path to the devkit archive meta_file (str, optional): Optional name for the meta information file + force (bool, optional). Force extraction if MD5 checksum does not match. + Defaults to False """ import scipy.io as sio @@ -140,6 +151,8 @@ def parse_val_groundtruth(devkit_root): if meta_file is None: meta_file = os.path.join(os.path.dirname(archive), META_FILE_NAME) + _verify_archive(archive, ARCHIVE_DICT["devkit"]["md5"], force) + tmpdir = tempfile.mkdtemp() extract_archive(archive, tmpdir) @@ -163,24 +176,27 @@ def load_meta_file(root, filename=META_FILE_NAME): raise RuntimeError(msg.format(file)) -def parse_train_archive(archive, folder=None): +def parse_train_archive(archive, folder=None, force=False): """Parse the train images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: archive (str): Path to the train images archive folder (str, optional): Optional name for train images folder + force (bool, optional). Force extraction if MD5 checksum does not match. + Defaults to False """ if folder is None: folder = os.path.join(os.path.dirname(archive), "train") + _verify_archive(archive, ARCHIVE_DICT["train"]["md5"], force) extract_archive(archive, folder) for archive in [os.path.join(folder, file) for file in os.listdir(folder)]: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) -def parse_val_archive(archive, wnids=None, folder=None): +def parse_val_archive(archive, wnids=None, folder=None, force=False): """Parse the validation images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. @@ -190,6 +206,8 @@ def parse_val_archive(archive, wnids=None, folder=None): is given, the IDs are tried to be loaded from the meta information binary file in the same directory as the archive. folder (str, optional): Optional name for validation images folder + force (bool, optional). Force extraction if MD5 checksum does not match. + Defaults to False """ root = os.path.dirname(archive) if wnids is None: @@ -197,6 +215,7 @@ def parse_val_archive(archive, wnids=None, folder=None): if folder is None: folder = os.path.join(root, "val") + _verify_archive(archive, ARCHIVE_DICT["val"]["md5"], force) extract_archive(archive, folder) img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) From f717e99bb702253246fdb7ba49b7120540d5df1c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 18 Oct 2019 12:13:09 +0200 Subject: [PATCH 13/15] add mock to test --- test/test_datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 3076945765d..4febcbcb1f7 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -108,9 +108,10 @@ def test_fashionmnist(self, mock_download_extract): img, target = dataset[0] self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) + @mock.patch('torchvision.datasets.imagenet._verify_archive') @mock.patch('torchvision.datasets.imagenet.ImageNet._verify_archive') @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") - def test_imagenet(self, mock_check): + def test_imagenet(self, mock_verify_external, mock_verify_internal): with imagenet_root() as root: dataset = torchvision.datasets.ImageNet(root, split='train') self.generic_classification_dataset_test(dataset) From eb1bd7a40012d5539987d6443518a0e54148683a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Oct 2019 12:57:51 +0200 Subject: [PATCH 14/15] * unify _verify_archive() method and function * remove force flag for parse_*_archive functions * cleanup --- test/test_datasets.py | 3 +- torchvision/datasets/imagenet.py | 190 +++++++++++++++---------------- 2 files changed, 91 insertions(+), 102 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 4febcbcb1f7..2410f18de09 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -109,9 +109,8 @@ def test_fashionmnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.imagenet._verify_archive') - @mock.patch('torchvision.datasets.imagenet.ImageNet._verify_archive') @unittest.skipIf(not HAS_SCIPY, "scipy unavailable") - def test_imagenet(self, mock_verify_external, mock_verify_internal): + def test_imagenet(self, mock_verify): with imagenet_root() as root: dataset = torchvision.datasets.ImageNet(root, split='train') self.generic_classification_dataset_test(dataset) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 47677ed0cac..db6adb14201 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,27 +1,19 @@ import warnings +from contextlib import contextmanager import os import shutil import tempfile import torch from .folder import ImageFolder -from .utils import check_integrity, extract_archive, verify_str_arg, check_md5 - -ARCHIVE_DICT = { - 'train': { - 'file': 'ILSVRC2012_img_train.tar', - 'md5': '1d675b47d978889d74fa0da5fadfb00e', - }, - 'val': { - 'file': 'ILSVRC2012_img_val.tar', - 'md5': '29b22e2961454d5413ddabcf34fc5622', - }, - 'devkit': { - 'file': 'ILSVRC2012_devkit_t12.tar.gz', - 'md5': 'fa75699e90414af021442c21a62c3abf', - } +from .utils import check_integrity, extract_archive, verify_str_arg + +ARCHIVE_META = { + 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), + 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), + 'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf') } -META_FILE_NAME = "meta.bin" +META_FILE = "meta.bin" class ImageNet(ImageFolder): @@ -59,7 +51,7 @@ def __init__(self, root, split='train', download=None, **kwargs): root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) - self.extract_archives() + self.parse_archives() wnid_to_classes = load_meta_file(self.root)[0] super(ImageNet, self).__init__(self.split_folder, **kwargs) @@ -72,33 +64,15 @@ def __init__(self, root, split='train', download=None, **kwargs): for idx, clss in enumerate(self.classes) for cls in clss} - def extract_archives(self): - if not check_integrity(self.meta_file): - archive_dict = ARCHIVE_DICT['devkit'] - archive = os.path.join(self.root, archive_dict["file"]) - self._verify_archive(archive, archive_dict["md5"]) - - parse_devkit_archive(archive) + def parse_archives(self): + if not check_integrity(os.path.join(self.root, META_FILE)): + parse_devkit_archive(self.root) if not os.path.isdir(self.split_folder): - archive_dict = ARCHIVE_DICT[self.split] - archive = os.path.join(self.root, archive_dict["file"]) - self._verify_archive(archive, archive_dict["md5"]) - if self.split == 'train': - parse_train_archive(archive) + parse_train_archive(self.root) elif self.split == 'val': - parse_val_archive(archive) - - def _verify_archive(self, archive, md5): - if not check_integrity(archive, md5): - msg = ("The file {} is not present in the root directory or corrupted. " - "You need to download it externally and place it in {}.") - raise RuntimeError(msg.format(os.path.basename(archive), self.root)) - - @property - def meta_file(self): - return os.path.join(self.root, META_FILE_NAME) + parse_val_archive(self.root) @property def split_folder(self): @@ -108,28 +82,38 @@ def extra_repr(self): return "Split: {split}".format(**self.__dict__) -def _verify_archive(archive, md5, force): - if not check_integrity(archive): - raise RuntimeError("The file {} doesn't exist.".format(archive)) - if not check_md5(archive, md5) and not force: - msg = ("The MD5 checksum of the file {} and the original archive do not match. " - "Use force=True to force an extraction") - raise RuntimeError(msg.format(archive)) +def load_meta_file(root, file=None): + if file is None: + file = META_FILE + file = os.path.join(root, file) + + if check_integrity(file): + return torch.load(file) + else: + msg = ("The meta file {} is not present in the root directory or is corrupted. " + "This file is automatically created by the ImageNet dataset.") + raise RuntimeError(msg.format(file, root)) -def parse_devkit_archive(archive, meta_file=None, force=False): +def _verify_archive(root, file, md5): + if not check_integrity(os.path.join(root, file), md5): + msg = ("The archive {} is not present in the root directory or is corrupted. " + "You need to download it externally and place it in {}.") + raise RuntimeError(msg.format(file, root)) + + +def parse_devkit_archive(root, file=None): """Parse the devkit archive of the ImageNet2012 classification dataset and save the meta information in a binary file. Args: - archive (str): Path to the devkit archive - meta_file (str, optional): Optional name for the meta information file - force (bool, optional). Force extraction if MD5 checksum does not match. - Defaults to False + root (str): Root directory containing the devkit archive + file (str, optional): Name of devkit archive. Defaults to + 'ILSVRC2012_devkit_t12.tar.gz' """ import scipy.io as sio - def parse_meta(devkit_root): + def parse_meta_mat(devkit_root): metafile = os.path.join(devkit_root, "data", "meta.mat") meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] nums_children = list(zip(*meta))[4] @@ -141,87 +125,93 @@ def parse_meta(devkit_root): wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} return idx_to_wnid, wnid_to_classes - def parse_val_groundtruth(devkit_root): + def parse_val_groundtruth_txt(devkit_root): file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") with open(file, 'r') as txtfh: val_idcs = txtfh.readlines() return [int(val_idx) for val_idx in val_idcs] - if meta_file is None: - meta_file = os.path.join(os.path.dirname(archive), META_FILE_NAME) + @contextmanager + def get_tmp_dir(): + tmp_dir = tempfile.mkdtemp() + try: + yield tmp_dir + finally: + shutil.rmtree(tmp_dir) - _verify_archive(archive, ARCHIVE_DICT["devkit"]["md5"], force) + archive_meta = ARCHIVE_META["devkit"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] - tmpdir = tempfile.mkdtemp() - extract_archive(archive, tmpdir) + _verify_archive(root, file, md5) - devkit_root = os.path.join(tmpdir, "ILSVRC2012_devkit_t12") - idx_to_wnid, wnid_to_classes = parse_meta(devkit_root) - val_idcs = parse_val_groundtruth(devkit_root) - val_wnids = [idx_to_wnid[idx] for idx in val_idcs] + with get_tmp_dir() as tmp_dir: + extract_archive(os.path.join(root, file), tmp_dir) - torch.save((wnid_to_classes, val_wnids), meta_file) + devkit_root = os.path.join(tmp_dir, "ILSVRC2012_devkit_t12") + idx_to_wnid, wnid_to_classes = parse_meta_mat(devkit_root) + val_idcs = parse_val_groundtruth_txt(devkit_root) + val_wnids = [idx_to_wnid[idx] for idx in val_idcs] - shutil.rmtree(tmpdir) + torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) -def load_meta_file(root, filename=META_FILE_NAME): - file = os.path.join(root, filename) - if check_integrity(file): - return torch.load(file) - else: - msg = ("Meta file not found at {}. You can create it with the " - "parse_devkit_archive() function") - raise RuntimeError(msg.format(file)) - - -def parse_train_archive(archive, folder=None, force=False): +def parse_train_archive(root, file=None, folder="train"): """Parse the train images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: - archive (str): Path to the train images archive - folder (str, optional): Optional name for train images folder - force (bool, optional). Force extraction if MD5 checksum does not match. - Defaults to False + root (str): Root directory containing the train images archive + file (str, optional): Name of train images archive. Defaults to + 'ILSVRC2012_img_train.tar' + folder (str, optional): Optional name for train images folder. Defaults to + 'train' """ - if folder is None: - folder = os.path.join(os.path.dirname(archive), "train") + archive_meta = ARCHIVE_META["train"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] - _verify_archive(archive, ARCHIVE_DICT["train"]["md5"], force) - extract_archive(archive, folder) + _verify_archive(root, file, md5) - for archive in [os.path.join(folder, file) for file in os.listdir(folder)]: + train_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), train_root) + + for archive in [os.path.join(train_root, file) for file in os.listdir(train_root)]: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) -def parse_val_archive(archive, wnids=None, folder=None, force=False): +def parse_val_archive(root, file=None, wnids=None, folder="val"): """Parse the validation images archive of the ImageNet2012 classification dataset and prepare it for usage with the ImageNet dataset. Args: - archive (str): Path to the validation images archive + root (str): Root directory containing the validation images archive + file (str, optional): Name of validation images archive. Defaults to + 'ILSVRC2012_img_val.tar' wnids (list, optional): List of WordNet IDs of the validation images. If None - is given, the IDs are tried to be loaded from the meta information binary - file in the same directory as the archive. - folder (str, optional): Optional name for validation images folder - force (bool, optional). Force extraction if MD5 checksum does not match. - Defaults to False + is given, the IDs are loaded from the meta file in the root directory + folder (str, optional): Optional name for validation images folder. Defaults to + 'val' """ - root = os.path.dirname(archive) + archive_meta = ARCHIVE_META["val"] + if file is None: + file = archive_meta[0] + md5 = archive_meta[1] if wnids is None: wnids = load_meta_file(root)[1] - if folder is None: - folder = os.path.join(root, "val") - _verify_archive(archive, ARCHIVE_DICT["val"]["md5"], force) - extract_archive(archive, folder) + _verify_archive(root, file, md5) + + val_root = os.path.join(root, folder) + extract_archive(os.path.join(root, file), val_root) - img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) + img_files = sorted([os.path.join(val_root, file) for file in os.listdir(val_root)]) for wnid in set(wnids): - os.mkdir(os.path.join(folder, wnid)) + os.mkdir(os.path.join(val_root, wnid)) for wnid, img_file in zip(wnids, img_files): - shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) + shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file))) From b57b90e07bd247bbb6fa2a3462e70ba27a229386 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Oct 2019 13:06:02 +0200 Subject: [PATCH 15/15] flake8 --- torchvision/datasets/imagenet.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index db6adb14201..a45ff3cd44b 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -179,7 +179,8 @@ def parse_train_archive(root, file=None, folder="train"): train_root = os.path.join(root, folder) extract_archive(os.path.join(root, file), train_root) - for archive in [os.path.join(train_root, file) for file in os.listdir(train_root)]: + archives = [os.path.join(train_root, archive) for archive in os.listdir(train_root)] + for archive in archives: extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) @@ -208,10 +209,10 @@ def parse_val_archive(root, file=None, wnids=None, folder="val"): val_root = os.path.join(root, folder) extract_archive(os.path.join(root, file), val_root) - img_files = sorted([os.path.join(val_root, file) for file in os.listdir(val_root)]) + images = sorted([os.path.join(val_root, image) for image in os.listdir(val_root)]) for wnid in set(wnids): os.mkdir(os.path.join(val_root, wnid)) - for wnid, img_file in zip(wnids, img_files): + for wnid, img_file in zip(wnids, images): shutil.move(img_file, os.path.join(val_root, wnid, os.path.basename(img_file)))