Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/torchaudio_unittest/backend/soundfile/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):

def parameterize(*params):
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)


def fetch_wav_subtype(dtype, encoding, bits_per_sample):
subtype = {
(None, None): dtype2subtype(dtype),
(None, 8): "PCM_U8",
('PCM_U', None): "PCM_U8",
('PCM_U', 8): "PCM_U8",
('PCM_S', None): "PCM_32",
('PCM_S', 16): "PCM_16",
('PCM_S', 32): "PCM_32",
('PCM_F', None): "FLOAT",
('PCM_F', 32): "FLOAT",
('PCM_F', 64): "DOUBLE",
('ULAW', None): "ULAW",
('ULAW', 8): "ULAW",
('ALAW', None): "ALAW",
('ALAW', 8): "ALAW",
}.get((encoding, bits_per_sample))
if subtype:
return subtype
raise ValueError(
f"wav does not support ({encoding}, {bits_per_sample}).")
67 changes: 54 additions & 13 deletions test/torchaudio_unittest/backend/soundfile/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
get_wav_data,
load_wav,
)
from .common import parameterize, dtype2subtype, skipIfFormatNotSupported
from .common import (
fetch_wav_subtype,
parameterize,
skipIfFormatNotSupported,
)

if _mod_utils.is_module_available("soundfile"):
import soundfile
Expand All @@ -20,36 +24,56 @@
class MockedSaveTest(PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True],
[
(None, None),
('PCM_U', None),
('PCM_U', 8),
('PCM_S', None),
('PCM_S', 16),
('PCM_S', 32),
('PCM_F', None),
('PCM_F', 32),
('PCM_F', 64),
('ULAW', None),
('ULAW', 8),
('ALAW', None),
('ALAW', 8),
],
)
@patch("soundfile.write")
def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_write):
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
enc_params, mocked_write):
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
filepath = "foo.wav"
input_tensor = get_wav_data(
dtype,
num_channels,
num_frames=3 * sample_rate,
normalize=dtype == "flaot32",
normalize=dtype == "float32",
channels_first=channels_first,
).t()

encoding, bits_per_sample = enc_params
soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first
filepath, input_tensor, sample_rate, channels_first=channels_first,
encoding=encoding, bits_per_sample=bits_per_sample
)

# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] == dtype2subtype(dtype)
assert args["subtype"] == fetch_wav_subtype(
dtype, encoding, bits_per_sample)
assert args["format"] is None
self.assertEqual(
args["data"], input_tensor.t() if channels_first else input_tensor
)

@patch("soundfile.write")
def assert_non_wav(
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write,
encoding=None, bits_per_sample=None,
):
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
filepath = f"foo.{fmt}"
Expand All @@ -63,14 +87,14 @@ def assert_non_wav(
expected_data = input_tensor.t() if channels_first else input_tensor

soundfile_backend.save(
filepath, input_tensor, sample_rate, channels_first=channels_first
filepath, input_tensor, sample_rate, channels_first,
encoding=encoding, bits_per_sample=bits_per_sample,
)

# on +Py3.8 call_args.kwargs is more descreptive
args = mocked_write.call_args[1]
assert args["file"] == filepath
assert args["samplerate"] == sample_rate
assert args["subtype"] is None
if fmt in ["sph", "nist", "nis"]:
assert args["format"] == "NIST"
else:
Expand All @@ -83,19 +107,36 @@ def assert_non_wav(
[8000, 16000],
[1, 2],
[False, True],
[
('PCM_S', 8),
('PCM_S', 16),
('PCM_S', 24),
('PCM_S', 32),
('ULAW', 8),
('ALAW', 8),
('ALAW', 16),
('ALAW', 24),
('ALAW', 32),
],
)
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first):
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav(fmt, dtype, sample_rate, num_channels, channels_first)
encoding, bits_per_sample = enc_params
self.assert_non_wav(fmt, dtype, sample_rate, num_channels,
channels_first, encoding=encoding,
bits_per_sample=bits_per_sample)

@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
[8, 16, 24],
)
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
def test_flac(self, dtype, sample_rate, num_channels,
channels_first, bits_per_sample):
"""soundfile_backend.save passes default format and subtype (None-s) to
soundfile.write when not WAV"""
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first)
self.assert_non_wav("flac", dtype, sample_rate, num_channels,
channels_first, bits_per_sample=bits_per_sample)

@parameterize(
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
Expand Down Expand Up @@ -228,7 +269,7 @@ def _test_fileobj(self, ext):
found, sr = soundfile.read(fileobj, dtype='float32')

assert sr == sample_rate
self.assertEqual(expected, found)
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)

def test_fileobj_wav(self):
"""Saving audio via file-like object works"""
Expand Down
170 changes: 152 additions & 18 deletions torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,93 @@ def load(
return waveform, sample_rate


def _get_subtype_for_wav(
dtype: torch.dtype,
encoding: str,
bits_per_sample: int):
if not encoding:
if not bits_per_sample:
subtype = {
torch.uint8: "PCM_U8",
torch.int16: "PCM_16",
torch.int32: "PCM_32",
torch.float32: "FLOAT",
torch.float64: "DOUBLE",
}.get(dtype)
if not subtype:
raise ValueError(f"Unsupported dtype for wav: {dtype}")
return subtype
if bits_per_sample == 8:
return "PCM_U8"
return f"PCM_{bits_per_sample}"
if encoding == "PCM_S":
if not bits_per_sample:
return "PCM_32"
if bits_per_sample == 8:
raise ValueError("wav does not support 8-bit signed PCM encoding.")
return f"PCM_{bits_per_sample}"
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "PCM_U8"
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
if encoding == "PCM_F":
if bits_per_sample in (None, 32):
return "FLOAT"
if bits_per_sample == 64:
return "DOUBLE"
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("wav only supports 8-bit mu-law encoding.")
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "ALAW"
raise ValueError("wav only supports 8-bit a-law encoding.")
raise ValueError(f"wav does not support {encoding}.")


def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
if encoding in (None, "PCM_S"):
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
if encoding in ("PCM_U", "PCM_F"):
raise ValueError(f"sph does not support {encoding} encoding.")
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "ULAW"
raise ValueError("sph only supports 8-bit for mu-law encoding.")
if encoding == "ALAW":
return "ALAW"
raise ValueError(f"sph does not support {encoding}.")


def _get_subtype(
dtype: torch.dtype,
format: str,
encoding: str,
bits_per_sample: int):
if format == "wav":
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
if encoding:
raise ValueError("flac does not support encoding.")
if not bits_per_sample:
return "PCM_24"
if bits_per_sample > 24:
raise ValueError("flac does not support bits_per_sample > 24.")
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError(
"ogg/vorbis does not support encoding/bits_per_sample.")
return "VORBIS"
if format == "sph":
return _get_subtype_for_sphere(encoding, bits_per_sample)
if format in ("nis", "nist"):
return "PCM_16"
raise ValueError(f"Unsupported format: {format}")


@_mod_utils.requires_module("soundfile")
def save(
filepath: str,
Expand All @@ -217,6 +304,8 @@ def save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.

Expand Down Expand Up @@ -246,9 +335,65 @@ def save(
otherwise ``[time, channel]``.
compression (Optional[float]): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
format (str, optional): Output audio format.
This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
format (str, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is
inferred from file extension. If the file extension is missing or
different, you can specify the correct format with this argument.

When ``filepath`` argument is file-like object,
this argument is required.

Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for supported formats.
This argument is effective only for supported formats, sush as
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;

- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)

bits_per_sample (int, optional): Changes the bit depth for the
supported formats.
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
you can change the bit depth.
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.

Supported formats/encodings/bit depth/compression are:

``"wav"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law

Note: Default encoding/bit depth is determined by the dtype of
the input Tensor.

``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)

``"ogg"``, ``"vorbis"``
- Doesn't accept changing configuration.

``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law

"""
if src.ndim != 2:
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
Expand All @@ -260,24 +405,13 @@ def save(
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
ext = format
ext = format.lower()
else:
ext = str(filepath).split(".")[-1].lower()

if ext != "wav":
subtype = None
elif src.dtype == torch.uint8:
subtype = "PCM_U8"
elif src.dtype == torch.int16:
subtype = "PCM_16"
elif src.dtype == torch.int32:
subtype = "PCM_32"
elif src.dtype == torch.float32:
subtype = "FLOAT"
elif src.dtype == torch.float64:
subtype = "DOUBLE"
else:
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
raise ValueError("Invalid bits_per_sample.")
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)

# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
# so we extend the extensions manually here
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def save(
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;

- ``"PCM_S"`` (signed integer Linear PCM)
Expand Down