Skip to content

Commit bf0979c

Browse files
author
Caroline Chen
committed
restructure internals
1 parent 1d51496 commit bf0979c

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

torchaudio/functional/functional.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,9 +1302,7 @@ def _get_sinc_resample_kernel(
13021302
orig_freq: int,
13031303
new_freq: int,
13041304
lowpass_filter_width: int,
1305-
rolloff: float,
1306-
device: torch.device,
1307-
dtype: torch.dtype):
1305+
rolloff: float):
13081306
assert lowpass_filter_width > 0
13091307
kernels = []
13101308
base_freq = min(orig_freq, new_freq)
@@ -1336,7 +1334,7 @@ def _get_sinc_resample_kernel(
13361334
# they will have a lot of almost zero values to the left or to the right...
13371335
# There is probably a way to evaluate those filters more efficiently, but this is kept for
13381336
# future work.
1339-
idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
1337+
idx = torch.arange(-width, width + orig_freq)
13401338

13411339
for i in range(new_freq):
13421340
t = (-i / new_freq + idx / orig_freq) * base_freq
@@ -1353,13 +1351,36 @@ def _get_sinc_resample_kernel(
13531351
return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width
13541352

13551353

1354+
def _apply_sinc_resample_kernel(
1355+
waveform: Tensor,
1356+
orig_freq: int,
1357+
new_freq: int,
1358+
kernel: Tensor,
1359+
width: int,
1360+
):
1361+
# pack batch
1362+
shape = waveform.size()
1363+
waveform = waveform.view(-1, shape[-1])
1364+
kernel = kernel.to(device=waveform.device, dtype=waveform.dtype)
1365+
1366+
num_wavs, length = waveform.shape
1367+
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
1368+
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
1369+
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
1370+
target_length = int(math.ceil(new_freq * length / orig_freq))
1371+
resampled = resampled[..., :target_length]
1372+
1373+
# unpack batch
1374+
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
1375+
return resampled
1376+
1377+
13561378
def resample(
13571379
waveform: Tensor,
13581380
orig_freq: float,
13591381
new_freq: float,
13601382
lowpass_filter_width: int = 6,
13611383
rolloff: float = 0.99,
1362-
kernel: Tensor = None
13631384
) -> Tensor:
13641385
r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
13651386
which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
@@ -1378,20 +1399,13 @@ def resample(
13781399
but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)
13791400
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
13801401
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
1381-
kernel (Tensor, optional): Tensor of dimension (f, 1, w) representing the windowed sinc function that is
1382-
used in convolution to calculate the resampled waveform. ``f = new_freq_gcd`` and ``w = 2 *
1383-
math.ceil(lowpass_filter_width * (orig_freq_gcd) / (rolloff * min(orig_freq_gcd, new_freq_gcd)) + orig_freq_gcd``,
1384-
where ``new_freq_gcd`` and ``orig_freq_gcd`` are equal to ``new_freq // gcd`` and ``old_freq // gcd``
13851402
13861403
Returns:
13871404
Tensor: The waveform at the new frequency of dimension (..., time).
13881405
1389-
Note: transforms.Resample passes in a precomputed kernel, which will result in more efficient computation if reusing
1390-
the same set of resampling parameters to resample multiple waveforms.
1406+
Note: ``transforms.Resample` precomputes and reuses the resampling kernel, so using it will result in
1407+
more efficient computation if resampling multiple waveforms with the same resampling parameters.
13911408
"""
1392-
# pack batch
1393-
shape = waveform.size()
1394-
waveform = waveform.view(-1, shape[-1])
13951409

13961410
assert orig_freq > 0.0 and new_freq > 0.0
13971411

@@ -1414,22 +1428,7 @@ def resample(
14141428
orig_freq = orig_freq // gcd
14151429
new_freq = new_freq // gcd
14161430

1417-
if kernel == None:
1418-
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
1419-
rolloff, waveform.device, waveform.dtype)
1420-
else:
1421-
base_freq = min(orig_freq, new_freq) * rolloff
1422-
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
1423-
assert kernel.shape[0] == new_freq
1424-
assert kernel.shape[2] == 2 * width + orig_freq
1425-
1426-
num_wavs, length = waveform.shape
1427-
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
1428-
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
1429-
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
1430-
target_length = int(math.ceil(new_freq * length / orig_freq))
1431-
resampled = resampled[..., :target_length]
1432-
1433-
# unpack batch
1434-
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
1431+
kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
1432+
rolloff, waveform.device, waveform.dtype)
1433+
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, kernel, width)
14351434
return resampled

torchaudio/transforms.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from torch import Tensor
99
from torchaudio import functional as F
1010

11-
from .functional.functional import _get_sinc_resample_kernel
11+
from .functional.functional import (
12+
_get_sinc_resample_kernel,
13+
_apply_sinc_resample_kernel,
14+
)
1215

1316
__all__ = [
1417
'Spectrogram',
@@ -654,12 +657,15 @@ def __init__(self,
654657
lowpass_filter_width: int = 6,
655658
rolloff: float = 0.99) -> None:
656659
super(Resample, self).__init__()
657-
self.orig_freq = orig_freq
658-
self.new_freq = new_freq
660+
self.orig_freq = int(orig_freq)
661+
self.new_freq = int(new_freq)
662+
self.gcd = math.gcd(self.orig_freq, self.new_freq)
659663
self.resampling_method = resampling_method
660664
self.lowpass_filter_width = lowpass_filter_width
661665
self.rolloff = rolloff
662-
self.kernel = None
666+
667+
self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq // self.gcd, self.new_freq // self.gcd,
668+
self.lowpass_filter_width, self.rolloff)
663669

664670
def forward(self, waveform: Tensor) -> Tensor:
665671
r"""
@@ -670,13 +676,8 @@ def forward(self, waveform: Tensor) -> Tensor:
670676
Tensor: Output signal of dimension (..., time).
671677
"""
672678
if self.resampling_method == 'sinc_interpolation':
673-
if self.kernel == None:
674-
gcd = math.gcd(self.orig_freq, self.new_freq)
675-
orig_freq = self.orig_freq // gcd
676-
new_freq = self.new_freq // gcd
677-
self.kernel, _ = _get_sinc_resample_kernel(orig_freq, new_freq, self.lowpass_filter_width,
678-
self.rolloff, waveform.device, waveform.dtype)
679-
return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff, self.kernel)
679+
return _apply_sinc_resample_kernel(waveform, self.orig_freq // self.gcd, self.new_freq // self.gcd,
680+
self.kernel, self.width)
680681

681682
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
682683

0 commit comments

Comments
 (0)