Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test/torchaudio_unittest/rnnt/autograd_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test_rnnt_loss_gradcheck(self, data_func):
data["target_lengths"], # target_lengths
data["blank"], # blank
-1, # clamp
True, # fused_log_softmax
)

self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)
Expand Down
16 changes: 0 additions & 16 deletions test/torchaudio_unittest/rnnt/rnnt_loss_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
compute_with_numpy_transducer,
compute_with_pytorch_transducer,
get_basic_data,
get_B1_T10_U3_D4_data,
get_B1_T2_U3_D5_data,
get_B2_T4_U3_D3_data,
get_random_data,
Expand Down Expand Up @@ -80,18 +79,3 @@ def test_costs_and_gradients_random_data_with_numpy_fp32(self):
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)

def test_rnnt_nonfused_log_softmax(self):
for random in [False, True]:
data = get_B1_T10_U3_D4_data(
random=random,
dtype=torch.float32,
device=self.device,
)
data["fused_log_softmax"] = False
ref_costs, ref_gradients = compute_with_numpy_transducer(
data=data
)
self._test_costs_and_gradients(
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
)
1 change: 0 additions & 1 deletion test/torchaudio_unittest/rnnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data):
def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
fused_log_softmax=data.get("fused_log_softmax", True),
reduction="none",
)(
logits=data["logits"],
Expand Down
24 changes: 5 additions & 19 deletions torchaudio/csrc/rnnt/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
auto result =
rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
Expand All @@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
logits, targets, logit_lengths, target_lengths, blank, clamp);
return std::make_tuple(results[0], results[1]);
}

Expand Down
15 changes: 3 additions & 12 deletions torchaudio/csrc/rnnt/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
Expand All @@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
"float clamp) -> (Tensor, Tensor?)");
}
3 changes: 1 addition & 2 deletions torchaudio/csrc/rnnt/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax);
double clamp);
4 changes: 1 addition & 3 deletions torchaudio/csrc/rnnt/cpu/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;

CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;
Expand Down
4 changes: 1 addition & 3 deletions torchaudio/csrc/rnnt/gpu/compute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
double clamp) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
"logits and targets must be on the same device");
Expand Down Expand Up @@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
options.numTargets_ = logits.size(3);
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;

CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
options.stream_ = at::cuda::getCurrentCUDAStream();
Expand Down
19 changes: 3 additions & 16 deletions torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
const int* tgtLengths,
const CAST_DTYPE* denominators,
CAST_DTYPE* logProbs,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
Expand All @@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];

if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
CAST_DTYPE(logits[idx * D + blank]);
}

if (u < U - 1) {
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];

if (!fusedLogSmax) {
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
CAST_DTYPE(logits[idx * D + target]);
}
}
}

Expand Down Expand Up @@ -330,8 +319,7 @@ __global__ void ComputeGradients(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int bTgt = blockIdx.z; // 0 <= b < B
const int t = blockIdx.x * blockDim.x + threadIdx.x;
const int u = blockIdx.y;
Expand All @@ -353,8 +341,7 @@ __global__ void ComputeGradients(
alphas,
betas,
gradients,
H,
fusedLogSmax);
H);
}

// This is a __global__ wrapper around ComputeAlphas
Expand Down
8 changes: 2 additions & 6 deletions torchaudio/csrc/rnnt/gpu/gpu_transducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ status_t Compute(
const int& blank = options.blank_;
const CAST_DTYPE clamp = options.clamp_;

const bool& fusedLogSmax = options.fusedLogSmax_;

{ // compute denominators.
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
/*stream=*/stream,
Expand Down Expand Up @@ -134,8 +132,7 @@ status_t Compute(
/*tgtLengths=*/tgtLengths,
/*denominators=*/workspace.GetPointerToDenominators(),
/*log_probs=*/workspace.GetPointerToLogProbs(),
H,
fusedLogSmax);
H);

if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_LOG_PROBS_FAILED;
Expand Down Expand Up @@ -200,8 +197,7 @@ status_t Compute(
/*alphas=*/workspace.GetPointerToAlphas(),
/*betas=*/workspace.GetPointerToBetas(),
/*gradients=*/gradients,
H,
fusedLogSmax);
H);
if (cudaGetLastError() != cudaSuccess) {
return COMPUTE_GRADIENTS_FAILED;
}
Expand Down
53 changes: 15 additions & 38 deletions torchaudio/csrc/rnnt/gpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement(
const CAST_DTYPE* alphas,
const CAST_DTYPE* betas,
DTYPE* gradients,
int H = 1,
bool fusedLogSmax = true) {
int H = 1) {
const int& maxT = maxSrcLen;
const int& maxU = maxTgtLen;
const int& D = numTargets;
Expand Down Expand Up @@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement(
int b_t_u_d = idx_b_t_u * D + d;
CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c;

if (fusedLogSmax) {
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (d == blank && t == T - 1 && u == U - 1) { // last blank transition.
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g);
} else if (t < T - 1 && d == blank) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]);
}
} else { // Non fused log softmax case
CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]);
if (d == blank && t == T - 1 && u == U - 1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u];
} else if (t < T - 1 && d == blank) {
if (idx_b_tp1_u != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1];
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
}
} else {
gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY);
} else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
if (idx_b_t_up1 != -1) {
gradients[b_t_u_d] =
gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]);
}
gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]);
} else {
gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]);
}

if (clamp > 0) {
Expand Down
9 changes: 1 addition & 8 deletions torchaudio/csrc/rnnt/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ typedef struct Options {
// num_targets = D.
int numTargets_;

// if set to true, inputs are logits and gradients are
// fused with logsoftmax gradients.
// if set to false, log_softmax is computed outside of loss
// True by default
bool fusedLogSmax_;

Options()
: device_(UNDEFINED),
numThreads_(0),
Expand All @@ -58,8 +52,7 @@ typedef struct Options {
nHypos_(1),
maxSrcLen_(0),
maxTgtLen_(0),
numTargets_(0),
fusedLogSmax_(true) {}
numTargets_(0) {}

int BU() const {
return batchSize_ * maxTgtLen_ * nHypos_;
Expand Down
10 changes: 0 additions & 10 deletions torchaudio/prototype/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def rnnt_loss(
target_lengths: Tensor,
blank: int = -1,
clamp: float = -1,
fused_log_softmax: bool = True,
reduction: str = "mean",
):
"""Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
Expand All @@ -31,7 +30,6 @@ def rnnt_loss(
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)

Expand All @@ -42,9 +40,6 @@ def rnnt_loss(
if reduction not in ['none', 'mean', 'sum']:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")

if not fused_log_softmax:
logits = torch.nn.functional.log_softmax(logits, dim=-1)

if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank

Expand All @@ -55,7 +50,6 @@ def rnnt_loss(
target_lengths=target_lengths,
blank=blank,
clamp=clamp,
fused_log_softmax=fused_log_softmax
)

if reduction == 'mean':
Expand All @@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module):
Args:
blank (int, opt): blank label (Default: ``-1``)
clamp (float): clamp for gradients (Default: ``-1``)
fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``)
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
"""
Expand All @@ -86,13 +79,11 @@ def __init__(
self,
blank: int = -1,
clamp: float = -1.,
fused_log_softmax: bool = True,
reduction: str = "mean",
):
super().__init__()
self.blank = blank
self.clamp = clamp
self.fused_log_softmax = fused_log_softmax
self.reduction = reduction

def forward(
Expand Down Expand Up @@ -120,6 +111,5 @@ def forward(
target_lengths,
self.blank,
self.clamp,
self.fused_log_softmax,
self.reduction
)