Skip to content

Commit 3a88889

Browse files
committed
remove caching from QMNIST
1 parent 8265f2f commit 3a88889

File tree

1 file changed

+33
-22
lines changed

1 file changed

+33
-22
lines changed

torchvision/datasets/mnist.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -382,40 +382,51 @@ def __init__(
382382
self.test_file = self.data_file
383383
super(QMNIST, self).__init__(root, train, **kwargs)
384384

385+
@property
386+
def images_file(self) -> str:
387+
(url, _), _ = self.resources[self.subsets[self.what]]
388+
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
389+
390+
@property
391+
def labels_file(self) -> str:
392+
_, (url, _) = self.resources[self.subsets[self.what]]
393+
return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])
394+
395+
def _check_exists(self) -> bool:
396+
return all(check_integrity(file) for file in (self.images_file, self.labels_file))
397+
398+
def _load_data(self):
399+
data = read_sn3_pascalvincent_tensor(self.images_file)
400+
assert (data.dtype == torch.uint8)
401+
assert (data.ndimension() == 3)
402+
403+
targets = read_sn3_pascalvincent_tensor(self.labels_file).long()
404+
assert (targets.ndimension() == 2)
405+
406+
if self.what == 'test10k':
407+
data = data[0:10000, :, :].clone()
408+
targets = targets[0:10000, :].clone()
409+
elif self.what == 'test50k':
410+
data = data[10000:, :, :].clone()
411+
targets = targets[10000:, :].clone()
412+
413+
return data, targets
414+
385415
def download(self) -> None:
386-
"""Download the QMNIST data if it doesn't exist in processed_folder already.
416+
"""Download the QMNIST data if it doesn't exist already.
387417
Note that we only download what has been asked for (argument 'what').
388418
"""
389419
if self._check_exists():
390420
return
421+
391422
os.makedirs(self.raw_folder, exist_ok=True)
392-
os.makedirs(self.processed_folder, exist_ok=True)
393423
split = self.resources[self.subsets[self.what]]
394-
files = []
395424

396-
# download data files if not already there
397425
for url, md5 in split:
398426
filename = url.rpartition('/')[2]
399427
file_path = os.path.join(self.raw_folder, filename)
400428
if not os.path.isfile(file_path):
401-
download_url(url, root=self.raw_folder, filename=filename, md5=md5)
402-
files.append(file_path)
403-
404-
# process and save as torch files
405-
print('Processing...')
406-
data = read_sn3_pascalvincent_tensor(files[0])
407-
assert(data.dtype == torch.uint8)
408-
assert(data.ndimension() == 3)
409-
targets = read_sn3_pascalvincent_tensor(files[1]).long()
410-
assert(targets.ndimension() == 2)
411-
if self.what == 'test10k':
412-
data = data[0:10000, :, :].clone()
413-
targets = targets[0:10000, :].clone()
414-
if self.what == 'test50k':
415-
data = data[10000:, :, :].clone()
416-
targets = targets[10000:, :].clone()
417-
with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
418-
torch.save((data, targets), f)
429+
download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5)
419430

420431
def __getitem__(self, index: int) -> Tuple[Any, Any]:
421432
# redefined to handle the compat flag

0 commit comments

Comments
 (0)