@@ -315,66 +315,68 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
315315 if (logit_softcap == 0 .0f ) {
316316 constexpr bool use_logit_softcap = false ;
317317
318+ // cudaOccupancyMaxActiveBlocksPerMultiprocessor is not supported on HIP platform
319+ // so, skipping the occupancy check for HIP platform
320+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
318321 // Determine the number of active blocks per SM
319322 // parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
320323 int numActiveBlocks = 1 ;
321- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
324+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
325+ flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
322326
323327 // we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
324328 // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
325329 // If there are not enough tiles to process, we can reduce the number of blocks
326330 const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327331
328- if (parallel_blocks >= 24 )
329- {
332+ if (parallel_blocks >= 24 ) {
330333 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
331334 }
332- else if (parallel_blocks >= 16 )
333- {
335+ else if (parallel_blocks >= 16 ) {
334336 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
335337 }
336- else if (parallel_blocks >= 12 )
337- {
338+ else if (parallel_blocks >= 12 ) {
338339 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
339340 }
340- else if (parallel_blocks >= 8 )
341- {
341+ else if (parallel_blocks >= 8 ) {
342342 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343343 }
344- else
344+ else
345+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
345346 {
346347 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
347348 }
348349 }
349350 else
350351 {
351352 constexpr bool use_logit_softcap = true ;
353+
354+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
352355 int numActiveBlocks = 1 ;
353- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
356+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
357+ flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
354358
355359 const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
356360
357- if (parallel_blocks >= 24 )
358- {
361+ if (parallel_blocks >= 24 ) {
359362 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
360363 }
361- else if (parallel_blocks >= 16 )
362- {
364+ else if (parallel_blocks >= 16 ) {
363365 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
364366 }
365- else if (parallel_blocks >= 12 )
366- {
367+ else if (parallel_blocks >= 12 ) {
367368 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
368369 }
369- else if (parallel_blocks >= 8 )
370- {
370+ else if (parallel_blocks >= 8 ) {
371371 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372372 }
373- else
373+ else
374+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
374375 {
375376 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
376377 }
377378 }
379+
378380 return ;
379381 }
380382
0 commit comments