Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(
sox/effects.cpp
sox/effects_chain.cpp
sox/types.cpp
lfilter.cpp
)

if(BUILD_TRANSDUCER)
Expand Down
71 changes: 71 additions & 0 deletions torchaudio/csrc/lfilter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <torch/script.h>

namespace {

template <typename scalar_t>
void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_channel = input_signal_windows.size(0);
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_samples_output = padded_output_waveform.size(1);
int64_t n_order = a_coeff_flipped.size(0);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input = i_channel * n_samples_input;
int64_t offset_output = i_channel * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
}

void cpu_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
TORCH_CHECK(
input_signal_windows.device().is_cpu() &&
a_coeff_flipped.device().is_cpu() &&
padded_output_waveform.device().is_cpu());

TORCH_CHECK(
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
padded_output_waveform.is_contiguous());

TORCH_CHECK(
(input_signal_windows.dtype() == torch::kFloat32 ||
input_signal_windows.dtype() == torch::kFloat64) &&
(a_coeff_flipped.dtype() == torch::kFloat32 ||
a_coeff_flipped.dtype() == torch::kFloat64) &&
(padded_output_waveform.dtype() == torch::kFloat32 ||
padded_output_waveform.dtype() == torch::kFloat64));

TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));

TORCH_CHECK(
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 ==
padded_output_waveform.size(1));

AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
host_lfilter_core_loop<scalar_t>(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
});
}

} // namespace

// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}
30 changes: 24 additions & 6 deletions torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,23 @@ def highpass_biquad(
return biquad(waveform, b0, b1, b2, a0, a1, a2)


def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor):
n_order = a_coeffs_flipped.size(0)
for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[
:, i_sample:i_sample + n_order
]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0


try:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter_core_loop'
_lfilter_core_cpu_loop = _lfilter_core_generic_loop


def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
Expand Down Expand Up @@ -877,12 +894,13 @@ def lfilter(

input_signal_windows.div_(a_coeffs[0])
a_coeffs_flipped.div_(a_coeffs[0])
for i_sample, o0 in enumerate(input_signal_windows.t()):
windowed_output_signal = padded_output_waveform[
:, i_sample:i_sample + n_order
]
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
padded_output_waveform[:, i_sample + n_order - 1] = o0

if input_signal_windows.device == torch.device('cpu') and\
a_coeffs_flipped.device == torch.device('cpu') and\
padded_output_waveform.device == torch.device('cpu'):
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)

output = padded_output_waveform[:, n_order - 1:]

Expand Down