1818from torch .utils .data import TensorDataset
1919
2020from pytorch_lightning .trainer .supporters import (
21+ _nested_calc_num_data ,
2122 CombinedDataset ,
2223 CombinedLoader ,
2324 CombinedLoaderIterator ,
@@ -61,7 +62,7 @@ def test_cycle_iterator():
6162def test_none_length_cycle_iterator ():
6263 """Test the infinite cycling function of `CycleIterator`"""
6364 iterator = CycleIterator (range (100 ))
64- assert iterator .__len__ () == float (' inf' )
65+ assert iterator .__len__ () == float (" inf" )
6566
6667 # test infinite loop
6768 for idx , item in enumerate (iterator ):
@@ -70,12 +71,15 @@ def test_none_length_cycle_iterator():
7071 assert item == 0
7172
7273
73- @pytest .mark .parametrize (['dataset_1' , 'dataset_2' ], [
74- ([list (range (10 )), list (range (20 ))]),
75- ([range (10 ), range (20 )]),
76- ([torch .randn (10 , 3 , 2 ), torch .randn (20 , 5 , 6 )]),
77- ([TensorDataset (torch .randn (10 , 3 , 2 )), TensorDataset (torch .randn (20 , 5 , 6 ))])
78- ])
74+ @pytest .mark .parametrize (
75+ ["dataset_1" , "dataset_2" ],
76+ [
77+ ([list (range (10 )), list (range (20 ))]),
78+ ([range (10 ), range (20 )]),
79+ ([torch .randn (10 , 3 , 2 ), torch .randn (20 , 5 , 6 )]),
80+ ([TensorDataset (torch .randn (10 , 3 , 2 )), TensorDataset (torch .randn (20 , 5 , 6 ))]),
81+ ],
82+ )
7983def test_combined_dataset (dataset_1 , dataset_2 ):
8084 """Verify the length of the CombinedDataset"""
8185 datasets = [dataset_1 , dataset_2 ]
@@ -86,83 +90,91 @@ def test_combined_dataset(dataset_1, dataset_2):
8690
8791
8892def test_combined_dataset_length_mode_error ():
89- with pytest .raises (MisconfigurationException , match = ' Invalid Mode' ):
90- CombinedDataset ._calc_num_data ([range (10 )], ' test' )
93+ with pytest .raises (MisconfigurationException , match = " Invalid Mode" ):
94+ CombinedDataset ._calc_num_data ([range (10 )], " test" )
9195
9296
9397def test_combined_loader_iterator_dict_min_size ():
9498 """Test `CombinedLoaderIterator` given mapping loaders"""
95- loaders = {'a' : torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
96- 'b' : torch .utils .data .DataLoader (range (20 ), batch_size = 5 )}
99+ loaders = {
100+ "a" : torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
101+ "b" : torch .utils .data .DataLoader (range (20 ), batch_size = 5 ),
102+ }
97103
98104 combined_iter = CombinedLoaderIterator (loaders )
99105
100106 for idx , item in enumerate (combined_iter ):
101107 assert isinstance (item , dict )
102108 assert len (item ) == 2
103- assert 'a' in item and 'b' in item
109+ assert "a" in item and "b" in item
104110
105- assert idx == min (len (loaders ['a' ]), len (loaders ['b' ])) - 1
111+ assert idx == min (len (loaders ["a" ]), len (loaders ["b" ])) - 1
106112
107113
108114def test_combined_loader_init_mode_error ():
109115 """Test the ValueError when constructing `CombinedLoader`"""
110- with pytest .raises (MisconfigurationException , match = ' selected unsupported mode' ):
111- CombinedLoader ([range (10 )], ' testtt' )
116+ with pytest .raises (MisconfigurationException , match = " selected unsupported mode" ):
117+ CombinedLoader ([range (10 )], " testtt" )
112118
113119
114120def test_combined_loader_loader_type_error ():
115121 """Test the ValueError when wrapping the loaders"""
116- with pytest .raises (ValueError , match = ' Invalid Datatype' ):
117- CombinedLoader (None , ' max_size_cycle' )
122+ with pytest .raises (ValueError , match = " Invalid Datatype" ):
123+ CombinedLoader (None , " max_size_cycle" )
118124
119125
120126def test_combined_loader_calc_length_mode_error ():
121127 """Test the ValueError when calculating the number of batches"""
122- with pytest .raises (TypeError , match = 'Got Type NoneType, but expected one of Sequence, int or Mapping' ):
128+ with pytest .raises (TypeError , match = "Expected data to be int, Sequence or Mapping, but got NoneType" ):
123129 CombinedLoader ._calc_num_batches (None )
124130
125131
126132def test_combined_loader_dict_min_size ():
127133 """Test `CombinedLoader` of mode 'min_size' given mapping loaders"""
128- loaders = {'a' : torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
129- 'b' : torch .utils .data .DataLoader (range (20 ), batch_size = 5 )}
134+ loaders = {
135+ "a" : torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
136+ "b" : torch .utils .data .DataLoader (range (20 ), batch_size = 5 ),
137+ }
130138
131- combined_loader = CombinedLoader (loaders , ' min_size' )
139+ combined_loader = CombinedLoader (loaders , " min_size" )
132140
133141 assert len (combined_loader ) == min ([len (v ) for v in loaders .values ()])
134142
135143 for idx , item in enumerate (combined_loader ):
136144 assert isinstance (item , dict )
137145 assert len (item ) == 2
138- assert 'a' in item and 'b' in item
146+ assert "a" in item and "b" in item
139147
140148 assert idx == len (combined_loader ) - 1
141149
142150
143151def test_combined_loader_dict_max_size_cycle ():
144152 """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders"""
145- loaders = {'a' : torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
146- 'b' : torch .utils .data .DataLoader (range (20 ), batch_size = 5 )}
153+ loaders = {
154+ "a" : torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
155+ "b" : torch .utils .data .DataLoader (range (20 ), batch_size = 5 ),
156+ }
147157
148- combined_loader = CombinedLoader (loaders , ' max_size_cycle' )
158+ combined_loader = CombinedLoader (loaders , " max_size_cycle" )
149159
150160 assert len (combined_loader ) == max ([len (v ) for v in loaders .values ()])
151161
152162 for idx , item in enumerate (combined_loader ):
153163 assert isinstance (item , dict )
154164 assert len (item ) == 2
155- assert 'a' in item and 'b' in item
165+ assert "a" in item and "b" in item
156166
157167 assert idx == len (combined_loader ) - 1
158168
159169
160170def test_combined_loader_sequence_min_size ():
161171 """Test `CombinedLoader` of mode 'min_size' given sequence loaders"""
162- loaders = [torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
163- torch .utils .data .DataLoader (range (20 ), batch_size = 5 )]
172+ loaders = [
173+ torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
174+ torch .utils .data .DataLoader (range (20 ), batch_size = 5 ),
175+ ]
164176
165- combined_loader = CombinedLoader (loaders , ' min_size' )
177+ combined_loader = CombinedLoader (loaders , " min_size" )
166178
167179 assert len (combined_loader ) == min ([len (v ) for v in loaders ])
168180
@@ -175,10 +187,12 @@ def test_combined_loader_sequence_min_size():
175187
176188def test_combined_loader_sequence_max_size_cycle ():
177189 """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders"""
178- loaders = [torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
179- torch .utils .data .DataLoader (range (20 ), batch_size = 5 )]
190+ loaders = [
191+ torch .utils .data .DataLoader (range (10 ), batch_size = 4 ),
192+ torch .utils .data .DataLoader (range (20 ), batch_size = 5 ),
193+ ]
180194
181- combined_loader = CombinedLoader (loaders , ' max_size_cycle' )
195+ combined_loader = CombinedLoader (loaders , " max_size_cycle" )
182196
183197 assert len (combined_loader ) == max ([len (v ) for v in loaders ])
184198
@@ -187,3 +201,22 @@ def test_combined_loader_sequence_max_size_cycle():
187201 assert len (item ) == 2
188202
189203 assert idx == len (combined_loader ) - 1
204+
205+
206+ @pytest .mark .parametrize (
207+ ["input_data" , "compute_func" , "expected_length" ],
208+ [
209+ ([* range (10 ), list (range (1 , 20 ))], min , 0 ),
210+ ([* range (10 ), list (range (1 , 20 ))], max , 19 ),
211+ ([* range (10 ), {str (i ): i for i in range (1 , 20 )}], min , 0 ),
212+ ([* range (10 ), {str (i ): i for i in range (1 , 20 )}], max , 19 ),
213+ ({** {str (i ): i for i in range (10 )}, "nested" : {str (i ): i for i in range (1 , 20 )}}, min , 0 ),
214+ ({** {str (i ): i for i in range (10 )}, "nested" : {str (i ): i for i in range (1 , 20 )}}, max , 19 ),
215+ ({** {str (i ): i for i in range (10 )}, "nested" : list (range (20 ))}, min , 0 ),
216+ ({** {str (i ): i for i in range (10 )}, "nested" : list (range (20 ))}, max , 19 ),
217+ ],
218+ )
219+ def test_nested_calc_num_data (input_data , compute_func , expected_length ):
220+ calculated_length = _nested_calc_num_data (input_data , compute_func )
221+
222+ assert calculated_length == expected_length
0 commit comments