Skip to content

Commit 793eeab

Browse files
authored
Add load function (#731)
This is a part of PRs to add new "sox_io" backend. #726 and depends on #718 and #728 . This PR adds `load` function to "sox_io" backend, which is tested on the following audio formats; - `wav` - `mp3` - `flac` - `ogg/vorbis` * By default, "sox_io" backend returns Tensor with `float32` dtype and the shape of `[channel, time]`. The samples are normalized to fit in the range of `[-1.0, 1.0]`. Unlike existing "sox" backend, the new `load` function can handle WAV file natively, when the input format is WAV with integer type, (such as 32-bit signed integer, 16-bit signed integer and 8-bit unsigned integer) by providing `normalize=False`, this function can return integer Tensor, where the samples are expressed within the whole range of the corresponding dtype, that is, `int32` tensor for `32-bit PCM`, `int16` for `16-bit PCM` and `uint8` for `8-bit PCM`. This behavior follows [scipy.io.wavfile.read](https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html). `normalize` parameter has no effect for other formats and the load function always return normalized value with `float32` Tensor. __* Note__ The current binary distribution of torchaudio does not contain `ogg/vorbis` and `opus` codecs. To handle these files, one needs to build torchaudio from the source with proper codecs in the system. __Note 2__ Since this PR, `scipy` becomes required module for running test.
1 parent 0f0d0af commit 793eeab

File tree

11 files changed

+772
-55
lines changed

11 files changed

+772
-55
lines changed

test/sox_io_backend/common.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,90 @@
1+
from typing import Optional
2+
3+
import torch
4+
import scipy.io.wavfile
5+
6+
17
def get_test_name(func, _, params):
28
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'
9+
10+
11+
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
12+
if tensor.dtype == torch.float32:
13+
pass
14+
elif tensor.dtype == torch.int32:
15+
tensor = tensor.to(torch.float32)
16+
tensor[tensor > 0] /= 2147483647.
17+
tensor[tensor < 0] /= 2147483648.
18+
elif tensor.dtype == torch.int16:
19+
tensor = tensor.to(torch.float32)
20+
tensor[tensor > 0] /= 32767.
21+
tensor[tensor < 0] /= 32768.
22+
elif tensor.dtype == torch.uint8:
23+
tensor = tensor.to(torch.float32) - 128
24+
tensor[tensor > 0] /= 127.
25+
tensor[tensor < 0] /= 128.
26+
return tensor
27+
28+
29+
def get_wav_data(
30+
dtype: str,
31+
num_channels: int,
32+
*,
33+
num_frames: Optional[int] = None,
34+
normalize: bool = True,
35+
channels_first: bool = True,
36+
):
37+
"""Generate linear signal of the given dtype and num_channels
38+
39+
Data range is
40+
[-1.0, 1.0] for float32,
41+
[-2147483648, 2147483647] for int32
42+
[-32768, 32767] for int16
43+
[0, 255] for uint8
44+
45+
num_frames allow to change the linear interpolation parameter.
46+
Default values are 256 for uint8, else 1 << 16.
47+
1 << 16 as default is so that int16 value range is completely covered.
48+
"""
49+
dtype_ = getattr(torch, dtype)
50+
51+
if num_frames is None:
52+
if dtype == 'uint8':
53+
num_frames = 256
54+
else:
55+
num_frames = 1 << 16
56+
57+
if dtype == 'uint8':
58+
base = torch.linspace(0, 255, num_frames, dtype=dtype_)
59+
if dtype == 'float32':
60+
base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
61+
if dtype == 'int32':
62+
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
63+
if dtype == 'int16':
64+
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_)
65+
data = base.repeat([num_channels, 1])
66+
if not channels_first:
67+
data = data.transpose(1, 0)
68+
if normalize:
69+
data = normalize_wav(data)
70+
return data
71+
72+
73+
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor:
74+
"""Load wav file without torchaudio"""
75+
sample_rate, data = scipy.io.wavfile.read(path)
76+
data = torch.from_numpy(data.copy())
77+
if data.ndim == 1:
78+
data = data.unsqueeze(1)
79+
if normalize:
80+
data = normalize_wav(data)
81+
if channels_first:
82+
data = data.transpose(1, 0)
83+
return data, sample_rate
84+
85+
86+
def save_wav(path, data, sample_rate, channels_first=True):
87+
"""Save wav file without torchaudio"""
88+
if channels_first:
89+
data = data.transpose(1, 0)
90+
scipy.io.wavfile.write(path, sample_rate, data.numpy())

test/sox_io_backend/sox_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def gen_audio_file(
2626
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
2727
):
2828
"""Generate synthetic audio file with `sox` command."""
29+
if path.endswith('.wav'):
30+
raise RuntimeError(
31+
'Use get_wav_data and save_wav to generate wav file for accurate result.')
2932
command = [
3033
'sox',
3134
'-V', # verbose
@@ -51,4 +54,17 @@ def gen_audio_file(
5154
command += ['vol', f'-{attenuation}dB']
5255
print(' '.join(command))
5356
subprocess.run(command, check=True)
54-
subprocess.run(['soxi', path], check=True)
57+
58+
59+
def convert_audio_file(
60+
src_path, dst_path,
61+
*, bit_depth=None, compression=None):
62+
"""Convert audio file with `sox` command."""
63+
command = ['sox', '-V', str(src_path)]
64+
if bit_depth is not None:
65+
command += ['--bits', str(bit_depth)]
66+
if compression is not None:
67+
command += ['--compression', str(compression)]
68+
command += [dst_path]
69+
print(' '.join(command))
70+
subprocess.run(command, check=True)

test/sox_io_backend/test_info.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
skipIfNoExtension,
1111
)
1212
from .common import (
13-
get_test_name
13+
get_test_name,
14+
get_wav_data,
15+
save_wav,
1416
)
1517
from . import sox_utils
1618

@@ -27,12 +29,8 @@ def test_wav(self, dtype, sample_rate, num_channels):
2729
"""`sox_io_backend.info` can check wav file correctly"""
2830
duration = 1
2931
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
30-
sox_utils.gen_audio_file(
31-
path, sample_rate, num_channels,
32-
bit_depth=sox_utils.get_bit_depth(dtype),
33-
encoding=sox_utils.get_encoding(dtype),
34-
duration=duration,
35-
)
32+
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
33+
save_wav(path, data, sample_rate)
3634
info = sox_io_backend.info(path)
3735
assert info.get_sample_rate() == sample_rate
3836
assert info.get_num_frames() == sample_rate * duration
@@ -47,12 +45,8 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
4745
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
4846
duration = 1
4947
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
50-
sox_utils.gen_audio_file(
51-
path, sample_rate, num_channels,
52-
bit_depth=sox_utils.get_bit_depth(dtype),
53-
encoding=sox_utils.get_encoding(dtype),
54-
duration=duration,
55-
)
48+
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
49+
save_wav(path, data, sample_rate)
5650
info = sox_io_backend.info(path)
5751
assert info.get_sample_rate() == sample_rate
5852
assert info.get_num_frames() == sample_rate * duration

0 commit comments

Comments
 (0)