Skip to content

Commit b67b2c3

Browse files
committed
Address most of feedbacks
1 parent b3fbe41 commit b67b2c3

File tree

3 files changed

+11
-14
lines changed

3 files changed

+11
-14
lines changed
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import torch
21
from torchaudio_unittest.common_utils import PytorchTestCase
3-
from .autograd_test_impl import AutogradTestCase
2+
from .autograd_test_impl import AutogradTestMixin
43

54

6-
class AutogradCPUTest(AutogradTestCase, PytorchTestCase):
5+
class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
76
device = 'cpu'
8-
dtype = torch.float64
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import torch
21
from torchaudio_unittest.common_utils import (
32
PytorchTestCase,
43
skipIfNoCuda,
54
)
6-
from .autograd_test_impl import AutogradTestCase
5+
from .autograd_test_impl import AutogradTestMixin
76

87

98
@skipIfNoCuda
10-
class AutogradCUDATest(AutogradTestCase, PytorchTestCase):
9+
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
1110
device = 'cuda'
12-
dtype = torch.float64

test/torchaudio_unittest/transforms/autograd_test_impl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from parameterized import parameterized
2+
import torch
23
from torch.autograd import gradcheck, gradgradcheck
34
import torchaudio.transforms as T
45

@@ -8,16 +9,16 @@
89
)
910

1011

11-
class AutogradTestCase(TestBaseMixin):
12-
def assert_grad(self, transform, *inputs, eps=1e-06, atol=1e-05, rtol=0.001):
13-
transform = transform.to(self.device, self.dtype)
12+
class AutogradTestMixin(TestBaseMixin):
13+
def assert_grad(self, transform, *inputs):
14+
transform = transform.to(dtype=torch.float64, device=self.device)
1415

1516
inputs_ = []
1617
for i in inputs:
1718
i.requires_grad = True
18-
inputs_.append(i.to(dtype=self.dtype, device=self.device))
19-
assert gradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
20-
assert gradgradcheck(transform, inputs_, eps=eps, atol=atol, rtol=rtol)
19+
inputs_.append(i.to(dtype=torch.float64, device=self.device))
20+
assert gradcheck(transform, inputs_)
21+
assert gradgradcheck(transform, inputs_)
2122

2223
@parameterized.expand([
2324
({'pad': 0, 'normalized': False, 'power': None}, ),

0 commit comments

Comments
 (0)