diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py new file mode 100644 index 0000000000..06b6baf5a1 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py @@ -0,0 +1,10 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .utils import skipIfNoTransducer +from .torchscript_consistency_impl import RNNTLossTorchscript + + +@skipIfNoTransducer +class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py new file mode 100644 index 0000000000..22b1713582 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py @@ -0,0 +1,11 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda +from .utils import skipIfNoTransducer +from .torchscript_consistency_impl import RNNTLossTorchscript + + +@skipIfNoTransducer +@skipIfNoCuda +class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py new file mode 100644 index 0000000000..aeba7e3ae4 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py @@ -0,0 +1,70 @@ +import torch +from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin +from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss + + +class RNNTLossTorchscript(TempDirMixin, TestBaseMixin): + """Implements test for RNNT Loss that are performed for different devices""" + def _assert_consistency(self, func, tensor, shape_only=False): + tensor = tensor.to(device=self.device, dtype=self.dtype) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + output = func(input_tensor) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + ts_output = ts_func(input_tensor) + + self.assertEqual(ts_output, output) + + def test_rnnt_loss(self): + def func( + logits, + ): + targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + return rnnt_loss(logits, targets, logit_lengths, target_lengths) + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + + self._assert_consistency(func, logits) + + def test_RNNTLoss(self): + func = RNNTLoss() + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) + + tensor = logits.to(device=self.device, dtype=self.dtype) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + output = func(input_tensor, targets, logit_lengths, target_lengths) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths) + + self.assertEqual(ts_output, output) diff --git a/test/torchaudio_unittest/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py index 8e93d28032..5f4c40379e 100644 --- a/test/torchaudio_unittest/rnnt/utils.py +++ b/test/torchaudio_unittest/rnnt/utils.py @@ -405,10 +405,10 @@ def get_numpy_random_data( def numpy_to_torch(data, device, requires_grad=True): - logits = torch.from_numpy(data["logits"]) - targets = torch.from_numpy(data["targets"]) - logit_lengths = torch.from_numpy(data["logit_lengths"]) - target_lengths = torch.from_numpy(data["target_lengths"]) + logits = torch.from_numpy(data["logits"]).to(device=device) + targets = torch.from_numpy(data["targets"]).to(device=device) + logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device) + target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device) if "nbest_wers" in data: data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device) diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index ebf577eb64..64661f96a5 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp + rnnt/autograd.cpp ) if (USE_CUDA) diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp new file mode 100644 index 0000000000..73ad9f9b3c --- /dev/null +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -0,0 +1,74 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { + +class RNNTLossFunction : public torch::autograd::Function { + public: + static torch::autograd::tensor_list forward( + torch::autograd::AutogradContext* ctx, + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + at::AutoNonVariableTypeMode g; + torch::Tensor undef; + auto result = rnnt_loss( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); + auto costs = std::get<0>(result); + auto grads = std::get<1>(result).value_or(undef); + ctx->save_for_backward({grads}); + return {costs, grads}; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto grad = saved[0]; + auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); + auto result = grad * grad_out; + torch::Tensor undef; + return {result, undef, undef, undef, undef, undef, undef, undef}; + } +}; + +std::tuple> rnnt_loss_autograd( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + auto results = RNNTLossFunction::apply( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); + return std::make_tuple(results[0], results[1]); +} + +TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { + m.impl("rnnt_loss", rnnt_loss_autograd); +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index bce803fffa..f47f0f505d 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -1,4 +1,28 @@ #include +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("torchaudio::rnnt_loss", "") + .typed(); + return op.call( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); +} TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( diff --git a/torchaudio/csrc/rnnt/compute.h b/torchaudio/csrc/rnnt/compute.h new file mode 100644 index 0000000000..9616d45fc3 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax, + bool reuse_logits_for_grads); diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 08c96ab0f2..40d255b854 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor __all__ = [ "RNNTLoss", @@ -19,15 +20,6 @@ def _rnnt_loss_alphas( See documentation for RNNTLoss """ - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - return torch.ops.torchaudio.rnnt_loss_alphas( logits, targets, @@ -51,15 +43,6 @@ def _rnnt_loss_betas( See documentation for RNNTLoss """ - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - return torch.ops.torchaudio.rnnt_loss_betas( logits, targets, @@ -70,77 +53,15 @@ def _rnnt_loss_betas( ) -class _RNNT(torch.autograd.Function): - @staticmethod - def forward( - ctx, - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, - fused_log_softmax=True, - reuse_logits_for_grads=True, - ): - """ - See documentation for RNNTLoss - """ - - # move everything to the same device. - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - - if blank < 0: # reinterpret blank index if blank < 0. - blank = logits.shape[-1] + blank - - costs, gradients = torch.ops.torchaudio.rnnt_loss( - logits=logits, - targets=targets, - src_lengths=logit_lengths, - tgt_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_smax=fused_log_softmax, - reuse_logits_for_grads=reuse_logits_for_grads, - ) - - ctx.grads = gradients - - return costs - - @staticmethod - def backward(ctx, output_gradients): - output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads) - ctx.grads.mul_(output_gradients).to(ctx.grads) - - return ( - ctx.grads, # logits - None, # targets - None, # logit_lengths - None, # target_lengths - None, # blank - None, # clamp - None, # fused_log_softmax - None, # reuse_logits_for_grads - ) - - def rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, - fused_log_softmax=True, - reuse_logits_for_grads=True, + logits: Tensor, + targets: Tensor, + logit_lengths: Tensor, + target_lengths: Tensor, + blank: int = -1, + clamp: float = -1, + fused_log_softmax: bool = True, + reuse_logits_for_grads: bool = True, ): """ Compute the RNN Transducer Loss. @@ -166,17 +87,20 @@ def rnnt_loss( False # softmax needs the original logits value ) - cost = _RNNT.apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax, - reuse_logits_for_grads, - ) - return cost + if blank < 0: # reinterpret blank index if blank < 0. + blank = logits.shape[-1] + blank + + costs, gradients = torch.ops.torchaudio.rnnt_loss( + logits=logits, + targets=targets, + src_lengths=logit_lengths, + tgt_lengths=target_lengths, + blank=blank, + clamp=clamp, + fused_log_smax=fused_log_softmax, + reuse_logits_for_grads=reuse_logits_for_grads,) + + return costs class RNNTLoss(torch.nn.Module): @@ -196,10 +120,10 @@ class RNNTLoss(torch.nn.Module): def __init__( self, - blank=-1, - clamp=-1, - fused_log_softmax=True, - reuse_logits_for_grads=True, + blank: int = -1, + clamp: float = -1., + fused_log_softmax: bool = True, + reuse_logits_for_grads: bool = True, ): super().__init__() self.blank = blank