diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 049ae25c90..dc1bf9547d 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -59,6 +59,24 @@ def test_lfilter_all_inputs(self): b = torch.tensor([0.4, 0.2, 0.9]) self.assert_grad(F.lfilter, (x, a, b)) + def test_batch_lfilter(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(F.lfilter, (x, a, b)) + + def test_filter_banks_lfilter(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2).unsqueeze(1) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(F.lfilter, (x, a, b)) + def test_biquad(self): torch.random.manual_seed(2434) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py index bd91103b15..b9770b22af 100644 --- a/test/torchaudio_unittest/functional/batch_consistency_test.py +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -217,3 +217,31 @@ def test_compute_kaldi_pitch(self): batch = waveform.view(self.batch_size, n_channels, waveform.size(-1)) self.assert_batch_consistency( F.compute_kaldi_pitch, batch, sample_rate=sample_rate) + + def test_lfilter_separated_filters(self): + signal_length = 2048 + torch.manual_seed(2434) + x = torch.randn(self.batch_size, signal_length) + a = torch.rand(self.batch_size, 3) + b = torch.rand(self.batch_size, 3) + + batchwise_output = F.lfilter(x, a, b) + itemwise_output = torch.stack([ + F.lfilter(x[i], a[i], b[i]) + for i in range(self.batch_size) + ]) + + self.assertEqual(batchwise_output, itemwise_output) + + def test_lfilter(self): + signal_length = 2048 + torch.manual_seed(2434) + x = torch.randn(signal_length) + a = torch.rand(4, 3) + b = torch.rand(4, 3) + + def filter_wrapper(ab_coeffs, waveform): + a, b = ab_coeffs[..., 0, :], ab_coeffs[..., 1, :] + return F.lfilter(waveform, a, b) + + self.assert_batch_consistency(filter_wrapper, torch.stack([a, b], 1), x) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 518a9fa6f6..48d83267d3 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -80,6 +80,20 @@ def test_lfilter_shape(self, shape): output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) assert shape == waveform.size() == output_waveform.size() + @parameterized.expand([ + ((44100,), (2, 3), (2, 44100)), + ((3, 44100), (1, 3), (3, 44100)), + ((3, 44100), (3, 3), (3, 44100)), + ((1, 2, 1, 44100), (3, 3), (1, 2, 3, 44100)) + ]) + def test_lfilter_broadcast_shape(self, input_shape, coeff_shape, target_shape): + torch.random.manual_seed(42) + waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) + b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) + a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + assert target_shape == output_waveform.size() + def test_lfilter_9th_order_filter_stability(self): """ Validate the precision of lfilter against reference scipy implementation when using high order filter. diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index dcc5dbe442..4aeef27453 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -8,23 +8,28 @@ void host_lfilter_core_loop( const torch::Tensor& input_signal_windows, const torch::Tensor& a_coeff_flipped, torch::Tensor& padded_output_waveform) { - int64_t n_channel = input_signal_windows.size(0); - int64_t n_samples_input = input_signal_windows.size(1); - int64_t n_samples_output = padded_output_waveform.size(1); - int64_t n_order = a_coeff_flipped.size(0); + int64_t n_batch = input_signal_windows.size(0); + int64_t n_channel = input_signal_windows.size(1); + int64_t n_samples_input = input_signal_windows.size(2); + int64_t n_samples_output = padded_output_waveform.size(2); + int64_t n_order = a_coeff_flipped.size(1); scalar_t* output_data = padded_output_waveform.data_ptr(); const scalar_t* input_data = input_signal_windows.data_ptr(); const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr(); - for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - int64_t offset_input = i_channel * n_samples_input; - int64_t offset_output = i_channel * n_samples_output; - scalar_t a0 = input_data[offset_input + i_sample]; - for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { - a0 -= output_data[offset_output + i_sample + i_coeff] * - a_coeff_flipped_data[i_coeff]; + for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) { + for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + int64_t offset_input = + ((i_batch * n_channel) + i_channel) * n_samples_input; + int64_t offset_output = + ((i_batch * n_channel) + i_channel) * n_samples_output; + scalar_t a0 = input_data[offset_input + i_sample]; + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { + a0 -= output_data[offset_output + i_sample + i_coeff] * + a_coeff_flipped_data[i_coeff + i_channel * n_order]; + } + output_data[offset_output + i_sample + n_order - 1] = a0; } - output_data[offset_output + i_sample + n_order - 1] = a0; } } } @@ -51,10 +56,11 @@ void cpu_lfilter_core_loop( padded_output_waveform.dtype() == torch::kFloat64)); TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); + TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1)); TORCH_CHECK( - input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 == - padded_output_waveform.size(1)); + input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 == + padded_output_waveform.size(2)); AT_DISPATCH_FLOATING_TYPES( input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { @@ -67,16 +73,26 @@ void lfilter_core_generic_loop( const torch::Tensor& input_signal_windows, const torch::Tensor& a_coeff_flipped, torch::Tensor& padded_output_waveform) { - int64_t n_samples_input = input_signal_windows.size(1); - int64_t n_order = a_coeff_flipped.size(0); + int64_t n_samples_input = input_signal_windows.size(2); + int64_t n_order = a_coeff_flipped.size(1); + auto coeff = a_coeff_flipped.unsqueeze(2); for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - auto windowed_output_signal = padded_output_waveform.index( - {torch::indexing::Slice(), - torch::indexing::Slice(i_sample, i_sample + n_order)}); - auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample}) - .addmv(windowed_output_signal, a_coeff_flipped, 1, -1); + auto windowed_output_signal = + padded_output_waveform + .index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(i_sample, i_sample + n_order)}) + .transpose(0, 1); + auto o0 = + input_signal_windows.index( + {torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) - + at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); padded_output_waveform.index_put_( - {torch::indexing::Slice(), i_sample + n_order - 1}, o0); + {torch::indexing::Slice(), + torch::indexing::Slice(), + i_sample + n_order - 1}, + o0); } } @@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function { const torch::Tensor& a_coeffs_normalized) { auto device = waveform.device(); auto dtype = waveform.dtype(); - int64_t n_channel = waveform.size(0); - int64_t n_sample = waveform.size(1); - int64_t n_order = a_coeffs_normalized.size(0); + int64_t n_batch = waveform.size(0); + int64_t n_channel = waveform.size(1); + int64_t n_sample = waveform.size(2); + int64_t n_order = a_coeffs_normalized.size(1); int64_t n_sample_padded = n_sample + n_order - 1; - auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous(); + auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous(); auto options = torch::TensorOptions().dtype(dtype).device(device); auto padded_output_waveform = - torch::zeros({n_channel, n_sample_padded}, options); + torch::zeros({n_batch, n_channel, n_sample_padded}, options); if (device.is_cpu()) { cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); @@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function { auto output = padded_output_waveform.index( {torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(n_order - 1, torch::indexing::None)}); ctx->save_for_backward({waveform, a_coeffs_normalized, output}); @@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function { auto a_coeffs_normalized = saved[1]; auto y = saved[2]; - int64_t n_channel = x.size(0); - int64_t n_order = a_coeffs_normalized.size(0); + int64_t n_batch = x.size(0); + int64_t n_channel = x.size(1); + int64_t n_order = a_coeffs_normalized.size(1); auto dx = torch::Tensor(); auto da = torch::Tensor(); @@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function { F::PadFuncOptions({n_order - 1, 0})); da = F::conv1d( - dyda.unsqueeze(0), - dy.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)) - .sum(1) - .squeeze(0) - .flip(0); + dyda.view({1, n_batch * n_channel, -1}), + dy.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); } if (x.requires_grad()) { - dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1); + dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2); } return {dx, da}; @@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function { torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, const torch::Tensor& b_coeffs) { - int64_t n_order = b_coeffs.size(0); + int64_t n_order = b_coeffs.size(1); + int64_t n_channel = b_coeffs.size(0); namespace F = torch::nn::functional; - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + auto b_coeff_flipped = b_coeffs.flip(1).contiguous(); auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - auto output = - F::conv1d( - padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) - .squeeze(1); + auto output = F::conv1d( + padded_waveform, + b_coeff_flipped.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); ctx->save_for_backward({waveform, b_coeffs, output}); return output; @@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function { auto b_coeffs = saved[1]; auto y = saved[2]; - int64_t n_channel = x.size(0); - int64_t n_order = b_coeffs.size(0); + int64_t n_batch = x.size(0); + int64_t n_channel = x.size(1); + int64_t n_order = b_coeffs.size(1); auto dx = torch::Tensor(); auto db = torch::Tensor(); @@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function { if (b_coeffs.requires_grad()) { db = F::conv1d( - F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), - dy.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)) - .sum(1) - .squeeze(0) - .flip(0); + F::pad(x, F::PadFuncOptions({n_order - 1, 0})) + .view({1, n_batch * n_channel, -1}), + dy.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); } if (x.requires_grad()) { dx = F::conv1d( - F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), - b_coeffs.view({1, 1, n_order})) - .squeeze(1); + F::pad(dy, F::PadFuncOptions({0, n_order - 1})), + b_coeffs.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); } return {dx, db}; @@ -219,19 +241,27 @@ torch::Tensor lfilter_core( const torch::Tensor& b_coeffs) { TORCH_CHECK(waveform.device() == a_coeffs.device()); TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); + TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes()); - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); + TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3); + TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2); + TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1)); - int64_t n_order = b_coeffs.size(0); + int64_t n_order = b_coeffs.size(1); TORCH_INTERNAL_ASSERT(n_order > 0); - auto filtered_waveform = - DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]); - - auto output = - DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]); + auto filtered_waveform = DifferentiableFIR::apply( + waveform, + b_coeffs / + a_coeffs.index( + {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); + + auto output = DifferentiableIIR::apply( + filtered_waveform, + a_coeffs / + a_coeffs.index( + {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); return output; } diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 68269d9b34..ee58ea4bc1 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -857,13 +857,14 @@ def highpass_biquad( def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): - n_order = a_coeffs_flipped.size(0) - for i_sample, o0 in enumerate(input_signal_windows.t()): + n_order = a_coeffs_flipped.size(1) + a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2) + for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)): windowed_output_signal = padded_output_waveform[ - :, i_sample:i_sample + n_order + :, :, i_sample:i_sample + n_order ] - o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) - padded_output_waveform[:, i_sample + n_order - 1] = o0 + o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t() + padded_output_waveform[:, :, i_sample + n_order - 1] = o0 try: @@ -879,13 +880,13 @@ def _lfilter_core( b_coeffs: Tensor, ) -> Tensor: - assert a_coeffs.size(0) == b_coeffs.size(0) - assert len(waveform.size()) == 2 + assert a_coeffs.size() == b_coeffs.size() + assert len(waveform.size()) == 3 assert waveform.device == a_coeffs.device assert b_coeffs.device == a_coeffs.device - n_channel, n_sample = waveform.size() - n_order = a_coeffs.size(0) + n_batch, n_channel, n_sample = waveform.size() + n_order = a_coeffs.size(1) assert n_order > 0 # Pad the input and create output @@ -895,17 +896,18 @@ def _lfilter_core( # Set up the coefficients matrix # Flip coefficients' order - a_coeffs_flipped = a_coeffs.flip(0) - b_coeffs_flipped = b_coeffs.flip(0) + a_coeffs_flipped = a_coeffs.flip(1) + b_coeffs_flipped = b_coeffs.flip(1) # calculate windowed_input_signal in parallel using convolution input_signal_windows = torch.nn.functional.conv1d( - padded_waveform.unsqueeze(1), - b_coeffs_flipped.view(1, 1, -1) - ).squeeze(1) + padded_waveform, + b_coeffs_flipped.unsqueeze(1), + groups=n_channel + ) - input_signal_windows.div_(a_coeffs[0]) - a_coeffs_flipped.div_(a_coeffs[0]) + input_signal_windows.div_(a_coeffs[:, :1]) + a_coeffs_flipped.div_(a_coeffs[:, :1]) if input_signal_windows.device == torch.device('cpu') and\ a_coeffs_flipped.device == torch.device('cpu') and\ @@ -914,7 +916,7 @@ def _lfilter_core( else: _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) - output = padded_output_waveform[:, n_order - 1:] + output = padded_output_waveform[:, :, n_order - 1:] return output try: @@ -938,20 +940,36 @@ def lfilter( Args: waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. - a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(*, n_order + 1)``. + Where * is the optional number of filter banks, + and must be broadcastable to ``waveform`` except time dimension. Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. Must be same size as b_coeffs (pad with 0's as necessary). - b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(*, n_order + 1)``. + Where * is the optional number of filter banks, + and must be broadcastable to ``waveform`` except time dimension. Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Must be same size as a_coeffs (pad with 0's as necessary). clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) Returns: - Tensor: Waveform with dimension of ``(..., time)``. + Tensor: Waveform with dimension of ``(..., *, time)``. """ # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) + assert a_coeffs.size() == b_coeffs.size() + assert a_coeffs.ndim <= 2 + + a_coeffs, b_coeffs = a_coeffs.squeeze(), b_coeffs.squeeze() + + if a_coeffs.ndim > 1: + shape = shape[:-2] + (a_coeffs.shape[0], shape[-1]) + waveform = torch.broadcast_to(waveform, shape) + waveform = waveform.reshape(-1, shape[-2], shape[-1]) + else: + waveform = waveform.reshape(-1, 1, shape[-1]) + a_coeffs = a_coeffs.unsqueeze(0) + b_coeffs = b_coeffs.unsqueeze(0) output = _lfilter(waveform, a_coeffs, b_coeffs)