-
Notifications
You must be signed in to change notification settings - Fork 738
Description
🚀 Feature
Add batch dimension inside the computation of lfilter to make thing fast
Motivation
It is recently possible to integrate digital filters inside the neural network because of the backprop support for lfilter (#1310).
However, currently, the filter only supports the 1D waveform, 1D a_coefficients, 1D b_coeficient. It is not possible to include the batch dimension in the computation. The only way to do it is naive for-loop, but it is very time-consuming.
Therefore, I modify some code to let the computation include batch dimension, but as the issues (#1408) say, compute lfilter without CPU optimization is slow.
I am ask for help whether any people have any other idea to add dimension inside the computation of lfilter ?
Pitch
I would love to see that any other elegant way to include the batch dimension when running the lfilter or
maybe someone know how to include batch dimension inside the CPU optimization loop which is used in lfilter.
Alternatives
I modify the original lfilter to include the batch dimension.
The major modification is
- calculate
windowed_input_signalbyF.conv1dwith group\ - replace
input_signal_windows.t()toinput_signal_windows.permute(2, 0, 1)(inside_lfilter_core_generic_loop) - replace
addmv_toeinsum(inside_lfilter_core_generic_loop)
def lfilter_batch(waveform, a_coeffs, b_coeffs):
batch = waveform.size(0)
n_sample = waveform.size(-1)
n_order = a_coeffs.size(-1)
# Pad the input and create output
padded_waveform = F.pad(waveform, [n_order - 1, 0])
padded_output_waveform = torch.zeros_like(padded_waveform)
# Set up the coefficients matrix
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(-1)
b_coeffs_flipped = b_coeffs.flip(-1)
# calculate windowed_input_signal in parallel using convolution
input_signal_windows = F.conv1d(padded_waveform.permute(1, 0, 2), b_coeffs_flipped.unsqueeze(1), groups=batch)
input_signal_windows = input_signal_windows.permute(1, 0, 2)
a_coeffs = a_coeffs.unsqueeze(1)
a_coeffs_flipped = a_coeffs_flipped.unsqueeze(1)
input_signal_windows.div_(a_coeffs[:, :, :1])
a_coeffs_flipped.div_(a_coeffs[:, :, :1])
a_coeffs_flipped = a_coeffs_flipped.squeeze(1)
padded_output_waveform = _lfilter_core_generic_loop_batch(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
output = padded_output_waveform[:, :, n_order - 1:]
return output
def _lfilter_core_execute(i_sample, n_order, o0, a_coeffs_flipped, out):
windowed_output_signal = out[:, :, i_sample:i_sample + n_order]
out[:, :, i_sample + n_order - 1] = o0 - torch.einsum('bcd, bd->bc', windowed_output_signal, a_coeffs_flipped)
def _lfilter_core_generic_loop_batch(input_signal_windows, a_coeffs_flipped, padded_output_waveform):
n_order = a_coeffs_flipped.size(-1)
out = padded_output_waveform
input_signal_windows_transpose = input_signal_windows.permute(2, 0, 1)
for (i_sample, o0) in enumerate(input_signal_windows_transpose):
_lfilter_core_execute(i_sample, n_order, o0, a_coeffs_flipped, out)
return out