diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 0c663481b8..156627ba69 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -1,4 +1,5 @@ from typing import Callable, Tuple +from functools import partial import torch from parameterized import parameterized from torch import Tensor @@ -62,6 +63,15 @@ def test_lfilter_all_inputs(self): def test_lfilter_filterbanks(self): torch.random.manual_seed(2434) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(partial(F.lfilter, batching=False), (x, a, b)) + + def test_lfilter_batching(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]]) b = torch.tensor([[0.4, 0.2, 0.9], diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py index bd91103b15..38b81d7090 100644 --- a/test/torchaudio_unittest/functional/batch_consistency_test.py +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -217,3 +217,18 @@ def test_compute_kaldi_pitch(self): batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) self.assert_batch_consistency( F.compute_kaldi_pitch, batch, sample_rate=sample_rate) + + def test_lfilter(self): + signal_length = 2048 + torch.manual_seed(2434) + x = torch.randn(self.batch_size, signal_length) + a = torch.rand(self.batch_size, 3) + b = torch.rand(self.batch_size, 3) + + batchwise_output = F.lfilter(x, a, b, batching=True) + itemwise_output = torch.stack([ + F.lfilter(x[i], a[i], b[i]) + for i in range(self.batch_size) + ]) + + self.assertEqual(batchwise_output, itemwise_output) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 1b18bd255b..0c5dabe45a 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -80,7 +80,7 @@ def test_lfilter_shape(self, input_shape, coeff_shape, target_shape): waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) - output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs, batching=False) assert input_shape == waveform.size() assert target_shape == output_waveform.size() diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 57d5b745cc..0dcd8f159c 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -930,6 +930,7 @@ def lfilter( a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, + batching: bool = True ) -> Tensor: r"""Perform an IIR filter by evaluating difference equation. @@ -948,6 +949,10 @@ def lfilter( Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Must be same size as a_coeffs (pad with 0's as necessary). clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) + batching (bool, optional): Activate when coefficients are in 2D. If ``True``, then waveform should be at least + 2D, and the size of second axis from last should equals to ``num_filters``. + The output can be expressed as ``output[..., i, :] = lfilter(waveform[..., i, :], + a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``) Returns: Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` @@ -957,7 +962,11 @@ def lfilter( assert a_coeffs.ndim <= 2 if a_coeffs.ndim > 1: - waveform = torch.stack([waveform] * a_coeffs.shape[0], -2) + if batching: + assert waveform.ndim > 1 + assert waveform.shape[-2] == a_coeffs.shape[0] + else: + waveform = torch.stack([waveform] * a_coeffs.shape[0], -2) else: a_coeffs = a_coeffs.unsqueeze(0) b_coeffs = b_coeffs.unsqueeze(0)