Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,23 @@ void main() {
#if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
A_TYPE32 aa = A_TYPE32(data_a[idx]);
buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w);
buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x);
buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y);
buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z);
buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w);
#elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
A_TYPE32 aa = A_TYPE32(data_a[idx]);
buf_a[buf_idx ] = FLOAT_TYPE(aa.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y);
buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z);
buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w);
#else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
Expand Down Expand Up @@ -808,14 +810,19 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#if defined(DATA_B_BF16)
B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
#else
B_TYPE32 bb = B_TYPE32(data_b[idx]);
#endif
buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w);
buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x);
buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y);
buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z);
buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w);
#elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[loadc_b + l];
Expand All @@ -824,10 +831,15 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
#if defined(DATA_B_BF16)
B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
#else
B_TYPE32 bb = B_TYPE32(data_b[idx]);
#endif
buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x);
buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y);
buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z);
buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w);
#elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
Expand Down
11 changes: 11 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.comp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float
#define A_TYPE32 float
#elif LOAD_VEC_A == 4
#define A_TYPE vec4
#define A_TYPE32 vec4
#elif LOAD_VEC_A == 8
#define A_TYPE mat2x4
#define A_TYPE32 mat2x4
#endif
#endif

Expand All @@ -26,10 +29,13 @@

#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float16_t
#define A_TYPE32 float
#elif LOAD_VEC_A == 4
#define A_TYPE f16vec4
#define A_TYPE32 vec4
#elif LOAD_VEC_A == 8
#define A_TYPE f16mat2x4
#define A_TYPE32 mat2x4
#endif
#endif

Expand Down Expand Up @@ -1424,6 +1430,11 @@ float bf16_to_fp32(uint32_t u)
return uintBitsToFloat(u << 16);
}

vec4 bf16_to_fp32(uvec4 u)
{
return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w));
}

float e8m0_to_fp32(uint8_t x) {
uint32_t bits;

Expand Down
20 changes: 10 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
};

// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);

string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);

// bf16
{
Expand All @@ -384,8 +384,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (!(coopmat || coopmat2))
#endif
{
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}

Expand All @@ -408,13 +408,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c

// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}

if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}

#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
Expand Down
Loading