From 37d38a9e48a3f6c69b2d195e4e5796033d44cc6d Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 18 Jun 2021 16:06:54 +0800 Subject: [PATCH 1/9] let coeffs can be 2d tensor --- torchaudio/functional/filtering.py | 60 ++++++++++++++++++------------ 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 68269d9b34..4e7f3d7bc8 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -857,13 +857,14 @@ def highpass_biquad( def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor): - n_order = a_coeffs_flipped.size(0) - for i_sample, o0 in enumerate(input_signal_windows.t()): + n_order = a_coeffs_flipped.size(1) + a_coeffs_flipped = a_coeffs_flipped.unsqueeze(2) + for i_sample, o0 in enumerate(input_signal_windows.permute(2, 0, 1)): windowed_output_signal = padded_output_waveform[ - :, i_sample:i_sample + n_order + :, :, i_sample:i_sample + n_order ] - o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1) - padded_output_waveform[:, i_sample + n_order - 1] = o0 + o0 -= (windowed_output_signal.transpose(0, 1) @ a_coeffs_flipped)[..., 0].t() + padded_output_waveform[:, :, i_sample + n_order - 1] = o0 try: @@ -879,13 +880,13 @@ def _lfilter_core( b_coeffs: Tensor, ) -> Tensor: - assert a_coeffs.size(0) == b_coeffs.size(0) - assert len(waveform.size()) == 2 + assert a_coeffs.size() == b_coeffs.size() + assert len(waveform.size()) == 3 assert waveform.device == a_coeffs.device assert b_coeffs.device == a_coeffs.device - n_channel, n_sample = waveform.size() - n_order = a_coeffs.size(0) + n_batch, n_channel, n_sample = waveform.size() + n_order = a_coeffs.size(1) assert n_order > 0 # Pad the input and create output @@ -895,17 +896,18 @@ def _lfilter_core( # Set up the coefficients matrix # Flip coefficients' order - a_coeffs_flipped = a_coeffs.flip(0) - b_coeffs_flipped = b_coeffs.flip(0) + a_coeffs_flipped = a_coeffs.flip(1) + b_coeffs_flipped = b_coeffs.flip(1) # calculate windowed_input_signal in parallel using convolution input_signal_windows = torch.nn.functional.conv1d( - padded_waveform.unsqueeze(1), - b_coeffs_flipped.view(1, 1, -1) - ).squeeze(1) + padded_waveform, + b_coeffs_flipped.unsqueeze(1), + groups=n_channel + ) - input_signal_windows.div_(a_coeffs[0]) - a_coeffs_flipped.div_(a_coeffs[0]) + input_signal_windows.div_(a_coeffs[:, :1]) + a_coeffs_flipped.div_(a_coeffs[:, :1]) if input_signal_windows.device == torch.device('cpu') and\ a_coeffs_flipped.device == torch.device('cpu') and\ @@ -914,9 +916,10 @@ def _lfilter_core( else: _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) - output = padded_output_waveform[:, n_order - 1:] + output = padded_output_waveform[:, :, n_order - 1:] return output + try: _lfilter = torch.ops.torchaudio._lfilter except RuntimeError as err: @@ -938,20 +941,31 @@ def lfilter( Args: waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. - a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(*, n_order + 1)``. + Where * is the optional number of filter banks. Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. Must be same size as b_coeffs (pad with 0's as necessary). - b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. - Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. - Must be same size as a_coeffs (pad with 0's as necessary). + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(*, n_order + 1)``. + Where * is the optional number of filter banks. + Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. + Must be same size as a_coeffs (pad with 0's as necessary). clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) Returns: - Tensor: Waveform with dimension of ``(..., time)``. + Tensor: Waveform with dimension of ``(..., *, time)``. """ # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) + assert a_coeffs.size() == b_coeffs.size() + assert a_coeffs.ndim <= 2 + + waveform = waveform.reshape(-1, 1, shape[-1]) + if a_coeffs.ndim > 1: + shape = shape[:-1] + (a_coeffs.shape[0], shape[-1]) + waveform = torch.broadcast_to(waveform, a_coeffs) + else: + a_coeffs = a_coeffs.unsqueeze(0) + b_coeffs = b_coeffs.unsqueeze(0) output = _lfilter(waveform, a_coeffs, b_coeffs) From 48e402662511e0150988bbe8f81c7704a6877475 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 18 Jun 2021 16:08:55 +0800 Subject: [PATCH 2/9] update cpp implementation to support filter banks --- torchaudio/csrc/lfilter.cpp | 154 +++++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 62 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index dcc5dbe442..4aeef27453 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -8,23 +8,28 @@ void host_lfilter_core_loop( const torch::Tensor& input_signal_windows, const torch::Tensor& a_coeff_flipped, torch::Tensor& padded_output_waveform) { - int64_t n_channel = input_signal_windows.size(0); - int64_t n_samples_input = input_signal_windows.size(1); - int64_t n_samples_output = padded_output_waveform.size(1); - int64_t n_order = a_coeff_flipped.size(0); + int64_t n_batch = input_signal_windows.size(0); + int64_t n_channel = input_signal_windows.size(1); + int64_t n_samples_input = input_signal_windows.size(2); + int64_t n_samples_output = padded_output_waveform.size(2); + int64_t n_order = a_coeff_flipped.size(1); scalar_t* output_data = padded_output_waveform.data_ptr(); const scalar_t* input_data = input_signal_windows.data_ptr(); const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr(); - for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - int64_t offset_input = i_channel * n_samples_input; - int64_t offset_output = i_channel * n_samples_output; - scalar_t a0 = input_data[offset_input + i_sample]; - for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { - a0 -= output_data[offset_output + i_sample + i_coeff] * - a_coeff_flipped_data[i_coeff]; + for (int64_t i_batch = 0; i_batch < n_batch; i_batch++) { + for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + int64_t offset_input = + ((i_batch * n_channel) + i_channel) * n_samples_input; + int64_t offset_output = + ((i_batch * n_channel) + i_channel) * n_samples_output; + scalar_t a0 = input_data[offset_input + i_sample]; + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { + a0 -= output_data[offset_output + i_sample + i_coeff] * + a_coeff_flipped_data[i_coeff + i_channel * n_order]; + } + output_data[offset_output + i_sample + n_order - 1] = a0; } - output_data[offset_output + i_sample + n_order - 1] = a0; } } } @@ -51,10 +56,11 @@ void cpu_lfilter_core_loop( padded_output_waveform.dtype() == torch::kFloat64)); TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); + TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1)); TORCH_CHECK( - input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 == - padded_output_waveform.size(1)); + input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 == + padded_output_waveform.size(2)); AT_DISPATCH_FLOATING_TYPES( input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { @@ -67,16 +73,26 @@ void lfilter_core_generic_loop( const torch::Tensor& input_signal_windows, const torch::Tensor& a_coeff_flipped, torch::Tensor& padded_output_waveform) { - int64_t n_samples_input = input_signal_windows.size(1); - int64_t n_order = a_coeff_flipped.size(0); + int64_t n_samples_input = input_signal_windows.size(2); + int64_t n_order = a_coeff_flipped.size(1); + auto coeff = a_coeff_flipped.unsqueeze(2); for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { - auto windowed_output_signal = padded_output_waveform.index( - {torch::indexing::Slice(), - torch::indexing::Slice(i_sample, i_sample + n_order)}); - auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample}) - .addmv(windowed_output_signal, a_coeff_flipped, 1, -1); + auto windowed_output_signal = + padded_output_waveform + .index( + {torch::indexing::Slice(), + torch::indexing::Slice(), + torch::indexing::Slice(i_sample, i_sample + n_order)}) + .transpose(0, 1); + auto o0 = + input_signal_windows.index( + {torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) - + at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); padded_output_waveform.index_put_( - {torch::indexing::Slice(), i_sample + n_order - 1}, o0); + {torch::indexing::Slice(), + torch::indexing::Slice(), + i_sample + n_order - 1}, + o0); } } @@ -88,16 +104,17 @@ class DifferentiableIIR : public torch::autograd::Function { const torch::Tensor& a_coeffs_normalized) { auto device = waveform.device(); auto dtype = waveform.dtype(); - int64_t n_channel = waveform.size(0); - int64_t n_sample = waveform.size(1); - int64_t n_order = a_coeffs_normalized.size(0); + int64_t n_batch = waveform.size(0); + int64_t n_channel = waveform.size(1); + int64_t n_sample = waveform.size(2); + int64_t n_order = a_coeffs_normalized.size(1); int64_t n_sample_padded = n_sample + n_order - 1; - auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous(); + auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous(); auto options = torch::TensorOptions().dtype(dtype).device(device); auto padded_output_waveform = - torch::zeros({n_channel, n_sample_padded}, options); + torch::zeros({n_batch, n_channel, n_sample_padded}, options); if (device.is_cpu()) { cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); @@ -108,6 +125,7 @@ class DifferentiableIIR : public torch::autograd::Function { auto output = padded_output_waveform.index( {torch::indexing::Slice(), + torch::indexing::Slice(), torch::indexing::Slice(n_order - 1, torch::indexing::None)}); ctx->save_for_backward({waveform, a_coeffs_normalized, output}); @@ -122,8 +140,9 @@ class DifferentiableIIR : public torch::autograd::Function { auto a_coeffs_normalized = saved[1]; auto y = saved[2]; - int64_t n_channel = x.size(0); - int64_t n_order = a_coeffs_normalized.size(0); + int64_t n_batch = x.size(0); + int64_t n_channel = x.size(1); + int64_t n_order = a_coeffs_normalized.size(1); auto dx = torch::Tensor(); auto da = torch::Tensor(); @@ -137,16 +156,16 @@ class DifferentiableIIR : public torch::autograd::Function { F::PadFuncOptions({n_order - 1, 0})); da = F::conv1d( - dyda.unsqueeze(0), - dy.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)) - .sum(1) - .squeeze(0) - .flip(0); + dyda.view({1, n_batch * n_channel, -1}), + dy.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); } if (x.requires_grad()) { - dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1); + dx = DifferentiableIIR::apply(dy.flip(2), a_coeffs_normalized).flip(2); } return {dx, da}; @@ -159,17 +178,18 @@ class DifferentiableFIR : public torch::autograd::Function { torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, const torch::Tensor& b_coeffs) { - int64_t n_order = b_coeffs.size(0); + int64_t n_order = b_coeffs.size(1); + int64_t n_channel = b_coeffs.size(0); namespace F = torch::nn::functional; - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + auto b_coeff_flipped = b_coeffs.flip(1).contiguous(); auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - auto output = - F::conv1d( - padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) - .squeeze(1); + auto output = F::conv1d( + padded_waveform, + b_coeff_flipped.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); ctx->save_for_backward({waveform, b_coeffs, output}); return output; @@ -183,8 +203,9 @@ class DifferentiableFIR : public torch::autograd::Function { auto b_coeffs = saved[1]; auto y = saved[2]; - int64_t n_channel = x.size(0); - int64_t n_order = b_coeffs.size(0); + int64_t n_batch = x.size(0); + int64_t n_channel = x.size(1); + int64_t n_order = b_coeffs.size(1); auto dx = torch::Tensor(); auto db = torch::Tensor(); @@ -194,19 +215,20 @@ class DifferentiableFIR : public torch::autograd::Function { if (b_coeffs.requires_grad()) { db = F::conv1d( - F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), - dy.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)) - .sum(1) - .squeeze(0) - .flip(0); + F::pad(x, F::PadFuncOptions({n_order - 1, 0})) + .view({1, n_batch * n_channel, -1}), + dy.view({n_batch * n_channel, 1, -1}), + F::Conv1dFuncOptions().groups(n_batch * n_channel)) + .view({n_batch, n_channel, -1}) + .sum(0) + .flip(1); } if (x.requires_grad()) { dx = F::conv1d( - F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), - b_coeffs.view({1, 1, n_order})) - .squeeze(1); + F::pad(dy, F::PadFuncOptions({0, n_order - 1})), + b_coeffs.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)); } return {dx, db}; @@ -219,19 +241,27 @@ torch::Tensor lfilter_core( const torch::Tensor& b_coeffs) { TORCH_CHECK(waveform.device() == a_coeffs.device()); TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); + TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes()); - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); + TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3); + TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2); + TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1)); - int64_t n_order = b_coeffs.size(0); + int64_t n_order = b_coeffs.size(1); TORCH_INTERNAL_ASSERT(n_order > 0); - auto filtered_waveform = - DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]); - - auto output = - DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]); + auto filtered_waveform = DifferentiableFIR::apply( + waveform, + b_coeffs / + a_coeffs.index( + {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); + + auto output = DifferentiableIIR::apply( + filtered_waveform, + a_coeffs / + a_coeffs.index( + {torch::indexing::Slice(), torch::indexing::Slice(0, 1)})); return output; } From 5f05f4baa2aa24b3007ddbcd525db2375f16898d Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 18 Jun 2021 17:28:49 +0800 Subject: [PATCH 3/9] fix: wrong argument --- torchaudio/functional/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 4e7f3d7bc8..acef296304 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -962,7 +962,7 @@ def lfilter( waveform = waveform.reshape(-1, 1, shape[-1]) if a_coeffs.ndim > 1: shape = shape[:-1] + (a_coeffs.shape[0], shape[-1]) - waveform = torch.broadcast_to(waveform, a_coeffs) + waveform = waveform.repeat(1, a_coeffs.shape[0], 1) else: a_coeffs = a_coeffs.unsqueeze(0) b_coeffs = b_coeffs.unsqueeze(0) From e88923257ed586f09e8245f8311be49800d7861e Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 18 Jun 2021 17:29:45 +0800 Subject: [PATCH 4/9] add tests --- .../torchaudio_unittest/functional/autograd_impl.py | 9 +++++++++ .../functional/functional_impl.py | 13 +++++++++++++ 2 files changed, 22 insertions(+) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 049ae25c90..000a22159b 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -59,6 +59,15 @@ def test_lfilter_all_inputs(self): b = torch.tensor([0.4, 0.2, 0.9]) self.assert_grad(F.lfilter, (x, a, b)) + def test_lfilter_filterbanks(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([[0.7, 0.2, 0.6], + [0.8, 0.2, 0.9]]) + b = torch.tensor([[0.4, 0.2, 0.9], + [0.7, 0.2, 0.6]]) + self.assert_grad(F.lfilter, (x, a, b)) + def test_biquad(self): torch.random.manual_seed(2434) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 518a9fa6f6..4d20af0223 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -80,6 +80,19 @@ def test_lfilter_shape(self, shape): output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) assert shape == waveform.size() == output_waveform.size() + @parameterized.expand([ + ((44100,), (2, 3), (2, 44100)), + ((3, 44100), (1, 3), (3, 1, 44100)), + ((1, 2, 44100), (3, 3), (1, 2, 3, 44100)) + ]) + def test_lfilter_filterbanks_shape(self, input_shape, coeff_shape, target_shape): + torch.random.manual_seed(42) + waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) + b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) + a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + assert target_shape == output_waveform.size() + def test_lfilter_9th_order_filter_stability(self): """ Validate the precision of lfilter against reference scipy implementation when using high order filter. From db636e6d8c553f98e4b0fec67ff5e60b6d954de5 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Tue, 22 Jun 2021 10:14:51 +0800 Subject: [PATCH 5/9] test: merge shape tests --- .../functional/functional_impl.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index 4d20af0223..4bd1d9ace6 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -67,30 +67,21 @@ def test_lfilter_clamp(self): assert output_signal.max() > 1 @parameterized.expand([ - ((44100,),), - ((3, 44100),), - ((2, 3, 44100),), - ((1, 2, 3, 44100),) + ((44100,), (4,), (44100,)), + ((3, 44100), (4,), (3, 44100,)), + ((2, 3, 44100), (4,), (2, 3, 44100,)), + ((1, 2, 3, 44100), (4,), (1, 2, 3, 44100,)), + ((44100,), (2, 4), (2, 44100)), + ((3, 44100), (1, 4), (3, 1, 44100)), + ((1, 2, 44100), (3, 4), (1, 2, 3, 44100)) ]) - def test_lfilter_shape(self, shape): - torch.random.manual_seed(42) - waveform = torch.rand(*shape, dtype=self.dtype, device=self.device) - b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) - a_coeffs = torch.tensor([1, 0, 0, 0], dtype=self.dtype, device=self.device) - output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) - assert shape == waveform.size() == output_waveform.size() - - @parameterized.expand([ - ((44100,), (2, 3), (2, 44100)), - ((3, 44100), (1, 3), (3, 1, 44100)), - ((1, 2, 44100), (3, 3), (1, 2, 3, 44100)) - ]) - def test_lfilter_filterbanks_shape(self, input_shape, coeff_shape, target_shape): + def test_lfilter_shape(self, input_shape, coeff_shape, target_shape): torch.random.manual_seed(42) waveform = torch.rand(*input_shape, dtype=self.dtype, device=self.device) b_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) a_coeffs = torch.rand(*coeff_shape, dtype=self.dtype, device=self.device) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + assert input_shape == waveform.size() assert target_shape == output_waveform.size() def test_lfilter_9th_order_filter_stability(self): From 111a43fedf0ab6ff0178e4a0648cf149be7fb901 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Tue, 22 Jun 2021 10:38:16 +0800 Subject: [PATCH 6/9] test: different number of channels --- test/torchaudio_unittest/functional/autograd_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 000a22159b..0c663481b8 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -61,7 +61,7 @@ def test_lfilter_all_inputs(self): def test_lfilter_filterbanks(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=3) a = torch.tensor([[0.7, 0.2, 0.6], [0.8, 0.2, 0.9]]) b = torch.tensor([[0.4, 0.2, 0.9], From 2669e4a46a1eebc941d6c01cb19ab3f086195717 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Sun, 27 Jun 2021 22:29:13 +0800 Subject: [PATCH 7/9] doc: state 2D shapes explicitly --- torchaudio/functional/filtering.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index acef296304..b81b284540 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -941,18 +941,19 @@ def lfilter( Args: waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. - a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(*, n_order + 1)``. - Where * is the optional number of filter banks. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either + 1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. Must be same size as b_coeffs (pad with 0's as necessary). - b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(*, n_order + 1)``. - Where * is the optional number of filter banks. + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either + 1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Must be same size as a_coeffs (pad with 0's as necessary). clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) Returns: - Tensor: Waveform with dimension of ``(..., *, time)``. + Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` are 2D Tensors, + or ``(..., time)`` otherwise. """ # pack batch shape = waveform.size() From c87c16ec1dc6888381739467c44f405547542225 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Sun, 27 Jun 2021 22:45:38 +0800 Subject: [PATCH 8/9] refactor: shape calculation --- torchaudio/functional/filtering.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index b81b284540..b4d03c6109 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -955,19 +955,18 @@ def lfilter( Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` are 2D Tensors, or ``(..., time)`` otherwise. """ - # pack batch - shape = waveform.size() assert a_coeffs.size() == b_coeffs.size() assert a_coeffs.ndim <= 2 - waveform = waveform.reshape(-1, 1, shape[-1]) if a_coeffs.ndim > 1: - shape = shape[:-1] + (a_coeffs.shape[0], shape[-1]) - waveform = waveform.repeat(1, a_coeffs.shape[0], 1) + waveform = torch.stack([waveform] * a_coeffs.shape[0], -2) else: a_coeffs = a_coeffs.unsqueeze(0) b_coeffs = b_coeffs.unsqueeze(0) + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, a_coeffs.shape[0], shape[-1]) output = _lfilter(waveform, a_coeffs, b_coeffs) if clamp: From a42af04120292c9ce2e955709539e7d33b1a21b5 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Sun, 27 Jun 2021 22:52:38 +0800 Subject: [PATCH 9/9] doc: remove trailing whitespace --- torchaudio/functional/filtering.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index b4d03c6109..827b16fd4d 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -941,19 +941,19 @@ def lfilter( Args: waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. - a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either 1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. Must be same size as b_coeffs (pad with 0's as necessary). - b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either 1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Must be same size as a_coeffs (pad with 0's as necessary). clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) Returns: - Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` are 2D Tensors, - or ``(..., time)`` otherwise. + Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` + are 2D Tensors, or ``(..., time)`` otherwise. """ assert a_coeffs.size() == b_coeffs.size() assert a_coeffs.ndim <= 2