Skip to content

Commit 1f8a9cc

Browse files
committed
add test
1 parent c1dd209 commit 1f8a9cc

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

test/builtin_dataset_mocks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def _download(self, _):
4343
class DatasetMock:
4444
def __init__(self, name, mock_data_fn, *, configs=None):
4545
self.dataset = find(name)
46-
self.root = TEST_HOME / self.dataset.name
46+
self.home = TEST_HOME
47+
self.root = self.home / self.dataset.name
4748
self.mock_data_fn = self._parse_mock_data(mock_data_fn)
4849
self.configs = configs or self.info._configs
4950
self._cache = {}
@@ -100,7 +101,7 @@ def wrapper(info, root, config):
100101

101102
return wrapper
102103

103-
def _load_mock(self, config):
104+
def make_mock_resources(self, config):
104105
with contextlib.suppress(KeyError):
105106
return self._cache[config]
106107

@@ -120,7 +121,7 @@ def load(self, config, *, decoder=DEFAULT_DECODER):
120121
except ModuleNotFoundError as error:
121122
pytest.skip(str(error))
122123

123-
mock_resources, mock_info = self._load_mock(config)
124+
mock_resources, mock_info = self.make_mock_resources(config)
124125
datapipe = self.dataset._make_datapipe(
125126
[resource.load(self.root) for resource in mock_resources],
126127
config=config,

test/test_prototype_builtin_datasets.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
77
from torch.utils.data.graph import traverse
88
from torchdata.datapipes.iter import IterDataPipe, Shuffler
9-
from torchvision.prototype import transforms
9+
from torchvision.prototype import transforms, datasets
1010
from torchvision.prototype.utils._internal import sequence_to_str
1111

1212

@@ -85,6 +85,13 @@ def scan(graph):
8585
else:
8686
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
8787

88+
def test_loadable_through_api(self, mocker, dataset_mock, config):
89+
# Make all resources that are necessary for the given config
90+
dataset_mock.make_mock_resources(config)
91+
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(dataset_mock.home))
92+
93+
datasets.load(dataset_mock.name, **config)
94+
8895

8996
class TestQMNIST:
9097
@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"])

0 commit comments

Comments
 (0)