@@ -672,9 +672,8 @@ static void on_no_fattn_vec_case(const int D) {
672672
673673template <int D, int ncols1, int ncols2, int KQ_stride>
674674void launch_fattn (
675- ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
676- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V, const bool stream_k,
677- const int warp_size = WARP_SIZE
675+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
676+ const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
678677) {
679678 constexpr int ncols = ncols1 * ncols2;
680679
@@ -697,9 +696,6 @@ void launch_fattn(
697696
698697 GGML_ASSERT (Q->ne [3 ] == 1 );
699698
700- GGML_ASSERT (stream_k || ncols2 == 1 );
701- const bool use_parallel_blocks = !stream_k && (Q->ne [1 ] <= ncols1) ? true : false ;
702-
703699 ggml_cuda_pool & pool = ctx.pool ();
704700 cudaStream_t main_stream = ctx.stream ();
705701 const int id = ggml_cuda_get_device ();
@@ -772,19 +768,38 @@ void launch_fattn(
772768
773769 dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + D) * sizeof (float ));
774770 } else {
775- if (use_parallel_blocks) {
776- const int num_blocks_base = ntiles_x*Q->ne [2 ]*Q->ne [3 ];
777- const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
778- const int seqlen_tiles = (K->ne [1 ] + D - 1 ) / D;
779-
780- // Determine the number of active blocks per SM
781- int numActiveBlocks = 1 ;
782- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
783-
784- // we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
785- // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
786- // If there are not enough tiles to process, we can reduce the number of blocks
787- parallel_blocks = std::max (std::min ((nsm * numActiveBlocks) / num_blocks_base, seqlen_tiles), 1 );
771+ GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
772+ const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
773+
774+ int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
775+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
776+
777+ // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
778+ parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
779+
780+ // parallel_blocks must not be larger than what the tensor size allows:
781+ parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
782+
783+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
784+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
785+ const int blocks_per_wave = nsm * max_blocks_per_sm;
786+ int nwaves_best = 0 ;
787+ int efficiency_percent_best = 0 ;
788+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
789+ const int nblocks_total = ntiles_total * parallel_blocks_test;
790+ const int nwaves = (nblocks_total + blocks_per_wave - 1 ) / blocks_per_wave;
791+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
792+
793+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
794+ if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
795+ break ;
796+ }
797+
798+ if (efficiency_percent > efficiency_percent_best) {
799+ nwaves_best = nwaves;
800+ efficiency_percent_best = efficiency_percent;
801+ parallel_blocks = parallel_blocks_test;
802+ }
788803 }
789804
790805 blocks_num.x = ntiles_x;
0 commit comments