-
Notifications
You must be signed in to change notification settings - Fork 13.2k
hip : substituted bpermute ops with swizzle ops (gfx906, maybe all AMD) #16291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
The CUDA vector FA kernels are being refactored in #16208 . In particular the KQ accumulation and softmax is always done using FP32 because the use of FP16 can cause numerical issues. The speedup of the other PR is well above 1.2 for my MI50, please use that version as the base for further performance optimizations. |
Yeah i imagined that (just like what happened with the tile kernel). i just added the optimizations to the common.cuh for pull req which is the main performance improvement and sligthly impacts also tile kernel. The f16 vector modifications are just experiments. Thanks |
bpermute indeed has higher latency than the swizzle operations and is used for the shfl operators so this pr is in principal of value. However before it can be properly reviewed it needs to be cleaned up. Please remove all changes except those to the the reduction operators. |
There was an old PR that with similar change that used |
I'm testing it with build 6615h. Propably we get a +0.2% in pp and +1.2% in tg. Performance is so similar that is not possible to actually tell. The old f16 kernel was using thread reduction massively so it was much more effective there. (Thousands of bpermute wrt to few hundreds here). New kernel is just too good itself. will update if i find out something. |
In the new kernel the number of threads per K row/V column is configurable. If the reduction within a warp becomes faster the optimal values for the number of threads may change. |
The working modification is in common.cuh . it provides a slight increase of speed in pp and tg (only with real cases or depth ≠ 0) I also played a bit with the thread counts but it seems to be already optimal as it is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iacopPBK as I mentioned before this pr is fundamentally not in a state that is ready for review, as it contains a bunch of unrelated changes.
Please create a branch in your repo with just the changes to the reduction operators and then open a pr with just those changes, which i will then test for regressions an review.
If you would rather not do this, please close the pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, I don't want to maintain the the swizzle ops as they are currently used. Please write a function ggml_cuda_shfl_xor_sync
that internally either uses __shfl_xor_sync
for NVIDIA GPUs or whatever is equivalent but faster for AMD.
In terms of performance testing, let's please wait for a few days until I receive a riser cable which I intend to use for my newly purchased RX 9060 XT.
@iacopPBK , you can probably get even more PP/TG by using DPP16 instead of swizzle instructions for offsets < 16, like this: template <int dpp_ctrl, typename T, int row_mask = 0xf, int bank_mask = 0xf, bool bound_ctrl = true>
static __device__ __forceinline__ T hip_update_dpp(T old, T v) {
return __builtin_bit_cast(
T,
__builtin_amdgcn_update_dpp(
__builtin_bit_cast(int, old),
__builtin_bit_cast(int, v),
dpp_ctrl,
row_mask,
bank_mask,
bound_ctrl
)
);
}
template <int mask, typename T>
static __device__ __forceinline__ T hip_ds_swizzle(T v) {
return __builtin_bit_cast(T, __builtin_amdgcn_ds_swizzle(__builtin_bit_cast(int, v), mask));
}
template<int offset, typename T>
static __device__ __forceinline__ T hip_shfl(T val) {
if constexpr (offset == 32) {
return __shfl_xor_sync(0xffffffff, val, 32, 64);
}
else if constexpr (offset == 16) {
return hip_ds_swizzle<0x401f>(val); // swap neighboring groups of 16
}
else if constexpr (offset && (offset & 0xf) == offset) {
static T initial;
return hip_update_dpp<0x160 + offset>(initial, val); // ROW_XMASK:offset
// there are quad_perm and row_ror options for where xmask is not supported.
}
else {
static_assert(false, "Unhandled offset");
}
}
template<int width>
static __device__ __forceinline__ int hip_warp_reduce_sum(int x) {
if constexpr (width < 2) {
return x;
}
else {
return hip_warp_reduce_sum<width/2>(x + hip_shfl<width/2>(x));
// unlike shfl_xor, swizzle/dpp ctrl and masks have to be compile time constants,
// hence can't use straight loop.
}
} |
Replaced bpermute instructions with native swizzle operations in the HIP backend, specifically targeting GFX906 architecture (MI50/MI60/Vega VII). The primitive implementation and dispatch of swizzles is contained in common.cuh file for your review.
I verified no degradation in model quality and benchmarked in llama-bench: see the gfx906 fork README.md file to see performance improvements (+20% inference speed on avg in both synthetic and real cases).
I only tested it on gfx906, i didn't verfify if this is compatible with all GGML_USE_HIP hardware.