Skip to content

Commit 22cca4c

Browse files
committed
SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators
1 parent 7a50cf3 commit 22cca4c

File tree

7 files changed

+334
-67
lines changed

7 files changed

+334
-67
lines changed

docs/ops.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Legend:
2222
| ARANGE ||||||||||
2323
| ARGMAX ||||||||||
2424
| ARGSORT ||||||||||
25-
| CEIL ||||||| |||
25+
| CEIL ||||||| |||
2626
| CLAMP ||||| 🟡 | 🟡 || 🟡 ||
2727
| CONCAT |||| 🟡 || 🟡 | 🟡 |||
2828
| CONT || 🟡 |||| 🟡 | 🟡 | 🟡 ||
@@ -42,7 +42,7 @@ Legend:
4242
| ELU |||| 🟡 | 🟡 || 🟡 |||
4343
| EXP |||| 🟡 | 🟡 || 🟡 |||
4444
| FLASH_ATTN_EXT || 🟡 || 🟡 | 🟡 ||| 🟡 ||
45-
| FLOOR ||||||| |||
45+
| FLOOR ||||||| |||
4646
| GATED_LINEAR_ATTN ||||||||||
4747
| GEGLU ||||| 🟡 ||| 🟡 ||
4848
| GEGLU_ERF ||||| 🟡 ||| 🟡 ||
@@ -84,7 +84,7 @@ Legend:
8484
| ROLL ||||||||||
8585
| ROPE || 🟡 ||||||||
8686
| ROPE_BACK ||||||||||
87-
| ROUND ||||||| |||
87+
| ROUND ||||||| |||
8888
| RWKV_WKV6 ||||||||||
8989
| RWKV_WKV7 ||||||||||
9090
| SCALE || 🟡 ||||||||
@@ -111,6 +111,6 @@ Legend:
111111
| TANH |||| 🟡 | 🟡 || 🟡 | 🟡 ||
112112
| TIMESTEP_EMBEDDING ||||||||||
113113
| TOPK_MOE ||||||||||
114-
| TRUNC ||||||| |||
114+
| TRUNC ||||||| |||
115115
| UPSCALE || 🟡 ||| 🟡 || 🟡 |||
116116
| XIELU ||||||||||

docs/ops/SYCL.csv

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@
3131
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
3232
"SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
3333
"SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
34+
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
35+
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
36+
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
37+
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
38+
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
39+
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
40+
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
41+
"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
3442
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
3543
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
3644
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
@@ -95,6 +103,14 @@
95103
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
96104
"SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
97105
"SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
106+
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
107+
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
108+
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
109+
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
110+
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
111+
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
112+
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
113+
"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
98114
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
99115
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
100116
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"

ggml/include/ggml.h

Lines changed: 42 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -286,19 +286,19 @@ __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexc
286286
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
287287
//
288288
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
289-
const type prefix##0 = (pointer) ? (pointer)->array[0] : 0; \
289+
const type prefix##0 = (pointer)->array[0]; \
290290
GGML_UNUSED(prefix##0);
291291
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
292292
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
293-
const type prefix##1 = (pointer) ? (pointer)->array[1] : 0; \
293+
const type prefix##1 = (pointer)->array[1]; \
294294
GGML_UNUSED(prefix##1);
295295
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
296296
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
297-
const type prefix##2 = (pointer) ? (pointer)->array[2] : 0; \
297+
const type prefix##2 = (pointer)->array[2]; \
298298
GGML_UNUSED(prefix##2);
299299
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
300300
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
301-
const type prefix##3 = (pointer) ? (pointer)->array[3] : 0; \
301+
const type prefix##3 = (pointer)->array[3]; \
302302
GGML_UNUSED(prefix##3);
303303

304304
#define GGML_TENSOR_UNARY_OP_LOCALS \
@@ -513,7 +513,6 @@ extern "C" {
513513
GGML_OP_CONV_TRANSPOSE_1D,
514514
GGML_OP_IM2COL,
515515
GGML_OP_IM2COL_BACK,
516-
GGML_OP_IM2COL_3D,
517516
GGML_OP_CONV_2D,
518517
GGML_OP_CONV_3D,
519518
GGML_OP_CONV_2D_DW,
@@ -580,8 +579,9 @@ extern "C" {
580579
GGML_UNARY_OP_FLOOR,
581580
GGML_UNARY_OP_CEIL,
582581
GGML_UNARY_OP_ROUND,
583-
GGML_UNARY_OP_TRUNC,
584-
582+
GGML_UNARY_OP_TRUNC,
583+
584+
585585
GGML_UNARY_OP_COUNT,
586586
};
587587

@@ -1103,7 +1103,40 @@ extern "C" {
11031103
GGML_API struct ggml_tensor * ggml_gelu_inplace(
11041104
struct ggml_context * ctx,
11051105
struct ggml_tensor * a);
1106+
1107+
GGML_API struct ggml_tensor * ggml_floor(
1108+
struct ggml_context * ctx,
1109+
struct ggml_tensor * a);
1110+
1111+
GGML_API struct ggml_tensor * ggml_floor_inplace(
1112+
struct ggml_context * ctx,
1113+
struct ggml_tensor * a);
1114+
1115+
GGML_API struct ggml_tensor * ggml_ceil(
1116+
struct ggml_context * ctx,
1117+
struct ggml_tensor * a);
1118+
1119+
GGML_API struct ggml_tensor * ggml_ceil_inplace(
1120+
struct ggml_context * ctx,
1121+
struct ggml_tensor * a);
11061122

1123+
GGML_API struct ggml_tensor * ggml_round(
1124+
struct ggml_context * ctx,
1125+
struct ggml_tensor * a);
1126+
1127+
GGML_API struct ggml_tensor * ggml_round_inplace(
1128+
struct ggml_context * ctx,
1129+
struct ggml_tensor * a);
1130+
1131+
GGML_API struct ggml_tensor * ggml_trunc(
1132+
struct ggml_context * ctx,
1133+
struct ggml_tensor * a);
1134+
1135+
GGML_API struct ggml_tensor * ggml_trunc_inplace(
1136+
struct ggml_context * ctx,
1137+
struct ggml_tensor * a);
1138+
1139+
11071140
// GELU using erf (error function) when possible
11081141
// some backends may fallback to approximation based on Abramowitz and Stegun formula
11091142
GGML_API struct ggml_tensor * ggml_gelu_erf(
@@ -1463,7 +1496,6 @@ extern "C" {
14631496
struct ggml_tensor * a,
14641497
struct ggml_tensor * b);
14651498

1466-
// note: casting from f32 to i32 will discard the fractional part
14671499
GGML_API struct ggml_tensor * ggml_cast(
14681500
struct ggml_context * ctx,
14691501
struct ggml_tensor * a,
@@ -1588,11 +1620,7 @@ extern "C" {
15881620
struct ggml_context * ctx,
15891621
struct ggml_tensor * a);
15901622

1591-
// supports 4D a:
1592-
// a [n_embd, ne1, ne2, ne3]
1593-
// b I32 [n_rows, ne2, ne3, 1]
1594-
//
1595-
// return [n_embd, n_rows, ne2, ne3]
1623+
// supports 3D: a->ne[2] == b->ne[1]
15961624
GGML_API struct ggml_tensor * ggml_get_rows(
15971625
struct ggml_context * ctx,
15981626
struct ggml_tensor * a, // data
@@ -1942,41 +1970,6 @@ extern "C" {
19421970
int d0, // dilation dimension 0
19431971
int d1); // dilation dimension 1
19441972

1945-
GGML_API struct ggml_tensor * ggml_im2col_3d(
1946-
struct ggml_context * ctx,
1947-
struct ggml_tensor * a,
1948-
struct ggml_tensor * b,
1949-
int64_t IC,
1950-
int s0, // stride width
1951-
int s1, // stride height
1952-
int s2, // stride depth
1953-
int p0, // padding width
1954-
int p1, // padding height
1955-
int p2, // padding depth
1956-
int d0, // dilation width
1957-
int d1, // dilation height
1958-
int d2, // dilation depth
1959-
enum ggml_type dst_type);
1960-
1961-
// a: [OC*IC, KD, KH, KW]
1962-
// b: [N*IC, ID, IH, IW]
1963-
// result: [N*OC, OD, OH, OW]
1964-
GGML_API struct ggml_tensor * ggml_conv_3d(
1965-
struct ggml_context * ctx,
1966-
struct ggml_tensor * a,
1967-
struct ggml_tensor * b,
1968-
int64_t IC,
1969-
int s0, // stride width
1970-
int s1, // stride height
1971-
int s2, // stride depth
1972-
int p0, // padding width
1973-
int p1, // padding height
1974-
int p2, // padding depth
1975-
int d0, // dilation width
1976-
int d1, // dilation height
1977-
int d2 // dilation depth
1978-
);
1979-
19801973
// kernel size is a->ne[0] x a->ne[1]
19811974
// stride is equal to kernel size
19821975
// padding is zero
@@ -2048,7 +2041,7 @@ extern "C" {
20482041
int d0, // dilation dimension 0
20492042
int d1); // dilation dimension 1
20502043

2051-
GGML_API struct ggml_tensor * ggml_conv_3d_direct(
2044+
GGML_API struct ggml_tensor * ggml_conv_3d(
20522045
struct ggml_context * ctx,
20532046
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
20542047
struct ggml_tensor * b, // input [W, H, D, C * N]
@@ -2155,19 +2148,6 @@ extern "C" {
21552148
int p2,
21562149
int p3);
21572150

2158-
GGML_API struct ggml_tensor * ggml_pad_ext(
2159-
struct ggml_context * ctx,
2160-
struct ggml_tensor * a,
2161-
int lp0,
2162-
int rp0,
2163-
int lp1,
2164-
int rp1,
2165-
int lp2,
2166-
int rp2,
2167-
int lp3,
2168-
int rp3
2169-
);
2170-
21712151
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
21722152
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
21732153
struct ggml_context * ctx,

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
150150
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
151151
}
152152

153+
template<typename T>
154+
static __dpct_inline__ T op_floor(T x) {
155+
return sycl::floor(x);
156+
}
157+
158+
template<typename T>
159+
static __dpct_inline__ T op_ceil(T x) {
160+
return sycl::ceil(x);
161+
}
162+
163+
template<typename T>
164+
static __dpct_inline__ T op_round(T x) {
165+
return sycl::round(x);
166+
}
167+
168+
template<typename T>
169+
static __dpct_inline__ T op_trunc(T x) {
170+
return sycl::trunc(x);
171+
}
172+
153173
template<typename T>
154174
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
155175
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
304324
}
305325
}
306326

327+
template<typename T>
328+
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
329+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
330+
dst[i] = op_floor(x[i]);
331+
}
332+
}
333+
334+
template<typename T>
335+
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
336+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
337+
dst[i] = op_ceil(x[i]);
338+
}
339+
}
340+
341+
template<typename T>
342+
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
343+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
344+
dst[i] = op_round(x[i]);
345+
}
346+
}
347+
348+
template<typename T>
349+
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
350+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
351+
dst[i] = op_trunc(x[i]);
352+
}
353+
}
354+
307355
template<typename T>
308356
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
309357
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -870,6 +918,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
870918
}, min_val, max_val);
871919
}
872920

921+
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
922+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
923+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
924+
const int num_blocks = ceil_div(k_elements, 256);
925+
stream->parallel_for(
926+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
927+
sycl::range<1>(256)),
928+
[=](sycl::nd_item<1> item_ct1) {
929+
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
930+
});
931+
});
932+
}
933+
934+
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
935+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
936+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
937+
const int num_blocks = ceil_div(k_elements, 256);
938+
stream->parallel_for(
939+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
940+
sycl::range<1>(256)),
941+
[=](sycl::nd_item<1> item_ct1) {
942+
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
943+
});
944+
});
945+
}
946+
947+
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
948+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
949+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
950+
const int num_blocks = ceil_div(k_elements, 256);
951+
stream->parallel_for(
952+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
953+
sycl::range<1>(256)),
954+
[=](sycl::nd_item<1> item_ct1) {
955+
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
956+
});
957+
});
958+
}
959+
960+
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
961+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
962+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
963+
const int num_blocks = ceil_div(k_elements, 256);
964+
stream->parallel_for(
965+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
966+
sycl::range<1>(256)),
967+
[=](sycl::nd_item<1> item_ct1) {
968+
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
969+
});
970+
});
971+
}
972+
873973
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
874974
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
875975
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -944,7 +1044,6 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
9441044
});
9451045
}
9461046

947-
9481047
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
9491048
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
9501049
ggml_sycl_op_sqrt(ctx, dst);
@@ -1090,3 +1189,23 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10901189
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
10911190
ggml_sycl_op_geglu_quick(ctx, dst);
10921191
}
1192+
1193+
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1194+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1195+
ggml_sycl_op_floor(ctx, dst);
1196+
}
1197+
1198+
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1199+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1200+
ggml_sycl_op_ceil(ctx, dst);
1201+
}
1202+
1203+
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1204+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1205+
ggml_sycl_op_round(ctx, dst);
1206+
}
1207+
1208+
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1209+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1210+
ggml_sycl_op_trunc(ctx, dst);
1211+
}

0 commit comments

Comments
 (0)