Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
from shutil import copytree
import backports.tempfile as tempfile
import platform
_MAJOR, _MINOR, _PATCH = platform.python_version_tuple()
if _MAJOR == '2' or (_MAJOR == '3' and int(_MINOR) < 7):
import backports.tempfile as tempfile
else:
import tempfile

import torch

TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
Expand Down
107 changes: 6 additions & 101 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import math
import warnings

import torch

Expand Down Expand Up @@ -120,107 +121,11 @@ def istft(
Returns:
torch.Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
assert stft_matrix.numel() > 0

if stft_matrix_dim == 3:
# add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0)

# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1])

dtype = stft_matrix.dtype
device = stft_matrix.device
fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (
not onesided and n_fft == fft_size
), (
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
+ "Given values were onesided: %s, n_fft: %d, fft_size: %d"
% ("True" if onesided else False, n_fft, fft_size)
)

# use stft defaults for Optionals
if win_length is None:
win_length = n_fft

if hop_length is None:
hop_length = int(win_length // 4)

# There must be overlap
assert 0 < hop_length <= win_length
assert 0 < win_length <= n_fft

if window is None:
window = torch.ones(win_length, device=device, dtype=dtype)

assert window.dim() == 1 and window.size(0) == win_length

if win_length != n_fft:
# center window with pad left and right zeros
left = (n_fft - win_length) // 2
window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
assert window.size(0) == n_fft
# win_length and n_fft are synonymous from here on

stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frame, fft_size, 2)
stft_matrix = torch.irfft(
stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
) # size (channel, n_frame, n_fft)

assert stft_matrix.size(2) == n_fft
n_frame = stft_matrix.size(1)

ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frame, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame)

eye = torch.eye(n_fft, device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft)

# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.conv_transpose1d(
ytmp, eye, stride=hop_length, padding=0
) # size (channel, 1, expected_signal_len)

# do the same for the window function
window_sq = (
window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0)
) # size (1, n_fft, n_frame)
window_envelop = torch.nn.functional.conv_transpose1d(
window_sq, eye, stride=hop_length, padding=0
) # size (1, 1, expected_signal_len)

expected_signal_len = n_fft + hop_length * (n_frame - 1)
assert y.size(2) == expected_signal_len
assert window_envelop.size(2) == expected_signal_len

half_n_fft = n_fft // 2
# we need to trim the front padding away if center
start = half_n_fft if center else 0
end = -half_n_fft if length is None else start + length

y = y[:, :, start:end]
window_envelop = window_envelop[:, :, start:end]

# check NOLA non-zero overlap condition
window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
window_envelop_lowest
)

y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)

# unpack batch
y = y.view(shape[:-3] + y.shape[-1:])

if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)

return y
warnings.warn(
'istft has been moved to PyTorch and will be deprecated, please use torch.istft instead.')
return torch.istft(
input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided, length=length)


def spectrogram(
Expand Down