diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 8bf5829ffe..18646134a5 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -1,4 +1,5 @@ from typing import List +import unittest from parameterized import parameterized import torch @@ -35,10 +36,16 @@ def assert_grad( ): transform = transform.to(dtype=torch.float64, device=self.device) + # gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or + # `torch.cdouble`, when the default eps and tolerance values are used. inputs_ = [] for i in inputs: - i.requires_grad = True - inputs_.append(i.to(dtype=torch.float64, device=self.device)) + if torch.is_tensor(i): + i = i.to( + dtype=torch.cdouble if i.is_complex() else torch.double, + device=self.device) + i.requires_grad = True + inputs_.append(i) assert gradcheck(transform, inputs_) assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) @@ -129,3 +136,48 @@ def test_amplitude_to_db(self): transform = T.AmplitudeToDB() waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) self.assert_grad(transform, [waveform]) + + @unittest.expectedFailure + def test_timestretch_zeros_fail(self): + """Test that ``T.TimeStretch`` fails gradcheck at 0 + + This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate, + which performs ``atan2(img, real)``, and gradient is not defined at 0. + """ + n_fft = 16 + transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99) + waveform = torch.zeros(2, 40) + spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) + self.assert_grad(transform, [spectrogram]) + + @nested_params( + [0.7, 0.8, 0.9, 1.0, 1.3], + [False, True], + ) + def test_timestretch_non_zero(self, rate, test_pseudo_complex): + """Verify that ``T.TimeStretch`` does not fail if it's not close to 0 + + ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability + for cases where input is not zero. + + As tested above, when spectrogram contains values close to zero, the gradients are unstable + and gradcheck fails. + + In this test, we generate spectrogram from random signal, then we push the points around + zero away from the origin. + + This process does not reflect the real use-case, and it is not practical for users, but + this helps us understand to what degree the function is differentiable and when not. + """ + n_fft = 16 + transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate) + waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2) + spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) + + # 1e-3 is too small (on CPU) + epsilon = 1e-2 + too_close = spectrogram.abs() < epsilon + spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs() + if test_pseudo_complex: + spectrogram = torch.view_as_real(spectrogram) + self.assert_grad(transform, [spectrogram])