-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[core] support sage attention through kernels
#12439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
sayakpaul
wants to merge
5
commits into
main
Choose a base branch
from
sage-kernels
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
""" | ||
Copyright (c) 2024 by SageAttention, The HuggingFace team. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the | ||
License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an | ||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
""" | ||
|
||
""" | ||
Modified from | ||
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py | ||
""" | ||
|
||
|
||
import torch # noqa | ||
|
||
|
||
SAGE_ATTENTION_DISPATCH = { | ||
"sm80": { | ||
"func": "sageattn_qk_int8_pv_fp16_cuda", | ||
"kwargs": { | ||
"tensor_layout": "NHD", | ||
"is_causal": False, | ||
"sm_scale": None, | ||
"return_lse": False, | ||
"pv_accum_dtype": "fp32", | ||
}, | ||
}, | ||
"sm89": { | ||
"func": "sageattn_qk_int8_pv_fp8_cuda", | ||
"kwargs": { | ||
"tensor_layout": "NHD", | ||
"is_causal": False, | ||
"sm_scale": None, | ||
"return_lse": False, | ||
"pv_accum_dtype": "fp32+fp16", | ||
}, | ||
}, | ||
"sm90": { | ||
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90", | ||
"kwargs": { | ||
"tensor_layout": "NHD", | ||
"is_causal": False, | ||
"sm_scale": None, | ||
"return_lse": False, | ||
"pv_accum_dtype": "fp32+fp32", | ||
}, | ||
}, | ||
"sm120": { | ||
"func": "sageattn_qk_int8_pv_fp8_cuda", | ||
"kwargs": { | ||
"tensor_layout": "NHD", | ||
"is_causal": False, | ||
"qk_quant_gran": "per_warp", | ||
"sm_scale": None, | ||
"return_lse": False, | ||
"pv_accum_dtype": "fp32+fp16", | ||
}, | ||
}, | ||
} | ||
|
||
|
||
def get_cuda_version(): | ||
if torch.cuda.is_available(): | ||
major, minor = torch.cuda.get_device_capability() | ||
return major, minor | ||
else: | ||
raise EnvironmentError("CUDA not found.") | ||
|
||
|
||
def get_cuda_arch_versions(): | ||
if not torch.cuda.is_available(): | ||
EnvironmentError("CUDA not found.") | ||
cuda_archs = [] | ||
for i in range(torch.cuda.device_count()): | ||
major, minor = torch.cuda.get_device_capability(i) | ||
cuda_archs.append(f"sm{major}{minor}") | ||
return cuda_archs | ||
|
||
|
||
# Unlike the actual implementation, we just maintain function names rather than actual | ||
# implementations. | ||
def _get_sage_attn_fn_for_device(): | ||
""" | ||
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute | ||
capability. | ||
|
||
Parameters ---------- q : torch.Tensor | ||
The query tensor. Shape: | ||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. | ||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. | ||
|
||
k : torch.Tensor | ||
The key tensor. Shape: | ||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. | ||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. | ||
|
||
v : torch.Tensor | ||
The value tensor. Shape: | ||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. | ||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. | ||
|
||
tensor_layout : str | ||
The tensor layout, either "HND" or "NHD". Default: "HND". | ||
|
||
is_causal : bool | ||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False. | ||
|
||
sm_scale : Optional[float] | ||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. | ||
|
||
return_lse : bool | ||
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention. | ||
Default: False. | ||
|
||
Returns ------- torch.Tensor | ||
The output tensor. Shape: | ||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. | ||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. | ||
|
||
torch.Tensor | ||
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape: | ||
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True. | ||
|
||
Note ---- | ||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``. | ||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` | ||
- All tensors must be on the same cuda device. | ||
""" | ||
device_index = torch.cuda.current_device() | ||
arch = get_cuda_arch_versions()[device_index] | ||
return SAGE_ATTENTION_DISPATCH[arch] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see their usage, hence removed.