diff --git a/test/torchaudio_unittest/functional/autograd_impl.py b/test/torchaudio_unittest/functional/autograd_impl.py index e273369ea6..049ae25c90 100644 --- a/test/torchaudio_unittest/functional/autograd_impl.py +++ b/test/torchaudio_unittest/functional/autograd_impl.py @@ -3,7 +3,7 @@ from parameterized import parameterized from torch import Tensor import torchaudio.functional as F -from torch.autograd import gradcheck +from torch.autograd import gradcheck, gradgradcheck from torchaudio_unittest.common_utils import ( TestBaseMixin, get_whitenoise, @@ -26,6 +26,7 @@ def assert_grad( i.requires_grad = True inputs_.append(i) assert gradcheck(transform, inputs_) + assert gradgradcheck(transform, inputs_) def test_lfilter_x(self): torch.random.manual_seed(2434) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 56540c33f6..dcc5dbe442 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -80,170 +80,159 @@ void lfilter_core_generic_loop( } } -std::tuple lfilter_core( - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, - const torch::Tensor& b_coeffs) { - TORCH_CHECK(waveform.device() == a_coeffs.device()); - TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); - TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); +class DifferentiableIIR : public torch::autograd::Function { + public: + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& waveform, + const torch::Tensor& a_coeffs_normalized) { + auto device = waveform.device(); + auto dtype = waveform.dtype(); + int64_t n_channel = waveform.size(0); + int64_t n_sample = waveform.size(1); + int64_t n_order = a_coeffs_normalized.size(0); + int64_t n_sample_padded = n_sample + n_order - 1; - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); + auto a_coeff_flipped = a_coeffs_normalized.flip(0).contiguous(); - auto device = waveform.device(); - int64_t n_order = a_coeffs.size(0); + auto options = torch::TensorOptions().dtype(dtype).device(device); + auto padded_output_waveform = + torch::zeros({n_channel, n_sample_padded}, options); + + if (device.is_cpu()) { + cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform); + } else { + lfilter_core_generic_loop( + waveform, a_coeff_flipped, padded_output_waveform); + } - TORCH_INTERNAL_ASSERT(n_order > 0); + auto output = padded_output_waveform.index( + {torch::indexing::Slice(), + torch::indexing::Slice(n_order - 1, torch::indexing::None)}); - namespace F = torch::nn::functional; + ctx->save_for_backward({waveform, a_coeffs_normalized, output}); + return output; + } - auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); - auto padded_output_waveform = torch::zeros_like(padded_waveform); + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto x = saved[0]; + auto a_coeffs_normalized = saved[1]; + auto y = saved[2]; - auto a_coeff_flipped = a_coeffs.flip(0).contiguous(); - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + int64_t n_channel = x.size(0); + int64_t n_order = a_coeffs_normalized.size(0); - auto input_signal_windows = - F::conv1d( - padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) - .squeeze(1); + auto dx = torch::Tensor(); + auto da = torch::Tensor(); + auto dy = grad_outputs[0]; - input_signal_windows.div_(a_coeffs[0]); - a_coeff_flipped.div_(a_coeffs[0]); + namespace F = torch::nn::functional; - if (device.is_cpu()) { - cpu_lfilter_core_loop( - input_signal_windows, a_coeff_flipped, padded_output_waveform); - } else { - lfilter_core_generic_loop( - input_signal_windows, a_coeff_flipped, padded_output_waveform); - } + if (a_coeffs_normalized.requires_grad()) { + auto dyda = F::pad( + DifferentiableIIR::apply(-y, a_coeffs_normalized), + F::PadFuncOptions({n_order - 1, 0})); - auto output = padded_output_waveform.index( - {torch::indexing::Slice(), - torch::indexing::Slice(n_order - 1, torch::indexing::None)}); + da = F::conv1d( + dyda.unsqueeze(0), + dy.unsqueeze(1), + F::Conv1dFuncOptions().groups(n_channel)) + .sum(1) + .squeeze(0) + .flip(0); + } - return {output, input_signal_windows}; -} + if (x.requires_grad()) { + dx = DifferentiableIIR::apply(dy.flip(1), a_coeffs_normalized).flip(1); + } -torch::Tensor lfilter_simple( - const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, - const torch::Tensor& b_coeffs) { - return std::get<0>(lfilter_core(waveform, a_coeffs, b_coeffs)); -} + return {dx, da}; + } +}; -class DifferentiableLfilter - : public torch::autograd::Function { +class DifferentiableFIR : public torch::autograd::Function { public: static torch::Tensor forward( torch::autograd::AutogradContext* ctx, const torch::Tensor& waveform, - const torch::Tensor& a_coeffs, const torch::Tensor& b_coeffs) { - at::AutoNonVariableTypeMode g; - auto result = lfilter_core(waveform, a_coeffs, b_coeffs); - ctx->save_for_backward( - {waveform, - a_coeffs, - b_coeffs, - std::get<0>(result), - std::get<1>(result)}); - return std::get<0>(result); + int64_t n_order = b_coeffs.size(0); + + namespace F = torch::nn::functional; + auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + auto padded_waveform = + F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); + + auto output = + F::conv1d( + padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) + .squeeze(1); + + ctx->save_for_backward({waveform, b_coeffs, output}); + return output; } static torch::autograd::tensor_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) { auto saved = ctx->get_saved_variables(); - auto waveform = saved[0]; - auto a_coeffs = saved[1]; - auto b_coeffs = saved[2]; - auto y = saved[3]; - auto xh = saved[4]; - - auto device = waveform.device(); - auto dtype = waveform.dtype(); - int64_t n_channel = waveform.size(0); - int64_t n_sample = waveform.size(1); - int64_t n_order = a_coeffs.size(0); - int64_t n_sample_padded = n_sample + n_order - 1; + auto x = saved[0]; + auto b_coeffs = saved[1]; + auto y = saved[2]; - auto a_coeff_flipped = a_coeffs.flip(0).contiguous(); - auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); - b_coeff_flipped.div_(a_coeffs[0]); - a_coeff_flipped.div_(a_coeffs[0]); + int64_t n_channel = x.size(0); + int64_t n_order = b_coeffs.size(0); auto dx = torch::Tensor(); - auto da = torch::Tensor(); auto db = torch::Tensor(); auto dy = grad_outputs[0]; - at::AutoNonVariableTypeMode g; namespace F = torch::nn::functional; - auto options = torch::TensorOptions().dtype(dtype).device(device); - if (a_coeffs.requires_grad()) { - auto dyda = torch::zeros({n_channel, n_sample_padded}, options); - if (device.is_cpu()) { - cpu_lfilter_core_loop(-y, a_coeff_flipped, dyda); - } else { - lfilter_core_generic_loop(-y, a_coeff_flipped, dyda); - } - - da = F::conv1d( - dyda.unsqueeze(0), + if (b_coeffs.requires_grad()) { + db = F::conv1d( + F::pad(x.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), dy.unsqueeze(1), F::Conv1dFuncOptions().groups(n_channel)) .sum(1) .squeeze(0) .flip(0); - da.div_(a_coeffs[0]); } - if (b_coeffs.requires_grad() || waveform.requires_grad()) { - auto dxh = torch::zeros({n_channel, n_sample_padded}, options); - if (device.is_cpu()) { - cpu_lfilter_core_loop(dy.flip(1), a_coeff_flipped, dxh); - } else { - lfilter_core_generic_loop(dy.flip(1), a_coeff_flipped, dxh); - } - - dxh = dxh.index( - {torch::indexing::Slice(), - torch::indexing::Slice(n_order - 1, torch::indexing::None)}) - .flip(1); - - if (waveform.requires_grad()) { - dx = F::conv1d( - F::pad(dxh.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), - b_coeffs.view({1, 1, n_order})) - .squeeze(1); - dx.div_(a_coeffs[0]); - } - if (b_coeffs.requires_grad()) { - db = - F::conv1d( - F::pad( - waveform.unsqueeze(0), F::PadFuncOptions({n_order - 1, 0})), - dxh.unsqueeze(1), - F::Conv1dFuncOptions().groups(n_channel)) - .sum(1) - .squeeze(0) - .flip(0); - db.div_(a_coeffs[0]); - } + if (x.requires_grad()) { + dx = F::conv1d( + F::pad(dy.unsqueeze(1), F::PadFuncOptions({0, n_order - 1})), + b_coeffs.view({1, 1, n_order})) + .squeeze(1); } - return {dx, da, db}; + return {dx, db}; } }; -torch::Tensor lfilter_autograd( +torch::Tensor lfilter_core( const torch::Tensor& waveform, const torch::Tensor& a_coeffs, const torch::Tensor& b_coeffs) { - return DifferentiableLfilter::apply(waveform, a_coeffs, b_coeffs); + TORCH_CHECK(waveform.device() == a_coeffs.device()); + TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); + TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); + + TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); + + int64_t n_order = b_coeffs.size(0); + + TORCH_INTERNAL_ASSERT(n_order > 0); + + auto filtered_waveform = + DifferentiableFIR::apply(waveform, b_coeffs / a_coeffs[0]); + + auto output = + DifferentiableIIR::apply(filtered_waveform, a_coeffs / a_coeffs[0]); + return output; } } // namespace @@ -259,10 +248,6 @@ TORCH_LIBRARY(torchaudio, m) { "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); } -TORCH_LIBRARY_IMPL(torchaudio, DefaultBackend, m) { - m.impl("torchaudio::_lfilter", lfilter_simple); -} - -TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("torchaudio::_lfilter", lfilter_autograd); +TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) { + m.impl("torchaudio::_lfilter", lfilter_core); }