|
1 | 1 | #include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
|
2 |
| -#include <torch/script.h> |
| 2 | +#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 3 | +#include <torch/csrc/inductor/aoti_runtime/utils.h> |
| 4 | +#include <torch/csrc/stable/library.h> |
3 | 5 |
|
4 | 6 | namespace torchaudio {
|
5 | 7 | namespace rnnt {
|
6 | 8 | namespace cpu {
|
7 | 9 |
|
| 10 | +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; |
| 11 | + |
8 | 12 | // Entry point into RNNT Loss
|
9 |
| -std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute( |
10 |
| - torch::Tensor& logits, |
11 |
| - const torch::Tensor& targets, |
12 |
| - const torch::Tensor& logit_lengths, |
13 |
| - const torch::Tensor& target_lengths, |
| 13 | +std::tuple<RAIIATH, RAIIATH> compute( |
| 14 | + const RAIIATH logits, |
| 15 | + const RAIIATH targets, |
| 16 | + const RAIIATH logit_lengths, |
| 17 | + const RAIIATH target_lengths, |
14 | 18 | int64_t blank,
|
15 | 19 | double clamp,
|
16 | 20 | bool fused_log_softmax = true) {
|
17 |
| - TORCH_CHECK( |
18 |
| - logits.device().type() == targets.device().type(), |
19 |
| - "logits and targets must be on the same device"); |
20 |
| - TORCH_CHECK( |
21 |
| - logits.device().type() == logit_lengths.device().type(), |
22 |
| - "logits and logit_lengths must be on the same device"); |
23 |
| - TORCH_CHECK( |
24 |
| - logits.device().type() == target_lengths.device().type(), |
25 |
| - "logits and target_lengths must be on the same device"); |
26 |
| - |
27 |
| - TORCH_CHECK( |
28 |
| - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, |
29 |
| - "logits must be float32 or float16 (half) type"); |
30 |
| - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); |
31 |
| - TORCH_CHECK( |
32 |
| - logit_lengths.dtype() == torch::kInt32, |
33 |
| - "logit_lengths must be int32 type"); |
34 |
| - TORCH_CHECK( |
35 |
| - target_lengths.dtype() == torch::kInt32, |
36 |
| - "target_lengths must be int32 type"); |
37 |
| - |
38 |
| - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); |
39 |
| - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); |
40 |
| - TORCH_CHECK( |
41 |
| - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); |
42 |
| - TORCH_CHECK( |
43 |
| - target_lengths.is_contiguous(), "target_lengths must be contiguous"); |
44 |
| - |
45 |
| - TORCH_CHECK( |
46 |
| - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); |
47 |
| - TORCH_CHECK( |
48 |
| - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); |
49 |
| - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); |
50 |
| - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); |
51 |
| - |
52 |
| - TORCH_CHECK( |
53 |
| - logit_lengths.size(0) == logits.size(0), |
54 |
| - "batch dimension mismatch between logits and logit_lengths"); |
55 |
| - TORCH_CHECK( |
56 |
| - target_lengths.size(0) == logits.size(0), |
57 |
| - "batch dimension mismatch between logits and target_lengths"); |
58 |
| - TORCH_CHECK( |
59 |
| - targets.size(0) == logits.size(0), |
60 |
| - "batch dimension mismatch between logits and targets"); |
61 |
| - |
62 |
| - TORCH_CHECK( |
63 |
| - blank >= 0 && blank < logits.size(-1), |
64 |
| - "blank must be within [0, logits.shape[-1])"); |
65 |
| - |
66 |
| - TORCH_CHECK( |
67 |
| - logits.size(1) == at::max(logit_lengths).item().toInt(), |
68 |
| - "input length mismatch"); |
69 |
| - TORCH_CHECK( |
70 |
| - logits.size(2) == at::max(target_lengths).item().toInt() + 1, |
71 |
| - "output length mismatch"); |
72 |
| - TORCH_CHECK( |
73 |
| - targets.size(1) == at::max(target_lengths).item().toInt(), |
74 |
| - "target length mismatch"); |
| 21 | + |
| 22 | + int32_t logits_device; |
| 23 | + aoti_torch_get_device_type(logits.get(), &logits_device); |
| 24 | + int32_t targets_device; |
| 25 | + aoti_torch_get_device_type(targets.get(), &targets_device); |
| 26 | + int32_t logit_lengths_device; |
| 27 | + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); |
| 28 | + int32_t target_lengths_device; |
| 29 | + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); |
| 30 | + |
| 31 | + AOTI_TORCH_CHECK(logits_device == targets_device); |
| 32 | + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); |
| 33 | + AOTI_TORCH_CHECK(logits_device == target_lengths_device); |
| 34 | + |
| 35 | + int32_t logits_dtype; |
| 36 | + aoti_torch_get_dtype(logits.get(), &logits_dtype); |
| 37 | + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || |
| 38 | + logits_dtype == aoti_torch_dtype_float16()); |
| 39 | + |
| 40 | + int32_t targets_dtype; |
| 41 | + aoti_torch_get_dtype(targets.get(), &targets_dtype); |
| 42 | + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || |
| 43 | + logits_dtype == aoti_torch_dtype_float16()); |
| 44 | + |
| 45 | + int32_t logit_lengths_dtype; |
| 46 | + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); |
| 47 | + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || |
| 48 | + logit_lengths_dtype == aoti_torch_dtype_float16()); |
| 49 | + |
| 50 | + int32_t target_lengths_dtype; |
| 51 | + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); |
| 52 | + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || |
| 53 | + target_lengths_dtype == aoti_torch_dtype_float16()); |
| 54 | + |
| 55 | + bool bool_tmp; |
| 56 | + aoti_torch_is_contiguous(logits.get(), &bool_tmp); |
| 57 | + AOTI_TORCH_CHECK(bool_tmp); |
| 58 | + aoti_torch_is_contiguous(targets.get(), &bool_tmp); |
| 59 | + AOTI_TORCH_CHECK(bool_tmp); |
| 60 | + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); |
| 61 | + AOTI_TORCH_CHECK(bool_tmp); |
| 62 | + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); |
| 63 | + |
| 64 | + int64_t int_tmp; |
| 65 | + aoti_torch_get_dim(logits.get(), &int_tmp); |
| 66 | + AOTI_TORCH_CHECK(int_tmp == 4); |
| 67 | + aoti_torch_get_dim(targets.get(), &int_tmp); |
| 68 | + AOTI_TORCH_CHECK(int_tmp == 2); |
| 69 | + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); |
| 70 | + AOTI_TORCH_CHECK(int_tmp == 1); |
| 71 | + aoti_torch_get_dim(target_lengths.get(), &int_tmp); |
| 72 | + AOTI_TORCH_CHECK(int_tmp == 1); |
| 73 | + |
| 74 | + int64_t logit_lengths_size; |
| 75 | + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); |
| 76 | + int64_t logits_size; |
| 77 | + aoti_torch_get_size(logits.get(), 0, &logits_size); |
| 78 | + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); |
| 79 | + int64_t target_lengths_size; |
| 80 | + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); |
| 81 | + AOTI_TORCH_CHECK(target_lengths_size == logits_size); |
| 82 | + int64_t targets_size; |
| 83 | + aoti_torch_get_size(targets.get(), 0, &targets_size); |
| 84 | + AOTI_TORCH_CHECK(targets_size == logits_size); |
| 85 | + |
| 86 | + // TORCH_CHECK( |
| 87 | + // blank >= 0 && blank < logits.size(-1), |
| 88 | + // "blank must be within [0, logits.shape[-1])"); |
| 89 | + |
| 90 | + // TORCH_CHECK( |
| 91 | + // logits.size(1) == at::max(logit_lengths).item().toInt(), |
| 92 | + // "input length mismatch"); |
| 93 | + // TORCH_CHECK( |
| 94 | + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, |
| 95 | + // "output length mismatch"); |
| 96 | + // TORCH_CHECK( |
| 97 | + // targets.size(1) == at::max(target_lengths).item().toInt(), |
| 98 | + // "target length mismatch"); |
75 | 99 |
|
76 | 100 | Options options;
|
77 |
| - options.batchSize_ = logit_lengths.size(0); |
78 |
| - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); |
79 |
| - options.maxSrcLen_ = logits.size(1); |
80 |
| - options.maxTgtLen_ = logits.size(2); |
81 |
| - options.numTargets_ = logits.size(3); |
| 101 | + options.batchSize_ = (int)logit_lengths_size; |
| 102 | + options.nHypos_ = (int)target_lengths_size; |
| 103 | + options.nHypos_ /= options.batchSize_; |
| 104 | + aoti_torch_get_size(logits.get(), 1, &int_tmp); |
| 105 | + options.maxSrcLen_ = (int)int_tmp; |
| 106 | + aoti_torch_get_size(logits.get(), 2, &int_tmp); |
| 107 | + options.maxTgtLen_ = (int)int_tmp; |
| 108 | + aoti_torch_get_size(logits.get(), 3, &int_tmp); |
| 109 | + options.numTargets_ = (int)int_tmp; |
82 | 110 | options.blank_ = blank;
|
83 | 111 | options.clamp_ = clamp;
|
84 | 112 | options.fusedLogSmax_ = fused_log_softmax;
|
85 | 113 |
|
86 |
| - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); |
| 114 | + AOTI_TORCH_CHECK(logits_device == aoti_torch_device_type_cpu()); |
87 | 115 | options.device_ = CPU;
|
88 | 116 |
|
89 |
| - torch::Tensor costs = torch::empty( |
90 |
| - options.batchSize_ * options.nHypos_, |
91 |
| - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); |
92 |
| - std::optional<torch::Tensor> gradients = torch::zeros_like(logits); |
93 |
| - |
94 |
| - torch::Tensor int_workspace = torch::empty( |
95 |
| - IntWorkspace::ComputeSizeFromOptions(options), |
96 |
| - torch::TensorOptions() |
97 |
| - .device(logits.device()) |
98 |
| - .dtype(torch::ScalarType::Int)); |
99 |
| - |
100 |
| - torch::Tensor float_workspace = torch::empty( |
101 |
| - DtypeWorkspace<float>::ComputeSizeFromOptions(options), |
102 |
| - torch::TensorOptions() |
103 |
| - .device(logits.device()) |
104 |
| - .dtype(torch::ScalarType::Float)); |
| 117 | + int32_t logits_device_index; |
| 118 | + aoti_torch_get_device_index(logits.get(), &logits_device_index); |
| 119 | + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; |
| 120 | + int64_t stride1[1] = {1}; |
| 121 | + AtenTensorHandle costs; |
| 122 | + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); |
| 123 | + |
| 124 | + AtenTensorHandle gradients; |
| 125 | + aoti_torch_clone(logits.get(), &gradients); |
| 126 | + aoti_torch_zero_(gradients); |
| 127 | + |
| 128 | + AtenTensorHandle int_workspace; |
| 129 | + int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; |
| 130 | + int64_t strides[1] = {1}; |
| 131 | + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); |
| 132 | + |
| 133 | + AtenTensorHandle float_workspace; |
| 134 | + aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); |
| 135 | + |
| 136 | + int64_t float_numel; |
| 137 | + aoti_torch_get_numel(float_workspace, &float_numel); |
| 138 | + void *int_workspace_ptr; |
| 139 | + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); |
| 140 | + void *float_workspace_ptr; |
| 141 | + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); |
| 142 | + int64_t int_numel; |
| 143 | + aoti_torch_get_numel(int_workspace, &int_numel); |
105 | 144 |
|
106 | 145 | Workspace<float> workspace(
|
107 | 146 | /*options=*/options,
|
108 |
| - /*dtype_data=*/float_workspace.data_ptr<float>(), |
109 |
| - /*dtype_size=*/float_workspace.numel(), |
110 |
| - /*int_data=*/int_workspace.data_ptr<int>(), |
111 |
| - /*int_size=*/int_workspace.numel()); |
| 147 | + /*dtype_data=*/(float*)float_workspace_ptr, |
| 148 | + /*dtype_size=*/float_numel, |
| 149 | + /*int_data=*/(int*)int_workspace_ptr, |
| 150 | + /*int_size=*/int_numel); |
| 151 | + |
| 152 | + void *logit_ptr; |
| 153 | + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); |
| 154 | + |
| 155 | + void *target_ptr; |
| 156 | + aoti_torch_get_data_ptr(targets.get(), &target_ptr); |
| 157 | + |
| 158 | + void *logit_len_ptr; |
| 159 | + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); |
112 | 160 |
|
113 |
| - switch (logits.scalar_type()) { |
114 |
| - case torch::ScalarType::Float: { |
| 161 | + void *target_len_ptr; |
| 162 | + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); |
| 163 | + |
| 164 | + void *costs_ptr; |
| 165 | + aoti_torch_get_data_ptr(costs, &costs_ptr); |
| 166 | + |
| 167 | + void *grads_ptr; |
| 168 | + aoti_torch_get_data_ptr(gradients, &grads_ptr); |
| 169 | + |
| 170 | + if (logits_dtype == aoti_torch_dtype_float32()) { |
115 | 171 | Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
|
116 | 172 | /*workspace=*/workspace,
|
117 |
| - /*logits=*/logits.data_ptr<float>(), |
118 |
| - /*targets=*/targets.data_ptr<int>(), |
119 |
| - /*logit_lengths=*/logit_lengths.data_ptr<int>(), |
120 |
| - /*target_lengths=*/target_lengths.data_ptr<int>(), |
121 |
| - /*costs=*/costs.data_ptr<float>(), |
122 |
| - /*gradients=*/gradients->data_ptr<float>()); |
123 |
| - break; |
124 |
| - } |
125 |
| - case torch::ScalarType::Half: { |
| 173 | + /*logits=*/(float*)logit_ptr, |
| 174 | + /*targets=*/(int*)target_ptr, |
| 175 | + /*logit_lengths=*/(int*)logit_len_ptr, |
| 176 | + /*target_lengths=*/(int*)target_len_ptr, |
| 177 | + /*costs=*/(float*)costs_ptr, |
| 178 | + /*gradients=*/(float*)grads_ptr); |
| 179 | + } else { |
126 | 180 | Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
|
127 | 181 | /*workspace=*/workspace,
|
128 |
| - /*logits=*/logits.data_ptr<c10::Half>(), |
129 |
| - /*targets=*/targets.data_ptr<int>(), |
130 |
| - /*logit_lengths=*/logit_lengths.data_ptr<int>(), |
131 |
| - /*target_lengths=*/target_lengths.data_ptr<int>(), |
132 |
| - /*costs=*/costs.data_ptr<c10::Half>(), |
133 |
| - /*gradients=*/gradients->data_ptr<c10::Half>()); |
134 |
| - break; |
| 182 | + /*logits=*/(c10::Half*)logit_ptr, |
| 183 | + /*targets=*/(int*)target_ptr, |
| 184 | + /*logit_lengths=*/(int*)logit_len_ptr, |
| 185 | + /*target_lengths=*/(int*)target_len_ptr, |
| 186 | + /*costs=*/(c10::Half*)costs_ptr, |
| 187 | + /*gradients=*/(c10::Half*)grads_ptr); |
135 | 188 | }
|
136 |
| - default: { |
137 |
| - break; |
138 |
| - } |
139 |
| - }; |
140 | 189 |
|
141 |
| - return std::make_tuple(costs, gradients); |
| 190 | + return std::make_tuple(RAIIATH(costs), RAIIATH(gradients)); |
| 191 | +} |
| 192 | + |
| 193 | +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { |
| 194 | + RAIIATH t1(to<AtenTensorHandle>(stack[0])); |
| 195 | + RAIIATH t2(to<AtenTensorHandle>(stack[1])); |
| 196 | + RAIIATH t3(to<AtenTensorHandle>(stack[2])); |
| 197 | + RAIIATH t4(to<AtenTensorHandle>(stack[3])); |
| 198 | + int64_t blank = to<int64_t>(stack[4]); |
| 199 | + double clamp = to<double>(stack[5]); |
| 200 | + bool fused_log_softmax = to<bool>(stack[6]); |
| 201 | + auto result = compute( |
| 202 | + std::move(t1), std::move(t2), std::move(t3), std::move(t4), |
| 203 | + blank, clamp, fused_log_softmax); |
| 204 | + stack[0] = from((std::get<0>(result)).release()); |
| 205 | + stack[1] = from((std::get<1>(result)).release()); |
142 | 206 | }
|
143 | 207 |
|
144 |
| -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
145 |
| - m.impl("rnnt_loss", &compute); |
| 208 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
| 209 | + m.impl("torchaudio::rnnt_loss", &boxed_compute); |
146 | 210 | }
|
147 | 211 |
|
148 | 212 | } // namespace cpu
|
|
0 commit comments