@@ -47,14 +47,25 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
4747
4848
4949@common_utils .skipIfNoSoxBackend
50- class Test_Kaldi (common_utils .TorchaudioTestCase ):
50+ class Test_Kaldi (common_utils .TempDirMixin , common_utils . TorchaudioTestCase ):
5151 backend = 'sox'
5252
53- test_filepath = common_utils .get_asset_path ('kaldi_file.wav' )
54- test_8000_filepath = common_utils .get_asset_path ('kaldi_file_8000.wav' )
5553 kaldi_output_dir = common_utils .get_asset_path ('kaldi' )
54+ test_filepath = common_utils .get_asset_path ('kaldi_file.wav' )
5655 test_filepaths = {prefix : [] for prefix in compliance_utils .TEST_PREFIX }
5756
57+ def setUp (self ):
58+ super ().setUp ()
59+
60+ # 1. test signal for testing resampling
61+ self .test1_signal_sr = 16000
62+ self .test1_signal = common_utils .get_whitenoise (
63+ sample_rate = self .test1_signal_sr , duration = 0.5 ,
64+ )
65+
66+ # 2. test audio file corresponding to saved kaldi ark files
67+ self .test2_filepath = common_utils .get_asset_path ('kaldi_file_8000.wav' )
68+
5869 # separating test files by their types (e.g 'spec', 'fbank', etc.)
5970 for f in os .listdir (kaldi_output_dir ):
6071 dash_idx = f .find ('-' )
@@ -94,7 +105,6 @@ def test_get_strided(self):
94105
95106 def _create_data_set (self ):
96107 # used to generate the dataset to test on. this is not used in testing (offline procedure)
97- test_filepath = common_utils .get_asset_path ('kaldi_file.wav' )
98108 sr = 16000
99109 x = torch .arange (0 , 20 ).float ()
100110 # between [-6,6]
@@ -103,8 +113,8 @@ def _create_data_set(self):
103113 y = (y / 6 * (1 << 30 )).long ()
104114 # clear the last 16 bits because they aren't used anyways
105115 y = ((y >> 16 ) << 16 ).float ()
106- torchaudio .save (test_filepath , y , sr )
107- sound , sample_rate = torchaudio .load (test_filepath , normalization = False )
116+ torchaudio .save (self . test_filepath , y , sr )
117+ sound , sample_rate = torchaudio .load (self . test_filepath , normalization = False )
108118 print (y >> 16 )
109119 self .assertTrue (sample_rate == sr )
110120 torch .testing .assert_allclose (y , sound )
@@ -123,7 +133,7 @@ def _print_diagnostic(self, output, expect_output):
123133 print ('relative_mse:' , relative_mse .item (), 'relative_max_error:' , relative_max_error .item ())
124134
125135 def _compliance_test_helper (self , sound_filepath , filepath_key , expected_num_files ,
126- expected_num_args , get_output_fn , atol = 1e-5 , rtol = 1e-8 ):
136+ expected_num_args , get_output_fn , atol = 1e-5 , rtol = 1e-7 ):
127137 """
128138 Inputs:
129139 sound_filepath (str): The location of the sound file
@@ -135,7 +145,7 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
135145 atol (float): absolute tolerance
136146 rtol (float): relative tolerance
137147 """
138- sound , sample_rate = torchaudio .load_wav (sound_filepath )
148+ sound , sr = torchaudio .load_wav (sound_filepath )
139149 files = self .test_filepaths [filepath_key ]
140150
141151 assert len (files ) == expected_num_files , ('number of kaldi %s file changed to %d' % (filepath_key , len (files )))
@@ -170,22 +180,19 @@ def get_output_fn(sound, args):
170180 output = kaldi .resample_waveform (sound , args [1 ], args [2 ])
171181 return output
172182
173- self ._compliance_test_helper (self .test_8000_filepath , 'resample' , 32 , 3 , get_output_fn , atol = 1e-2 , rtol = 1e-5 )
183+ self ._compliance_test_helper (self .test2_filepath , 'resample' , 32 , 3 , get_output_fn , atol = 1e-2 , rtol = 1e-5 )
174184
175185 def test_resample_waveform_upsample_size (self ):
176- sound , sample_rate = torchaudio .load_wav (self .test_8000_filepath )
177- upsample_sound = kaldi .resample_waveform (sound , sample_rate , sample_rate * 2 )
178- self .assertTrue (upsample_sound .size (- 1 ) == sound .size (- 1 ) * 2 )
186+ upsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr * 2 )
187+ self .assertTrue (upsample_sound .size (- 1 ) == self .test1_signal .size (- 1 ) * 2 )
179188
180189 def test_resample_waveform_downsample_size (self ):
181- sound , sample_rate = torchaudio .load_wav (self .test_8000_filepath )
182- downsample_sound = kaldi .resample_waveform (sound , sample_rate , sample_rate // 2 )
183- self .assertTrue (downsample_sound .size (- 1 ) == sound .size (- 1 ) // 2 )
190+ downsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr // 2 )
191+ self .assertTrue (downsample_sound .size (- 1 ) == self .test1_signal .size (- 1 ) // 2 )
184192
185193 def test_resample_waveform_identity_size (self ):
186- sound , sample_rate = torchaudio .load_wav (self .test_8000_filepath )
187- downsample_sound = kaldi .resample_waveform (sound , sample_rate , sample_rate )
188- self .assertTrue (downsample_sound .size (- 1 ) == sound .size (- 1 ))
194+ downsample_sound = kaldi .resample_waveform (self .test1_signal , self .test1_signal_sr , self .test1_signal_sr )
195+ self .assertTrue (downsample_sound .size (- 1 ) == self .test1_signal .size (- 1 ))
189196
190197 def _test_resample_waveform_accuracy (self , up_scale_factor = None , down_scale_factor = None ,
191198 atol = 1e-1 , rtol = 1e-4 ):
@@ -226,19 +233,19 @@ def test_resample_waveform_upsample_accuracy(self):
226233 def test_resample_waveform_multi_channel (self ):
227234 num_channels = 3
228235
229- sound , sample_rate = torchaudio .load_wav (self .test_8000_filepath ) # (1, 8000)
230- multi_sound = sound .repeat (num_channels , 1 ) # (num_channels, 8000)
236+ multi_sound = self .test1_signal .repeat (num_channels , 1 ) # (num_channels, 8000 smp)
231237
232238 for i in range (num_channels ):
233239 multi_sound [i , :] *= (i + 1 ) * 1.5
234240
235- multi_sound_sampled = kaldi .resample_waveform (multi_sound , sample_rate , sample_rate // 2 )
241+ multi_sound_sampled = kaldi .resample_waveform (multi_sound , self . test1_signal_sr , self . test1_signal_sr // 2 )
236242
237243 # check that sampling is same whether using separately or in a tensor of size (c, n)
238244 for i in range (num_channels ):
239- single_channel = sound * (i + 1 ) * 1.5
240- single_channel_sampled = kaldi .resample_waveform (single_channel , sample_rate , sample_rate // 2 )
241- torch .testing .assert_allclose (multi_sound_sampled [i , :], single_channel_sampled [0 ], rtol = 1e-4 , atol = 1e-8 )
245+ single_channel = self .test1_signal * (i + 1 ) * 1.5
246+ single_channel_sampled = kaldi .resample_waveform (single_channel , self .test1_signal_sr ,
247+ self .test1_signal_sr // 2 )
248+ torch .testing .assert_allclose (multi_sound_sampled [i , :], single_channel_sampled [0 ], rtol = 1e-4 , atol = 1e-7 )
242249
243250
244251if __name__ == '__main__' :
0 commit comments