Skip to content

Commit 4b767ee

Browse files
committed
Add sample_rate option to sox_io_backend.load
1 parent 90896fc commit 4b767ee

File tree

5 files changed

+83
-2
lines changed

5 files changed

+83
-2
lines changed

test/torchaudio_unittest/sox_io_backend/load_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,47 @@ def test_channels_first(self, channels_first):
261261
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
262262
expected = self.original if channels_first else self.original.transpose(1, 0)
263263
self.assertEqual(found, expected)
264+
265+
266+
@skipIfNoExec('sox')
267+
@skipIfNoExtension
268+
class TestSampleRate(TempDirMixin, PytorchTestCase):
269+
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
270+
path = None
271+
272+
def setUp(self):
273+
super().setUp()
274+
sample_rate = 16000
275+
original = get_wav_data('float32', num_channels=2)
276+
self.path = self.get_temp_path('original.wave')
277+
save_wav(self.path, original, sample_rate)
278+
279+
@parameterized.expand([(8000, ), (44100, )], name_func=name_func)
280+
def test_sample_rate(self, sample_rate):
281+
"""sample_rate changes sample rate"""
282+
found, rate = sox_io_backend.load(self.path, sample_rate=sample_rate)
283+
ref_path = self.get_temp_path('reference.wav')
284+
sox_utils.run_sox_effect(self.path, ref_path, ['rate', f'{sample_rate}'])
285+
expected, expected_rate = load_wav(ref_path)
286+
287+
assert rate == expected_rate
288+
self.assertEqual(found, expected)
289+
290+
@parameterized.expand(list(itertools.product(
291+
[8000, 44100],
292+
[0, 1, 10, 100, 1000],
293+
[-1, 1, 10, 100, 1000],
294+
)), name_func=name_func)
295+
def test_frame(self, sample_rate, frame_offset, num_frames):
296+
"""frame_offset and num_frames applied after sample_rate"""
297+
found, rate = sox_io_backend.load(
298+
self.path, frame_offset=frame_offset, num_frames=num_frames, sample_rate=sample_rate)
299+
300+
ref_path = self.get_temp_path('reference.wav')
301+
sox_utils.run_sox_effect(self.path, ref_path, ['rate', f'{sample_rate}'])
302+
reference, expected_rate = load_wav(ref_path)
303+
frame_end = None if num_frames == -1 else frame_offset + num_frames
304+
expected = reference[:, frame_offset:frame_end]
305+
306+
assert rate == expected_rate
307+
self.assertEqual(found, expected)

torchaudio/backend/sox_io_backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def load(
4040
num_frames: int = -1,
4141
normalize: bool = True,
4242
channels_first: bool = True,
43+
sample_rate: Optional[int] = None,
4344
) -> Tuple[torch.Tensor, int]:
4445
"""Load audio data from file.
4546
@@ -83,11 +84,13 @@ def load(
8384
Path to audio file
8485
frame_offset (int):
8586
Number of frames to skip before start reading data.
87+
If ``sample_rate`` is given, frame counts start after the audio is resampled.
8688
num_frames (int):
8789
Maximum number of frames to read. ``-1`` reads all the remaining samples,
8890
starting from ``frame_offset``.
8991
This function may return the less number of frames if there is not enough
9092
frames in the given file.
93+
If ``sample_rate`` is given, frame counts start after the audio is resampled.
9194
normalize (bool):
9295
When ``True``, this function always return ``float32``, and sample values are
9396
normalized to ``[-1.0, 1.0]``.
@@ -97,15 +100,18 @@ def load(
97100
channels_first (bool):
98101
When True, the returned Tensor has dimension ``[channel, time]``.
99102
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
103+
sample_rate (int, optional):
104+
Perform resampling.
100105
101106
Returns:
102107
torch.Tensor:
103108
If the input file has integer wav format and normalization is off, then it has
104109
integer type, else ``float32`` type. If ``channels_first=True``, it has
105110
``[channel, time]`` else ``[time, channel]``.
106111
"""
107-
signal = torch.ops.torchaudio.sox_io_load_audio_file(
108-
filepath, frame_offset, num_frames, normalize, channels_first)
112+
sample_rate = -1 if sample_rate is None else sample_rate
113+
signal = torch.ops.torchaudio.sox_io_load_audio_file_v1(
114+
filepath, frame_offset, num_frames, normalize, channels_first, sample_rate)
109115
return signal.get_tensor(), signal.get_sample_rate()
110116

111117

torchaudio/csrc/register.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ TORCH_LIBRARY(torchaudio, m) {
5151
m.def(
5252
"torchaudio::sox_io_load_audio_file",
5353
&torchaudio::sox_io::load_audio_file);
54+
m.def(
55+
"torchaudio::sox_io_load_audio_file_v1",
56+
&torchaudio::sox_io::load_audio_file_v1);
5457
m.def(
5558
"torchaudio::sox_io_save_audio_file",
5659
&torchaudio::sox_io::save_audio_file);

torchaudio/csrc/sox_io.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
5353
const int64_t num_frames,
5454
const bool normalize,
5555
const bool channels_first) {
56+
return load_audio_file_v1(path, frame_offset, num_frames, channels_first, -1);
57+
}
58+
59+
c10::intrusive_ptr<TensorSignal> load_audio_file_v1(
60+
const std::string& path,
61+
const int64_t frame_offset,
62+
const int64_t num_frames,
63+
const bool normalize,
64+
const bool channels_first,
65+
const int64_t sample_rate) {
5666
if (frame_offset < 0) {
5767
throw std::runtime_error(
5868
"Invalid argument: frame_offset must be non-negative.");
@@ -61,8 +71,16 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
6171
throw std::runtime_error(
6272
"Invalid argument: num_frames must be -1 or greater than 0.");
6373
}
74+
if (sample_rate == 0 || sample_rate < -1) {
75+
throw std::runtime_error(
76+
"Invalid argument: sample_rate must be -1 or greater than 0.");
77+
}
6478

6579
std::vector<std::vector<std::string>> effects;
80+
if (sample_rate != -1) {
81+
effects.emplace_back(
82+
std::vector<std::string>{"rate", std::to_string(sample_rate)});
83+
}
6684
if (num_frames != -1) {
6785
std::ostringstream offset, frames;
6886
offset << frame_offset << "s";

torchaudio/csrc/sox_io.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,23 @@ struct SignalInfo : torch::CustomClassHolder {
2323

2424
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
2525

26+
// ver. 0
2627
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
2728
const std::string& path,
2829
const int64_t frame_offset = 0,
2930
const int64_t num_frames = -1,
3031
const bool normalize = true,
3132
const bool channels_first = true);
3233

34+
// ver. 1 sample_rate is added
35+
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file_v1(
36+
const std::string& path,
37+
const int64_t frame_offset = 0,
38+
const int64_t num_frames = -1,
39+
const bool normalize = true,
40+
const bool channels_first = true,
41+
const int64_t sample_rate = -1);
42+
3343
void save_audio_file(
3444
const std::string& file_name,
3545
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,

0 commit comments

Comments
 (0)