From 8aa45d00eee787dd7456d7dbe6dbb5f413c76e3b Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Mon, 15 Feb 2021 17:07:01 +0000 Subject: [PATCH 1/8] Added encoding and bits_per_sample to soundfile's backend save() --- torchaudio/backend/_soundfile_backend.py | 127 ++++++++++++++++++++++- torchaudio/backend/sox_io_backend.py | 2 +- 2 files changed, 125 insertions(+), 4 deletions(-) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 2b36ddd7b0..a0a37da660 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -217,6 +217,8 @@ def save( channels_first: bool = True, compression: Optional[float] = None, format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, ): """Save audio data to file. @@ -246,9 +248,106 @@ def save( otherwise ``[time, channel]``. compression (Optional[float]): Not used. It is here only for interface compatibility reson with "sox_io" backend. - format (str, optional): Output audio format. - This is required when the output audio format cannot be infered from - ``filepath``, (such as file extension or ``name`` attribute of the given file object). + format (str, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is + inferred from file extension. If the file extension is missing or + different, you can specify the correct format with this argument. + + When ``filepath`` argument is file-like object, + this argument is required. + + Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``, + ``"flac"`` and ``"sph"``. + encoding (str, optional): Changes the encoding for supported formats. + This argument is effective only for supported formats, sush as + ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are; + + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) + + Default values: + If not provided, the default value is picked based on + ``format`` and ``bits_per_sample``. + + ``"wav"``: + - | If both ``encoding`` and ``bits_per_sample`` are not + | provided, the ``dtype`` of the Tensor is used to + | determine the default value. + - ``"PCM_U"`` if dtype is ``uint8`` + - ``"PCM_S"`` if dtype is ``int16`` or ``int32` + - ``"PCM_F"`` if dtype is ``float32`` + + - ``"PCM_U"`` if ``bits_per_sample=8`` + - ``"PCM_S"`` otherwise + + ``"sph"``: + - the default value is ``"PCM_S"`` + + bits_per_sample (int, optional): Changes the bit depth for the + supported formats. + When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, + you can change the bit depth. + Valid values are ``8``, ``16``, ``32`` and ``64``. + + Default Value: + If not provided, the default values are picked based on + ``format`` and ``"encoding"``; + + ``"wav"``: + - | If both ``encoding`` and ``bits_per_sample`` are not + | provided, the ``dtype`` of the Tensor is used. + - ``8`` if dtype is ``uint8`` + - ``16`` if dtype is ``int16`` + - ``32`` if dtype is ``int32`` or ``float32`` + + - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or + ``"ALAW"`` + - ``16`` if ``encoding`` is ``"PCM_S"`` + - ``32`` if ``encoding`` is ``"PCM_F"`` + + ``"flac"``: + - the default value is ``24`` + + ``"sph"``: + - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, + ``"PCM_F"`` or not provided. + - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` + + Supported formats/encodings/bit depth/compression are: + + ``"wav"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: Default encoding/bit depth is determined by the dtype of + the input Tensor. + + ``"flac"`` + - 8-bit + - 16-bit + - 24-bit (default) + + ``"ogg"``, ``"vorbis"`` + - Different quality level. Default: approx. 112kbps + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law """ if src.ndim != 2: raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.") @@ -264,6 +363,7 @@ def save( else: ext = str(filepath).split(".")[-1].lower() + """ if ext != "wav": subtype = None elif src.dtype == torch.uint8: @@ -278,6 +378,27 @@ def save( subtype = "DOUBLE" else: raise ValueError(f"Unsupported dtype for WAV: {src.dtype}") + """ + + subtype = None + if ext == "wav": + if not encoding and not bits_per_sample: + mapper = { + torch.uint8: "PCM_U8", + torch.int16: "PCM_16", + torch.int32: "PCM_32", + torch.float32: "FLOAT", + torch.float64: "DOUBLE", + } + subtype = mapper.get(src.dtype, None) + if not subtype: + raise ValueError(f"Unsupported dtype for WAV: {src.dtype}") + elif bits_per_sample == 8: + subtype = "PCM_U8" + else: + subtype = "PCM_S8" + elif ext == "sph": + subtype = "PCM_S8" # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, # so we extend the extensions manually here diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 16abc70deb..a8524c60e4 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -197,7 +197,7 @@ def save( Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, ``"amb"``, ``"flac"`` and ``"sph"``. encoding (str, optional): Changes the encoding for the supported formats. - This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"`` + This argument is effective only for supported formats, sush as ``"wav"``, ``""amb"`` and ``"sph"``. Valid values are; - ``"PCM_S"`` (signed integer Linear PCM) From 9b1cd447560bd3064e3f0edee2da185489a32f0c Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 17 Feb 2021 13:20:51 +0000 Subject: [PATCH 2/8] Updated unit test --- .../backend/soundfile/save_test.py | 55 ++++-- torchaudio/backend/_soundfile_backend.py | 175 ++++++++++-------- 2 files changed, 136 insertions(+), 94 deletions(-) diff --git a/test/torchaudio_unittest/backend/soundfile/save_test.py b/test/torchaudio_unittest/backend/soundfile/save_test.py index 2f2741c303..a29f8b2a23 100644 --- a/test/torchaudio_unittest/backend/soundfile/save_test.py +++ b/test/torchaudio_unittest/backend/soundfile/save_test.py @@ -11,7 +11,7 @@ get_wav_data, load_wav, ) -from .common import parameterize, dtype2subtype, skipIfFormatNotSupported +from .common import parameterize, skipIfFormatNotSupported if _mod_utils.is_module_available("soundfile"): import soundfile @@ -20,28 +20,39 @@ class MockedSaveTest(PytorchTestCase): @parameterize( ["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True], + [ + ('PCM_U', 8), + ('PCM_S', 16), + ('PCM_S', 32), + ('PCM_F', 32), + ('PCM_F', 64), + ('ULAW', 8), + ('ALAW', 8), + ], ) @patch("soundfile.write") - def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_write): + def test_wav(self, dtype, sample_rate, num_channels, channels_first, + enc_params, mocked_write): """soundfile_backend.save passes correct subtype to soundfile.write when WAV""" filepath = "foo.wav" input_tensor = get_wav_data( dtype, num_channels, num_frames=3 * sample_rate, - normalize=dtype == "flaot32", + normalize=dtype == "float32", channels_first=channels_first, ).t() + encoding, bits_per_sample = enc_params soundfile_backend.save( - filepath, input_tensor, sample_rate, channels_first=channels_first + filepath, input_tensor, sample_rate, channels_first=channels_first, + encoding=encoding, bits_per_sample=bits_per_sample ) # on +Py3.8 call_args.kwargs is more descreptive args = mocked_write.call_args[1] assert args["file"] == filepath assert args["samplerate"] == sample_rate - assert args["subtype"] == dtype2subtype(dtype) assert args["format"] is None self.assertEqual( args["data"], input_tensor.t() if channels_first else input_tensor @@ -49,7 +60,8 @@ def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_writ @patch("soundfile.write") def assert_non_wav( - self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write + self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write, + encoding=None, bits_per_sample=None, ): """soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE""" filepath = f"foo.{fmt}" @@ -63,14 +75,14 @@ def assert_non_wav( expected_data = input_tensor.t() if channels_first else input_tensor soundfile_backend.save( - filepath, input_tensor, sample_rate, channels_first=channels_first + filepath, input_tensor, sample_rate, channels_first, + encoding=encoding, bits_per_sample=bits_per_sample, ) # on +Py3.8 call_args.kwargs is more descreptive args = mocked_write.call_args[1] assert args["file"] == filepath assert args["samplerate"] == sample_rate - assert args["subtype"] is None if fmt in ["sph", "nist", "nis"]: assert args["format"] == "NIST" else: @@ -83,19 +95,36 @@ def assert_non_wav( [8000, 16000], [1, 2], [False, True], + [ + ('PCM_S', 8, ), + ('PCM_S', 16, ), + ('PCM_S', 24, ), + ('PCM_S', 32, ), + ('ULAW', 8), + ('ALAW', 8), + ('ALAW', 16), + ('ALAW', 24), + ('ALAW', 32), + ], ) - def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first): + def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params): """soundfile_backend.save passes default format and subtype (None-s) to soundfile.write when not WAV""" - self.assert_non_wav(fmt, dtype, sample_rate, num_channels, channels_first) + encoding, bits_per_sample = enc_params + self.assert_non_wav(fmt, dtype, sample_rate, num_channels, + channels_first, encoding=encoding, + bits_per_sample=bits_per_sample) @parameterize( ["int32", "int16"], [8000, 16000], [1, 2], [False, True], + [8, 16, 24], ) - def test_flac(self, dtype, sample_rate, num_channels, channels_first): + def test_flac(self, dtype, sample_rate, num_channels, + channels_first, bits_per_sample): """soundfile_backend.save passes default format and subtype (None-s) to soundfile.write when not WAV""" - self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first) + self.assert_non_wav("flac", dtype, sample_rate, num_channels, + channels_first, bits_per_sample=bits_per_sample) @parameterize( ["int32", "int16"], [8000, 16000], [1, 2], [False, True], @@ -228,7 +257,7 @@ def _test_fileobj(self, ext): found, sr = soundfile.read(fileobj, dtype='float32') assert sr == sample_rate - self.assertEqual(expected, found) + self.assertEqual(expected, found, atol=1e-4, rtol=1e-8) def test_fileobj_wav(self): """Saving audio via file-like object works""" diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index a0a37da660..ebdce1b8ea 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -209,6 +209,95 @@ def load( return waveform, sample_rate +def _get_subtype_for_wav( + dtype: torch.dtype, + encoding: str, + bits_per_sample: int): + if not encoding: + if not bits_per_sample: + mapper = { + torch.uint8: "PCM_U8", + torch.int16: "PCM_16", + torch.int32: "PCM_32", + torch.float32: "FLOAT", + torch.float64: "DOUBLE", + } + subtype = mapper.get(dtype, None) + if not subtype: + raise ValueError(f"Unsupported dtype for wav: {dtype}") + return subtype + if bits_per_sample == 8: + return "PCM_U8" + return f"PCM_{bits_per_sample}" + if encoding == "PCM_S": + if not bits_per_sample: + return "PCM_32" + if bits_per_sample == 8: + raise ValueError("wav does not support 8-bit signed PCM encoding.") + return f"PCM_{bits_per_sample}" + if encoding == "PCM_U": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("wav only supports 8-bit unsigned PCM encoding.") + if encoding == "PCM_F": + if bits_per_sample in (None, 32): + return "FLOAT" + if bits_per_sample == 64: + return "DOUBLE" + raise ValueError("wav only supports 32/64-bit float PCM encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("wav only supports 8-bit mu-law encoding.") + if encoding == "ALAW": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("wav only supports 8-bit a-law encoding.") + raise ValueError(f"wav does not support {encoding}.") + + +def _get_subtype_for_sphere(encoding: str, bits_per_sample: int): + if encoding in (None, "PCM_S"): + return f"PCM{bits_per_sample}" if bits_per_sample else "PCM_32" + if encoding in ("PCM_U", "PCM_F"): + raise ValueError(f"sph does not support {encoding} encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("sph only supports 8-bit for mu-law encoding.") + if encoding == "ALAW": + return ("PCM_U8" if bits_per_sample in (None, 8) + else f"PCM_{bits_per_sample}") + raise ValueError(f"sph does not support {encoding}.") + + +def _get_subtype( + dtype: torch.dtype, + format: str, + encoding: str, + bits_per_sample: int): + if format == "wav": + return _get_subtype_for_wav(dtype, encoding, bits_per_sample) + if format == "flac": + if encoding: + raise ValueError("flac does not support encoding.") + if not bits_per_sample: + return "PCM_24" + if bits_per_sample > 24: + raise ValueError("flac does not support bits_per_sample > 24.") + return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}" + if format in ("ogg", "vorbis"): + if encoding or bits_per_sample: + raise ValueError( + "ogg/vorbis does not support encoding/bits_per_sample.") + return "VORBIS" + if format == "sph": + return _get_subtype_for_sphere(encoding, bits_per_sample) + if format in ("nis", "nist"): + return "PCM_16" + raise ValueError(f"Unsupported format: {format}") + + @_mod_utils.requires_module("soundfile") def save( filepath: str, @@ -267,54 +356,11 @@ def save( - ``"PCM_F"`` (floating point PCM) - ``"ULAW"`` (mu-law) - ``"ALAW"`` (a-law) - - Default values: - If not provided, the default value is picked based on - ``format`` and ``bits_per_sample``. - - ``"wav"``: - - | If both ``encoding`` and ``bits_per_sample`` are not - | provided, the ``dtype`` of the Tensor is used to - | determine the default value. - - ``"PCM_U"`` if dtype is ``uint8`` - - ``"PCM_S"`` if dtype is ``int16`` or ``int32` - - ``"PCM_F"`` if dtype is ``float32`` - - - ``"PCM_U"`` if ``bits_per_sample=8`` - - ``"PCM_S"`` otherwise - - ``"sph"``: - - the default value is ``"PCM_S"`` - bits_per_sample (int, optional): Changes the bit depth for the supported formats. When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, you can change the bit depth. - Valid values are ``8``, ``16``, ``32`` and ``64``. - - Default Value: - If not provided, the default values are picked based on - ``format`` and ``"encoding"``; - - ``"wav"``: - - | If both ``encoding`` and ``bits_per_sample`` are not - | provided, the ``dtype`` of the Tensor is used. - - ``8`` if dtype is ``uint8`` - - ``16`` if dtype is ``int16`` - - ``32`` if dtype is ``int32`` or ``float32`` - - - ``8`` if ``encoding`` is ``"PCM_U"``, ``"ULAW"`` or - ``"ALAW"`` - - ``16`` if ``encoding`` is ``"PCM_S"`` - - ``32`` if ``encoding`` is ``"PCM_F"`` - - ``"flac"``: - - the default value is ``24`` - - ``"sph"``: - - ``16`` if ``encoding`` is ``"PCM_U"``, ``"PCM_S"``, - ``"PCM_F"`` or not provided. - - ``8`` if ``encoding`` is ``"ULAW"`` or ``"ALAW"`` + Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``. Supported formats/encodings/bit depth/compression are: @@ -359,46 +405,13 @@ def save( if hasattr(filepath, 'write'): if format is None: raise RuntimeError('`format` is required when saving to file object.') - ext = format + ext = format.lower() else: ext = str(filepath).split(".")[-1].lower() - """ - if ext != "wav": - subtype = None - elif src.dtype == torch.uint8: - subtype = "PCM_U8" - elif src.dtype == torch.int16: - subtype = "PCM_16" - elif src.dtype == torch.int32: - subtype = "PCM_32" - elif src.dtype == torch.float32: - subtype = "FLOAT" - elif src.dtype == torch.float64: - subtype = "DOUBLE" - else: - raise ValueError(f"Unsupported dtype for WAV: {src.dtype}") - """ - - subtype = None - if ext == "wav": - if not encoding and not bits_per_sample: - mapper = { - torch.uint8: "PCM_U8", - torch.int16: "PCM_16", - torch.int32: "PCM_32", - torch.float32: "FLOAT", - torch.float64: "DOUBLE", - } - subtype = mapper.get(src.dtype, None) - if not subtype: - raise ValueError(f"Unsupported dtype for WAV: {src.dtype}") - elif bits_per_sample == 8: - subtype = "PCM_U8" - else: - subtype = "PCM_S8" - elif ext == "sph": - subtype = "PCM_S8" + if bits_per_sample not in (None, 8, 16, 24, 32, 64): + raise ValueError("Invalid bits_per_sample.") + subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample) # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, # so we extend the extensions manually here From c0440f5be7606981bb802aaaf805ff63b582ebaa Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 17 Feb 2021 13:23:12 +0000 Subject: [PATCH 3/8] Fixed typo --- torchaudio/backend/sox_io_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index a8524c60e4..b033eadd9d 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -197,7 +197,7 @@ def save( Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, ``"amb"``, ``"flac"`` and ``"sph"``. encoding (str, optional): Changes the encoding for the supported formats. - This argument is effective only for supported formats, sush as ``"wav"``, ``""amb"`` + This argument is effective only for supported formats, such as ``"wav"``, ``""amb"`` and ``"sph"``. Valid values are; - ``"PCM_S"`` (signed integer Linear PCM) From f60defb768da086a9ecb26c9975a7651e2bd18f3 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 17 Feb 2021 13:34:28 +0000 Subject: [PATCH 4/8] Added newline after docstring list --- torchaudio/backend/_soundfile_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index ebdce1b8ea..f13060e92f 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -356,6 +356,7 @@ def save( - ``"PCM_F"`` (floating point PCM) - ``"ULAW"`` (mu-law) - ``"ALAW"`` (a-law) + bits_per_sample (int, optional): Changes the bit depth for the supported formats. When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, @@ -394,6 +395,7 @@ def save( - 16-bit a-law - 24-bit a-law - 32-bit a-law + """ if src.ndim != 2: raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.") From a875fcca117e8fe39bdc60bde901af134d50bb3a Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 18 Feb 2021 18:00:44 +0000 Subject: [PATCH 5/8] Addressed review comments --- .../backend/soundfile/common.py | 23 +++++++++++++++++++ .../backend/soundfile/save_test.py | 22 ++++++++++++++---- torchaudio/backend/_soundfile_backend.py | 16 ++++++------- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/test/torchaudio_unittest/backend/soundfile/common.py b/test/torchaudio_unittest/backend/soundfile/common.py index 8f991fb0f8..c6b014dd4c 100644 --- a/test/torchaudio_unittest/backend/soundfile/common.py +++ b/test/torchaudio_unittest/backend/soundfile/common.py @@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt): def parameterize(*params): return parameterized.expand(list(itertools.product(*params)), name_func=name_func) + + +def fetch_wav_subtype(dtype, encoding, bits_per_sample): + subtype = { + (None, None): dtype2subtype(dtype), + (None, 8): "PCM_U8", + ('PCM_U', None): "PCM_U8", + ('PCM_U', 8): "PCM_U8", + ('PCM_S', None): "PCM_32", + ('PCM_S', 16): "PCM_16", + ('PCM_S', 32): "PCM_32", + ('PCM_F', None): "FLOAT", + ('PCM_F', 32): "FLOAT", + ('PCM_F', 64): "DOUBLE", + ('ULAW', None): "ULAW", + ('ULAW', 8): "ULAW", + ('ALAW', None): "ALAW", + ('ALAW', 8): "ALAW", + }.get((encoding, bits_per_sample)) + if subtype: + return subtype + raise ValueError( + f"wav does not support ({encoding}, {bits_per_sample}).") diff --git a/test/torchaudio_unittest/backend/soundfile/save_test.py b/test/torchaudio_unittest/backend/soundfile/save_test.py index a29f8b2a23..2c511ae3a1 100644 --- a/test/torchaudio_unittest/backend/soundfile/save_test.py +++ b/test/torchaudio_unittest/backend/soundfile/save_test.py @@ -11,7 +11,11 @@ get_wav_data, load_wav, ) -from .common import parameterize, skipIfFormatNotSupported +from .common import ( + fetch_wav_subtype, + parameterize, + skipIfFormatNotSupported, +) if _mod_utils.is_module_available("soundfile"): import soundfile @@ -21,12 +25,18 @@ class MockedSaveTest(PytorchTestCase): @parameterize( ["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True], [ + (None, None), + ('PCM_U', None), ('PCM_U', 8), + ('PCM_S', None), ('PCM_S', 16), ('PCM_S', 32), + ('PCM_F', None), ('PCM_F', 32), ('PCM_F', 64), + ('ULAW', None), ('ULAW', 8), + ('ALAW', None), ('ALAW', 8), ], ) @@ -53,6 +63,8 @@ def test_wav(self, dtype, sample_rate, num_channels, channels_first, args = mocked_write.call_args[1] assert args["file"] == filepath assert args["samplerate"] == sample_rate + assert args["subtype"] == fetch_wav_subtype( + dtype, encoding, bits_per_sample) assert args["format"] is None self.assertEqual( args["data"], input_tensor.t() if channels_first else input_tensor @@ -96,10 +108,10 @@ def assert_non_wav( [1, 2], [False, True], [ - ('PCM_S', 8, ), - ('PCM_S', 16, ), - ('PCM_S', 24, ), - ('PCM_S', 32, ), + ('PCM_S', 8), + ('PCM_S', 16), + ('PCM_S', 24), + ('PCM_S', 32), ('ULAW', 8), ('ALAW', 8), ('ALAW', 16), diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index f13060e92f..e0423b09df 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -215,14 +215,13 @@ def _get_subtype_for_wav( bits_per_sample: int): if not encoding: if not bits_per_sample: - mapper = { + subtype = { torch.uint8: "PCM_U8", torch.int16: "PCM_16", torch.int32: "PCM_32", torch.float32: "FLOAT", torch.float64: "DOUBLE", - } - subtype = mapper.get(dtype, None) + }.get(dtype) if not subtype: raise ValueError(f"Unsupported dtype for wav: {dtype}") return subtype @@ -247,11 +246,11 @@ def _get_subtype_for_wav( raise ValueError("wav only supports 32/64-bit float PCM encoding.") if encoding == "ULAW": if bits_per_sample in (None, 8): - return "PCM_U8" + return "ULAW" raise ValueError("wav only supports 8-bit mu-law encoding.") if encoding == "ALAW": if bits_per_sample in (None, 8): - return "PCM_U8" + return "ALAW" raise ValueError("wav only supports 8-bit a-law encoding.") raise ValueError(f"wav does not support {encoding}.") @@ -263,11 +262,10 @@ def _get_subtype_for_sphere(encoding: str, bits_per_sample: int): raise ValueError(f"sph does not support {encoding} encoding.") if encoding == "ULAW": if bits_per_sample in (None, 8): - return "PCM_U8" + return "ULAW" raise ValueError("sph only supports 8-bit for mu-law encoding.") if encoding == "ALAW": - return ("PCM_U8" if bits_per_sample in (None, 8) - else f"PCM_{bits_per_sample}") + return "ALAW" raise ValueError(f"sph does not support {encoding}.") @@ -383,7 +381,7 @@ def save( - 24-bit (default) ``"ogg"``, ``"vorbis"`` - - Different quality level. Default: approx. 112kbps + - Doesn't accept changing configuration. ``"sph"`` - 8-bit signed integer PCM From 9bccd7bf0157485b58a538d23f2e1d3d97f31382 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 19 Feb 2021 15:41:03 +0000 Subject: [PATCH 6/8] Fixed docstring indentation --- torchaudio/backend/_soundfile_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index e0423b09df..e4f910968f 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -356,7 +356,7 @@ def save( - ``"ALAW"`` (a-law) bits_per_sample (int, optional): Changes the bit depth for the - supported formats. + supported formats. When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, you can change the bit depth. Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``. From 8629e4c9f85765fca8ba2db46d532eb603cf1214 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 19 Feb 2021 17:39:23 +0000 Subject: [PATCH 7/8] Correct flac default --- torchaudio/backend/_soundfile_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index e4f910968f..4f55db6e80 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -280,7 +280,7 @@ def _get_subtype( if encoding: raise ValueError("flac does not support encoding.") if not bits_per_sample: - return "PCM_24" + return "PCM_16" if bits_per_sample > 24: raise ValueError("flac does not support bits_per_sample > 24.") return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}" From 354328fb901d7006a5da43fb1091fc4d18672283 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Mon, 22 Feb 2021 15:30:12 +0000 Subject: [PATCH 8/8] Revert flac default to PCM_24 --- torchaudio/backend/_soundfile_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 8d70a306b4..f939548413 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -280,7 +280,7 @@ def _get_subtype( if encoding: raise ValueError("flac does not support encoding.") if not bits_per_sample: - return "PCM_16" + return "PCM_24" if bits_per_sample > 24: raise ValueError("flac does not support bits_per_sample > 24.") return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"