Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion torchaudio/csrc/sox_effects_chain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ struct TensorInputPriv {
struct TensorOutputPriv {
std::vector<sox_sample_t>* buffer;
};
struct FileOutputPriv {
sox_format_t* sf;
};

/// Callback function to feed Tensor data to SoxEffectChain.
int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
Expand Down Expand Up @@ -84,7 +87,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {

/// Callback function to fetch data from SoxEffectChain.
int tensor_output_flow(
sox_effect_t* effp LSX_UNUSED,
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
Expand All @@ -97,6 +100,28 @@ int tensor_output_flow(
return SOX_SUCCESS;
}

int file_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
if (*isamp) {
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
if (sox_write(sf, ibuf, *isamp) != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}

sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_tensor",
/*usage=*/NULL,
Expand Down Expand Up @@ -125,6 +150,20 @@ sox_effect_handler_t* get_tensor_output_handler() {
return &handler;
}

sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_file",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/file_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileOutputPriv)};
return &handler;
}

} // namespace

SoxEffectsChain::SoxEffectsChain(
Expand All @@ -134,6 +173,7 @@ SoxEffectsChain::SoxEffectsChain(
out_enc_(output_encoding),
in_sig_(),
interm_sig_(),
out_sig_(),
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
if (!sec_) {
throw std::runtime_error("Failed to create effect chain.");
Expand Down Expand Up @@ -184,6 +224,17 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
}
}

void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_file_output_handler()));
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Failed to add effect: output " << sf->filename;
throw std::runtime_error(stream.str());
}
}

void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
const auto num_args = effect.size();
if (num_args == 0) {
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/csrc/sox_effects_chain.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SoxEffectsChain {
const sox_encodinginfo_t out_enc_;
sox_signalinfo_t in_sig_;
sox_signalinfo_t interm_sig_;
sox_signalinfo_t out_sig_;
sox_effects_chain_t* sec_;

public:
Expand All @@ -29,6 +30,7 @@ class SoxEffectsChain {
void addInputTensor(torchaudio::sox_utils::TensorSignal* signal);
void addInputFile(sox_format_t* sf);
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
void addOutputFile(sox_format_t* sf);
void addEffect(const std::vector<std::string> effect);
int64_t getOutputNumChannels();
int64_t getOutputSampleRate();
Expand Down
94 changes: 21 additions & 73 deletions torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <sox.h>
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_effects_chain.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>

Expand Down Expand Up @@ -60,72 +62,28 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
"Invalid argument: num_frames must be -1 or greater than 0.");
}

SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));

validate_input_file(sf);

const int64_t num_channels = sf->signal.channels;
const int64_t num_total_samples = sf->signal.length;
const int64_t sample_start = sf->signal.channels * frame_offset;

if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
throw std::runtime_error("Error reading audio file: offset past EOF.");
std::vector<std::vector<std::string>> effects;
if (num_frames != -1) {
std::ostringstream offset, frames;
offset << frame_offset << "s";
frames << "+" << num_frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", offset.str(), frames.str()});
} else if (frame_offset != 0) {
std::ostringstream offset;
offset << frame_offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", offset.str()});
}

const int64_t sample_end = [&]() {
if (num_frames == -1)
return num_total_samples;
const int64_t sample_end_ = num_channels * num_frames + sample_start;
if (num_total_samples < sample_end_) {
// For lossy encoding, it is difficult to predict exact size of buffer for
// reading the number of samples required.
// So we allocate buffer size of given `num_frames` and ask sox to read as
// much as possible. For lossless format, sox reads exact number of
// samples, but for lossy encoding, sox can end up reading less. (i.e.
// mp3) For the consistent behavior specification between lossy/lossless
// format, we allow users to provide `num_frames` value that exceeds #of
// available samples, and we adjust it here.
return num_total_samples;
}
return sample_end_;
}();

const int64_t max_samples = sample_end - sample_start;

// Read samples into buffer
std::vector<sox_sample_t> buffer;
buffer.reserve(max_samples);
const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
if (num_samples == 0) {
throw std::runtime_error(
"Error reading audio file: empty file or read operation failed.");
}
// NOTE: num_samples may be smaller than max_samples if the input
// format is compressed (i.e. mp3).

// Convert to Tensor
auto tensor = convert_to_tensor(
buffer.data(),
num_samples,
num_channels,
get_dtype(sf->encoding.encoding, sf->signal.precision),
normalize,
channels_first);

return c10::make_intrusive<TensorSignal>(
tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first);
}

void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression) {
const auto tensor = signal->getTensor();
const auto channels_first = signal->getChannelsFirst();

validate_input_tensor(tensor);

Expand All @@ -146,22 +104,12 @@ void save_audio_file(
throw std::runtime_error("Error saving audio file: failed to open file.");
}

auto tensor_ = tensor;
if (channels_first) {
tensor_ = tensor_.t();
}

const int64_t frames_per_chunk = 65536;
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous();

const size_t numel = chunk.numel();
if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
throw std::runtime_error(
"Error saving audio file: failed to write the entier buffer.");
}
}
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype(), 0.),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(signal.get());
chain.addOutputFile(sf);
chain.run();
}

} // namespace sox_io
Expand Down