|  | 
| 1 | 1 | from typing import List | 
|  | 2 | +import unittest | 
| 2 | 3 | 
 | 
| 3 | 4 | from parameterized import parameterized | 
| 4 | 5 | import torch | 
| @@ -37,8 +38,12 @@ def assert_grad( | 
| 37 | 38 | 
 | 
| 38 | 39 |         inputs_ = [] | 
| 39 | 40 |         for i in inputs: | 
| 40 |  | -            i.requires_grad = True | 
| 41 |  | -            inputs_.append(i.to(dtype=torch.float64, device=self.device)) | 
|  | 41 | +            if torch.is_tensor(i): | 
|  | 42 | +                i = i.to( | 
|  | 43 | +                    dtype=torch.cdouble if i.is_complex() else torch.double, | 
|  | 44 | +                    device=self.device) | 
|  | 45 | +                i.requires_grad = True | 
|  | 46 | +            inputs_.append(i) | 
| 42 | 47 |         assert gradcheck(transform, inputs_) | 
| 43 | 48 |         assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) | 
| 44 | 49 | 
 | 
| @@ -123,3 +128,52 @@ def test_spectral_centroid(self): | 
| 123 | 128 |         transform = T.SpectralCentroid(sample_rate=sample_rate) | 
| 124 | 129 |         waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2) | 
| 125 | 130 |         self.assert_grad(transform, [waveform], nondet_tol=1e-10) | 
|  | 131 | + | 
|  | 132 | +    @unittest.expectedFailure | 
|  | 133 | +    def test_timestretch_zeros_fail(self): | 
|  | 134 | +        """Test that ``T.TimeStretch`` fails gradcheck at 0 | 
|  | 135 | +
 | 
|  | 136 | +        This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate, | 
|  | 137 | +        which performs ``atan2(img, real)``, and gradient is not defined at 0. | 
|  | 138 | +        """ | 
|  | 139 | +        n_fft = 16 | 
|  | 140 | +        transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99) | 
|  | 141 | +        waveform = torch.zeros(2, 40) | 
|  | 142 | +        spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) | 
|  | 143 | +        self.assert_grad(transform, [spectrogram]) | 
|  | 144 | + | 
|  | 145 | +    @nested_params( | 
|  | 146 | +        [0.7, 0.8, 0.9, 1.0, 1.3], | 
|  | 147 | +        [False, True], | 
|  | 148 | +    ) | 
|  | 149 | +    def test_timestretch(self, rate, test_pseudo_complex): | 
|  | 150 | +        """Verify that ``T.TimeStretch`` does not fail if it's not too close to 0 | 
|  | 151 | +
 | 
|  | 152 | +        ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability | 
|  | 153 | +        for cases where input is not zero, and different configurations of `TimeStretch`. | 
|  | 154 | +
 | 
|  | 155 | +        Ideally, we should be testing on Spectrogram of random waveform but it is hard to control | 
|  | 156 | +        the values around zeros. | 
|  | 157 | +        """ | 
|  | 158 | +        n_fft = 16 | 
|  | 159 | +        transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate) | 
|  | 160 | +        waveform = torch.zeros(2, 40) | 
|  | 161 | +        spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) | 
|  | 162 | + | 
|  | 163 | +        # Epsilon values tried | 
|  | 164 | +        # | 
|  | 165 | +        # Note: | 
|  | 166 | +        #   This is not experimental and comprehensive. | 
|  | 167 | +        #   The result also depends on ``n_fft``. | 
|  | 168 | +        # | 
|  | 169 | +        #                 CPU / CUDA | 
|  | 170 | +        # * 1e-1           ok / ok | 
|  | 171 | +        # * 1e-2           ok / ok | 
|  | 172 | +        # * 1e-3           ok / ok | 
|  | 173 | +        # * 1e-3 + 1e-3j   ok / ok | 
|  | 174 | +        # * 1e-4           ok / NG | 
|  | 175 | + | 
|  | 176 | +        spectrogram += 1e-3 | 
|  | 177 | +        if test_pseudo_complex: | 
|  | 178 | +            spectrogram = torch.view_as_real(spectrogram) | 
|  | 179 | +        self.assert_grad(transform, [spectrogram]) | 
0 commit comments