99from torchaudio_unittest import common_utils
1010
1111
12+ def _name_from_args (func , _ , params ):
13+ """Return a parameterized test name, based on parameter values."""
14+ return "{}_{}" .format (
15+ func .__name__ ,
16+ "_" .join (str (arg ) for arg in params .args ))
17+
18+
1219@parameterized_class ([
1320 # Single-item batch isolates problems that come purely from adding a
1421 # dimension (rather than processing multiple items)
@@ -58,7 +65,7 @@ def test_griffinlim(self):
5865 @parameterized .expand (list (itertools .product (
5966 [8000 , 16000 , 44100 ],
6067 [1 , 2 ],
61- )), name_func = lambda f , _ , p : f' { f . __name__ } _ { "_" . join ( str ( arg ) for arg in p . args ) } ' )
68+ )), name_func = _name_from_args )
6269 def test_detect_pitch_frequency (self , sample_rate , n_channels ):
6370 # Use different frequencies to ensure each item in the batch returns a
6471 # different answer.
@@ -180,16 +187,16 @@ def test_flanger(self):
180187 sample_rate = 44100
181188 self .assert_batch_consistency (F .flanger , waveforms , sample_rate )
182189
183- def test_sliding_window_cmn (self ):
184- waveforms = torch .randn (self .batch_size , 2 , 1024 ) - 0.5
185- self .assert_batch_consistency (
186- F .sliding_window_cmn , waveforms , center = True , norm_vars = True )
187- self .assert_batch_consistency (
188- F .sliding_window_cmn , waveforms , center = True , norm_vars = False )
189- self .assert_batch_consistency (
190- F .sliding_window_cmn , waveforms , center = False , norm_vars = True )
190+ @parameterized .expand (list (itertools .product (
191+ [True , False ], # center
192+ [True , False ], # norm_vars
193+ )), name_func = _name_from_args )
194+ def test_sliding_window_cmn (self , center , norm_vars ):
195+ torch .manual_seed (0 )
196+ spectrogram = torch .rand (self .batch_size , 2 , 1024 , 1024 ) * 200
191197 self .assert_batch_consistency (
192- F .sliding_window_cmn , waveforms , center = False , norm_vars = False )
198+ F .sliding_window_cmn , spectrogram , center = center ,
199+ norm_vars = norm_vars )
193200
194201 def test_vad_from_file (self ):
195202 filepath = common_utils .get_asset_path ("vad-go-stereo-44100.wav" )
@@ -202,6 +209,7 @@ def test_vad_from_file(self):
202209 def test_vad_different_items (self ):
203210 """Separate test to ensure VAD consistency with differing items."""
204211 sample_rate = 44100
212+ torch .manual_seed (0 )
205213 waveforms = torch .rand (self .batch_size , 2 , 100 ) - 0.5
206214 self .assert_batch_consistency (
207215 F .vad , waveforms , sample_rate = sample_rate )
0 commit comments