Skip to content
Merged
23 changes: 23 additions & 0 deletions test/torchaudio_unittest/backend/soundfile/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):

def parameterize(*params):
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)


def fetch_wav_subtype(dtype, encoding, bits_per_sample):
subtype = {
(None, None): dtype2subtype(dtype),
(None, 8): "PCM_U8",
('PCM_U', None): "PCM_U8",
('PCM_U', 8): "PCM_U8",
('PCM_S', None): "PCM_32",
('PCM_S', 16): "PCM_16",
('PCM_S', 32): "PCM_32",
('PCM_F', None): "FLOAT",
('PCM_F', 32): "FLOAT",
('PCM_F', 64): "DOUBLE",
('ULAW', None): "ULAW",
('ULAW', 8): "ULAW",
('ALAW', None): "ALAW",
('ALAW', 8): "ALAW",
}.get((encoding, bits_per_sample))
if subtype:
return subtype
raise ValueError(
f"wav does not support ({encoding}, {bits_per_sample}).")
67 changes: 54 additions & 13 deletions test/torchaudio_unittest/backend/soundfile/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
get_wav_data,
load_wav,
)
from .common import parameterize, dtype2subtype, skipIfFormatNotSupported
from .common import (
fetch_wav_subtype,
parameterize,
skipIfFormatNotSupported,
)

if _mod_utils.is_module_available("soundfile"):
import soundfile
Expand All @@ -20,36 +24,56 @@
class MockedSaveTest(PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True],
[
(None, None),
('PCM_U', None),
('PCM_U', 8),
('PCM_S', None),
('PCM_S', 16),
('PCM_S', 32),
('PCM_F', None),
('PCM_F', 32),
('PCM_F', 64),
('ULAW', None),
('ULAW', 8),
('ALAW', None),
('ALAW', 8),
],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_write):
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "flaot32",
normalize=dtype == "float32",
channels_first=channels_first,
).t()

encoding, bits_per_sample = enc_params
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first
filepath, input_tensor, sample_rate, channels_first=channels_first,
encoding=encoding, bits_per_sample=bits_per_sample
)

# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == dtype2subtype(dtype)
assert args["subtype"] == fetch_wav_subtype(
dtype, encoding, bits_per_sample)
assert args["format"] is None
self.assertEqual(
args["data"], input_tensor.t() if channels_first else input_tensor
)

@patch("soundfile.write")
def assert_non_wav(
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write,
encoding=None, bits_per_sample=None,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
Expand All @@ -63,14 +87,14 @@ def assert_non_wav(
expected_data = input_tensor.t() if channels_first else input_tensor

soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first
filepath, input_tensor, sample_rate, channels_first,
encoding=encoding, bits_per_sample=bits_per_sample,
)

# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] is None
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
Expand All @@ -83,19 +107,36 @@ def assert_non_wav(
[8000, 16000],
[1, 2],
[False, True],
[
('PCM_S', 8),
('PCM_S', 16),
('PCM_S', 24),
('PCM_S', 32),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first):
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav(fmt, dtype, sample_rate, num_channels, channels_first)
encoding, bits_per_sample = enc_params
self.assert_non_wav(fmt, dtype, sample_rate, num_channels,
channels_first, encoding=encoding,
bits_per_sample=bits_per_sample)

@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
def test_flac(self, dtype, sample_rate, num_channels,
channels_first, bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first)
self.assert_non_wav("flac", dtype, sample_rate, num_channels,
channels_first, bits_per_sample=bits_per_sample)

@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
Expand Down Expand Up @@ -228,7 +269,7 @@ def _test_fileobj(self, ext):
found, sr = soundfile.read(fileobj, dtype='float32')

assert sr == sample_rate
self.assertEqual(expected, found)
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)

def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
Expand Down
170 changes: 152 additions & 18 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,93 @@ def load(
return waveform, sample_rate


def _get_subtype_for_wav(
dtype: torch.dtype,
encoding: str,
bits_per_sample: int):
if not encoding:
if not bits_per_sample:
subtype = {
torch.uint8: "PCM_U8",
torch.int16: "PCM_16",
torch.int32: "PCM_32",
torch.float32: "FLOAT",
torch.float64: "DOUBLE",
}.get(dtype)
if not subtype:
raise ValueError(f"Unsupported dtype for wav: {dtype}")
return subtype
if bits_per_sample == 8:
return "PCM_U8"
return f"PCM_{bits_per_sample}"
if encoding == "PCM_S":
if not bits_per_sample:
return "PCM_32"
if bits_per_sample == 8:
raise ValueError("wav does not support 8-bit signed PCM encoding.")
return f"PCM_{bits_per_sample}"
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "PCM_U8"
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
if encoding == "PCM_F":
if bits_per_sample in (None, 32):
return "FLOAT"
if bits_per_sample == 64:
return "DOUBLE"
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("wav only supports 8-bit mu-law encoding.")
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "ALAW"
raise ValueError("wav only supports 8-bit a-law encoding.")
raise ValueError(f"wav does not support {encoding}.")


def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
if encoding in (None, "PCM_S"):
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
if encoding in ("PCM_U", "PCM_F"):
raise ValueError(f"sph does not support {encoding} encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("sph only supports 8-bit for mu-law encoding.")
if encoding == "ALAW":
return "ALAW"
raise ValueError(f"sph does not support {encoding}.")


def _get_subtype(
dtype: torch.dtype,
format: str,
encoding: str,
bits_per_sample: int):
if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
if encoding:
raise ValueError("flac does not support encoding.")
if not bits_per_sample:
return "PCM_24"
if bits_per_sample > 24:
raise ValueError("flac does not support bits_per_sample > 24.")
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError(
"ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS"
if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample)
if format in ("nis", "nist"):
return "PCM_16"
raise ValueError(f"Unsupported format: {format}")


@_mod_utils.requires_module("soundfile")
def save(
filepath: str,
Expand All @@ -217,6 +304,8 @@ def save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.

Expand Down Expand Up @@ -246,9 +335,65 @@ def save(
otherwise ``[time, channel]``.
compression (Optional[float]): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional): Output audio format.
This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
format (str, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.

When ``filepath`` argument is file-like object,
this argument is required.

Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;

- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)

bits_per_sample (int, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.

Supported formats/encodings/bit depth/compression are:

``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law

Note: Default encoding/bit depth is determined by the dtype of
the input Tensor.

``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)

``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.

``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law

"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
Expand All @@ -260,24 +405,13 @@ def save(
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
ext = format
ext = format.lower()
else:
ext = str(filepath).split(".")[-1].lower()

if ext != "wav":
subtype = None
elif src.dtype == torch.uint8:
subtype = "PCM_U8"
elif src.dtype == torch.int16:
subtype = "PCM_16"
elif src.dtype == torch.int32:
subtype = "PCM_32"
elif src.dtype == torch.float32:
subtype = "FLOAT"
elif src.dtype == torch.float64:
subtype = "DOUBLE"
else:
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
raise ValueError("Invalid bits_per_sample.")
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)

# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def save(
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.

encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;

- ``"PCM_S"`` (signed integer Linear PCM)
Expand Down