diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index dff7a8368a..1bab67be5a 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -10,6 +10,7 @@ set( sox/effects.cpp sox/effects_chain.cpp sox/types.cpp + lfilter.cpp ) if(BUILD_TRANSDUCER) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp new file mode 100644 index 0000000000..af9425fd15 --- /dev/null +++ b/torchaudio/csrc/lfilter.cpp @@ -0,0 +1,71 @@ +#include + +namespace { + +template +void host_lfilter_core_loop( + const torch::Tensor& input_signal_windows, + const torch::Tensor& a_coeff_flipped, + torch::Tensor& padded_output_waveform) { + int64_t n_channel = input_signal_windows.size(0); + int64_t n_samples_input = input_signal_windows.size(1); + int64_t n_samples_output = padded_output_waveform.size(1); + int64_t n_order = a_coeff_flipped.size(0); + scalar_t* output_data = padded_output_waveform.data_ptr(); + const scalar_t* input_data = input_signal_windows.data_ptr(); + const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr(); + for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + int64_t offset_input = i_channel * n_samples_input; + int64_t offset_output = i_channel * n_samples_output; + scalar_t a0 = input_data[offset_input + i_sample]; + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { + a0 -= output_data[offset_output + i_sample + i_coeff] * + a_coeff_flipped_data[i_coeff]; + } + output_data[offset_output + i_sample + n_order - 1] = a0; + } + } +} + +void cpu_lfilter_core_loop( + const torch::Tensor& input_signal_windows, + const torch::Tensor& a_coeff_flipped, + torch::Tensor& padded_output_waveform) { + TORCH_CHECK( + input_signal_windows.device().is_cpu() && + a_coeff_flipped.device().is_cpu() && + padded_output_waveform.device().is_cpu()); + + TORCH_CHECK( + input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && + padded_output_waveform.is_contiguous()); + + TORCH_CHECK( + (input_signal_windows.dtype() == torch::kFloat32 || + input_signal_windows.dtype() == torch::kFloat64) && + (a_coeff_flipped.dtype() == torch::kFloat32 || + a_coeff_flipped.dtype() == torch::kFloat64) && + (padded_output_waveform.dtype() == torch::kFloat32 || + padded_output_waveform.dtype() == torch::kFloat64)); + + TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); + + TORCH_CHECK( + input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 == + padded_output_waveform.size(1)); + + AT_DISPATCH_FLOATING_TYPES( + input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { + host_lfilter_core_loop( + input_signal_windows, a_coeff_flipped, padded_output_waveform); + }); +} + +} // namespace + +// Note: We want to avoid using "catch-all" kernel. +// The following registration should be replaced with CPU specific registration. +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); +} diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 61b8fa06b5..3c27a0d617 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -808,6 +808,23 @@ def highpass_biquad( return biquad(waveform, b0, b1, b2, a0, a1, a2) +def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): + n_order = a_coeffs_flipped.size(0) + for i_sample, o0 in enumerate(input_signal_windows.t()): + windowed_output_signal = padded_output_waveform[ + :, i_sample:i_sample + n_order + ] + o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) + padded_output_waveform[:, i_sample + n_order - 1] = o0 + + +try: + _lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop +except RuntimeError as err: + assert str(err) == 'No such operator torchaudio::_lfilter_core_loop' + _lfilter_core_cpu_loop = _lfilter_core_generic_loop + + def lfilter( waveform: Tensor, a_coeffs: Tensor, @@ -877,12 +894,13 @@ def lfilter( input_signal_windows.div_(a_coeffs[0]) a_coeffs_flipped.div_(a_coeffs[0]) - for i_sample, o0 in enumerate(input_signal_windows.t()): - windowed_output_signal = padded_output_waveform[ - :, i_sample:i_sample + n_order - ] - o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) - padded_output_waveform[:, i_sample + n_order - 1] = o0 + + if input_signal_windows.device == torch.device('cpu') and\ + a_coeffs_flipped.device == torch.device('cpu') and\ + padded_output_waveform.device == torch.device('cpu'): + _lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) + else: + _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) output = padded_output_waveform[:, n_order - 1:]