diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index a1309fb21a..ebab3d1e0b 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1299,12 +1299,28 @@ def compute_kaldi_pitch( def _get_sinc_resample_kernel( - orig_freq: int, - new_freq: int, + orig_freq: float, + new_freq: float, + gcd: int, lowpass_filter_width: int, - rolloff: float, - device: torch.device, - dtype: torch.dtype): + rolloff: float): + + if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): + warnings.warn( + "Non-integer frequencies are being cast to ints and may result in poor resampling quality " + "because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. " + "Using non-integer valued frequencies will throw an error in the next release. " + "To work around this issue, manually convert both frequencies to integer values " + "that maintain their resampling rate ratio before passing them into the function " + "Example: To downsample a 44100 hz waveform by a factor of 8, use " + "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` " + "For more information or to leave feedback about this change, please refer to " + "https://github.com/pytorch/audio/issues/1487." + ) + + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + assert lowpass_filter_width > 0 kernels = [] base_freq = min(orig_freq, new_freq) @@ -1336,7 +1352,7 @@ def _get_sinc_resample_kernel( # they will have a lot of almost zero values to the left or to the right... # There is probably a way to evaluate those filters more efficiently, but this is kept for # future work. - idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype) + idx = torch.arange(-width, width + orig_freq) for i in range(new_freq): t = (-i / new_freq + idx / orig_freq) * base_freq @@ -1353,6 +1369,34 @@ def _get_sinc_resample_kernel( return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width +def _apply_sinc_resample_kernel( + waveform: Tensor, + orig_freq: float, + new_freq: float, + gcd: int, + kernel: Tensor, + width: int, +): + orig_freq = int(orig_freq) // gcd + new_freq = int(new_freq) // gcd + + # pack batch + shape = waveform.size() + waveform = waveform.view(-1, shape[-1]) + kernel = kernel.to(device=waveform.device, dtype=waveform.dtype) + + num_wavs, length = waveform.shape + waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) + resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) + resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) + target_length = int(math.ceil(new_freq * length / orig_freq)) + resampled = resampled[..., :target_length] + + # unpack batch + resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) + return resampled + + def resample( waveform: Tensor, orig_freq: float, @@ -1380,42 +1424,15 @@ def resample( Returns: Tensor: The waveform at the new frequency of dimension (..., time). + + Note: ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in + more efficient computation if resampling multiple waveforms with the same resampling parameters. """ - # pack batch - shape = waveform.size() - waveform = waveform.view(-1, shape[-1]) assert orig_freq > 0.0 and new_freq > 0.0 - if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq): - warnings.warn( - "Non-integer frequencies are being cast to ints and may result in poor resampling quality " - "because the underlying algorithm requires an integer ratio between `orig_freq` and `new_freq`. " - "Using non-integer valued frequencies will throw an error in the next release. " - "To work around this issue, manually convert both frequencies to integer values " - "that maintain their resampling rate ratio before passing them into the function " - "Example: To downsample a 44100 hz waveform by a factor of 8, use " - "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5` " - "For more information or to leave feedback about this change, please refer to " - "https://github.com/pytorch/audio/issues/1487." - ) - - orig_freq = int(orig_freq) - new_freq = int(new_freq) - gcd = math.gcd(orig_freq, new_freq) - orig_freq = orig_freq // gcd - new_freq = new_freq // gcd - - kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width, - rolloff, waveform.device, waveform.dtype) - - num_wavs, length = waveform.shape - waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq)) - resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq) - resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) - target_length = int(math.ceil(new_freq * length / orig_freq)) - resampled = resampled[..., :target_length] + gcd = math.gcd(int(orig_freq), int(new_freq)) - # unpack batch - resampled = resampled.view(shape[:-1] + resampled.shape[-1:]) + kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, gcd, lowpass_filter_width, rolloff) + resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width) return resampled diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index c7a72c55a7..2a273e95e1 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -8,6 +8,10 @@ from torch import Tensor from torchaudio import functional as F +from .functional.functional import ( + _get_sinc_resample_kernel, + _apply_sinc_resample_kernel, +) __all__ = [ 'Spectrogram', @@ -647,18 +651,23 @@ class Resample(torch.nn.Module): """ def __init__(self, - orig_freq: int = 16000, - new_freq: int = 16000, + orig_freq: float = 16000, + new_freq: float = 16000, resampling_method: str = 'sinc_interpolation', lowpass_filter_width: int = 6, rolloff: float = 0.99) -> None: super(Resample, self).__init__() + self.orig_freq = orig_freq self.new_freq = new_freq + self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq)) self.resampling_method = resampling_method self.lowpass_filter_width = lowpass_filter_width self.rolloff = rolloff + self.kernel, self.width = _get_sinc_resample_kernel(self.orig_freq, self.new_freq, self.gcd, + self.lowpass_filter_width, self.rolloff) + def forward(self, waveform: Tensor) -> Tensor: r""" Args: @@ -668,7 +677,8 @@ def forward(self, waveform: Tensor) -> Tensor: Tensor: Output signal of dimension (..., time). """ if self.resampling_method == 'sinc_interpolation': - return F.resample(waveform, self.orig_freq, self.new_freq, self.lowpass_filter_width, self.rolloff) + return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, + self.kernel, self.width) raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))