From b0740ab114af43d83fd02a9b99f608023a2fa02e Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 26 May 2021 12:15:52 -0700 Subject: [PATCH 1/5] select autograd test from carolineechen/audio#2 --- .../rnnt/autograd_cpu_test.py | 10 +++ .../rnnt/autograd_cuda_test.py | 11 +++ .../torchaudio_unittest/rnnt/autograd_impl.py | 78 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 test/torchaudio_unittest/rnnt/autograd_cpu_test.py create mode 100644 test/torchaudio_unittest/rnnt/autograd_cuda_test.py create mode 100644 test/torchaudio_unittest/rnnt/autograd_impl.py diff --git a/test/torchaudio_unittest/rnnt/autograd_cpu_test.py b/test/torchaudio_unittest/rnnt/autograd_cpu_test.py new file mode 100644 index 0000000000..100922e569 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/autograd_cpu_test.py @@ -0,0 +1,10 @@ +import torch +from .autograd_impl import Autograd +from torchaudio_unittest import common_utils +from .utils import skipIfNoTransducer + + +@skipIfNoTransducer +class TestAutograd(Autograd, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/rnnt/autograd_cuda_test.py b/test/torchaudio_unittest/rnnt/autograd_cuda_test.py new file mode 100644 index 0000000000..cf93e032d9 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/autograd_cuda_test.py @@ -0,0 +1,11 @@ +import torch +from .autograd_impl import Autograd +from torchaudio_unittest import common_utils +from .utils import skipIfNoTransducer + + +@skipIfNoTransducer +@common_utils.skipIfNoCuda +class TestAutograd(Autograd, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py new file mode 100644 index 0000000000..25f48e3bb5 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -0,0 +1,78 @@ +from typing import Callable, Tuple +import torch +from torch import Tensor +from torch.autograd import gradcheck +from torchaudio_unittest.common_utils import ( + TestBaseMixin, +) +from torchaudio.prototype.rnnt_loss import RNNTLoss +from parameterized import parameterized +from .utils import ( + numpy_to_torch, + get_B1_T10_U3_D4_data, + get_numpy_data_B2_T4_U3_D3, + get_numpy_data_B1_T2_U3_D5 +) +from .numpy_transducer import NumpyTransducerLoss + + +class Autograd(TestBaseMixin): + @staticmethod + def get_data(data_func, device): + data_np = data_func() + if type(data_np) == tuple: + data_np = data_np[0] + data = numpy_to_torch( + data=data_np, device=device, requires_grad=True + ) + return data + + def assert_grad( + self, + loss: Callable[..., Tensor], + inputs: Tuple[torch.Tensor], + *, + enable_all_grad: bool = True, + ): + inputs_ = [] + for i in inputs: + if torch.is_tensor(i): + i = i.to(dtype=self.dtype, device=self.device) + if enable_all_grad: + i.requires_grad = True + inputs_.append(i) + assert gradcheck(loss, inputs, eps=1e-03, atol=1e-03, rtol=1e-03, nondet_tol=0.) + + @parameterized.expand([ + (get_B1_T10_U3_D4_data, ), + (get_numpy_data_B2_T4_U3_D3, ), + (get_numpy_data_B1_T2_U3_D5, ), + ]) + def test_RNNTLoss_gradcheck(self, data_func): + data = self.get_data(data_func, self.device) + inputs = ( + data["logits"].to(self.dtype), + data["targets"], + data["logit_lengths"], + data["target_lengths"], + ) + loss = RNNTLoss(blank=data["blank"]) + + self.assert_grad(loss, inputs, enable_all_grad=False) + + @parameterized.expand([ + (get_B1_T10_U3_D4_data, ), + (get_numpy_data_B2_T4_U3_D3, ), + (get_numpy_data_B1_T2_U3_D5, ), + ]) + def test_np_transducer_gradcheck(self, data_func): + data = self.get_data(data_func, self.device) + inputs = ( + data["logits"].to(self.dtype), + data["logit_lengths"], + data["target_lengths"], + data["targets"], + ) + loss = NumpyTransducerLoss(blank=data["blank"]) + + self.assert_grad(loss, inputs, enable_all_grad=False) From 65c0d592409eebaa6a8ce78f20e268ff4e3c2c51 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 26 May 2021 14:36:46 -0700 Subject: [PATCH 2/5] fix numpy backward: be careful to not modify inplace. --- test/torchaudio_unittest/rnnt/numpy_transducer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/torchaudio_unittest/rnnt/numpy_transducer.py b/test/torchaudio_unittest/rnnt/numpy_transducer.py index a284bc1a8d..1a907034b8 100644 --- a/test/torchaudio_unittest/rnnt/numpy_transducer.py +++ b/test/torchaudio_unittest/rnnt/numpy_transducer.py @@ -33,8 +33,9 @@ def forward( return costs @staticmethod - def backward(ctx, output_gradients): - return ctx.grads, None, None, None, None, None, None, None, None + def backward(ctx, grad_output): + grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads) + return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None @staticmethod def compute_alpha_one_sequence(log_probs, targets, blank=-1): From fa2956c844f5f67f005d0321528e437e2b382c36 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 27 May 2021 09:20:14 -0700 Subject: [PATCH 3/5] gradcheck will fail if input is modified in place. --- test/torchaudio_unittest/rnnt/autograd_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py index 25f48e3bb5..c2a91719cb 100644 --- a/test/torchaudio_unittest/rnnt/autograd_impl.py +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -56,7 +56,7 @@ def test_RNNTLoss_gradcheck(self, data_func): data["logit_lengths"], data["target_lengths"], ) - loss = RNNTLoss(blank=data["blank"]) + loss = RNNTLoss(blank=data["blank"], reuse_logits_for_grads=False) self.assert_grad(loss, inputs, enable_all_grad=False) From 4f6ec1e833009450cae6d199469c5fcc0f434460 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 3 Jun 2021 08:47:26 -0700 Subject: [PATCH 4/5] add rnnt_loss autograd test too. --- .../torchaudio_unittest/rnnt/autograd_impl.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py index c2a91719cb..5cc811634d 100644 --- a/test/torchaudio_unittest/rnnt/autograd_impl.py +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -5,7 +5,7 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, ) -from torchaudio.prototype.rnnt_loss import RNNTLoss +from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss from parameterized import parameterized from .utils import ( numpy_to_torch, @@ -60,6 +60,26 @@ def test_RNNTLoss_gradcheck(self, data_func): self.assert_grad(loss, inputs, enable_all_grad=False) + @parameterized.expand([ + (get_B1_T10_U3_D4_data, ), + (get_numpy_data_B2_T4_U3_D3, ), + (get_numpy_data_B1_T2_U3_D5, ), + ]) + def test_rnnt_loss_gradcheck(self, data_func): + data = self.get_data(data_func, self.device) + inputs = ( + data["logits"].to(self.dtype), # logits + data["targets"], # targets + data["logit_lengths"], # logit_lengths + data["target_lengths"], # target_lengths + data["blank"], # blank + -1, # clamp + True, # fused_log_softmax + False, # reuse_logits_for_grads + ) + + self.assert_grad(rnnt_loss, inputs, enable_all_grad=False) + @parameterized.expand([ (get_B1_T10_U3_D4_data, ), (get_numpy_data_B2_T4_U3_D3, ), From 3425bac6317613bc100eb7708830f83adcee2288 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 3 Jun 2021 14:46:19 -0700 Subject: [PATCH 5/5] leave rtol to default value. --- test/torchaudio_unittest/rnnt/autograd_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py index 5cc811634d..72e0090665 100644 --- a/test/torchaudio_unittest/rnnt/autograd_impl.py +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -41,7 +41,8 @@ def assert_grad( if enable_all_grad: i.requires_grad = True inputs_.append(i) - assert gradcheck(loss, inputs, eps=1e-03, atol=1e-03, rtol=1e-03, nondet_tol=0.) + # gradcheck with float32 requires higher atol and epsilon + assert gradcheck(loss, inputs, eps=1e-3, atol=1e-3, nondet_tol=0.) @parameterized.expand([ (get_B1_T10_U3_D4_data, ),