Skip to content

Commit c3cb201

Browse files
authored
Add encoding and bits_per_sample option to save function (#1226)
1 parent 4f9b552 commit c3cb201

File tree

14 files changed

+883
-631
lines changed

14 files changed

+883
-631
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,14 @@
11
def name_func(func, _, params):
22
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
3+
4+
5+
def get_enc_params(dtype):
6+
if dtype == 'float32':
7+
return 'PCM_F', 32
8+
if dtype == 'int32':
9+
return 'PCM_S', 32
10+
if dtype == 'int16':
11+
return 'PCM_S', 16
12+
if dtype == 'uint8':
13+
return 'PCM_U', 8
14+
raise ValueError(f'Unexpected dtype: {dtype}')

test/torchaudio_unittest/backend/sox_io/roundtrip_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from .common import (
1414
name_func,
15+
get_enc_params,
1516
)
1617

1718

@@ -27,10 +28,11 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
2728
def test_wav(self, dtype, sample_rate, num_channels):
2829
"""save/load round trip should not degrade data for wav formats"""
2930
original = get_wav_data(dtype, num_channels, normalize=False)
31+
enc, bps = get_enc_params(dtype)
3032
data = original
3133
for i in range(10):
3234
path = self.get_temp_path(f'{i}.wav')
33-
sox_io_backend.save(path, data, sample_rate)
35+
sox_io_backend.save(path, data, sample_rate, encoding=enc, bits_per_sample=bps)
3436
data, sr = sox_io_backend.load(path, normalize=False)
3537
assert sr == sample_rate
3638
self.assertEqual(original, data)

test/torchaudio_unittest/backend/sox_io/save_test.py

Lines changed: 302 additions & 438 deletions
Large diffs are not rendered by default.

test/torchaudio_unittest/backend/sox_io/torchscript_test.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from .common import (
1919
name_func,
20+
get_enc_params,
2021
)
2122

2223

@@ -35,8 +36,12 @@ def py_save_func(
3536
sample_rate: int,
3637
channels_first: bool = True,
3738
compression: Optional[float] = None,
39+
encoding: Optional[str] = None,
40+
bits_per_sample: Optional[int] = None,
3841
):
39-
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
42+
torchaudio.save(
43+
filepath, tensor, sample_rate, channels_first,
44+
compression, None, encoding, bits_per_sample)
4045

4146

4247
@skipIfNoExec('sox')
@@ -102,15 +107,16 @@ def test_save_wav(self, dtype, sample_rate, num_channels):
102107
torch.jit.script(py_save_func).save(script_path)
103108
ts_save_func = torch.jit.load(script_path)
104109

105-
expected = get_wav_data(dtype, num_channels)
110+
expected = get_wav_data(dtype, num_channels, normalize=False)
106111
py_path = self.get_temp_path(f'test_save_py_{dtype}_{sample_rate}_{num_channels}.wav')
107112
ts_path = self.get_temp_path(f'test_save_ts_{dtype}_{sample_rate}_{num_channels}.wav')
113+
enc, bps = get_enc_params(dtype)
108114

109-
py_save_func(py_path, expected, sample_rate, True, None)
110-
ts_save_func(ts_path, expected, sample_rate, True, None)
115+
py_save_func(py_path, expected, sample_rate, True, None, enc, bps)
116+
ts_save_func(ts_path, expected, sample_rate, True, None, enc, bps)
111117

112-
py_data, py_sr = load_wav(py_path)
113-
ts_data, ts_sr = load_wav(ts_path)
118+
py_data, py_sr = load_wav(py_path, normalize=False)
119+
ts_data, ts_sr = load_wav(ts_path, normalize=False)
114120

115121
self.assertEqual(sample_rate, py_sr)
116122
self.assertEqual(sample_rate, ts_sr)
@@ -131,8 +137,8 @@ def test_save_flac(self, sample_rate, num_channels, compression_level):
131137
py_path = self.get_temp_path(f'test_save_py_{sample_rate}_{num_channels}_{compression_level}.flac')
132138
ts_path = self.get_temp_path(f'test_save_ts_{sample_rate}_{num_channels}_{compression_level}.flac')
133139

134-
py_save_func(py_path, expected, sample_rate, True, compression_level)
135-
ts_save_func(ts_path, expected, sample_rate, True, compression_level)
140+
py_save_func(py_path, expected, sample_rate, True, compression_level, None, None)
141+
ts_save_func(ts_path, expected, sample_rate, True, compression_level, None, None)
136142

137143
# converting to 32 bit because flac file has 24 bit depth which scipy cannot handle.
138144
py_path_wav = f'{py_path}.wav'

test/torchaudio_unittest/common_utils/sox_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import subprocess
23
import warnings
34

@@ -32,6 +33,7 @@ def gen_audio_file(
3233
command = [
3334
'sox',
3435
'-V3', # verbose
36+
'--no-dither', # disable automatic dithering
3537
'-R',
3638
# -R is supposed to be repeatable, though the implementation looks suspicious
3739
# and not setting the seed to a fixed value.
@@ -61,21 +63,23 @@ def gen_audio_file(
6163
]
6264
if attenuation is not None:
6365
command += ['vol', f'-{attenuation}dB']
64-
print(' '.join(command))
66+
print(' '.join(command), file=sys.stderr)
6567
subprocess.run(command, check=True)
6668

6769

6870
def convert_audio_file(
6971
src_path, dst_path,
70-
*, bit_depth=None, compression=None):
72+
*, encoding=None, bit_depth=None, compression=None):
7173
"""Convert audio file with `sox` command."""
72-
command = ['sox', '-V3', '-R', str(src_path)]
74+
command = ['sox', '-V3', '--no-dither', '-R', str(src_path)]
75+
if encoding is not None:
76+
command += ['--encoding', str(encoding)]
7377
if bit_depth is not None:
7478
command += ['--bits', str(bit_depth)]
7579
if compression is not None:
7680
command += ['--compression', str(compression)]
7781
command += [dst_path]
78-
print(' '.join(command))
82+
print(' '.join(command), file=sys.stderr)
7983
subprocess.run(command, check=True)
8084

8185

torchaudio/backend/sox_io_backend.py

Lines changed: 128 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import warnings
32
from typing import Tuple, Optional
43

54
import torch
@@ -152,26 +151,6 @@ def load(
152151
filepath, frame_offset, num_frames, normalize, channels_first, format)
153152

154153

155-
@torch.jit.unused
156-
def _save(
157-
filepath: str,
158-
src: torch.Tensor,
159-
sample_rate: int,
160-
channels_first: bool = True,
161-
compression: Optional[float] = None,
162-
format: Optional[str] = None,
163-
dtype: Optional[str] = None,
164-
):
165-
if hasattr(filepath, 'write'):
166-
if format is None:
167-
raise RuntimeError('`format` is required when saving to file object.')
168-
torchaudio._torchaudio.save_audio_fileobj(
169-
filepath, src, sample_rate, channels_first, compression, format, dtype)
170-
else:
171-
torch.ops.torchaudio.sox_io_save_audio_file(
172-
os.fspath(filepath), src, sample_rate, channels_first, compression, format, dtype)
173-
174-
175154
@_mod_utils.requires_module('torchaudio._torchaudio')
176155
def save(
177156
filepath: str,
@@ -180,30 +159,11 @@ def save(
180159
channels_first: bool = True,
181160
compression: Optional[float] = None,
182161
format: Optional[str] = None,
183-
dtype: Optional[str] = None,
162+
encoding: Optional[str] = None,
163+
bits_per_sample: Optional[int] = None,
184164
):
185165
"""Save audio data to file.
186166
187-
Note:
188-
Supported formats are;
189-
190-
* WAV, AMB
191-
192-
* 32-bit floating-point
193-
* 32-bit signed integer
194-
* 16-bit signed integer
195-
* 8-bit unsigned integer
196-
197-
* MP3
198-
* FLAC
199-
* OGG/VORBIS
200-
* SPHERE
201-
* AMR-NB
202-
203-
To save ``MP3``, ``FLAC``, ``OGG/VORBIS``, and other codecs ``libsox`` does not
204-
handle natively, your installation of ``torchaudio`` has to be linked to ``libsox``
205-
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
206-
207167
Args:
208168
filepath (str or pathlib.Path): Path to save file.
209169
This function also handles ``pathlib.Path`` objects, but is annotated
@@ -215,32 +175,137 @@ def save(
215175
compression (Optional[float]): Used for formats other than WAV.
216176
This corresponds to ``-C`` option of ``sox`` command.
217177
218-
* | ``MP3``: Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
219-
| VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
220-
* | ``FLAC``: compression level. Whole number from ``0`` to ``8``.
221-
| ``8`` is default and highest compression.
222-
* | ``OGG/VORBIS``: number from ``-1`` to ``10``; ``-1`` is the highest compression
223-
| and lowest quality. Default: ``3``.
178+
``"mp3"``
179+
Either bitrate (in ``kbps``) with quality factor, such as ``128.2``, or
180+
VBR encoding with quality factor such as ``-4.2``. Default: ``-4.5``.
181+
182+
``"flac"``
183+
Whole number from ``0`` to ``8``. ``8`` is default and highest compression.
184+
185+
``"ogg"``, ``"vorbis"``
186+
Number from ``-1`` to ``10``; ``-1`` is the highest compression
187+
and lowest quality. Default: ``3``.
224188
225189
See the detail at http://sox.sourceforge.net/soxformat.html.
226-
format (str, optional): Output audio format.
227-
This is required when the output audio format cannot be infered from
228-
``filepath``, (such as file extension or ``name`` attribute of the given file object).
229-
dtype (str, optional): Output tensor dtype.
230-
Valid values: ``"uint8", "int16", "int32", "float32", "float64", None``
231-
``dtype=None`` means no conversion is performed.
232-
``dtype`` parameter is only effective for ``float32`` Tensor.
190+
format (str, optional): Override the audio format.
191+
When ``filepath`` argument is path-like object, audio format is infered from
192+
file extension. If file extension is missing or different, you can specify the
193+
correct format with this argument.
194+
195+
When ``filepath`` argument is file-like object, this argument is required.
196+
197+
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
198+
``"amb"``, ``"flac"`` and ``"sph"``.
199+
encoding (str, optional): Changes the encoding for the supported formats.
200+
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
201+
and ``"sph"``. Valid values are;
202+
203+
- ``"PCM_S"`` (signed integer Linear PCM)
204+
- ``"PCM_U"`` (unsigned integer Linear PCM)
205+
- ``"PCM_F"`` (floating point PCM)
206+
- ``"ULAW"`` (mu-law)
207+
- ``"ALAW"`` (a-law)
208+
209+
Default values
210+
If not provided, the default value is picked based on ``format`` and ``bits_per_sample``.
211+
212+
``"wav"``, ``"amb"``
213+
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
214+
| Tensor is used to determine the default value.
215+
- ``"PCM_U"`` if dtype is ``uint8``
216+
- ``"PCM_S"`` if dtype is ``int16`` or ``int32`
217+
- ``"PCM_F"`` if dtype is ``float32``
218+
219+
- ``"PCM_U"`` if ``bits_per_sample=8``
220+
- ``"PCM_S"`` otherwise
221+
222+
``"sph"`` format;
223+
- the default value is ``"PCM_S"``
224+
225+
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
226+
When ``format`` is one of ``"wav"``, ``"flac"``, ``"sph"``, or ``"amb"``, you can change the
227+
bit depth. Valid values are ``8``, ``16``, ``32`` and ``64``.
228+
229+
Default Value;
230+
If not provided, the default values are picked based on ``format`` and ``"encoding"``;
231+
232+
``"wav"``, ``"amb"``;
233+
- | If both ``encoding`` and ``bits_per_sample`` are not provided, the ``dtype`` of the
234+
| Tensor is used.
235+
- ``8`` if dtype is ``uint8``
236+
- ``16`` if dtype is ``int16``
237+
- ``32`` if dtype is ``int32`` or ``float32``
238+
239+
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
240+
- ``16`` if ``encoding`` is ``"PCM_S"``
241+
- ``32`` if ``encoding`` is ``"PCM_F"``
242+
243+
``"flac"`` format;
244+
- the default value is ``24``
245+
246+
``"sph"`` format;
247+
- ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, ``"PCM_F"`` or not provided.
248+
- ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"``
249+
250+
``"amb"`` format;
251+
- ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or ``"ALAW"``
252+
- ``16`` if ``encoding`` is ``"PCM_S"`` or not provided.
253+
- ``32`` if ``encoding`` is ``"PCM_F"``
254+
255+
Supported formats/encodings/bit depth/compression are;
256+
257+
``"wav"``, ``"amb"``
258+
- 32-bit floating-point PCM
259+
- 32-bit signed integer PCM
260+
- 24-bit signed integer PCM
261+
- 16-bit signed integer PCM
262+
- 8-bit unsigned integer PCM
263+
- 8-bit mu-law
264+
- 8-bit a-law
265+
266+
Note: Default encoding/bit depth is determined by the dtype of the input Tensor.
267+
268+
``"mp3"``
269+
Fixed bit rate (such as 128kHz) and variable bit rate compression.
270+
Default: VBR with high quality.
271+
272+
``"flac"``
273+
- 8-bit
274+
- 16-bit
275+
- 24-bit (default)
276+
277+
``"ogg"``, ``"vorbis"``
278+
- Different quality level. Default: approx. 112kbps
279+
280+
``"sph"``
281+
- 8-bit signed integer PCM
282+
- 16-bit signed integer PCM
283+
- 24-bit signed integer PCM
284+
- 32-bit signed integer PCM (default)
285+
- 8-bit mu-law
286+
- 8-bit a-law
287+
- 16-bit a-law
288+
- 24-bit a-law
289+
- 32-bit a-law
290+
291+
``"amr-nb"``
292+
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s
293+
294+
Note:
295+
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
296+
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
297+
to be linked to ``libsox`` and corresponding codec libraries such as ``libmad``
298+
or ``libmp3lame`` etc.
233299
"""
234-
if src.dtype == torch.float32 and dtype is None:
235-
warnings.warn(
236-
'`dtype` default value will be changed to `int16` in 0.9 release.'
237-
'Specify `dtype` to suppress this warning.'
238-
)
239300
if not torch.jit.is_scripting():
240-
_save(filepath, src, sample_rate, channels_first, compression, format, dtype)
241-
return
301+
if hasattr(filepath, 'write'):
302+
torchaudio._torchaudio.save_audio_fileobj(
303+
filepath, src, sample_rate, channels_first, compression,
304+
format, encoding, bits_per_sample)
305+
return
306+
filepath = os.fspath(filepath)
242307
torch.ops.torchaudio.sox_io_save_audio_file(
243-
filepath, src, sample_rate, channels_first, compression, format, dtype)
308+
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample)
244309

245310

246311
@_mod_utils.requires_module('torchaudio._torchaudio')

torchaudio/csrc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ set(
99
sox/utils.cpp
1010
sox/effects.cpp
1111
sox/effects_chain.cpp
12+
sox/types.cpp
1213
)
1314

1415
if(BUILD_TRANSDUCER)

0 commit comments

Comments
 (0)