-
Notifications
You must be signed in to change notification settings - Fork 739
Closed
Description
🚀 Feature
I request Preemphasis / Deemphasis modules. wiki In several sound related deep-learning tasks, These module are often used to help sound quality.
Motivation
In my speech enhancement case, that model usually generate high frequency noises without preemphasis. With this method, it brings many helps to enhance enhancement quality. But, to use of it on training models, it should be implemented with pytorch.
Pitch and Alternatives
I already did code first version of preemphasis and deemphasis written in pytorch.
So, after checking out below code has some bugs, I suggest to add this modules to transforms.py
-
code snippet
class PreEmphasis(torch.nn.Module):
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
# make kernel
# In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
self.register_buffer(
'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, input: torch.tensor) -> torch.tensor:
assert len(input.size()) == 3, 'The number of dimensions of input tensor must be 3!'
# reflect padding to match lengths of in/out
input = F.pad(input, (1, 0), 'reflect')
return F.conv1d(input, self.flipped_filter)
class InversePreEmphasis(torch.nn.Module):
"""
Implement Inverse Pre-emphasis by using RNN to boost up inference speed.
"""
def __init__(self, coef: float = 0.97):
super().__init__()
self.coef = coef
self.rnn = torch.nn.RNN(1, 1, 1, bias=False, batch_first=True)
# use originally on that time
self.rnn.weight_ih_l0.data.fill_(1)
# multiply coefficient on previous output
self.rnn.weight_hh_l0.data.fill_(self.coef)
def forward(self, input: torch.tensor) -> torch.tensor:
x, _ = self.rnn(input.transpose(1, 2))
return x.transpose(1, 2)Additional context
Sample Plots
- scipy codes
from scipy import signal
def preemphasis(x, coeff=0.97):
return signal.lfilter([1, -coeff], [1], x).astype(np.float32)
def inv_preemphasis(x, coeff=0.97):
return signal.lfilter([1], [1, -coeff], x).astype(np.float32)- Sample wave
- Preemphasis / deemphasis sample
- Pytorch samples
zjlww, roudimit, vincentqb, mogwai and fastyangmh
Metadata
Metadata
Assignees
Labels
No labels



