From 3ab8169167eb27cf05347b430171e9ed84f4297c Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 18 Mar 2020 09:32:56 -0700 Subject: [PATCH] Use istft from torch --- test/common_utils.py | 8 ++- torchaudio/functional.py | 107 +++------------------------------------ 2 files changed, 13 insertions(+), 102 deletions(-) diff --git a/test/common_utils.py b/test/common_utils.py index a79f413d2b..3d09cc4697 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -1,7 +1,13 @@ from __future__ import absolute_import, division, print_function, unicode_literals import os from shutil import copytree -import backports.tempfile as tempfile +import platform +_MAJOR, _MINOR, _PATCH = platform.python_version_tuple() +if _MAJOR == '2' or (_MAJOR == '3' and int(_MINOR) < 7): + import backports.tempfile as tempfile +else: + import tempfile + import torch TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index c8763ad3c2..08ed03cb3e 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import math +import warnings import torch @@ -120,107 +121,11 @@ def istft( Returns: torch.Tensor: Least squares estimation of the original signal of size (..., signal_length) """ - stft_matrix_dim = stft_matrix.dim() - assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) - assert stft_matrix.numel() > 0 - - if stft_matrix_dim == 3: - # add a channel dimension - stft_matrix = stft_matrix.unsqueeze(0) - - # pack batch - shape = stft_matrix.size() - stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1]) - - dtype = stft_matrix.dtype - device = stft_matrix.device - fft_size = stft_matrix.size(1) - assert (onesided and n_fft // 2 + 1 == fft_size) or ( - not onesided and n_fft == fft_size - ), ( - "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. " - + "Given values were onesided: %s, n_fft: %d, fft_size: %d" - % ("True" if onesided else False, n_fft, fft_size) - ) - - # use stft defaults for Optionals - if win_length is None: - win_length = n_fft - - if hop_length is None: - hop_length = int(win_length // 4) - - # There must be overlap - assert 0 < hop_length <= win_length - assert 0 < win_length <= n_fft - - if window is None: - window = torch.ones(win_length, device=device, dtype=dtype) - - assert window.dim() == 1 and window.size(0) == win_length - - if win_length != n_fft: - # center window with pad left and right zeros - left = (n_fft - win_length) // 2 - window = torch.nn.functional.pad(window, (left, n_fft - win_length - left)) - assert window.size(0) == n_fft - # win_length and n_fft are synonymous from here on - - stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frame, fft_size, 2) - stft_matrix = torch.irfft( - stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,) - ) # size (channel, n_frame, n_fft) - - assert stft_matrix.size(2) == n_fft - n_frame = stft_matrix.size(1) - - ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frame, n_fft) - # each column of a channel is a frame which needs to be overlap added at the right place - ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame) - - eye = torch.eye(n_fft, device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft) - - # this does overlap add where the frames of ytmp are added such that the i'th frame of - # ytmp is added starting at i*hop_length in the output - y = torch.nn.functional.conv_transpose1d( - ytmp, eye, stride=hop_length, padding=0 - ) # size (channel, 1, expected_signal_len) - - # do the same for the window function - window_sq = ( - window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0) - ) # size (1, n_fft, n_frame) - window_envelop = torch.nn.functional.conv_transpose1d( - window_sq, eye, stride=hop_length, padding=0 - ) # size (1, 1, expected_signal_len) - - expected_signal_len = n_fft + hop_length * (n_frame - 1) - assert y.size(2) == expected_signal_len - assert window_envelop.size(2) == expected_signal_len - - half_n_fft = n_fft // 2 - # we need to trim the front padding away if center - start = half_n_fft if center else 0 - end = -half_n_fft if length is None else start + length - - y = y[:, :, start:end] - window_envelop = window_envelop[:, :, start:end] - - # check NOLA non-zero overlap condition - window_envelop_lowest = window_envelop.abs().min() - assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % ( - window_envelop_lowest - ) - - y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) - - # unpack batch - y = y.view(shape[:-3] + y.shape[-1:]) - - if stft_matrix_dim == 3: # remove the channel dimension - y = y.squeeze(0) - - return y + warnings.warn( + 'istft has been moved to PyTorch and will be deprecated, please use torch.istft instead.') + return torch.istft( + input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, + center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided, length=length) def spectrogram(