Skip to content

Commit b5cd948

Browse files
committed
Replace save function with sox effects chain
1 parent c7520eb commit b5cd948

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

torchaudio/csrc/sox_effects_chain.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ struct TensorInputPriv {
4646
struct TensorOutputPriv {
4747
std::vector<sox_sample_t>* buffer;
4848
};
49+
struct FileOutputPriv {
50+
sox_format_t* sf;
51+
};
4952

5053
/// Callback function to feed Tensor data to SoxEffectChain.
5154
int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
@@ -84,7 +87,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
8487

8588
/// Callback function to fetch data from SoxEffectChain.
8689
int tensor_output_flow(
87-
sox_effect_t* effp LSX_UNUSED,
90+
sox_effect_t* effp,
8891
sox_sample_t const* ibuf,
8992
sox_sample_t* obuf LSX_UNUSED,
9093
size_t* isamp,
@@ -97,6 +100,28 @@ int tensor_output_flow(
97100
return SOX_SUCCESS;
98101
}
99102

103+
int file_output_flow(
104+
sox_effect_t* effp,
105+
sox_sample_t const* ibuf,
106+
sox_sample_t* obuf LSX_UNUSED,
107+
size_t* isamp,
108+
size_t* osamp) {
109+
*osamp = 0;
110+
if (*isamp) {
111+
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
112+
if (sox_write(sf, ibuf, *isamp) != *isamp) {
113+
if (sf->sox_errno) {
114+
std::ostringstream stream;
115+
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
116+
<< sf->filename;
117+
throw std::runtime_error(stream.str());
118+
}
119+
return SOX_EOF;
120+
}
121+
}
122+
return SOX_SUCCESS;
123+
}
124+
100125
sox_effect_handler_t* get_tensor_input_handler() {
101126
static sox_effect_handler_t handler{/*name=*/"input_tensor",
102127
/*usage=*/NULL,
@@ -125,6 +150,20 @@ sox_effect_handler_t* get_tensor_output_handler() {
125150
return &handler;
126151
}
127152

153+
sox_effect_handler_t* get_file_output_handler() {
154+
static sox_effect_handler_t handler{/*name=*/"output_file",
155+
/*usage=*/NULL,
156+
/*flags=*/SOX_EFF_MCHAN,
157+
/*getopts=*/NULL,
158+
/*start=*/NULL,
159+
/*flow=*/file_output_flow,
160+
/*drain=*/NULL,
161+
/*stop=*/NULL,
162+
/*kill=*/NULL,
163+
/*priv_size=*/sizeof(FileOutputPriv)};
164+
return &handler;
165+
}
166+
128167
} // namespace
129168

130169
SoxEffectsChain::SoxEffectsChain(
@@ -134,6 +173,7 @@ SoxEffectsChain::SoxEffectsChain(
134173
out_enc_(output_encoding),
135174
in_sig_(),
136175
interm_sig_(),
176+
out_sig_(),
137177
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
138178
if (!sec_) {
139179
throw std::runtime_error("Failed to create effect chain.");
@@ -184,6 +224,17 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
184224
}
185225
}
186226

227+
void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
228+
out_sig_ = sf->signal;
229+
SoxEffect e(sox_create_effect(get_file_output_handler()));
230+
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
231+
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
232+
std::ostringstream stream;
233+
stream << "Failed to add effect: output " << sf->filename;
234+
throw std::runtime_error(stream.str());
235+
}
236+
}
237+
187238
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
188239
const auto num_args = effect.size();
189240
if (num_args == 0) {

torchaudio/csrc/sox_effects_chain.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SoxEffectsChain {
1414
const sox_encodinginfo_t out_enc_;
1515
sox_signalinfo_t in_sig_;
1616
sox_signalinfo_t interm_sig_;
17+
sox_signalinfo_t out_sig_;
1718
sox_effects_chain_t* sec_;
1819

1920
public:
@@ -29,6 +30,7 @@ class SoxEffectsChain {
2930
void addInputTensor(torchaudio::sox_utils::TensorSignal* signal);
3031
void addInputFile(sox_format_t* sf);
3132
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
33+
void addOutputFile(sox_format_t* sf);
3234
void addEffect(const std::vector<std::string> effect);
3335
int64_t getOutputNumChannels();
3436
int64_t getOutputSampleRate();

torchaudio/csrc/sox_io.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include <sox.h>
2+
#include <torchaudio/csrc/sox_effects.h>
3+
#include <torchaudio/csrc/sox_effects_chain.h>
24
#include <torchaudio/csrc/sox_io.h>
35
#include <torchaudio/csrc/sox_utils.h>
4-
#include <torchaudio/csrc/sox_effects.h>
56

67
using namespace torch::indexing;
78
using namespace torchaudio::sox_utils;
@@ -66,22 +67,23 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
6667
std::ostringstream offset, frames;
6768
offset << frame_offset << "s";
6869
frames << "+" << num_frames << "s";
69-
effects.emplace_back(std::vector<std::string>{"trim", offset.str(), frames.str()});
70+
effects.emplace_back(
71+
std::vector<std::string>{"trim", offset.str(), frames.str()});
7072
} else if (frame_offset != 0) {
7173
std::ostringstream offset;
7274
offset << frame_offset << "s";
7375
effects.emplace_back(std::vector<std::string>{"trim", offset.str()});
7476
}
7577

76-
return torchaudio::sox_effects::apply_effects_file(path, effects, normalize, channels_first);
78+
return torchaudio::sox_effects::apply_effects_file(
79+
path, effects, normalize, channels_first);
7780
}
7881

7982
void save_audio_file(
8083
const std::string& file_name,
8184
const c10::intrusive_ptr<TensorSignal>& signal,
8285
const double compression) {
8386
const auto tensor = signal->getTensor();
84-
const auto channels_first = signal->getChannelsFirst();
8587

8688
validate_input_tensor(tensor);
8789

@@ -102,22 +104,12 @@ void save_audio_file(
102104
throw std::runtime_error("Error saving audio file: failed to open file.");
103105
}
104106

105-
auto tensor_ = tensor;
106-
if (channels_first) {
107-
tensor_ = tensor_.t();
108-
}
109-
110-
const int64_t frames_per_chunk = 65536;
111-
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
112-
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
113-
chunk = unnormalize_wav(chunk).contiguous();
114-
115-
const size_t numel = chunk.numel();
116-
if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
117-
throw std::runtime_error(
118-
"Error saving audio file: failed to write the entier buffer.");
119-
}
120-
}
107+
torchaudio::sox_effects_chain::SoxEffectsChain chain(
108+
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype(), 0.),
109+
/*output_encoding=*/sf->encoding);
110+
chain.addInputTensor(signal.get());
111+
chain.addOutputFile(sf);
112+
chain.run();
121113
}
122114

123115
} // namespace sox_io

0 commit comments

Comments
 (0)