From e0d18a056f80b275377e05fdd4bbee250faa1427 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Tue, 6 Apr 2021 21:03:36 +0800 Subject: [PATCH 01/11] refactor lfilter cpp implementation --- .../functional/autograd_impl.py | 3 +- torchaudio/csrc/lfilter.cpp | 201 +++++++----------- 2 files changed, 74 insertions(+), 130 deletions(-) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index caa444c988..7d267865b4 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -1,6 +1,6 @@ import torch import torchaudio.functional as F -from torch.autograd import gradcheck +from torch.autograd import gradcheck, gradgradcheck from torchaudio_unittest import common_utils @@ -20,6 +20,7 @@ def test_lfilter_a(self): b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) a.requires_grad = True assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + assert gradgradcheck(F.lfilter, (x, a, b)) def test_lfilter_b(self): torch.random.manual_seed(2434) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 56540c33f6..54aecaafd1 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -80,88 +80,15 @@ void lfilter_core_generic_loop( } } -std::tuple lfilter_core( - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, - const torch::Tensor& b_coeffs) { - TORCH_CHECK(waveform.device() == a_coeffs.device()); - TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); - - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); - - auto device = waveform.device(); - int64_t n_order = a_coeffs.size(0); - - TORCH_INTERNAL_ASSERT(n_order > 0); - - namespace F = torch::nn::functional; - - auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - auto padded_output_waveform = torch::zeros_like(padded_waveform); - - auto a_coeff_flipped = a_coeffs.flip(0).contiguous(); - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); - - auto input_signal_windows = - F::conv1d( - padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) - .squeeze(1); - - input_signal_windows.div_(a_coeffs[0]); - a_coeff_flipped.div_(a_coeffs[0]); +torch::Tensor iir_autograd(const torch::Tensor&, const torch::Tensor&); - if (device.is_cpu()) { - cpu_lfilter_core_loop( - input_signal_windows, a_coeff_flipped, padded_output_waveform); - } else { - lfilter_core_generic_loop( - input_signal_windows, a_coeff_flipped, padded_output_waveform); - } - - auto output = padded_output_waveform.index( - {torch::indexing::Slice(), - torch::indexing::Slice(n_order - 1, torch::indexing::None)}); - - return {output, input_signal_windows}; -} - -torch::Tensor lfilter_simple( - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, - const torch::Tensor& b_coeffs) { - return std::get<0>(lfilter_core(waveform, a_coeffs, b_coeffs)); -} - -class DifferentiableLfilter - : public torch::autograd::Function { +class DifferentiableIIR : public torch::autograd::Function { public: static torch::Tensor forward( torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, - const torch::Tensor& b_coeffs) { + const torch::Tensor& a_coeffs) { at::AutoNonVariableTypeMode g; - auto result = lfilter_core(waveform, a_coeffs, b_coeffs); - ctx->save_for_backward( - {waveform, - a_coeffs, - b_coeffs, - std::get<0>(result), - std::get<1>(result)}); - return std::get<0>(result); - } - - static torch::autograd::tensor_list backward( - torch::autograd::AutogradContext* ctx, - torch::autograd::tensor_list grad_outputs) { - auto saved = ctx->get_saved_variables(); - auto waveform = saved[0]; - auto a_coeffs = saved[1]; - auto b_coeffs = saved[2]; - auto y = saved[3]; - auto xh = saved[4]; - auto device = waveform.device(); auto dtype = waveform.dtype(); int64_t n_channel = waveform.size(0); @@ -170,26 +97,48 @@ class DifferentiableLfilter int64_t n_sample_padded = n_sample + n_order - 1; auto a_coeff_flipped = a_coeffs.flip(0).contiguous(); - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); - b_coeff_flipped.div_(a_coeffs[0]); a_coeff_flipped.div_(a_coeffs[0]); + auto options = torch::TensorOptions().dtype(dtype).device(device); + auto padded_output_waveform = + torch::zeros({n_channel, n_sample_padded}, options); + + if (device.is_cpu()) { + cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); + } else { + lfilter_core_generic_loop( + waveform, a_coeff_flipped, padded_output_waveform); + } + + auto output = padded_output_waveform.index( + {torch::indexing::Slice(), + torch::indexing::Slice(n_order - 1, torch::indexing::None)}); + + ctx->save_for_backward({waveform, a_coeffs, output}); + return output; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto x = saved[0]; + auto a_coeffs = saved[1]; + auto y = saved[2]; + + int64_t n_channel = x.size(0); + int64_t n_sample = x.size(1); + int64_t n_order = a_coeffs.size(0); + auto dx = torch::Tensor(); auto da = torch::Tensor(); - auto db = torch::Tensor(); auto dy = grad_outputs[0]; - at::AutoNonVariableTypeMode g; namespace F = torch::nn::functional; - auto options = torch::TensorOptions().dtype(dtype).device(device); if (a_coeffs.requires_grad()) { - auto dyda = torch::zeros({n_channel, n_sample_padded}, options); - if (device.is_cpu()) { - cpu_lfilter_core_loop(-y, a_coeff_flipped, dyda); - } else { - lfilter_core_generic_loop(-y, a_coeff_flipped, dyda); - } + auto dyda = F::pad( + iir_autograd(-y, a_coeffs), F::PadFuncOptions({n_order - 1, 0})); da = F::conv1d( dyda.unsqueeze(0), @@ -198,52 +147,50 @@ class DifferentiableLfilter .sum(1) .squeeze(0) .flip(0); - da.div_(a_coeffs[0]); + da.div_(a_coeffs[0].item()); } - if (b_coeffs.requires_grad() || waveform.requires_grad()) { - auto dxh = torch::zeros({n_channel, n_sample_padded}, options); - if (device.is_cpu()) { - cpu_lfilter_core_loop(dy.flip(1), a_coeff_flipped, dxh); - } else { - lfilter_core_generic_loop(dy.flip(1), a_coeff_flipped, dxh); - } - - dxh = dxh.index( - {torch::indexing::Slice(), - torch::indexing::Slice(n_order - 1, torch::indexing::None)}) - .flip(1); - - if (waveform.requires_grad()) { - dx = F::conv1d( - F::pad(dxh.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), - b_coeffs.view({1, 1, n_order})) - .squeeze(1); - dx.div_(a_coeffs[0]); - } - if (b_coeffs.requires_grad()) { - db = - F::conv1d( - F::pad( - waveform.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), - dxh.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)) - .sum(1) - .squeeze(0) - .flip(0); - db.div_(a_coeffs[0]); - } + if (x.requires_grad()) { + dx = iir_autograd(dy.flip(1), a_coeffs).flip(1); } - return {dx, da, db}; + return {dx, da}; } }; -torch::Tensor lfilter_autograd( +torch::Tensor iir_autograd( + const torch::Tensor& waveform, + const torch::Tensor& a_coeffs) { + return DifferentiableIIR::apply(waveform, a_coeffs); +} + +torch::Tensor lfilter_core( const torch::Tensor& waveform, const torch::Tensor& a_coeffs, const torch::Tensor& b_coeffs) { - return DifferentiableLfilter::apply(waveform, a_coeffs, b_coeffs); + TORCH_CHECK(waveform.device() == a_coeffs.device()); + TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); + TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); + + TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); + + int64_t n_order = b_coeffs.size(0); + + TORCH_INTERNAL_ASSERT(n_order > 0); + + namespace F = torch::nn::functional; + + auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); + auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + b_coeff_flipped.div_(a_coeffs[0].item()); + + auto filtered_waveform = + F::conv1d( + padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) + .squeeze(1); + + auto output = iir_autograd(filtered_waveform, a_coeffs); + return output; } } // namespace @@ -259,10 +206,6 @@ TORCH_LIBRARY(torchaudio, m) { "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); } -TORCH_LIBRARY_IMPL(torchaudio, DefaultBackend, m) { - m.impl("torchaudio::_lfilter", lfilter_simple); -} - -TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("torchaudio::_lfilter", lfilter_autograd); +TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) { + m.impl("torchaudio::_lfilter", lfilter_core); } From 3b17e4000a9266dd345ca331ffe0f6b74a454af9 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 7 Apr 2021 09:54:27 +0800 Subject: [PATCH 02/11] add comments --- torchaudio/csrc/lfilter.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 54aecaafd1..eadc143081 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -147,6 +147,7 @@ class DifferentiableIIR : public torch::autograd::Function { .sum(1) .squeeze(0) .flip(0); + // use .item() to detach from graph da.div_(a_coeffs[0].item()); } @@ -182,6 +183,8 @@ torch::Tensor lfilter_core( auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + + // use .item() to detach a[0] from graph b_coeff_flipped.div_(a_coeffs[0].item()); auto filtered_waveform = From 2bad0fe8d228b7bc56a951ec033141e4d9ce75b6 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 7 Apr 2021 15:53:58 +0800 Subject: [PATCH 03/11] feat: normalize a_coeff outside custom function --- torchaudio/csrc/lfilter.cpp | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index eadc143081..7d2ca03743 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -87,17 +87,16 @@ class DifferentiableIIR : public torch::autograd::Function { static torch::Tensor forward( torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, - const torch::Tensor& a_coeffs) { + const torch::Tensor& a_coeffs_normalized) { at::AutoNonVariableTypeMode g; auto device = waveform.device(); auto dtype = waveform.dtype(); int64_t n_channel = waveform.size(0); int64_t n_sample = waveform.size(1); - int64_t n_order = a_coeffs.size(0); + int64_t n_order = a_coeffs_normalized.size(0); int64_t n_sample_padded = n_sample + n_order - 1; - auto a_coeff_flipped = a_coeffs.flip(0).contiguous(); - a_coeff_flipped.div_(a_coeffs[0]); + auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous(); auto options = torch::TensorOptions().dtype(dtype).device(device); auto padded_output_waveform = @@ -114,7 +113,7 @@ class DifferentiableIIR : public torch::autograd::Function { {torch::indexing::Slice(), torch::indexing::Slice(n_order - 1, torch::indexing::None)}); - ctx->save_for_backward({waveform, a_coeffs, output}); + ctx->save_for_backward({waveform, a_coeffs_normalized, output}); return output; } @@ -123,12 +122,11 @@ class DifferentiableIIR : public torch::autograd::Function { torch::autograd::tensor_list grad_outputs) { auto saved = ctx->get_saved_variables(); auto x = saved[0]; - auto a_coeffs = saved[1]; + auto a_coeffs_normalized = saved[1]; auto y = saved[2]; int64_t n_channel = x.size(0); - int64_t n_sample = x.size(1); - int64_t n_order = a_coeffs.size(0); + int64_t n_order = a_coeffs_normalized.size(0); auto dx = torch::Tensor(); auto da = torch::Tensor(); @@ -136,9 +134,10 @@ class DifferentiableIIR : public torch::autograd::Function { namespace F = torch::nn::functional; - if (a_coeffs.requires_grad()) { + if (a_coeffs_normalized.requires_grad()) { auto dyda = F::pad( - iir_autograd(-y, a_coeffs), F::PadFuncOptions({n_order - 1, 0})); + iir_autograd(-y, a_coeffs_normalized), + F::PadFuncOptions({n_order - 1, 0})); da = F::conv1d( dyda.unsqueeze(0), @@ -147,12 +146,10 @@ class DifferentiableIIR : public torch::autograd::Function { .sum(1) .squeeze(0) .flip(0); - // use .item() to detach from graph - da.div_(a_coeffs[0].item()); } if (x.requires_grad()) { - dx = iir_autograd(dy.flip(1), a_coeffs).flip(1); + dx = iir_autograd(dy.flip(1), a_coeffs_normalized).flip(1); } return {dx, da}; @@ -161,8 +158,8 @@ class DifferentiableIIR : public torch::autograd::Function { torch::Tensor iir_autograd( const torch::Tensor& waveform, - const torch::Tensor& a_coeffs) { - return DifferentiableIIR::apply(waveform, a_coeffs); + const torch::Tensor& a_coeffs_normalized) { + return DifferentiableIIR::apply(waveform, a_coeffs_normalized); } torch::Tensor lfilter_core( @@ -184,15 +181,14 @@ torch::Tensor lfilter_core( auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); - // use .item() to detach a[0] from graph - b_coeff_flipped.div_(a_coeffs[0].item()); + b_coeff_flipped.div_(a_coeffs[0]); auto filtered_waveform = F::conv1d( padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) .squeeze(1); - auto output = iir_autograd(filtered_waveform, a_coeffs); + auto output = iir_autograd(filtered_waveform, a_coeffs / a_coeffs[0]); return output; } From 38d1b4ffba8c08afa2fdee3ace49a352341ee933 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 7 Apr 2021 22:25:16 +0800 Subject: [PATCH 04/11] test: add gradgradcheck --- test/torchaudio_unittest/functional/autograd_impl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 7d267865b4..5c3209216c 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -12,6 +12,7 @@ def test_lfilter_x(self): b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) x.requires_grad = True assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + assert gradgradcheck(F.lfilter, (x, a, b)) def test_lfilter_a(self): torch.random.manual_seed(2434) @@ -29,6 +30,7 @@ def test_lfilter_b(self): b = torch.tensor([0.4, 0.2, 0.9], dtype=self.dtype, device=self.device) b.requires_grad = True assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + assert gradgradcheck(F.lfilter, (x, a, b)) def test_lfilter_all_inputs(self): torch.random.manual_seed(2434) @@ -39,6 +41,7 @@ def test_lfilter_all_inputs(self): a.requires_grad = True x.requires_grad = True assert gradcheck(F.lfilter, (x, a, b), eps=1e-10) + assert gradgradcheck(F.lfilter, (x, a, b)) def test_biquad(self): torch.random.manual_seed(2434) From 92fe68fb72f41a899bed529de8a1e1542403d344 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Thu, 8 Apr 2021 10:55:36 +0800 Subject: [PATCH 05/11] refactor: remote iir_autograd --- torchaudio/csrc/lfilter.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 7d2ca03743..4ba41f751c 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -80,8 +80,6 @@ void lfilter_core_generic_loop( } } -torch::Tensor iir_autograd(const torch::Tensor&, const torch::Tensor&); - class DifferentiableIIR : public torch::autograd::Function { public: static torch::Tensor forward( @@ -136,7 +134,7 @@ class DifferentiableIIR : public torch::autograd::Function { if (a_coeffs_normalized.requires_grad()) { auto dyda = F::pad( - iir_autograd(-y, a_coeffs_normalized), + DifferentiableIIR::apply(-y, a_coeffs_normalized), F::PadFuncOptions({n_order - 1, 0})); da = F::conv1d( @@ -149,19 +147,13 @@ class DifferentiableIIR : public torch::autograd::Function { } if (x.requires_grad()) { - dx = iir_autograd(dy.flip(1), a_coeffs_normalized).flip(1); + dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1); } return {dx, da}; } }; -torch::Tensor iir_autograd( - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs_normalized) { - return DifferentiableIIR::apply(waveform, a_coeffs_normalized); -} - torch::Tensor lfilter_core( const torch::Tensor& waveform, const torch::Tensor& a_coeffs, @@ -188,7 +180,8 @@ torch::Tensor lfilter_core( padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) .squeeze(1); - auto output = iir_autograd(filtered_waveform, a_coeffs / a_coeffs[0]); + auto output = + DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]); return output; } From 20cc31d575d4e9ac9028ea9d45503115de212eaa Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Thu, 8 Apr 2021 11:30:04 +0800 Subject: [PATCH 06/11] test: add 2nd order gradient check --- test/torchaudio_unittest/functional/autograd_impl.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 2d9dca76d6..9f54d0680c 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -3,7 +3,7 @@ from parameterized import parameterized from torch import Tensor import torchaudio.functional as F -from torch.autograd import gradcheck +from torch.autograd import gradcheck, gradgradcheck from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, @@ -26,6 +26,7 @@ def assert_grad( i.requires_grad = True inputs_.append(i) assert gradcheck(transform, inputs_) + assert gradgradcheck(transform, inputs_) def test_lfilter_x(self): torch.random.manual_seed(2434) @@ -43,14 +44,6 @@ def test_lfilter_a(self): a.requires_grad = True self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) - def test_lfilter_b(self): - torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) - a = torch.tensor([0.7, 0.2, 0.6]) - b = torch.tensor([0.4, 0.2, 0.9]) - b.requires_grad = True - self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) - def test_lfilter_all_inputs(self): torch.random.manual_seed(2434) x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) From 1ebaa027df8995571723e836fffbca39bea55fa0 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Thu, 8 Apr 2021 12:00:01 +0800 Subject: [PATCH 07/11] test: use shorter sequence for faster runtime --- .../functional/autograd_impl.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 9f54d0680c..4c85d6f3f5 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -38,7 +38,7 @@ def test_lfilter_x(self): def test_lfilter_a(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) a.requires_grad = True @@ -46,14 +46,14 @@ def test_lfilter_a(self): def test_lfilter_all_inputs(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) self.assert_grad(F.lfilter, (x, a, b)) def test_biquad(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=1) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2])) @@ -65,7 +65,7 @@ def test_biquad(self): def test_band_biquad(self, central_freq, Q, noise): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise)) @@ -77,7 +77,7 @@ def test_band_biquad(self, central_freq, Q, noise): def test_bass_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) gain = torch.tensor(gain) @@ -91,7 +91,7 @@ def test_bass_biquad(self, central_freq, Q, gain): def test_treble_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) gain = torch.tensor(gain) @@ -103,7 +103,7 @@ def test_treble_biquad(self, central_freq, Q, gain): def test_allpass_biquad(self, central_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q)) @@ -114,7 +114,7 @@ def test_allpass_biquad(self, central_freq, Q): def test_lowpass_biquad(self, cutoff_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) cutoff_freq = torch.tensor(cutoff_freq) Q = torch.tensor(Q) self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) @@ -125,7 +125,7 @@ def test_lowpass_biquad(self, cutoff_freq, Q): def test_highpass_biquad(self, cutoff_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) cutoff_freq = torch.tensor(cutoff_freq) Q = torch.tensor(Q) self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q)) @@ -137,7 +137,7 @@ def test_highpass_biquad(self, cutoff_freq, Q): def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain)) @@ -149,7 +149,7 @@ def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): def test_equalizer_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) gain = torch.tensor(gain) @@ -161,7 +161,7 @@ def test_equalizer_biquad(self, central_freq, Q, gain): def test_bandreject_biquad(self, central_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=2) + x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q)) From bf2df526377faa39e115f84ea9e4eb4c5525b5a0 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Thu, 8 Apr 2021 20:22:41 +0800 Subject: [PATCH 08/11] test: use even shorter inputs --- .../functional/autograd_impl.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index 4c85d6f3f5..d2c0684a64 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -30,7 +30,7 @@ def assert_grad( def test_lfilter_x(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) x.requires_grad = True @@ -38,7 +38,7 @@ def test_lfilter_x(self): def test_lfilter_a(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) a.requires_grad = True @@ -46,14 +46,14 @@ def test_lfilter_a(self): def test_lfilter_all_inputs(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=2) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) self.assert_grad(F.lfilter, (x, a, b)) def test_biquad(self): torch.random.manual_seed(2434) - x = get_whitenoise(sample_rate=22050, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=1) a = torch.tensor([0.7, 0.2, 0.6]) b = torch.tensor([0.4, 0.2, 0.9]) self.assert_grad(F.biquad, (x, b[0], b[1], b[2], a[0], a[1], a[2])) @@ -65,7 +65,7 @@ def test_biquad(self): def test_band_biquad(self, central_freq, Q, noise): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.band_biquad, (x, sr, central_freq, Q, noise)) @@ -77,7 +77,7 @@ def test_band_biquad(self, central_freq, Q, noise): def test_bass_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) gain = torch.tensor(gain) @@ -91,7 +91,7 @@ def test_bass_biquad(self, central_freq, Q, gain): def test_treble_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) gain = torch.tensor(gain) @@ -103,7 +103,7 @@ def test_treble_biquad(self, central_freq, Q, gain): def test_allpass_biquad(self, central_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.allpass_biquad, (x, sr, central_freq, Q)) @@ -114,7 +114,7 @@ def test_allpass_biquad(self, central_freq, Q): def test_lowpass_biquad(self, cutoff_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) cutoff_freq = torch.tensor(cutoff_freq) Q = torch.tensor(Q) self.assert_grad(F.lowpass_biquad, (x, sr, cutoff_freq, Q)) @@ -125,7 +125,7 @@ def test_lowpass_biquad(self, cutoff_freq, Q): def test_highpass_biquad(self, cutoff_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) cutoff_freq = torch.tensor(cutoff_freq) Q = torch.tensor(Q) self.assert_grad(F.highpass_biquad, (x, sr, cutoff_freq, Q)) @@ -137,7 +137,7 @@ def test_highpass_biquad(self, cutoff_freq, Q): def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.bandpass_biquad, (x, sr, central_freq, Q, const_skirt_gain)) @@ -149,7 +149,7 @@ def test_bandpass_biquad(self, central_freq, Q, const_skirt_gain): def test_equalizer_biquad(self, central_freq, Q, gain): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) gain = torch.tensor(gain) @@ -161,7 +161,7 @@ def test_equalizer_biquad(self, central_freq, Q, gain): def test_bandreject_biquad(self, central_freq, Q): torch.random.manual_seed(2434) sr = 22050 - x = get_whitenoise(sample_rate=sr, duration=0.025, n_channels=1) + x = get_whitenoise(sample_rate=sr, duration=0.01, n_channels=1) central_freq = torch.tensor(central_freq) Q = torch.tensor(Q) self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q)) From bb1a7928f35e694dba102a0cd8c462bfeebcb194 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Mon, 3 May 2021 09:49:07 +0800 Subject: [PATCH 09/11] feat: add custom FIR for performance reason --- torchaudio/csrc/lfilter.cpp | 72 +++++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 4ba41f751c..740f9210df 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -154,6 +154,67 @@ class DifferentiableIIR : public torch::autograd::Function { } }; +class DifferentiableFIR : public torch::autograd::Function { + public: + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& waveform, + const torch::Tensor& b_coeffs) { + at::AutoNonVariableTypeMode g; + int64_t n_order = b_coeffs.size(0); + + namespace F = torch::nn::functional; + auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + auto padded_waveform = + F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); + + auto output = + F::conv1d( + padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) + .squeeze(1); + + ctx->save_for_backward({waveform, b_coeffs, output}); + return output; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto x = saved[0]; + auto b_coeffs = saved[1]; + auto y = saved[2]; + + int64_t n_channel = x.size(0); + int64_t n_order = b_coeffs.size(0); + + auto dx = torch::Tensor(); + auto db = torch::Tensor(); + auto dy = grad_outputs[0]; + + namespace F = torch::nn::functional; + + if (b_coeffs.requires_grad()) { + db = F::conv1d( + F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), + dy.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)) + .sum(1) + .squeeze(0) + .flip(0); + } + + if (x.requires_grad()) { + dx = F::conv1d( + F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), + b_coeffs.view({1, 1, n_order})) + .squeeze(1); + } + + return {dx, db}; + } +}; + torch::Tensor lfilter_core( const torch::Tensor& waveform, const torch::Tensor& a_coeffs, @@ -168,17 +229,8 @@ torch::Tensor lfilter_core( TORCH_INTERNAL_ASSERT(n_order > 0); - namespace F = torch::nn::functional; - - auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); - - b_coeff_flipped.div_(a_coeffs[0]); - auto filtered_waveform = - F::conv1d( - padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) - .squeeze(1); + DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]); auto output = DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]); From e490090d064279d92379f14c07b2494b7858be59 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Tue, 4 May 2021 09:43:48 +0800 Subject: [PATCH 10/11] remove unnecessarry autograd guard --- torchaudio/csrc/lfilter.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 740f9210df..dcc5dbe442 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -86,7 +86,6 @@ class DifferentiableIIR : public torch::autograd::Function { torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, const torch::Tensor& a_coeffs_normalized) { - at::AutoNonVariableTypeMode g; auto device = waveform.device(); auto dtype = waveform.dtype(); int64_t n_channel = waveform.size(0); @@ -160,7 +159,6 @@ class DifferentiableFIR : public torch::autograd::Function { torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, const torch::Tensor& b_coeffs) { - at::AutoNonVariableTypeMode g; int64_t n_order = b_coeffs.size(0); namespace F = torch::nn::functional; From 1bc3f45541a4a7632b5e70e3e98ab7591e923a19 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 5 May 2021 10:57:07 +0800 Subject: [PATCH 11/11] test: bring back b_coeff test --- test/torchaudio_unittest/functional/autograd_impl.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index d2c0684a64..049ae25c90 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -44,6 +44,14 @@ def test_lfilter_a(self): a.requires_grad = True self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) + def test_lfilter_b(self): + torch.random.manual_seed(2434) + x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2) + a = torch.tensor([0.7, 0.2, 0.6]) + b = torch.tensor([0.4, 0.2, 0.9]) + b.requires_grad = True + self.assert_grad(F.lfilter, (x, a, b), enable_all_grad=False) + def test_lfilter_all_inputs(self): torch.random.manual_seed(2434) x = get_whitenoise(sample_rate=22050, duration=0.01, n_channels=2)