Skip to content

Commit c37c88c

Browse files
committed
Use istft from torch
1 parent 3a4f356 commit c37c88c

File tree

2 files changed

+17
-106
lines changed

2 files changed

+17
-106
lines changed

test/test_functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_istft_of_zeros(self):
172172
def test_istft_requires_overlap_windows(self):
173173
# the window is size 1 but it hops 20 so there is a gap which throw an error
174174
stft = torch.zeros((3, 5, 2))
175-
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4,
175+
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
176176
hop_length=20, win_length=1, window=torch.ones(1))
177177

178178
def test_istft_requires_nola(self):
@@ -192,11 +192,11 @@ def test_istft_requires_nola(self):
192192
# A window of ones meets NOLA but a window of zeros does not. This should
193193
# throw an error.
194194
torchaudio.functional.istft(stft, **kwargs_ok)
195-
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok)
195+
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
196196

197197
def test_istft_requires_non_empty(self):
198-
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
199-
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
198+
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
199+
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
200200

201201
def _test_istft_of_sine(self, amplitude, L, n):
202202
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L

torchaudio/functional.py

Lines changed: 13 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import math
44
from typing import Optional, Tuple
5+
import warnings
56

67
import torch
78
from torch import Tensor
@@ -49,7 +50,7 @@ def istft(
4950
win_length: Optional[int] = None,
5051
window: Optional[Tensor] = None,
5152
center: bool = True,
52-
pad_mode: str = "reflect",
53+
pad_mode: Optional[str] = None,
5354
normalized: bool = False,
5455
onesided: bool = True,
5556
length: Optional[int] = None,
@@ -94,8 +95,7 @@ def istft(
9495
center (bool, optional): Whether ``input`` was padded on both sides so
9596
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
9697
(Default: ``True``)
97-
pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
98-
``"reflect"``)
98+
pad_mode: This argument was ignored and to be removed.
9999
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
100100
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
101101
length (int or None, optional): The amount to trim the signal by (i.e. the
@@ -104,105 +104,16 @@ def istft(
104104
Returns:
105105
Tensor: Least squares estimation of the original signal of size (..., signal_length)
106106
"""
107-
stft_matrix_dim = stft_matrix.dim()
108-
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
109-
assert stft_matrix.numel() > 0
110-
111-
if stft_matrix_dim == 3:
112-
# add a channel dimension
113-
stft_matrix = stft_matrix.unsqueeze(0)
114-
115-
# pack batch
116-
shape = stft_matrix.size()
117-
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
118-
119-
dtype = stft_matrix.dtype
120-
device = stft_matrix.device
121-
fft_size = stft_matrix.size(1)
122-
assert (onesided and n_fft // 2 + 1 == fft_size) or (
123-
not onesided and n_fft == fft_size
124-
), (
125-
"one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
126-
+ "Given values were onesided: %s, n_fft: %d, fft_size: %d"
127-
% ("True" if onesided else False, n_fft, fft_size)
128-
)
129-
130-
# use stft defaults for Optionals
131-
if win_length is None:
132-
win_length = n_fft
133-
134-
if hop_length is None:
135-
hop_length = int(win_length // 4)
136-
137-
# There must be overlap
138-
assert 0 < hop_length <= win_length
139-
assert 0 < win_length <= n_fft
140-
141-
if window is None:
142-
window = torch.ones(win_length, device=device, dtype=dtype)
143-
144-
assert window.dim() == 1 and window.size(0) == win_length
145-
146-
if win_length != n_fft:
147-
# center window with pad left and right zeros
148-
left = (n_fft - win_length) // 2
149-
window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
150-
assert window.size(0) == n_fft
151-
# win_length and n_fft are synonymous from here on
152-
153-
stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frame, fft_size, 2)
154-
stft_matrix = torch.irfft(
155-
stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
156-
) # size (channel, n_frame, n_fft)
157-
158-
assert stft_matrix.size(2) == n_fft
159-
n_frame = stft_matrix.size(1)
160-
161-
ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frame, n_fft)
162-
# each column of a channel is a frame which needs to be overlap added at the right place
163-
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frame)
164-
165-
# this does overlap add where the frames of ytmp are added such that the i'th frame of
166-
# ytmp is added starting at i*hop_length in the output
167-
y = torch.nn.functional.fold(
168-
ytmp, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length)
169-
).squeeze(2)
170-
171-
# do the same for the window function
172-
window_sq = (
173-
window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0)
174-
) # size (1, n_fft, n_frame)
175-
window_envelop = torch.nn.functional.fold(
176-
window_sq, (1, (n_frame - 1) * hop_length + n_fft), (1, n_fft), stride=(1, hop_length)
177-
).squeeze(2) # size (1, 1, expected_signal_len)
178-
179-
expected_signal_len = n_fft + hop_length * (n_frame - 1)
180-
assert y.size(2) == expected_signal_len
181-
assert window_envelop.size(2) == expected_signal_len
182-
183-
half_n_fft = n_fft // 2
184-
# we need to trim the front padding away if center
185-
start = half_n_fft if center else 0
186-
end = -half_n_fft if length is None else start + length
187-
188-
y = y[:, :, start:end]
189-
window_envelop = window_envelop[:, :, start:end]
190-
191-
# check NOLA non-zero overlap condition
192-
window_envelop_lowest = window_envelop.abs().min()
193-
assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
194-
window_envelop_lowest
195-
)
196-
197-
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
198-
199-
# unpack batch
200-
y = y.reshape(shape[:-3] + y.shape[-1:])
201-
202-
if stft_matrix_dim == 3: # remove the channel dimension
203-
y = y.squeeze(0)
204-
205-
return y
107+
warnings.warn(
108+
'istft has been moved to PyTorch and will be removed from torchaudio, '
109+
'please use torch.istft instead.')
110+
if pad_mode is not None:
111+
warnings.warn(
112+
'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
113+
'Please set `pad_mode` to None to suppress this warning.')
114+
return torch.istft(
115+
input=stft_matrix, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window,
116+
center=center, normalized=normalized, onesided=onesided, length=length)
206117

207118

208119
def spectrogram(

0 commit comments

Comments
 (0)