Skip to content

Commit f33cff7

Browse files
authored
Fix out-of-bounds memory access in Galore dequant kernel (#1125)
1 parent a2faafe commit f33cff7

File tree

1 file changed

+8
-3
lines changed
  • torchao/prototype/galore/kernels

1 file changed

+8
-3
lines changed

torchao/prototype/galore/kernels/quant.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def _dequant_kernel(
1111
dq_ptr,
1212
stride_qm,
1313
stride_qn,
14+
M,
15+
N,
1416
GROUP_SIZE: tl.constexpr,
1517
BLOCK_M: tl.constexpr,
1618
BLOCK_N: tl.constexpr,
@@ -22,17 +24,18 @@ def _dequant_kernel(
2224
# rm = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
2325
# rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
2426
offsets = rm[:, None] * stride_qm + rn[None, :] * stride_qn
27+
mask = (rm[:, None] < M) & (rn[None, :] < N)
2528
tl.static_print(offsets)
2629
group_offsets = offsets // GROUP_SIZE
2730
tl.static_print("group_offsets", group_offsets)
28-
q_idx = tl.load(q_idx_ptr + offsets)
31+
q_idx = tl.load(q_idx_ptr + offsets, mask=mask)
2932
tl.static_print(q_idx)
3033
# NOTE: Must upcast q_idx to int32 (q_idx is tl.uint8, which does not work for pointer indexing)
3134
q_vals = tl.load(qmap_ptr + q_idx.to(tl.int32))
32-
absmax = tl.load(absmax_ptr + group_offsets)
35+
absmax = tl.load(absmax_ptr + group_offsets, mask=group_offsets < (M * N // GROUP_SIZE))
3336

3437
dq = q_vals * absmax
35-
tl.store(dq_ptr + offsets, dq)
38+
tl.store(dq_ptr + offsets, dq, mask=mask)
3639

3740

3841
def triton_dequant_blockwise(
@@ -51,6 +54,8 @@ def triton_dequant_blockwise(
5154
dq,
5255
q.stride(0),
5356
q.stride(1),
57+
M,
58+
N,
5459
BLOCK_M=1,
5560
BLOCK_N=group_size,
5661
GROUP_SIZE=group_size,

0 commit comments

Comments
 (0)