Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def test_lfilter_all_inputs(self):
b = torch.tensor([0.4, 0.2, 0.9])
self.assert_grad(F.lfilter, (x, a, b))

def test_batch_lfilter(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b))

def test_filter_banks_lfilter(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2).unsqueeze(1)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(F.lfilter, (x, a, b))

def test_biquad(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1)
Expand Down
28 changes: 28 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,31 @@ def test_compute_kaldi_pitch(self):
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)

def test_lfilter_separated_filters(self):
signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)

batchwise_output = F.lfilter(x, a, b)
itemwise_output = torch.stack([
F.lfilter(x[i], a[i], b[i])
for i in range(self.batch_size)
])

self.assertEqual(batchwise_output, itemwise_output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use self.assert_batch_consistency helper method? It handles dtype/device as well.

Copy link
Contributor Author

@yoyolicoris yoyolicoris Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.assert_batch_consistency seems assume it only needs to take batch on the first input, but in our case, a_coeffs and b_coeffs should also be in batch as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean it is applying different filters to samples in batch? I thought the same set of filters are applied to each sample in batch, so one can change the batch size without changing a_coeffs and b_coeffs.

Copy link
Contributor Author

@yoyolicoris yoyolicoris Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like we need to do 2 type of tests. I will add another one that use self.assert_batch_consistency.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And can you clarify that in the above case, where you said "a_coeffs and b_coeffs should also be in batch as well", the number of filter bank happens to be same as the batch size, but that's not requirement?

So my understanding/expectation is that when input batch is the shape of [batch_size, sequence_length], a_coeffs and b_coeffs can take any shape of [filter_dim, number_of_filters], without being constrained on the input shape.

And if I understand correctly, here your test is testing that filter banks produces the same result regardless they are applied separately or together, in that correct?

Copy link
Contributor Author

@yoyolicoris yoyolicoris Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my understanding/expectation is that when input batch is the shape of [batch_size, sequence_length], a_coeffs and b_coeffs can take any shape of [filter_dim, number_of_filters], without being constrained on the input shape.

In this case, when input is a 2D batch of signals, a_coeffs and b_coeffs should be in shape of [batch_size, filter_order + 1] or just [filter_order + 1]. The first one means that the number of filters is equal to batch_size, and each signal is applied with different filter; the second is just one filter apply on all signals.

The case that filter shape will not be constrainted, is when the shape of input is [..., 1, sequence_length]. Then a_coeffs and b_coeffs can be in any shape of 2D matrix [number_of_filters, filter_order + 1], the output shape will be [..., number_of_filters, sequence_length]. It means each signal is filtered by a shared set of filters.

And if I understand correctly, here your test is testing that filter banks produces the same result regardless they are applied separately or together, in that correct?

Yes, that's correct.

Copy link
Contributor Author

@yoyolicoris yoyolicoris Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I think the batch behavior we want to test is actually the coefficients, not the input. 😆
So we might need to change the test, with a_coeffs and b_coeffs as input batch, waveform as the parameter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my understanding/expectation is that when input batch is the shape of [batch_size, sequence_length], a_coeffs and b_coeffs can take any shape of [filter_dim, number_of_filters], without being constrained on the input shape.

In this case, when input is a 2D batch of signals, a_coeffs and b_coeffs should be in shape of [batch_size, filter_order + 1] or just [filter_order + 1]. The first one means that the number of filters is equal to batch_size, and each signal is applied with different filter; the second is just one filter apply on all signals.

The case that filter shape will not be constrainted, is when the shape of input is [..., 1, sequence_length]. Then a_coeffs and b_coeffs can be in any shape of 2D matrix [number_of_filters, filter_order + 1], the output shape will be [..., number_of_filters, sequence_length]. It means each signal is filtered by a shared set of filters.

And if I understand correctly, here your test is testing that filter banks produces the same result regardless they are applied separately or together, in that correct?

Yes, that's correct.

@yoyololicon

Can you help me clarify with the understanding of the shape semantics here?

  • When the input batch is 2D, the first dimension is interpreted as channel, thus the number of filters has to match the number of channels.
  • When an input signal has multiple channels, then filters have to have multiple channels.
    Are these correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mthrok

Can you help me clarify with the understanding of the shape semantics here?

  • When the input batch is 2D, the first dimension is interpreted as channel, thus the number of filters has to match the number of channels.
  • When an input signal has multiple channels, then filters have to have multiple channels.
    Are these correct?

If you want to apply multiple filters at once, these are correct; if there is only one filter, it will fall back to original behavior.
The shape semantics I proposed actually follows pytorch conventions except the last dimension, which is time or filter order.


def test_lfilter(self):
signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(signal_length)
a = torch.rand(4, 3)
b = torch.rand(4, 3)

def filter_wrapper(ab_coeffs, waveform):
a, b = ab_coeffs[..., 0, :], ab_coeffs[..., 1, :]
return F.lfilter(waveform, a, b)

self.assert_batch_consistency(filter_wrapper, torch.stack([a, b], 1), x)
14 changes: 14 additions & 0 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def test_lfilter_shape(self, shape):
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size()

@parameterized.expand([
((44100,), (2, 3), (2, 44100)),
((3, 44100), (1, 3), (3, 44100)),
((3, 44100), (3, 3), (3, 44100)),
((1, 2, 1, 44100), (3, 3), (1, 2, 3, 44100))
])
def test_lfilter_broadcast_shape(self, input_shape, coeff_shape, target_shape):
Comment on lines +83 to +89
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mthrok
The part that tests the broadcasting behavior.

torch.random.manual_seed(42)
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert target_shape == output_waveform.size()

def test_lfilter_9th_order_filter_stability(self):
"""
Validate the precision of lfilter against reference scipy implementation when using high order filter.
Expand Down
154 changes: 92 additions & 62 deletions torchaudio/csrc/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,28 @@ void host_lfilter_core_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_channel = input_signal_windows.size(0);
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_samples_output = padded_output_waveform.size(1);
int64_t n_order = a_coeff_flipped.size(0);
int64_t n_batch = input_signal_windows.size(0);
int64_t n_channel = input_signal_windows.size(1);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_samples_output = padded_output_waveform.size(2);
int64_t n_order = a_coeff_flipped.size(1);
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input = i_channel * n_samples_input;
int64_t offset_output = i_channel * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff];
for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) {
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
int64_t offset_input =
((i_batch * n_channel) + i_channel) * n_samples_input;
int64_t offset_output =
((i_batch * n_channel) + i_channel) * n_samples_output;
scalar_t a0 = input_data[offset_input + i_sample];
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
a0 -= output_data[offset_output + i_sample + i_coeff] *
a_coeff_flipped_data[i_coeff + i_channel * n_order];
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
output_data[offset_output + i_sample + n_order - 1] = a0;
}
}
}
Expand All @@ -51,10 +56,11 @@ void cpu_lfilter_core_loop(
padded_output_waveform.dtype() == torch::kFloat64));

TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));

TORCH_CHECK(
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 ==
padded_output_waveform.size(1));
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
padded_output_waveform.size(2));

AT_DISPATCH_FLOATING_TYPES(
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
Expand All @@ -67,16 +73,26 @@ void lfilter_core_generic_loop(
const torch::Tensor& input_signal_windows,
const torch::Tensor& a_coeff_flipped,
torch::Tensor& padded_output_waveform) {
int64_t n_samples_input = input_signal_windows.size(1);
int64_t n_order = a_coeff_flipped.size(0);
int64_t n_samples_input = input_signal_windows.size(2);
int64_t n_order = a_coeff_flipped.size(1);
auto coeff = a_coeff_flipped.unsqueeze(2);
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
auto windowed_output_signal = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)});
auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample})
.addmv(windowed_output_signal, a_coeff_flipped, 1, -1);
auto windowed_output_signal =
padded_output_waveform
.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(i_sample, i_sample + n_order)})
.transpose(0, 1);
auto o0 =
input_signal_windows.index(
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
padded_output_waveform.index_put_(
{torch::indexing::Slice(), i_sample + n_order - 1}, o0);
{torch::indexing::Slice(),
torch::indexing::Slice(),
i_sample + n_order - 1},
o0);
}
}

Expand All @@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_channel = waveform.size(0);
int64_t n_sample = waveform.size(1);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1;

auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous();
auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();

auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_channel, n_sample_padded}, options);
torch::zeros({n_batch, n_channel, n_sample_padded}, options);

if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
Expand All @@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {

auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

ctx->save_for_backward({waveform, a_coeffs_normalized, output});
Expand All @@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
auto a_coeffs_normalized = saved[1];
auto y = saved[2];

int64_t n_channel = x.size(0);
int64_t n_order = a_coeffs_normalized.size(0);
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);

auto dx = torch::Tensor();
auto da = torch::Tensor();
Expand All @@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
F::PadFuncOptions({n_order - 1, 0}));

da = F::conv1d(
dyda.unsqueeze(0),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
dyda.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

if (x.requires_grad()) {
dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1);
dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2);
}

return {dx, da};
Expand All @@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(0);
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);

namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(0).contiguous();
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

auto output =
F::conv1d(
padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order}))
.squeeze(1);
auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
Expand All @@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
auto b_coeffs = saved[1];
auto y = saved[2];

int64_t n_channel = x.size(0);
int64_t n_order = b_coeffs.size(0);
int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);

auto dx = torch::Tensor();
auto db = torch::Tensor();
Expand All @@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {

if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})),
dy.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel))
.sum(1)
.squeeze(0)
.flip(0);
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})),
b_coeffs.view({1, 1, n_order}))
.squeeze(1);
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}

return {dx, db};
Expand All @@ -219,19 +241,27 @@ torch::Tensor lfilter_core(
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0));
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));

int64_t n_order = b_coeffs.size(0);
int64_t n_order = b_coeffs.size(1);

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform =
DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]);

auto output =
DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]);
auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));

auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
}

Expand Down
Loading