Skip to content

Remove istft wrapper #839

@mthrok

Description

@mthrok

istft has been moved to torch.

For the backward compatibility we kept the torchaudio.functional.istft interface, but it has been deprecated in 0.6.0 release.

We can now remove the interface itself.

  1. Delete the function definition
    def istft(
    stft_matrix: Tensor,
    n_fft: int,
    hop_length: Optional[int] = None,
    win_length: Optional[int] = None,
    window: Optional[Tensor] = None,
    center: bool = True,
    pad_mode: Optional[str] = None,
    normalized: bool = False,
    onesided: bool = True,
    length: Optional[int] = None,
    ) -> Tensor:
    r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
    It has the same parameters (+ additional optional parameter of ``length``) and it should return the
    least squares estimation of the original signal. The algorithm will check using the NOLA condition (
    nonzero overlap).
    Important consideration in the parameters ``window`` and ``center`` so that the envelop
    created by the summation of all the windows is never zero at certain point in time. Specifically,
    :math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`.
    Since stft discards elements at the end of the signal if they do not fit in a frame, the
    istft may return a shorter signal than the original signal (can occur if ``center`` is False
    since the signal isn't padded).
    If ``center`` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
    can be trimmed off exactly because they can be calculated but right padding cannot be calculated
    without additional information.
    Example: Suppose the last window is:
    [17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
    The n_frame, hop_length, win_length are all the same which prevents the calculation of right padding.
    These additional values could be zeros or a reflection of the signal so providing ``length``
    could be useful. If ``length`` is ``None`` then padding will be aggressively removed
    (some loss of signal).
    [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
    IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
    Args:
    stft_matrix (Tensor): Output of stft where each row of a channel is a frequency and each
    column is a window. It has a size of either (..., fft_size, n_frame, 2)
    n_fft (int): Size of Fourier transform
    hop_length (int or None, optional): The distance between neighboring sliding window frames.
    (Default: ``win_length // 4``)
    win_length (int or None, optional): The size of window frame and STFT filter. (Default: ``n_fft``)
    window (Tensor or None, optional): The optional window function.
    (Default: ``torch.ones(win_length)``)
    center (bool, optional): Whether ``input`` was padded on both sides so
    that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
    (Default: ``True``)
    pad_mode: This argument was ignored and to be removed.
    normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
    onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
    length (int or None, optional): The amount to trim the signal by (i.e. the
    original signal length). (Default: whole signal)
    Returns:
    Tensor: Least squares estimation of the original signal of size (..., signal_length)
    """
    warnings.warn(
    'istft has been moved to PyTorch and will be removed from torchaudio, '
    'please use torch.istft instead.')
    if pad_mode is not None:
    warnings.warn(
    'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
    'Please set `pad_mode` to None to suppress this warning.')
    return torch.istft(
    input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
    center=center, normalized=normalized, onesided=onesided, length=length)
  2. Delete the tests
    def _compare_estimate(sound, estimate, atol=1e-6, rtol=1e-8):
    # trim sound for case when constructed signal is shorter than original
    sound = sound[..., :estimate.size(-1)]
    torch.testing.assert_allclose(estimate, sound, atol=atol, rtol=rtol)
    def _test_istft_is_inverse_of_stft(kwargs):
    # generates a random sound signal for each tril and then does the stft/istft
    # operation to check whether we can reconstruct signal
    for data_size in [(2, 20), (3, 15), (4, 10)]:
    for i in range(100):
    sound = random_float_tensor(i, data_size)
    stft = torch.stft(sound, **kwargs)
    estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
    _compare_estimate(sound, estimate)
    class TestIstft(common_utils.TorchaudioTestCase):
    """Test suite for correctness of istft with various input"""
    number_of_trials = 100
    def test_istft_is_inverse_of_stft1(self):
    # hann_window, centered, normalized, onesided
    kwargs1 = {
    'n_fft': 12,
    'hop_length': 4,
    'win_length': 12,
    'window': torch.hann_window(12),
    'center': True,
    'pad_mode': 'reflect',
    'normalized': True,
    'onesided': True,
    }
    _test_istft_is_inverse_of_stft(kwargs1)
    def test_istft_is_inverse_of_stft2(self):
    # hann_window, centered, not normalized, not onesided
    kwargs2 = {
    'n_fft': 12,
    'hop_length': 2,
    'win_length': 8,
    'window': torch.hann_window(8),
    'center': True,
    'pad_mode': 'reflect',
    'normalized': False,
    'onesided': False,
    }
    _test_istft_is_inverse_of_stft(kwargs2)
    def test_istft_is_inverse_of_stft3(self):
    # hamming_window, centered, normalized, not onesided
    kwargs3 = {
    'n_fft': 15,
    'hop_length': 3,
    'win_length': 11,
    'window': torch.hamming_window(11),
    'center': True,
    'pad_mode': 'constant',
    'normalized': True,
    'onesided': False,
    }
    _test_istft_is_inverse_of_stft(kwargs3)
    def test_istft_is_inverse_of_stft4(self):
    # hamming_window, not centered, not normalized, onesided
    # window same size as n_fft
    kwargs4 = {
    'n_fft': 5,
    'hop_length': 2,
    'win_length': 5,
    'window': torch.hamming_window(5),
    'center': False,
    'pad_mode': 'constant',
    'normalized': False,
    'onesided': True,
    }
    _test_istft_is_inverse_of_stft(kwargs4)
    def test_istft_is_inverse_of_stft5(self):
    # hamming_window, not centered, not normalized, not onesided
    # window same size as n_fft
    kwargs5 = {
    'n_fft': 3,
    'hop_length': 2,
    'win_length': 3,
    'window': torch.hamming_window(3),
    'center': False,
    'pad_mode': 'reflect',
    'normalized': False,
    'onesided': False,
    }
    _test_istft_is_inverse_of_stft(kwargs5)
    def test_istft_of_ones(self):
    # stft = torch.stft(torch.ones(4), 4)
    stft = torch.tensor([
    [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
    [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
    [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
    ])
    estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
    _compare_estimate(torch.ones(4), estimate)
    def test_istft_of_zeros(self):
    # stft = torch.stft(torch.zeros(4), 4)
    stft = torch.zeros((3, 5, 2))
    estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
    _compare_estimate(torch.zeros(4), estimate)
    def test_istft_requires_overlap_windows(self):
    # the window is size 1 but it hops 20 so there is a gap which throw an error
    stft = torch.zeros((3, 5, 2))
    self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
    hop_length=20, win_length=1, window=torch.ones(1))
    def test_istft_requires_nola(self):
    stft = torch.zeros((3, 5, 2))
    kwargs_ok = {
    'n_fft': 4,
    'win_length': 4,
    'window': torch.ones(4),
    }
    kwargs_not_ok = {
    'n_fft': 4,
    'win_length': 4,
    'window': torch.zeros(4),
    }
    # A window of ones meets NOLA but a window of zeros does not. This should
    # throw an error.
    torchaudio.functional.istft(stft, **kwargs_ok)
    self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
    def test_istft_requires_non_empty(self):
    self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
    self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
    def _test_istft_of_sine(self, amplitude, L, n):
    # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
    x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
    sound = amplitude * torch.sin(2 * math.pi / L * x * n)
    # stft = torch.stft(sound, L, hop_length=L, win_length=L,
    # window=torch.ones(L), center=False, normalized=False)
    stft = torch.zeros((L // 2 + 1, 2, 2))
    stft_largest_val = (amplitude * L) / 2.0
    if n < stft.size(0):
    stft[n, :, 1] = -stft_largest_val
    if 0 <= L - n < stft.size(0):
    # symmetric about L // 2
    stft[L - n, :, 1] = stft_largest_val
    estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
    window=torch.ones(L), center=False, normalized=False)
    # There is a larger error due to the scaling of amplitude
    _compare_estimate(sound, estimate, atol=1e-3)
    def test_istft_of_sine(self):
    self._test_istft_of_sine(amplitude=123, L=5, n=1)
    self._test_istft_of_sine(amplitude=150, L=5, n=2)
    self._test_istft_of_sine(amplitude=111, L=5, n=3)
    self._test_istft_of_sine(amplitude=160, L=7, n=4)
    self._test_istft_of_sine(amplitude=145, L=8, n=5)
    self._test_istft_of_sine(amplitude=80, L=9, n=6)
    self._test_istft_of_sine(amplitude=99, L=10, n=7)
    def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8):
    for i in range(self.number_of_trials):
    tensor1 = random_float_tensor(i, data_size)
    tensor2 = random_float_tensor(i * 2, data_size)
    a, b = torch.rand(2)
    istft1 = torchaudio.functional.istft(tensor1, **kwargs)
    istft2 = torchaudio.functional.istft(tensor2, **kwargs)
    istft = a * istft1 + b * istft2
    estimate = torchaudio.functional.istft(a * tensor1 + b * tensor2, **kwargs)
    _compare_estimate(istft, estimate, atol, rtol)
    def test_linearity_of_istft1(self):
    # hann_window, centered, normalized, onesided
    kwargs1 = {
    'n_fft': 12,
    'window': torch.hann_window(12),
    'center': True,
    'pad_mode': 'reflect',
    'normalized': True,
    'onesided': True,
    }
    data_size = (2, 7, 7, 2)
    self._test_linearity_of_istft(data_size, kwargs1)
    def test_linearity_of_istft2(self):
    # hann_window, centered, not normalized, not onesided
    kwargs2 = {
    'n_fft': 12,
    'window': torch.hann_window(12),
    'center': True,
    'pad_mode': 'reflect',
    'normalized': False,
    'onesided': False,
    }
    data_size = (2, 12, 7, 2)
    self._test_linearity_of_istft(data_size, kwargs2)
    def test_linearity_of_istft3(self):
    # hamming_window, centered, normalized, not onesided
    kwargs3 = {
    'n_fft': 12,
    'window': torch.hamming_window(12),
    'center': True,
    'pad_mode': 'constant',
    'normalized': True,
    'onesided': False,
    }
    data_size = (2, 12, 7, 2)
    self._test_linearity_of_istft(data_size, kwargs3)
    def test_linearity_of_istft4(self):
    # hamming_window, not centered, not normalized, onesided
    kwargs4 = {
    'n_fft': 12,
    'window': torch.hamming_window(12),
    'center': False,
    'pad_mode': 'constant',
    'normalized': False,
    'onesided': True,
    }
    data_size = (2, 7, 3, 2)
    self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
  3. Delete the helper function
    def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
    """ Generates random tensors given a seed and size
    https://en.wikipedia.org/wiki/Linear_congruential_generator
    X_{n + 1} = (a * X_n + c) % m
    Using Borland C/C++ values
    The tensor will have values between [0,1)
    Inputs:
    seed (int): an int
    size (Tuple[int]): the size of the output tensor
    a (int): the multiplier constant to the generator
    c (int): the additive constant to the generator
    m (int): the modulus constant to the generator
    """
    num_elements = 1
    for s in size:
    num_elements *= s
    arr = [(a * seed + c) % m]
    for i in range(num_elements - 1):
    arr.append((a * arr[i] + c) % m)
    return torch.tensor(arr).float().view(size) / m

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions