diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp index 59fce659f8..eb4bf24549 100644 --- a/torchaudio/csrc/sox/effects.cpp +++ b/torchaudio/csrc/sox/effects.cpp @@ -60,8 +60,8 @@ std::tuple apply_effects_tensor( // Create SoxEffectsChain const auto dtype = waveform.dtype(); torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", dtype), - /*output_encoding=*/get_encodinginfo("wav", dtype)); + /*input_encoding=*/get_tensor_encodinginfo(dtype), + /*output_encoding=*/get_tensor_encodinginfo(dtype)); // Prepare output buffer std::vector out_buffer; @@ -112,7 +112,7 @@ std::tuple apply_effects_file( // Create and run SoxEffectsChain torchaudio::sox_effects_chain::SoxEffectsChain chain( /*input_encoding=*/sf->encoding, - /*output_encoding=*/get_encodinginfo("wav", dtype)); + /*output_encoding=*/get_tensor_encodinginfo(dtype)); chain.addInputFile(sf); for (const auto& effect : effects) { @@ -214,7 +214,7 @@ std::tuple apply_effects_fileobj( const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); torchaudio::sox_effects_chain::SoxEffectsChain chain( /*input_encoding=*/sf->encoding, - /*output_encoding=*/get_encodinginfo("wav", dtype)); + /*output_encoding=*/get_tensor_encodinginfo(dtype)); chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj); for (const auto& effect : effects) { chain.addEffect(effect); diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index daf1ff102e..ee0bf3c9dd 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -175,7 +175,8 @@ void save_audio_file( } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); + const auto encoding_info = + get_encodinginfo_for_save(filetype, tgt_dtype, compression); SoxFormat sf(sox_open_write( path.c_str(), @@ -190,7 +191,7 @@ void save_audio_file( } torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), /*output_encoding=*/sf->encoding); chain.addInputTensor(&tensor, sample_rate, channels_first); chain.addOutputFile(sf); @@ -313,7 +314,8 @@ void save_audio_fileobj( } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); + const auto encoding_info = + get_encodinginfo_for_save(filetype, tgt_dtype, compression); AutoReleaseBuffer buffer; @@ -331,7 +333,7 @@ void save_audio_fileobj( } torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), /*output_encoding=*/sf->encoding); chain.addInputTensor(&tensor, sample_rate, channels_first); chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index b36f027bf4..983b7829fe 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -291,12 +291,32 @@ sox_signalinfo_t get_signalinfo( /*length=*/static_cast(waveform->numel())}; } -sox_encodinginfo_t get_encodinginfo( - const std::string filetype, - const caffe2::TypeMeta dtype) { +sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { + sox_encoding_t encoding = [&]() { + if (dtype == torch::kUInt8) + return SOX_ENCODING_UNSIGNED; + if (dtype == torch::kInt16) + return SOX_ENCODING_SIGN2; + if (dtype == torch::kInt32) + return SOX_ENCODING_SIGN2; + if (dtype == torch::kFloat32) + return SOX_ENCODING_FLOAT; + throw std::runtime_error("Unsupported dtype."); + }(); + unsigned bits_per_sample = [&]() { + if (dtype == torch::kUInt8) + return 8; + if (dtype == torch::kInt16) + return 16; + if (dtype == torch::kInt32) + return 32; + if (dtype == torch::kFloat32) + return 32; + throw std::runtime_error("Unsupported dtype."); + }(); return sox_encodinginfo_t{ - /*encoding=*/get_encoding(filetype, dtype), - /*bits_per_sample=*/get_precision(filetype, dtype), + /*encoding=*/encoding, + /*bits_per_sample=*/bits_per_sample, /*compression=*/HUGE_VAL, /*reverse_bytes=*/sox_option_default, /*reverse_nibbles=*/sox_option_default, @@ -304,7 +324,7 @@ sox_encodinginfo_t get_encodinginfo( /*opposite_endian=*/sox_false}; } -sox_encodinginfo_t get_encodinginfo( +sox_encodinginfo_t get_encodinginfo_for_save( const std::string filetype, const caffe2::TypeMeta dtype, c10::optional& compression) { diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index 57dad7e6dd..ea2a6a2953 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -108,12 +108,11 @@ sox_signalinfo_t get_signalinfo( const std::string filetype, const bool channels_first); -/// Get sox_encofinginfo_t for saving audoi file -sox_encodinginfo_t get_encodinginfo( - const std::string filetype, - const caffe2::TypeMeta dtype); +/// Get sox_encodinginfo_t for Tensor I/O +sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); -sox_encodinginfo_t get_encodinginfo( +/// Get sox_encodinginfo_t for saving to file/file object +sox_encodinginfo_t get_encodinginfo_for_save( const std::string filetype, const caffe2::TypeMeta dtype, c10::optional& compression);