Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import unittest

from parameterized import parameterized
import torch
Expand Down Expand Up @@ -35,10 +36,16 @@ def assert_grad(
):
transform = transform.to(dtype=torch.float64, device=self.device)

# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
# `torch.cdouble`, when the default eps and tolerance values are used.
inputs_ = []
for i in inputs:
i.requires_grad = True
inputs_.append(i.to(dtype=torch.float64, device=self.device))
if torch.is_tensor(i):
i = i.to(
dtype=torch.cdouble if i.is_complex() else torch.double,
device=self.device)
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)

Expand Down Expand Up @@ -129,3 +136,48 @@ def test_amplitude_to_db(self):
transform = T.AmplitudeToDB()
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform])

@unittest.expectedFailure
def test_timestretch_zeros_fail(self):
"""Test that ``T.TimeStretch`` fails gradcheck at 0

This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
which performs ``atan2(img, real)``, and gradient is not defined at 0.
"""
n_fft = 16
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=0.99)
waveform = torch.zeros(2, 40)
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
self.assert_grad(transform, [spectrogram])

@nested_params(
[0.7, 0.8, 0.9, 1.0, 1.3],
[False, True],
)
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0

``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
for cases where input is not zero.

As tested above, when spectrogram contains values close to zero, the gradients are unstable
and gradcheck fails.

In this test, we generate spectrogram from random signal, then we push the points around
zero away from the origin.

This process does not reflect the real use-case, and it is not practical for users, but
this helps us understand to what degree the function is differentiable and when not.
"""
n_fft = 16
transform = T.TimeStretch(n_freq=n_fft // 2 + 1, fixed_rate=rate)
waveform = get_whitenoise(sample_rate=40, duration=1, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)

# 1e-3 is too small (on CPU)
epsilon = 1e-2
too_close = spectrogram.abs() < epsilon
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram])