Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/torchaudio_unittest/backend/sox_io/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def test_save_flac(self, test_mode, bits_per_sample, compression_level):
"flac", compression=compression_level,
bits_per_sample=bits_per_sample, test_mode=test_mode)

@nested_params(
["path", "fileobj", "bytesio"],
)
def test_save_gsm(self, test_mode):
self.assert_save_consistency("gsm", test_mode=test_mode)

@nested_params(
["path", "fileobj", "bytesio"],
[
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def save(
When ``filepath`` argument is file-like object, this argument is required.

Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"`` and ``"sph"``.
``"amb"``, ``"flac"``, ``"sph"``, and ``"gsm"``.
encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are;
Expand Down Expand Up @@ -291,6 +291,9 @@ def save(
``"amr-nb"``
Bitrate ranging from 4.75 kbit/s to 12.2 kbit/s. Default: 4.75 kbit/s

``"gsm"``
The GSM 06.10 Lossy Speech Compression only supports its default configuration.

Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
Expand Down
33 changes: 19 additions & 14 deletions torchaudio/csrc/sox/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@ namespace torchaudio {
namespace sox_utils {

Format get_format_from_string(const std::string& format) {
if (format == "wav")
if (format == "wav") {
return Format::WAV;
if (format == "mp3")
} else if (format == "mp3") {
return Format::MP3;
if (format == "flac")
} else if (format == "flac") {
return Format::FLAC;
if (format == "ogg" || format == "vorbis")
} else if (format == "ogg" || format == "vorbis") {
return Format::VORBIS;
if (format == "amr-nb")
} else if (format == "amr-nb") {
return Format::AMR_NB;
if (format == "amr-wb")
} else if (format == "amr-wb") {
return Format::AMR_WB;
if (format == "amb")
} else if (format == "amb") {
return Format::AMB;
if (format == "sph")
} else if (format == "sph") {
return Format::SPHERE;
} else if (format == "gsm") {
return Format::GSM;
}
std::ostringstream stream;
stream << "Internal Error: unexpected format value: " << format;
throw std::runtime_error(stream.str());
Expand Down Expand Up @@ -57,19 +60,21 @@ std::string to_string(Encoding v) {
}

Encoding get_encoding_from_option(const c10::optional<std::string>& encoding) {
if (!encoding.has_value())
if (!encoding.has_value()) {
return Encoding::NOT_PROVIDED;
}
std::string v = encoding.value();
if (v == "PCM_S")
if (v == "PCM_S") {
return Encoding::PCM_SIGNED;
if (v == "PCM_U")
} else if (v == "PCM_U") {
return Encoding::PCM_UNSIGNED;
if (v == "PCM_F")
} else if (v == "PCM_F") {
return Encoding::PCM_FLOAT;
if (v == "ULAW")
} else if (v == "ULAW") {
return Encoding::ULAW;
if (v == "ALAW")
} else if (v == "ALAW") {
return Encoding::ALAW;
}
std::ostringstream stream;
stream << "Internal Error: unexpected encoding value: " << v;
throw std::runtime_error(stream.str());
Expand Down
1 change: 1 addition & 0 deletions torchaudio/csrc/sox/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ enum class Format {
AMR_WB,
AMB,
SPHERE,
GSM,
};

Format get_format_from_string(const std::string& format);
Expand Down
110 changes: 74 additions & 36 deletions torchaudio/csrc/sox/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
const Encoding& encoding,
const BitDepth& bits_per_sample) {
switch (encoding) {
case Encoding::NOT_PROVIDED:
case Encoding::NOT_PROVIDED: {
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::NOT_PROVIDED: {
switch (dtype.toScalarType()) {
case c10::ScalarType::Float:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
Expand All @@ -232,65 +232,79 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
default:
throw std::runtime_error("Internal Error: Unexpected dtype.");
}
}
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
default: {
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
}
case Encoding::PCM_SIGNED:
}
case Encoding::PCM_SIGNED: {
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
case BitDepth::B8:
throw std::runtime_error(
format + " does not support 8-bit signed PCM encoding.");
default:
default: {
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bits_per_sample));
}
}
case Encoding::PCM_UNSIGNED:
}
case Encoding::PCM_UNSIGNED: {
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
default:
default: {
throw std::runtime_error(
format + " only supports 8-bit for unsigned PCM encoding.");
}
}
case Encoding::PCM_FLOAT:
}
case Encoding::PCM_FLOAT: {
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B32:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
case BitDepth::B64:
return std::make_tuple<>(SOX_ENCODING_FLOAT, 64);
default:
default: {
throw std::runtime_error(
format +
" only supports 32-bit or 64-bit for floating-point PCM encoding.");
}
}
case Encoding::ULAW:
}
case Encoding::ULAW: {
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
default: {
throw std::runtime_error(
format + " only supports 8-bit for mu-law encoding.");
}
}
case Encoding::ALAW:
}
case Encoding::ALAW: {
switch (bits_per_sample) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
default: {
throw std::runtime_error(
format + " only supports 8-bit for a-law encoding.");
}
}
default:
}
default: {
throw std::runtime_error(
format + " does not support encoding: " + to_string(encoding));
}
}
}

Expand All @@ -307,90 +321,113 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
case Format::WAV:
case Format::AMB:
return get_save_encoding_for_wav(format, dtype, enc, bps);
case Format::MP3:
case Format::MP3: {
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("mp3 does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::VORBIS:
}
case Format::GSM: {
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("gsm does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"gsm does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_GSM, 16);
}
case Format::VORBIS: {
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"vorbis does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_VORBIS, 16);
case Format::AMR_NB:
}
case Format::AMR_NB: {
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("amr-nb does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"amr-nb does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_AMR_NB, 16);
case Format::FLAC:
}
case Format::FLAC: {
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("flac does not support `encoding` option.");
switch (bps) {
case BitDepth::B32:
case BitDepth::B64:
throw std::runtime_error(
"flac does not support `bits_per_sample` larger than 24.");
default:
default: {
return std::make_tuple<>(
SOX_ENCODING_FLAC, static_cast<unsigned>(bps));
}
}
case Format::SPHERE:
}
case Format::SPHERE: {
switch (enc) {
case Encoding::NOT_PROVIDED:
case Encoding::PCM_SIGNED:
switch (bps) {
case BitDepth::NOT_PROVIDED:
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
default:
default: {
return std::make_tuple<>(
SOX_ENCODING_SIGN2, static_cast<unsigned>(bps));
}
}
case Encoding::PCM_UNSIGNED:
case Encoding::PCM_UNSIGNED: {
throw std::runtime_error(
"sph does not support unsigned integer PCM.");
case Encoding::PCM_FLOAT:
}
case Encoding::PCM_FLOAT: {
throw std::runtime_error("sph does not support floating point PCM.");
case Encoding::ULAW:
}
case Encoding::ULAW: {
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ULAW, 8);
default:
default: {
throw std::runtime_error(
"sph only supports 8-bit for mu-law encoding.");
}
}
case Encoding::ALAW:
}
case Encoding::ALAW: {
switch (bps) {
case BitDepth::NOT_PROVIDED:
case BitDepth::B8:
return std::make_tuple<>(SOX_ENCODING_ALAW, 8);
default:
default: {
return std::make_tuple<>(
SOX_ENCODING_ALAW, static_cast<unsigned>(bps));
}
}
default:
}
default: {
throw std::runtime_error(
"sph does not support encoding: " + encoding.value());
}
}
default:
}
default: {
throw std::runtime_error("Unsupported format: " + format);
}
}
}

unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
if (filetype == "mp3")
if (filetype == "mp3") {
return SOX_UNSPEC;
if (filetype == "flac")
} else if (filetype == "flac") {
return 24;
if (filetype == "ogg" || filetype == "vorbis")
} else if (filetype == "ogg" || filetype == "vorbis") {
return SOX_UNSPEC;
if (filetype == "wav" || filetype == "amb") {
} else if (filetype == "wav" || filetype == "amb") {
switch (dtype.toScalarType()) {
case c10::ScalarType::Byte:
return 8;
Expand All @@ -403,10 +440,11 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
default:
throw std::runtime_error("Unsupported dtype.");
}
}
if (filetype == "sph")
} else if (filetype == "sph") {
return 32;
if (filetype == "amr-nb") {
} else if (filetype == "amr-nb") {
return 16;
} else if (filetype == "gsm") {
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype);
Expand Down