-
Notifications
You must be signed in to change notification settings - Fork 741
ISTFT #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
ISTFT #130
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
95a810f
first
jamarshon 0a8386c
add tests
jamarshon 5f48c98
more test
jamarshon 1d2cbed
remove print
jamarshon 6107f65
abs min instead of min
jamarshon 050ae23
apply feedback
jamarshon 690fe92
apply feedback
jamarshon 1e2f949
flake8
jamarshon ee20335
apply feedback
jamarshon 7242e7b
apply feedback
jamarshon 50a6f3e
apply feedback
jamarshon 38c94b7
fix test_transforms.py. pytorch nightly must have changed from_numpy …
jamarshon 60ae1bb
apply feedback
jamarshon 1c56a06
apply feedback
jamarshon fe001e1
apply feedback
jamarshon 8427a89
apply feedback
jamarshon 80255af
flake8
jamarshon 1d79b54
apply feedback
jamarshon 2f36eb7
apply feedback
jamarshon 21c95a2
test
jamarshon dd2b838
test
jamarshon a0de40d
done
jamarshon d2a72e6
apply feedback
jamarshon 6a9ef42
apply feedback
jamarshon fc57968
apply feedback
jamarshon d8bbb8d
revert files
jamarshon 92801d5
apply feedback
jamarshon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,202 @@ | ||
| import math | ||
|
|
||
| import torch | ||
| import torchaudio | ||
| import unittest | ||
| import test.common_utils | ||
|
|
||
|
|
||
| class TestFunctional(unittest.TestCase): | ||
| data_sizes = [(2, 20), (3, 15)] | ||
| number_of_trials = 100 | ||
| stored_rand_data = [] | ||
| fixed_precision = int(1e10) | ||
|
|
||
| def setUp(self): | ||
| # we want to make sure that the random values are reproducible | ||
| self.stored_rand_data.clear() | ||
| torch.manual_seed(0) | ||
| for data_size in self.data_sizes: | ||
| rand_data1 = torch.randint(low=-self.fixed_precision, high=self.fixed_precision, size=data_size) | ||
| rand_data2 = torch.randint(low=-self.fixed_precision, high=self.fixed_precision, size=data_size) | ||
| self.stored_rand_data.append([rand_data1, rand_data2]) | ||
|
|
||
| def _get_random_tensor(self, i): | ||
| # gets a random tensor of size data_sizes[i]. adds to previous tensors and then mods it. | ||
| rand_data = self.stored_rand_data[i] | ||
| rand_data3 = (rand_data[0] + rand_data[1]) % self.fixed_precision | ||
| rand_data.pop(0) | ||
| rand_data.append(rand_data3) | ||
| return rand_data3.float() / self.fixed_precision | ||
|
|
||
| def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): | ||
| # trim sound for case when constructed signal is shorter than original | ||
| sound = sound[..., :estimate.size(-1)] | ||
|
|
||
| self.assertTrue(sound.shape == estimate.shape, (sound.shape, estimate.shape)) | ||
| self.assertTrue(torch.allclose(sound, estimate, atol=atol, rtol=rtol)) | ||
|
|
||
| def _test_istft_is_inverse_of_stft(self, kwargs): | ||
| # generates a random sound signal for each tril and then does the stft/istft | ||
| # operation to check whether we can reconstruct signal | ||
| for i in range(len(self.data_sizes)): | ||
| for _ in range(self.number_of_trials): | ||
| sound = self._get_random_tensor(i) | ||
|
|
||
| stft = torch.stft(sound, **kwargs) | ||
| estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) | ||
|
|
||
| self._compare_estimate(sound, estimate) | ||
|
|
||
| def test_istft_is_inverse_of_stft1(self): | ||
| # hann_window, centered, normalized, onesided | ||
| kwargs1 = { | ||
| 'n_fft': 12, | ||
| 'hop_length': 4, | ||
| 'win_length': 12, | ||
| 'window': torch.hann_window(12), | ||
| 'center': True, | ||
| 'pad_mode': 'reflect', | ||
| 'normalized': True, | ||
| 'onesided': True, | ||
| } | ||
|
|
||
| self._test_istft_is_inverse_of_stft(kwargs1) | ||
|
|
||
| def test_istft_is_inverse_of_stft2(self): | ||
| # hann_window, centered, not normalized, not onesided | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| kwargs2 = { | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| 'n_fft': 12, | ||
| 'hop_length': 2, | ||
| 'win_length': 8, | ||
| 'window': torch.hann_window(8), | ||
| 'center': True, | ||
| 'pad_mode': 'reflect', | ||
| 'normalized': False, | ||
| 'onesided': False, | ||
| } | ||
|
|
||
| self._test_istft_is_inverse_of_stft(kwargs2) | ||
|
|
||
| def test_istft_is_inverse_of_stft3(self): | ||
| # hamming_window, centered, normalized, not onesided | ||
| kwargs3 = { | ||
| 'n_fft': 15, | ||
| 'hop_length': 3, | ||
| 'win_length': 11, | ||
| 'window': torch.hamming_window(11), | ||
| 'center': True, | ||
| 'pad_mode': 'constant', | ||
| 'normalized': True, | ||
| 'onesided': False, | ||
| } | ||
|
|
||
| self._test_istft_is_inverse_of_stft(kwargs3) | ||
|
|
||
| def test_istft_is_inverse_of_stft4(self): | ||
| # hamming_window, not centered, not normalized, onesided | ||
| # window same size as n_fft | ||
| kwargs4 = { | ||
| 'n_fft': 5, | ||
| 'hop_length': 2, | ||
| 'win_length': 5, | ||
| 'window': torch.hamming_window(5), | ||
| 'center': False, | ||
| 'pad_mode': 'constant', | ||
| 'normalized': False, | ||
| 'onesided': True, | ||
| } | ||
|
|
||
| self._test_istft_is_inverse_of_stft(kwargs4) | ||
|
|
||
| def test_istft_is_inverse_of_stft5(self): | ||
| # hamming_window, not centered, not normalized, not onesided | ||
| # window same size as n_fft | ||
| kwargs5 = { | ||
| 'n_fft': 3, | ||
| 'hop_length': 2, | ||
| 'win_length': 3, | ||
| 'window': torch.hamming_window(3), | ||
| 'center': False, | ||
| 'pad_mode': 'reflect', | ||
| 'normalized': False, | ||
| 'onesided': False, | ||
| } | ||
|
|
||
| self._test_istft_is_inverse_of_stft(kwargs5) | ||
|
|
||
| def test_istft_of_ones(self): | ||
| # stft = torch.stft(torch.ones(4), 4) | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| stft = torch.tensor([ | ||
| [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], | ||
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], | ||
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] | ||
| ]) | ||
|
|
||
| estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) | ||
| self._compare_estimate(torch.ones(4), estimate) | ||
|
|
||
| def test_istft_of_zeros(self): | ||
| # stft = torch.stft(torch.zeros(4), 4) | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| stft = torch.zeros((3, 5, 2)) | ||
|
|
||
| estimate = torchaudio.functional.istft(stft, n_fft=4, length=4) | ||
| self._compare_estimate(torch.zeros(4), estimate) | ||
|
|
||
| def test_istft_requires_overlap_windows(self): | ||
| # the window is size 1 but it hops 20 so there is a gap which throw an error | ||
| stft = torch.rand((3, 5, 2)) | ||
| self.assertRaises(AssertionError, torchaudio.functional.istft, stft, n_fft=4, | ||
| hop_length=20, win_length=1, window=torch.ones(1)) | ||
|
|
||
| def test_istft_requires_nola(self): | ||
| stft = torch.zeros((3, 5, 2)) | ||
| kwargs_ok = { | ||
| 'n_fft': 4, | ||
| 'win_length': 4, | ||
| 'window': torch.ones(4), | ||
| } | ||
|
|
||
| kwargs_not_ok = { | ||
| 'n_fft': 4, | ||
| 'win_length': 4, | ||
| 'window': torch.zeros(4), | ||
| } | ||
|
|
||
| # A window of ones meets NOLA but a window of zeros does not. This should | ||
| # throw an error. | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| torchaudio.functional.istft(stft, **kwargs_ok) | ||
| self.assertRaises(AssertionError, torchaudio.functional.istft, stft, **kwargs_not_ok) | ||
|
|
||
| def _test_istft_of_sine(self, amplitude, L, n): | ||
| # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L | ||
| x = torch.arange(2 * L + 1, dtype=torch.get_default_dtype()) | ||
| sound = amplitude * torch.sin(2 * math.pi / L * x * n) | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # stft = torch.stft(sound, L, hop_length=L, win_length=L, | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # window=torch.ones(L), center=False, normalized=False) | ||
| stft = torch.zeros((L // 2 + 1, 2, 2)) | ||
| stft_largest_val = (amplitude * L) / 2.0 | ||
| if n < stft.size(0): | ||
| stft[n, :, 1] = -stft_largest_val | ||
|
|
||
| if 0 <= L - n < stft.size(0): | ||
| # symmetric about L // 2 | ||
| stft[L - n, :, 1] = stft_largest_val | ||
|
|
||
| estimate = torchaudio.functional.istft(stft, L, hop_length=L, win_length=L, | ||
| window=torch.ones(L), center=False, normalized=False) | ||
| # There is a larger error due to the scaling of amplitude | ||
| self._compare_estimate(sound, estimate, atol=1e-3) | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def test_istft_of_sine(self): | ||
| self._test_istft_of_sine(amplitude=123, L=5, n=1) | ||
| self._test_istft_of_sine(amplitude=150, L=5, n=2) | ||
| self._test_istft_of_sine(amplitude=111, L=5, n=3) | ||
| self._test_istft_of_sine(amplitude=160, L=7, n=4) | ||
| self._test_istft_of_sine(amplitude=145, L=8, n=5) | ||
| self._test_istft_of_sine(amplitude=80, L=9, n=6) | ||
| self._test_istft_of_sine(amplitude=99, L=10, n=7) | ||
jamarshon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.