diff --git a/test/test_functional.py b/test/test_functional.py new file mode 100644 index 0000000000..674baf0b21 --- /dev/null +++ b/test/test_functional.py @@ -0,0 +1,202 @@ +import math + +import torch +import torchaudio +import unittest +import test.common_utils + + +class TestFunctional(unittest.TestCase): + data_sizes = [(2, 20), (3, 15)] + number_of_trials = 100 + stored_rand_data = [] + fixed_precision = int(1e10) + + def setUp(self): + # we want to make sure that the random values are reproducible + self.stored_rand_data.clear() + torch.manual_seed(0) + for data_size in self.data_sizes: + rand_data1 = torch.randint(low=-self.fixed_precision, high=self.fixed_precision, size=data_size) + rand_data2 = torch.randint(low=-self.fixed_precision, high=self.fixed_precision, size=data_size) + self.stored_rand_data.append([rand_data1, rand_data2]) + + def _get_random_tensor(self, i): + # gets a random tensor of size data_sizes[i]. adds to previous tensors and then mods it. + rand_data = self.stored_rand_data[i] + rand_data3 = (rand_data[0] + rand_data[1]) % self.fixed_precision + rand_data.pop(0) + rand_data.append(rand_data3) + return rand_data3.float() / self.fixed_precision + + def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): + # trim sound for case when constructed signal is shorter than original + sound = sound[..., :estimate.size(-1)] + + self.assertTrue(sound.shape == estimate.shape, (sound.shape, estimate.shape)) + self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol)) + + def _test_istft_is_inverse_of_stft(self, kwargs): + # generates a random sound signal for each tril and then does the stft/istft + # operation to check whether we can reconstruct signal + for i in range(len(self.data_sizes)): + for _ in range(self.number_of_trials): + sound = self._get_random_tensor(i) + + stft = torch.stft(sound, **kwargs) + estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) + + self._compare_estimate(sound, estimate) + + def test_istft_is_inverse_of_stft1(self): + # hann_window, centered, normalized, onesided + kwargs1 = { + 'n_fft': 12, + 'hop_length': 4, + 'win_length': 12, + 'window': torch.hann_window(12), + 'center': True, + 'pad_mode': 'reflect', + 'normalized': True, + 'onesided': True, + } + + self._test_istft_is_inverse_of_stft(kwargs1) + + def test_istft_is_inverse_of_stft2(self): + # hann_window, centered, not normalized, not onesided + kwargs2 = { + 'n_fft': 12, + 'hop_length': 2, + 'win_length': 8, + 'window': torch.hann_window(8), + 'center': True, + 'pad_mode': 'reflect', + 'normalized': False, + 'onesided': False, + } + + self._test_istft_is_inverse_of_stft(kwargs2) + + def test_istft_is_inverse_of_stft3(self): + # hamming_window, centered, normalized, not onesided + kwargs3 = { + 'n_fft': 15, + 'hop_length': 3, + 'win_length': 11, + 'window': torch.hamming_window(11), + 'center': True, + 'pad_mode': 'constant', + 'normalized': True, + 'onesided': False, + } + + self._test_istft_is_inverse_of_stft(kwargs3) + + def test_istft_is_inverse_of_stft4(self): + # hamming_window, not centered, not normalized, onesided + # window same size as n_fft + kwargs4 = { + 'n_fft': 5, + 'hop_length': 2, + 'win_length': 5, + 'window': torch.hamming_window(5), + 'center': False, + 'pad_mode': 'constant', + 'normalized': False, + 'onesided': True, + } + + self._test_istft_is_inverse_of_stft(kwargs4) + + def test_istft_is_inverse_of_stft5(self): + # hamming_window, not centered, not normalized, not onesided + # window same size as n_fft + kwargs5 = { + 'n_fft': 3, + 'hop_length': 2, + 'win_length': 3, + 'window': torch.hamming_window(3), + 'center': False, + 'pad_mode': 'reflect', + 'normalized': False, + 'onesided': False, + } + + self._test_istft_is_inverse_of_stft(kwargs5) + + def test_istft_of_ones(self): + # stft = torch.stft(torch.ones(4), 4) + stft = torch.tensor([ + [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] + ]) + + estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) + self._compare_estimate(torch.ones(4), estimate) + + def test_istft_of_zeros(self): + # stft = torch.stft(torch.zeros(4), 4) + stft = torch.zeros((3, 5, 2)) + + estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) + self._compare_estimate(torch.zeros(4), estimate) + + def test_istft_requires_overlap_windows(self): + # the window is size 1 but it hops 20 so there is a gap which throw an error + stft = torch.rand((3, 5, 2)) + self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4, + hop_length=20, win_length=1, window=torch.ones(1)) + + def test_istft_requires_nola(self): + stft = torch.zeros((3, 5, 2)) + kwargs_ok = { + 'n_fft': 4, + 'win_length': 4, + 'window': torch.ones(4), + } + + kwargs_not_ok = { + 'n_fft': 4, + 'win_length': 4, + 'window': torch.zeros(4), + } + + # A window of ones meets NOLA but a window of zeros does not. This should + # throw an error. + torchaudio.functional.istft(stft, **kwargs_ok) + self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok) + + def _test_istft_of_sine(self, amplitude, L, n): + # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L + x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype()) + sound = amplitude * torch.sin(2 * math.pi / L * x * n) + # stft = torch.stft(sound, L, hop_length=L, win_length=L, + # window=torch.ones(L), center=False, normalized=False) + stft = torch.zeros((L // 2 + 1, 2, 2)) + stft_largest_val = (amplitude * L) / 2.0 + if n < stft.size(0): + stft[n, :, 1] = -stft_largest_val + + if 0 <= L - n < stft.size(0): + # symmetric about L // 2 + stft[L - n, :, 1] = stft_largest_val + + estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L, + window=torch.ones(L), center=False, normalized=False) + # There is a larger error due to the scaling of amplitude + self._compare_estimate(sound, estimate, atol=1e-3) + + def test_istft_of_sine(self): + self._test_istft_of_sine(amplitude=123, L=5, n=1) + self._test_istft_of_sine(amplitude=150, L=5, n=2) + self._test_istft_of_sine(amplitude=111, L=5, n=3) + self._test_istft_of_sine(amplitude=160, L=7, n=4) + self._test_istft_of_sine(amplitude=145, L=8, n=5) + self._test_istft_of_sine(amplitude=80, L=9, n=6) + self._test_istft_of_sine(amplitude=99, L=10, n=7) + + +if __name__ == '__main__': + unittest.main() diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 95f583b4a8..a7fb5298c8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -7,6 +7,7 @@ 'pad_trim', 'downmix_mono', 'LC2CL', + 'istft', 'spectrogram', 'create_fb_matrix', 'spectrogram_to_DB', @@ -105,6 +106,138 @@ def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normal return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) +def istft(stft_matrix, # type: Tensor + n_fft, # type: int + hop_length=None, # type: Optional[int] + win_length=None, # type: Optional[int] + window=None, # type: Optional[Tensor] + center=True, # type: bool + pad_mode='reflect', # type: str + normalized=False, # type: bool + onesided=True, # type: bool + length=None # type: Optional[int] + ): + # type: (...) -> Tensor + r""" Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. + It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the + least squares estimation of the original signal. The algorithm will check using the NOLA condition ( + nonzero overlap). + Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop + created by the summation of all the windows is never zero at certain point in time. Specifically, + :math:`\sum_{t=-\ infty}^{\ infty} w^2[n-t\times hop\_length] \neq 0`. + Since stft discards elements at the end of the signal if they do not fit in a frame, the + istft may return a shorter signal than the original signal (can occur if :attr:`center` is False + since the signal isn't padded). + If :attr:`center` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding + can be trimmed off exactly because they can be calculated but right padding cannot be calculated + without additional information. + Example: Suppose the last window is: + [17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0] + The n_frames, hop_length, win_length are all the same which prevents the calculation of right padding. + These additional values could be zeros or a reflection of the signal so providing :attr:`length` + could be useful. If :attr:`length` is None then padding will be aggressively removed (some loss of signal). + [1] D. W. Griffin and J. S. Lim, “Signal estimation from modified short-time Fourier transform,” + IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. + Inputs: + stft_matrix (Tensor): output of stft where each row of a batch is a frequency and each column is + a window. it has a shape of either (batch, fft_size, n_frames, 2) or (fft_size, n_frames, 2) + n_fft (int): size of Fourier transform + hop_length (Optional[int]): the distance between neighboring sliding window frames. (Default: win_length // 4) + win_length (Optional[int]): the size of window frame and STFT filter. (Default: n_fft) + window (Optional[Tensor]): the optional window function. (Default: torch.ones(win_length)) + center (bool): whether :attr:`input` was padded on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}` + pad_mode (str): controls the padding method used when :attr:`center` is ``True`` + normalized (bool): whether the STFT was normalized + onesided (bool): whether the STFT is onesided + length (Optional[int]): the amount to trim the signal by (i.e. the + original signal length). (Default: whole signal) + Outputs: + Tensor: least squares estimation of the original signal of size (batch, signal_length) or (signal_length) + """ + stft_matrix_dim = stft_matrix.dim() + assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) + + if stft_matrix_dim == 3: + # add a batch dimension + stft_matrix = stft_matrix.unsqueeze(0) + + device = stft_matrix.device + fft_size = stft_matrix.size(1) + assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), ( + 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' + + 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) + + # use stft defaults for Optionals + if win_length is None: + win_length = n_fft + + if hop_length is None: + hop_length = int(win_length // 4) + + # There must be overlap + assert 0 < hop_length <= win_length + assert 0 < win_length <= n_fft + + if window is None: + window = torch.ones(win_length) + + assert window.dim() == 1 and window.size(0) == win_length + + if win_length != n_fft: + # center window with pad left and right zeros + left = (n_fft - win_length) // 2 + window = torch.nn.functional.pad(window, (left, n_fft - win_length - left)) + assert window.size(0) == n_fft + # win_length and n_fft are synonymous from here on + + stft_matrix = stft_matrix.transpose(1, 2) # size (batch, n_frames, fft_size, 2) + stft_matrix = torch.irfft(stft_matrix, 1, normalized, + onesided, signal_sizes=(n_fft,)) # size (batch, n_frames, n_fft) + + assert stft_matrix.size(2) == n_fft + n_frames = stft_matrix.size(1) + + ytmp = stft_matrix * window.view(1, 1, n_fft) # size (batch, n_frames, n_fft) + # each column of a batch is a frame which needs to be overlap added at the right place + ytmp = ytmp.transpose(1, 2) # size (batch, n_fft, n_frames) + + eye = torch.eye(n_fft, requires_grad=False, + device=device).unsqueeze(1) # size (n_fft, 1, n_fft) + + # this does overlap add where the frames of ytmp are added such that the i'th frame of + # ytmp is added starting at i*hop_length in the output + y = torch.nn.functional.conv_transpose1d( + ytmp, eye, stride=hop_length, padding=0) # size (batch, 1, expected_signal_len) + + # do the same for the window function + window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames) + window_envelop = torch.nn.functional.conv_transpose1d( + window_sq, eye, stride=hop_length, padding=0) # size (1, 1, expected_signal_len) + + expected_signal_len = n_fft + hop_length * (n_frames - 1) + assert y.size(2) == expected_signal_len + assert window_envelop.size(2) == expected_signal_len + + half_n_fft = n_fft // 2 + # we need to trim the front padding away if center + start = half_n_fft if center else 0 + end = -half_n_fft if length is None else start + length + + y = y[:, :, start:end] + window_envelop = window_envelop[:, :, start:end] + + # check NOLA non-zero overlap condition + window_envelop_lowest = window_envelop.abs().min() + assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest)) + + y = (y / window_envelop).squeeze(1) # size (batch, expected_signal_len) + + if stft_matrix_dim == 3: # remove the batch dimension + y = y.squeeze(0) + return y + + @torch.jit.script def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor