diff --git a/test/torchaudio_unittest/backend/sox_io/info_test.py b/test/torchaudio_unittest/backend/sox_io/info_test.py index d0b9f24fcb..000acd4fa0 100644 --- a/test/torchaudio_unittest/backend/sox_io/info_test.py +++ b/test/torchaudio_unittest/backend/sox_io/info_test.py @@ -205,7 +205,7 @@ def test_ulaw(self): assert info.encoding == "ULAW" def test_alaw(self): - """`sox_io_backend.info` can check ulaw file correctly""" + """`sox_io_backend.info` can check alaw file correctly""" duration = 1 num_channels = 1 sample_rate = 8000 @@ -221,6 +221,22 @@ def test_alaw(self): assert info.bits_per_sample == 8 assert info.encoding == "ALAW" + def test_htk(self): + """`sox_io_backend.info` can check HTK file correctly""" + duration = 1 + num_channels = 1 + sample_rate = 8000 + path = self.get_temp_path('data.htk') + sox_utils.gen_audio_file( + path, sample_rate=sample_rate, num_channels=num_channels, + bit_depth=16, duration=duration) + info = sox_io_backend.info(path) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == 16 + assert info.encoding == "PCM_S" + def test_gsm(self): """`sox_io_backend.info` can check gsm file correctly""" duration = 1 diff --git a/test/torchaudio_unittest/backend/sox_io/save_test.py b/test/torchaudio_unittest/backend/sox_io/save_test.py index 5d3fdb03ca..da9365a24a 100644 --- a/test/torchaudio_unittest/backend/sox_io/save_test.py +++ b/test/torchaudio_unittest/backend/sox_io/save_test.py @@ -317,6 +317,12 @@ def test_save_gsm(self, test_mode): self.assert_save_consistency( "gsm", test_mode=test_mode) + @nested_params( + ["path", "fileobj", "bytesio"], + ) + def test_save_htk(self, test_mode): + self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1) + @parameterized.expand([ ("wav", "PCM_S", 16), ("mp3", ), diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 91b4474a9d..036aa5f4ac 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -195,7 +195,7 @@ def save( When ``filepath`` argument is file-like object, this argument is required. Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, - ``"amb"``, ``"flac"``, ``"sph"`` and ``"gsm"``. + ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``. encoding (str, optional): Changes the encoding for the supported formats. This argument is effective only for supported formats, such as ``"wav"``, ``""amb"`` and ``"sph"``. Valid values are; @@ -294,6 +294,9 @@ def save( ``"gsm"`` Lossy Speech Compression, CPU intensive. + ``"htk"`` + Uses its default single-channel 16-bit PCM format. + Note: To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index 729c64666c..42edf105e4 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -97,6 +97,10 @@ void save_audio_file( const auto num_channels = tensor.size(channels_first ? 0 : 1); TORCH_CHECK( num_channels == 1, "amr-nb format only supports single channel audio."); + } else if (filetype == "htk") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + TORCH_CHECK( + num_channels == 1, "htk format only supports single channel audio."); } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); @@ -233,6 +237,12 @@ void save_audio_fileobj( throw std::runtime_error( "amr-nb format only supports single channel audio."); } + } else if (filetype == "htk") { + const auto num_channels = tensor.size(channels_first ? 0 : 1); + if (num_channels != 1) { + throw std::runtime_error( + "htk format only supports single channel audio."); + } } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); diff --git a/torchaudio/csrc/sox/types.cpp b/torchaudio/csrc/sox/types.cpp index d70fe218f0..b73b34b1fc 100644 --- a/torchaudio/csrc/sox/types.cpp +++ b/torchaudio/csrc/sox/types.cpp @@ -20,6 +20,8 @@ Format get_format_from_string(const std::string& format) { return Format::AMB; if (format == "sph") return Format::SPHERE; + if (format == "htk") + return Format::HTK; if (format == "gsm") return Format::GSM; std::ostringstream stream; diff --git a/torchaudio/csrc/sox/types.h b/torchaudio/csrc/sox/types.h index b13510cde1..2f29bd1c00 100644 --- a/torchaudio/csrc/sox/types.h +++ b/torchaudio/csrc/sox/types.h @@ -17,6 +17,7 @@ enum class Format { AMB, SPHERE, GSM, + HTK, }; Format get_format_from_string(const std::string& format); diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 71bca54b7e..83bb31bc9c 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -314,6 +314,13 @@ std::tuple get_save_encoding( throw std::runtime_error( "mp3 does not support `bits_per_sample` option."); return std::make_tuple<>(SOX_ENCODING_MP3, 16); + case Format::HTK: + if (enc != Encoding::NOT_PROVIDED) + throw std::runtime_error("htk does not support `encoding` option."); + if (bps != BitDepth::NOT_PROVIDED) + throw std::runtime_error( + "htk does not support `bits_per_sample` option."); + return std::make_tuple<>(SOX_ENCODING_SIGN2, 16); case Format::VORBIS: if (enc != Encoding::NOT_PROVIDED) throw std::runtime_error("vorbis does not support `encoding` option."); @@ -417,8 +424,12 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { if (filetype == "amr-nb") { return 16; } - if (filetype == "gsm") + if (filetype == "gsm") { return 16; + } + if (filetype == "htk") { + return 16; + } throw std::runtime_error("Unsupported file type: " + filetype); }