66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < cstring>
10-
11- #include < xa_nnlib_kernels_api.h>
12-
9+ #include < executorch/backends/cadence/fusion_g3/operators/tensor_util.h>
1310#include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
1411#include < executorch/runtime/kernel/kernel_includes.h>
12+ #include < xa_nnlib_kernels_api.h>
13+ #include < cstring>
1514
16- using ::executorch::aten::ScalarType;
17- using ::executorch::aten::Tensor;
18- using ::executorch::runtime::Error;
19- using ::executorch::runtime::KernelRuntimeContext;
15+ using exec_aten::Scalar;
16+ using exec_aten::ScalarType;
17+ using exec_aten::Tensor;
18+ using torch::executor::Error;
19+ using torch::executor::KernelRuntimeContext;
2020
2121/* ScalarType in Executorch do not have support for below data types.
2222 * So, creating a placeholder for these data types. Once, ScalarTypes is
@@ -39,13 +39,15 @@ Tensor& cat_out(
3939 dim += out.dim ();
4040 }
4141
42+ int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit ;
43+
44+ #ifdef OP_ARG_CHECK
4245 ET_KERNEL_CHECK (
4346 ctx,
4447 torch::executor::check_cat_args (tensors, dim, out),
4548 InvalidArgument,
4649 out);
4750
48- int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit ;
4951 Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
5052 size_t expected_out_dim = 0 ;
5153 torch::executor::get_cat_out_target_size (
@@ -57,6 +59,20 @@ Tensor& cat_out(
5759 out, {expected_out_size, expected_out_dim}) == Error::Ok,
5860 InvalidArgument,
5961 out);
62+ #endif
63+ // Special handling when all inputs are 1D-empty tensors for aten
64+ // consistency In that case, just return an 1D-empty tensor without checking
65+ // dim
66+ bool all_1d_empty = true ;
67+ for (size_t i = 0 ; i < tensors.size (); ++i) {
68+ if (tensors[i].numel () != 0 || tensors[i].dim () != 1 ) {
69+ all_1d_empty = false ;
70+ break ;
71+ }
72+ }
73+ if (all_1d_empty) {
74+ return out;
75+ }
6076
6177 const signed char * inp_tensors[tensors.size ()];
6278 const int * inp_tensors_shapes[tensors.size ()];
@@ -87,7 +103,10 @@ Tensor& cat_out(
87103 }
88104
89105 if (out.scalar_type () == ScalarType::Int) {
90- xa_nn_cat (
106+ XT_KERNEL_CHECK (
107+ ctx,
108+ out,
109+ xa_nn_cat,
91110 out_data,
92111 out_shapes,
93112 inp_tensors,
@@ -97,7 +116,10 @@ Tensor& cat_out(
97116 (int )dim,
98117 sizeof (int ));
99118 } else if (out.scalar_type () == ScalarType::Short) {
100- xa_nn_cat (
119+ XT_KERNEL_CHECK (
120+ ctx,
121+ out,
122+ xa_nn_cat,
101123 out_data,
102124 out_shapes,
103125 inp_tensors,
@@ -107,7 +129,10 @@ Tensor& cat_out(
107129 (int )dim,
108130 sizeof (short ));
109131 } else if (out.scalar_type () == ScalarType::Char) {
110- xa_nn_cat (
132+ XT_KERNEL_CHECK (
133+ ctx,
134+ out,
135+ xa_nn_cat,
111136 out_data,
112137 out_shapes,
113138 inp_tensors,
@@ -117,7 +142,10 @@ Tensor& cat_out(
117142 (int )dim,
118143 sizeof (char ));
119144 } else if (out.scalar_type () == (ScalarType)Uint) {
120- xa_nn_cat (
145+ XT_KERNEL_CHECK (
146+ ctx,
147+ out,
148+ xa_nn_cat,
121149 out_data,
122150 out_shapes,
123151 inp_tensors,
@@ -127,7 +155,10 @@ Tensor& cat_out(
127155 (int )dim,
128156 sizeof (int ));
129157 } else if (out.scalar_type () == (ScalarType)Ushort) {
130- xa_nn_cat (
158+ XT_KERNEL_CHECK (
159+ ctx,
160+ out,
161+ xa_nn_cat,
131162 out_data,
132163 out_shapes,
133164 inp_tensors,
@@ -137,7 +168,10 @@ Tensor& cat_out(
137168 (int )dim,
138169 sizeof (short ));
139170 } else if (out.scalar_type () == ScalarType::Byte) {
140- xa_nn_cat (
171+ XT_KERNEL_CHECK (
172+ ctx,
173+ out,
174+ xa_nn_cat,
141175 out_data,
142176 out_shapes,
143177 inp_tensors,
@@ -148,19 +182,6 @@ Tensor& cat_out(
148182 sizeof (char ));
149183
150184 } else {
151- // Special handling when all inputs are 1D-empty tensors for aten
152- // consistency In that case, just return an 1D-empty tensor without checking
153- // dim
154- bool all_1d_empty = true ;
155- for (size_t i = 0 ; i < tensors.size (); ++i) {
156- if (tensors[i].numel () != 0 || tensors[i].dim () != 1 ) {
157- all_1d_empty = false ;
158- break ;
159- }
160- }
161- if (all_1d_empty) {
162- return out;
163- }
164185 const size_t outer = executorch::runtime::getLeadingDims (out, dim);
165186 const size_t dim_stride = executorch::runtime::getTrailingDims (out, dim);
166187 const size_t ninputs = tensors.size ();
0 commit comments