44typedef void (*set_rows_kernel_t )(const char * src, char * dst);
55
66// Generic quantized set_rows kernel template
7- template <typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
7+ template <typename idx_t , typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
88static __global__ void k_set_rows_quant (
9- const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
9+ const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
1010 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
1111 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
1212 const int64_t s01, const int64_t s02, const int64_t s03,
@@ -45,9 +45,9 @@ static __global__ void k_set_rows_quant(
4545}
4646
4747// Template dispatch function for quantized set_rows
48- template <typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
48+ template <typename idx_t , typename block_type, int qk, void (*quantize_func)(const float *, block_type*)>
4949static void set_rows_cuda_quant (
50- const float * src0_d, const int64_t * src1_d, block_type * dst_d,
50+ const float * src0_d, const idx_t * src1_d, block_type * dst_d,
5151 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
5252 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
5353 const size_t nb01, const size_t nb02, const size_t nb03,
@@ -64,15 +64,15 @@ static void set_rows_cuda_quant(
6464 const int64_t s01 = nb01/sizeof (float );
6565 const int64_t s02 = nb02/sizeof (float );
6666 const int64_t s03 = nb03/sizeof (float );
67- const int64_t s10 = nb10/sizeof (int64_t );
68- const int64_t s11 = nb11/sizeof (int64_t );
69- const int64_t s12 = nb12/sizeof (int64_t );
67+ const int64_t s10 = nb10/sizeof (idx_t );
68+ const int64_t s11 = nb11/sizeof (idx_t );
69+ const int64_t s12 = nb12/sizeof (idx_t );
7070 const int64_t s1 = nb1;
7171 const int64_t s2 = nb2;
7272 const int64_t s3 = nb3;
7373
7474 if (ne_total > 0 ) {
75- k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
75+ k_set_rows_quant<idx_t , block_type, qk, quantize_func><<<grid_size, block_size, 0 , stream>>> (
7676 src0_d, src1_d, dst_d,
7777 ne00, ne01, ne02, ne03,
7878 ne10, ne11, ne12, ne13,
@@ -82,9 +82,9 @@ static void set_rows_cuda_quant(
8282 }
8383}
8484
85- template <typename src_t , typename dst_t >
85+ template <typename src_t , typename idx_t , typename dst_t >
8686static __global__ void k_set_rows (
87- const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
87+ const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
8888 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
8989 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
9090 const int64_t s01, const int64_t s02, const int64_t s03,
@@ -118,9 +118,9 @@ static __global__ void k_set_rows(
118118 GGML_UNUSED (ne13);
119119}
120120
121- template <typename src_t , typename dst_t >
121+ template <typename src_t , typename idx_t , typename dst_t >
122122static void set_rows_cuda (
123- const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
123+ const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
124124 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
125125 const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
126126 const size_t nb01, const size_t nb02, const size_t nb03,
@@ -137,9 +137,9 @@ static void set_rows_cuda(
137137 const int64_t s01 = nb01/sizeof (src_t );
138138 const int64_t s02 = nb02/sizeof (src_t );
139139 const int64_t s03 = nb03/sizeof (src_t );
140- const int64_t s10 = nb10/sizeof (int64_t );
141- const int64_t s11 = nb11/sizeof (int64_t );
142- const int64_t s12 = nb12/sizeof (int64_t );
140+ const int64_t s10 = nb10/sizeof (idx_t );
141+ const int64_t s11 = nb11/sizeof (idx_t );
142+ const int64_t s12 = nb12/sizeof (idx_t );
143143 const int64_t s1 = nb1/sizeof (dst_t );
144144 const int64_t s2 = nb2/sizeof (dst_t );
145145 const int64_t s3 = nb3/sizeof (dst_t );
@@ -155,23 +155,16 @@ static void set_rows_cuda(
155155 }
156156}
157157
158-
159- void ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
160- const ggml_tensor * src0 = dst->src [0 ];
161- const ggml_tensor * src1 = dst->src [1 ];
162-
163- GGML_ASSERT (src0->type == GGML_TYPE_F32);
164- GGML_ASSERT (src1->type == GGML_TYPE_I64);
158+ template <typename src_t , typename idx_t >
159+ static void set_rows_cuda (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
160+ const src_t * src0_d = (const src_t *)src0->data ;
161+ const idx_t * src1_d = (const idx_t *)src1->data ;
165162
166163 GGML_TENSOR_BINARY_OP_LOCALS
167164
168- const float * src0_d = (const float *)src0->data ;
169- const int64_t * src1_d = (const int64_t *)src1->data ;
170-
171165 cudaStream_t stream = ctx.stream ();
172166
173167
174-
175168 if (dst->type == GGML_TYPE_F32) {
176169 set_rows_cuda (
177170 src0_d, src1_d, (float *)dst->data ,
@@ -203,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
203196 stream
204197 );
205198 } else if (dst->type == GGML_TYPE_Q4_0) {
206- set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
199+ set_rows_cuda_quant<idx_t , block_q4_0, QK4_0, quantize_f32_q4_0_block>(
207200 src0_d, src1_d, (block_q4_0*)dst->data ,
208201 ne00, ne01, ne02, ne03,
209202 ne10, ne11, ne12, ne13,
@@ -213,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
213206 stream
214207 );
215208 } else if (dst->type == GGML_TYPE_Q4_1) {
216- set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
209+ set_rows_cuda_quant<idx_t , block_q4_1, QK4_1, quantize_f32_q4_1_block>(
217210 src0_d, src1_d, (block_q4_1*)dst->data ,
218211 ne00, ne01, ne02, ne03,
219212 ne10, ne11, ne12, ne13,
@@ -223,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
223216 stream
224217 );
225218 } else if (dst->type == GGML_TYPE_Q5_0) {
226- set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
219+ set_rows_cuda_quant<idx_t , block_q5_0, QK5_0, quantize_f32_q5_0_block>(
227220 src0_d, src1_d, (block_q5_0*)dst->data ,
228221 ne00, ne01, ne02, ne03,
229222 ne10, ne11, ne12, ne13,
@@ -233,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
233226 stream
234227 );
235228 } else if (dst->type == GGML_TYPE_Q5_1) {
236- set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
229+ set_rows_cuda_quant<idx_t , block_q5_1, QK5_1, quantize_f32_q5_1_block>(
237230 src0_d, src1_d, (block_q5_1*)dst->data ,
238231 ne00, ne01, ne02, ne03,
239232 ne10, ne11, ne12, ne13,
@@ -243,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
243236 stream
244237 );
245238 } else if (dst->type == GGML_TYPE_Q8_0) {
246- set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
239+ set_rows_cuda_quant<idx_t , block_q8_0, QK8_0, quantize_f32_q8_0_block>(
247240 src0_d, src1_d, (block_q8_0*)dst->data ,
248241 ne00, ne01, ne02, ne03,
249242 ne10, ne11, ne12, ne13,
@@ -253,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253246 stream
254247 );
255248 } else if (dst->type == GGML_TYPE_IQ4_NL) {
256- set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
249+ set_rows_cuda_quant<idx_t , block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
257250 src0_d, src1_d, (block_iq4_nl*)dst->data ,
258251 ne00, ne01, ne02, ne03,
259252 ne10, ne11, ne12, ne13,
@@ -266,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
266259 GGML_ABORT (" unsupported type %s" , ggml_type_name (dst->type ));
267260 }
268261}
262+
263+
264+ void ggml_cuda_op_set_rows (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
265+ const ggml_tensor * src0 = dst->src [0 ];
266+ const ggml_tensor * src1 = dst->src [1 ];
267+
268+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
269+ GGML_ASSERT (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
270+
271+ if (src1->type == GGML_TYPE_I64) {
272+ set_rows_cuda<float , int64_t >(ctx, src0, src1, dst);
273+ } else {
274+ set_rows_cuda<float , int32_t >(ctx, src0, src1, dst);
275+ }
276+ }
0 commit comments