From 4a71561454670ca07185deacc0e4301d7a1d45ad Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Feb 2021 16:45:40 +0100 Subject: [PATCH 1/8] remove caching from (Fashion|K)?MNIST --- test/test_datasets.py | 9 +++-- torchvision/datasets/mnist.py | 67 +++++++++++++++++++---------------- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 8ec5be7de19..6050734ac5d 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -111,7 +111,8 @@ def test_imagefolder_empty(self): ) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_mnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_mnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "MNIST") as root: dataset = torchvision.datasets.MNIST(root, download=True) @@ -120,7 +121,8 @@ def test_mnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_kmnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_kmnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "KMNIST") as root: dataset = torchvision.datasets.KMNIST(root, download=True) @@ -129,7 +131,8 @@ def test_kmnist(self, mock_download_extract): self.assertEqual(dataset.class_to_idx[dataset.classes[0]], target) @mock.patch('torchvision.datasets.mnist.download_and_extract_archive') - def test_fashionmnist(self, mock_download_extract): + @mock.patch('torchvision.datasets.mnist.check_integrity', return_value=True) + def test_fashionmnist(self, mock_download_extract, mock_check_integrity): num_examples = 30 with mnist_root(num_examples, "FashionMNIST") as root: dataset = torchvision.datasets.FashionMNIST(root, download=True) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index e798894089b..4757605516b 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -11,7 +11,7 @@ import lzma from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union from .utils import download_url, download_and_extract_archive, extract_archive, \ - verify_str_arg + verify_str_arg, check_integrity class MNIST(VisionDataset): @@ -75,6 +75,10 @@ def __init__( target_transform=target_transform) self.train = train # training set or test set + if self._check_legacy_exist(): + self.data, self.targets = self._load_legacy_cache() + return + if download: self.download() @@ -82,11 +86,32 @@ def __init__( raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') - if self.train: - data_file = self.training_file - else: - data_file = self.test_file - self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) + self.data, self.targets = self._load_data() + + def _check_legacy_exist(self): + processed_folder_exists = os.path.exists(self.processed_folder) + if not processed_folder_exists: + return False + + return all( + check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + ) + + def _load_legacy_data(self): + # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # directly. + data_file = self.training_file if self.train else self.test_file + return torch.load(os.path.join(self.processed_folder, data_file)) + + def _load_data(self): + image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" + data = read_image_file(os.path.join(self.raw_folder, image_file)) + + label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" + targets = read_label_file(os.path.join(self.raw_folder, label_file)) + + return data, targets + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ @@ -126,43 +151,23 @@ def class_to_idx(self) -> Dict[str, int]: return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self) -> bool: - return (os.path.exists(os.path.join(self.processed_folder, - self.training_file)) and - os.path.exists(os.path.join(self.processed_folder, - self.test_file))) + return all( + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]), md5) + for url, md5 in self.resources + ) def download(self) -> None: - """Download the MNIST data if it doesn't exist in processed_folder already.""" + """Download the MNIST data if it doesn't exist already.""" if self._check_exists(): return os.makedirs(self.raw_folder, exist_ok=True) - os.makedirs(self.processed_folder, exist_ok=True) - # download files for url, md5 in self.resources: filename = url.rpartition('/')[2] download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) - # process and save as torch files - print('Processing...') - - training_set = ( - read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), - read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) - ) - test_set = ( - read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), - read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) - ) - with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: - torch.save(training_set, f) - with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: - torch.save(test_set, f) - - print('Done!') - def extra_repr(self) -> str: return "Split: {}".format("Train" if self.train is True else "Test") From 4db2e334f71d3c7d7d14c924381bc179bbecb3b0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 19 Feb 2021 10:59:07 +0100 Subject: [PATCH 2/8] remove unnecessary lazy import --- torchvision/datasets/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 4757605516b..a36dd3a9028 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union from .utils import download_url, download_and_extract_archive, extract_archive, \ verify_str_arg, check_integrity +import shutil class MNIST(VisionDataset): @@ -282,7 +283,6 @@ def _test_file(split) -> str: def download(self) -> None: """Download the EMNIST data if it doesn't exist in processed_folder already.""" - import shutil if self._check_exists(): return From 53f31b62579b32f3834f74430d88787ad803eacd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 19 Feb 2021 11:06:26 +0100 Subject: [PATCH 3/8] remove false check of binaries against the md5 of archives --- torchvision/datasets/mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index a36dd3a9028..bfeecd8a22c 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -153,8 +153,8 @@ def class_to_idx(self) -> Dict[str, int]: def _check_exists(self) -> bool: return all( - check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]), md5) - for url, md5 in self.resources + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) + for url, _ in self.resources ) def download(self) -> None: From 8265f2fe851a818d48850030931e947670a30b62 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 19 Feb 2021 11:42:02 +0100 Subject: [PATCH 4/8] remove caching from EMNIST --- torchvision/datasets/mnist.py | 46 ++++++++++++++++------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index bfeecd8a22c..714b2432092 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -281,43 +281,39 @@ def _training_file(split) -> str: def _test_file(split) -> str: return 'test_{}.pt'.format(split) + @property + def _file_prefix(self) -> str: + return f"emnist-{self.split}-{'train' if self.train else 'test'}" + + @property + def images_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") + + @property + def labels_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") + + def _load_data(self): + return read_image_file(self.images_file), read_label_file(self.labels_file) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + def download(self) -> None: - """Download the EMNIST data if it doesn't exist in processed_folder already.""" + """Download the EMNIST data if it doesn't exist already.""" if self._check_exists(): return os.makedirs(self.raw_folder, exist_ok=True) - os.makedirs(self.processed_folder, exist_ok=True) - # download files - print('Downloading and extracting zip archive') - download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip", - remove_finished=True, md5=self.md5) + download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): - extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder) - - # process and save as torch files - for split in self.splits: - print('Processing ' + split) - training_set = ( - read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), - read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) - ) - test_set = ( - read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), - read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) - ) - with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f: - torch.save(training_set, f) - with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f: - torch.save(test_set, f) + extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) shutil.rmtree(gzip_folder) - print('Done!') - class QMNIST(MNIST): """`QMNIST `_ Dataset. From 3a888896c8f26d027aa165df79bf6b32876c1320 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 19 Feb 2021 11:55:07 +0100 Subject: [PATCH 5/8] remove caching from QMNIST --- torchvision/datasets/mnist.py | 55 +++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 714b2432092..5612db1a10e 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -382,40 +382,51 @@ def __init__( self.test_file = self.data_file super(QMNIST, self).__init__(root, train, **kwargs) + @property + def images_file(self) -> str: + (url, _), _ = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + @property + def labels_file(self) -> str: + _, (url, _) = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def _load_data(self): + data = read_sn3_pascalvincent_tensor(self.images_file) + assert (data.dtype == torch.uint8) + assert (data.ndimension() == 3) + + targets = read_sn3_pascalvincent_tensor(self.labels_file).long() + assert (targets.ndimension() == 2) + + if self.what == 'test10k': + data = data[0:10000, :, :].clone() + targets = targets[0:10000, :].clone() + elif self.what == 'test50k': + data = data[10000:, :, :].clone() + targets = targets[10000:, :].clone() + + return data, targets + def download(self) -> None: - """Download the QMNIST data if it doesn't exist in processed_folder already. + """Download the QMNIST data if it doesn't exist already. Note that we only download what has been asked for (argument 'what'). """ if self._check_exists(): return + os.makedirs(self.raw_folder, exist_ok=True) - os.makedirs(self.processed_folder, exist_ok=True) split = self.resources[self.subsets[self.what]] - files = [] - # download data files if not already there for url, md5 in split: filename = url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) if not os.path.isfile(file_path): - download_url(url, root=self.raw_folder, filename=filename, md5=md5) - files.append(file_path) - - # process and save as torch files - print('Processing...') - data = read_sn3_pascalvincent_tensor(files[0]) - assert(data.dtype == torch.uint8) - assert(data.ndimension() == 3) - targets = read_sn3_pascalvincent_tensor(files[1]).long() - assert(targets.ndimension() == 2) - if self.what == 'test10k': - data = data[0:10000, :, :].clone() - targets = targets[0:10000, :].clone() - if self.what == 'test50k': - data = data[10000:, :, :].clone() - targets = targets[10000:, :].clone() - with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f: - torch.save((data, targets), f) + download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag From 8cb1fb8b674da9b389a3764fe7fd068d716f0ef0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 19 Feb 2021 11:56:33 +0100 Subject: [PATCH 6/8] lint --- torchvision/datasets/mnist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 5612db1a10e..bf571f4e2d7 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -77,7 +77,7 @@ def __init__( self.train = train # training set or test set if self._check_legacy_exist(): - self.data, self.targets = self._load_legacy_cache() + self.data, self.targets = self._load_legacy_data() return if download: @@ -113,7 +113,6 @@ def _load_data(self): return data, targets - def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: From b0475de88ac10261f11b4b4310f7abc2eec08c01 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Sun, 21 Feb 2021 16:26:29 +0100 Subject: [PATCH 7/8] fix EMNIST --- torchvision/datasets/mnist.py | 32 +++++++++++--------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index bf571f4e2d7..16c980ab276 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -9,7 +9,7 @@ import string import gzip import lzma -from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .utils import download_url, download_and_extract_archive, extract_archive, \ verify_str_arg, check_integrity import shutil @@ -425,7 +425,12 @@ def download(self) -> None: filename = url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) if not os.path.isfile(file_path): - download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) + download_url(url, self.raw_folder, filename=filename, md5=md5) + if filename.endswith(".xz"): + with lzma.open(file_path, "rb") as fh1, open(os.path.splitext(file_path)[0], "wb") as fh2: + fh2.write(fh1.read()) + else: + extract_archive(file_path, os.path.splitext(file_path)[0]) def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag @@ -447,19 +452,6 @@ def get_int(b: bytes) -> int: return int(codecs.encode(b, 'hex'), 16) -def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]: - """Return a file object that possibly decompresses 'path' on the fly. - Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. - """ - if not isinstance(path, torch._six.string_classes): - return path - if path.endswith('.gz'): - return gzip.open(path, 'rb') - if path.endswith('.xz'): - return lzma.open(path, 'rb') - return open(path, 'rb') - - SN3_PASCALVINCENT_TYPEMAP = { 8: (torch.uint8, np.uint8, np.uint8), 9: (torch.int8, np.int8, np.int8), @@ -470,12 +462,12 @@ def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile] } -def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor: +def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). Argument may be a filename, compressed filename, or file object. """ # read - with open_maybe_compressed_file(path) as f: + with open(path, "rb") as f: data = f.read() # parse magic = get_int(data[0:4]) @@ -491,16 +483,14 @@ def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> def read_label_file(path: str) -> torch.Tensor: - with open(path, 'rb') as f: - x = read_sn3_pascalvincent_tensor(f, strict=False) + x = read_sn3_pascalvincent_tensor(path, strict=False) assert(x.dtype == torch.uint8) assert(x.ndimension() == 1) return x.long() def read_image_file(path: str) -> torch.Tensor: - with open(path, 'rb') as f: - x = read_sn3_pascalvincent_tensor(f, strict=False) + x = read_sn3_pascalvincent_tensor(path, strict=False) assert(x.dtype == torch.uint8) assert(x.ndimension() == 3) return x From bb908d8a10273cb0cf2eaed4b8b749a5372f28d6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 16 Mar 2021 10:49:05 +0100 Subject: [PATCH 8/8] streamline QMNIST download --- torchvision/datasets/mnist.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index e4d49492369..e356f17dd1b 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -7,11 +7,9 @@ import torch import codecs import string -import lzma from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError -from .utils import download_url, download_and_extract_archive, extract_archive, \ - verify_str_arg, check_integrity +from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity import shutil @@ -448,12 +446,7 @@ def download(self) -> None: filename = url.rpartition('/')[2] file_path = os.path.join(self.raw_folder, filename) if not os.path.isfile(file_path): - download_url(url, self.raw_folder, filename=filename, md5=md5) - if filename.endswith(".xz"): - with lzma.open(file_path, "rb") as fh1, open(os.path.splitext(file_path)[0], "wb") as fh2: - fh2.write(fh1.read()) - else: - extract_archive(file_path, os.path.splitext(file_path)[0]) + download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag