Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript


@skipIfNoTransducer
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cpu')
11 changes: 11 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_cuda_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch

from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .utils import skipIfNoTransducer
from .torchscript_consistency_impl import RNNTLossTorchscript


@skipIfNoTransducer
@skipIfNoCuda
class TestRNNTLoss(RNNTLossTorchscript, PytorchTestCase):
device = torch.device('cuda')
70 changes: 70 additions & 0 deletions test/torchaudio_unittest/rnnt/torchscript_consistency_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torchaudio_unittest.common_utils import TempDirMixin, TestBaseMixin
from torchaudio.prototype.rnnt_loss import RNNTLoss, rnnt_loss


class RNNTLossTorchscript(TempDirMixin, TestBaseMixin):
"""Implements test for RNNT Loss that are performed for different devices"""
def _assert_consistency(self, func, tensor, shape_only=False):
tensor = tensor.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor)

self.assertEqual(ts_output, output)

def test_rnnt_loss(self):
def func(
logits,
):
targets = torch.tensor([[1, 2]], device=logits.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=logits.device, dtype=torch.int32)
return rnnt_loss(logits, targets, logit_lengths, target_lengths)

logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])

self._assert_consistency(func, logits)

def test_RNNTLoss(self):
func = RNNTLoss()

logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.6, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.8, 0.1]],
[[0.1, 0.6, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.2, 0.1, 0.1],
[0.7, 0.1, 0.2, 0.1, 0.1]]]])
targets = torch.tensor([[1, 2]], device=self.device, dtype=torch.int32)
logit_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)
target_lengths = torch.tensor([2], device=self.device, dtype=torch.int32)

tensor = logits.to(device=self.device, dtype=self.dtype)

path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
output = func(input_tensor, targets, logit_lengths, target_lengths)

torch.random.manual_seed(40)
input_tensor = tensor.clone().detach().requires_grad_(True)
ts_output = ts_func(input_tensor, targets, logit_lengths, target_lengths)

self.assertEqual(ts_output, output)
8 changes: 4 additions & 4 deletions test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,10 @@ def get_numpy_random_data(


def numpy_to_torch(data, device, requires_grad=True):
logits = torch.from_numpy(data["logits"])
targets = torch.from_numpy(data["targets"])
logit_lengths = torch.from_numpy(data["logit_lengths"])
target_lengths = torch.from_numpy(data["target_lengths"])
logits = torch.from_numpy(data["logits"]).to(device=device)
targets = torch.from_numpy(data["targets"]).to(device=device)
logit_lengths = torch.from_numpy(data["logit_lengths"]).to(device=device)
target_lengths = torch.from_numpy(data["target_lengths"]).to(device=device)

if "nbest_wers" in data:
data["nbest_wers"] = torch.from_numpy(data["nbest_wers"]).to(device=device)
Expand Down
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ if(BUILD_TRANSDUCER)
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
rnnt/autograd.cpp
)

if (USE_CUDA)
Expand Down
74 changes: 74 additions & 0 deletions torchaudio/csrc/rnnt/autograd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>

namespace torchaudio {
namespace rnnt {

class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
at::AutoNonVariableTypeMode g;
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
return {costs, grads};
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto grad = saved[0];
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
auto result = grad * grad_out;
torch::Tensor undef;
return {result, undef, undef, undef, undef, undef, undef, undef};
}
};

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
auto results = RNNTLossFunction::apply(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
return std::make_tuple(results[0], results[1]);
}

TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("rnnt_loss", rnnt_loss_autograd);
}

} // namespace rnnt
} // namespace torchaudio
24 changes: 24 additions & 0 deletions torchaudio/csrc/rnnt/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax = true,
bool reuse_logits_for_grads = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
src_lengths,
tgt_lengths,
blank,
clamp,
fused_log_smax,
reuse_logits_for_grads);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
Expand Down
13 changes: 13 additions & 0 deletions torchaudio/csrc/rnnt/compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <torch/script.h>

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& src_lengths,
const torch::Tensor& tgt_lengths,
int64_t blank,
double clamp,
bool fused_log_smax,
bool reuse_logits_for_grads);
Loading