Skip to content

Commit 77b3082

Browse files
author
Caroline Chen
committed
remove fused_log_softmax option
1 parent 27690ec commit 77b3082

File tree

13 files changed

+32
-135
lines changed

13 files changed

+32
-135
lines changed

test/torchaudio_unittest/rnnt/autograd_impl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def test_rnnt_loss_gradcheck(self, data_func):
7575
data["target_lengths"], # target_lengths
7676
data["blank"], # blank
7777
-1, # clamp
78-
True, # fused_log_softmax
7978
)
8079

8180
self.assert_grad(rnnt_loss, inputs, enable_all_grad=False)

test/torchaudio_unittest/rnnt/rnnt_loss_impl.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,3 @@ def test_costs_and_gradients_random_data_with_numpy_fp32(self):
9696
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
9797
)
9898

99-
def test_rnnt_nonfused_log_softmax(self):
100-
for random in [False, True]:
101-
data = get_B1_T10_U3_D4_data(
102-
random=random,
103-
)
104-
data = numpy_to_torch(
105-
data=data, device=self.device, requires_grad=True
106-
)
107-
data["fused_log_softmax"] = False
108-
ref_costs, ref_gradients = compute_with_numpy_transducer(
109-
data=data
110-
)
111-
self._test_costs_and_gradients(
112-
data=data, ref_costs=ref_costs, ref_gradients=ref_gradients
113-
)

test/torchaudio_unittest/rnnt/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def compute_with_numpy_transducer(data):
2929
def compute_with_pytorch_transducer(data):
3030
costs = RNNTLoss(
3131
blank=data["blank"],
32-
fused_log_softmax=data.get("fused_log_softmax", True),
3332
reduction="none",
3433
)(
3534
logits=data["logits"],

torchaudio/csrc/rnnt/autograd.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,10 @@ class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
1313
const torch::Tensor& logit_lengths,
1414
const torch::Tensor& target_lengths,
1515
int64_t blank,
16-
double clamp,
17-
bool fused_log_softmax = true) {
16+
double clamp) {
1817
torch::Tensor undef;
19-
auto result = rnnt_loss(
20-
logits,
21-
targets,
22-
logit_lengths,
23-
target_lengths,
24-
blank,
25-
clamp,
26-
fused_log_softmax);
18+
auto result =
19+
rnnt_loss(logits, targets, logit_lengths, target_lengths, blank, clamp);
2720
auto costs = std::get<0>(result);
2821
auto grads = std::get<1>(result).value_or(undef);
2922
ctx->save_for_backward({grads});
@@ -48,17 +41,10 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
4841
const torch::Tensor& logit_lengths,
4942
const torch::Tensor& target_lengths,
5043
int64_t blank,
51-
double clamp,
52-
bool fused_log_softmax = true) {
44+
double clamp) {
5345
at::AutoDispatchBelowADInplaceOrView guard;
5446
auto results = RNNTLossFunction::apply(
55-
logits,
56-
targets,
57-
logit_lengths,
58-
target_lengths,
59-
blank,
60-
clamp,
61-
fused_log_softmax);
47+
logits, targets, logit_lengths, target_lengths, blank, clamp);
6248
return std::make_tuple(results[0], results[1]);
6349
}
6450

torchaudio/csrc/rnnt/compute.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,11 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
77
const torch::Tensor& logit_lengths,
88
const torch::Tensor& target_lengths,
99
int64_t blank,
10-
double clamp,
11-
bool fused_log_softmax = true) {
10+
double clamp) {
1211
static auto op = torch::Dispatcher::singleton()
1312
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
1413
.typed<decltype(rnnt_loss)>();
15-
return op.call(
16-
logits,
17-
targets,
18-
logit_lengths,
19-
target_lengths,
20-
blank,
21-
clamp,
22-
fused_log_softmax);
23-
}
14+
return op.call(logits, targets, logit_lengths, target_lengths, blank, clamp);
2415

2516
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
2617
m.def(
@@ -29,6 +20,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
2920
"Tensor logit_lengths,"
3021
"Tensor target_lengths,"
3122
"int blank,"
32-
"float clamp,"
33-
"bool fused_log_softmax=True) -> (Tensor, Tensor?)");
23+
"float clamp) -> (Tensor, Tensor?)");
3424
}

torchaudio/csrc/rnnt/compute.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
88
const torch::Tensor& logit_lengths,
99
const torch::Tensor& target_lengths,
1010
int64_t blank,
11-
double clamp,
12-
bool fused_log_softmax);
11+
double clamp);

torchaudio/csrc/rnnt/cpu/compute.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
1212
const torch::Tensor& logit_lengths,
1313
const torch::Tensor& target_lengths,
1414
int64_t blank,
15-
double clamp,
16-
bool fused_log_softmax = true) {
15+
double clamp) {
1716
TORCH_CHECK(
1817
logits.device().type() == targets.device().type(),
1918
"logits and targets must be on the same device");
@@ -81,7 +80,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
8180
options.numTargets_ = logits.size(3);
8281
options.blank_ = blank;
8382
options.clamp_ = clamp;
84-
options.fusedLogSmax_ = fused_log_softmax;
8583

8684
CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
8785
options.device_ = CPU;

torchaudio/csrc/rnnt/gpu/compute.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
1313
const torch::Tensor& logit_lengths,
1414
const torch::Tensor& target_lengths,
1515
int64_t blank,
16-
double clamp,
17-
bool fused_log_softmax = true) {
16+
double clamp) {
1817
TORCH_CHECK(
1918
logits.device().type() == targets.device().type(),
2019
"logits and targets must be on the same device");
@@ -82,7 +81,6 @@ std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
8281
options.numTargets_ = logits.size(3);
8382
options.blank_ = blank;
8483
options.clamp_ = clamp;
85-
options.fusedLogSmax_ = fused_log_softmax;
8684

8785
CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
8886
options.stream_ = at::cuda::getCurrentCUDAStream();

torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ __global__ void ComputeLogProbs(
2323
const int* tgtLengths,
2424
const CAST_DTYPE* denominators,
2525
CAST_DTYPE* logProbs,
26-
int H = 1,
27-
bool fusedLogSmax = true) {
26+
int H = 1) {
2827
const int& maxT = maxSrcLen;
2928
const int& maxU = maxTgtLen;
3029
const int& D = numTargets;
@@ -49,22 +48,12 @@ __global__ void ComputeLogProbs(
4948
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
5049
CAST_DTYPE(logits[idx * D + blank]) - denominators[idx];
5150

52-
if (!fusedLogSmax) {
53-
logProbs[(idx << 1) + LOG_PROBS_SKIP_IDX] =
54-
CAST_DTYPE(logits[idx * D + blank]);
55-
}
56-
5751
if (u < U - 1) {
5852
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
5953
// u).
6054
int target = targets[Indexer2D(maxU - 1)(bTgt, u)];
6155
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
6256
CAST_DTYPE(logits[idx * D + target]) - denominators[idx];
63-
64-
if (!fusedLogSmax) {
65-
logProbs[(idx << 1) + LOG_PROBS_EMIT_IDX] =
66-
CAST_DTYPE(logits[idx * D + target]);
67-
}
6857
}
6958
}
7059

@@ -330,8 +319,7 @@ __global__ void ComputeGradients(
330319
const CAST_DTYPE* alphas,
331320
const CAST_DTYPE* betas,
332321
DTYPE* gradients,
333-
int H = 1,
334-
bool fusedLogSmax = true) {
322+
int H = 1) {
335323
const int bTgt = blockIdx.z; // 0 <= b < B
336324
const int t = blockIdx.x * blockDim.x + threadIdx.x;
337325
const int u = blockIdx.y;
@@ -353,8 +341,7 @@ __global__ void ComputeGradients(
353341
alphas,
354342
betas,
355343
gradients,
356-
H,
357-
fusedLogSmax);
344+
H);
358345
}
359346

360347
// This is a __global__ wrapper around ComputeAlphas

torchaudio/csrc/rnnt/gpu/gpu_transducer.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ status_t Compute(
102102
const int& blank = options.blank_;
103103
const CAST_DTYPE clamp = options.clamp_;
104104

105-
const bool& fusedLogSmax = options.fusedLogSmax_;
106-
107105
{ // compute denominators.
108106
status_t status = LogSumExp2D<DTYPE, CAST_DTYPE>(
109107
/*stream=*/stream,
@@ -134,8 +132,7 @@ status_t Compute(
134132
/*tgtLengths=*/tgtLengths,
135133
/*denominators=*/workspace.GetPointerToDenominators(),
136134
/*log_probs=*/workspace.GetPointerToLogProbs(),
137-
H,
138-
fusedLogSmax);
135+
H);
139136

140137
if (cudaGetLastError() != cudaSuccess) {
141138
return COMPUTE_LOG_PROBS_FAILED;
@@ -200,8 +197,7 @@ status_t Compute(
200197
/*alphas=*/workspace.GetPointerToAlphas(),
201198
/*betas=*/workspace.GetPointerToBetas(),
202199
/*gradients=*/gradients,
203-
H,
204-
fusedLogSmax);
200+
H);
205201
if (cudaGetLastError() != cudaSuccess) {
206202
return COMPUTE_GRADIENTS_FAILED;
207203
}

0 commit comments

Comments
 (0)