diff --git a/test/torchaudio_unittest/backend/sox_io/info_test.py b/test/torchaudio_unittest/backend/sox_io/info_test.py index a2a93648a1..8701414f6e 100644 --- a/test/torchaudio_unittest/backend/sox_io/info_test.py +++ b/test/torchaudio_unittest/backend/sox_io/info_test.py @@ -205,7 +205,7 @@ def test_ulaw(self): assert info.encoding == "ULAW" def test_alaw(self): - """`sox_io_backend.info` can check ulaw file correctly""" + """`sox_io_backend.info` can check alaw file correctly""" duration = 1 num_channels = 1 sample_rate = 8000 @@ -221,6 +221,22 @@ def test_alaw(self): assert info.bits_per_sample == 8 assert info.encoding == "ALAW" + def test_htk(self): + """`sox_io_backend.info` can check HTK file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.htk') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, + bit_depth=16, duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 16 + assert info.encoding == "PCM_S" + @skipIfNoExtension class TestInfoOpus(PytorchTestCase): diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index 5d3fdb03ca..d971efc01e 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -237,6 +237,12 @@ def test_save_flac(self, test_mode, bits_per_sample, compression_level): "flac", compression=compression_level, bits_per_sample=bits_per_sample, test_mode=test_mode) + @nested_params( + ["path", "fileobj", "bytesio"], + ) + def test_save_htk(self, test_mode): + self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1) + @nested_params( ["path", "fileobj", "bytesio"], [ diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 6f33de4e05..54bacd5e5f 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -195,7 +195,8 @@ def save( When ``filepath`` argument is file-like object, this argument is required. Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, - ``"amb"``, ``"flac"``, ``"sph"`` and ``"gsm"``. + ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``. + encoding (str, optional): Changes the encoding for the supported formats. This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"`` and ``"sph"``. Valid values are; @@ -294,6 +295,9 @@ def save( ``"gsm"`` Lossy Speech Compression, CPU intensive. + ``"htk"`` + Uses a default single-channel 16-bit PCM format. + Note: To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index efffd14342..613058e171 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -20,6 +20,8 @@ Format get_format_from_string(const std::string& format) { return Format::AMB; if (format == "sph") return Format::SPHERE; + if (format == "htk") + return Format::HTK; if (format == "gsm") return Format::GSM; std::ostringstream stream; diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h index 0b52aab905..577ec473f4 100644 --- a/torchaudio/csrc/sox/types.h +++ b/torchaudio/csrc/sox/types.h @@ -16,6 +16,7 @@ enum class Format { AMB, SPHERE, GSM, + HTK, }; Format get_format_from_string(const std::string& format); diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index d491a7c1ea..7d75863943 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -314,6 +314,13 @@ std::tuple get_save_encoding( throw std::runtime_error( "mp3 does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_MP3, 16); + case Format::HTK: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("htk does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "htk does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); case Format::VORBIS: if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("vorbis does not support `encoding` option."); @@ -417,8 +424,12 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { if (filetype == "amr-nb") { return 16; } - if (filetype == "gsm") + if (filetype == "gsm") { return 16; + } + if (filetype == "htk") { + return 16; + } throw std::runtime_error("Unsupported file type: " + filetype); }