@@ -291,3 +291,47 @@ def test_channels_first(self, channels_first):
291291 found , _ = sox_io_backend .load (self .path , channels_first = channels_first )
292292 expected = self .original if channels_first else self .original .transpose (1 , 0 )
293293 self .assertEqual (found , expected )
294+
295+
296+ @skipIfNoExec ('sox' )
297+ @skipIfNoExtension
298+ class TestSampleRate (TempDirMixin , PytorchTestCase ):
299+ """Test the correctness of frame parameters of `sox_io_backend.load`"""
300+ path = None
301+
302+ def setUp (self ):
303+ super ().setUp ()
304+ sample_rate = 16000
305+ original = get_wav_data ('float32' , num_channels = 2 )
306+ self .path = self .get_temp_path ('original.wave' )
307+ save_wav (self .path , original , sample_rate )
308+
309+ @parameterized .expand ([(8000 , ), (44100 , )], name_func = name_func )
310+ def test_sample_rate (self , sample_rate ):
311+ """sample_rate changes sample rate"""
312+ found , rate = sox_io_backend .load (self .path , sample_rate = sample_rate )
313+ ref_path = self .get_temp_path ('reference.wav' )
314+ sox_utils .run_sox_effect (self .path , ref_path , ['rate' , f'{ sample_rate } ' ])
315+ expected , expected_rate = load_wav (ref_path )
316+
317+ assert rate == expected_rate
318+ self .assertEqual (found , expected )
319+
320+ @parameterized .expand (list (itertools .product (
321+ [8000 , 44100 ],
322+ [0 , 1 , 10 , 100 , 1000 ],
323+ [- 1 , 1 , 10 , 100 , 1000 ],
324+ )), name_func = name_func )
325+ def test_frame (self , sample_rate , frame_offset , num_frames ):
326+ """frame_offset and num_frames applied after sample_rate"""
327+ found , rate = sox_io_backend .load (
328+ self .path , frame_offset = frame_offset , num_frames = num_frames , sample_rate = sample_rate )
329+
330+ ref_path = self .get_temp_path ('reference.wav' )
331+ sox_utils .run_sox_effect (self .path , ref_path , ['rate' , f'{ sample_rate } ' ])
332+ reference , expected_rate = load_wav (ref_path )
333+ frame_end = None if num_frames == - 1 else frame_offset + num_frames
334+ expected = reference [:, frame_offset :frame_end ]
335+
336+ assert rate == expected_rate
337+ self .assertEqual (found , expected )
0 commit comments