From cb774b76c670bef9bd67359e906465622555d442 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 11 May 2022 15:05:28 +0200 Subject: [PATCH 1/9] simplify OnlineResource.load --- test/test_prototype_datasets_utils.py | 198 +++++++++++++++++- .../prototype/datasets/utils/_resource.py | 35 ++-- 2 files changed, 213 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index b1c95844574..2c63ffb354c 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -1,11 +1,14 @@ +import gzip +import pathlib import sys import numpy as np import pytest import torch -from datasets_utils import make_fake_flo_file +from datasets_utils import make_fake_flo_file, make_tar +from torchdata.datapipes.iter import FileOpener, TarArchiveLoader from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset +from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -48,6 +51,197 @@ def test_read_flo(tmpdir): torch.testing.assert_close(actual, expected) +class TestOnlineResource: + class DummyResource(OnlineResource): + def __init__(self, download_fn=None, **kwargs): + super().__init__(**kwargs) + self._download_fn = download_fn + + def _download(self, root): + if self._download_fn is None: + raise pytest.UsageError( + "`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`." + ) + + return self._download_fn(self, root) + + def _make_file(self, root, *, content, name="file.txt"): + file = root / name + with open(file, "w") as fh: + fh.write(content) + + return file + + def _make_folder(self, root, *, name="folder"): + folder = root / name + subfolder = folder / "subfolder" + subfolder.mkdir(parents=True) + + files = {} + for idx, root in enumerate([folder, folder, subfolder]): + content = f"sentinel{idx}" + file = self._make_file(root, name=f"file{idx}.txt", content=content) + files[str(file)] = content + + return folder, files + + def _make_tar(self, root, *, name="archive.tar", remove=True): + folder, files = self._make_folder(root, name=name.split(".")[0]) + archive = make_tar(root, name, folder, remove=remove) + files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()} + return archive, files + + def test_load_file(self, tmp_path): + content = "sentinel" + file = self._make_file(tmp_path, content=content) + + resource = self.DummyResource(file_name=file.name) + + dp = resource.load(tmp_path) + assert isinstance(dp, FileOpener) + + data = list(dp) + assert len(data) == 1 + + path, buffer = data[0] + assert path == str(file) + assert buffer.read().decode() == content + + def test_load_folder(self, tmp_path): + folder, files = self._make_folder(tmp_path) + + resource = self.DummyResource(file_name=folder.name) + + dp = resource.load(tmp_path) + assert isinstance(dp, FileOpener) + assert {path: buffer.read().decode() for path, buffer in dp} == files + + def test_load_archive(self, tmp_path): + archive, files = self._make_tar(tmp_path) + + resource = self.DummyResource(file_name=archive.name) + + dp = resource.load(tmp_path) + assert isinstance(dp, TarArchiveLoader) + assert {path: buffer.read().decode() for path, buffer in dp} == files + + def test_priority_decompressed_gt_raw(self, tmp_path): + # We don't need to actually compress here. Adding the suffix is sufficient + self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz") + file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt") + + resource = self.DummyResource(file_name=file.name) + + dp = resource.load(tmp_path) + path, buffer = next(iter(dp)) + + assert path == str(file) + assert buffer.read().decode() == "decompressed_sentinel" + + def test_priority_extracted_gt_decopressed(self, tmp_path): + archive, _ = self._make_tar(tmp_path, remove=False) + + resource = self.DummyResource(file_name=archive.name) + + dp = resource.load(tmp_path) + # If the archive had been selected, this would be a `TarArchiveReader` + assert isinstance(dp, FileOpener) + + def test_download(self, tmp_path): + file_name = "file.txt" + content = "sentinel" + + resource = self.DummyResource( + file_name=file_name, + download_fn=lambda resource, root: self._make_file(root, content=content, name=resource.file_name), + ) + + dp = resource.load(tmp_path) + data = list(dp) + assert len(data) == 1 + + path, buffer = data[0] + assert path == str(tmp_path / file_name) + assert buffer.read().decode() == content + + def test_preprocess_decompress(self, tmp_path): + file_name = "file.txt.gz" + content = "sentinel" + + def download_fn(resource, root): + file = root / resource.file_name + with gzip.open(file, "wb") as fh: + fh.write(content.encode()) + return file + + resource = self.DummyResource(file_name=file_name, preprocess="decompress", download_fn=download_fn) + + dp = resource.load(tmp_path) + data = list(dp) + assert len(data) == 1 + + path, buffer = data[0] + assert path == str(tmp_path / file_name).replace(".gz", "") + assert buffer.read().decode() == content + + def test_preprocess_extract(self, tmp_path): + files = None + + def download_fn(resource, root): + nonlocal files + archive, files = self._make_tar(root, name=resource.file_name) + return archive + + resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn) + + dp = resource.load(tmp_path) + assert files is not None, "`download_fn()` was never called" + assert isinstance(dp, FileOpener) + + actual = {path: buffer.read().decode() for path, buffer in dp} + expected = { + path.replace(resource.file_name, resource.file_name.split(".")[0]): content + for path, content in files.items() + } + assert actual == expected + + def test_preprocess_fn(self, tmp_path): + file_name = "file.txt" + content = "sentinel" + + preprocessed_file_name = f"preprocessed_{file_name}" + preprocessed_content = f"preprocessed_{content}" + + resource = self.DummyResource( + file_name=file_name, + preprocess=lambda path: self._make_file( + path.parent, content=preprocessed_content, name=preprocessed_file_name + ), + download_fn=lambda resource, root: self._make_file(root, content=content, name=resource.file_name), + ) + + dp = resource.load(tmp_path) + data = list(dp) + assert len(data) == 1 + + path, buffer = data[0] + assert path == str(tmp_path / preprocessed_file_name) + assert buffer.read().decode() == preprocessed_content + + def test_preprocess_only_after_download(self, tmp_path): + file = self._make_file(tmp_path, content="_") + + def preprocess(path): + raise AssertionError("`preprocess` was called although the file was already present.") + + resource = self.DummyResource( + file_name=file.name, + preprocess=preprocess, + ) + + resource.load(tmp_path) + + class TestHttpResource: def test_resolve_to_http(self, mocker): file_name = "data.tar" diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 507428a98d3..4ca9b17a60f 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -91,31 +91,30 @@ def load( ) -> IterDataPipe[Tuple[str, IO]]: root = pathlib.Path(root) path = root / self.file_name + # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories # with no suffixes at all. stem = path.name.replace("".join(path.suffixes), "") - - # In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since - # extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive - # is always extracted in a folder with the corresponding file name. + # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder + # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the + # test split of the stanford-cars dataset uses the files + # - cars_test.tgz + # - cars_test_annos_withlabels.mat + # Globbing for `"cars_test*"` picks up both. + candidates = {file for file in path.parent.glob(f"{stem}.*")} folder_candidate = path.parent / stem - if folder_candidate.exists() and folder_candidate.is_dir(): - return self._loader(folder_candidate) - - # If there is no folder, we look for all files that share the same stem as the raw file, but might have a - # different suffix. - file_candidates = {file for file in path.parent.glob(stem + ".*")} - # If we don't find anything, we download the raw file. - if not file_candidates: - file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} - # If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. - if file_candidates == {path}: + if folder_candidate.exists(): + candidates.add(folder_candidate) + + if not candidates: + self.download(root, skip_integrity_check=skip_integrity_check) if self._preprocess is not None: path = self._preprocess(path) - # Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we - # want for the best I/O performance. else: - path = min(file_candidates, key=lambda path: len(path.suffixes)) + # We use the path with the fewest suffixes. This gives us the + # extracted > decompressed > raw + # priority that we want for the best I/O performance. + path = min(candidates, key=lambda path: len(path.suffixes)) return self._loader(path) @abc.abstractmethod From d62747962f9ed6a7b0b80849e7c971efabb5d3da Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 May 2022 11:38:12 +0200 Subject: [PATCH 2/9] [PoC] merge mock data preparation and loading --- test/builtin_dataset_mocks.py | 52 ++++++++++++++++------- test/test_prototype_builtin_datasets.py | 55 +++++++------------------ 2 files changed, 52 insertions(+), 55 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index cea0f297be5..05e00d034e7 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -10,6 +10,7 @@ import pathlib import pickle import random +import shutil import unittest.mock import warnings import xml.etree.ElementTree as ET @@ -22,7 +23,6 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor -from torchvision._utils import sequence_to_str from torchvision.prototype import datasets make_tensor = functools.partial(_make_tensor, device="cpu") @@ -62,27 +62,47 @@ def _parse_mock_info(self, mock_info): return mock_info - def prepare(self, config): + def load(self, config): # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in # test/test_prototype_builtin_datasets.py root = pathlib.Path(datasets.home()) / self.name - root.mkdir(exist_ok=True) + mock_data_folder = root / "__mock__" + mock_data_folder.mkdir(parents=True) - mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) + mock_info = self._parse_mock_info(self.mock_data_fn(mock_data_folder, config)) - with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"): - required_file_names = { - resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() - } - available_file_names = {path.name for path in root.glob("*")} - missing_file_names = required_file_names - available_file_names - if missing_file_names: - raise pytest.UsageError( - f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - f"for {config}, but they were not created by the mock data function." - ) + def mock_data_download(resource, root, **kwargs): + src = mock_data_folder / resource.file_name + if not src.exists(): + raise pytest.UsageError( + f"Dataset '{self.name}' requires the file {resource.file_name} for {config}" + f"but it was not created by the mock data function." + ) - return mock_info + dst = root / resource.file_name + shutil.move(str(src), str(root)) + + return dst + + with unittest.mock.patch( + "torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=mock_data_download + ): + dataset = datasets.load(self.name, **config) + + extra_files = list(mock_data_folder.glob("**/*")) + if not extra_files: + mock_data_folder.rmdir() + else: + pass + # raise pytest.UsageError( + # ( + # f"Dataset '{self.name}' created the following files for {config} in the mock data function, " + # f"but they were not loaded:\n\n" + # ) + # + "\n".join(str(file.relative_to(mock_data_folder)) for file in extra_files) + # ) + + return dataset, mock_info def config_id(name, config): diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 23190b25ddc..5a8c9e7eff8 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -56,18 +56,14 @@ def test_info(self, name): @parametrize_dataset_mocks(DATASET_MOCKS) def test_smoke(self, dataset_mock, config): - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) if not isinstance(dataset, datasets.utils.Dataset): raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_sample(self, dataset_mock, config): - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) try: sample = next(iter(dataset)) @@ -84,17 +80,13 @@ def test_sample(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_num_samples(self, dataset_mock, config): - mock_info = dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, mock_info = dataset_mock.load(config) assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) def test_no_vanilla_tensors(self, dataset_mock, config): - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -105,24 +97,20 @@ def test_no_vanilla_tensors(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, dataset_mock, config): - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) next(iter(dataset.map(transforms.Identity()))) @pytest.mark.parametrize("only_datapipe", [False, True]) @parametrize_dataset_mocks(DATASET_MOCKS) def test_traversable(self, dataset_mock, config, only_datapipe): - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) traverse(dataset, only_datapipe=only_datapipe) @parametrize_dataset_mocks(DATASET_MOCKS) def test_serializable(self, dataset_mock, config): - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) pickle.dumps(dataset) @@ -135,8 +123,7 @@ def _collate_fn(self, batch): @pytest.mark.parametrize("num_workers", [0, 1]) @parametrize_dataset_mocks(DATASET_MOCKS) def test_data_loader(self, dataset_mock, config, num_workers): - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) dl = DataLoader( dataset, @@ -153,17 +140,15 @@ def test_data_loader(self, dataset_mock, config, num_workers): @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, dataset_mock, config, annotation_dp_type): - - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_save_load(self, dataset_mock, config): - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) + sample = next(iter(dataset)) with io.BytesIO() as buffer: @@ -173,8 +158,7 @@ def test_save_load(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_infinite_buffer_size(self, dataset_mock, config): - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) for dp in extract_datapipes(dataset): if hasattr(dp, "buffer_size"): @@ -184,8 +168,7 @@ def test_infinite_buffer_size(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_has_length(self, dataset_mock, config): - dataset_mock.prepare(config) - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) assert len(dataset) > 0 @@ -193,9 +176,7 @@ def test_has_length(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: def test_extra_label(self, dataset_mock, config): - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) sample = next(iter(dataset)) for key, type in ( @@ -218,9 +199,7 @@ def test_label_matches_path(self, dataset_mock, config): if config["split"] != "train": return - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) @@ -230,9 +209,7 @@ def test_label_matches_path(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["usps"]) class TestUSPS: def test_sample_content(self, dataset_mock, config): - dataset_mock.prepare(config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset, _ = dataset_mock.load(config) for sample in dataset: assert "image" in sample From 99c2daf71ed518b2223a4c306c20ea2336a596e1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 May 2022 15:10:18 +0200 Subject: [PATCH 3/9] Revert "cache mock data based on config" This reverts commit 5ed6eedef74865e0baa746a375d5ec1f0ab1bde7. From 65198d15419cb1c8b7d9add5e85fca8ec1e8d844 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 12 May 2022 11:38:12 +0200 Subject: [PATCH 4/9] Revert "[PoC] merge mock data preparation and loading" This reverts commit d62747962f9ed6a7b0b80849e7c971efabb5d3da. --- test/builtin_dataset_mocks.py | 52 +++++++---------------- test/test_prototype_builtin_datasets.py | 55 ++++++++++++++++++------- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 05e00d034e7..cea0f297be5 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -10,7 +10,6 @@ import pathlib import pickle import random -import shutil import unittest.mock import warnings import xml.etree.ElementTree as ET @@ -23,6 +22,7 @@ from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor +from torchvision._utils import sequence_to_str from torchvision.prototype import datasets make_tensor = functools.partial(_make_tensor, device="cpu") @@ -62,47 +62,27 @@ def _parse_mock_info(self, mock_info): return mock_info - def load(self, config): + def prepare(self, config): # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in # test/test_prototype_builtin_datasets.py root = pathlib.Path(datasets.home()) / self.name - mock_data_folder = root / "__mock__" - mock_data_folder.mkdir(parents=True) + root.mkdir(exist_ok=True) - mock_info = self._parse_mock_info(self.mock_data_fn(mock_data_folder, config)) + mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) - def mock_data_download(resource, root, **kwargs): - src = mock_data_folder / resource.file_name - if not src.exists(): - raise pytest.UsageError( - f"Dataset '{self.name}' requires the file {resource.file_name} for {config}" - f"but it was not created by the mock data function." - ) - - dst = root / resource.file_name - shutil.move(str(src), str(root)) - - return dst - - with unittest.mock.patch( - "torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=mock_data_download - ): - dataset = datasets.load(self.name, **config) + with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"): + required_file_names = { + resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() + } + available_file_names = {path.name for path in root.glob("*")} + missing_file_names = required_file_names - available_file_names + if missing_file_names: + raise pytest.UsageError( + f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " + f"for {config}, but they were not created by the mock data function." + ) - extra_files = list(mock_data_folder.glob("**/*")) - if not extra_files: - mock_data_folder.rmdir() - else: - pass - # raise pytest.UsageError( - # ( - # f"Dataset '{self.name}' created the following files for {config} in the mock data function, " - # f"but they were not loaded:\n\n" - # ) - # + "\n".join(str(file.relative_to(mock_data_folder)) for file in extra_files) - # ) - - return dataset, mock_info + return mock_info def config_id(name, config): diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 5a8c9e7eff8..23190b25ddc 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -56,14 +56,18 @@ def test_info(self, name): @parametrize_dataset_mocks(DATASET_MOCKS) def test_smoke(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) if not isinstance(dataset, datasets.utils.Dataset): raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_sample(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) try: sample = next(iter(dataset)) @@ -80,13 +84,17 @@ def test_sample(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_num_samples(self, dataset_mock, config): - dataset, mock_info = dataset_mock.load(config) + mock_info = dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) def test_no_vanilla_tensors(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -97,20 +105,24 @@ def test_no_vanilla_tensors(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) next(iter(dataset.map(transforms.Identity()))) @pytest.mark.parametrize("only_datapipe", [False, True]) @parametrize_dataset_mocks(DATASET_MOCKS) def test_traversable(self, dataset_mock, config, only_datapipe): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) traverse(dataset, only_datapipe=only_datapipe) @parametrize_dataset_mocks(DATASET_MOCKS) def test_serializable(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) pickle.dumps(dataset) @@ -123,7 +135,8 @@ def _collate_fn(self, batch): @pytest.mark.parametrize("num_workers", [0, 1]) @parametrize_dataset_mocks(DATASET_MOCKS) def test_data_loader(self, dataset_mock, config, num_workers): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) dl = DataLoader( dataset, @@ -140,15 +153,17 @@ def test_data_loader(self, dataset_mock, config, num_workers): @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, dataset_mock, config, annotation_dp_type): - dataset, _ = dataset_mock.load(config) + + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_save_load(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) - + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) with io.BytesIO() as buffer: @@ -158,7 +173,8 @@ def test_save_load(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_infinite_buffer_size(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) for dp in extract_datapipes(dataset): if hasattr(dp, "buffer_size"): @@ -168,7 +184,8 @@ def test_infinite_buffer_size(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) def test_has_length(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + dataset = datasets.load(dataset_mock.name, **config) assert len(dataset) > 0 @@ -176,7 +193,9 @@ def test_has_length(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: def test_extra_label(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) for key, type in ( @@ -199,7 +218,9 @@ def test_label_matches_path(self, dataset_mock, config): if config["split"] != "train": return - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) @@ -209,7 +230,9 @@ def test_label_matches_path(self, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["usps"]) class TestUSPS: def test_sample_content(self, dataset_mock, config): - dataset, _ = dataset_mock.load(config) + dataset_mock.prepare(config) + + dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: assert "image" in sample From 232c6a9187ca6acfe35fb174090017673b947874 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 16 May 2022 14:03:18 +0200 Subject: [PATCH 5/9] remove preprocess returning a new path in favor of querying twice --- .../prototype/datasets/utils/_resource.py | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 4ca9b17a60f..b450a94e38f 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -32,7 +32,7 @@ def __init__( *, file_name: str, sha256: Optional[str] = None, - preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None, + preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None, ) -> None: self.file_name = file_name self.sha256 = sha256 @@ -50,14 +50,12 @@ def __init__( self._preprocess = preprocess @staticmethod - def _extract(file: pathlib.Path) -> pathlib.Path: - return pathlib.Path( - extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) - ) + def _extract(file: pathlib.Path) -> None: + extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False) @staticmethod - def _decompress(file: pathlib.Path) -> pathlib.Path: - return pathlib.Path(_decompress(str(file), remove_finished=True)) + def _decompress(file: pathlib.Path) -> None: + _decompress(str(file), remove_finished=True) def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: if path.is_dir(): @@ -93,29 +91,36 @@ def load( path = root / self.file_name # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories - # with no suffixes at all. + # with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which + # is not sufficient for files with multiple suffixes, e.g. foo.tar.gz. stem = path.name.replace("".join(path.suffixes), "") - # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder - # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the - # test split of the stanford-cars dataset uses the files - # - cars_test.tgz - # - cars_test_annos_withlabels.mat - # Globbing for `"cars_test*"` picks up both. - candidates = {file for file in path.parent.glob(f"{stem}.*")} - folder_candidate = path.parent / stem - if folder_candidate.exists(): - candidates.add(folder_candidate) + + def find_candidates(): + # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder + # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the + # test split of the stanford-cars dataset uses the files + # - cars_test.tgz + # - cars_test_annos_withlabels.mat + # Globbing for `"cars_test*"` picks up both. + candidates = {file for file in path.parent.glob(f"{stem}.*")} + folder_candidate = path.parent / stem + if folder_candidate.exists(): + candidates.add(folder_candidate) + + return candidates + + candidates = find_candidates() if not candidates: self.download(root, skip_integrity_check=skip_integrity_check) if self._preprocess is not None: - path = self._preprocess(path) - else: - # We use the path with the fewest suffixes. This gives us the - # extracted > decompressed > raw - # priority that we want for the best I/O performance. - path = min(candidates, key=lambda path: len(path.suffixes)) - return self._loader(path) + self._preprocess(path) + candidates = find_candidates() + + # We use the path with the fewest suffixes. This gives us the + # extracted > decompressed > raw + # priority that we want for the best I/O performance. + return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes))) @abc.abstractmethod def _download(self, root: pathlib.Path) -> None: From 89df201925c75bf3fc3908819c8a2a1140adf7e2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 16 May 2022 14:03:54 +0200 Subject: [PATCH 6/9] address test comments --- test/test_prototype_datasets_utils.py | 61 +++++++++++---------------- 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index 2c63ffb354c..aba26c15032 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -51,6 +51,15 @@ def test_read_flo(tmpdir): torch.testing.assert_close(actual, expected) +# This cannot be defined on the TestOnlineResource class, since it is used in a `@pytest.mark.parametrize` decorator +def _decompress_gz(path): + file = path.with_name(path.name.replace(".gz", "")) + with gzip.open(path, "rb") as rfh, open(file, "wb") as wfh: + wfh.write(rfh.read()) + path.unlink() + return file + + class TestOnlineResource: class DummyResource(OnlineResource): def __init__(self, download_fn=None, **kwargs): @@ -138,7 +147,7 @@ def test_priority_decompressed_gt_raw(self, tmp_path): assert path == str(file) assert buffer.read().decode() == "decompressed_sentinel" - def test_priority_extracted_gt_decopressed(self, tmp_path): + def test_priority_extracted_gt_decompressed(self, tmp_path): archive, _ = self._make_tar(tmp_path, remove=False) resource = self.DummyResource(file_name=archive.name) @@ -148,23 +157,26 @@ def test_priority_extracted_gt_decopressed(self, tmp_path): assert isinstance(dp, FileOpener) def test_download(self, tmp_path): - file_name = "file.txt" - content = "sentinel" + download_fn_was_called = False + + def download_fn(resource, root): + nonlocal download_fn_was_called + download_fn_was_called = True + + return self._make_file(root, content="_", name=resource.file_name) resource = self.DummyResource( - file_name=file_name, - download_fn=lambda resource, root: self._make_file(root, content=content, name=resource.file_name), + file_name="file.txt", + download_fn=download_fn, ) - dp = resource.load(tmp_path) - data = list(dp) - assert len(data) == 1 + resource.load(tmp_path) - path, buffer = data[0] - assert path == str(tmp_path / file_name) - assert buffer.read().decode() == content + assert download_fn_was_called, "`download_fn()` was never called" - def test_preprocess_decompress(self, tmp_path): + # This tests the `"decompress"` literal as well as a custom callable + @pytest.mark.parametrize("preprocess", ["decompress", _decompress_gz]) + def test_preprocess_decompress(self, tmp_path, preprocess): file_name = "file.txt.gz" content = "sentinel" @@ -174,7 +186,7 @@ def download_fn(resource, root): fh.write(content.encode()) return file - resource = self.DummyResource(file_name=file_name, preprocess="decompress", download_fn=download_fn) + resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn) dp = resource.load(tmp_path) data = list(dp) @@ -205,29 +217,6 @@ def download_fn(resource, root): } assert actual == expected - def test_preprocess_fn(self, tmp_path): - file_name = "file.txt" - content = "sentinel" - - preprocessed_file_name = f"preprocessed_{file_name}" - preprocessed_content = f"preprocessed_{content}" - - resource = self.DummyResource( - file_name=file_name, - preprocess=lambda path: self._make_file( - path.parent, content=preprocessed_content, name=preprocessed_file_name - ), - download_fn=lambda resource, root: self._make_file(root, content=content, name=resource.file_name), - ) - - dp = resource.load(tmp_path) - data = list(dp) - assert len(data) == 1 - - path, buffer = data[0] - assert path == str(tmp_path / preprocessed_file_name) - assert buffer.read().decode() == preprocessed_content - def test_preprocess_only_after_download(self, tmp_path): file = self._make_file(tmp_path, content="_") From e8ca14672fb5d9478f09a7cc4bb9982ab2cb3014 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 16 May 2022 14:10:25 +0200 Subject: [PATCH 7/9] clarify comment --- test/test_prototype_datasets_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index aba26c15032..0bcc7bc4caa 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -51,7 +51,8 @@ def test_read_flo(tmpdir): torch.testing.assert_close(actual, expected) -# This cannot be defined on the TestOnlineResource class, since it is used in a `@pytest.mark.parametrize` decorator +# This cannot be defined on the TestOnlineResource class, since it is used in a `@pytest.mark.parametrize` decorator on +# a method on said class def _decompress_gz(path): file = path.with_name(path.name.replace(".gz", "")) with gzip.open(path, "rb") as rfh, open(file, "wb") as wfh: From c0ecb46daf342f20e4132f405381f02888c663a6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 17 May 2022 10:26:40 +0200 Subject: [PATCH 8/9] mypy --- torchvision/prototype/datasets/utils/_resource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index b450a94e38f..3c9b95cb498 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,7 +2,7 @@ import hashlib import itertools import pathlib -from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn +from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set from urllib.parse import urlparse from torchdata.datapipes.iter import ( @@ -95,7 +95,7 @@ def load( # is not sufficient for files with multiple suffixes, e.g. foo.tar.gz. stem = path.name.replace("".join(path.suffixes), "") - def find_candidates(): + def find_candidates() -> Set[pathlib.Path]: # Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder # candidate simultaneously, that would also pick up other files that share the same prefix. For example, the # test split of the stanford-cars dataset uses the files From 82af340d60293593130d0c59c0f780d61c60d77f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 17 May 2022 14:29:26 +0200 Subject: [PATCH 9/9] use builtin decompress utility --- test/test_prototype_datasets_utils.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index 0bcc7bc4caa..8790b1638f9 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -8,6 +8,7 @@ from datasets_utils import make_fake_flo_file, make_tar from torchdata.datapipes.iter import FileOpener, TarArchiveLoader from torchvision.datasets._optical_flow import _read_flo as read_flo_ref +from torchvision.datasets.utils import _decompress from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -51,16 +52,6 @@ def test_read_flo(tmpdir): torch.testing.assert_close(actual, expected) -# This cannot be defined on the TestOnlineResource class, since it is used in a `@pytest.mark.parametrize` decorator on -# a method on said class -def _decompress_gz(path): - file = path.with_name(path.name.replace(".gz", "")) - with gzip.open(path, "rb") as rfh, open(file, "wb") as wfh: - wfh.write(rfh.read()) - path.unlink() - return file - - class TestOnlineResource: class DummyResource(OnlineResource): def __init__(self, download_fn=None, **kwargs): @@ -176,7 +167,13 @@ def download_fn(resource, root): assert download_fn_was_called, "`download_fn()` was never called" # This tests the `"decompress"` literal as well as a custom callable - @pytest.mark.parametrize("preprocess", ["decompress", _decompress_gz]) + @pytest.mark.parametrize( + "preprocess", + [ + "decompress", + lambda path: _decompress(str(path), remove_finished=True), + ], + ) def test_preprocess_decompress(self, tmp_path, preprocess): file_name = "file.txt.gz" content = "sentinel"