@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
2323 const int * tgtLengths,
2424 const CAST_DTYPE* denominators,
2525 CAST_DTYPE* logProbs,
26- int H = 1 ,
27- bool fusedLogSmax = true ) {
26+ int H = 1 ) {
2827 const int & maxT = maxSrcLen;
2928 const int & maxU = maxTgtLen;
3029 const int & D = numTargets;
@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
4948 logProbs[(idx << 1 ) + LOG_PROBS_SKIP_IDX] =
5049 CAST_DTYPE (logits[idx * D + blank]) - denominators[idx];
5150
52- if (!fusedLogSmax) {
53- logProbs[(idx << 1 ) + LOG_PROBS_SKIP_IDX] =
54- CAST_DTYPE (logits[idx * D + blank]);
55- }
56-
5751 if (u < U - 1 ) {
5852 // emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
5953 // u).
6054 int target = targets[Indexer2D (maxU - 1 )(bTgt, u)];
6155 logProbs[(idx << 1 ) + LOG_PROBS_EMIT_IDX] =
6256 CAST_DTYPE (logits[idx * D + target]) - denominators[idx];
63-
64- if (!fusedLogSmax) {
65- logProbs[(idx << 1 ) + LOG_PROBS_EMIT_IDX] =
66- CAST_DTYPE (logits[idx * D + target]);
67- }
6857 }
6958}
7059
@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
330319 const CAST_DTYPE* alphas,
331320 const CAST_DTYPE* betas,
332321 DTYPE* gradients,
333- int H = 1 ,
334- bool fusedLogSmax = true ) {
322+ int H = 1 ) {
335323 const int bTgt = blockIdx .z ; // 0 <= b < B
336324 const int t = blockIdx .x * blockDim .x + threadIdx .x ;
337325 const int u = blockIdx .y ;
@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
353341 alphas,
354342 betas,
355343 gradients,
356- H,
357- fusedLogSmax);
344+ H);
358345}
359346
360347// This is a __global__ wrapper around ComputeAlphas
0 commit comments