Skip to content

Commit cc592e0

Browse files
author
Sam Anklesaria
committed
Use stable ABI for compute
1 parent 7f11d1d commit cc592e0

File tree

5 files changed

+369
-274
lines changed

5 files changed

+369
-274
lines changed

src/libtorchaudio/rnnt/compute.cpp

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,12 @@
1-
#include <libtorchaudio/rnnt/compute.h>
1+
#include <torch/csrc/stable/library.h>
22

3-
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
4-
torch::Tensor& logits,
5-
const torch::Tensor& targets,
6-
const torch::Tensor& logit_lengths,
7-
const torch::Tensor& target_lengths,
8-
int64_t blank,
9-
double clamp,
10-
bool fused_log_softmax = true) {
11-
static auto op = torch::Dispatcher::singleton()
12-
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
13-
.typed<decltype(rnnt_loss)>();
14-
return op.call(
15-
logits,
16-
targets,
17-
logit_lengths,
18-
target_lengths,
19-
blank,
20-
clamp,
21-
fused_log_softmax);
22-
}
23-
24-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
254
m.def(
26-
"rnnt_loss(Tensor logits,"
5+
"torchaudio::rnnt_loss(Tensor logits,"
276
"Tensor targets,"
287
"Tensor logit_lengths,"
298
"Tensor target_lengths,"
309
"int blank,"
3110
"float clamp,"
3211
"bool fused_log_softmax) -> (Tensor, Tensor?)");
33-
m.def("torchaudio::rnnt_loss_forward", &rnnt_loss);
3412
}

src/libtorchaudio/rnnt/compute.h

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 179 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,148 +1,212 @@
11
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
2-
#include <torch/script.h>
2+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
3+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
4+
#include <torch/csrc/stable/library.h>
35

46
namespace torchaudio {
57
namespace rnnt {
68
namespace cpu {
79

10+
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
11+
812
// Entry point into RNNT Loss
9-
std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
10-
torch::Tensor& logits,
11-
const torch::Tensor& targets,
12-
const torch::Tensor& logit_lengths,
13-
const torch::Tensor& target_lengths,
13+
std::tuple<RAIIATH, RAIIATH> compute(
14+
const RAIIATH logits,
15+
const RAIIATH targets,
16+
const RAIIATH logit_lengths,
17+
const RAIIATH target_lengths,
1418
int64_t blank,
1519
double clamp,
1620
bool fused_log_softmax = true) {
17-
TORCH_CHECK(
18-
logits.device().type() == targets.device().type(),
19-
"logits and targets must be on the same device");
20-
TORCH_CHECK(
21-
logits.device().type() == logit_lengths.device().type(),
22-
"logits and logit_lengths must be on the same device");
23-
TORCH_CHECK(
24-
logits.device().type() == target_lengths.device().type(),
25-
"logits and target_lengths must be on the same device");
26-
27-
TORCH_CHECK(
28-
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
29-
"logits must be float32 or float16 (half) type");
30-
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
31-
TORCH_CHECK(
32-
logit_lengths.dtype() == torch::kInt32,
33-
"logit_lengths must be int32 type");
34-
TORCH_CHECK(
35-
target_lengths.dtype() == torch::kInt32,
36-
"target_lengths must be int32 type");
37-
38-
TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
39-
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
40-
TORCH_CHECK(
41-
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
42-
TORCH_CHECK(
43-
target_lengths.is_contiguous(), "target_lengths must be contiguous");
44-
45-
TORCH_CHECK(
46-
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
47-
TORCH_CHECK(
48-
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
49-
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
50-
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
51-
52-
TORCH_CHECK(
53-
logit_lengths.size(0) == logits.size(0),
54-
"batch dimension mismatch between logits and logit_lengths");
55-
TORCH_CHECK(
56-
target_lengths.size(0) == logits.size(0),
57-
"batch dimension mismatch between logits and target_lengths");
58-
TORCH_CHECK(
59-
targets.size(0) == logits.size(0),
60-
"batch dimension mismatch between logits and targets");
61-
62-
TORCH_CHECK(
63-
blank >= 0 && blank < logits.size(-1),
64-
"blank must be within [0, logits.shape[-1])");
65-
66-
TORCH_CHECK(
67-
logits.size(1) == at::max(logit_lengths).item().toInt(),
68-
"input length mismatch");
69-
TORCH_CHECK(
70-
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
71-
"output length mismatch");
72-
TORCH_CHECK(
73-
targets.size(1) == at::max(target_lengths).item().toInt(),
74-
"target length mismatch");
21+
22+
int32_t logits_device;
23+
aoti_torch_get_device_type(logits.get(), &logits_device);
24+
int32_t targets_device;
25+
aoti_torch_get_device_type(targets.get(), &targets_device);
26+
int32_t logit_lengths_device;
27+
aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device);
28+
int32_t target_lengths_device;
29+
aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device);
30+
31+
AOTI_TORCH_CHECK(logits_device == targets_device);
32+
AOTI_TORCH_CHECK(logits_device == logit_lengths_device);
33+
AOTI_TORCH_CHECK(logits_device == target_lengths_device);
34+
35+
int32_t logits_dtype;
36+
aoti_torch_get_dtype(logits.get(), &logits_dtype);
37+
AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() ||
38+
logits_dtype == aoti_torch_dtype_float16());
39+
40+
int32_t targets_dtype;
41+
aoti_torch_get_dtype(targets.get(), &targets_dtype);
42+
AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() ||
43+
logits_dtype == aoti_torch_dtype_float16());
44+
45+
int32_t logit_lengths_dtype;
46+
aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype);
47+
AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() ||
48+
logit_lengths_dtype == aoti_torch_dtype_float16());
49+
50+
int32_t target_lengths_dtype;
51+
aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype);
52+
AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() ||
53+
target_lengths_dtype == aoti_torch_dtype_float16());
54+
55+
bool bool_tmp;
56+
aoti_torch_is_contiguous(logits.get(), &bool_tmp);
57+
AOTI_TORCH_CHECK(bool_tmp);
58+
aoti_torch_is_contiguous(targets.get(), &bool_tmp);
59+
AOTI_TORCH_CHECK(bool_tmp);
60+
aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp);
61+
AOTI_TORCH_CHECK(bool_tmp);
62+
aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp);
63+
64+
int64_t int_tmp;
65+
aoti_torch_get_dim(logits.get(), &int_tmp);
66+
AOTI_TORCH_CHECK(int_tmp == 4);
67+
aoti_torch_get_dim(targets.get(), &int_tmp);
68+
AOTI_TORCH_CHECK(int_tmp == 2);
69+
aoti_torch_get_dim(logit_lengths.get(), &int_tmp);
70+
AOTI_TORCH_CHECK(int_tmp == 1);
71+
aoti_torch_get_dim(target_lengths.get(), &int_tmp);
72+
AOTI_TORCH_CHECK(int_tmp == 1);
73+
74+
int64_t logit_lengths_size;
75+
aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size);
76+
int64_t logits_size;
77+
aoti_torch_get_size(logits.get(), 0, &logits_size);
78+
AOTI_TORCH_CHECK(logit_lengths_size == logits_size);
79+
int64_t target_lengths_size;
80+
aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size);
81+
AOTI_TORCH_CHECK(target_lengths_size == logits_size);
82+
int64_t targets_size;
83+
aoti_torch_get_size(targets.get(), 0, &targets_size);
84+
AOTI_TORCH_CHECK(targets_size == logits_size);
85+
86+
// TORCH_CHECK(
87+
// blank >= 0 && blank < logits.size(-1),
88+
// "blank must be within [0, logits.shape[-1])");
89+
90+
// TORCH_CHECK(
91+
// logits.size(1) == at::max(logit_lengths).item().toInt(),
92+
// "input length mismatch");
93+
// TORCH_CHECK(
94+
// logits.size(2) == at::max(target_lengths).item().toInt() + 1,
95+
// "output length mismatch");
96+
// TORCH_CHECK(
97+
// targets.size(1) == at::max(target_lengths).item().toInt(),
98+
// "target length mismatch");
7599

76100
Options options;
77-
options.batchSize_ = logit_lengths.size(0);
78-
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
79-
options.maxSrcLen_ = logits.size(1);
80-
options.maxTgtLen_ = logits.size(2);
81-
options.numTargets_ = logits.size(3);
101+
options.batchSize_ = (int)logit_lengths_size;
102+
options.nHypos_ = (int)target_lengths_size;
103+
options.nHypos_ /= options.batchSize_;
104+
aoti_torch_get_size(logits.get(), 1, &int_tmp);
105+
options.maxSrcLen_ = (int)int_tmp;
106+
aoti_torch_get_size(logits.get(), 2, &int_tmp);
107+
options.maxTgtLen_ = (int)int_tmp;
108+
aoti_torch_get_size(logits.get(), 3, &int_tmp);
109+
options.numTargets_ = (int)int_tmp;
82110
options.blank_ = blank;
83111
options.clamp_ = clamp;
84112
options.fusedLogSmax_ = fused_log_softmax;
85113

86-
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
114+
AOTI_TORCH_CHECK(logits_device == aoti_torch_device_type_cpu());
87115
options.device_ = CPU;
88116

89-
torch::Tensor costs = torch::empty(
90-
options.batchSize_ * options.nHypos_,
91-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
92-
std::optional<torch::Tensor> gradients = torch::zeros_like(logits);
93-
94-
torch::Tensor int_workspace = torch::empty(
95-
IntWorkspace::ComputeSizeFromOptions(options),
96-
torch::TensorOptions()
97-
.device(logits.device())
98-
.dtype(torch::ScalarType::Int));
99-
100-
torch::Tensor float_workspace = torch::empty(
101-
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
102-
torch::TensorOptions()
103-
.device(logits.device())
104-
.dtype(torch::ScalarType::Float));
117+
int32_t logits_device_index;
118+
aoti_torch_get_device_index(logits.get(), &logits_device_index);
119+
int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_};
120+
int64_t stride1[1] = {1};
121+
AtenTensorHandle costs;
122+
aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs);
123+
124+
AtenTensorHandle gradients;
125+
aoti_torch_clone(logits.get(), &gradients);
126+
aoti_torch_zero_(gradients);
127+
128+
AtenTensorHandle int_workspace;
129+
int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)};
130+
int64_t strides[1] = {1};
131+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace);
132+
133+
AtenTensorHandle float_workspace;
134+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace);
135+
136+
int64_t float_numel;
137+
aoti_torch_get_numel(float_workspace, &float_numel);
138+
void *int_workspace_ptr;
139+
aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr);
140+
void *float_workspace_ptr;
141+
aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr);
142+
int64_t int_numel;
143+
aoti_torch_get_numel(int_workspace, &int_numel);
105144

106145
Workspace<float> workspace(
107146
/*options=*/options,
108-
/*dtype_data=*/float_workspace.data_ptr<float>(),
109-
/*dtype_size=*/float_workspace.numel(),
110-
/*int_data=*/int_workspace.data_ptr<int>(),
111-
/*int_size=*/int_workspace.numel());
147+
/*dtype_data=*/(float*)float_workspace_ptr,
148+
/*dtype_size=*/float_numel,
149+
/*int_data=*/(int*)int_workspace_ptr,
150+
/*int_size=*/int_numel);
151+
152+
void *logit_ptr;
153+
aoti_torch_get_data_ptr(logits.get(), &logit_ptr);
154+
155+
void *target_ptr;
156+
aoti_torch_get_data_ptr(targets.get(), &target_ptr);
157+
158+
void *logit_len_ptr;
159+
aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr);
112160

113-
switch (logits.scalar_type()) {
114-
case torch::ScalarType::Float: {
161+
void *target_len_ptr;
162+
aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr);
163+
164+
void *costs_ptr;
165+
aoti_torch_get_data_ptr(costs, &costs_ptr);
166+
167+
void *grads_ptr;
168+
aoti_torch_get_data_ptr(gradients, &grads_ptr);
169+
170+
if (logits_dtype == aoti_torch_dtype_float32()) {
115171
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
116172
/*workspace=*/workspace,
117-
/*logits=*/logits.data_ptr<float>(),
118-
/*targets=*/targets.data_ptr<int>(),
119-
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
120-
/*target_lengths=*/target_lengths.data_ptr<int>(),
121-
/*costs=*/costs.data_ptr<float>(),
122-
/*gradients=*/gradients->data_ptr<float>());
123-
break;
124-
}
125-
case torch::ScalarType::Half: {
173+
/*logits=*/(float*)logit_ptr,
174+
/*targets=*/(int*)target_ptr,
175+
/*logit_lengths=*/(int*)logit_len_ptr,
176+
/*target_lengths=*/(int*)target_len_ptr,
177+
/*costs=*/(float*)costs_ptr,
178+
/*gradients=*/(float*)grads_ptr);
179+
} else {
126180
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
127181
/*workspace=*/workspace,
128-
/*logits=*/logits.data_ptr<c10::Half>(),
129-
/*targets=*/targets.data_ptr<int>(),
130-
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
131-
/*target_lengths=*/target_lengths.data_ptr<int>(),
132-
/*costs=*/costs.data_ptr<c10::Half>(),
133-
/*gradients=*/gradients->data_ptr<c10::Half>());
134-
break;
182+
/*logits=*/(c10::Half*)logit_ptr,
183+
/*targets=*/(int*)target_ptr,
184+
/*logit_lengths=*/(int*)logit_len_ptr,
185+
/*target_lengths=*/(int*)target_len_ptr,
186+
/*costs=*/(c10::Half*)costs_ptr,
187+
/*gradients=*/(c10::Half*)grads_ptr);
135188
}
136-
default: {
137-
break;
138-
}
139-
};
140189

141-
return std::make_tuple(costs, gradients);
190+
return std::make_tuple(RAIIATH(costs), RAIIATH(gradients));
191+
}
192+
193+
void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
194+
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
195+
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
196+
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
197+
RAIIATH t4(to<AtenTensorHandle>(stack[3]));
198+
int64_t blank = to<int64_t>(stack[4]);
199+
double clamp = to<double>(stack[5]);
200+
bool fused_log_softmax = to<bool>(stack[6]);
201+
auto result = compute(
202+
std::move(t1), std::move(t2), std::move(t3), std::move(t4),
203+
blank, clamp, fused_log_softmax);
204+
stack[0] = from((std::get<0>(result)).release());
205+
stack[1] = from((std::get<1>(result)).release());
142206
}
143207

144-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
145-
m.impl("rnnt_loss", &compute);
208+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
209+
m.impl("torchaudio::rnnt_loss", &boxed_compute);
146210
}
147211

148212
} // namespace cpu

0 commit comments

Comments
 (0)