From b8f3b257a37461429b6bbb3268c120688532d3ae Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 2 Feb 2021 13:10:50 -0800 Subject: [PATCH 1/3] Distinguish get_encodinginfo for Tensor I/O and save output --- torchaudio/csrc/sox/effects.cpp | 8 ++++---- torchaudio/csrc/sox/io.cpp | 8 ++++---- torchaudio/csrc/sox/utils.cpp | 9 ++++----- torchaudio/csrc/sox/utils.h | 9 ++++----- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/torchaudio/csrc/sox/effects.cpp b/torchaudio/csrc/sox/effects.cpp index 59fce659f8..eb4bf24549 100644 --- a/torchaudio/csrc/sox/effects.cpp +++ b/torchaudio/csrc/sox/effects.cpp @@ -60,8 +60,8 @@ std::tuple apply_effects_tensor( // Create SoxEffectsChain const auto dtype = waveform.dtype(); torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", dtype), - /*output_encoding=*/get_encodinginfo("wav", dtype)); + /*input_encoding=*/get_tensor_encodinginfo(dtype), + /*output_encoding=*/get_tensor_encodinginfo(dtype)); // Prepare output buffer std::vector out_buffer; @@ -112,7 +112,7 @@ std::tuple apply_effects_file( // Create and run SoxEffectsChain torchaudio::sox_effects_chain::SoxEffectsChain chain( /*input_encoding=*/sf->encoding, - /*output_encoding=*/get_encodinginfo("wav", dtype)); + /*output_encoding=*/get_tensor_encodinginfo(dtype)); chain.addInputFile(sf); for (const auto& effect : effects) { @@ -214,7 +214,7 @@ std::tuple apply_effects_fileobj( const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); torchaudio::sox_effects_chain::SoxEffectsChain chain( /*input_encoding=*/sf->encoding, - /*output_encoding=*/get_encodinginfo("wav", dtype)); + /*output_encoding=*/get_tensor_encodinginfo(dtype)); chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj); for (const auto& effect : effects) { chain.addEffect(effect); diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index daf1ff102e..ce7db11d5a 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -175,7 +175,7 @@ void save_audio_file( } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); + const auto encoding_info = get_encodinginfo_for_save(filetype, tgt_dtype, compression); SoxFormat sf(sox_open_write( path.c_str(), @@ -190,7 +190,7 @@ void save_audio_file( } torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), /*output_encoding=*/sf->encoding); chain.addInputTensor(&tensor, sample_rate, channels_first); chain.addOutputFile(sf); @@ -313,7 +313,7 @@ void save_audio_fileobj( } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); + const auto encoding_info = get_encodinginfo_for_save(filetype, tgt_dtype, compression); AutoReleaseBuffer buffer; @@ -331,7 +331,7 @@ void save_audio_fileobj( } torchaudio::sox_effects_chain::SoxEffectsChain chain( - /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), + /*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()), /*output_encoding=*/sf->encoding); chain.addInputTensor(&tensor, sample_rate, channels_first); chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index b36f027bf4..9960e9fe74 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -291,12 +291,11 @@ sox_signalinfo_t get_signalinfo( /*length=*/static_cast(waveform->numel())}; } -sox_encodinginfo_t get_encodinginfo( - const std::string filetype, +sox_encodinginfo_t get_tensor_encodinginfo( const caffe2::TypeMeta dtype) { return sox_encodinginfo_t{ - /*encoding=*/get_encoding(filetype, dtype), - /*bits_per_sample=*/get_precision(filetype, dtype), + /*encoding=*/get_encoding("wav", dtype), + /*bits_per_sample=*/get_precision("wav", dtype), /*compression=*/HUGE_VAL, /*reverse_bytes=*/sox_option_default, /*reverse_nibbles=*/sox_option_default, @@ -304,7 +303,7 @@ sox_encodinginfo_t get_encodinginfo( /*opposite_endian=*/sox_false}; } -sox_encodinginfo_t get_encodinginfo( +sox_encodinginfo_t get_encodinginfo_for_save( const std::string filetype, const caffe2::TypeMeta dtype, c10::optional& compression) { diff --git a/torchaudio/csrc/sox/utils.h b/torchaudio/csrc/sox/utils.h index 57dad7e6dd..ea2a6a2953 100644 --- a/torchaudio/csrc/sox/utils.h +++ b/torchaudio/csrc/sox/utils.h @@ -108,12 +108,11 @@ sox_signalinfo_t get_signalinfo( const std::string filetype, const bool channels_first); -/// Get sox_encofinginfo_t for saving audoi file -sox_encodinginfo_t get_encodinginfo( - const std::string filetype, - const caffe2::TypeMeta dtype); +/// Get sox_encodinginfo_t for Tensor I/O +sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype); -sox_encodinginfo_t get_encodinginfo( +/// Get sox_encodinginfo_t for saving to file/file object +sox_encodinginfo_t get_encodinginfo_for_save( const std::string filetype, const caffe2::TypeMeta dtype, c10::optional& compression); From e8d6daf4bdc86137d50f270af35299264b89b09e Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 2 Feb 2021 13:25:09 -0800 Subject: [PATCH 2/3] Isolate get_tensor_encodinginfo --- torchaudio/csrc/sox/utils.cpp | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 9960e9fe74..23f6388783 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -293,9 +293,31 @@ sox_signalinfo_t get_signalinfo( sox_encodinginfo_t get_tensor_encodinginfo( const caffe2::TypeMeta dtype) { + sox_encoding_t encoding = [&](){ + if (dtype == torch::kUInt8) + return SOX_ENCODING_UNSIGNED; + if (dtype == torch::kInt16) + return SOX_ENCODING_SIGN2; + if (dtype == torch::kInt32) + return SOX_ENCODING_SIGN2; + if (dtype == torch::kFloat32) + return SOX_ENCODING_FLOAT; + throw std::runtime_error("Unsupported dtype."); + }(); + unsigned bits_per_sample = [&](){ + if (dtype == torch::kUInt8) + return 8; + if (dtype == torch::kInt16) + return 16; + if (dtype == torch::kInt32) + return 32; + if (dtype == torch::kFloat32) + return 32; + throw std::runtime_error("Unsupported dtype."); + }(); return sox_encodinginfo_t{ - /*encoding=*/get_encoding("wav", dtype), - /*bits_per_sample=*/get_precision("wav", dtype), + /*encoding=*/encoding, + /*bits_per_sample=*/bits_per_sample, /*compression=*/HUGE_VAL, /*reverse_bytes=*/sox_option_default, /*reverse_nibbles=*/sox_option_default, From dfc09211ca6a92dcd588b3ddd753c50a59d0f4ee Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 3 Feb 2021 16:16:31 +0000 Subject: [PATCH 3/3] Fix style --- torchaudio/csrc/sox/io.cpp | 6 ++++-- torchaudio/csrc/sox/utils.cpp | 7 +++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index ce7db11d5a..ee0bf3c9dd 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -175,7 +175,8 @@ void save_audio_file( } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = get_encodinginfo_for_save(filetype, tgt_dtype, compression); + const auto encoding_info = + get_encodinginfo_for_save(filetype, tgt_dtype, compression); SoxFormat sf(sox_open_write( path.c_str(), @@ -313,7 +314,8 @@ void save_audio_fileobj( } const auto signal_info = get_signalinfo(&tensor, sample_rate, filetype, channels_first); - const auto encoding_info = get_encodinginfo_for_save(filetype, tgt_dtype, compression); + const auto encoding_info = + get_encodinginfo_for_save(filetype, tgt_dtype, compression); AutoReleaseBuffer buffer; diff --git a/torchaudio/csrc/sox/utils.cpp b/torchaudio/csrc/sox/utils.cpp index 23f6388783..983b7829fe 100644 --- a/torchaudio/csrc/sox/utils.cpp +++ b/torchaudio/csrc/sox/utils.cpp @@ -291,9 +291,8 @@ sox_signalinfo_t get_signalinfo( /*length=*/static_cast(waveform->numel())}; } -sox_encodinginfo_t get_tensor_encodinginfo( - const caffe2::TypeMeta dtype) { - sox_encoding_t encoding = [&](){ +sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) { + sox_encoding_t encoding = [&]() { if (dtype == torch::kUInt8) return SOX_ENCODING_UNSIGNED; if (dtype == torch::kInt16) @@ -304,7 +303,7 @@ sox_encodinginfo_t get_tensor_encodinginfo( return SOX_ENCODING_FLOAT; throw std::runtime_error("Unsupported dtype."); }(); - unsigned bits_per_sample = [&](){ + unsigned bits_per_sample = [&]() { if (dtype == torch::kUInt8) return 8; if (dtype == torch::kInt16)