Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_istft_of_zeros(self):
def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft = torch.zeros((3, 5, 2))
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4,
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1))

def test_istft_requires_nola(self):
Expand All @@ -192,11 +192,11 @@ def test_istft_requires_nola(self):
# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio.functional.istft(stft, **kwargs_ok)
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok)
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)

def test_istft_requires_non_empty(self):
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)

def _test_istft_of_sine(self, amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
Expand Down
115 changes: 13 additions & 102 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
from typing import Optional, Tuple
import warnings

import torch
from torch import Tensor
Expand Down Expand Up @@ -49,7 +50,7 @@ def istft(
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: str = "reflect",
pad_mode: Optional[str] = None,
normalized: bool = False,
onesided: bool = True,
length: Optional[int] = None,
Expand Down Expand Up @@ -94,8 +95,7 @@ def istft(
center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
``"reflect"``)
pad_mode: This argument was ignored and to be removed.
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (int or None, optional): The amount to trim the signal by (i.e. the
Expand All @@ -104,105 +104,16 @@ def istft(
Returns:
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert stft_matrix.numel() > 0

if stft_matrix_dim == 3:
# add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0)

# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
Comment on lines -115 to -117
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what enabled the implementation to support batching.


dtype = stft_matrix.dtype
device = stft_matrix.device
fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (
not onesided and n_fft == fft_size
), (
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
+ "Given values were onesided: %s, n_fft: %d, fft_size: %d"
% ("True" if onesided else False, n_fft, fft_size)
)

# use stft defaults for Optionals
if win_length is None:
win_length = n_fft

if hop_length is None:
hop_length = int(win_length // 4)

# There must be overlap
assert 0 < hop_length <= win_length
assert 0 < win_length <= n_fft

if window is None:
window = torch.ones(win_length, device=device, dtype=dtype)

assert window.dim() == 1 and window.size(0) == win_length

if win_length != n_fft:
# center window with pad left and right zeros
left = (n_fft - win_length) // 2
window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
assert window.size(0) == n_fft
# win_length and n_fft are synonymous from here on

stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frame, fft_size, 2)
stft_matrix = torch.irfft(
stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
) # size (channel, n_frame, n_fft)

assert stft_matrix.size(2) == n_fft
n_frame = stft_matrix.size(1)

ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frame, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame)

# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.fold(
ytmp, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length)
).squeeze(2)

# do the same for the window function
window_sq = (
window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0)
) # size (1, n_fft, n_frame)
window_envelop = torch.nn.functional.fold(
window_sq, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length)
).squeeze(2) # size (1, 1, expected_signal_len)

expected_signal_len = n_fft + hop_length * (n_frame - 1)
assert y.size(2) == expected_signal_len
assert window_envelop.size(2) == expected_signal_len

half_n_fft = n_fft // 2
# we need to trim the front padding away if center
start = half_n_fft if center else 0
end = -half_n_fft if length is None else start + length

y = y[:, :, start:end]
window_envelop = window_envelop[:, :, start:end]

# check NOLA non-zero overlap condition
window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
window_envelop_lowest
)

y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)

# unpack batch
y = y.reshape(shape[:-3] + y.shape[-1:])

if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)

return y
warnings.warn(
'istft has been moved to PyTorch and will be removed from torchaudio, '
'please use torch.istft instead.')
if pad_mode is not None:
warnings.warn(
'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
'Please set `pad_mode` to None to suppress this warning.')
return torch.istft(
input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
center=center, normalized=normalized, onesided=onesided, length=length)


def spectrogram(
Expand Down