Skip to content

Commit e1bede5

Browse files
committed
Add autograd test to T.TimeStretch (and F.phase_vocoder)
1 parent e911e5e commit e1bede5

File tree

1 file changed

+56
-2
lines changed

1 file changed

+56
-2
lines changed

test/torchaudio_unittest/transforms/autograd_test_impl.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List
2+
import unittest
23

34
from parameterized import parameterized
45
import torch
@@ -37,8 +38,12 @@ def assert_grad(
3738

3839
inputs_ = []
3940
for i in inputs:
40-
i.requires_grad = True
41-
inputs_.append(i.to(dtype=torch.float64, device=self.device))
41+
if torch.is_tensor(i):
42+
i = i.to(
43+
dtype=torch.cdouble if i.is_complex() else torch.double,
44+
device=self.device)
45+
i.requires_grad = True
46+
inputs_.append(i)
4247
assert gradcheck(transform, inputs_)
4348
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
4449

@@ -123,3 +128,52 @@ def test_spectral_centroid(self):
123128
transform = T.SpectralCentroid(sample_rate=sample_rate)
124129
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
125130
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
131+
132+
@unittest.expectedFailure
133+
def test_timestretch_zeros_fail(self):
134+
"""Test that ``T.TimeStretch`` fails gradcheck at 0
135+
136+
This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
137+
which performs ``atan2(img, real)``, and gradient is not defined at 0.
138+
"""
139+
n_fft = 16
140+
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99)
141+
waveform = torch.zeros(2, 40)
142+
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
143+
self.assert_grad(transform, [spectrogram])
144+
145+
@nested_params(
146+
[0.7, 0.8, 0.9, 1.0, 1.3],
147+
[False, True],
148+
)
149+
def test_timestretch(self, rate, test_pseudo_complex):
150+
"""Verify that ``T.TimeStretch`` does not fail if it's not too close to 0
151+
152+
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
153+
for cases where input is not zero, and different configurations of `TimeStretch`.
154+
155+
Ideally, we should be testing on Spectrogram of random waveform but it is hard to control
156+
the values around zeros.
157+
"""
158+
n_fft = 16
159+
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
160+
waveform = torch.zeros(2, 40)
161+
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
162+
163+
# Epsilon values tried
164+
#
165+
# Note:
166+
# This is not experimental and comprehensive.
167+
# The result also depends on ``n_fft``.
168+
#
169+
# CPU / CUDA
170+
# * 1e-1 ok / ok
171+
# * 1e-2 ok / ok
172+
# * 1e-3 ok / ok
173+
# * 1e-3 + 1e-3j ok / ok
174+
# * 1e-4 ok / NG
175+
176+
spectrogram += 1e-3
177+
if test_pseudo_complex:
178+
spectrogram = torch.view_as_real(spectrogram)
179+
self.assert_grad(transform, [spectrogram])

0 commit comments

Comments
 (0)