Skip to content
Merged
14 changes: 7 additions & 7 deletions torchaudio/csrc/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ void host_lfilter_core_loop(
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) {
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {

at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
for (auto i = begin; i < end; i++) {
int64_t offset_input = i * n_samples_input;
int64_t offset_output = i * n_samples_output;
int64_t i_channel = i % n_channel;
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input =
((i_batch * n_channel) + i_channel) * n_samples_input;
int64_t offset_output =
((i_batch * n_channel) + i_channel) * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
Expand All @@ -31,7 +31,7 @@ void host_lfilter_core_loop(
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
}
});
}

void cpu_lfilter_core_loop(
Expand Down