Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions torchaudio/csrc/lfilter.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <torch/script.h>
#include <torch/torch.h>

namespace {

Expand Down Expand Up @@ -62,10 +63,82 @@ void cpu_lfilter_core_loop(
});
}

void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_order = a_coeff_flipped.size(0);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)});
auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample})
.addmv(windowed_output_signal, a_coeff_flipped, 1, -1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(), i_sample + n_order - 1}, o0);
}
}

torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);

auto device = waveform.device();
int64_t n_order = a_coeffs.size(0);

TORCH_INTERNAL_ASSERT(n_order > 0);

namespace F = torch::nn::functional;

auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
auto padded_output_waveform = torch::zeros_like(padded_waveform);

auto a_coeff_flipped = a_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();

auto input_signal_windows =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);

input_signal_windows.div_(a_coeffs[0]);
a_coeff_flipped.div_(a_coeffs[0]);

if (device.is_cpu()) {
cpu_lfilter_core_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
} else {
lfilter_core_generic_loop(
input_signal_windows, a_coeff_flipped, padded_output_waveform);
}

auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

return output;
}

} // namespace

// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}

TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchaudio, Math, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
}
56 changes: 36 additions & 20 deletions torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,30 +825,11 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
_lfilter_core_cpu_loop = _lfilter_core_generic_loop


def lfilter(
def _lfilter_core(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.

Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)

Returns:
Tensor: Waveform with dimension of ``(..., time)``.
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

assert a_coeffs.size(0) == b_coeffs.size(0)
assert len(waveform.size()) == 2
Expand Down Expand Up @@ -886,6 +867,41 @@ def lfilter(
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)

output = padded_output_waveform[:, n_order - 1:]
return output

try:
_lfilter = torch.ops.torchaudio._lfilter
except RuntimeError as err:
assert str(err) == 'No such operator torchaudio::_lfilter'
_lfilter = _lfilter_core


def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.

Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)

Returns:
Tensor: Waveform with dimension of ``(..., time)``.
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

output = _lfilter(waveform, a_coeffs, b_coeffs)

if clamp:
output = torch.clamp(output, min=-1.0, max=1.0)
Expand Down