From 04c0e21989946fa2ba1706030dc1ca60c6e14f16 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 26 Feb 2021 11:57:14 +0800 Subject: [PATCH 1/7] add cpp implementation of lfilter --- torchaudio/csrc/lfilter.cpp | 81 +++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index af9425fd15..fd0d26d3d7 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -1,4 +1,5 @@ #include +#include namespace { @@ -62,6 +63,77 @@ void cpu_lfilter_core_loop( }); } +void lfilter_core_generic_loop( + const torch::Tensor& input_signal_windows, + const torch::Tensor& a_coeff_flipped, + torch::Tensor& padded_output_waveform) { + int64_t n_samples_input = input_signal_windows.size(1); + int64_t n_order = a_coeff_flipped.size(0); + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { + auto windowed_output_signal = padded_output_waveform.index( + {torch::indexing::Slice(), + torch::indexing::Slice(i_sample, i_sample + n_order)}); + auto o0 = input_signal_windows.index({torch::indexing::Slice(), i_sample}) + .addmv(windowed_output_signal, a_coeff_flipped, 1, -1); + padded_output_waveform.index_put_( + {torch::indexing::Slice(), i_sample + n_order - 1}, o0); + } +} + +torch::Tensor lfilter_core( + const torch::Tensor& raw_waveform, + const torch::Tensor& a_coeffs, + const torch::Tensor& b_coeffs) { + TORCH_CHECK(raw_waveform.device() == a_coeffs.device()); + TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); + TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); + + torch::Tensor waveform = raw_waveform.contiguous(); + + auto shape = waveform.sizes(); + waveform = waveform.view({-1, waveform.size(-1)}); + + TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); + + auto device = waveform.device(); + int64_t n_order = a_coeffs.size(0); + + TORCH_INTERNAL_ASSERT(n_order > 0); + + namespace F = torch::nn::functional; + + auto padded_waveform = F::pad(waveform, F::PadFuncOptions({n_order - 1, 0})); + auto padded_output_waveform = torch::zeros_like(padded_waveform); + + auto a_coeff_flipped = a_coeffs.flip(0).contiguous(); + auto b_coeff_flipped = b_coeffs.flip(0).contiguous(); + + auto input_signal_windows = + F::conv1d( + padded_waveform.unsqueeze(1), b_coeff_flipped.view({1, 1, n_order})) + .squeeze(1); + + input_signal_windows.div_(a_coeffs[0]); + a_coeff_flipped.div_(a_coeffs[0]); + + if (device.is_cpu()) { + cpu_lfilter_core_loop( + input_signal_windows, a_coeff_flipped, padded_output_waveform); + } else { + lfilter_core_generic_loop( + input_signal_windows, a_coeff_flipped, padded_output_waveform); + } + + auto output = + padded_output_waveform + .index( + {torch::indexing::Slice(), + torch::indexing::Slice(n_order - 1, torch::indexing::None)}) + .view(shape); + + return output; +} + } // namespace // Note: We want to avoid using "catch-all" kernel. @@ -69,3 +141,12 @@ void cpu_lfilter_core_loop( TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); } + +TORCH_LIBRARY(torchaudio, m) { + m.def( + "torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(torchaudio, Math, m) { + m.impl("torchaudio::_lfilter", lfilter_core); +} \ No newline at end of file From 9d3b0915117815ef5a59356cf887a623b947fc9e Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 26 Feb 2021 12:09:30 +0800 Subject: [PATCH 2/7] move lfilter forward part into cpp backend --- torchaudio/csrc/lfilter.cpp | 13 +++---- torchaudio/functional/filtering.py | 57 +++++++++++++++++++----------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index fd0d26d3d7..813a69a146 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -124,14 +124,11 @@ torch::Tensor lfilter_core( input_signal_windows, a_coeff_flipped, padded_output_waveform); } - auto output = - padded_output_waveform - .index( - {torch::indexing::Slice(), - torch::indexing::Slice(n_order - 1, torch::indexing::None)}) - .view(shape); - - return output; + auto output = padded_output_waveform.index( + {torch::indexing::Slice(), + torch::indexing::Slice(n_order - 1, torch::indexing::None)}); + + return output.view(shape); } } // namespace diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index a9c359d6ed..ad2976ac13 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -825,28 +825,11 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T _lfilter_core_cpu_loop = _lfilter_core_generic_loop -def lfilter( +def _lfilter_core( waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, - clamp: bool = True, ) -> Tensor: - r"""Perform an IIR filter by evaluating difference equation. - - Args: - waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. - a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. - Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. - Must be same size as b_coeffs (pad with 0's as necessary). - b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. - Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. - Must be same size as a_coeffs (pad with 0's as necessary). - clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) - - Returns: - Tensor: Waveform with dimension of ``(..., time)``. - """ - # pack batch shape = waveform.size() waveform = waveform.reshape(-1, shape[-1]) @@ -887,11 +870,43 @@ def lfilter( output = padded_output_waveform[:, n_order - 1:] - if clamp: - output = torch.clamp(output, min=-1.0, max=1.0) - # unpack batch output = output.reshape(shape[:-1] + output.shape[-1:]) + return output + +try: + _lfilter = torch.ops.torchaudio._lfilter +except RuntimeError as err: + assert str(err) == 'No such operator torchaudio::_lfilter' + _lfilter = _lfilter_core + + +def lfilter( + waveform: Tensor, + a_coeffs: Tensor, + b_coeffs: Tensor, + clamp: bool = True, +) -> Tensor: + r"""Perform an IIR filter by evaluating difference equation. + + Args: + waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. + a_coeffs (Tensor): denominator coefficients of difference equation of dimension of ``(n_order + 1)``. + Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. + Must be same size as b_coeffs (pad with 0's as necessary). + b_coeffs (Tensor): numerator coefficients of difference equation of dimension of ``(n_order + 1)``. + Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. + Must be same size as a_coeffs (pad with 0's as necessary). + clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) + + Returns: + Tensor: Waveform with dimension of ``(..., time)``. + """ + # pack batch + output = _lfilter(waveform, a_coeffs, b_coeffs) + + if clamp: + output = torch.clamp(output, min=-1.0, max=1.0) return output From 4493961daf7e0ec73b02c658dd4eca4f594817d2 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Fri, 26 Feb 2021 12:12:12 +0800 Subject: [PATCH 3/7] add new line in the end --- torchaudio/csrc/lfilter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 813a69a146..c2c3e6e199 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -146,4 +146,4 @@ TORCH_LIBRARY(torchaudio, m) { TORCH_LIBRARY_IMPL(torchaudio, Math, m) { m.impl("torchaudio::_lfilter", lfilter_core); -} \ No newline at end of file +} From 8aced5c18a4729b9f006559d72e41204b5db892f Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Tue, 2 Mar 2021 11:56:13 +0800 Subject: [PATCH 4/7] use reshape instead of view --- torchaudio/csrc/lfilter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index c2c3e6e199..2e17c7cb74 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -128,7 +128,7 @@ torch::Tensor lfilter_core( {torch::indexing::Slice(), torch::indexing::Slice(n_order - 1, torch::indexing::None)}); - return output.view(shape); + return output.reshape(shape); } } // namespace From fa460e06f4eea17e5c3a03224d0675482001a515 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 3 Mar 2021 10:49:36 +0800 Subject: [PATCH 5/7] move packing and unpacking batch to outer function --- torchaudio/csrc/lfilter.cpp | 11 +++-------- torchaudio/functional/filtering.py | 11 ++++++----- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/torchaudio/csrc/lfilter.cpp b/torchaudio/csrc/lfilter.cpp index 2e17c7cb74..c7ce6f9538 100644 --- a/torchaudio/csrc/lfilter.cpp +++ b/torchaudio/csrc/lfilter.cpp @@ -81,18 +81,13 @@ void lfilter_core_generic_loop( } torch::Tensor lfilter_core( - const torch::Tensor& raw_waveform, + const torch::Tensor& waveform, const torch::Tensor& a_coeffs, const torch::Tensor& b_coeffs) { - TORCH_CHECK(raw_waveform.device() == a_coeffs.device()); + TORCH_CHECK(waveform.device() == a_coeffs.device()); TORCH_CHECK(b_coeffs.device() == a_coeffs.device()); TORCH_CHECK(a_coeffs.size(0) == b_coeffs.size(0)); - torch::Tensor waveform = raw_waveform.contiguous(); - - auto shape = waveform.sizes(); - waveform = waveform.view({-1, waveform.size(-1)}); - TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 2); auto device = waveform.device(); @@ -128,7 +123,7 @@ torch::Tensor lfilter_core( {torch::indexing::Slice(), torch::indexing::Slice(n_order - 1, torch::indexing::None)}); - return output.reshape(shape); + return output; } } // namespace diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index ad2976ac13..a5c6c8d81c 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -830,8 +830,6 @@ def _lfilter_core( a_coeffs: Tensor, b_coeffs: Tensor, ) -> Tensor: - shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) assert a_coeffs.size(0) == b_coeffs.size(0) assert len(waveform.size()) == 2 @@ -869,9 +867,6 @@ def _lfilter_core( _lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform) output = padded_output_waveform[:, n_order - 1:] - - # unpack batch - output = output.reshape(shape[:-1] + output.shape[-1:]) return output try: @@ -903,11 +898,17 @@ def lfilter( Tensor: Waveform with dimension of ``(..., time)``. """ # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + output = _lfilter(waveform, a_coeffs, b_coeffs) if clamp: output = torch.clamp(output, min=-1.0, max=1.0) + # unpack batch + output = output.reshape(*shape) + return output From e6a80d7e3af33c11a75d3b5a9e3a1692d740e722 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 3 Mar 2021 11:07:34 +0800 Subject: [PATCH 6/7] explicitly declare shape to support jit --- torchaudio/functional/filtering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index a5c6c8d81c..358eaf1f57 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -905,9 +905,9 @@ def lfilter( if clamp: output = torch.clamp(output, min=-1.0, max=1.0) - + # unpack batch - output = output.reshape(*shape) + output = output.reshape(shape[:-1] + output.shape[-1:]) return output From e68baa8a8493a635ba4b0bdc9062dedd1bcda851 Mon Sep 17 00:00:00 2001 From: Chin Yun Yu Date: Wed, 3 Mar 2021 11:32:54 +0800 Subject: [PATCH 7/7] remove whitespace --- torchaudio/functional/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 358eaf1f57..03efe7a0ca 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -905,7 +905,7 @@ def lfilter( if clamp: output = torch.clamp(output, min=-1.0, max=1.0) - + # unpack batch output = output.reshape(shape[:-1] + output.shape[-1:])