Skip to content

Commit 7bb3783

Browse files
committed
fix
1 parent 2f239e3 commit 7bb3783

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

test/test_functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_istft_of_zeros(self):
188188
def test_istft_requires_overlap_windows(self):
189189
# the window is size 1 but it hops 20 so there is a gap which throw an error
190190
stft = torch.zeros((3, 5, 2))
191-
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4,
191+
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, n_fft=4,
192192
hop_length=20, win_length=1, window=torch.ones(1))
193193

194194
def test_istft_requires_nola(self):
@@ -208,11 +208,11 @@ def test_istft_requires_nola(self):
208208
# A window of ones meets NOLA but a window of zeros does not. This should
209209
# throw an error.
210210
torchaudio.functional.istft(stft, **kwargs_ok)
211-
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok)
211+
self.assertRaises(RuntimeError, torchaudio.functional.istft, stft, **kwargs_not_ok)
212212

213213
def test_istft_requires_non_empty(self):
214-
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
215-
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
214+
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
215+
self.assertRaises(RuntimeError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)
216216

217217
def _test_istft_of_sine(self, amplitude, L, n):
218218
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L

torchaudio/functional.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,13 @@ def griffinlim(
248248
tprev = rebuilt
249249

250250
# Invert with our current estimate of the phases
251-
inverse = istft(specgram * angles,
252-
n_fft=n_fft,
253-
hop_length=hop_length,
254-
win_length=win_length,
255-
window=window,
256-
length=length).float()
251+
inverse = torch.istft(
252+
specgram * angles,
253+
n_fft=n_fft,
254+
hop_length=hop_length,
255+
win_length=win_length,
256+
window=window,
257+
length=length).float()
257258

258259
# Rebuild the spectrogram
259260
rebuilt = torch.stft(inverse, n_fft, hop_length, win_length, window,
@@ -266,12 +267,13 @@ def griffinlim(
266267
angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles))
267268

268269
# Return the final phase estimates
269-
waveform = istft(specgram * angles,
270-
n_fft=n_fft,
271-
hop_length=hop_length,
272-
win_length=win_length,
273-
window=window,
274-
length=length)
270+
waveform = torch.istft(
271+
specgram * angles,
272+
n_fft=n_fft,
273+
hop_length=hop_length,
274+
win_length=win_length,
275+
window=window,
276+
length=length)
275277

276278
# unpack batch
277279
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])

0 commit comments

Comments
 (0)