Skip to content
Merged

ISTFT #135

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
95a810f
first
jamarshon Jul 3, 2019
0a8386c
add tests
jamarshon Jul 3, 2019
5f48c98
more test
jamarshon Jul 3, 2019
1d2cbed
remove print
jamarshon Jul 3, 2019
6107f65
abs min instead of min
jamarshon Jul 3, 2019
050ae23
apply feedback
jamarshon Jul 5, 2019
690fe92
apply feedback
jamarshon Jul 5, 2019
1e2f949
flake8
jamarshon Jul 5, 2019
ee20335
apply feedback
jamarshon Jul 5, 2019
7242e7b
apply feedback
jamarshon Jul 5, 2019
50a6f3e
apply feedback
jamarshon Jul 8, 2019
38c94b7
fix test_transforms.py. pytorch nightly must have changed from_numpy …
jamarshon Jul 8, 2019
60ae1bb
apply feedback
jamarshon Jul 8, 2019
1c56a06
apply feedback
jamarshon Jul 8, 2019
fe001e1
apply feedback
jamarshon Jul 8, 2019
8427a89
apply feedback
jamarshon Jul 9, 2019
80255af
flake8
jamarshon Jul 9, 2019
1d79b54
apply feedback
jamarshon Jul 9, 2019
2f36eb7
apply feedback
jamarshon Jul 9, 2019
21c95a2
test
jamarshon Jul 9, 2019
dd2b838
test
jamarshon Jul 9, 2019
a0de40d
done
jamarshon Jul 9, 2019
d2a72e6
apply feedback
jamarshon Jul 9, 2019
6a9ef42
apply feedback
jamarshon Jul 9, 2019
fc57968
apply feedback
jamarshon Jul 9, 2019
d8bbb8d
revert files
jamarshon Jul 9, 2019
92801d5
apply feedback
jamarshon Jul 9, 2019
03c0fe2
jMerge branch 'master' into istft
jamarshon Jul 10, 2019
f8d97da
apply feedback
jamarshon Jul 10, 2019
9cd6ee6
apply feedback
jamarshon Jul 10, 2019
025f4ef
apply feedback
jamarshon Jul 10, 2019
335bf81
apply feedback
jamarshon Jul 10, 2019
39c93d2
apply feedback
jamarshon Jul 10, 2019
50b5743
apply feedback
jamarshon Jul 10, 2019
78ebf7a
apply feedback
jamarshon Jul 10, 2019
f7fddea
apply feedback
jamarshon Jul 10, 2019
a41a6a0
apply feedback
jamarshon Jul 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from shutil import copytree
import tempfile
import torch


TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -16,3 +17,34 @@ def create_temp_assets_dir():
copytree(os.path.join(TEST_DIR_PATH, "assets"),
os.path.join(tmp_dir.name, "assets"))
return tmp_dir.name, tmp_dir


def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values

The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements = 1
for s in size:
num_elements *= s

arr = [(a * seed + c) % m]
for i in range(num_elements - 1):
arr.append((a * arr[i] + c) % m)

return torch.tensor(arr).float().view(size) / m


def random_int_tensor(seed, size, low=0, high=2 ** 32, a=22695477, c=1, m=2 ** 32):
""" Same as random_float_tensor but integers between [low, high)
"""
return torch.floor(random_float_tensor(seed, size, a, c, m) * (high - low)) + low
187 changes: 187 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import math

import torch
import torchaudio
import unittest
import test.common_utils


class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
number_of_trials = 100

def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
# trim sound for case when constructed signal is shorter than original
sound = sound[..., :estimate.size(-1)]

self.assertTrue(sound.shape == estimate.shape, (sound.shape, estimate.shape))
self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol))

def _test_istft_is_inverse_of_stft(self, kwargs):
# generates a random sound signal for each tril and then does the stft/istft
# operation to check whether we can reconstruct signal
for data_size in self.data_sizes:
for i in range(self.number_of_trials):
sound = test.common_utils.random_float_tensor(i, data_size)

stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)

self._compare_estimate(sound, estimate)

def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
'n_fft': 12,
'hop_length': 4,
'win_length': 12,
'window': torch.hann_window(12),
'center': True,
'pad_mode': 'reflect',
'normalized': True,
'onesided': True,
}

self._test_istft_is_inverse_of_stft(kwargs1)

def test_istft_is_inverse_of_stft2(self):
# hann_window, centered, not normalized, not onesided
kwargs2 = {
'n_fft': 12,
'hop_length': 2,
'win_length': 8,
'window': torch.hann_window(8),
'center': True,
'pad_mode': 'reflect',
'normalized': False,
'onesided': False,
}

self._test_istft_is_inverse_of_stft(kwargs2)

def test_istft_is_inverse_of_stft3(self):
# hamming_window, centered, normalized, not onesided
kwargs3 = {
'n_fft': 15,
'hop_length': 3,
'win_length': 11,
'window': torch.hamming_window(11),
'center': True,
'pad_mode': 'constant',
'normalized': True,
'onesided': False,
}

self._test_istft_is_inverse_of_stft(kwargs3)

def test_istft_is_inverse_of_stft4(self):
# hamming_window, not centered, not normalized, onesided
# window same size as n_fft
kwargs4 = {
'n_fft': 5,
'hop_length': 2,
'win_length': 5,
'window': torch.hamming_window(5),
'center': False,
'pad_mode': 'constant',
'normalized': False,
'onesided': True,
}

self._test_istft_is_inverse_of_stft(kwargs4)

def test_istft_is_inverse_of_stft5(self):
# hamming_window, not centered, not normalized, not onesided
# window same size as n_fft
kwargs5 = {
'n_fft': 3,
'hop_length': 2,
'win_length': 3,
'window': torch.hamming_window(3),
'center': False,
'pad_mode': 'reflect',
'normalized': False,
'onesided': False,
}

self._test_istft_is_inverse_of_stft(kwargs5)

def test_istft_of_ones(self):
# stft = torch.stft(torch.ones(4), 4)
stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])

estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
self._compare_estimate(torch.ones(4), estimate)

def test_istft_of_zeros(self):
# stft = torch.stft(torch.zeros(4), 4)
stft = torch.zeros((3, 5, 2))

estimate = torchaudio.functional.istft(stft, n_fft=4, length=4)
self._compare_estimate(torch.zeros(4), estimate)

def test_istft_requires_overlap_windows(self):
# the window is size 1 but it hops 20 so there is a gap which throw an error
stft = torch.zeros((3, 5, 2))
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4,
hop_length=20, win_length=1, window=torch.ones(1))

def test_istft_requires_nola(self):
stft = torch.zeros((3, 5, 2))
kwargs_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.ones(4),
}

kwargs_not_ok = {
'n_fft': 4,
'win_length': 4,
'window': torch.zeros(4),
}

# A window of ones meets NOLA but a window of zeros does not. This should
# throw an error.
torchaudio.functional.istft(stft, **kwargs_ok)
self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok)

def test_istft_requires_non_empty(self):
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((3, 0, 2)), 2)
self.assertRaises(AssertionError, torchaudio.functional.istft, torch.zeros((0, 3, 2)), 2)

def _test_istft_of_sine(self, amplitude, L, n):
# stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype())
sound = amplitude * torch.sin(2 * math.pi / L * x * n)
# stft = torch.stft(sound, L, hop_length=L, win_length=L,
# window=torch.ones(L), center=False, normalized=False)
stft = torch.zeros((L // 2 + 1, 2, 2))
stft_largest_val = (amplitude * L) / 2.0
if n < stft.size(0):
stft[n, :, 1] = -stft_largest_val

if 0 <= L - n < stft.size(0):
# symmetric about L // 2
stft[L - n, :, 1] = stft_largest_val

estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L,
window=torch.ones(L), center=False, normalized=False)
# There is a larger error due to the scaling of amplitude
self._compare_estimate(sound, estimate, atol=1e-3)

def test_istft_of_sine(self):
self._test_istft_of_sine(amplitude=123, L=5, n=1)
self._test_istft_of_sine(amplitude=150, L=5, n=2)
self._test_istft_of_sine(amplitude=111, L=5, n=3)
self._test_istft_of_sine(amplitude=160, L=7, n=4)
self._test_istft_of_sine(amplitude=145, L=8, n=5)
self._test_istft_of_sine(amplitude=80, L=9, n=6)
self._test_istft_of_sine(amplitude=99, L=10, n=7)


if __name__ == '__main__':
unittest.main()
133 changes: 133 additions & 0 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
'pad_trim',
'downmix_mono',
'LC2CL',
'istft',
'spectrogram',
'create_fb_matrix',
'spectrogram_to_DB',
Expand Down Expand Up @@ -105,6 +106,138 @@ def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normal
return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided)


def istft(stft_matrix, # type: Tensor
n_fft, # type: int
hop_length=None, # type: Optional[int]
win_length=None, # type: Optional[int]
window=None, # type: Optional[Tensor]
center=True, # type: bool
pad_mode='reflect', # type: str
normalized=False, # type: bool
onesided=True, # type: bool
length=None # type: Optional[int]
):
# type: (...) -> Tensor
r""" Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
least squares estimation of the original signal. The algorithm will check using the NOLA condition (
nonzero overlap).
Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop
created by the summation of all the windows is never zero at certain point in time. Specifically,
:math:`\sum_{t=-\ infty}^{\ infty} w^2[n-t\times hop\_length] \neq 0`.
Since stft discards elements at the end of the signal if they do not fit in a frame, the
istft may return a shorter signal than the original signal (can occur if :attr:`center` is False
since the signal isn't padded).
If :attr:`center` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
can be trimmed off exactly because they can be calculated but right padding cannot be calculated
without additional information.
Example: Suppose the last window is:
[17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
The n_frames, hop_length, win_length are all the same which prevents the calculation of right padding.
These additional values could be zeros or a reflection of the signal so providing :attr:`length`
could be useful. If :attr:`length` is None then padding will be aggressively removed (some loss of signal).
[1] D. W. Griffin and J. S. Lim, “Signal estimation from modified short-time Fourier transform,”
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Inputs:
stft_matrix (Tensor): output of stft where each row of a batch is a frequency and each column is
a window. it has a shape of either (batch, fft_size, n_frames, 2) or (fft_size, n_frames, 2)
n_fft (int): size of Fourier transform
hop_length (Optional[int]): the distance between neighboring sliding window frames. (Default: win_length // 4)
win_length (Optional[int]): the size of window frame and STFT filter. (Default: n_fft)
window (Optional[Tensor]): the optional window function. (Default: torch.ones(win_length))
center (bool): whether :attr:`input` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`
pad_mode (str): controls the padding method used when :attr:`center` is ``True``
normalized (bool): whether the STFT was normalized
onesided (bool): whether the STFT is onesided
length (Optional[int]): the amount to trim the signal by (i.e. the
original signal length). (Default: whole signal)
Outputs:
Tensor: least squares estimation of the original signal of size (batch, signal_length) or (signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim))

if stft_matrix_dim == 3:
# add a batch dimension
stft_matrix = stft_matrix.unsqueeze(0)

device = stft_matrix.device
fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), (
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. '
+ 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size))

# use stft defaults for Optionals
if win_length is None:
win_length = n_fft

if hop_length is None:
hop_length = int(win_length // 4)

# There must be overlap
assert 0 < hop_length <= win_length
assert 0 < win_length <= n_fft

if window is None:
window = torch.ones(win_length)

assert window.dim() == 1 and window.size(0) == win_length

if win_length != n_fft:
# center window with pad left and right zeros
left = (n_fft - win_length) // 2
window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
assert window.size(0) == n_fft
# win_length and n_fft are synonymous from here on

stft_matrix = stft_matrix.transpose(1, 2) # size (batch, n_frames, fft_size, 2)
stft_matrix = torch.irfft(stft_matrix, 1, normalized,
onesided, signal_sizes=(n_fft,)) # size (batch, n_frames, n_fft)

assert stft_matrix.size(2) == n_fft
n_frames = stft_matrix.size(1)

ytmp = stft_matrix * window.view(1, 1, n_fft) # size (batch, n_frames, n_fft)
# each column of a batch is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (batch, n_fft, n_frames)

eye = torch.eye(n_fft, requires_grad=False,
device=device).unsqueeze(1) # size (n_fft, 1, n_fft)

# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.conv_transpose1d(
ytmp, eye, stride=hop_length, padding=0) # size (batch, 1, expected_signal_len)

# do the same for the window function
window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames)
window_envelop = torch.nn.functional.conv_transpose1d(
window_sq, eye, stride=hop_length, padding=0) # size (1, 1, expected_signal_len)

expected_signal_len = n_fft + hop_length * (n_frames - 1)
assert y.size(2) == expected_signal_len
assert window_envelop.size(2) == expected_signal_len

half_n_fft = n_fft // 2
# we need to trim the front padding away if center
start = half_n_fft if center else 0
end = -half_n_fft if length is None else start + length

y = y[:, :, start:end]
window_envelop = window_envelop[:, :, start:end]

# check NOLA non-zero overlap condition
window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest))

y = (y / window_envelop).squeeze(1) # size (batch, expected_signal_len)

if stft_matrix_dim == 3: # remove the batch dimension
y = y.squeeze(0)
return y


@torch.jit.script
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
Expand Down