@@ -942,6 +942,34 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> split_image_qkv(struct ggml_c
942942 return {q, k, v};
943943}
944944
945+ __STATIC_INLINE__ struct ggml_tensor * ggml_full (struct ggml_context * ctx,
946+ float value,
947+ int64_t ne0,
948+ int64_t ne1,
949+ int64_t ne2,
950+ int64_t ne3) {
951+ auto one = ggml_get_tensor (ctx, " ggml_runner_build_in_tensor:one" );
952+ auto t = ggml_scale (ctx, one, value); // [1,]
953+ t = ggml_repeat_4d (ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3]
954+ return t;
955+ }
956+
957+ __STATIC_INLINE__ struct ggml_tensor * ggml_zeros (struct ggml_context * ctx,
958+ int64_t ne0,
959+ int64_t ne1,
960+ int64_t ne2,
961+ int64_t ne3) {
962+ return ggml_full (ctx, 0 .f , ne0, ne1, ne2, ne3);
963+ }
964+
965+ __STATIC_INLINE__ struct ggml_tensor * ggml_ones (struct ggml_context * ctx,
966+ int64_t ne0,
967+ int64_t ne1,
968+ int64_t ne2,
969+ int64_t ne3) {
970+ return ggml_full (ctx, 1 .f , ne0, ne1, ne2, ne3);
971+ }
972+
945973// q: [N * n_head, n_token, d_head]
946974// k: [N * n_head, n_k, d_head]
947975// v: [N * n_head, d_head, n_k]
@@ -969,6 +997,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
969997// q: [N, L_q, C] or [N*n_head, L_q, d_head]
970998// k: [N, L_k, C] or [N*n_head, L_k, d_head]
971999// v: [N, L_k, C] or [N, L_k, n_head, d_head]
1000+ // mask: [N, L_q, L_k]
9721001// return: [N, L_q, C]
9731002__STATIC_INLINE__ struct ggml_tensor * ggml_nn_attention_ext (struct ggml_context * ctx,
9741003 struct ggml_tensor * q,
@@ -1019,7 +1048,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
10191048
10201049 if (mask != nullptr ) {
10211050 // TODO(Green-Sky): figure out if we can bend t5 to work too
1022- can_use_flash_attn = can_use_flash_attn && mask->ne [2 ] == 1 ;
10231051 can_use_flash_attn = can_use_flash_attn && mask->ne [3 ] == 1 ;
10241052 }
10251053
@@ -1046,14 +1074,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
10461074
10471075 if (mask != nullptr ) {
10481076 mask = ggml_transpose (ctx, mask);
1049-
1050- if (mask-> ne [ 1 ] < GGML_PAD (q-> ne [ 1 ], GGML_KQ_MASK_PAD) ) {
1051- LOG_DEBUG ( " mask dims %ld, %ld, %ld, %ld \n " , mask-> ne [ 0 ], mask-> ne [ 1 ], mask-> ne [ 2 ], mask-> ne [ 3 ]);
1052- LOG_DEBUG ( " needs padding, padding from %ld to %ld \n " , mask-> ne [ 1 ], GGML_PAD (q-> ne [ 1 ], GGML_KQ_MASK_PAD));
1053- mask = ggml_pad (ctx, mask, 0 , GGML_PAD (q-> ne [ 1 ], GGML_KQ_MASK_PAD) - mask-> ne [ 1 ], 0 , 0 );
1077+ } else {
1078+ if (kv_pad > 0 ) {
1079+ mask = ggml_zeros (ctx, L_k, L_q, 1 , 1 ); // [L_q, L_k]
1080+ auto pad_tensor = ggml_full (ctx, -INFINITY, kv_pad, L_q, 1 , 1 ); // [L_q, kv_pad]
1081+ mask = ggml_concat (ctx, mask, pad_tensor, 0 ); // [L_q, L_k + kv_pad]
10541082 }
1083+ }
10551084
1085+ // mask pad
1086+ if (mask != nullptr ) {
1087+ int mask_pad = 0 ;
1088+ if (mask->ne [1 ] % GGML_KQ_MASK_PAD != 0 ) {
1089+ mask_pad = GGML_PAD (L_q, GGML_KQ_MASK_PAD) - mask->ne [1 ];
1090+ }
1091+ if (mask_pad > 0 ) {
1092+ mask = ggml_pad (ctx, mask, 0 , mask_pad, 0 , 0 ); // [L_q + mask_pad, L_k + kv_pad]
1093+ }
10561094 mask = ggml_cast (ctx, mask, GGML_TYPE_F16);
1095+ // LOG_DEBUG("L_k: %ld, L_q: %ld, mask->ne[1]: %ld, mask_pad: %d, kv_pad: %d", L_k, L_q, mask->ne[1], mask_pad, kv_pad);
10571096 }
10581097
10591098 kqv = ggml_flash_attn_ext (ctx, q, k, v, mask, scale, 0 , 0 );
@@ -1271,6 +1310,9 @@ struct GGMLRunner {
12711310 struct ggml_context * compute_ctx = NULL ;
12721311 struct ggml_gallocr * compute_allocr = NULL ;
12731312
1313+ std::vector<float > one_vec = {1 .f };
1314+ ggml_tensor* one_tensor = NULL ;
1315+
12741316 std::map<struct ggml_tensor *, const void *> backend_tensor_data_map;
12751317
12761318 void alloc_params_ctx () {
@@ -1315,12 +1357,29 @@ struct GGMLRunner {
13151357 }
13161358 }
13171359
1360+ void prepare_build_in_tensor_before () {
1361+ one_tensor = ggml_new_tensor_1d (compute_ctx, GGML_TYPE_F32, 1 );
1362+ ggml_set_name (one_tensor, " ggml_runner_build_in_tensor:one" );
1363+ set_backend_tensor_data (one_tensor, one_vec.data ());
1364+ }
1365+
1366+ void prepare_build_in_tensor_after (struct ggml_cgraph * gf) {
1367+ ggml_build_forward_expand (gf, one_tensor);
1368+ }
1369+
1370+ struct ggml_cgraph * get_compute_graph (get_graph_cb_t get_graph) {
1371+ prepare_build_in_tensor_before ();
1372+ struct ggml_cgraph * gf = get_graph ();
1373+ prepare_build_in_tensor_after (gf);
1374+ return gf;
1375+ }
1376+
13181377 bool alloc_compute_buffer (get_graph_cb_t get_graph) {
13191378 if (compute_allocr != NULL ) {
13201379 return true ;
13211380 }
13221381 reset_compute_ctx ();
1323- struct ggml_cgraph * gf = get_graph ( );
1382+ struct ggml_cgraph * gf = get_compute_graph (get_graph );
13241383 backend_tensor_data_map.clear ();
13251384 compute_allocr = ggml_gallocr_new (ggml_backend_get_default_buffer_type (runtime_backend));
13261385
@@ -1531,7 +1590,7 @@ struct GGMLRunner {
15311590 }
15321591 alloc_compute_buffer (get_graph);
15331592 reset_compute_ctx ();
1534- struct ggml_cgraph * gf = get_graph ( );
1593+ struct ggml_cgraph * gf = get_compute_graph (get_graph );
15351594 GGML_ASSERT (ggml_gallocr_alloc_graph (compute_allocr, gf));
15361595 cpy_data_to_backend_tensor ();
15371596 if (ggml_backend_is_cpu (runtime_backend)) {
0 commit comments