Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 17 additions & 90 deletions gpt_oss/triton/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor



@triton.jit
Expand All @@ -23,22 +25,6 @@ def _attn_fwd(
M,
Out, #
Start_q,
stride_qz,
stride_qh,
stride_qm,
stride_qk, #
stride_kz,
stride_kh,
stride_kn,
stride_kk, #
stride_vz,
stride_vh,
stride_vn,
stride_vk, #
stride_oz,
stride_oh,
stride_om,
stride_ok, #
Z,
H,
N_Q_CTX,
Expand All @@ -54,44 +40,6 @@ def _attn_fwd(
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
v_offset = off_z.to(tl.int64) * stride_vz + off_h.to(tl.int64) * stride_vh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh

# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(N_Q_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(N_KV_CTX, HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(HEAD_DIM, N_KV_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(N_Q_CTX, HEAD_DIM),
strides=(stride_om, stride_ok),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)

# load attention sinks
if Sinks is not None:
Expand All @@ -108,17 +56,13 @@ def _attn_fwd(
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
q = tl.load(Q_block_ptr)
q = Q.load([off_z, off_h, start_m * BLOCK_M, 0]).reshape([BLOCK_M, HEAD_DIM])

if BANDWIDTH:
lo, hi = tl.maximum(start_q, start_q + start_m * BLOCK_M - BANDWIDTH), (start_q + start_m + 1) * BLOCK_M
else:
lo, hi = start_q, (start_q + start_m + 1) * BLOCK_M

# advance the KV block-pointers so they point at `lo`
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)

Expand All @@ -128,7 +72,7 @@ def _attn_fwd(
too_old = (start_n + offs_n[None, :]) < (start_q + offs_m[:, None] - BANDWIDTH + 1)
mask = mask | too_old

k = tl.load(K_block_ptr)
k = K.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM]).T
qk = tl.dot(q, k, allow_tf32=False)

qk = qk * qk_scale + tl.where(mask, -1.0e6, 0.0)
Expand All @@ -140,22 +84,21 @@ def _attn_fwd(
l_ij = tl.sum(p, 1)
acc = acc * alpha[:, None]

v = tl.load(V_block_ptr).to(tl.float32)
v = V.load([off_z, off_h, start_n, 0]).reshape([BLOCK_N, HEAD_DIM])
v = v.to(tl.float32)
acc = tl.dot(p, v, acc, allow_tf32=False)

l_i = l_i * alpha + l_ij
m_i = m_ij

V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))

sink = tl.math.exp(sink - m_i)
z = l_i + sink
acc = acc / z[:, None]
m_i += tl.math.log(l_i)
m_ptrs = M + off_hz * N_Q_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
acc = acc.to(Out.dtype)[None, None, :, :]
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


class _attention(torch.autograd.Function):
Expand Down Expand Up @@ -189,35 +132,19 @@ def forward(ctx, q, k, v, sinks, sm_scale, bandwidth, start_q):
M = torch.empty((bs, n_heads, n_ctx + m_pad_size), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(n_ctx, BLOCK_M), bs * n_heads, 1)
_attn_fwd[grid](
q,
k,
v,
TensorDescriptor.from_tensor(q, [1, 1, BLOCK_M, HEAD_DIM_K]),
TensorDescriptor.from_tensor(k, [1, 1, BLOCK_N, HEAD_DIM_K]),
TensorDescriptor.from_tensor(v, [1, 1, BLOCK_N, HEAD_DIM_K]),
sinks,
sm_scale,
M,
o, #
TensorDescriptor.from_tensor(o, [1, 1, BLOCK_M, HEAD_DIM_K]),
start_q,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3), #
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3), #
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3), #
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3), #
q.shape[0],
q.shape[1], #
N_Q_CTX=n_ctx + m_pad_size, #
N_KV_CTX=n_kv_ctx, #
HEAD_DIM=HEAD_DIM_K, #
q.shape[1],
N_Q_CTX=n_ctx + m_pad_size,
N_KV_CTX=n_kv_ctx,
HEAD_DIM=HEAD_DIM_K,
BANDWIDTH=bandwidth,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
Expand Down Expand Up @@ -299,4 +226,4 @@ def test_eq(batch_size, num_queries, num_keys, num_key_value_heads, num_key_valu
o1 = attention(q, k, v, sinks, sm_scale, sliding_window, start_q)
o2 = attention_ref(q, k, v, sinks, sm_scale, sliding_window, start_q)

torch.testing.assert_close(o1, o2)
torch.testing.assert_close(o1, o2)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.12,<3.13"
version = "0.0.3"

[project.optional-dependencies]
triton = ["triton", "safetensors>=0.5.3", "torch>=2.7.0"]
triton = ["triton>=3.4", "safetensors>=0.5.3", "torch>=2.7.0"]
torch = ["safetensors>=0.5.3", "torch>=2.7.0"]
metal = ["numpy", "tqdm", "safetensors", "torch"]
test = ["pytest>=8.4.1", "httpx>=0.28.1"]
Expand Down