diff --git a/test/torchaudio_unittest/backend/soundfile/common.py b/test/torchaudio_unittest/backend/soundfile/common.py index 8f991fb0f8..c6b014dd4c 100644 --- a/test/torchaudio_unittest/backend/soundfile/common.py +++ b/test/torchaudio_unittest/backend/soundfile/common.py @@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt): def parameterize(*params): return parameterized.expand(list(itertools.product(*params)), name_func=name_func) + + +def fetch_wav_subtype(dtype, encoding, bits_per_sample): + subtype = { + (None, None): dtype2subtype(dtype), + (None, 8): "PCM_U8", + ('PCM_U', None): "PCM_U8", + ('PCM_U', 8): "PCM_U8", + ('PCM_S', None): "PCM_32", + ('PCM_S', 16): "PCM_16", + ('PCM_S', 32): "PCM_32", + ('PCM_F', None): "FLOAT", + ('PCM_F', 32): "FLOAT", + ('PCM_F', 64): "DOUBLE", + ('ULAW', None): "ULAW", + ('ULAW', 8): "ULAW", + ('ALAW', None): "ALAW", + ('ALAW', 8): "ALAW", + }.get((encoding, bits_per_sample)) + if subtype: + return subtype + raise ValueError( + f"wav does not support ({encoding}, {bits_per_sample}).") diff --git a/test/torchaudio_unittest/backend/soundfile/save_test.py b/test/torchaudio_unittest/backend/soundfile/save_test.py index 2f2741c303..2c511ae3a1 100644 --- a/test/torchaudio_unittest/backend/soundfile/save_test.py +++ b/test/torchaudio_unittest/backend/soundfile/save_test.py @@ -11,7 +11,11 @@ get_wav_data, load_wav, ) -from .common import parameterize, dtype2subtype, skipIfFormatNotSupported +from .common import ( + fetch_wav_subtype, + parameterize, + skipIfFormatNotSupported, +) if _mod_utils.is_module_available("soundfile"): import soundfile @@ -20,28 +24,47 @@ class MockedSaveTest(PytorchTestCase): @parameterize( ["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True], + [ + (None, None), + ('PCM_U', None), + ('PCM_U', 8), + ('PCM_S', None), + ('PCM_S', 16), + ('PCM_S', 32), + ('PCM_F', None), + ('PCM_F', 32), + ('PCM_F', 64), + ('ULAW', None), + ('ULAW', 8), + ('ALAW', None), + ('ALAW', 8), + ], ) @patch("soundfile.write") - def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_write): + def test_wav(self, dtype, sample_rate, num_channels, channels_first, + enc_params, mocked_write): """soundfile_backend.save passes correct subtype to soundfile.write when WAV""" filepath = "foo.wav" input_tensor = get_wav_data( dtype, num_channels, num_frames=3 * sample_rate, - normalize=dtype == "flaot32", + normalize=dtype == "float32", channels_first=channels_first, ).t() + encoding, bits_per_sample = enc_params soundfile_backend.save( - filepath, input_tensor, sample_rate, channels_first=channels_first + filepath, input_tensor, sample_rate, channels_first=channels_first, + encoding=encoding, bits_per_sample=bits_per_sample ) # on +Py3.8 call_args.kwargs is more descreptive args = mocked_write.call_args[1] assert args["file"] == filepath assert args["samplerate"] == sample_rate - assert args["subtype"] == dtype2subtype(dtype) + assert args["subtype"] == fetch_wav_subtype( + dtype, encoding, bits_per_sample) assert args["format"] is None self.assertEqual( args["data"], input_tensor.t() if channels_first else input_tensor @@ -49,7 +72,8 @@ def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_writ @patch("soundfile.write") def assert_non_wav( - self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write + self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write, + encoding=None, bits_per_sample=None, ): """soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE""" filepath = f"foo.{fmt}" @@ -63,14 +87,14 @@ def assert_non_wav( expected_data = input_tensor.t() if channels_first else input_tensor soundfile_backend.save( - filepath, input_tensor, sample_rate, channels_first=channels_first + filepath, input_tensor, sample_rate, channels_first, + encoding=encoding, bits_per_sample=bits_per_sample, ) # on +Py3.8 call_args.kwargs is more descreptive args = mocked_write.call_args[1] assert args["file"] == filepath assert args["samplerate"] == sample_rate - assert args["subtype"] is None if fmt in ["sph", "nist", "nis"]: assert args["format"] == "NIST" else: @@ -83,19 +107,36 @@ def assert_non_wav( [8000, 16000], [1, 2], [False, True], + [ + ('PCM_S', 8), + ('PCM_S', 16), + ('PCM_S', 24), + ('PCM_S', 32), + ('ULAW', 8), + ('ALAW', 8), + ('ALAW', 16), + ('ALAW', 24), + ('ALAW', 32), + ], ) - def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first): + def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params): """soundfile_backend.save passes default format and subtype (None-s) to soundfile.write when not WAV""" - self.assert_non_wav(fmt, dtype, sample_rate, num_channels, channels_first) + encoding, bits_per_sample = enc_params + self.assert_non_wav(fmt, dtype, sample_rate, num_channels, + channels_first, encoding=encoding, + bits_per_sample=bits_per_sample) @parameterize( ["int32", "int16"], [8000, 16000], [1, 2], [False, True], + [8, 16, 24], ) - def test_flac(self, dtype, sample_rate, num_channels, channels_first): + def test_flac(self, dtype, sample_rate, num_channels, + channels_first, bits_per_sample): """soundfile_backend.save passes default format and subtype (None-s) to soundfile.write when not WAV""" - self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first) + self.assert_non_wav("flac", dtype, sample_rate, num_channels, + channels_first, bits_per_sample=bits_per_sample) @parameterize( ["int32", "int16"], [8000, 16000], [1, 2], [False, True], @@ -228,7 +269,7 @@ def _test_fileobj(self, ext): found, sr = soundfile.read(fileobj, dtype='float32') assert sr == sample_rate - self.assertEqual(expected, found) + self.assertEqual(expected, found, atol=1e-4, rtol=1e-8) def test_fileobj_wav(self): """Saving audio via file-like object works""" diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 2b36ddd7b0..f939548413 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -209,6 +209,93 @@ def load( return waveform, sample_rate +def _get_subtype_for_wav( + dtype: torch.dtype, + encoding: str, + bits_per_sample: int): + if not encoding: + if not bits_per_sample: + subtype = { + torch.uint8: "PCM_U8", + torch.int16: "PCM_16", + torch.int32: "PCM_32", + torch.float32: "FLOAT", + torch.float64: "DOUBLE", + }.get(dtype) + if not subtype: + raise ValueError(f"Unsupported dtype for wav: {dtype}") + return subtype + if bits_per_sample == 8: + return "PCM_U8" + return f"PCM_{bits_per_sample}" + if encoding == "PCM_S": + if not bits_per_sample: + return "PCM_32" + if bits_per_sample == 8: + raise ValueError("wav does not support 8-bit signed PCM encoding.") + return f"PCM_{bits_per_sample}" + if encoding == "PCM_U": + if bits_per_sample in (None, 8): + return "PCM_U8" + raise ValueError("wav only supports 8-bit unsigned PCM encoding.") + if encoding == "PCM_F": + if bits_per_sample in (None, 32): + return "FLOAT" + if bits_per_sample == 64: + return "DOUBLE" + raise ValueError("wav only supports 32/64-bit float PCM encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "ULAW" + raise ValueError("wav only supports 8-bit mu-law encoding.") + if encoding == "ALAW": + if bits_per_sample in (None, 8): + return "ALAW" + raise ValueError("wav only supports 8-bit a-law encoding.") + raise ValueError(f"wav does not support {encoding}.") + + +def _get_subtype_for_sphere(encoding: str, bits_per_sample: int): + if encoding in (None, "PCM_S"): + return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32" + if encoding in ("PCM_U", "PCM_F"): + raise ValueError(f"sph does not support {encoding} encoding.") + if encoding == "ULAW": + if bits_per_sample in (None, 8): + return "ULAW" + raise ValueError("sph only supports 8-bit for mu-law encoding.") + if encoding == "ALAW": + return "ALAW" + raise ValueError(f"sph does not support {encoding}.") + + +def _get_subtype( + dtype: torch.dtype, + format: str, + encoding: str, + bits_per_sample: int): + if format == "wav": + return _get_subtype_for_wav(dtype, encoding, bits_per_sample) + if format == "flac": + if encoding: + raise ValueError("flac does not support encoding.") + if not bits_per_sample: + return "PCM_24" + if bits_per_sample > 24: + raise ValueError("flac does not support bits_per_sample > 24.") + return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}" + if format in ("ogg", "vorbis"): + if encoding or bits_per_sample: + raise ValueError( + "ogg/vorbis does not support encoding/bits_per_sample.") + return "VORBIS" + if format == "sph": + return _get_subtype_for_sphere(encoding, bits_per_sample) + if format in ("nis", "nist"): + return "PCM_16" + raise ValueError(f"Unsupported format: {format}") + + @_mod_utils.requires_module("soundfile") def save( filepath: str, @@ -217,6 +304,8 @@ def save( channels_first: bool = True, compression: Optional[float] = None, format: Optional[str] = None, + encoding: Optional[str] = None, + bits_per_sample: Optional[int] = None, ): """Save audio data to file. @@ -246,9 +335,65 @@ def save( otherwise ``[time, channel]``. compression (Optional[float]): Not used. It is here only for interface compatibility reson with "sox_io" backend. - format (str, optional): Output audio format. - This is required when the output audio format cannot be infered from - ``filepath``, (such as file extension or ``name`` attribute of the given file object). + format (str, optional): Override the audio format. + When ``filepath`` argument is path-like object, audio format is + inferred from file extension. If the file extension is missing or + different, you can specify the correct format with this argument. + + When ``filepath`` argument is file-like object, + this argument is required. + + Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``, + ``"flac"`` and ``"sph"``. + encoding (str, optional): Changes the encoding for supported formats. + This argument is effective only for supported formats, sush as + ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are; + + - ``"PCM_S"`` (signed integer Linear PCM) + - ``"PCM_U"`` (unsigned integer Linear PCM) + - ``"PCM_F"`` (floating point PCM) + - ``"ULAW"`` (mu-law) + - ``"ALAW"`` (a-law) + + bits_per_sample (int, optional): Changes the bit depth for the + supported formats. + When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``, + you can change the bit depth. + Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``. + + Supported formats/encodings/bit depth/compression are: + + ``"wav"`` + - 32-bit floating-point PCM + - 32-bit signed integer PCM + - 24-bit signed integer PCM + - 16-bit signed integer PCM + - 8-bit unsigned integer PCM + - 8-bit mu-law + - 8-bit a-law + + Note: Default encoding/bit depth is determined by the dtype of + the input Tensor. + + ``"flac"`` + - 8-bit + - 16-bit + - 24-bit (default) + + ``"ogg"``, ``"vorbis"`` + - Doesn't accept changing configuration. + + ``"sph"`` + - 8-bit signed integer PCM + - 16-bit signed integer PCM + - 24-bit signed integer PCM + - 32-bit signed integer PCM (default) + - 8-bit mu-law + - 8-bit a-law + - 16-bit a-law + - 24-bit a-law + - 32-bit a-law + """ if src.ndim != 2: raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.") @@ -260,24 +405,13 @@ def save( if hasattr(filepath, 'write'): if format is None: raise RuntimeError('`format` is required when saving to file object.') - ext = format + ext = format.lower() else: ext = str(filepath).split(".")[-1].lower() - if ext != "wav": - subtype = None - elif src.dtype == torch.uint8: - subtype = "PCM_U8" - elif src.dtype == torch.int16: - subtype = "PCM_16" - elif src.dtype == torch.int32: - subtype = "PCM_32" - elif src.dtype == torch.float32: - subtype = "FLOAT" - elif src.dtype == torch.float64: - subtype = "DOUBLE" - else: - raise ValueError(f"Unsupported dtype for WAV: {src.dtype}") + if bits_per_sample not in (None, 8, 16, 24, 32, 64): + raise ValueError("Invalid bits_per_sample.") + subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample) # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format, # so we extend the extensions manually here diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 54bacd5e5f..ecfd5ebd42 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -198,7 +198,7 @@ def save( ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``. encoding (str, optional): Changes the encoding for the supported formats. - This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"`` + This argument is effective only for supported formats, such as ``"wav"``, ``""amb"`` and ``"sph"``. Valid values are; - ``"PCM_S"`` (signed integer Linear PCM)