Skip to content

Add batch dimension inside the computation of lfilter #1476

@ChenPaulYu

Description

@ChenPaulYu

🚀 Feature

Add batch dimension inside the computation of lfilter to make thing fast

Motivation

It is recently possible to integrate digital filters inside the neural network because of the backprop support for lfilter (#1310).

However, currently, the filter only supports the 1D waveform, 1D a_coefficients, 1D b_coeficient. It is not possible to include the batch dimension in the computation. The only way to do it is naive for-loop, but it is very time-consuming.

Therefore, I modify some code to let the computation include batch dimension, but as the issues (#1408) say, compute lfilter without CPU optimization is slow.

I am ask for help whether any people have any other idea to add dimension inside the computation of lfilter ?

Pitch

I would love to see that any other elegant way to include the batch dimension when running the lfilter or
maybe someone know how to include batch dimension inside the CPU optimization loop which is used in lfilter.

Alternatives

I modify the original lfilter to include the batch dimension.
The major modification is

  1. calculate windowed_input_signal by F.conv1d with group\
  2. replace input_signal_windows.t() to input_signal_windows.permute(2, 0, 1) (inside _lfilter_core_generic_loop)
  3. replace addmv_ to einsum (inside _lfilter_core_generic_loop)
def lfilter_batch(waveform, a_coeffs, b_coeffs):
    
    batch    = waveform.size(0)
    n_sample = waveform.size(-1)
    n_order  = a_coeffs.size(-1)
        
    # Pad the input and create output
    padded_waveform        = F.pad(waveform, [n_order - 1, 0])
    padded_output_waveform = torch.zeros_like(padded_waveform)
    
    
    # Set up the coefficients matrix
    # Flip coefficients' order
    a_coeffs_flipped = a_coeffs.flip(-1)
    b_coeffs_flipped = b_coeffs.flip(-1)
    
    # calculate windowed_input_signal in parallel using convolution
    
    input_signal_windows = F.conv1d(padded_waveform.permute(1, 0, 2), b_coeffs_flipped.unsqueeze(1), groups=batch)
    input_signal_windows = input_signal_windows.permute(1, 0, 2)
    a_coeffs         = a_coeffs.unsqueeze(1)
    a_coeffs_flipped = a_coeffs_flipped.unsqueeze(1)
        
    input_signal_windows.div_(a_coeffs[:, :, :1])
    a_coeffs_flipped.div_(a_coeffs[:, :, :1])
    a_coeffs_flipped = a_coeffs_flipped.squeeze(1)

    
    padded_output_waveform = _lfilter_core_generic_loop_batch(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
    output = padded_output_waveform[:, :, n_order - 1:]
    return output


def _lfilter_core_execute(i_sample, n_order, o0, a_coeffs_flipped, out):    
    windowed_output_signal = out[:, :, i_sample:i_sample + n_order]   
    out[:, :, i_sample + n_order - 1] = o0 - torch.einsum('bcd, bd->bc', windowed_output_signal, a_coeffs_flipped)


def _lfilter_core_generic_loop_batch(input_signal_windows, a_coeffs_flipped, padded_output_waveform):
    n_order = a_coeffs_flipped.size(-1)
    out     = padded_output_waveform
    input_signal_windows_transpose = input_signal_windows.permute(2, 0, 1)
    for (i_sample, o0) in enumerate(input_signal_windows_transpose):
        _lfilter_core_execute(i_sample, n_order, o0, a_coeffs_flipped, out)
    return out

Additional context

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions