@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
2323    const  int * tgtLengths,
2424    const  CAST_DTYPE* denominators,
2525    CAST_DTYPE* logProbs,
26-     int  H = 1 ,
27-     bool  fusedLogSmax = true ) {
26+     int  H = 1 ) {
2827  const  int & maxT = maxSrcLen;
2928  const  int & maxU = maxTgtLen;
3029  const  int & D = numTargets;
@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
4948  logProbs[(idx << 1 ) + LOG_PROBS_SKIP_IDX] =
5049      CAST_DTYPE (logits[idx * D + blank]) - denominators[idx];
5150
52-   if  (!fusedLogSmax) {
53-     logProbs[(idx << 1 ) + LOG_PROBS_SKIP_IDX] =
54-         CAST_DTYPE (logits[idx * D + blank]);
55-   }
56- 
5751  if  (u < U - 1 ) {
5852    //  emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
5953    //  u).
6054    int  target = targets[Indexer2D (maxU - 1 )(bTgt, u)];
6155    logProbs[(idx << 1 ) + LOG_PROBS_EMIT_IDX] =
6256        CAST_DTYPE (logits[idx * D + target]) - denominators[idx];
63- 
64-     if  (!fusedLogSmax) {
65-       logProbs[(idx << 1 ) + LOG_PROBS_EMIT_IDX] =
66-           CAST_DTYPE (logits[idx * D + target]);
67-     }
6857  }
6958}
7059
@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
330319    const  CAST_DTYPE* alphas,
331320    const  CAST_DTYPE* betas,
332321    DTYPE* gradients,
333-     int  H = 1 ,
334-     bool  fusedLogSmax = true ) {
322+     int  H = 1 ) {
335323  const  int  bTgt = blockIdx .z ; //  0 <= b < B
336324  const  int  t = blockIdx .x  * blockDim .x  + threadIdx .x ;
337325  const  int  u = blockIdx .y ;
@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
353341      alphas,
354342      betas,
355343      gradients,
356-       H,
357-       fusedLogSmax);
344+       H);
358345}
359346
360347//  This is a __global__ wrapper around ComputeAlphas
0 commit comments