Skip to content

Commit aa5566f

Browse files
committed
fix ggml_nn_attention_ext mask
1 parent 48d4c1c commit aa5566f

File tree

3 files changed

+69
-10
lines changed

3 files changed

+69
-10
lines changed

ggml_extend.hpp

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,34 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_c
942942
return {q, k, v};
943943
}
944944

945+
__STATIC_INLINE__ struct ggml_tensor* ggml_full(struct ggml_context* ctx,
946+
float value,
947+
int64_t ne0,
948+
int64_t ne1,
949+
int64_t ne2,
950+
int64_t ne3) {
951+
auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one");
952+
auto t = ggml_scale(ctx, one, value); // [1,]
953+
t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
954+
return t;
955+
}
956+
957+
__STATIC_INLINE__ struct ggml_tensor* ggml_zeros(struct ggml_context* ctx,
958+
int64_t ne0,
959+
int64_t ne1,
960+
int64_t ne2,
961+
int64_t ne3) {
962+
return ggml_full(ctx, 0.f, ne0, ne1, ne2, ne3);
963+
}
964+
965+
__STATIC_INLINE__ struct ggml_tensor* ggml_ones(struct ggml_context* ctx,
966+
int64_t ne0,
967+
int64_t ne1,
968+
int64_t ne2,
969+
int64_t ne3) {
970+
return ggml_full(ctx, 1.f, ne0, ne1, ne2, ne3);
971+
}
972+
945973
// q: [N * n_head, n_token, d_head]
946974
// k: [N * n_head, n_k, d_head]
947975
// v: [N * n_head, d_head, n_k]
@@ -969,6 +997,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
969997
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
970998
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
971999
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
1000+
// mask: [N, L_q, L_k]
9721001
// return: [N, L_q, C]
9731002
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
9741003
struct ggml_tensor* q,
@@ -1019,7 +1048,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
10191048

10201049
if (mask != nullptr) {
10211050
// TODO(Green-Sky): figure out if we can bend t5 to work too
1022-
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
10231051
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
10241052
}
10251053

@@ -1046,14 +1074,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
10461074

10471075
if (mask != nullptr) {
10481076
mask = ggml_transpose(ctx, mask);
1049-
1050-
if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) {
1051-
LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]);
1052-
LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD));
1053-
mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0);
1077+
} else {
1078+
if (kv_pad > 0) {
1079+
mask = ggml_zeros(ctx, L_k, L_q, 1, 1); // [L_q, L_k]
1080+
auto pad_tensor = ggml_full(ctx, -INFINITY, kv_pad, L_q, 1, 1); // [L_q, kv_pad]
1081+
mask = ggml_concat(ctx, mask, pad_tensor, 0); // [L_q, L_k + kv_pad]
10541082
}
1083+
}
10551084

1085+
// mask pad
1086+
if (mask != nullptr) {
1087+
int mask_pad = 0;
1088+
if (mask->ne[1] % GGML_KQ_MASK_PAD != 0) {
1089+
mask_pad = GGML_PAD(L_q, GGML_KQ_MASK_PAD) - mask->ne[1];
1090+
}
1091+
if (mask_pad > 0) {
1092+
mask = ggml_pad(ctx, mask, 0, mask_pad, 0, 0); // [L_q + mask_pad, L_k + kv_pad]
1093+
}
10561094
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
1095+
// LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad);
10571096
}
10581097

10591098
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
@@ -1271,6 +1310,9 @@ struct GGMLRunner {
12711310
struct ggml_context* compute_ctx = NULL;
12721311
struct ggml_gallocr* compute_allocr = NULL;
12731312

1313+
std::vector<float> one_vec = {1.f};
1314+
ggml_tensor* one_tensor = NULL;
1315+
12741316
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
12751317

12761318
void alloc_params_ctx() {
@@ -1315,12 +1357,29 @@ struct GGMLRunner {
13151357
}
13161358
}
13171359

1360+
void prepare_build_in_tensor_before() {
1361+
one_tensor = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, 1);
1362+
ggml_set_name(one_tensor, "ggml_runner_build_in_tensor:one");
1363+
set_backend_tensor_data(one_tensor, one_vec.data());
1364+
}
1365+
1366+
void prepare_build_in_tensor_after(struct ggml_cgraph* gf) {
1367+
ggml_build_forward_expand(gf, one_tensor);
1368+
}
1369+
1370+
struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) {
1371+
prepare_build_in_tensor_before();
1372+
struct ggml_cgraph* gf = get_graph();
1373+
prepare_build_in_tensor_after(gf);
1374+
return gf;
1375+
}
1376+
13181377
bool alloc_compute_buffer(get_graph_cb_t get_graph) {
13191378
if (compute_allocr != NULL) {
13201379
return true;
13211380
}
13221381
reset_compute_ctx();
1323-
struct ggml_cgraph* gf = get_graph();
1382+
struct ggml_cgraph* gf = get_compute_graph(get_graph);
13241383
backend_tensor_data_map.clear();
13251384
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
13261385

@@ -1531,7 +1590,7 @@ struct GGMLRunner {
15311590
}
15321591
alloc_compute_buffer(get_graph);
15331592
reset_compute_ctx();
1534-
struct ggml_cgraph* gf = get_graph();
1593+
struct ggml_cgraph* gf = get_compute_graph(get_graph);
15351594
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
15361595
cpy_data_to_backend_tensor();
15371596
if (ggml_backend_is_cpu(runtime_backend)) {

stable-diffusion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ enum sd_type_t {
101101
// SD_TYPE_IQ4_NL_4_4 = 36,
102102
// SD_TYPE_IQ4_NL_4_8 = 37,
103103
// SD_TYPE_IQ4_NL_8_8 = 38,
104-
SD_TYPE_MXFP4 = 39, // MXFP4 (1 block)
104+
SD_TYPE_MXFP4 = 39, // MXFP4 (1 block)
105105
SD_TYPE_COUNT = 40,
106106
};
107107

vae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ struct VAE : public GGMLRunner {
529529
struct ggml_tensor** output,
530530
struct ggml_context* output_ctx) = 0;
531531
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) = 0;
532-
virtual void enable_conv2d_direct() {};
532+
virtual void enable_conv2d_direct(){};
533533
};
534534

535535
struct AutoEncoderKL : public VAE {

0 commit comments

Comments
 (0)