11import functools
22import io
33import pickle
4+ from collections import deque
45from pathlib import Path
56
67import pytest
1112from torch .utils .data .graph import traverse_dps
1213from torch .utils .data .graph_settings import get_all_graph_pipes
1314from torchdata .datapipes .iter import ShardingFilter , Shuffler
15+ from torchdata .datapipes .utils import StreamWrapper
1416from torchvision ._utils import sequence_to_str
15- from torchvision .prototype import datasets , transforms
17+ from torchvision .prototype import datasets , features , transforms
1618from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE
17- from torchvision . prototype . features import Image , Label
19+
1820
1921assert_samples_equal = functools .partial (
2022 assert_equal , pair_types = (TensorLikePair , ObjectPair ), rtol = 0 , atol = 0 , equal_nan = True
@@ -25,6 +27,17 @@ def extract_datapipes(dp):
2527 return get_all_graph_pipes (traverse_dps (dp ))
2628
2729
30+ def consume (iterator ):
31+ # Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
32+ deque (iterator , maxlen = 0 )
33+
34+
35+ def next_consume (iterator ):
36+ item = next (iterator )
37+ consume (iterator )
38+ return item
39+
40+
2841@pytest .fixture (autouse = True )
2942def test_home (mocker , tmp_path ):
3043 mocker .patch ("torchvision.prototype.datasets._api.home" , return_value = str (tmp_path ))
@@ -66,7 +79,7 @@ def test_sample(self, dataset_mock, config):
6679 dataset , _ = dataset_mock .load (config )
6780
6881 try :
69- sample = next (iter (dataset ))
82+ sample = next_consume (iter (dataset ))
7083 except StopIteration :
7184 raise AssertionError ("Unable to draw any sample." ) from None
7285 except Exception as error :
@@ -84,22 +97,53 @@ def test_num_samples(self, dataset_mock, config):
8497
8598 assert len (list (dataset )) == mock_info ["num_samples" ]
8699
100+ @pytest .fixture
101+ def log_session_streams (self ):
102+ debug_unclosed_streams = StreamWrapper .debug_unclosed_streams
103+ try :
104+ StreamWrapper .debug_unclosed_streams = True
105+ yield
106+ finally :
107+ StreamWrapper .debug_unclosed_streams = debug_unclosed_streams
108+
87109 @parametrize_dataset_mocks (DATASET_MOCKS )
88- def test_no_vanilla_tensors (self , dataset_mock , config ):
110+ def test_stream_closing (self , log_session_streams , dataset_mock , config ):
111+ def make_msg_and_close (head ):
112+ unclosed_streams = []
113+ for stream in StreamWrapper .session_streams .keys ():
114+ unclosed_streams .append (repr (stream .file_obj ))
115+ stream .close ()
116+ unclosed_streams = "\n " .join (unclosed_streams )
117+ return f"{ head } \n \n { unclosed_streams } "
118+
119+ if StreamWrapper .session_streams :
120+ raise pytest .UsageError (make_msg_and_close ("A previous test did not close the following streams:" ))
121+
89122 dataset , _ = dataset_mock .load (config )
90123
91- vanilla_tensors = {key for key , value in next (iter (dataset )).items () if type (value ) is torch .Tensor }
92- if vanilla_tensors :
124+ consume (iter (dataset ))
125+
126+ if StreamWrapper .session_streams :
127+ raise AssertionError (make_msg_and_close ("The following streams were not closed after a full iteration:" ))
128+
129+ @parametrize_dataset_mocks (DATASET_MOCKS )
130+ def test_no_simple_tensors (self , dataset_mock , config ):
131+ dataset , _ = dataset_mock .load (config )
132+
133+ simple_tensors = {key for key , value in next_consume (iter (dataset )).items () if features .is_simple_tensor (value )}
134+ if simple_tensors :
93135 raise AssertionError (
94136 f"The values of key(s) "
95- f"{ sequence_to_str (sorted (vanilla_tensors ), separate_last = 'and ' )} contained vanilla tensors."
137+ f"{ sequence_to_str (sorted (simple_tensors ), separate_last = 'and ' )} contained simple tensors."
96138 )
97139
98140 @parametrize_dataset_mocks (DATASET_MOCKS )
99141 def test_transformable (self , dataset_mock , config ):
100142 dataset , _ = dataset_mock .load (config )
101143
102- next (iter (dataset .map (transforms .Identity ())))
144+ dataset = dataset .map (transforms .Identity ())
145+
146+ consume (iter (dataset ))
103147
104148 @parametrize_dataset_mocks (DATASET_MOCKS )
105149 def test_traversable (self , dataset_mock , config ):
@@ -131,7 +175,7 @@ def test_data_loader(self, dataset_mock, config, num_workers):
131175 collate_fn = self ._collate_fn ,
132176 )
133177
134- next ( iter ( dl ) )
178+ consume ( dl )
135179
136180 # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
137181 # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
@@ -148,7 +192,7 @@ def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
148192 def test_save_load (self , dataset_mock , config ):
149193 dataset , _ = dataset_mock .load (config )
150194
151- sample = next (iter (dataset ))
195+ sample = next_consume (iter (dataset ))
152196
153197 with io .BytesIO () as buffer :
154198 torch .save (sample , buffer )
@@ -177,7 +221,7 @@ class TestQMNIST:
177221 def test_extra_label (self , dataset_mock , config ):
178222 dataset , _ = dataset_mock .load (config )
179223
180- sample = next (iter (dataset ))
224+ sample = next_consume (iter (dataset ))
181225 for key , type in (
182226 ("nist_hsf_series" , int ),
183227 ("nist_writer_id" , int ),
@@ -214,7 +258,7 @@ def test_sample_content(self, dataset_mock, config):
214258 assert "image" in sample
215259 assert "label" in sample
216260
217- assert isinstance (sample ["image" ], Image )
218- assert isinstance (sample ["label" ], Label )
261+ assert isinstance (sample ["image" ], features . Image )
262+ assert isinstance (sample ["label" ], features . Label )
219263
220264 assert sample ["image" ].shape == (1 , 16 , 16 )
0 commit comments