Skip to content

Commit 268ae6c

Browse files
committed
metal : support mul_mm_id with ne00 % 32 != 0
1 parent 06bc3dd commit 268ae6c

File tree

5 files changed

+191
-145
lines changed

5 files changed

+191
-145
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,19 +670,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
670670
return res;
671671
}
672672

673-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
673+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
674674
char base[256];
675675
char name[256];
676676

677+
const ggml_type tsrc0 = op->src[0]->type;
678+
const ggml_type tsrc1 = op->src[1]->type;
679+
680+
const bool bc = op->src[0]->ne[0] % 32 != 0;
681+
677682
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
678-
snprintf(name, 256, "%s", base);
683+
snprintf(name, 256, "%s_bc=%d", base, bc);
679684

680685
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
681686
if (res) {
682687
return res;
683688
}
684689

685-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
690+
ggml_metal_cv_t cv = ggml_metal_cv_init();
691+
692+
ggml_metal_cv_set_bool(cv, bc, FC_MUL_MM + 0);
693+
694+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
695+
696+
ggml_metal_cv_free(cv);
686697

687698
ggml_metal_pipeline_set_smem(res, 8192);
688699

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext (ggml_me
118118
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
119119
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
120120
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
121-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1);
121+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
122122
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
123123
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
124124
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,19 +1630,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
16301630
// ne21 = n_rows (batch size)
16311631
const int ne21_mm_id_min = 32;
16321632

1633-
if (props_dev->has_simdgroup_mm &&
1634-
ne00 % 32 == 0 && ne00 >= 64 &&
1635-
(ne21 >= ne21_mm_id_min)) {
1636-
GGML_ASSERT(ne00 % 4 == 0);
1637-
1633+
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
16381634
// some Metal matrix data types require aligned pointers
16391635
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1640-
switch (op->src[0]->type) {
1641-
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1642-
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1643-
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1644-
default: break;
1645-
}
1636+
//switch (op->src[0]->type) {
1637+
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1638+
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1639+
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1640+
// default: break;
1641+
//}
16461642

16471643
// extra buffers for intermediate id mapping
16481644
ggml_metal_buffer_id bid_tpe = bid_dst;
@@ -1686,7 +1682,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
16861682
ggml_metal_op_concurrency_reset(ctx);
16871683

16881684
{
1689-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, op->src[1]->type);
1685+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
16901686

16911687
ggml_metal_kargs_mul_mm_id args = {
16921688
/*.ne00 =*/ ne00,

0 commit comments

Comments
 (0)