Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def check_env_flag(name, default=''):
extra_compile_args=eca,
extra_objects=extra_objects,
extra_link_args=ela),
CppExtension(
'_torch_filtering',
['torchaudio/filtering.cpp'],
libraries=libraries,
include_dirs=include_dirs,
extra_compile_args=eca,
extra_objects=extra_objects,
extra_link_args=ela),
],
cmdclass={'build_ext': BuildExtension},
install_requires=[pytorch_package_dep]
Expand Down
200 changes: 200 additions & 0 deletions test/test_functional_lfilter_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import math
import os
import torch
import torchaudio
import unittest
import common_utils

from torchaudio.functional import lfilter, lfilter_cpp_impl


class TestFunctionalLFilterPerformance(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()

@staticmethod
def run_test(n_channels, n_frames, n_order_filter, assertClose=True):
waveform = torch.rand(n_channels, n_frames, device="cpu")

if n_order_filter == 8:
# Eighth Order Filter
# >>> import scipy.signal
# >>> wp = 0.3
# >>> ws = 0.5
# >>> gpass = 1
# >>> gstop = 100
# >>> b, a = scipy.signal.iirdesign(wp, ws, gpass, gstop)
b_coeffs = [
0.0006544487997063485,
0.001669274889397942,
0.003218714446315984,
0.004222562499298002,
0.004222562499298002,
0.0032187144463159834,
0.0016692748893979413,
0.0006544487997063485,
]
a_coeffs = [
1.0,
-4.67403506662255,
10.516336803850786,
-14.399207825856776,
12.844181702707655,
-7.43604712843608,
2.5888616732696077,
-0.4205601576432048,
]
elif n_order_filter == 5:
# Fifth Order Filter
# >>> import scipy.signal
# >>> wp = 0.3, ws = 0.5, gpass = 1, gstop = 40
# >>> b, a = scipy.signal.iirdesign(wp, ws, gpass, gstop)
b_coeffs = [
0.0353100066384039,
0.023370652985988206,
0.0560524973457262,
0.023370652985988193,
0.03531000663840389,
]
a_coeffs = [
1.0,
-2.321010052951366,
2.677193357612127,
-1.5774235418173692,
0.4158137396065854,
]
elif n_order_filter == 18:
# >>> import scipy.signal
# >>> wp = 0.48, ws = 0.5, gpass = 0.2, gstop = 120
# >>> b, a = scipy.signal.iirdesign(wp, ws, gpass, gstop)
b_coeffs = [
0.0006050813536446144,
0.002920916369302935,
0.010247568347759453,
0.02591236698507957,
0.05390501051935878,
0.09344581172781004,
0.13951533321139883,
0.1808658576803922,
0.2056643061895918,
0.2056643061895911,
0.1808658576803912,
0.13951533321139847,
0.09344581172781012,
0.053905010519358885,
0.02591236698507962,
0.010247568347759466,
0.0029209163693029367,
0.0006050813536446148,
]
a_coeffs = [
1.0,
-4.3964136877356745,
14.650181359641305,
-34.45816395187684,
67.18247518997862,
-108.01956225077998,
149.4332056661277,
-178.07791467502364,
185.28267044557634,
-168.13382659655514,
133.22364764531704,
-91.59439958870928,
54.15835239046956,
-27.090521914173934,
11.163677645454127,
-3.627296054625132,
0.8471764313073272,
-0.11712354962357388,
]
elif n_order_filter == 40:
# Create random set of 40 coefficients, will not be stable, test runtime rather than
# correctness
a_coeffs = torch.rand(40).numpy()
b_coeffs = torch.rand(40).numpy()
a_coeffs[0] = 1 # Set a0 to 1.0

# Cast into Tensors
a_coeffs = torch.tensor(a_coeffs, device="cpu", dtype=torch.float32)
b_coeffs = torch.tensor(b_coeffs, device="cpu", dtype=torch.float32)

def time_and_output(func):
import time

st = time.time()
output = func()
run_time = time.time() - st
return (output, run_time)

(output_waveform_1, run_time_1) = time_and_output(
lambda: lfilter(waveform, a_coeffs, b_coeffs)
)
(output_waveform_2, run_time_2) = time_and_output(
lambda: lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'element_wise')
)
(output_waveform_3, run_time_3) = time_and_output(
lambda: lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'matrix')
)

print("-" * 80)
print(
"lfilter perf - Data Size: [%d x %d], Filter Order: %d"
% (waveform.size(0), waveform.size(1), a_coeffs.size(0))
)
print("-" * 80)
print("Python Matrix Runtime [current]: %10.6f s" % run_time_1)
print("CPP Element Wise Runtime : %10.6f s" % run_time_2)
print("CPP Matrix Runtime : %10.6f s" % run_time_3)
print("-" * 80)
print("Ratio Python / CPP ElementWise : %10.2f x" % (run_time_1 / run_time_2))

if assertClose:
# maxDeviation = torch.kthvalue(torch.abs(output_waveform_2- output_waveform_3), output_waveform_1.size(1))
assert torch.allclose(output_waveform_1, output_waveform_2, atol=3e-4)
assert torch.allclose(output_waveform_2, output_waveform_3, atol=3e-4)
print("PASS - all outputs are identical")
print("-" * 80)

def test_cpp_lfilter_runs(self):
waveform = torch.rand(2, 1000, dtype=torch.float32)
b_coeffs = torch.rand(2, dtype=torch.float32)
a_coeffs = torch.rand(2, dtype=torch.float32)
a_coeffs[0] = 1

double_waveform = torch.rand(2, 1000, dtype=torch.float64)
double_b_coeffs = torch.rand(5, dtype=torch.float64)
double_a_coeffs = torch.rand(5, dtype=torch.float64)
double_a_coeffs[0] = 1

output_waveform = lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'element_wise')
output_waveform = lfilter_cpp_impl(waveform, a_coeffs, b_coeffs, 'matrix')
output_waveform = lfilter_cpp_impl(double_waveform, double_a_coeffs, double_b_coeffs, 'element_wise')

def test_lfilter_cmp(self):
"""
Runs comparison on CPU
"""

torch.random.manual_seed(423)
self.run_test(2, 8000, 5)
self.run_test(2, 80000, 5)
self.run_test(2, 800000, 5)
self.run_test(2, 8000, 8)
self.run_test(2, 80000, 8)
self.run_test(2, 800000, 8)

# For higher order filters, due to floating point precision
# matrix method and element method can get different results depending on order of operations
# Also, for longer signals and higher filters, easier to create unstable filter
# https://dsp.stackexchange.com/questions/54386/relation-between-order-and-stability-in-iir-filter
self.run_test(2, 8000, 18, False)
self.run_test(2, 80000, 18, False)
self.run_test(2, 800000, 18, False)

self.run_test(2, 8000, 40, False)
self.run_test(2, 80000, 40, False)
self.run_test(2, 800000, 40, False)


if __name__ == "__main__":
unittest.main()
112 changes: 112 additions & 0 deletions torchaudio/filtering.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <torch/extension.h>
// TBD - for CUDA support #include <ATen/cuda/CUDAContext.h>
// TBD - Compile on CUDA

namespace torch {
namespace audio {


void _lfilter_tensor_matrix(
at::Tensor const & padded_waveform,
at::Tensor & padded_output_waveform,
at::Tensor const & a_coeffs_filled,
at::Tensor const & b_coeffs_filled,
at::Tensor & o0,
at::Tensor const & normalization_a0
) {
int64_t n_order = a_coeffs_filled.size(0);
int64_t n_frames = padded_waveform.size(1) - n_order + 1;
int64_t n_channels = padded_waveform.size(0);

for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
// reset all o0
o0.fill_(0.0);

// time window of input and output, size [n_channels, n_order]
at::Tensor const & input_window =
padded_waveform.slice(0, 0, n_channels)
.slice(1, i_frame, i_frame + n_order);
at::Tensor const & output_window =
padded_output_waveform.slice(0, 0, n_channels)
.slice(1, i_frame, i_frame+ n_order);

// matrix multiply to get [n_channels x n_channels],
// extract diagonal and unsqueeze to get [n_channels, 1] result
at::Tensor inp_result =
torch::unsqueeze(torch::diag(torch::mm(input_window,
b_coeffs_filled)), 1);
at::Tensor out_result =
torch::unsqueeze(torch::diag(torch::mm(output_window,
a_coeffs_filled)), 1);

o0.add_(inp_result);
o0.sub_(out_result);

// normalize by a0
o0.div_(normalization_a0);

// Set the output
padded_output_waveform.slice(0, 0, n_channels)
.slice(1,
i_frame + n_order - 1,
i_frame + n_order - 1 + 1) = o0;
}
}

template <typename T>
void _lfilter_element_wise(
at::Tensor const & padded_waveform,
at::Tensor & padded_output_waveform,
at::Tensor const & a_coeffs,
at::Tensor const & b_coeffs
) {
int64_t n_order = a_coeffs.size(0);
int64_t n_frames = padded_waveform.size(1) - n_order + 1;
int64_t n_channels = padded_waveform.size(0);

// Create CPU accessors for fast access
// https://pytorch.org/cppdocs/notes/tensor_basics.html
auto input_accessor = padded_waveform.accessor<T, 2>();
auto output_accessor = padded_output_waveform.accessor<T, 2>();
auto a_coeffs_accessor = a_coeffs.accessor<T, 1>();
auto b_coeffs_accessor = b_coeffs.accessor<T, 1>();
T o0;

for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) {
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
// execute the difference equation
o0 = 0;
for (int i_offset = 0; i_offset < n_order; ++i_offset) {
o0 += input_accessor[i_channel][i_frame + i_offset] *
b_coeffs_accessor[n_order - i_offset - 1];
o0 -= output_accessor[i_channel][i_frame + i_offset] *
a_coeffs_accessor[n_order - i_offset - 1];
}
o0 = o0 / a_coeffs_accessor[0];

// put back into the main data structure
output_accessor[i_channel][i_frame + n_order - 1] = o0;
}
}
}

} // namespace audio
} // namespace torch


PYBIND11_MODULE(_torch_filtering, m) {
py::options options;
options.disable_function_signatures();
m.def(
"_lfilter_tensor_matrix",
&torch::audio::_lfilter_tensor_matrix,
"Executes difference equation with tensor");
m.def(
"_lfilter_element_wise_float",
&torch::audio::_lfilter_element_wise<float>,
"Executes difference equation with tensor");
m.def(
"_lfilter_element_wise_double",
&torch::audio::_lfilter_element_wise<double>,
"Executes difference equation with tensor");
}
Loading