Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/functional/autograd_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Tuple
from functools import partial
import torch
from parameterized import parameterized
from torch import Tensor
Expand Down Expand Up @@ -62,6 +63,15 @@ def test_lfilter_all_inputs(self):
def test_lfilter_filterbanks(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
[0.7, 0.2, 0.6]])
self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))

def test_lfilter_batching(self):
torch.random.manual_seed(2434)
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
a = torch.tensor([[0.7, 0.2, 0.6],
[0.8, 0.2, 0.9]])
b = torch.tensor([[0.4, 0.2, 0.9],
Expand Down
15 changes: 15 additions & 0 deletions test/torchaudio_unittest/functional/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,18 @@ def test_compute_kaldi_pitch(self):
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
self.assert_batch_consistency(
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)

def test_lfilter(self):
signal_length = 2048
torch.manual_seed(2434)
x = torch.randn(self.batch_size, signal_length)
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)

batchwise_output = F.lfilter(x, a, b, batching=True)
itemwise_output = torch.stack([
F.lfilter(x[i], a[i], b[i])
for i in range(self.batch_size)
])

self.assertEqual(batchwise_output, itemwise_output)
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False)
assert input_shape == waveform.size()
assert target_shape == output_waveform.size()

Expand Down
11 changes: 10 additions & 1 deletion torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ def lfilter(
a_coeffs: Tensor,
b_coeffs: Tensor,
clamp: bool = True,
batching: bool = True
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.

Expand All @@ -948,6 +949,10 @@ def lfilter(
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least
batching (bool, optional): Effective only when coefficients are 2D. If ``True``, then waveform should be at least

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks,. Should I make another MR to fix the docstring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have time, please!

2D, and the size of second axis from last should equals to ``num_filters``.
The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :],
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)

Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
Expand All @@ -957,7 +962,11 @@ def lfilter(
assert a_coeffs.ndim <= 2

if a_coeffs.ndim > 1:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
if batching:
assert waveform.ndim > 1
assert waveform.shape[-2] == a_coeffs.shape[0]
else:
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shape manipulation looks very complicated, and I wonder if it is because we have pack/unpack batch in the following section. If so, is there a way to contain pack/unpack batch bellow into batching mechanism?

Copy link
Contributor Author

@yoyolicoris yoyolicoris Aug 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not related to pack/unpack batch, and I haven't come up a way to integrate these two parts together.
Waveform and coefficients should have equal number of channels before pack batch, and this line is just replicating the waveform to match the desire shape.

else:
a_coeffs = a_coeffs.unsqueeze(0)
b_coeffs = b_coeffs.unsqueeze(0)
Expand Down