diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu index 3a673d3e3..d1b924a20 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu @@ -25,10 +25,13 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel::value; int tlength = params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu index 7e20bdccc..827c5439e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu @@ -25,10 +25,13 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + mmha::masked_groupedquery_attention_kernel