-
Notifications
You must be signed in to change notification settings - Fork 739
extend batch support #391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extend batch support #391
Changes from all commits
d1a4d50
0dfb418
23d970e
310a8d2
5713868
62cc467
42c4fb2
a45e619
66f0023
bcdefb7
b3525b6
f5eba11
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,6 +103,25 @@ def test_griffinlim(self): | |
|
|
||
| self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5)) | ||
|
|
||
| def test_batch_griffinlim(self): | ||
|
|
||
| torch.random.manual_seed(42) | ||
| tensor = torch.rand((1, 201, 6)) | ||
|
|
||
| n_fft = 400 | ||
| ws = 400 | ||
| hop = 200 | ||
| window = torch.hann_window(ws) | ||
| power = 2 | ||
| normalize = False | ||
| momentum = 0.99 | ||
| n_iter = 32 | ||
| length = 1000 | ||
|
|
||
| self._test_batch( | ||
| F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5 | ||
| ) | ||
|
|
||
| def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): | ||
| computed = F.compute_deltas(specgram, win_length=win_length) | ||
| self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
|
|
@@ -126,22 +145,17 @@ def test_compute_deltas_randn(self): | |
| win_length = 2 * 7 + 1 | ||
| specgram = torch.randn(channel, n_mfcc, time) | ||
| computed = F.compute_deltas(specgram, win_length=win_length) | ||
|
|
||
| self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) | ||
|
|
||
| _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) | ||
|
|
||
| def test_batch_pitch(self): | ||
| waveform, sample_rate = torchaudio.load(self.test_filepath) | ||
| self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) | ||
|
|
||
| # Single then transform then batch | ||
| expected = F.detect_pitch_frequency(waveform, sample_rate) | ||
| expected = expected.unsqueeze(0).repeat(3, 1, 1) | ||
|
|
||
| # Batch then transform | ||
| waveform = waveform.unsqueeze(0).repeat(3, 1, 1) | ||
| computed = F.detect_pitch_frequency(waveform, sample_rate) | ||
|
|
||
| self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
| self.assertTrue(torch.allclose(computed, expected)) | ||
| def test_jit_pitch(self): | ||
| waveform, sample_rate = torchaudio.load(self.test_filepath) | ||
| _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) | ||
|
|
||
| def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): | ||
|
|
@@ -157,22 +171,13 @@ def _test_istft_is_inverse_of_stft(self, kwargs): | |
| for data_size in self.data_sizes: | ||
| for i in range(self.number_of_trials): | ||
|
|
||
| # Non-batch | ||
| sound = common_utils.random_float_tensor(i, data_size) | ||
|
|
||
| stft = torch.stft(sound, **kwargs) | ||
| estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) | ||
|
|
||
| self._compare_estimate(sound, estimate) | ||
|
|
||
| # Batch | ||
| stft = torch.stft(sound, **kwargs) | ||
| stft = stft.repeat(3, 1, 1, 1, 1) | ||
| sound = sound.repeat(3, 1, 1) | ||
|
|
||
| estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) | ||
| self._compare_estimate(sound, estimate) | ||
|
|
||
| def test_istft_is_inverse_of_stft1(self): | ||
| # hann_window, centered, normalized, onesided | ||
| kwargs1 = { | ||
|
|
@@ -389,6 +394,16 @@ def test_linearity_of_istft4(self): | |
| data_size = (2, 7, 3, 2) | ||
| self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) | ||
|
|
||
| def test_batch_istft(self): | ||
|
|
||
| stft = torch.tensor([ | ||
| [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], | ||
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why repeat all 0s instead of something more interesting (maybe just
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm reusing the parameters from |
||
| [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] | ||
| ]) | ||
|
|
||
| self._test_batch(F.istft, stft, n_fft=4, length=4) | ||
|
|
||
| def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0): | ||
| # Using a decorator here causes parametrize to fail on Python 2 | ||
| if not IMPORT_LIBROSA: | ||
|
|
@@ -489,22 +504,63 @@ def test_pitch(self): | |
| self.assertFalse(s) | ||
|
|
||
| # Convert to stereo and batch for testing purposes | ||
| freq = freq.repeat(3, 2, 1, 1) | ||
| waveform = waveform.repeat(3, 2, 1, 1) | ||
| self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) | ||
|
|
||
| freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) | ||
| def _test_batch_shape(self, functional, tensor, *args, **kwargs): | ||
|
|
||
| assert torch.allclose(freq, freq2, atol=1e-5) | ||
| kwargs_compare = {} | ||
| if 'atol' in kwargs: | ||
| atol = kwargs['atol'] | ||
| del kwargs['atol'] | ||
| kwargs_compare['atol'] = atol | ||
|
|
||
| def _test_batch(self, functional): | ||
| waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 | ||
| if 'rtol' in kwargs: | ||
| rtol = kwargs['rtol'] | ||
| del kwargs['rtol'] | ||
| kwargs_compare['rtol'] = rtol | ||
|
|
||
| # Single then transform then batch | ||
| expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1) | ||
|
|
||
| # Batch then transform | ||
| waveform = waveform.unsqueeze(0).repeat(3, 1, 1) | ||
| computed = functional(waveform) | ||
| torch.random.manual_seed(42) | ||
| expected = functional(tensor.clone(), *args, **kwargs) | ||
| expected = expected.unsqueeze(0).unsqueeze(0) | ||
|
|
||
| # 1-Batch then transform | ||
|
|
||
| tensors = tensor.unsqueeze(0).unsqueeze(0) | ||
|
|
||
| torch.random.manual_seed(42) | ||
| computed = functional(tensors.clone(), *args, **kwargs) | ||
|
|
||
| self._compare_estimate(computed, expected, **kwargs_compare) | ||
|
|
||
| return tensors, expected | ||
|
|
||
| def _test_batch(self, functional, tensor, *args, **kwargs): | ||
|
|
||
| tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs) | ||
|
|
||
| kwargs_compare = {} | ||
| if 'atol' in kwargs: | ||
| atol = kwargs['atol'] | ||
| del kwargs['atol'] | ||
| kwargs_compare['atol'] = atol | ||
|
|
||
| if 'rtol' in kwargs: | ||
| rtol = kwargs['rtol'] | ||
| del kwargs['rtol'] | ||
| kwargs_compare['rtol'] = rtol | ||
|
|
||
| # 3-Batch then transform | ||
|
|
||
| ind = [3] + [1] * (int(tensors.dim()) - 1) | ||
| tensors = tensor.repeat(*ind) | ||
|
|
||
| ind = [3] + [1] * (int(expected.dim()) - 1) | ||
| expected = expected.repeat(*ind) | ||
|
|
||
| torch.random.manual_seed(42) | ||
| computed = functional(tensors.clone(), *args, **kwargs) | ||
|
|
||
| def test_torchscript_create_fb_matrix(self): | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -374,6 +374,19 @@ def test_compute_deltas_twochannel(self): | |
| computed = transform(specgram) | ||
| self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) | ||
|
|
||
| def test_batch_MelScale(self): | ||
| specgram = torch.randn(2, 31, 2786) | ||
|
|
||
| # Single then transform then batch | ||
| expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1) | ||
|
|
||
| # Batch then transform | ||
| computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) | ||
|
|
||
| # shape = (3, 2, 201, 1394) | ||
| self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
| self.assertTrue(torch.allclose(computed, expected)) | ||
|
|
||
| def test_batch_compute_deltas(self): | ||
| specgram = torch.randn(2, 31, 2786) | ||
|
|
||
|
|
@@ -433,6 +446,30 @@ def test_batch_spectrogram(self): | |
| self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
| self.assertTrue(torch.allclose(computed, expected)) | ||
|
|
||
| def test_batch_melspectrogram(self): | ||
| waveform, sample_rate = torchaudio.load(self.test_filepath) | ||
|
|
||
| # Single then transform then batch | ||
| expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is repetitive. I'm wondering if there is a way of creating a test generator that explicitly exercises this decorator. It's a function that accepts a function with some args and kwargs and then assumes that the first input is to be batchable (so it applies various reshapes to args[0] etc.) (not sure on the reshape behavior, could also be a lambda).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I followed the format that I did for testing jitability, and simply introduced common functions to test batching. Thoughts?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I'm mostly just trying to iterate on top of that. Might be one of these non-problems I'm trying to fix. |
||
|
|
||
| # Batch then transform | ||
| computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1)) | ||
|
|
||
| self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
| self.assertTrue(torch.allclose(computed, expected)) | ||
|
|
||
| def test_batch_mfcc(self): | ||
| waveform, sample_rate = torchaudio.load(self.test_filepath) | ||
|
|
||
| # Single then transform then batch | ||
| expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1) | ||
|
|
||
| # Batch then transform | ||
| computed = transforms.MFCC()(waveform.repeat(3, 1, 1)) | ||
|
|
||
| self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) | ||
| self.assertTrue(torch.allclose(computed, expected, atol=1e-5)) | ||
|
|
||
| def test_scriptmodule_TimeStretch(self): | ||
| n_freq = 400 | ||
| hop_length = 512 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these parameters special to the batch test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the same as
test_torchscript_spectrogram, but we can change them or we could set them in the class directly.