77
88from torchaudio_unittest import common_utils
99from .compliance import utils as compliance_utils
10+ from parameterized import parameterized
1011
1112
1213def extract_window (window , wave , f , frame_length , frame_shift , snip_edges ):
@@ -182,20 +183,26 @@ def get_output_fn(sound, args):
182183
183184 self ._compliance_test_helper (self .test2_filepath , 'resample' , 32 , 3 , get_output_fn , atol = 1e-2 , rtol = 1e-5 )
184185
185- def test_resample_waveform_upsample_size (self ):
186- upsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr * 2 )
186+ @parameterized .expand ([("sinc_interpolation" ), ("kaiser_window" )])
187+ def test_resample_waveform_upsample_size (self , resampling_method ):
188+ upsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr * 2 ,
189+ resampling_method = resampling_method )
187190 self .assertTrue (upsample_sound .size (- 1 ) == self .test1_signal .size (- 1 ) * 2 )
188191
189- def test_resample_waveform_downsample_size (self ):
190- downsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr // 2 )
192+ @parameterized .expand ([("sinc_interpolation" ), ("kaiser_window" )])
193+ def test_resample_waveform_downsample_size (self , resampling_method ):
194+ downsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr // 2 ,
195+ resampling_method = resampling_method )
191196 self .assertTrue (downsample_sound .size (- 1 ) == self .test1_signal .size (- 1 ) // 2 )
192197
193- def test_resample_waveform_identity_size (self ):
194- downsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr )
198+ @parameterized .expand ([("sinc_interpolation" ), ("kaiser_window" )])
199+ def test_resample_waveform_identity_size (self , resampling_method ):
200+ downsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr ,
201+ resampling_method = resampling_method )
195202 self .assertTrue (downsample_sound .size (- 1 ) == self .test1_signal .size (- 1 ))
196203
197204 def _test_resample_waveform_accuracy (self , up_scale_factor = None , down_scale_factor = None ,
198- atol = 1e-1 , rtol = 1e-4 ):
205+ resampling_method = "sinc_interpolation" , atol = 1e-1 , rtol = 1e-4 ):
199206 # resample the signal and compare it to the ground truth
200207 n_to_trim = 20
201208 sample_rate = 1000
@@ -211,7 +218,8 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
211218 original_timestamps = torch .arange (0 , duration , 1.0 / sample_rate )
212219
213220 sound = 123 * torch .cos (2 * math .pi * 3 * original_timestamps ).unsqueeze (0 )
214- estimate = kaldi .resample_waveform (sound , sample_rate , new_sample_rate ).squeeze ()
221+ estimate = kaldi .resample_waveform (sound , sample_rate , new_sample_rate ,
222+ resampling_method = resampling_method ).squeeze ()
215223
216224 new_timestamps = torch .arange (0 , duration , 1.0 / new_sample_rate )[:estimate .size (0 )]
217225 ground_truth = 123 * torch .cos (2 * math .pi * 3 * new_timestamps )
@@ -222,27 +230,32 @@ def _test_resample_waveform_accuracy(self, up_scale_factor=None, down_scale_fact
222230
223231 self .assertEqual (estimate , ground_truth , atol = atol , rtol = rtol )
224232
225- def test_resample_waveform_downsample_accuracy (self ):
233+ @parameterized .expand ([("sinc_interpolation" ), ("kaiser_window" )])
234+ def test_resample_waveform_downsample_accuracy (self , resampling_method ):
226235 for i in range (1 , 20 ):
227- self ._test_resample_waveform_accuracy (down_scale_factor = i * 2 )
236+ self ._test_resample_waveform_accuracy (down_scale_factor = i * 2 , resampling_method = resampling_method )
228237
229- def test_resample_waveform_upsample_accuracy (self ):
238+ @parameterized .expand ([("sinc_interpolation" ), ("kaiser_window" )])
239+ def test_resample_waveform_upsample_accuracy (self , resampling_method ):
230240 for i in range (1 , 20 ):
231- self ._test_resample_waveform_accuracy (up_scale_factor = 1.0 + i / 20.0 )
241+ self ._test_resample_waveform_accuracy (up_scale_factor = 1.0 + i / 20.0 , resampling_method = resampling_method )
232242
233- def test_resample_waveform_multi_channel (self ):
243+ @parameterized .expand ([("sinc_interpolation" ), ("kaiser_window" )])
244+ def test_resample_waveform_multi_channel (self , resampling_method ):
234245 num_channels = 3
235246
236247 multi_sound = self .test1_signal .repeat (num_channels , 1 ) # (num_channels, 8000 smp)
237248
238249 for i in range (num_channels ):
239250 multi_sound [i , :] *= (i + 1 ) * 1.5
240251
241- multi_sound_sampled = kaldi .resample_waveform (multi_sound , self .test1_signal_sr , self .test1_signal_sr // 2 )
252+ multi_sound_sampled = kaldi .resample_waveform (multi_sound , self .test1_signal_sr , self .test1_signal_sr // 2 ,
253+ resampling_method = resampling_method )
242254
243255 # check that sampling is same whether using separately or in a tensor of size (c, n)
244256 for i in range (num_channels ):
245257 single_channel = self .test1_signal * (i + 1 ) * 1.5
246258 single_channel_sampled = kaldi .resample_waveform (single_channel , self .test1_signal_sr ,
247- self .test1_signal_sr // 2 )
259+ self .test1_signal_sr // 2 ,
260+ resampling_method = resampling_method )
248261 self .assertEqual (multi_sound_sampled [i , :], single_channel_sampled [0 ], rtol = 1e-4 , atol = 1e-7 )
0 commit comments