Skip to content

Commit 9436b72

Browse files
committed
Add autograd test to T.TimeStretch (and F.phase_vocoder)
1 parent 9a0e70e commit 9436b72

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

test/torchaudio_unittest/transforms/autograd_test_impl.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from torchaudio_unittest.common_utils import (
99
TestBaseMixin,
1010
get_whitenoise,
11+
get_spectrogram,
12+
nested_params,
1113
)
1214

1315

@@ -23,8 +25,12 @@ def assert_grad(
2325

2426
inputs_ = []
2527
for i in inputs:
26-
i.requires_grad = True
27-
inputs_.append(i.to(dtype=torch.float64, device=self.device))
28+
if torch.is_tensor(i):
29+
i = i.to(
30+
dtype=torch.cdouble if i.is_complex() else torch.double,
31+
device=self.device)
32+
i.requires_grad = True
33+
inputs_.append(i)
2834
assert gradcheck(transform, inputs_)
2935
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
3036

@@ -103,3 +109,23 @@ def test_spectral_centroid(self):
103109
transform = T.SpectralCentroid(sample_rate=sample_rate)
104110
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
105111
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
112+
113+
@nested_params(
114+
[0.7, 0.8, 0.9, 1.0, 1.3],
115+
[True, False],
116+
)
117+
def test_timestretch(self, rate, test_complex):
118+
transform = T.TimeStretch(fixed_rate=rate)
119+
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
120+
spectrogram = get_spectrogram(waveform, n_fft=400, power=1 if test_complex else None)
121+
self.assert_grad(transform, [spectrogram])
122+
123+
@nested_params(
124+
[0.7, 0.8, 0.9, 1.0, 1.3],
125+
[True, False],
126+
)
127+
def test_timestretch_override(self, rate, test_complex):
128+
transform = T.TimeStretch()
129+
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
130+
spectrogram = get_spectrogram(waveform, n_fft=400, power=1 if test_complex else None)
131+
self.assert_grad(transform, [spectrogram, rate])

0 commit comments

Comments
 (0)