Skip to content

Commit ed710b3

Browse files
committed
Implement overlap binary operators
1 parent da5296e commit ed710b3

File tree

4 files changed

+436
-159
lines changed

4 files changed

+436
-159
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 115 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ struct webgpu_context_struct {
252252
webgpu_pipeline get_rows_pipeline[30];
253253
webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
255-
webgpu_pipeline add_pipeline[2][2]; // type, inplace
256-
webgpu_pipeline sub_pipeline[2][2]; // type, inplace
257-
webgpu_pipeline mul_pipeline[2][2]; // type, inplace
258-
webgpu_pipeline div_pipeline[2][2]; // type, inplace
255+
webgpu_pipeline add_pipeline[2][2][2]; // type, inplace, overlap
256+
webgpu_pipeline sub_pipeline[2][2][2]; // type, inplace, overlap
257+
webgpu_pipeline mul_pipeline[2][2][2]; // type, inplace, overlap
258+
webgpu_pipeline div_pipeline[2][2][2]; // type, inplace, overlap
259259
webgpu_pipeline rms_norm_pipeline[2]; // inplace
260260
webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace
261261
webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split
@@ -677,9 +677,12 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor
677677
return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
678678
}
679679

680+
static size_t ggml_webgpu_tensor_align_binding_size(size_t size) {
681+
return (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
682+
}
683+
680684
static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
681-
return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
682-
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
685+
return ggml_webgpu_tensor_align_binding_size(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t));
683686
}
684687

685688
// Used to determine if two tensors are the same for in-place operations
@@ -688,6 +691,12 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
688691
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
689692
}
690693

694+
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
695+
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
696+
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
697+
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
698+
}
699+
691700
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
692701
uint32_t ne = (uint32_t) ggml_nelements(dst);
693702

@@ -870,16 +879,27 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
870879
return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
871880
}
872881

873-
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
874-
ggml_tensor * src0,
875-
ggml_tensor * src1,
876-
ggml_tensor * dst,
877-
webgpu_pipeline & pipeline,
878-
bool inplace) {
882+
template <size_t a, size_t b, size_t c>
883+
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
884+
ggml_tensor * src0,
885+
ggml_tensor * src1,
886+
ggml_tensor * dst,
887+
webgpu_pipeline (&pipelines)[a][b][c]) {
888+
int inplace = ggml_webgpu_tensor_equal(src0, dst);
889+
int overlap = ggml_webgpu_tensor_overlap(src0, src1);
890+
webgpu_pipeline pipeline = pipelines[dst->type][inplace][overlap];
891+
892+
uint32_t src1_offset = ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type);
893+
if (overlap) {
894+
// when overlapped, bind a single buffer covering both src0 and src1
895+
// TODO: Do other operations need this?
896+
src1_offset = (uint32_t) ((ggml_webgpu_tensor_offset(src1) - ggml_webgpu_tensor_align_offset(ctx, src0)) /
897+
ggml_type_size(src1->type));
898+
}
879899
std::vector<uint32_t> params = {
880900
(uint32_t) ggml_nelements(dst),
881901
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
882-
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
902+
src1_offset,
883903
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
884904
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
885905
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
@@ -894,25 +914,36 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
894914
(uint32_t) src1->ne[3],
895915
};
896916

917+
size_t src0_binding_size = ggml_webgpu_tensor_binding_size(ctx, src0);
918+
if (overlap) {
919+
const uint64_t base_align = ggml_webgpu_tensor_align_offset(ctx, src0);
920+
// assume end of src1 is >= end of src0
921+
const uint64_t max_end = ggml_webgpu_tensor_offset(src1) + ggml_nbytes(src1);
922+
src0_binding_size = ggml_webgpu_tensor_align_binding_size(max_end - base_align);
923+
}
897924
std::vector<wgpu::BindGroupEntry> entries = {
898925
{ .binding = 0,
899926
.buffer = ggml_webgpu_tensor_buf(src0),
900927
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
901-
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
902-
{ .binding = 1,
903-
.buffer = ggml_webgpu_tensor_buf(src1),
904-
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
905-
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
928+
.size = src0_binding_size }
906929
};
930+
uint32_t binding_num = 1;
931+
if (!overlap) {
932+
entries.push_back({ .binding = binding_num,
933+
.buffer = ggml_webgpu_tensor_buf(src1),
934+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
935+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) });
936+
binding_num++;
937+
}
907938
if (!inplace) {
908-
entries.push_back({ .binding = 2,
939+
entries.push_back({ .binding = binding_num,
909940
.buffer = ggml_webgpu_tensor_buf(dst),
910941
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
911942
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
912943
}
913944

914-
size_t max_wg_size = ctx->max_wg_size_x;
915-
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
945+
size_t max_wg_size = ctx->max_wg_size_x;
946+
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
916947
return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
917948
}
918949

@@ -1232,25 +1263,13 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
12321263
case GGML_OP_MUL_MAT:
12331264
return ggml_webgpu_mul_mat(ctx, src0, src1, node);
12341265
case GGML_OP_ADD:
1235-
{
1236-
int inplace = ggml_webgpu_tensor_equal(src0, node);
1237-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace);
1238-
}
1266+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline);
12391267
case GGML_OP_SUB:
1240-
{
1241-
int inplace = ggml_webgpu_tensor_equal(src0, node);
1242-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace);
1243-
}
1268+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline);
12441269
case GGML_OP_MUL:
1245-
{
1246-
int inplace = ggml_webgpu_tensor_equal(src0, node);
1247-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace);
1248-
}
1270+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline);
12491271
case GGML_OP_DIV:
1250-
{
1251-
int inplace = ggml_webgpu_tensor_equal(src0, node);
1252-
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace);
1253-
}
1272+
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline);
12541273
case GGML_OP_RMS_NORM:
12551274
return ggml_webgpu_rms_norm(ctx, src0, node);
12561275
case GGML_OP_ROPE:
@@ -1700,50 +1719,82 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
17001719

17011720
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
17021721
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1703-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
1704-
constants);
1705-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
1706-
constants);
1707-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace,
1722+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][0], wgsl_add_f32,
1723+
"add_f32", constants);
1724+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][0], wgsl_add_f16,
1725+
"add_f16", constants);
1726+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][0], wgsl_add_f32_inplace,
17081727
"add_f32_inplace", constants);
1709-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace,
1728+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][0], wgsl_add_f16_inplace,
17101729
"add_f16_inplace", constants);
1730+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0][1], wgsl_add_f32_overlap,
1731+
"add_f32_overlap", constants);
1732+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1][1],
1733+
wgsl_add_f32_inplace_overlap, "add_f32_inplace_overlap", constants);
1734+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0][1], wgsl_add_f16_overlap,
1735+
"add_f16_overlap", constants);
1736+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1][1],
1737+
wgsl_add_f16_inplace_overlap, "add_f16_inplace_overlap", constants);
17111738
}
17121739

17131740
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
17141741
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1715-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
1716-
constants);
1717-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
1718-
constants);
1719-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace,
1742+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][0], wgsl_sub_f32,
1743+
"sub_f32", constants);
1744+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][0], wgsl_sub_f16,
1745+
"sub_f16", constants);
1746+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][0], wgsl_sub_f32_inplace,
17201747
"sub_f32_inplace", constants);
1721-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace,
1748+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][0], wgsl_sub_f16_inplace,
17221749
"sub_f16_inplace", constants);
1750+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0][1], wgsl_sub_f32_overlap,
1751+
"sub_f32_overlap", constants);
1752+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1][1],
1753+
wgsl_sub_f32_inplace_overlap, "sub_f32_inplace_overlap", constants);
1754+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0][1], wgsl_sub_f16_overlap,
1755+
"sub_f16_overlap", constants);
1756+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1][1],
1757+
wgsl_sub_f16_inplace_overlap, "sub_f16_inplace_overlap", constants);
17231758
}
17241759

17251760
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
17261761
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1727-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
1728-
constants);
1729-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
1730-
constants);
1731-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace,
1762+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][0], wgsl_mul_f32,
1763+
"mul_f32", constants);
1764+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][0], wgsl_mul_f16,
1765+
"mul_f16", constants);
1766+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][0], wgsl_mul_f32_inplace,
17321767
"mul_f32_inplace", constants);
1733-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace,
1768+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][0], wgsl_mul_f16_inplace,
17341769
"mul_f16_inplace", constants);
1770+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0][1], wgsl_mul_f32_overlap,
1771+
"mul_f32_overlap", constants);
1772+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1][1],
1773+
wgsl_mul_f32_inplace_overlap, "mul_f32_inplace_overlap", constants);
1774+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0][1], wgsl_mul_f16_overlap,
1775+
"mul_f16_overlap", constants);
1776+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1][1],
1777+
wgsl_mul_f16_inplace_overlap, "mul_f16_inplace_overlap", constants);
17351778
}
17361779

17371780
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
17381781
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
1739-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
1740-
constants);
1741-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
1742-
constants);
1743-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace,
1782+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][0], wgsl_div_f32,
1783+
"div_f32", constants);
1784+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][0], wgsl_div_f16,
1785+
"div_f16", constants);
1786+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][0], wgsl_div_f32_inplace,
17441787
"div_f32_inplace", constants);
1745-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace,
1788+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][0], wgsl_div_f16_inplace,
17461789
"div_f16_inplace", constants);
1790+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0][1], wgsl_div_f32_overlap,
1791+
"div_f32_overlap", constants);
1792+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1][1],
1793+
wgsl_div_f32_inplace_overlap, "div_f32_inplace_overlap", constants);
1794+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0][1], wgsl_div_f16_overlap,
1795+
"div_f16_overlap", constants);
1796+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1][1],
1797+
wgsl_div_f16_inplace_overlap, "div_f16_inplace_overlap", constants);
17471798
}
17481799

17491800
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
@@ -2152,9 +2203,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
21522203
// TODO: Don't enable for WASM builds, they won't have an effect anyways
21532204
// TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
21542205
// only for native performance?
2155-
const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
2156-
"disable_polyfills_on_integer_div_and_mod" };
2157-
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
2206+
const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init",
2207+
"disable_polyfills_on_integer_div_and_mod" };
2208+
const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
21582209
wgpu::DawnTogglesDescriptor deviceTogglesDesc;
21592210
deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
21602211
deviceTogglesDesc.enabledToggleCount = 4;

0 commit comments

Comments
 (0)