@@ -261,3 +261,47 @@ def test_channels_first(self, channels_first):
261261 found , _ = sox_io_backend .load (self .path , channels_first = channels_first )
262262 expected = self .original if channels_first else self .original .transpose (1 , 0 )
263263 self .assertEqual (found , expected )
264+
265+
266+ @skipIfNoExec ('sox' )
267+ @skipIfNoExtension
268+ class TestSampleRate (TempDirMixin , PytorchTestCase ):
269+ """Test the correctness of frame parameters of `sox_io_backend.load`"""
270+ path = None
271+
272+ def setUp (self ):
273+ super ().setUp ()
274+ sample_rate = 16000
275+ original = get_wav_data ('float32' , num_channels = 2 )
276+ self .path = self .get_temp_path ('original.wave' )
277+ save_wav (self .path , original , sample_rate )
278+
279+ @parameterized .expand ([(8000 , ), (44100 , )], name_func = name_func )
280+ def test_sample_rate (self , sample_rate ):
281+ """sample_rate changes sample rate"""
282+ found , rate = sox_io_backend .load (self .path , sample_rate = sample_rate )
283+ ref_path = self .get_temp_path ('reference.wav' )
284+ sox_utils .run_sox_effect (self .path , ref_path , ['rate' , f'{ sample_rate } ' ])
285+ expected , expected_rate = load_wav (ref_path )
286+
287+ assert rate == expected_rate
288+ self .assertEqual (found , expected )
289+
290+ @parameterized .expand (list (itertools .product (
291+ [8000 , 44100 ],
292+ [0 , 1 , 10 , 100 , 1000 ],
293+ [- 1 , 1 , 10 , 100 , 1000 ],
294+ )), name_func = name_func )
295+ def test_frame (self , sample_rate , frame_offset , num_frames ):
296+ """frame_offset and num_frames applied after sample_rate"""
297+ found , rate = sox_io_backend .load (
298+ self .path , frame_offset = frame_offset , num_frames = num_frames , sample_rate = sample_rate )
299+
300+ ref_path = self .get_temp_path ('reference.wav' )
301+ sox_utils .run_sox_effect (self .path , ref_path , ['rate' , f'{ sample_rate } ' ])
302+ reference , expected_rate = load_wav (ref_path )
303+ frame_end = None if num_frames == - 1 else frame_offset + num_frames
304+ expected = reference [:, frame_offset :frame_end ]
305+
306+ assert rate == expected_rate
307+ self .assertEqual (found , expected )
0 commit comments