diff --git a/test/test_functional.py b/test/test_functional.py index a8e28f523a..1943b813d1 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -172,7 +172,7 @@ def test_istft_of_zeros(self): def test_istft_requires_overlap_windows(self): # the window is size 1 but it hops 20 so there is a gap which throw an error stft = torch.zeros((3, 5, 2)) - self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4, + self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4, hop_length=20, win_length=1, window=torch.ones(1)) def test_istft_requires_nola(self): @@ -192,11 +192,11 @@ def test_istft_requires_nola(self): # A window of ones meets NOLA but a window of zeros does not. This should # throw an error. torchaudio.functional.istft(stft, **kwargs_ok) - self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok) + self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok) def test_istft_requires_non_empty(self): - self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2) - self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2) + self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2) + self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2) def _test_istft_of_sine(self, amplitude, L, n): # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 2245768b99..bf5665adfb 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -2,6 +2,7 @@ import math from typing import Optional, Tuple +import warnings import torch from torch import Tensor @@ -49,7 +50,7 @@ def istft( win_length: Optional[int] = None, window: Optional[Tensor] = None, center: bool = True, - pad_mode: str = "reflect", + pad_mode: Optional[str] = None, normalized: bool = False, onesided: bool = True, length: Optional[int] = None, @@ -94,8 +95,7 @@ def istft( center (bool, optional): Whether ``input`` was padded on both sides so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. (Default: ``True``) - pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default: - ``"reflect"``) + pad_mode: This argument was ignored and to be removed. normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``) onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``) length (int or None, optional): The amount to trim the signal by (i.e. the @@ -104,105 +104,16 @@ def istft( Returns: Tensor: Least squares estimation of the original signal of size (..., signal_length) """ - stft_matrix_dim = stft_matrix.dim() - assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) - assert stft_matrix.numel() > 0 - - if stft_matrix_dim == 3: - # add a channel dimension - stft_matrix = stft_matrix.unsqueeze(0) - - # pack batch - shape = stft_matrix.size() - stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1]) - - dtype = stft_matrix.dtype - device = stft_matrix.device - fft_size = stft_matrix.size(1) - assert (onesided and n_fft // 2 + 1 == fft_size) or ( - not onesided and n_fft == fft_size - ), ( - "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. " - + "Given values were onesided: %s, n_fft: %d, fft_size: %d" - % ("True" if onesided else False, n_fft, fft_size) - ) - - # use stft defaults for Optionals - if win_length is None: - win_length = n_fft - - if hop_length is None: - hop_length = int(win_length // 4) - - # There must be overlap - assert 0 < hop_length <= win_length - assert 0 < win_length <= n_fft - - if window is None: - window = torch.ones(win_length, device=device, dtype=dtype) - - assert window.dim() == 1 and window.size(0) == win_length - - if win_length != n_fft: - # center window with pad left and right zeros - left = (n_fft - win_length) // 2 - window = torch.nn.functional.pad(window, (left, n_fft - win_length - left)) - assert window.size(0) == n_fft - # win_length and n_fft are synonymous from here on - - stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frame, fft_size, 2) - stft_matrix = torch.irfft( - stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,) - ) # size (channel, n_frame, n_fft) - - assert stft_matrix.size(2) == n_fft - n_frame = stft_matrix.size(1) - - ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frame, n_fft) - # each column of a channel is a frame which needs to be overlap added at the right place - ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame) - - # this does overlap add where the frames of ytmp are added such that the i'th frame of - # ytmp is added starting at i*hop_length in the output - y = torch.nn.functional.fold( - ytmp, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length) - ).squeeze(2) - - # do the same for the window function - window_sq = ( - window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0) - ) # size (1, n_fft, n_frame) - window_envelop = torch.nn.functional.fold( - window_sq, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length) - ).squeeze(2) # size (1, 1, expected_signal_len) - - expected_signal_len = n_fft + hop_length * (n_frame - 1) - assert y.size(2) == expected_signal_len - assert window_envelop.size(2) == expected_signal_len - - half_n_fft = n_fft // 2 - # we need to trim the front padding away if center - start = half_n_fft if center else 0 - end = -half_n_fft if length is None else start + length - - y = y[:, :, start:end] - window_envelop = window_envelop[:, :, start:end] - - # check NOLA non-zero overlap condition - window_envelop_lowest = window_envelop.abs().min() - assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % ( - window_envelop_lowest - ) - - y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) - - # unpack batch - y = y.reshape(shape[:-3] + y.shape[-1:]) - - if stft_matrix_dim == 3: # remove the channel dimension - y = y.squeeze(0) - - return y + warnings.warn( + 'istft has been moved to PyTorch and will be removed from torchaudio, ' + 'please use torch.istft instead.') + if pad_mode is not None: + warnings.warn( + 'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. ' + 'Please set `pad_mode` to None to suppress this warning.') + return torch.istft( + input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, + center=center, normalized=normalized, onesided=onesided, length=length) def spectrogram(