diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index dca9c3ad9c..1ce2335410 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -237,6 +237,12 @@ def test_save_flac(self, test_mode, bits_per_sample, compression_level): "flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode=test_mode) + @nested_params( + ["path", "fileobj", "bytesio"], + ) + def test_save_gsm(self, test_mode): + self.assert_save_consistency("gsm", test_mode=test_mode) + @nested_params( ["path", "fileobj", "bytesio"], [ diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 16abc70deb..33c797183e 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -195,7 +195,7 @@ def save( When ``filepath`` argument is file-like object, this argument is required. Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, - ``"amb"``, ``"flac"`` and ``"sph"``. + ``"amb"``, ``"flac"``, ``"sph"``, and ``"gsm"``. encoding (str, optional): Changes the encoding for the supported formats. This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"`` and ``"sph"``. Valid values are; @@ -291,6 +291,9 @@ def save( ``"amr-nb"`` Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s + ``"gsm"`` + The GSM 06.10 Lossy Speech Compression only supports its default configuration. + Note: To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index 51e8e720d6..497de52f08 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -4,22 +4,25 @@ namespace torchaudio { namespace sox_utils { Format get_format_from_string(const std::string& format) { - if (format == "wav") + if (format == "wav") { return Format::WAV; - if (format == "mp3") + } else if (format == "mp3") { return Format::MP3; - if (format == "flac") + } else if (format == "flac") { return Format::FLAC; - if (format == "ogg" || format == "vorbis") + } else if (format == "ogg" || format == "vorbis") { return Format::VORBIS; - if (format == "amr-nb") + } else if (format == "amr-nb") { return Format::AMR_NB; - if (format == "amr-wb") + } else if (format == "amr-wb") { return Format::AMR_WB; - if (format == "amb") + } else if (format == "amb") { return Format::AMB; - if (format == "sph") + } else if (format == "sph") { return Format::SPHERE; + } else if (format == "gsm") { + return Format::GSM; + } std::ostringstream stream; stream << "Internal Error: unexpected format value: " << format; throw std::runtime_error(stream.str()); @@ -57,19 +60,21 @@ std::string to_string(Encoding v) { } Encoding get_encoding_from_option(const c10::optional& encoding) { - if (!encoding.has_value()) + if (!encoding.has_value()) { return Encoding::NOT_PROVIDED; + } std::string v = encoding.value(); - if (v == "PCM_S") + if (v == "PCM_S") { return Encoding::PCM_SIGNED; - if (v == "PCM_U") + } else if (v == "PCM_U") { return Encoding::PCM_UNSIGNED; - if (v == "PCM_F") + } else if (v == "PCM_F") { return Encoding::PCM_FLOAT; - if (v == "ULAW") + } else if (v == "ULAW") { return Encoding::ULAW; - if (v == "ALAW") + } else if (v == "ALAW") { return Encoding::ALAW; + } std::ostringstream stream; stream << "Internal Error: unexpected encoding value: " << v; throw std::runtime_error(stream.str()); diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h index f3ed637478..f3a337407c 100644 --- a/torchaudio/csrc/sox/types.h +++ b/torchaudio/csrc/sox/types.h @@ -15,6 +15,7 @@ enum class Format { AMR_WB, AMB, SPHERE, + GSM, }; Format get_format_from_string(const std::string& format); diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 99a264642f..e6b7183954 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -217,9 +217,9 @@ std::tuple get_save_encoding_for_wav( const Encoding& encoding, const BitDepth& bits_per_sample) { switch (encoding) { - case Encoding::NOT_PROVIDED: + case Encoding::NOT_PROVIDED: { switch (bits_per_sample) { - case BitDepth::NOT_PROVIDED: + case BitDepth::NOT_PROVIDED: { switch (dtype.toScalarType()) { case c10::ScalarType::Float: return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); @@ -232,65 +232,79 @@ std::tuple get_save_encoding_for_wav( default: throw std::runtime_error("Internal Error: Unexpected dtype."); } + } case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - default: + default: { return std::make_tuple<>( SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } } - case Encoding::PCM_SIGNED: + } + case Encoding::PCM_SIGNED: { switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); case BitDepth::B8: throw std::runtime_error( format + " does not support 8-bit signed PCM encoding."); - default: + default: { return std::make_tuple<>( SOX_ENCODING_SIGN2, static_cast(bits_per_sample)); + } } - case Encoding::PCM_UNSIGNED: + } + case Encoding::PCM_UNSIGNED: { switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8); - default: + default: { throw std::runtime_error( format + " only supports 8-bit for unsigned PCM encoding."); + } } - case Encoding::PCM_FLOAT: + } + case Encoding::PCM_FLOAT: { switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: case BitDepth::B32: return std::make_tuple<>(SOX_ENCODING_FLOAT, 32); case BitDepth::B64: return std::make_tuple<>(SOX_ENCODING_FLOAT, 64); - default: + default: { throw std::runtime_error( format + " only supports 32-bit or 64-bit for floating-point PCM encoding."); + } } - case Encoding::ULAW: + } + case Encoding::ULAW: { switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_ULAW, 8); - default: + default: { throw std::runtime_error( format + " only supports 8-bit for mu-law encoding."); + } } - case Encoding::ALAW: + } + case Encoding::ALAW: { switch (bits_per_sample) { case BitDepth::NOT_PROVIDED: case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_ALAW, 8); - default: + default: { throw std::runtime_error( format + " only supports 8-bit for a-law encoding."); + } } - default: + } + default: { throw std::runtime_error( format + " does not support encoding: " + to_string(encoding)); + } } } @@ -307,28 +321,39 @@ std::tuple get_save_encoding( case Format::WAV: case Format::AMB: return get_save_encoding_for_wav(format, dtype, enc, bps); - case Format::MP3: + case Format::MP3: { if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("mp3 does not support `encoding` option."); if (bps != BitDepth::NOT_PROVIDED) throw std::runtime_error( "mp3 does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_MP3, 16); - case Format::VORBIS: + } + case Format::GSM: { + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("gsm does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "gsm does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_GSM, 16); + } + case Format::VORBIS: { if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("vorbis does not support `encoding` option."); if (bps != BitDepth::NOT_PROVIDED) throw std::runtime_error( "vorbis does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_VORBIS, 16); - case Format::AMR_NB: + } + case Format::AMR_NB: { if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("amr-nb does not support `encoding` option."); if (bps != BitDepth::NOT_PROVIDED) throw std::runtime_error( "amr-nb does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16); - case Format::FLAC: + } + case Format::FLAC: { if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("flac does not support `encoding` option."); switch (bps) { @@ -336,61 +361,73 @@ std::tuple get_save_encoding( case BitDepth::B64: throw std::runtime_error( "flac does not support `bits_per_sample` larger than 24."); - default: + default: { return std::make_tuple<>( SOX_ENCODING_FLAC, static_cast(bps)); + } } - case Format::SPHERE: + } + case Format::SPHERE: { switch (enc) { case Encoding::NOT_PROVIDED: case Encoding::PCM_SIGNED: switch (bps) { case BitDepth::NOT_PROVIDED: return std::make_tuple<>(SOX_ENCODING_SIGN2, 32); - default: + default: { return std::make_tuple<>( SOX_ENCODING_SIGN2, static_cast(bps)); + } } - case Encoding::PCM_UNSIGNED: + case Encoding::PCM_UNSIGNED: { throw std::runtime_error( "sph does not support unsigned integer PCM."); - case Encoding::PCM_FLOAT: + } + case Encoding::PCM_FLOAT: { throw std::runtime_error("sph does not support floating point PCM."); - case Encoding::ULAW: + } + case Encoding::ULAW: { switch (bps) { case BitDepth::NOT_PROVIDED: case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_ULAW, 8); - default: + default: { throw std::runtime_error( "sph only supports 8-bit for mu-law encoding."); + } } - case Encoding::ALAW: + } + case Encoding::ALAW: { switch (bps) { case BitDepth::NOT_PROVIDED: case BitDepth::B8: return std::make_tuple<>(SOX_ENCODING_ALAW, 8); - default: + default: { return std::make_tuple<>( SOX_ENCODING_ALAW, static_cast(bps)); + } } - default: + } + default: { throw std::runtime_error( "sph does not support encoding: " + encoding.value()); + } } - default: + } + default: { throw std::runtime_error("Unsupported format: " + format); + } } } unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { - if (filetype == "mp3") + if (filetype == "mp3") { return SOX_UNSPEC; - if (filetype == "flac") + } else if (filetype == "flac") { return 24; - if (filetype == "ogg" || filetype == "vorbis") + } else if (filetype == "ogg" || filetype == "vorbis") { return SOX_UNSPEC; - if (filetype == "wav" || filetype == "amb") { + } else if (filetype == "wav" || filetype == "amb") { switch (dtype.toScalarType()) { case c10::ScalarType::Byte: return 8; @@ -403,10 +440,11 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { default: throw std::runtime_error("Unsupported dtype."); } - } - if (filetype == "sph") + } else if (filetype == "sph") { return 32; - if (filetype == "amr-nb") { + } else if (filetype == "amr-nb") { + return 16; + } else if (filetype == "gsm") { return 16; } throw std::runtime_error("Unsupported file type: " + filetype);