Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/backend/sox_io/common.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'


def get_enc_params(dtype):
if dtype == 'float32':
return 'PCM_F', 32
if dtype == 'int32':
return 'PCM_S', 32
if dtype == 'int16':
return 'PCM_S', 16
if dtype == 'uint8':
return 'PCM_U', 8
raise ValueError(f'Unexpected dtype: {dtype}')
4 changes: 3 additions & 1 deletion test/torchaudio_unittest/backend/sox_io/roundtrip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from .common import (
name_func,
get_enc_params,
)


Expand All @@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False)
enc, bps = get_enc_params(dtype)
data = original
for i in range(10):
path = self.get_temp_path(f'{i}.wav')
sox_io_backend.save(path, data, sample_rate)
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
data, sr = sox_io_backend.load(path, normalize=False)
assert sr == sample_rate
self.assertEqual(original, data)
Expand Down
740 changes: 302 additions & 438 deletions test/torchaudio_unittest/backend/sox_io/save_test.py

Large diffs are not rendered by default.

22 changes: 14 additions & 8 deletions test/torchaudio_unittest/backend/sox_io/torchscript_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from .common import (
name_func,
get_enc_params,
)


Expand All @@ -35,8 +36,12 @@ def py_save_func(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
torchaudio.save(
filepath, tensor, sample_rate, channels_first,
compression, None, encoding, bits_per_sample)


@skipIfNoExec('sox')
Expand Down Expand Up @@ -102,15 +107,16 @@ def test_save_wav(self, dtype, sample_rate, num_channels):
torch.jit.script(py_save_func).save(script_path)
ts_save_func = torch.jit.load(script_path)

expected = get_wav_data(dtype, num_channels)
expected = get_wav_data(dtype, num_channels, normalize=False)
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
enc, bps = get_enc_params(dtype)

py_save_func(py_path, expected, sample_rate, True, None)
ts_save_func(ts_path, expected, sample_rate, True, None)
py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)

py_data, py_sr = load_wav(py_path)
ts_data, ts_sr = load_wav(ts_path)
py_data, py_sr = load_wav(py_path, normalize=False)
ts_data, ts_sr = load_wav(ts_path, normalize=False)

self.assertEqual(sample_rate, py_sr)
self.assertEqual(sample_rate, ts_sr)
Expand All @@ -131,8 +137,8 @@ def test_save_flac(self, sample_rate, num_channels, compression_level):
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')

py_save_func(py_path, expected, sample_rate, True, compression_level)
ts_save_func(ts_path, expected, sample_rate, True, compression_level)
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)

# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
py_path_wav = f'{py_path}.wav'
Expand Down
12 changes: 8 additions & 4 deletions test/torchaudio_unittest/common_utils/sox_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import subprocess
import warnings

Expand Down Expand Up @@ -32,6 +33,7 @@ def gen_audio_file(
command = [
'sox',
'-V3', # verbose
'--no-dither', # disable automatic dithering
'-R',
# -R is supposed to be repeatable, though the implementation looks suspicious
# and not setting the seed to a fixed value.
Expand Down Expand Up @@ -61,21 +63,23 @@ def gen_audio_file(
]
if attenuation is not None:
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True)


def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
*, encoding=None, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V3', '-R', str(src_path)]
command = ['sox', '-V3', '--no-dither', '-R', str(src_path)]
if encoding is not None:
command += ['--encoding', str(encoding)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
command += ['--compression', str(compression)]
command += [dst_path]
print(' '.join(command))
print(' '.join(command), file=sys.stderr)
subprocess.run(command, check=True)


Expand Down
191 changes: 128 additions & 63 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import warnings
from typing import Tuple, Optional

import torch
Expand Down Expand Up @@ -152,26 +151,6 @@ def load(
filepath, frame_offset, num_frames, normalize, channels_first, format)


@torch.jit.unused
def _save(
filepath: str,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
):
if hasattr(filepath, 'write'):
if format is None:
raise RuntimeError('`format` is required when saving to file object.')
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format, dtype)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)


@_mod_utils.requires_module('torchaudio._torchaudio')
def save(
filepath: str,
Expand All @@ -180,30 +159,11 @@ def save(
channels_first: bool = True,
compression: Optional[float] = None,
format: Optional[str] = None,
dtype: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
):
"""Save audio data to file.

Note:
Supported formats are;

* WAV, AMB

* 32-bit floating-point
* 32-bit signed integer
* 16-bit signed integer
* 8-bit unsigned integer

* MP3
* FLAC
* OGG/VORBIS
* SPHERE
* AMR-NB

To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.

Args:
filepath (str or pathlib.Path): Path to save file.
This function also handles ``pathlib.Path`` objects, but is annotated
Expand All @@ -215,32 +175,137 @@ def save(
compression (Optional[float]): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command.

* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``.
| ``8`` is default and highest compression.
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression
| and lowest quality. Default: ``3``.
``"mp3"``
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.

``"flac"``
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.

``"ogg"``, ``"vorbis"``
Number from ``-1`` to ``10``; ``-1`` is the highest compression
and lowest quality. Default: ``3``.

See the detail at http://sox.sourceforge.net/soxformat.html.
format (str, optional): Output audio format.
This is required when the output audio format cannot be infered from
``filepath``, (such as file extension or ``name`` attribute of the given file object).
dtype (str, optional): Output tensor dtype.
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
``dtype=None`` means no conversion is performed.
``dtype`` parameter is only effective for ``float32`` Tensor.
format (str, optional): Override the audio format.
When ``filepath`` argument is path-like object, audio format is infered from
file extension. If file extension is missing or different, you can specify the
correct format with this argument.

When ``filepath`` argument is file-like object, this argument is required.

Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;

- ``"PCM_S"`` (signed integer Linear PCM)
- ``"PCM_U"`` (unsigned integer Linear PCM)
- ``"PCM_F"`` (floating point PCM)
- ``"ULAW"`` (mu-law)
- ``"ALAW"`` (a-law)

Default values
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.

``"wav"``, ``"amb"``
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used to determine the default value.
- ``"PCM_U"`` if dtype is ``uint8``
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
- ``"PCM_F"`` if dtype is ``float32``

- ``"PCM_U"`` if ``bits_per_sample=8``
- ``"PCM_S"`` otherwise

``"sph"`` format;
- the default value is ``"PCM_S"``

bits_per_sample (int, optional): Changes the bit depth for the supported formats.
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.

Default Value;
If not provided, the default values are picked based on ``format`` and ``"encoding"``;

``"wav"``, ``"amb"``;
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
| Tensor is used.
- ``8`` if dtype is ``uint8``
- ``16`` if dtype is ``int16``
- ``32`` if dtype is ``int32`` or ``float32``

- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"``
- ``32`` if ``encoding`` is ``"PCM_F"``

``"flac"`` format;
- the default value is ``24``

``"sph"`` format;
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``

``"amb"`` format;
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
- ``32`` if ``encoding`` is ``"PCM_F"``

Supported formats/encodings/bit depth/compression are;

``"wav"``, ``"amb"``
- 32-bit floating-point PCM
- 32-bit signed integer PCM
- 24-bit signed integer PCM
- 16-bit signed integer PCM
- 8-bit unsigned integer PCM
- 8-bit mu-law
- 8-bit a-law

Note: Default encoding/bit depth is determined by the dtype of the input Tensor.

``"mp3"``
Fixed bit rate (such as 128kHz) and variable bit rate compression.
Default: VBR with high quality.

``"flac"``
- 8-bit
- 16-bit
- 24-bit (default)

``"ogg"``, ``"vorbis"``
- Different quality level. Default: approx. 112kbps

``"sph"``
- 8-bit signed integer PCM
- 16-bit signed integer PCM
- 24-bit signed integer PCM
- 32-bit signed integer PCM (default)
- 8-bit mu-law
- 8-bit a-law
- 16-bit a-law
- 24-bit a-law
- 32-bit a-law

``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s

Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
or ``libmp3lame`` etc.
"""
if src.dtype == torch.float32 and dtype is None:
warnings.warn(
'`dtype` default value will be changed to `int16` in 0.9 release.'
'Specify `dtype` to suppress this warning.'
)
if not torch.jit.is_scripting():
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
return
if hasattr(filepath, 'write'):
torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression,
format, encoding, bits_per_sample)
return
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, dtype)
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ set(
sox/utils.cpp
sox/effects.cpp
sox/effects_chain.cpp
sox/types.cpp
)

if(BUILD_TRANSDUCER)
Expand Down
Loading