@@ -11,6 +11,8 @@ def _dequant_kernel(
1111 dq_ptr ,
1212 stride_qm ,
1313 stride_qn ,
14+ M ,
15+ N ,
1416 GROUP_SIZE : tl .constexpr ,
1517 BLOCK_M : tl .constexpr ,
1618 BLOCK_N : tl .constexpr ,
@@ -22,17 +24,18 @@ def _dequant_kernel(
2224 # rm = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
2325 # rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
2426 offsets = rm [:, None ] * stride_qm + rn [None , :] * stride_qn
27+ mask = (rm [:, None ] < M ) & (rn [None , :] < N )
2528 tl .static_print (offsets )
2629 group_offsets = offsets // GROUP_SIZE
2730 tl .static_print ("group_offsets" , group_offsets )
28- q_idx = tl .load (q_idx_ptr + offsets )
31+ q_idx = tl .load (q_idx_ptr + offsets , mask = mask )
2932 tl .static_print (q_idx )
3033 # NOTE: Must upcast q_idx to int32 (q_idx is tl.uint8, which does not work for pointer indexing)
3134 q_vals = tl .load (qmap_ptr + q_idx .to (tl .int32 ))
32- absmax = tl .load (absmax_ptr + group_offsets )
35+ absmax = tl .load (absmax_ptr + group_offsets , mask = group_offsets < ( M * N // GROUP_SIZE ) )
3336
3437 dq = q_vals * absmax
35- tl .store (dq_ptr + offsets , dq )
38+ tl .store (dq_ptr + offsets , dq , mask = mask )
3639
3740
3841def triton_dequant_blockwise (
@@ -51,6 +54,8 @@ def triton_dequant_blockwise(
5154 dq ,
5255 q .stride (0 ),
5356 q .stride (1 ),
57+ M ,
58+ N ,
5459 BLOCK_M = 1 ,
5560 BLOCK_N = group_size ,
5661 GROUP_SIZE = group_size ,
0 commit comments