@@ -840,18 +840,34 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
840840
841841 float scale = (1 .0f / sqrt ((float )d_head));
842842
843- // if (flash_attn) {
844- // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
845- // }
843+ int kv_pad = 0 ;
844+ // if (flash_attn) {
845+ // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
846+ // }
846847 // is there anything oddly shaped?? ping Green-Sky if you can trip this assert
847848 GGML_ASSERT (((L_k % 256 == 0 ) && L_q == L_k) || !(L_k % 256 == 0 ));
848849
849850 bool can_use_flash_attn = true ;
851+ can_use_flash_attn = can_use_flash_attn && (
852+ d_head == 64 ||
853+ d_head == 80 ||
854+ d_head == 96 ||
855+ d_head == 112 ||
856+ d_head == 128 ||
857+ d_head == 256
858+ );
859+ #if 0
850860 can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
851- can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0 ; // double check
852-
853- // cuda max d_head seems to be 256, cpu does seem to work with 512
854- can_use_flash_attn = can_use_flash_attn && d_head <= 256 ; // double check
861+ #else
862+ if (can_use_flash_attn && L_k % 256 != 0 ) {
863+ // TODO(Green-Sky): might be worth just padding by default
864+ if (L_k == 77 || L_k == 4208 || L_k == 3952 ) {
865+ kv_pad = GGML_PAD (L_k, 256 ) - L_k;
866+ } else {
867+ can_use_flash_attn = false ;
868+ }
869+ }
870+ #endif
855871
856872 if (mask != nullptr ) {
857873 // TODO(Green-Sky): figure out if we can bend t5 to work too
@@ -864,11 +880,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
864880 ggml_tensor* kqv = nullptr ;
865881 // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
866882 if (can_use_flash_attn && flash_attn) {
867- // LOG_DEBUG("using flash attention");
883+ // LOG_DEBUG(" uses flash attention");
884+ if (kv_pad != 0 ) {
885+ // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
886+ k = ggml_pad (ctx, k, 0 , kv_pad, 0 , 0 );
887+ }
868888 k = ggml_cast (ctx, k, GGML_TYPE_F16);
869889
870890 v = ggml_cont (ctx, ggml_permute (ctx, v, 0 , 2 , 1 , 3 )); // [N, n_head, L_k, d_head]
871891 v = ggml_reshape_3d (ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
892+ if (kv_pad != 0 ) {
893+ v = ggml_pad (ctx, v, 0 , kv_pad, 0 , 0 );
894+ }
872895 v = ggml_cast (ctx, v, GGML_TYPE_F16);
873896
874897 if (mask != nullptr ) {
0 commit comments