diff --git a/test/torchaudio_unittest/rnnt/autograd_impl.py b/test/torchaudio_unittest/rnnt/autograd_impl.py index 8d0ebf9cb1..41e33acf8d 100644 --- a/test/torchaudio_unittest/rnnt/autograd_impl.py +++ b/test/torchaudio_unittest/rnnt/autograd_impl.py @@ -71,7 +71,6 @@ def test_rnnt_loss_gradcheck(self, data_func): data["target_lengths"], # target_lengths data["blank"], # blank -1, # clamp - True, # fused_log_softmax ) self.assert_grad(rnnt_loss, inputs, enable_all_grad=False) diff --git a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py index a9ca72951c..7d7a53a113 100644 --- a/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py +++ b/test/torchaudio_unittest/rnnt/rnnt_loss_impl.py @@ -5,7 +5,6 @@ compute_with_numpy_transducer, compute_with_pytorch_transducer, get_basic_data, - get_B1_T10_U3_D4_data, get_B1_T2_U3_D5_data, get_B2_T4_U3_D3_data, get_random_data, @@ -80,18 +79,3 @@ def test_costs_and_gradients_random_data_with_numpy_fp32(self): self._test_costs_and_gradients( data=data, ref_costs=ref_costs, ref_gradients=ref_gradients ) - - def test_rnnt_nonfused_log_softmax(self): - for random in [False, True]: - data = get_B1_T10_U3_D4_data( - random=random, - dtype=torch.float32, - device=self.device, - ) - data["fused_log_softmax"] = False - ref_costs, ref_gradients = compute_with_numpy_transducer( - data=data - ) - self._test_costs_and_gradients( - data=data, ref_costs=ref_costs, ref_gradients=ref_gradients - ) diff --git a/test/torchaudio_unittest/rnnt/utils.py b/test/torchaudio_unittest/rnnt/utils.py index ec2933bf1e..1bec65552b 100644 --- a/test/torchaudio_unittest/rnnt/utils.py +++ b/test/torchaudio_unittest/rnnt/utils.py @@ -26,7 +26,6 @@ def compute_with_numpy_transducer(data): def compute_with_pytorch_transducer(data): costs = RNNTLoss( blank=data["blank"], - fused_log_softmax=data.get("fused_log_softmax", True), reduction="none", )( logits=data["logits"], diff --git a/torchaudio/csrc/rnnt/autograd.cpp b/torchaudio/csrc/rnnt/autograd.cpp index 59f3e9ebeb..4fa6ba7996 100644 --- a/torchaudio/csrc/rnnt/autograd.cpp +++ b/torchaudio/csrc/rnnt/autograd.cpp @@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function { const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, - double clamp, - bool fused_log_softmax = true) { + double clamp) { torch::Tensor undef; - auto result = rnnt_loss( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); + auto result = + rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp); auto costs = std::get<0>(result); auto grads = std::get<1>(result).value_or(undef); ctx->save_for_backward({grads}); @@ -48,17 +41,10 @@ std::tuple> rnnt_loss_autograd( const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, - double clamp, - bool fused_log_softmax = true) { + double clamp) { at::AutoDispatchBelowADInplaceOrView guard; auto results = RNNTLossFunction::apply( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); + logits, targets, logit_lengths, target_lengths, blank, clamp); return std::make_tuple(results[0], results[1]); } diff --git a/torchaudio/csrc/rnnt/compute.cpp b/torchaudio/csrc/rnnt/compute.cpp index f21413e432..9c3bf84a05 100644 --- a/torchaudio/csrc/rnnt/compute.cpp +++ b/torchaudio/csrc/rnnt/compute.cpp @@ -7,19 +7,11 @@ std::tuple> rnnt_loss( const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, - double clamp, - bool fused_log_softmax = true) { + double clamp) { static auto op = torch::Dispatcher::singleton() .findSchemaOrThrow("torchaudio::rnnt_loss", "") .typed(); - return op.call( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); + return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp); } TORCH_LIBRARY_FRAGMENT(torchaudio, m) { @@ -29,6 +21,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "Tensor logit_lengths," "Tensor target_lengths," "int blank," - "float clamp," - "bool fused_log_softmax=True) -> (Tensor, Tensor?)"); + "float clamp) -> (Tensor, Tensor?)"); } diff --git a/torchaudio/csrc/rnnt/compute.h b/torchaudio/csrc/rnnt/compute.h index 0508cec80e..eea16a5fee 100644 --- a/torchaudio/csrc/rnnt/compute.h +++ b/torchaudio/csrc/rnnt/compute.h @@ -8,5 +8,4 @@ std::tuple> rnnt_loss( const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, - double clamp, - bool fused_log_softmax); + double clamp); diff --git a/torchaudio/csrc/rnnt/cpu/compute.cpp b/torchaudio/csrc/rnnt/cpu/compute.cpp index 5d6df5aeab..c1557db0c0 100644 --- a/torchaudio/csrc/rnnt/cpu/compute.cpp +++ b/torchaudio/csrc/rnnt/cpu/compute.cpp @@ -12,8 +12,7 @@ std::tuple> compute( const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, - double clamp, - bool fused_log_softmax = true) { + double clamp) { TORCH_CHECK( logits.device().type() == targets.device().type(), "logits and targets must be on the same device"); @@ -81,7 +80,6 @@ std::tuple> compute( options.numTargets_ = logits.size(3); options.blank_ = blank; options.clamp_ = clamp; - options.fusedLogSmax_ = fused_log_softmax; CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); options.device_ = CPU; diff --git a/torchaudio/csrc/rnnt/gpu/compute.cu b/torchaudio/csrc/rnnt/gpu/compute.cu index 0e7badabe4..f463b5853c 100644 --- a/torchaudio/csrc/rnnt/gpu/compute.cu +++ b/torchaudio/csrc/rnnt/gpu/compute.cu @@ -13,8 +13,7 @@ std::tuple> compute( const torch::Tensor& logit_lengths, const torch::Tensor& target_lengths, int64_t blank, - double clamp, - bool fused_log_softmax = true) { + double clamp) { TORCH_CHECK( logits.device().type() == targets.device().type(), "logits and targets must be on the same device"); @@ -82,7 +81,6 @@ std::tuple> compute( options.numTargets_ = logits.size(3); options.blank_ = blank; options.clamp_ = clamp; - options.fusedLogSmax_ = fused_log_softmax; CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); options.stream_ = at::cuda::getCurrentCUDAStream(); diff --git a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh index 4ba04b68fc..90b5ebfd7e 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh +++ b/torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh @@ -23,8 +23,7 @@ __global__ void ComputeLogProbs( const int* tgtLengths, const CAST_DTYPE* denominators, CAST_DTYPE* logProbs, - int H = 1, - bool fusedLogSmax = true) { + int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int& D = numTargets; @@ -49,22 +48,12 @@ __global__ void ComputeLogProbs( logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = CAST_DTYPE(logits[idx * D + blank]) - denominators[idx]; - if (!fusedLogSmax) { - logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] = - CAST_DTYPE(logits[idx * D + blank]); - } - if (u < U - 1) { // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t, // u). int target = targets[Indexer2D(maxU - 1)(bTgt, u)]; logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = CAST_DTYPE(logits[idx * D + target]) - denominators[idx]; - - if (!fusedLogSmax) { - logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] = - CAST_DTYPE(logits[idx * D + target]); - } } } @@ -330,8 +319,7 @@ __global__ void ComputeGradients( const CAST_DTYPE* alphas, const CAST_DTYPE* betas, DTYPE* gradients, - int H = 1, - bool fusedLogSmax = true) { + int H = 1) { const int bTgt = blockIdx.z; // 0 <= b < B const int t = blockIdx.x * blockDim.x + threadIdx.x; const int u = blockIdx.y; @@ -353,8 +341,7 @@ __global__ void ComputeGradients( alphas, betas, gradients, - H, - fusedLogSmax); + H); } // This is a __global__ wrapper around ComputeAlphas diff --git a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h index 72759b39f4..54d16b9f21 100644 --- a/torchaudio/csrc/rnnt/gpu/gpu_transducer.h +++ b/torchaudio/csrc/rnnt/gpu/gpu_transducer.h @@ -102,8 +102,6 @@ status_t Compute( const int& blank = options.blank_; const CAST_DTYPE clamp = options.clamp_; - const bool& fusedLogSmax = options.fusedLogSmax_; - { // compute denominators. status_t status = LogSumExp2D( /*stream=*/stream, @@ -134,8 +132,7 @@ status_t Compute( /*tgtLengths=*/tgtLengths, /*denominators=*/workspace.GetPointerToDenominators(), /*log_probs=*/workspace.GetPointerToLogProbs(), - H, - fusedLogSmax); + H); if (cudaGetLastError() != cudaSuccess) { return COMPUTE_LOG_PROBS_FAILED; @@ -200,8 +197,7 @@ status_t Compute( /*alphas=*/workspace.GetPointerToAlphas(), /*betas=*/workspace.GetPointerToBetas(), /*gradients=*/gradients, - H, - fusedLogSmax); + H); if (cudaGetLastError() != cudaSuccess) { return COMPUTE_GRADIENTS_FAILED; } diff --git a/torchaudio/csrc/rnnt/gpu/kernels.h b/torchaudio/csrc/rnnt/gpu/kernels.h index db8bb5092b..b0627c2181 100644 --- a/torchaudio/csrc/rnnt/gpu/kernels.h +++ b/torchaudio/csrc/rnnt/gpu/kernels.h @@ -26,8 +26,7 @@ HOST_AND_DEVICE void ComputeGradientsElement( const CAST_DTYPE* alphas, const CAST_DTYPE* betas, DTYPE* gradients, - int H = 1, - bool fusedLogSmax = true) { + int H = 1) { const int& maxT = maxSrcLen; const int& maxU = maxTgtLen; const int& D = numTargets; @@ -79,44 +78,22 @@ HOST_AND_DEVICE void ComputeGradientsElement( int b_t_u_d = idx_b_t_u * D + d; CAST_DTYPE g = CAST_DTYPE(logits[b_t_u_d]) + c; - if (fusedLogSmax) { - if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. - gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g); - } else if (t < T - 1 && d == blank) { - gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); - if (idx_b_tp1_u != -1) { - gradients[b_t_u_d] = - gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]); - } - } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { - gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); - if (idx_b_t_up1 != -1) { - gradients[b_t_u_d] = - gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]); - } - } else { - gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (d == blank && t == T - 1 && u == U - 1) { // last blank transition. + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]) - std::exp(g); + } else if (t < T - 1 && d == blank) { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (idx_b_tp1_u != -1) { + gradients[b_t_u_d] = + gradients[b_t_u_d] - std::exp(g + betas[idx_b_tp1_u]); } - } else { // Non fused log softmax case - CAST_DTYPE g = cost + CAST_DTYPE(logits[b_t_u_d]); - if (d == blank && t == T - 1 && u == U - 1) { - gradients[b_t_u_d] = g + alphas[idx_b_t_u]; - } else if (t < T - 1 && d == blank) { - if (idx_b_tp1_u != -1) { - gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_tp1_u]; - } else { - gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); - } - } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { - if (idx_b_t_up1 != -1) { - gradients[b_t_u_d] = g + alphas[idx_b_t_u] + betas[idx_b_t_up1]; - } else { - gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); - } - } else { - gradients[b_t_u_d] = g + CAST_DTYPE(-INFINITY); + } else if (u < U - 1 && d == targets[idxr2(bTgt, u)]) { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); + if (idx_b_t_up1 != -1) { + gradients[b_t_u_d] = + gradients[b_t_u_d] - std::exp(g + betas[idx_b_t_up1]); } - gradients[b_t_u_d] = -std::exp(gradients[b_t_u_d]); + } else { + gradients[b_t_u_d] = std::exp(g + betas[idx_b_t_u]); } if (clamp > 0) { diff --git a/torchaudio/csrc/rnnt/options.h b/torchaudio/csrc/rnnt/options.h index f70a3c8c07..79109950fd 100644 --- a/torchaudio/csrc/rnnt/options.h +++ b/torchaudio/csrc/rnnt/options.h @@ -42,12 +42,6 @@ typedef struct Options { // num_targets = D. int numTargets_; - // if set to true, inputs are logits and gradients are - // fused with logsoftmax gradients. - // if set to false, log_softmax is computed outside of loss - // True by default - bool fusedLogSmax_; - Options() : device_(UNDEFINED), numThreads_(0), @@ -58,8 +52,7 @@ typedef struct Options { nHypos_(1), maxSrcLen_(0), maxTgtLen_(0), - numTargets_(0), - fusedLogSmax_(true) {} + numTargets_(0) {} int BU() const { return batchSize_ * maxTgtLen_ * nHypos_; diff --git a/torchaudio/prototype/rnnt_loss.py b/torchaudio/prototype/rnnt_loss.py index ffaee638c5..4019ffe3a4 100644 --- a/torchaudio/prototype/rnnt_loss.py +++ b/torchaudio/prototype/rnnt_loss.py @@ -14,7 +14,6 @@ def rnnt_loss( target_lengths: Tensor, blank: int = -1, clamp: float = -1, - fused_log_softmax: bool = True, reduction: str = "mean", ): """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks* @@ -31,7 +30,6 @@ def rnnt_loss( target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) - fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) @@ -42,9 +40,6 @@ def rnnt_loss( if reduction not in ['none', 'mean', 'sum']: raise ValueError("reduction should be one of 'none', 'mean', or 'sum'") - if not fused_log_softmax: - logits = torch.nn.functional.log_softmax(logits, dim=-1) - if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank @@ -55,7 +50,6 @@ def rnnt_loss( target_lengths=target_lengths, blank=blank, clamp=clamp, - fused_log_softmax=fused_log_softmax ) if reduction == 'mean': @@ -77,7 +71,6 @@ class RNNTLoss(torch.nn.Module): Args: blank (int, opt): blank label (Default: ``-1``) clamp (float): clamp for gradients (Default: ``-1``) - fused_log_softmax (bool): set to False if calling log_softmax outside loss (Default: ``True``) reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) """ @@ -86,13 +79,11 @@ def __init__( self, blank: int = -1, clamp: float = -1., - fused_log_softmax: bool = True, reduction: str = "mean", ): super().__init__() self.blank = blank self.clamp = clamp - self.fused_log_softmax = fused_log_softmax self.reduction = reduction def forward( @@ -120,6 +111,5 @@ def forward( target_lengths, self.blank, self.clamp, - self.fused_log_softmax, self.reduction )