diff --git a/test/torchaudio_unittest/rnnt/autograd_cpu_test.py b/test/torchaudio_unittest/rnnt/autograd_cpu_test.py new file mode 100644 index 0000000000..100922e569 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/autograd_cpu_test.py @@ -0,0 +1,10 @@ +import torch +from .autograd_impl import Autograd +from torchaudio_unittest import common_utils +from .utils import skipIfNoTransducer + + +@skipIfNoTransducer +class TestAutograd(Autograd, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/rnnt/autograd_cuda_test.py b/test/torchaudio_unittest/rnnt/autograd_cuda_test.py new file mode 100644 index 0000000000..cf93e032d9 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/autograd_cuda_test.py @@ -0,0 +1,11 @@ +import torch +from .autograd_impl import Autograd +from torchaudio_unittest import common_utils +from .utils import skipIfNoTransducer + + +@skipIfNoTransducer +@common_utils.skipIfNoCuda +class TestAutograd(Autograd, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py new file mode 100644 index 0000000000..72e0090665 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -0,0 +1,99 @@ +from typing import Callable, Tuple +import torch +from torch import Tensor +from torch.autograd import gradcheck +from torchaudio_unittest.common_utils import ( + TestBaseMixin, +) +from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss +from parameterized import parameterized +from .utils import ( + numpy_to_torch, + get_B1_T10_U3_D4_data, + get_numpy_data_B2_T4_U3_D3, + get_numpy_data_B1_T2_U3_D5 +) +from .numpy_transducer import NumpyTransducerLoss + + +class Autograd(TestBaseMixin): + @staticmethod + def get_data(data_func, device): + data_np = data_func() + if type(data_np) == tuple: + data_np = data_np[0] + data = numpy_to_torch( + data=data_np, device=device, requires_grad=True + ) + return data + + def assert_grad( + self, + loss: Callable[..., Tensor], + inputs: Tuple[torch.Tensor], + *, + enable_all_grad: bool = True, + ): + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to(dtype=self.dtype, device=self.device) + if enable_all_grad: + i.requires_grad = True + inputs_.append(i) + # gradcheck with float32 requires higher atol and epsilon + assert gradcheck(loss, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.) + + @parameterized.expand([ + (get_B1_T10_U3_D4_data, ), + (get_numpy_data_B2_T4_U3_D3, ), + (get_numpy_data_B1_T2_U3_D5, ), + ]) + def test_RNNTLoss_gradcheck(self, data_func): + data = self.get_data(data_func, self.device) + inputs = ( + data["logits"].to(self.dtype), + data["targets"], + data["logit_lengths"], + data["target_lengths"], + ) + loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False) + + self.assert_grad(loss, inputs, enable_all_grad=False) + + @parameterized.expand([ + (get_B1_T10_U3_D4_data, ), + (get_numpy_data_B2_T4_U3_D3, ), + (get_numpy_data_B1_T2_U3_D5, ), + ]) + def test_rnnt_loss_gradcheck(self, data_func): + data = self.get_data(data_func, self.device) + inputs = ( + data["logits"].to(self.dtype), # logits + data["targets"], # targets + data["logit_lengths"], # logit_lengths + data["target_lengths"], # target_lengths + data["blank"], # blank + -1, # clamp + True, # fused_log_softmax + False, # reuse_logits_for_grads + ) + + self.assert_grad(rnnt_loss, inputs, enable_all_grad=False) + + @parameterized.expand([ + (get_B1_T10_U3_D4_data, ), + (get_numpy_data_B2_T4_U3_D3, ), + (get_numpy_data_B1_T2_U3_D5, ), + ]) + def test_np_transducer_gradcheck(self, data_func): + data = self.get_data(data_func, self.device) + inputs = ( + data["logits"].to(self.dtype), + data["logit_lengths"], + data["target_lengths"], + data["targets"], + ) + loss = NumpyTransducerLoss(blank=data["blank"]) + + self.assert_grad(loss, inputs, enable_all_grad=False) diff --git a/test/torchaudio_unittest/rnnt/numpy_transducer.py b/test/torchaudio_unittest/rnnt/numpy_transducer.py index a284bc1a8d..1a907034b8 100644 --- a/test/torchaudio_unittest/rnnt/numpy_transducer.py +++ b/test/torchaudio_unittest/rnnt/numpy_transducer.py @@ -33,8 +33,9 @@ def forward( return costs @staticmethod - def backward(ctx, output_gradients): - return ctx.grads, None, None, None, None, None, None, None, None + def backward(ctx, grad_output): + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None @staticmethod def compute_alpha_one_sequence(log_probs, targets, blank=-1):