@@ -321,11 +321,11 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
321321 // Determine the number of active blocks per SM
322322 // parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
323323 int numActiveBlocks = 1 ;
324- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
324+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
325325 flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
326326
327327 // we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
328- // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
328+ // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
329329 // If there are not enough tiles to process, we can reduce the number of blocks
330330 const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
331331
@@ -341,7 +341,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
341341 else if (parallel_blocks >= 8 ) {
342342 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343343 }
344- else
344+ else
345345#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
346346 {
347347 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
@@ -353,7 +353,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
353353
354354#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
355355 int numActiveBlocks = 1 ;
356- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
356+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks,
357357 flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
358358
359359 const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
@@ -370,7 +370,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
370370 else if (parallel_blocks >= 8 ) {
371371 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372372 }
373- else
373+ else
374374#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
375375 {
376376 ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
0 commit comments