Skip to content

Commit 8094751

Browse files
authored
Add batch support to lfilter (#1638)
1 parent 15bc554 commit 8094751

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

test/torchaudio_unittest/functional/autograd_impl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Callable, Tuple
2+
from functools import partial
23
import torch
34
from parameterized import parameterized
45
from torch import Tensor
@@ -62,6 +63,15 @@ def test_lfilter_all_inputs(self):
6263
def test_lfilter_filterbanks(self):
6364
torch.random.manual_seed(2434)
6465
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3)
66+
a = torch.tensor([[0.7, 0.2, 0.6],
67+
[0.8, 0.2, 0.9]])
68+
b = torch.tensor([[0.4, 0.2, 0.9],
69+
[0.7, 0.2, 0.6]])
70+
self.assert_grad(partial(F.lfilter, batching=False), (x, a, b))
71+
72+
def test_lfilter_batching(self):
73+
torch.random.manual_seed(2434)
74+
x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)
6575
a = torch.tensor([[0.7, 0.2, 0.6],
6676
[0.8, 0.2, 0.9]])
6777
b = torch.tensor([[0.4, 0.2, 0.9],

test/torchaudio_unittest/functional/batch_consistency_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,18 @@ def test_compute_kaldi_pitch(self):
217217
batch = waveform.view(self.batch_size, n_channels, waveform.size(-1))
218218
self.assert_batch_consistency(
219219
F.compute_kaldi_pitch, batch, sample_rate=sample_rate)
220+
221+
def test_lfilter(self):
222+
signal_length = 2048
223+
torch.manual_seed(2434)
224+
x = torch.randn(self.batch_size, signal_length)
225+
a = torch.rand(self.batch_size, 3)
226+
b = torch.rand(self.batch_size, 3)
227+
228+
batchwise_output = F.lfilter(x, a, b, batching=True)
229+
itemwise_output = torch.stack([
230+
F.lfilter(x[i], a[i], b[i])
231+
for i in range(self.batch_size)
232+
])
233+
234+
self.assertEqual(batchwise_output, itemwise_output)

test/torchaudio_unittest/functional/functional_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_lfilter_shape(self, input_shape, coeff_shape, target_shape):
8080
waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device)
8181
b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
8282
a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device)
83-
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
83+
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False)
8484
assert input_shape == waveform.size()
8585
assert target_shape == output_waveform.size()
8686

torchaudio/functional/filtering.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def lfilter(
930930
a_coeffs: Tensor,
931931
b_coeffs: Tensor,
932932
clamp: bool = True,
933+
batching: bool = True
933934
) -> Tensor:
934935
r"""Perform an IIR filter by evaluating difference equation.
935936
@@ -948,6 +949,10 @@ def lfilter(
948949
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
949950
Must be same size as a_coeffs (pad with 0's as necessary).
950951
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
952+
batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least
953+
2D, and the size of second axis from last should equals to ``num_filters``.
954+
The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :],
955+
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
951956
952957
Returns:
953958
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs``
@@ -957,7 +962,11 @@ def lfilter(
957962
assert a_coeffs.ndim <= 2
958963

959964
if a_coeffs.ndim > 1:
960-
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
965+
if batching:
966+
assert waveform.ndim > 1
967+
assert waveform.shape[-2] == a_coeffs.shape[0]
968+
else:
969+
waveform = torch.stack([waveform] * a_coeffs.shape[0], -2)
961970
else:
962971
a_coeffs = a_coeffs.unsqueeze(0)
963972
b_coeffs = b_coeffs.unsqueeze(0)

0 commit comments

Comments
 (0)