@@ -308,16 +308,13 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
308308
309309 if (Q->ne [1 ] == 1 ) {
310310 constexpr int cols_per_block = 1 ;
311- const int total_blocks = (((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block)*Q->ne [2 ]*Q->ne [3 ]);
311+ const int num_blocks_base = (((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block)*Q->ne [2 ]*Q->ne [3 ]);
312312 const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
313313 const int seqlen_tiles = (K->ne [1 ] + D - 1 ) / D;
314314
315315 if (logit_softcap == 0 .0f ) {
316316 constexpr bool use_logit_softcap = false ;
317317
318- // cudaOccupancyMaxActiveBlocksPerMultiprocessor is not supported on HIP platform
319- // so, skipping the occupancy check for HIP platform
320- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
321318 // Determine the number of active blocks per SM
322319 // parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
323320 int numActiveBlocks = 1 ;
@@ -327,7 +324,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
327324 // we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
328325 // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
329326 // If there are not enough tiles to process, we can reduce the number of blocks
330- const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks , seqlen_tiles);
327+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / num_blocks_base , seqlen_tiles);
331328
332329 if (parallel_blocks >= 24 ) {
333330 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -341,22 +338,19 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
341338 else if (parallel_blocks >= 8 ) {
342339 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343340 }
344- else
345- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
346- {
341+ else {
347342 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
348343 }
349344 }
350345 else
351346 {
352347 constexpr bool use_logit_softcap = true ;
353348
354- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
355349 int numActiveBlocks = 1 ;
356350 CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
357351 flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
358352
359- const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks , seqlen_tiles);
353+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / num_blocks_base , seqlen_tiles);
360354
361355 if (parallel_blocks >= 24 ) {
362356 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -370,9 +364,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
370364 else if (parallel_blocks >= 8 ) {
371365 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372366 }
373- else
374- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
375- {
367+ else {
376368 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
377369 }
378370 }
0 commit comments