@@ -1302,9 +1302,7 @@ def _get_sinc_resample_kernel(
13021302 orig_freq : int ,
13031303 new_freq : int ,
13041304 lowpass_filter_width : int ,
1305- rolloff : float ,
1306- device : torch .device ,
1307- dtype : torch .dtype ):
1305+ rolloff : float ):
13081306 assert lowpass_filter_width > 0
13091307 kernels = []
13101308 base_freq = min (orig_freq , new_freq )
@@ -1336,7 +1334,7 @@ def _get_sinc_resample_kernel(
13361334 # they will have a lot of almost zero values to the left or to the right...
13371335 # There is probably a way to evaluate those filters more efficiently, but this is kept for
13381336 # future work.
1339- idx = torch .arange (- width , width + orig_freq , device = device , dtype = dtype )
1337+ idx = torch .arange (- width , width + orig_freq )
13401338
13411339 for i in range (new_freq ):
13421340 t = (- i / new_freq + idx / orig_freq ) * base_freq
@@ -1353,13 +1351,36 @@ def _get_sinc_resample_kernel(
13531351 return torch .stack (kernels ).view (new_freq , 1 , - 1 ).mul_ (scale ), width
13541352
13551353
1354+ def _apply_sinc_resample_kernel (
1355+ waveform : Tensor ,
1356+ orig_freq : int ,
1357+ new_freq : int ,
1358+ kernel : Tensor ,
1359+ width : int ,
1360+ ):
1361+ # pack batch
1362+ shape = waveform .size ()
1363+ waveform = waveform .view (- 1 , shape [- 1 ])
1364+ kernel = kernel .to (device = waveform .device , dtype = waveform .dtype )
1365+
1366+ num_wavs , length = waveform .shape
1367+ waveform = torch .nn .functional .pad (waveform , (width , width + orig_freq ))
1368+ resampled = torch .nn .functional .conv1d (waveform [:, None ], kernel , stride = orig_freq )
1369+ resampled = resampled .transpose (1 , 2 ).reshape (num_wavs , - 1 )
1370+ target_length = int (math .ceil (new_freq * length / orig_freq ))
1371+ resampled = resampled [..., :target_length ]
1372+
1373+ # unpack batch
1374+ resampled = resampled .view (shape [:- 1 ] + resampled .shape [- 1 :])
1375+ return resampled
1376+
1377+
13561378def resample (
13571379 waveform : Tensor ,
13581380 orig_freq : float ,
13591381 new_freq : float ,
13601382 lowpass_filter_width : int = 6 ,
13611383 rolloff : float = 0.99 ,
1362- kernel : Tensor = None
13631384) -> Tensor :
13641385 r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
13651386 which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
@@ -1378,20 +1399,13 @@ def resample(
13781399 but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
13791400 rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
13801401 Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
1381- kernel (Tensor, optional): Tensor of dimension (f, 1, w) representing the windowed sinc function that is
1382- used in convolution to calculate the resampled waveform. ``f = new_freq_gcd`` and ``w = 2 *
1383- math.ceil(lowpass_filter_width * (orig_freq_gcd) / (rolloff * min(orig_freq_gcd, new_freq_gcd)) + orig_freq_gcd``,
1384- where ``new_freq_gcd`` and ``orig_freq_gcd`` are equal to ``new_freq // gcd`` and ``old_freq // gcd``
13851402
13861403 Returns:
13871404 Tensor: The waveform at the new frequency of dimension (..., time).
13881405
1389- Note: transforms.Resample passes in a precomputed kernel, which will result in more efficient computation if reusing
1390- the same set of resampling parameters to resample multiple waveforms .
1406+ Note: `` transforms.Resample` precomputes and reuses the resampling kernel, so using it will result in
1407+ more efficient computation if resampling multiple waveforms with the same resampling parameters .
13911408 """
1392- # pack batch
1393- shape = waveform .size ()
1394- waveform = waveform .view (- 1 , shape [- 1 ])
13951409
13961410 assert orig_freq > 0.0 and new_freq > 0.0
13971411
@@ -1414,22 +1428,7 @@ def resample(
14141428 orig_freq = orig_freq // gcd
14151429 new_freq = new_freq // gcd
14161430
1417- if kernel == None :
1418- kernel , width = _get_sinc_resample_kernel (orig_freq , new_freq , lowpass_filter_width ,
1419- rolloff , waveform .device , waveform .dtype )
1420- else :
1421- base_freq = min (orig_freq , new_freq ) * rolloff
1422- width = math .ceil (lowpass_filter_width * orig_freq / base_freq )
1423- assert kernel .shape [0 ] == new_freq
1424- assert kernel .shape [2 ] == 2 * width + orig_freq
1425-
1426- num_wavs , length = waveform .shape
1427- waveform = torch .nn .functional .pad (waveform , (width , width + orig_freq ))
1428- resampled = torch .nn .functional .conv1d (waveform [:, None ], kernel , stride = orig_freq )
1429- resampled = resampled .transpose (1 , 2 ).reshape (num_wavs , - 1 )
1430- target_length = int (math .ceil (new_freq * length / orig_freq ))
1431- resampled = resampled [..., :target_length ]
1432-
1433- # unpack batch
1434- resampled = resampled .view (shape [:- 1 ] + resampled .shape [- 1 :])
1431+ kernel , width = _get_sinc_resample_kernel (orig_freq , new_freq , lowpass_filter_width ,
1432+ rolloff , waveform .device , waveform .dtype )
1433+ resampled = _apply_sinc_resample_kernel (waveform , orig_freq , new_freq , kernel , width )
14351434 return resampled
0 commit comments