Skip to content

Commit 9af2439

Browse files
committed
select autograd test from carolineechen#2
1 parent c8239c6 commit 9af2439

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
from .autograd_impl import Autograd
3+
from torchaudio_unittest import common_utils
4+
from .utils import skipIfNoTransducer
5+
6+
7+
@skipIfNoTransducer
8+
class TestAutograd(Autograd, common_utils.PytorchTestCase):
9+
dtype = torch.float32
10+
device = torch.device('cpu')
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
from .autograd_impl import Autograd
3+
from torchaudio_unittest import common_utils
4+
from .utils import skipIfNoTransducer
5+
6+
7+
@skipIfNoTransducer
8+
class TestAutograd(Autograd, common_utils.PytorchTestCase):
9+
dtype = torch.float32
10+
device = torch.device('cuda')
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Callable, Tuple
2+
import torch
3+
from torch import Tensor
4+
from torch.autograd import gradcheck
5+
from torchaudio_unittest.common_utils import (
6+
TestBaseMixin,
7+
)
8+
from torchaudio.prototype.rnnt_loss import RNNTLoss
9+
from parameterized import parameterized
10+
from .utils import (
11+
numpy_to_torch,
12+
get_B1_T10_U3_D4_data,
13+
get_numpy_data_B2_T4_U3_D3,
14+
get_numpy_data_B1_T2_U3_D5
15+
)
16+
from .numpy_transducer import NumpyTransducerLoss
17+
18+
19+
class Autograd(TestBaseMixin):
20+
@staticmethod
21+
def get_data(data_func, device):
22+
data_np = data_func()
23+
if type(data_np) == tuple:
24+
data_np = data_np[0]
25+
data = numpy_to_torch(
26+
data=data_np, device=device, requires_grad=True
27+
)
28+
return data
29+
30+
def assert_grad(
31+
self,
32+
loss: Callable[..., Tensor],
33+
inputs: Tuple[torch.Tensor],
34+
*,
35+
enable_all_grad: bool = True,
36+
):
37+
# inputs_ = []
38+
# for i in inputs:
39+
# if torch.is_tensor(i):
40+
# i = i.to(dtype=self.dtype, device=self.device)
41+
# if enable_all_grad:
42+
# i.requires_grad = True
43+
# inputs_.append(i)
44+
assert gradcheck(loss, inputs, eps=1e-03, atol=1e-03, rtol=1e-03, nondet_tol=0.)
45+
46+
@parameterized.expand([
47+
(get_B1_T10_U3_D4_data, ),
48+
(get_numpy_data_B2_T4_U3_D3, ),
49+
(get_numpy_data_B1_T2_U3_D5, ),
50+
])
51+
def test_RNNTLoss_gradcheck(self, data_func):
52+
data = self.get_data(data_func, self.device)
53+
inputs = (
54+
data["logits"].to(self.dtype),
55+
data["targets"],
56+
data["logit_lengths"],
57+
data["target_lengths"],
58+
)
59+
loss = RNNTLoss(blank=data["blank"])
60+
61+
self.assert_grad(loss, inputs, enable_all_grad=False)
62+
63+
@parameterized.expand([
64+
(get_B1_T10_U3_D4_data, ),
65+
(get_numpy_data_B2_T4_U3_D3, ),
66+
(get_numpy_data_B1_T2_U3_D5, ),
67+
])
68+
def test_np_transducer_gradcheck(self, data_func):
69+
data = self.get_data(data_func, self.device)
70+
inputs = (
71+
data["logits"].to(self.dtype),
72+
data["logit_lengths"],
73+
data["target_lengths"],
74+
data["targets"],
75+
)
76+
loss = NumpyTransducerLoss(blank=data["blank"])
77+
78+
self.assert_grad(loss, inputs, enable_all_grad=False)

0 commit comments

Comments
 (0)