diff --git a/setup.py b/setup.py index 9ed91f2468..4822a6e9d0 100644 --- a/setup.py +++ b/setup.py @@ -113,6 +113,14 @@ def check_env_flag(name, default=''): extra_compile_args=eca, extra_objects=extra_objects, extra_link_args=ela), + CppExtension( + '_torch_filtering', + ['torchaudio/filtering.cpp'], + libraries=libraries, + include_dirs=include_dirs, + extra_compile_args=eca, + extra_objects=extra_objects, + extra_link_args=ela), ], cmdclass={'build_ext': BuildExtension}, install_requires=[pytorch_package_dep] diff --git a/test/test_functional_lfilter_perf.py b/test/test_functional_lfilter_perf.py new file mode 100644 index 0000000000..ec9d9b2ee4 --- /dev/null +++ b/test/test_functional_lfilter_perf.py @@ -0,0 +1,200 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import math +import os +import torch +import torchaudio +import unittest +import common_utils + +from torchaudio.functional import lfilter, lfilter_cpp_impl + + +class TestFunctionalLFilterPerformance(unittest.TestCase): + test_dirpath, test_dir = common_utils.create_temp_assets_dir() + + @staticmethod + def run_test(n_channels, n_frames, n_order_filter, assertClose=True): + waveform = torch.rand(n_channels, n_frames, device="cpu") + + if n_order_filter == 8: + # Eighth Order Filter + # >>> import scipy.signal + # >>> wp = 0.3 + # >>> ws = 0.5 + # >>> gpass = 1 + # >>> gstop = 100 + # >>> b, a = scipy.signal.iirdesign(wp, ws, gpass, gstop) + b_coeffs = [ + 0.0006544487997063485, + 0.001669274889397942, + 0.003218714446315984, + 0.004222562499298002, + 0.004222562499298002, + 0.0032187144463159834, + 0.0016692748893979413, + 0.0006544487997063485, + ] + a_coeffs = [ + 1.0, + -4.67403506662255, + 10.516336803850786, + -14.399207825856776, + 12.844181702707655, + -7.43604712843608, + 2.5888616732696077, + -0.4205601576432048, + ] + elif n_order_filter == 5: + # Fifth Order Filter + # >>> import scipy.signal + # >>> wp = 0.3, ws = 0.5, gpass = 1, gstop = 40 + # >>> b, a = scipy.signal.iirdesign(wp, ws, gpass, gstop) + b_coeffs = [ + 0.0353100066384039, + 0.023370652985988206, + 0.0560524973457262, + 0.023370652985988193, + 0.03531000663840389, + ] + a_coeffs = [ + 1.0, + -2.321010052951366, + 2.677193357612127, + -1.5774235418173692, + 0.4158137396065854, + ] + elif n_order_filter == 18: + # >>> import scipy.signal + # >>> wp = 0.48, ws = 0.5, gpass = 0.2, gstop = 120 + # >>> b, a = scipy.signal.iirdesign(wp, ws, gpass, gstop) + b_coeffs = [ + 0.0006050813536446144, + 0.002920916369302935, + 0.010247568347759453, + 0.02591236698507957, + 0.05390501051935878, + 0.09344581172781004, + 0.13951533321139883, + 0.1808658576803922, + 0.2056643061895918, + 0.2056643061895911, + 0.1808658576803912, + 0.13951533321139847, + 0.09344581172781012, + 0.053905010519358885, + 0.02591236698507962, + 0.010247568347759466, + 0.0029209163693029367, + 0.0006050813536446148, + ] + a_coeffs = [ + 1.0, + -4.3964136877356745, + 14.650181359641305, + -34.45816395187684, + 67.18247518997862, + -108.01956225077998, + 149.4332056661277, + -178.07791467502364, + 185.28267044557634, + -168.13382659655514, + 133.22364764531704, + -91.59439958870928, + 54.15835239046956, + -27.090521914173934, + 11.163677645454127, + -3.627296054625132, + 0.8471764313073272, + -0.11712354962357388, + ] + elif n_order_filter == 40: + # Create random set of 40 coefficients, will not be stable, test runtime rather than + # correctness + a_coeffs = torch.rand(40).numpy() + b_coeffs = torch.rand(40).numpy() + a_coeffs[0] = 1 # Set a0 to 1.0 + + # Cast into Tensors + a_coeffs = torch.tensor(a_coeffs, device="cpu", dtype=torch.float32) + b_coeffs = torch.tensor(b_coeffs, device="cpu", dtype=torch.float32) + + def time_and_output(func): + import time + + st = time.time() + output = func() + run_time = time.time() - st + return (output, run_time) + + (output_waveform_1, run_time_1) = time_and_output( + lambda: lfilter(waveform, a_coeffs, b_coeffs) + ) + (output_waveform_2, run_time_2) = time_and_output( + lambda: lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'element_wise') + ) + (output_waveform_3, run_time_3) = time_and_output( + lambda: lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'matrix') + ) + + print("-" * 80) + print( + "lfilter perf - Data Size: [%d x %d], Filter Order: %d" + % (waveform.size(0), waveform.size(1), a_coeffs.size(0)) + ) + print("-" * 80) + print("Python Matrix Runtime [current]: %10.6f s" % run_time_1) + print("CPP Element Wise Runtime : %10.6f s" % run_time_2) + print("CPP Matrix Runtime : %10.6f s" % run_time_3) + print("-" * 80) + print("Ratio Python / CPP ElementWise : %10.2f x" % (run_time_1 / run_time_2)) + + if assertClose: + # maxDeviation = torch.kthvalue(torch.abs(output_waveform_2- output_waveform_3), output_waveform_1.size(1)) + assert torch.allclose(output_waveform_1, output_waveform_2, atol=3e-4) + assert torch.allclose(output_waveform_2, output_waveform_3, atol=3e-4) + print("PASS - all outputs are identical") + print("-" * 80) + + def test_cpp_lfilter_runs(self): + waveform = torch.rand(2, 1000, dtype=torch.float32) + b_coeffs = torch.rand(2, dtype=torch.float32) + a_coeffs = torch.rand(2, dtype=torch.float32) + a_coeffs[0] = 1 + + double_waveform = torch.rand(2, 1000, dtype=torch.float64) + double_b_coeffs = torch.rand(5, dtype=torch.float64) + double_a_coeffs = torch.rand(5, dtype=torch.float64) + double_a_coeffs[0] = 1 + + output_waveform = lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'element_wise') + output_waveform = lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'matrix') + output_waveform = lfilter_cpp_impl(double_waveform, double_a_coeffs, double_b_coeffs, 'element_wise') + + def test_lfilter_cmp(self): + """ + Runs comparison on CPU + """ + + torch.random.manual_seed(423) + self.run_test(2, 8000, 5) + self.run_test(2, 80000, 5) + self.run_test(2, 800000, 5) + self.run_test(2, 8000, 8) + self.run_test(2, 80000, 8) + self.run_test(2, 800000, 8) + + # For higher order filters, due to floating point precision + # matrix method and element method can get different results depending on order of operations + # Also, for longer signals and higher filters, easier to create unstable filter + # https://dsp.stackexchange.com/questions/54386/relation-between-order-and-stability-in-iir-filter + self.run_test(2, 8000, 18, False) + self.run_test(2, 80000, 18, False) + self.run_test(2, 800000, 18, False) + + self.run_test(2, 8000, 40, False) + self.run_test(2, 80000, 40, False) + self.run_test(2, 800000, 40, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchaudio/filtering.cpp b/torchaudio/filtering.cpp new file mode 100644 index 0000000000..61f1ec8ca0 --- /dev/null +++ b/torchaudio/filtering.cpp @@ -0,0 +1,112 @@ +#include +// TBD - for CUDA support #include +// TBD - Compile on CUDA + +namespace torch { +namespace audio { + + +void _lfilter_tensor_matrix( + at::Tensor const & padded_waveform, + at::Tensor & padded_output_waveform, + at::Tensor const & a_coeffs_filled, + at::Tensor const & b_coeffs_filled, + at::Tensor & o0, + at::Tensor const & normalization_a0 + ) { + int64_t n_order = a_coeffs_filled.size(0); + int64_t n_frames = padded_waveform.size(1) - n_order + 1; + int64_t n_channels = padded_waveform.size(0); + + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + // reset all o0 + o0.fill_(0.0); + + // time window of input and output, size [n_channels, n_order] + at::Tensor const & input_window = + padded_waveform.slice(0, 0, n_channels) + .slice(1, i_frame, i_frame + n_order); + at::Tensor const & output_window = + padded_output_waveform.slice(0, 0, n_channels) + .slice(1, i_frame, i_frame+ n_order); + + // matrix multiply to get [n_channels x n_channels], + // extract diagonal and unsqueeze to get [n_channels, 1] result + at::Tensor inp_result = + torch::unsqueeze(torch::diag(torch::mm(input_window, + b_coeffs_filled)), 1); + at::Tensor out_result = + torch::unsqueeze(torch::diag(torch::mm(output_window, + a_coeffs_filled)), 1); + + o0.add_(inp_result); + o0.sub_(out_result); + + // normalize by a0 + o0.div_(normalization_a0); + + // Set the output + padded_output_waveform.slice(0, 0, n_channels) + .slice(1, + i_frame + n_order - 1, + i_frame + n_order - 1 + 1) = o0; + } + } + + template + void _lfilter_element_wise( + at::Tensor const & padded_waveform, + at::Tensor & padded_output_waveform, + at::Tensor const & a_coeffs, + at::Tensor const & b_coeffs + ) { + int64_t n_order = a_coeffs.size(0); + int64_t n_frames = padded_waveform.size(1) - n_order + 1; + int64_t n_channels = padded_waveform.size(0); + + // Create CPU accessors for fast access + // https://pytorch.org/cppdocs/notes/tensor_basics.html + auto input_accessor = padded_waveform.accessor(); + auto output_accessor = padded_output_waveform.accessor(); + auto a_coeffs_accessor = a_coeffs.accessor(); + auto b_coeffs_accessor = b_coeffs.accessor(); + T o0; + + for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) { + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + // execute the difference equation + o0 = 0; + for (int i_offset = 0; i_offset < n_order; ++i_offset) { + o0 += input_accessor[i_channel][i_frame + i_offset] * + b_coeffs_accessor[n_order - i_offset - 1]; + o0 -= output_accessor[i_channel][i_frame + i_offset] * + a_coeffs_accessor[n_order - i_offset - 1]; + } + o0 = o0 / a_coeffs_accessor[0]; + + // put back into the main data structure + output_accessor[i_channel][i_frame + n_order - 1] = o0; + } + } + } + +} // namespace audio +} // namespace torch + + +PYBIND11_MODULE(_torch_filtering, m) { + py::options options; + options.disable_function_signatures(); + m.def( + "_lfilter_tensor_matrix", + &torch::audio::_lfilter_tensor_matrix, + "Executes difference equation with tensor"); + m.def( + "_lfilter_element_wise_float", + &torch::audio::_lfilter_element_wise, + "Executes difference equation with tensor"); + m.def( + "_lfilter_element_wise_double", + &torch::audio::_lfilter_element_wise, + "Executes difference equation with tensor"); +} diff --git a/torchaudio/functional.py b/torchaudio/functional.py index dfabb431ba..70a2a4fd30 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -3,6 +3,7 @@ import math import torch +from _torch_filtering import _lfilter_tensor_matrix, _lfilter_element_wise_float, _lfilter_element_wise_double __all__ = [ "istft", @@ -17,6 +18,7 @@ "magphase", "phase_vocoder", "lfilter", + "lfilter_cpp_impl", "lowpass_biquad", "highpass_biquad", "equalizer_biquad", @@ -537,7 +539,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): def lfilter(waveform, a_coeffs, b_coeffs): # type: (Tensor, Tensor, Tensor) -> Tensor r""" - Performs an IIR filter by evaluating difference equation. + Performs an IIR filter by evaluating difference equation. Coefficients should be designed to be stable. Args: waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`. Must be normalized to -1 to 1. @@ -595,13 +597,96 @@ def lfilter(waveform, a_coeffs, b_coeffs): return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])) +def lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, execution_method): + # type: (Tensor, Tensor, Tensor, string) -> Tensor + r""" + See lfilter documentation. Execution method can be either `element_wise` or `matrix`. + """ + + # Perform sanity checks, input check, and memory allocation in python + + # Current limitations to be removed in future + assert(waveform.dtype == torch.float32 or waveform.dtype == torch.float64) + assert(waveform.device.type == 'cpu') + + assert(waveform.device == a_coeffs.device) + assert(waveform.device == b_coeffs.device) + assert(waveform.dtype == a_coeffs.dtype) + assert(waveform.dtype == b_coeffs.dtype) + + # Use these parameters for any calculations + device = waveform.device + dtype = waveform.dtype + + # Check filter orders + n_order = a_coeffs.size(0) + assert(b_coeffs.size(0) == n_order) + assert(n_order > 0) + assert(a_coeffs[0] != 0) # a0 coeff for y[n] can not be 0 + + # Check waveform size + assert(len(waveform.size()) == 2) + n_channels = waveform.size(0) + n_frames = waveform.size(1) + n_padded_frames = n_frames + n_order - 1 + + # Allocate temporary data structures + # First, the input waveform should be padded by the order of the filter + # N.B. Should we look to how we can avoid copying + padded_waveform = torch.zeros(n_channels, n_padded_frames, dtype=dtype, device=device) + padded_output_waveform = torch.zeros(n_channels, n_padded_frames, dtype=dtype, device=device) + padded_waveform[:, (n_order - 1):] = waveform + + # More temporary data structures + o0 = torch.zeros(n_channels, 1, dtype=dtype, device=device) + normalization_a0 = a_coeffs[0].unsqueeze(0).repeat(n_channels, 1) + ones = torch.ones(n_channels, n_padded_frames, dtype=dtype, device=device) + + # Run through assertion size checks + assert(normalization_a0.size(0) == n_channels) + + if execution_method == 'element_wise': + if (dtype == torch.float32): + _lfilter_element_wise_float(padded_waveform, + padded_output_waveform, + a_coeffs, + b_coeffs, + ) + elif (dtype == torch.float64): + _lfilter_element_wise_double(padded_waveform, + padded_output_waveform, + a_coeffs, + b_coeffs, + ) + else: + raise Exception("lfilter not supported for type ", dtype) + elif execution_method == 'matrix': + # From [n_order] a_coeffs, create a [n_channel, n_order] tensor + # lowest order coefficients e.g. a0, b0 should be at the "bottom" + # used for matrix multiply later + a_coeffs_filled = a_coeffs.flip(0).repeat(n_channels, 1).t() + b_coeffs_filled = b_coeffs.flip(0).repeat(n_channels, 1).t() + _lfilter_tensor_matrix(padded_waveform, + padded_output_waveform, + a_coeffs_filled, + b_coeffs_filled, + o0, + normalization_a0, + ) + else: + raise Exception("invalid lfilter execution method %s" % execution_method) + + return torch.min(ones, torch.max(ones * -1, padded_output_waveform))[:, (n_order - 1):] + + def biquad(waveform, b0, b1, b2, a0, a1, a2): # type: (Tensor, float, float, float, float, float, float) -> Tensor r"""Performs a biquad filter of input tensor. Initial conditions set to 0. https://en.wikipedia.org/wiki/Digital_biquad_filter Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): Audio waveform of dimension of `(n_channel, time)`. + Currently only supports float32. Normalized [-1, 1] b0 (float): numerator coefficient of current input, x[n] b1 (float): numerator coefficient of input one time step ago x[n-1] b2 (float): numerator coefficient of input two time steps ago x[n-2] @@ -610,7 +695,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): a2 (float): denominator coefficient of current output y[n-2] Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(channel, time)` in range [-1, 1] """ device = waveform.device