Skip to content

Request preemphasis and deemphasis modules #226

@AppleHolic

Description

@AppleHolic

🚀 Feature

I request Preemphasis / Deemphasis modules. wiki In several sound related deep-learning tasks, These module are often used to help sound quality.

Motivation

In my speech enhancement case, that model usually generate high frequency noises without preemphasis. With this method, it brings many helps to enhance enhancement quality. But, to use of it on training models, it should be implemented with pytorch.

Pitch and Alternatives

I already did code first version of preemphasis and deemphasis written in pytorch.
So, after checking out below code has some bugs, I suggest to add this modules to transforms.py

class PreEmphasis(torch.nn.Module):

    def __init__(self, coef: float = 0.97):
        super().__init__()
        self.coef = coef
        # make kernel
        # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
        self.register_buffer(
            'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, input: torch.tensor) -> torch.tensor:
        assert len(input.size()) == 3, 'The number of dimensions of input tensor must be 3!'
        # reflect padding to match lengths of in/out
        input = F.pad(input, (1, 0), 'reflect')
        return F.conv1d(input, self.flipped_filter)


class InversePreEmphasis(torch.nn.Module):
    """
    Implement Inverse Pre-emphasis by using RNN to boost up inference speed.
    """

    def __init__(self, coef: float = 0.97):
        super().__init__()
        self.coef = coef
        self.rnn = torch.nn.RNN(1, 1, 1, bias=False, batch_first=True)
        # use originally on that time
        self.rnn.weight_ih_l0.data.fill_(1)
        # multiply coefficient on previous output
        self.rnn.weight_hh_l0.data.fill_(self.coef)

    def forward(self, input: torch.tensor) -> torch.tensor:
        x, _ = self.rnn(input.transpose(1, 2))
        return x.transpose(1, 2)

Additional context

Sample Plots

  • scipy codes
from scipy import signal

def preemphasis(x, coeff=0.97):
    return signal.lfilter([1, -coeff], [1], x).astype(np.float32)


def inv_preemphasis(x, coeff=0.97):
    return signal.lfilter([1], [1, -coeff], x).astype(np.float32)
  • Sample wave

image

  • Preemphasis / deemphasis sample

image

image

  • Pytorch samples

image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions