From c0ff6e3fac11bf22d043232fb1bb0b433ecbbc5f Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Fri, 14 May 2021 23:34:02 +0000 Subject: [PATCH 1/4] move autograd to c++, add torchscript tests --- .../rnnt/torchscript_consistency_cpu_test.py | 10 + .../rnnt/torchscript_consistency_cuda_test.py | 10 + .../rnnt/torchscript_consistency_impl.py | 73 ++++++ torchaudio/csrc/CMakeLists.txt | 1 + torchaudio/csrc/rnnt/autograd.cpp | 70 ++++++ torchaudio/csrc/rnnt/compute.cpp | 17 ++ torchaudio/csrc/rnnt/compute.h | 13 + torchaudio/prototype/rnnt_loss.py | 225 +++--------------- 8 files changed, 231 insertions(+), 188 deletions(-) create mode 100644 test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py create mode 100644 test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py create mode 100644 test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py create mode 100644 torchaudio/csrc/rnnt/autograd.cpp create mode 100644 torchaudio/csrc/rnnt/compute.h diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py new file mode 100644 index 0000000000..06b6baf5a1 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py @@ -0,0 +1,10 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .utils import skipIfNoTransducer +from .torchscript_consistency_impl import RNNTLossTorchscript + + +@skipIfNoTransducer +class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py new file mode 100644 index 0000000000..366ec3e46f --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py @@ -0,0 +1,10 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .utils import skipIfNoTransducer +from .torchscript_consistency_impl import RNNTLossTorchscript + + +@skipIfNoTransducer +class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): + device = torch.device('cuda') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py new file mode 100644 index 0000000000..d2382d4d90 --- /dev/null +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py @@ -0,0 +1,73 @@ +import torch +from torch import Tensor +from torchaudio_unittest import common_utils +from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin +from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss +from .utils import get_B1_T10_U3_D4_data, numpy_to_torch + +class RNNTLossTorchscript(TempDirMixin, TestBaseMixin): + """Implements test for RNNT Loss that are performed for different devices""" + def _assert_consistency(self, func, tensor, shape_only=False): + tensor = tensor.to(device=self.device, dtype=self.dtype) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + output = func(input_tensor) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + ts_output = ts_func(input_tensor) + + self.assertEqual(ts_output, output) + + def test_rnnt_loss(self): + def func( + logits, + ): + targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + return rnnt_loss(logits, targets, logit_lengths, target_lengths) + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + + self._assert_consistency(func, logits) + + def test_RNNTLoss(self): + func = RNNTLoss() + + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.6, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.8, 0.1]], + [[0.1, 0.6, 0.1, 0.1, 0.1], + [0.1, 0.1, 0.2, 0.1, 0.1], + [0.7, 0.1, 0.2, 0.1, 0.1]]]]) + targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + + tensor = logits.to(device=self.device, dtype=self.dtype) + + path = self.get_temp_path('func.zip') + torch.jit.script(func).save(path) + ts_func = torch.jit.load(path) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + output = func(input_tensor, targets, logit_lengths, target_lengths) + + torch.random.manual_seed(40) + input_tensor = tensor.clone().detach().requires_grad_(True) + ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths) + + self.assertEqual(ts_output, output) + \ No newline at end of file diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index ebf577eb64..64661f96a5 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER) rnnt/compute_alphas.cpp rnnt/compute_betas.cpp rnnt/compute.cpp + rnnt/autograd.cpp ) if (USE_CUDA) diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp new file mode 100644 index 0000000000..aadce48ddf --- /dev/null +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -0,0 +1,70 @@ +#include +#include + +namespace torchaudio { +namespace rnnt { + +class RNNTLossFunction : public torch::autograd::Function { + public: + static torch::autograd::tensor_list forward( + torch::autograd::AutogradContext* ctx, + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + at::AutoNonVariableTypeMode g; + torch::Tensor undef; + auto result = rnnt_loss(logits, targets, src_lengths, tgt_lengths, blank, clamp, + fused_log_smax, reuse_logits_for_grads); + auto costs = std::get<0>(result); + auto grads = std::get<1>(result).value_or(undef); + std::cout << grads << std::endl; + ctx->save_for_backward({grads}); + return {costs, grads}; + } + + static torch::autograd::tensor_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto grad = saved[0];auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); + auto result = grad * grad_out; + torch::Tensor undef; + return { + result, + undef, + undef, + undef, + undef, + undef, + undef, + undef}; + } +}; + + +std::tuple> rnnt_loss_autograd( + // torch::Tensor rnnt_loss_autograd( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + auto results = RNNTLossFunction::apply(logits, targets, src_lengths, tgt_lengths, blank, clamp, + fused_log_smax, reuse_logits_for_grads); + return std::make_tuple(results[0], results[1]); +} + +TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { + m.impl("rnnt_loss", rnnt_loss_autograd); +} + +} // namespace rnnt +} // namespace torchaudio diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index bce803fffa..af4a32291d 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -1,4 +1,21 @@ #include +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("torchaudio::rnnt_loss", "") + .typed(); + return op.call(logits, targets, src_lengths, tgt_lengths, blank, clamp, + fused_log_smax, reuse_logits_for_grads); +} TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( diff --git a/torchaudio/csrc/rnnt/compute.h b/torchaudio/csrc/rnnt/compute.h new file mode 100644 index 0000000000..9616d45fc3 --- /dev/null +++ b/torchaudio/csrc/rnnt/compute.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +std::tuple> rnnt_loss( + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax, + bool reuse_logits_for_grads); diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 8246bbe22f..b392a34cad 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -1,4 +1,5 @@ import torch +from torch import Tensor __all__ = [ "RNNTLoss", @@ -70,89 +71,15 @@ def _rnnt_loss_betas( ) -class _RNNT(torch.autograd.Function): - @staticmethod - def forward( - ctx, - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, - runtime_check=False, - fused_log_softmax=True, - reuse_logits_for_grads=True, - ): - """ - See documentation for RNNTLoss - """ - - # move everything to the same device. - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - - if blank < 0: # reinterpret blank index if blank < 0. - blank = logits.shape[-1] + blank - - if runtime_check: - check_inputs( - logits=logits, - targets=targets, - logit_lengths=logit_lengths, - target_lengths=target_lengths, - blank=blank, - ) - - costs, gradients = torch.ops.torchaudio.rnnt_loss( - logits=logits, - targets=targets, - src_lengths=logit_lengths, - tgt_lengths=target_lengths, - blank=blank, - clamp=clamp, - fused_log_smax=fused_log_softmax, - reuse_logits_for_grads=reuse_logits_for_grads, - ) - - ctx.grads = gradients - - return costs - - @staticmethod - def backward(ctx, output_gradients): - output_gradients = output_gradients.view(-1, 1, 1, 1).to(ctx.grads) - ctx.grads.mul_(output_gradients).to(ctx.grads) - - return ( - ctx.grads, # logits - None, # targets - None, # logit_lengths - None, # target_lengths - None, # blank - None, # clamp - None, # runtime_check - None, # fused_log_softmax - None, # reuse_logits_for_grads - ) - - def rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank=-1, - clamp=-1, - runtime_check=False, - fused_log_softmax=True, - reuse_logits_for_grads=True, + logits: Tensor, + targets: Tensor, + logit_lengths: Tensor, + target_lengths: Tensor, + blank: int = -1, + clamp: float = -1, + fused_log_softmax: bool = True, + reuse_logits_for_grads: bool = True, ): """ Compute the RNN Transducer Loss. @@ -178,18 +105,31 @@ def rnnt_loss( False # softmax needs the original logits value ) - cost = _RNNT.apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - runtime_check, - fused_log_softmax, - reuse_logits_for_grads, + # move everything to the same device. + targets = targets.to(device=logits.device) + logit_lengths = logit_lengths.to(device=logits.device) + target_lengths = target_lengths.to(device=logits.device) + + # make sure all int tensors are of type int32. + targets = targets.int() + logit_lengths = logit_lengths.int() + target_lengths = target_lengths.int() + + if blank < 0: # reinterpret blank index if blank < 0. + blank = logits.shape[-1] + blank + + costs, gradients = torch.ops.torchaudio.rnnt_loss( + logits=logits, + targets=targets, + src_lengths=logit_lengths, + tgt_lengths=target_lengths, + blank=blank, + clamp=clamp, + fused_log_smax=fused_log_softmax, + reuse_logits_for_grads=reuse_logits_for_grads, ) - return cost + + return costs class RNNTLoss(torch.nn.Module): @@ -203,23 +143,20 @@ class RNNTLoss(torch.nn.Module): Args: blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) - runtime_check (bool): whether to do sanity check during runtime. (Default: ``False``) fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) reuse_logits_for_grads (bool): whether to save memory by reusing logits memory for grads (Default: ``True``) """ def __init__( self, - blank=-1, - clamp=-1, - runtime_check=False, - fused_log_softmax=True, - reuse_logits_for_grads=True, + blank: int = -1, + clamp: float = -1., + fused_log_softmax: bool = True, + reuse_logits_for_grads: bool = True, ): super().__init__() self.blank = blank self.clamp = clamp - self.runtime_check = runtime_check self.fused_log_softmax = fused_log_softmax self.reuse_logits_for_grads = reuse_logits_for_grads @@ -244,94 +181,6 @@ def forward( target_lengths, self.blank, self.clamp, - self.runtime_check, self.fused_log_softmax, self.reuse_logits_for_grads, ) - - -def check_type(var, t, name): - if var.dtype is not t: - raise TypeError("{} must be {}".format(name, t)) - - -def check_contiguous(var, name): - if not var.is_contiguous(): - raise ValueError("{} must be contiguous".format(name)) - - -def check_dim(var, dim, name): - if len(var.shape) != dim: - raise ValueError("{} must be {}D".format(name, dim)) - - -def check_equal(var1, name1, var2, name2): - if var1 != var2: - raise ValueError( - "`{}` ({}) must equal to ".format(name1, var1) - + "`{}` ({})".format(name2, var2) - ) - - -def check_device(var1, name1, var2, name2): - if var1.device != var2.device: - raise ValueError( - "`{}` ({}) must be on the same ".format(name1, var1.device.type) - + "device as `{}` ({})".format(name2, var2.device.type) - ) - - -def check_inputs(logits, targets, logit_lengths, target_lengths, blank): - check_device(logits, "logits", targets, "targets") - check_device(logits, "logits", targets, "logit_lengths") - check_device(logits, "logits", targets, "target_lengths") - - check_type(logits, torch.float32, "logits") - check_type(targets, torch.int32, "targets") - check_type(logit_lengths, torch.int32, "logit_lengths") - check_type(target_lengths, torch.int32, "target_lengths") - - check_contiguous(logits, "logits") - check_contiguous(targets, "targets") - check_contiguous(target_lengths, "target_lengths") - check_contiguous(logit_lengths, "logit_lengths") - - check_dim(logits, 4, "logits") - check_dim(targets, 2, "targets") - check_dim(logit_lengths, 1, "logit_lengths") - check_dim(target_lengths, 1, "target_lengths") - - check_equal( - logit_lengths.shape[0], "logit_lengths.shape[0]", logits.shape[0], "logits.shape[0]" - ) - check_equal( - target_lengths.shape[0], "target_lengths.shape[0]", logits.shape[0], "logits.shape[0]" - ) - check_equal( - targets.shape[0], "targets.shape[0]", logits.shape[0], "logits.shape[0]" - ) - check_equal( - targets.shape[1], - "targets.shape[1]", - torch.max(target_lengths), - "torch.max(target_lengths)", - ) - check_equal( - logits.shape[1], - "logits.shape[1]", - torch.max(logit_lengths), - "torch.max(logit_lengths)", - ) - check_equal( - logits.shape[2], - "logits.shape[2]", - torch.max(target_lengths) + 1, - "torch.max(target_lengths) + 1", - ) - - if blank < 0 or blank >= logits.shape[-1]: - raise ValueError( - "blank ({}) must be within [0, logits.shape[-1]={})".format( - blank, logits.shape[-1] - ) - ) From 86ed12aa2c587e2e80a8a12c561d8c212aa6ef04 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Mon, 17 May 2021 14:24:35 +0000 Subject: [PATCH 2/4] fixes --- .../rnnt/torchscript_consistency_cuda_test.py | 3 +- .../rnnt/torchscript_consistency_impl.py | 15 ++-- torchaudio/csrc/rnnt/autograd.cpp | 76 ++++++++++--------- torchaudio/csrc/rnnt/compute.cpp | 15 +++- torchaudio/prototype/rnnt_loss.py | 1 + 5 files changed, 60 insertions(+), 50 deletions(-) diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py index 366ec3e46f..22b1713582 100644 --- a/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py @@ -1,10 +1,11 @@ import torch -from torchaudio_unittest.common_utils import PytorchTestCase +from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda from .utils import skipIfNoTransducer from .torchscript_consistency_impl import RNNTLossTorchscript @skipIfNoTransducer +@skipIfNoCuda class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase): device = torch.device('cuda') diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py index d2382d4d90..770edfaa7a 100644 --- a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py @@ -1,9 +1,7 @@ import torch -from torch import Tensor -from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss -from .utils import get_B1_T10_U3_D4_data, numpy_to_torch + class RNNTLossTorchscript(TempDirMixin, TestBaseMixin): """Implements test for RNNT Loss that are performed for different devices""" @@ -24,7 +22,7 @@ def _assert_consistency(self, func, tensor, shape_only=False): self.assertEqual(ts_output, output) - def test_rnnt_loss(self): + def test_rnnt_loss(self): def func( logits, ): @@ -32,23 +30,23 @@ def func( logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) return rnnt_loss(logits, targets, logit_lengths, target_lengths) - + logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], - [[0.1, 0.6, 0.1, 0.1, 0.1], + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]]]]) self._assert_consistency(func, logits) - def test_RNNTLoss(self): + def test_RNNTLoss(self): func = RNNTLoss() logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1], [0.1, 0.1, 0.2, 0.8, 0.1]], - [[0.1, 0.6, 0.1, 0.1, 0.1], + [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]]]]) targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) @@ -70,4 +68,3 @@ def test_RNNTLoss(self): ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths) self.assertEqual(ts_output, output) - \ No newline at end of file diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp index aadce48ddf..73ad9f9b3c 100644 --- a/torchaudio/csrc/rnnt/autograd.cpp +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -5,50 +5,47 @@ namespace torchaudio { namespace rnnt { class RNNTLossFunction : public torch::autograd::Function { - public: - static torch::autograd::tensor_list forward( - torch::autograd::AutogradContext* ctx, - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& src_lengths, - const torch::Tensor& tgt_lengths, - int64_t blank, - double clamp, - bool fused_log_smax = true, - bool reuse_logits_for_grads = true) { - at::AutoNonVariableTypeMode g; - torch::Tensor undef; - auto result = rnnt_loss(logits, targets, src_lengths, tgt_lengths, blank, clamp, - fused_log_smax, reuse_logits_for_grads); - auto costs = std::get<0>(result); - auto grads = std::get<1>(result).value_or(undef); - std::cout << grads << std::endl; - ctx->save_for_backward({grads}); - return {costs, grads}; - } + public: + static torch::autograd::tensor_list forward( + torch::autograd::AutogradContext* ctx, + torch::Tensor& logits, + const torch::Tensor& targets, + const torch::Tensor& src_lengths, + const torch::Tensor& tgt_lengths, + int64_t blank, + double clamp, + bool fused_log_smax = true, + bool reuse_logits_for_grads = true) { + at::AutoNonVariableTypeMode g; + torch::Tensor undef; + auto result = rnnt_loss( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); + auto costs = std::get<0>(result); + auto grads = std::get<1>(result).value_or(undef); + ctx->save_for_backward({grads}); + return {costs, grads}; + } static torch::autograd::tensor_list backward( torch::autograd::AutogradContext* ctx, torch::autograd::tensor_list grad_outputs) { auto saved = ctx->get_saved_variables(); - auto grad = saved[0];auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); + auto grad = saved[0]; + auto grad_out = grad_outputs[0].view({-1, 1, 1, 1}); auto result = grad * grad_out; torch::Tensor undef; - return { - result, - undef, - undef, - undef, - undef, - undef, - undef, - undef}; + return {result, undef, undef, undef, undef, undef, undef, undef}; } }; - std::tuple> rnnt_loss_autograd( - // torch::Tensor rnnt_loss_autograd( torch::Tensor& logits, const torch::Tensor& targets, const torch::Tensor& src_lengths, @@ -57,13 +54,20 @@ std::tuple> rnnt_loss_autograd( double clamp, bool fused_log_smax = true, bool reuse_logits_for_grads = true) { - auto results = RNNTLossFunction::apply(logits, targets, src_lengths, tgt_lengths, blank, clamp, - fused_log_smax, reuse_logits_for_grads); + auto results = RNNTLossFunction::apply( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); return std::make_tuple(results[0], results[1]); } TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) { - m.impl("rnnt_loss", rnnt_loss_autograd); + m.impl("rnnt_loss", rnnt_loss_autograd); } } // namespace rnnt diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index af4a32291d..f47f0f505d 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -11,10 +11,17 @@ std::tuple> rnnt_loss( bool fused_log_smax = true, bool reuse_logits_for_grads = true) { static auto op = torch::Dispatcher::singleton() - .findSchemaOrThrow("torchaudio::rnnt_loss", "") - .typed(); - return op.call(logits, targets, src_lengths, tgt_lengths, blank, clamp, - fused_log_smax, reuse_logits_for_grads); + .findSchemaOrThrow("torchaudio::rnnt_loss", "") + .typed(); + return op.call( + logits, + targets, + src_lengths, + tgt_lengths, + blank, + clamp, + fused_log_smax, + reuse_logits_for_grads); } TORCH_LIBRARY_FRAGMENT(torchaudio, m) { diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index 4403e7cbda..d27449cc0e 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -70,6 +70,7 @@ def _rnnt_loss_betas( clamp, ) + def rnnt_loss( logits: Tensor, targets: Tensor, From 148385486b1d8616a9bc558c8e38aa5397109f36 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Wed, 19 May 2021 17:23:22 +0000 Subject: [PATCH 3/4] remove tensor moves --- torchaudio/prototype/rnnt_loss.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index d27449cc0e..40d255b854 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -20,15 +20,6 @@ def _rnnt_loss_alphas( See documentation for RNNTLoss """ - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - return torch.ops.torchaudio.rnnt_loss_alphas( logits, targets, @@ -52,15 +43,6 @@ def _rnnt_loss_betas( See documentation for RNNTLoss """ - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - return torch.ops.torchaudio.rnnt_loss_betas( logits, targets, @@ -105,16 +87,6 @@ def rnnt_loss( False # softmax needs the original logits value ) - # move everything to the same device. - targets = targets.to(device=logits.device) - logit_lengths = logit_lengths.to(device=logits.device) - target_lengths = target_lengths.to(device=logits.device) - - # make sure all int tensors are of type int32. - targets = targets.int() - logit_lengths = logit_lengths.int() - target_lengths = target_lengths.int() - if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank From ccec70ef30aa2a1a37573df23168ec35d46a71c9 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Wed, 19 May 2021 18:59:07 +0000 Subject: [PATCH 4/4] fix test device --- .../rnnt/torchscript_consistency_impl.py | 6 +++--- test/torchaudio_unittest/rnnt/utils.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py index 770edfaa7a..aeba7e3ae4 100644 --- a/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py @@ -49,9 +49,9 @@ def test_RNNTLoss(self): [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]]]]) - targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32) - logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) - target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32) + targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32) + logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) + target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32) tensor = logits.to(device=self.device, dtype=self.dtype) diff --git a/test/torchaudio_unittest/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py index 8e93d28032..5f4c40379e 100644 --- a/test/torchaudio_unittest/rnnt/utils.py +++ b/test/torchaudio_unittest/rnnt/utils.py @@ -405,10 +405,10 @@ def get_numpy_random_data( def numpy_to_torch(data, device, requires_grad=True): - logits = torch.from_numpy(data["logits"]) - targets = torch.from_numpy(data["targets"]) - logit_lengths = torch.from_numpy(data["logit_lengths"]) - target_lengths = torch.from_numpy(data["target_lengths"]) + logits = torch.from_numpy(data["logits"]).to(device=device) + targets = torch.from_numpy(data["targets"]).to(device=device) + logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device) + target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device) if "nbest_wers" in data: data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)