Skip to content

Commit c959b67

Browse files
CUDA: fix FA occupancy, optimize tile kernel (#15982)
1 parent cd08fc3 commit c959b67

File tree

4 files changed

+353
-245
lines changed

4 files changed

+353
-245
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
7676
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
7777
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
78+
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
79+
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
7880
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
7981

8082
// Moore Threads
@@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
325327
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
326328
}
327329

330+
// Maximum number of bytes that can be copied in a single instruction.
331+
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
332+
#ifdef GGML_USE_HIP
333+
return 16;
334+
#else
335+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
336+
return 16;
337+
#else
338+
return 8;
339+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
340+
#endif // GGML_USE_HIP
341+
}
342+
343+
328344
[[noreturn]]
329345
static __device__ void no_device_code(
330346
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
647647
}
648648

649649
template<int D> // D == head size
650-
#if !defined(GGML_USE_HIP)
651650
__launch_bounds__(D, 1)
652-
#endif // !(defined(GGML_USE_HIP)
653651
static __global__ void flash_attn_combine_results(
654652
const float * __restrict__ VKQ_parts,
655653
const float2 * __restrict__ VKQ_meta,
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
692690
float VKQ_numerator = 0.0f;
693691
float VKQ_denominator = 0.0f;
694692
for (int l = 0; l < parallel_blocks; ++l) {
695-
const float diff = meta[l].x - kqmax;
696-
float KQ_max_scale = expf(diff);
697-
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
698-
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
693+
const float KQ_max_scale = expf(meta[l].x - kqmax);
699694

700695
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
701696
VKQ_denominator += KQ_max_scale * meta[l].y;
@@ -836,11 +831,10 @@ void launch_fattn(
836831
CUDA_CHECK(cudaGetLastError());
837832
}
838833

839-
int parallel_blocks = 1;
840-
841834
const dim3 block_dim(warp_size, nwarps, 1);
842835
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
843836
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
837+
int parallel_blocks = max_blocks_per_sm;
844838

845839
dim3 blocks_num;
846840
if (stream_k) {
@@ -862,9 +856,6 @@ void launch_fattn(
862856
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
863857
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
864858

865-
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866-
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
867-
868859
// parallel_blocks must not be larger than what the tensor size allows:
869860
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
870861

0 commit comments

Comments
 (0)