Skip to content

Commit d7ded21

Browse files
committed
SYCL: Add support for FLOOR,CEIL,ROUND and TRUNC unary operators
Clean up unrelated changes from previous commit
1 parent 7a50cf3 commit d7ded21

File tree

7 files changed

+327
-6
lines changed

7 files changed

+327
-6
lines changed

docs/ops.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Legend:
2222
| ARANGE ||||||||||
2323
| ARGMAX ||||||||||
2424
| ARGSORT ||||||||||
25-
| CEIL ||||||| |||
25+
| CEIL ||||||| |||
2626
| CLAMP ||||| 🟡 | 🟡 || 🟡 ||
2727
| CONCAT |||| 🟡 || 🟡 | 🟡 |||
2828
| CONT || 🟡 |||| 🟡 | 🟡 | 🟡 ||
@@ -42,7 +42,7 @@ Legend:
4242
| ELU |||| 🟡 | 🟡 || 🟡 |||
4343
| EXP |||| 🟡 | 🟡 || 🟡 |||
4444
| FLASH_ATTN_EXT || 🟡 || 🟡 | 🟡 ||| 🟡 ||
45-
| FLOOR ||||||| |||
45+
| FLOOR ||||||| |||
4646
| GATED_LINEAR_ATTN ||||||||||
4747
| GEGLU ||||| 🟡 ||| 🟡 ||
4848
| GEGLU_ERF ||||| 🟡 ||| 🟡 ||
@@ -84,7 +84,7 @@ Legend:
8484
| ROLL ||||||||||
8585
| ROPE || 🟡 ||||||||
8686
| ROPE_BACK ||||||||||
87-
| ROUND ||||||| |||
87+
| ROUND ||||||| |||
8888
| RWKV_WKV6 ||||||||||
8989
| RWKV_WKV7 ||||||||||
9090
| SCALE || 🟡 ||||||||
@@ -111,6 +111,6 @@ Legend:
111111
| TANH |||| 🟡 | 🟡 || 🟡 | 🟡 ||
112112
| TIMESTEP_EMBEDDING ||||||||||
113113
| TOPK_MOE ||||||||||
114-
| TRUNC ||||||| |||
114+
| TRUNC ||||||| |||
115115
| UPSCALE || 🟡 ||| 🟡 || 🟡 |||
116116
| XIELU ||||||||||

docs/ops/SYCL.csv

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@
3131
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
3232
"SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
3333
"SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
34+
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
35+
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
36+
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
37+
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
38+
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
39+
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
40+
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
41+
"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
3442
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
3543
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
3644
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
@@ -95,6 +103,14 @@
95103
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
96104
"SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
97105
"SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
106+
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
107+
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
108+
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
109+
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
110+
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
111+
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
112+
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
113+
"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
98114
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
99115
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
100116
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"

ggml/include/ggml.h

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ extern "C" {
581581
GGML_UNARY_OP_CEIL,
582582
GGML_UNARY_OP_ROUND,
583583
GGML_UNARY_OP_TRUNC,
584-
584+
585585
GGML_UNARY_OP_COUNT,
586586
};
587587

@@ -1103,7 +1103,40 @@ extern "C" {
11031103
GGML_API struct ggml_tensor * ggml_gelu_inplace(
11041104
struct ggml_context * ctx,
11051105
struct ggml_tensor * a);
1106+
1107+
GGML_API struct ggml_tensor * ggml_floor(
1108+
struct ggml_context * ctx,
1109+
struct ggml_tensor * a);
1110+
1111+
GGML_API struct ggml_tensor * ggml_floor_inplace(
1112+
struct ggml_context * ctx,
1113+
struct ggml_tensor * a);
1114+
1115+
GGML_API struct ggml_tensor * ggml_ceil(
1116+
struct ggml_context * ctx,
1117+
struct ggml_tensor * a);
1118+
1119+
GGML_API struct ggml_tensor * ggml_ceil_inplace(
1120+
struct ggml_context * ctx,
1121+
struct ggml_tensor * a);
11061122

1123+
GGML_API struct ggml_tensor * ggml_round(
1124+
struct ggml_context * ctx,
1125+
struct ggml_tensor * a);
1126+
1127+
GGML_API struct ggml_tensor * ggml_round_inplace(
1128+
struct ggml_context * ctx,
1129+
struct ggml_tensor * a);
1130+
1131+
GGML_API struct ggml_tensor * ggml_trunc(
1132+
struct ggml_context * ctx,
1133+
struct ggml_tensor * a);
1134+
1135+
GGML_API struct ggml_tensor * ggml_trunc_inplace(
1136+
struct ggml_context * ctx,
1137+
struct ggml_tensor * a);
1138+
1139+
11071140
// GELU using erf (error function) when possible
11081141
// some backends may fallback to approximation based on Abramowitz and Stegun formula
11091142
GGML_API struct ggml_tensor * ggml_gelu_erf(
@@ -2154,7 +2187,7 @@ extern "C" {
21542187
int p1,
21552188
int p2,
21562189
int p3);
2157-
2190+
21582191
GGML_API struct ggml_tensor * ggml_pad_ext(
21592192
struct ggml_context * ctx,
21602193
struct ggml_tensor * a,

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
150150
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
151151
}
152152

153+
template<typename T>
154+
static __dpct_inline__ T op_floor(T x) {
155+
return sycl::floor(x);
156+
}
157+
158+
template<typename T>
159+
static __dpct_inline__ T op_ceil(T x) {
160+
return sycl::ceil(x);
161+
}
162+
163+
template<typename T>
164+
static __dpct_inline__ T op_round(T x) {
165+
return sycl::round(x);
166+
}
167+
168+
template<typename T>
169+
static __dpct_inline__ T op_trunc(T x) {
170+
return sycl::trunc(x);
171+
}
172+
153173
template<typename T>
154174
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
155175
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
@@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
304324
}
305325
}
306326

327+
template<typename T>
328+
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
329+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
330+
dst[i] = op_floor(x[i]);
331+
}
332+
}
333+
334+
template<typename T>
335+
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
336+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
337+
dst[i] = op_ceil(x[i]);
338+
}
339+
}
340+
341+
template<typename T>
342+
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
343+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
344+
dst[i] = op_round(x[i]);
345+
}
346+
}
347+
348+
template<typename T>
349+
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
350+
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
351+
dst[i] = op_trunc(x[i]);
352+
}
353+
}
354+
307355
template<typename T>
308356
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
309357
const int nb02, const int nb03, const int ne10, const int ne11,
@@ -870,6 +918,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
870918
}, min_val, max_val);
871919
}
872920

921+
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
922+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
923+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
924+
const int num_blocks = ceil_div(k_elements, 256);
925+
stream->parallel_for(
926+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
927+
sycl::range<1>(256)),
928+
[=](sycl::nd_item<1> item_ct1) {
929+
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
930+
});
931+
});
932+
}
933+
934+
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
935+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
936+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
937+
const int num_blocks = ceil_div(k_elements, 256);
938+
stream->parallel_for(
939+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
940+
sycl::range<1>(256)),
941+
[=](sycl::nd_item<1> item_ct1) {
942+
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
943+
});
944+
});
945+
}
946+
947+
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
948+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
949+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
950+
const int num_blocks = ceil_div(k_elements, 256);
951+
stream->parallel_for(
952+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
953+
sycl::range<1>(256)),
954+
[=](sycl::nd_item<1> item_ct1) {
955+
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
956+
});
957+
});
958+
}
959+
960+
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
961+
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
962+
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
963+
const int num_blocks = ceil_div(k_elements, 256);
964+
stream->parallel_for(
965+
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
966+
sycl::range<1>(256)),
967+
[=](sycl::nd_item<1> item_ct1) {
968+
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
969+
});
970+
});
971+
}
972+
873973
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
874974
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
875975
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
@@ -1090,3 +1190,23 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
10901190
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
10911191
ggml_sycl_op_geglu_quick(ctx, dst);
10921192
}
1193+
1194+
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1195+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1196+
ggml_sycl_op_floor(ctx, dst);
1197+
}
1198+
1199+
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1200+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1201+
ggml_sycl_op_ceil(ctx, dst);
1202+
}
1203+
1204+
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1205+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1206+
ggml_sycl_op_round(ctx, dst);
1207+
}
1208+
1209+
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1210+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1211+
ggml_sycl_op_trunc(ctx, dst);
1212+
}

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,9 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8080
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8181
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8282
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
83+
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
84+
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
85+
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
86+
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8387

8488
#endif // GGML_SYCL_ELEMENTWISE_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3694,6 +3694,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
36943694
case GGML_UNARY_OP_ELU:
36953695
ggml_sycl_elu(ctx, dst);
36963696
break;
3697+
case GGML_UNARY_OP_FLOOR:
3698+
ggml_sycl_floor(ctx, dst);
3699+
break;
3700+
case GGML_UNARY_OP_CEIL:
3701+
ggml_sycl_ceil(ctx, dst);
3702+
break;
3703+
case GGML_UNARY_OP_ROUND:
3704+
ggml_sycl_round(ctx, dst);
3705+
break;
3706+
case GGML_UNARY_OP_TRUNC:
3707+
ggml_sycl_trunc(ctx, dst);
3708+
break;
36973709
default:
36983710
return false;
36993711
}
@@ -4255,6 +4267,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42554267
case GGML_UNARY_OP_SGN:
42564268
case GGML_UNARY_OP_ABS:
42574269
case GGML_UNARY_OP_ELU:
4270+
case GGML_UNARY_OP_FLOOR:
4271+
case GGML_UNARY_OP_CEIL:
4272+
case GGML_UNARY_OP_ROUND:
4273+
case GGML_UNARY_OP_TRUNC:
42584274
#if defined (GGML_SYCL_F16)
42594275
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
42604276
#else

0 commit comments

Comments
 (0)