66from torch .utils .data .datapipes .iter .grouping import ShardingFilterIterDataPipe as ShardingFilter
77from torch .utils .data .graph import traverse
88from torchdata .datapipes .iter import IterDataPipe , Shuffler
9- from torchvision .prototype import transforms
9+ from torchvision .prototype import transforms , datasets
1010from torchvision .prototype .utils ._internal import sequence_to_str
1111
1212
13- @parametrize_dataset_mocks (DATASET_MOCKS )
13+ def test_coverage ():
14+ untested_datasets = set (datasets .list ()) - DATASET_MOCKS .keys ()
15+ if untested_datasets :
16+ raise AssertionError (
17+ f"The dataset(s) { sequence_to_str (sorted (untested_datasets ), separate_last = 'and ' )} "
18+ f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. "
19+ f"Please add mock data to `test/builtin_dataset_mocks.py`."
20+ )
21+
22+
1423class TestCommon :
24+ @parametrize_dataset_mocks (DATASET_MOCKS )
1525 def test_smoke (self , dataset_mock , config ):
1626 dataset , _ = dataset_mock .load (config )
1727 if not isinstance (dataset , IterDataPipe ):
1828 raise AssertionError (f"Loading the dataset should return an IterDataPipe, but got { type (dataset )} instead." )
1929
30+ @parametrize_dataset_mocks (DATASET_MOCKS )
2031 def test_sample (self , dataset_mock , config ):
2132 dataset , _ = dataset_mock .load (config )
2233
@@ -31,6 +42,7 @@ def test_sample(self, dataset_mock, config):
3142 if not sample :
3243 raise AssertionError ("Sample dictionary is empty." )
3344
45+ @parametrize_dataset_mocks (DATASET_MOCKS )
3446 def test_num_samples (self , dataset_mock , config ):
3547 dataset , mock_info = dataset_mock .load (config )
3648
@@ -40,6 +52,7 @@ def test_num_samples(self, dataset_mock, config):
4052
4153 assert num_samples == mock_info ["num_samples" ]
4254
55+ @parametrize_dataset_mocks (DATASET_MOCKS )
4356 def test_decoding (self , dataset_mock , config ):
4457 dataset , _ = dataset_mock .load (config )
4558
@@ -50,6 +63,7 @@ def test_decoding(self, dataset_mock, config):
5063 f"{ sequence_to_str (sorted (undecoded_features ), separate_last = 'and ' )} were not decoded."
5164 )
5265
66+ @parametrize_dataset_mocks (DATASET_MOCKS )
5367 def test_no_vanilla_tensors (self , dataset_mock , config ):
5468 dataset , _ = dataset_mock .load (config )
5569
@@ -60,16 +74,33 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
6074 f"{ sequence_to_str (sorted (vanilla_tensors ), separate_last = 'and ' )} contained vanilla tensors."
6175 )
6276
77+ @parametrize_dataset_mocks (DATASET_MOCKS )
6378 def test_transformable (self , dataset_mock , config ):
6479 dataset , _ = dataset_mock .load (config )
6580
6681 next (iter (dataset .map (transforms .Identity ())))
6782
83+ @parametrize_dataset_mocks (
84+ DATASET_MOCKS ,
85+ marks = {
86+ "cub200" : pytest .mark .xfail (
87+ reason = "See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
88+ )
89+ },
90+ )
6891 def test_traversable (self , dataset_mock , config ):
6992 dataset , _ = dataset_mock .load (config )
7093
7194 traverse (dataset )
7295
96+ @parametrize_dataset_mocks (
97+ DATASET_MOCKS ,
98+ marks = {
99+ "cub200" : pytest .mark .xfail (
100+ reason = "See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
101+ )
102+ },
103+ )
73104 @pytest .mark .parametrize ("annotation_dp_type" , (Shuffler , ShardingFilter ), ids = lambda type : type .__name__ )
74105 def test_has_annotations (self , dataset_mock , config , annotation_dp_type ):
75106 def scan (graph ):
@@ -86,8 +117,8 @@ def scan(graph):
86117 raise AssertionError (f"The dataset doesn't comprise a { annotation_dp_type .__name__ } () datapipe." )
87118
88119
120+ @parametrize_dataset_mocks (DATASET_MOCKS ["qmnist" ])
89121class TestQMNIST :
90- @parametrize_dataset_mocks ([mock for mock in DATASET_MOCKS if mock .name == "qmnist" ])
91122 def test_extra_label (self , dataset_mock , config ):
92123 dataset , _ = dataset_mock .load (config )
93124
0 commit comments