33import math
44import torch
55from torch import Tensor
6- from torch .nn import functional as F
76
87import torchaudio
98import torchaudio ._internal .fft
@@ -753,71 +752,16 @@ def mfcc(
753752 return feature
754753
755754
756- def _get_sinc_resample_kernel (orig_freq : int , new_freq : int , lowpass_filter_width : int ,
757- device : torch .device , dtype : torch .dtype ):
758- assert lowpass_filter_width > 0
759- kernels = []
760- base_freq = min (orig_freq , new_freq )
761- # This will perform antialiasing filtering by removing the highest frequencies.
762- # At first I thought I only needed this when downsampling, but when upsampling
763- # you will get edge artifacts without this, as the edge is equivalent to zero padding,
764- # which will add high freq artifacts.
765- base_freq *= 0.99
766-
767- # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
768- # using the sinc interpolation formula:
769- # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
770- # We can then sample the function x(t) with a different sample rate:
771- # y[j] = x(j / new_freq)
772- # or,
773- # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
774-
775- # We see here that y[j] is the convolution of x[i] with a specific filter, for which
776- # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
777- # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
778- # Indeed:
779- # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
780- # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
781- # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
782- # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
783- # This will explain the F.conv1d after, with a stride of orig_freq.
784- width = math .ceil (lowpass_filter_width * orig_freq / base_freq )
785- # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
786- # they will have a lot of almost zero values to the left or to the right...
787- # There is probably a way to evaluate those filters more efficiently, but this is kept for
788- # future work.
789- idx = torch .arange (- width , width + orig_freq , device = device , dtype = dtype )
790-
791- for i in range (new_freq ):
792- t = (- i / new_freq + idx / orig_freq ) * base_freq
793- t = t .clamp_ (- lowpass_filter_width , lowpass_filter_width )
794- t *= math .pi
795- # we do not use torch.hann_window here as we need to evaluate the window
796- # at specific positions, not over a regular grid.
797- window = torch .cos (t / lowpass_filter_width / 2 )** 2
798- kernel = torch .where (t == 0 , torch .tensor (1. ).to (t ), torch .sin (t ) / t )
799- kernel .mul_ (window )
800- kernels .append (kernel )
801-
802- scale = base_freq / orig_freq
803- return torch .stack (kernels ).view (new_freq , 1 , - 1 ).mul_ (scale ), width
804-
805-
806755def resample_waveform (waveform : Tensor ,
807756 orig_freq : float ,
808757 new_freq : float ,
809758 lowpass_filter_width : int = 6 ) -> Tensor :
810- r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
811- which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
812- a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
813- the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
814- upsample/downsample the signal.
759+ r"""Resamples the waveform at the new frequency.
815760
816- https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
817- https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
761+ This is a wrapper around ``torchaudio.functional.resample``.
818762
819763 Args:
820- waveform (Tensor): The input signal of size (c, n )
764+ waveform (Tensor): The input signal of size (..., time )
821765 orig_freq (float): The original frequency of the signal
822766 new_freq (float): The desired frequency
823767 lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
@@ -826,21 +770,4 @@ def resample_waveform(waveform: Tensor,
826770 Returns:
827771 Tensor: The waveform at the new frequency
828772 """
829- assert waveform .dim () == 2
830- assert orig_freq > 0.0 and new_freq > 0.0
831-
832- orig_freq = int (orig_freq )
833- new_freq = int (new_freq )
834- gcd = math .gcd (orig_freq , new_freq )
835- orig_freq = orig_freq // gcd
836- new_freq = new_freq // gcd
837-
838- kernel , width = _get_sinc_resample_kernel (orig_freq , new_freq , lowpass_filter_width ,
839- waveform .device , waveform .dtype )
840-
841- num_wavs , length = waveform .shape
842- waveform = F .pad (waveform , (width , width + orig_freq ))
843- resampled = F .conv1d (waveform [:, None ], kernel , stride = orig_freq )
844- resampled = resampled .transpose (1 , 2 ).reshape (num_wavs , - 1 )
845- target_length = int (math .ceil (new_freq * length / orig_freq ))
846- return resampled [..., :target_length ]
773+ return torchaudio .functional .resample (waveform , orig_freq , new_freq , lowpass_filter_width )
0 commit comments