@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
647
647
}
648
648
649
649
template <int D> // D == head size
650
- #if !defined(GGML_USE_HIP)
651
650
__launch_bounds__ (D, 1 )
652
- #endif // !(defined(GGML_USE_HIP)
653
651
static __global__ void flash_attn_combine_results(
654
652
const float * __restrict__ VKQ_parts,
655
653
const float2 * __restrict__ VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
692
690
float VKQ_numerator = 0 .0f ;
693
691
float VKQ_denominator = 0 .0f ;
694
692
for (int l = 0 ; l < parallel_blocks; ++l) {
695
- const float diff = meta[l].x - kqmax;
696
- float KQ_max_scale = expf (diff);
697
- const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
698
- *((uint32_t *) &KQ_max_scale) &= ftz_mask;
693
+ const float KQ_max_scale = expf (meta[l].x - kqmax);
699
694
700
695
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
701
696
VKQ_denominator += KQ_max_scale * meta[l].y ;
@@ -836,11 +831,10 @@ void launch_fattn(
836
831
CUDA_CHECK (cudaGetLastError ());
837
832
}
838
833
839
- int parallel_blocks = 1 ;
840
-
841
834
const dim3 block_dim (warp_size, nwarps, 1 );
842
835
int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
843
836
CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
837
+ int parallel_blocks = max_blocks_per_sm;
844
838
845
839
dim3 blocks_num;
846
840
if (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
862
856
GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
863
857
const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
864
858
865
- // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866
- parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
867
-
868
859
// parallel_blocks must not be larger than what the tensor size allows:
869
860
parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
870
861
0 commit comments