Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def test_lfilter_all_inputs(self):
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.lfilter, (x, a, b))

def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b))

def test_biquad(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
Expand Down
22 changes: 13 additions & 9 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,22 @@ def test_lfilter_clamp(self):
assert output_signal.max() > 1

@parameterized.expand([
((44100,),),
((3, 44100),),
((2, 3, 44100),),
((1, 2, 3, 44100),)
((44100,), (4,), (44100,)),
((3, 44100), (4,), (3, 44100,)),
((2, 3, 44100), (4,), (2, 3, 44100,)),
((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)),
((44100,), (2, 4), (2, 44100)),
((3, 44100), (1, 4), (3, 1, 44100)),
((1, 2, 44100), (3, 4), (1, 2, 3, 44100))
])
def test_lfilter_shape(self, shape):
def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
torch.random.manual_seed(42)
waveform = torch.rand(*shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device)
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size()
assert input_shape == waveform.size()
assert target_shape == output_waveform.size()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up: We should add tests to verify that the resulting signals are same whether multiple filters are applied individually or as a bank.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, will add batch consistency tests and also benchmarks of multiple filters in later commits.

def test_lfilter_9th_order_filter_stability(self):
"""
Expand Down
154 changes: 92 additions & 62 deletions torchaudio/csrc/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,28 @@ void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_channel = input_signal_windows.size(0);
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_samples_output = padded_output_waveform.size(1);
int64_t n_order = a_coeff_flipped.size(0);
int64_t n_batch = input_signal_windows.size(0);
int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input = i_channel * n_samples_input;
int64_t offset_output = i_channel * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff];
for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) {
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input =
((i_batch * n_channel) + i_channel) * n_samples_input;
int64_t offset_output =
((i_batch * n_channel) + i_channel) * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff + i_channel * n_order];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
}
Expand All @@ -51,10 +56,11 @@ void cpu_lfilter_core_loop(
padded_output_waveform.dtype() == torch::kFloat64));

TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));

TORCH_CHECK(
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 ==
padded_output_waveform.size(1));
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(2));

AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
Expand All @@ -67,16 +73,26 @@ void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_order = a_coeff_flipped.size(0);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)});
auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample})
.addmv(windowed_output_signal, a_coeff_flipped, 1, -1);
auto windowed_output_signal =
padded_output_waveform
.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)})
.transpose(0, 1);
auto o0 =
input_signal_windows.index(
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(), i_sample + n_order - 1}, o0);
{torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
}
}

Expand All @@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1;

auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous();
auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();

auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_channel, n_sample_padded}, options);
torch::zeros({n_batch, n_channel, n_sample_padded}, options);

if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
Expand All @@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {

auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

ctx->save_for_backward({waveform, a_coeffs_normalized, output});
Expand All @@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto a_coeffs_normalized = saved[1];
auto y = saved[2];

int64_t n_channel = x.size(0);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);

auto dx = torch::Tensor();
auto da = torch::Tensor();
Expand All @@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
F::PadFuncOptions({n_order - 1, 0}));

da = F::conv1d(
dyda.unsqueeze(0),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
dyda.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1);
dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2);
}

return {dx, da};
Expand All @@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(0);
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);

namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

auto output =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
Expand All @@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
auto b_coeffs = saved[1];
auto y = saved[2];

int64_t n_channel = x.size(0);
int64_t n_order = b_coeffs.size(0);
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);

auto dx = torch::Tensor();
auto db = torch::Tensor();
Expand All @@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {

if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}

return {dx, db};
Expand All @@ -219,19 +241,27 @@ torch::Tensor lfilter_core(
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));

int64_t n_order = b_coeffs.size(0);
int64_t n_order = b_coeffs.size(1);

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform =
DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]);

auto output =
DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]);
auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));

auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
}

Expand Down
Loading