Skip to content

Commit b80d864

Browse files
committed
vulkan: handle mat_mul with A matrix > 4GB
This change splits mat_mul operations with huge A matrix into chunks in the M dimension. This works well for stable-diffusion use cases where the im2col matrix has very large M. Fix the order of setting the stride in mul_mm_cm2 - setting the dimension clobbers the stride, so stride should be set after.
1 parent 3f81b4e commit b80d864

File tree

3 files changed

+54
-9
lines changed

3 files changed

+54
-9
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5653,8 +5653,12 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
56535653
ggml_vk_queue_command_pools_cleanup(dst->device);
56545654
}
56555655

5656-
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, const vk_pipeline& pipeline) {
5657-
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
5656+
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
5657+
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
5658+
5659+
if (disable_split_k) {
5660+
return 1;
5661+
}
56585662

56595663
uint32_t split_k = 1;
56605664
if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
@@ -5979,7 +5983,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
59795983
ggml_vk_sync_buffers(ctx, subctx);
59805984
}
59815985

5982-
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5986+
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) {
59835987
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
59845988
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
59855989
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
@@ -5999,6 +6003,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
59996003

60006004
const uint64_t ne20 = dst->ne[0];
60016005
const uint64_t ne21 = dst->ne[1];
6006+
const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
6007+
const uint32_t stride_batch_d = stride_d*ne21;
60026008

60036009
const uint64_t r2 = ne12 / ne02;
60046010
const uint64_t r3 = ne13 / ne03;
@@ -6067,7 +6073,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
60676073
const int y_ne = padded_n * ne10;
60686074
const int d_ne = ne11 * ne01;
60696075

6070-
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
6076+
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
60716077

60726078
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
60736079
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -6226,13 +6232,16 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
62266232
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
62276233
}
62286234

6235+
// No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
6236+
VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
6237+
62296238
// compute
62306239
ggml_vk_matmul(
62316240
ctx, subctx, pipeline,
62326241
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
6233-
{ d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
6242+
{ d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
62346243
ne01, ne11, ne10,
6235-
ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
6244+
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
62366245
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
62376246
); // NOLINT
62386247

@@ -6712,7 +6721,34 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
67126721

67136722
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
67146723
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
6715-
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
6724+
6725+
// Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
6726+
// where the M dimension is very large.
6727+
// Split_k doesn't work with M splitting.
6728+
const size_t nbytes = ggml_nbytes(src0);
6729+
const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
6730+
if (needs_split) {
6731+
// Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
6732+
const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
6733+
uint32_t m_offset = 0;
6734+
while (m_offset < dst->ne[0]) {
6735+
const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
6736+
ggml_tensor dst2 = *dst;
6737+
ggml_tensor src02 = *src0;
6738+
6739+
dst2.view_src = dst->view_src ? dst->view_src : (ggml_tensor *)dst;
6740+
src02.view_src = src0->view_src ? src0->view_src : (ggml_tensor *)src0;
6741+
6742+
dst2.view_offs += m_offset * dst->nb[0];
6743+
src02.view_offs += m_offset * src0->nb[1];
6744+
dst2.ne[0] = cur_M_size;
6745+
src02.ne[1] = cur_M_size;
6746+
6747+
ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun);
6748+
6749+
m_offset += cur_M_size;
6750+
}
6751+
} else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
67166752
// detect 0213 permutation, and batch size of 1
67176753
src0->nb[0] <= src0->nb[2] &&
67186754
src0->nb[2] <= src0->nb[1] &&
@@ -6732,7 +6768,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
67326768
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
67336769
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
67346770
} else {
6735-
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
6771+
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun);
67366772
}
67376773
}
67386774

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,6 @@ void main() {
265265
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
266266
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
267267
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
268-
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
269268

270269
#if QUANT_K > 1
271270
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@@ -281,6 +280,8 @@ void main() {
281280
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
282281
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
283282

283+
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
284+
284285
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
285286

286287
#if !defined(MUL_MAT_ID)

tests/test-backend-ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6200,6 +6200,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
62006200
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
62016201
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
62026202

6203+
#if 0
6204+
// > 4GB A matrix. Too slow to be enabled by default.
6205+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 900000, 3, 2592, {1, 1}, {1, 1}));
6206+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));
6207+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1}));
6208+
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1}));
6209+
#endif
6210+
62036211
for (ggml_type type_a : all_types) {
62046212
for (int i = 1; i < 10; ++i) {
62056213
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));

0 commit comments

Comments
 (0)