diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 05e16cd432ad3..6a4d52020317e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -62,6 +62,7 @@ #define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 32u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_WAIT_ANY_BATCH_SIZE 64 #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE @@ -251,16 +252,18 @@ struct webgpu_context_struct { webgpu_pipeline set_rows_pipeline; webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; - webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type - webgpu_pipeline add_pipeline[2][2]; // type, inplace - webgpu_pipeline sub_pipeline[2][2]; // type, inplace - webgpu_pipeline mul_pipeline[2][2]; // type, inplace - webgpu_pipeline div_pipeline[2][2]; // type, inplace - webgpu_pipeline rms_norm_pipeline[2]; // inplace - webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace - webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split - webgpu_pipeline scale_pipeline[2]; // inplace + webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type + webgpu_pipeline add_pipeline[2][2]; // type, inplace + webgpu_pipeline sub_pipeline[2][2]; // type, inplace + webgpu_pipeline mul_pipeline[2][2]; // type, inplace + webgpu_pipeline div_pipeline[2][2]; // type, inplace + webgpu_pipeline rms_norm_pipeline[2]; // inplace + webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace + webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split + webgpu_pipeline scale_pipeline[2]; // inplace webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + webgpu_pipeline unary_pipeline[GGML_UNARY_OP_COUNT][2][2]; + size_t memset_bytes_per_thread; @@ -341,6 +344,8 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } + + pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } @@ -716,8 +721,10 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); + } + static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, @@ -858,6 +865,55 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } +static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * dst, + webgpu_pipeline & pipeline, + bool in_place, + bool additional_params=false) { + + + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + }; + + if (additional_params) { + for (uint i = 1; i < 5; i++) { + params.push_back((uint32_t)(ggml_get_op_params_f32(dst, i))); // alpha_n, alpha_p, beta, eps + } + } + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + + }; + if (!in_place) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -1202,6 +1258,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * src1 = node->src[1]; ggml_tensor * src2 = node->src[2]; + + switch (node->op) { // no-ops case GGML_OP_NONE: @@ -1249,12 +1307,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); + case GGML_OP_UNARY: + { + const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); + + int in_place = ggml_webgpu_tensor_equal(src0, node); + bool XIELU = (UNARY_OP == GGML_UNARY_OP_XIELU); + return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place, XIELU); + } + default: return std::nullopt; } } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); @@ -1473,6 +1541,8 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); @@ -1484,6 +1554,8 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); + + return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } @@ -1809,6 +1881,170 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } +static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); + + // ABS + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0], + wgsl_abs_f32, "abs_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0], + wgsl_abs_f16, "abs_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1], + wgsl_abs_in_place_f32, "abs_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1], + wgsl_abs_in_place_f16, "abs_in_place_f16", constants); + + // SGN + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0], + wgsl_sgn_f32, "sgn_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0], + wgsl_sgn_f16, "sgn_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1], + wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1], + wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants); + + // NEG + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0], + wgsl_neg_f32, "neg_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0], + wgsl_neg_f16, "neg_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1], + wgsl_neg_in_place_f32, "neg_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1], + wgsl_neg_in_place_f16, "neg_in_place_f16", constants); + + // STEP + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0], + wgsl_step_f32, "step_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0], + wgsl_step_f16, "step_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1], + wgsl_step_in_place_f32, "step_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1], + wgsl_step_in_place_f16, "step_in_place_f16", constants); + + // TANH + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0], + wgsl_tanh_f32, "tanh_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0], + wgsl_tanh_f16, "tanh_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1], + wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1], + wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants); + + // ELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0], + wgsl_elu_f32, "elu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0], + wgsl_elu_f16, "elu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1], + wgsl_elu_in_place_f32, "elu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1], + wgsl_elu_in_place_f16, "elu_in_place_f16", constants); + + // RELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0], + wgsl_relu_f32, "relu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0], + wgsl_relu_f16, "relu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1], + wgsl_relu_in_place_f32, "relu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1], + wgsl_relu_in_place_f16, "relu_in_place_f16", constants); + + // SIGMOID + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0], + wgsl_sigmoid_f32, "sigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0], + wgsl_sigmoid_f16, "sigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1], + wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1], + wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants); + + // GELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0], + wgsl_gelu_f32, "gelu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0], + wgsl_gelu_f16, "gelu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1], + wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1], + wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants); + + // GELU_QUICK + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0], + wgsl_gelu_quick_f32, "gelu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0], + wgsl_gelu_quick_f16, "gelu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1], + wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1], + wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants); + + // SILU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0], + wgsl_silu_f32, "silu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0], + wgsl_silu_f16, "silu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1], + wgsl_silu_in_place_f32, "silu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1], + wgsl_silu_in_place_f16, "silu_in_place_f16", constants); + + // HARDSWISH + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0], + wgsl_hardswish_f32, "hardswish_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0], + wgsl_hardswish_f16, "hardswish_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1], + wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1], + wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants); + + // HARDSIGMOID + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0], + wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0], + wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1], + wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1], + wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants); + + // EXP + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0], + wgsl_exp_f32, "exp_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0], + wgsl_exp_f16, "exp_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1], + wgsl_exp_in_place_f32, "exp_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1], + wgsl_exp_in_place_f16, "exp_in_place_f16", constants); + + // GELU_ERF + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], + wgsl_gelu_erf_f32, "gelu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], + wgsl_gelu_erf_f16, "gelu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1], + wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], + wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); + + // XIELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0], + wgsl_xielu_f32, "xielu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0], + wgsl_xielu_f16, "xielu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1], + wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1], + wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants); +} + static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", @@ -1848,6 +2084,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); @@ -1866,12 +2103,13 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co /* .device = */ dev, /* .context = */ &backend_ctx, }; - + //tried return &backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { // See GGML Backend Buffer Type Interface section + static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, @@ -1922,6 +2160,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); webgpu_context webgpu_ctx = ctx->webgpu_ctx; @@ -2034,6 +2273,36 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOFT_MAX: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_UNARY: + { + const ggml_unary_op UNARY_OP = ggml_get_unary_op(op); + + switch (UNARY_OP) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_XIELU: + supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_UNARY_OP_COUNT: + default: + break; + } + } + break; + default: break; } @@ -2179,6 +2448,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx); @@ -2193,6 +2463,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); ggml_webgpu_init_soft_max_pipeline(ctx); + ggml_webgpu_init_unary_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -2234,6 +2505,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* End GGML Backend Registration Interface */ ggml_backend_reg_t ggml_backend_webgpu_reg() { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); webgpu_context webgpu_ctx = std::make_shared(); @@ -2259,6 +2531,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); return ggml_backend_webgpu_device_init(dev, nullptr); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl new file mode 100644 index 0000000000000..f9fe9cb3c4bec --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -0,0 +1,601 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "abs_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + + }, + { + "SHADER_NAME": "abs_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "abs_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "abs_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sgn_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sgn_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sgn_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sgn_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "neg_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "neg_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "neg_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "neg_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "step_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "step_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "step_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "step_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "tanh_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "tanh_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "tanh_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "tanh_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "elu_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "elu_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "elu_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "elu_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "relu_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "relu_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "relu_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "relu_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_quick_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_quick_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "silu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "silu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "silu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "silu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardswish_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardswish_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardswish_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardswish_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardsigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardsigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "exp_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "exp_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "exp_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "exp_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_erf_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_erf_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" + }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE_EXT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE_EXT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE_EXT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE_EXT_PARAMS"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE_DFLT_PARAMS) + +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +#enddecl(NOT_INPLACE_DFLT_PARAMS) + +#decl(INPLACE_DFLT_PARAMS) + +@group(0) @binding(1) +var params: Params; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +#enddecl(INPLACE_DFLT_PARAMS) + +#decl(NOT_INPLACE_EXT_PARAMS) + +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32, + + // XIELU params + alpha_n: u32, + alpha_p: u32, + beta: u32, + eps: u32 +}; + +#enddecl(NOT_INPLACE_EXT_PARAMS) + +#decl(INPLACE_EXT_PARAMS) + +@group(0) @binding(1) +var params: Params; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32, + + // XIELU params + alpha_n: u32, + alpha_p: u32, + beta: u32, + eps: u32 +}; + +#enddecl(INPLACE_EXT_PARAMS) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +fn update(dst_i: u32, src_i: u32) { + {{FUNC}} +} + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + + update(params.offset_dst + dst_idx, params.offset_src + src_idx); +} + +#end(SHADER)