-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AMDGPU] Add support for safe bfloat16 fdiv on targets with bf16 trans instructions #154373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…s instructions Recent changes introduced custom lowering for bf16 fdiv on targets that support bf16 trans instructions, but only covered the unsafe version. This PR extends that support to the safe variant. For the safe version, the op is lowered by converting to float, performing the div in float, and converting the result back to bf16. This matches the behavior on targets that don't support bf16 trans instructions.
This stack of pull requests is managed by Graphite. Learn more about stacking. |
|
@llvm/pr-subscribers-backend-amdgpu Author: Shilei Tian (shiltian) ChangesRecent changes introduced custom lowering for bf16 fdiv on targets that support bf16 trans instructions, but only covered the unsafe version. This PR extends that support to the safe variant. For the safe version, the op is lowered by converting to float, performing the div in float, and converting the result back to bf16. This matches the behavior on targets that don't support bf16 trans instructions. Full diff: https://github.com/llvm/llvm-project/pull/154373.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index a2084074263da..561019bb65549 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -11540,9 +11540,22 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
return FastLowered;
SDLoc SL(Op);
+ EVT VT = Op.getValueType();
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
+ SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
+ SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
+
+ if (VT == MVT::bf16) {
+ SDValue ExtDiv =
+ DAG.getNode(ISD::FDIV, SL, MVT::f32, LHSExt, RHSExt, Op->getFlags());
+ return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ExtDiv,
+ DAG.getTargetConstant(0, SL, MVT::i32));
+ }
+
+ assert(VT == MVT::f16);
+
// a32.u = opx(V_CVT_F32_F16, a.u); // CVT to F32
// b32.u = opx(V_CVT_F32_F16, b.u); // CVT to F32
// r32.u = opx(V_RCP_F32, b32.u); // rcp = 1 / d
@@ -11559,9 +11572,6 @@ SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const {
// We will use ISD::FMA on targets that don't support ISD::FMAD.
unsigned FMADOpCode =
isOperationLegal(ISD::FMAD, MVT::f32) ? ISD::FMAD : ISD::FMA;
-
- SDValue LHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, LHS);
- SDValue RHSExt = DAG.getNode(ISD::FP_EXTEND, SL, MVT::f32, RHS);
SDValue NegRHSExt = DAG.getNode(ISD::FNEG, SL, MVT::f32, RHSExt);
SDValue Rcp =
DAG.getNode(AMDGPUISD::RCP, SL, MVT::f32, RHSExt, Op->getFlags());
diff --git a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
index 91831a8d4fecb..00cde422a2297 100644
--- a/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll
@@ -2,12 +2,68 @@
; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -mattr=+real-true16 -denormal-fp-math-f32=preserve-sign < %s | FileCheck -check-prefixes=GFX1250-TRUE16 %s
; RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -mattr=-real-true16 -denormal-fp-math-f32=preserve-sign < %s | FileCheck -check-prefixes=GFX1250-FAKE16 %s
-/* TODO: Support safe bf16 fdiv lowering.
define bfloat @v_fdiv_bf16(bfloat %x, bfloat %y) {
+; GFX1250-TRUE16-LABEL: v_fdiv_bf16:
+; GFX1250-TRUE16: ; %bb.0:
+; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
+; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
+; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v2.l, 0
+; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v2.h, v1.l
+; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.h, v0.l
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_3) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.l, v2.l
+; GFX1250-TRUE16-NEXT: v_div_scale_f32 v0, null, v2, v2, v1
+; GFX1250-TRUE16-NEXT: v_div_scale_f32 v4, vcc_lo, v1, v2, v1
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_2) | instskip(SKIP_2) | instid1(TRANS32_DEP_1)
+; GFX1250-TRUE16-NEXT: v_rcp_f32_e32 v3, v0
+; GFX1250-TRUE16-NEXT: s_denorm_mode 15
+; GFX1250-TRUE16-NEXT: v_nop
+; GFX1250-TRUE16-NEXT: v_fma_f32 v5, -v0, v3, 1.0
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT: v_fmac_f32_e32 v3, v5, v3
+; GFX1250-TRUE16-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT: v_fma_f32 v6, -v0, v5, v4
+; GFX1250-TRUE16-NEXT: v_fmac_f32_e32 v5, v6, v3
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT: v_fma_f32 v0, -v0, v5, v4
+; GFX1250-TRUE16-NEXT: s_denorm_mode 12
+; GFX1250-TRUE16-NEXT: v_div_fmas_f32 v0, v0, v3, v5
+; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-TRUE16-NEXT: v_div_fixup_f32 v0, v0, v2, v1
+; GFX1250-TRUE16-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
+; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
+;
+; GFX1250-FAKE16-LABEL: v_fdiv_bf16:
+; GFX1250-FAKE16: ; %bb.0:
+; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
+; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
+; GFX1250-FAKE16-NEXT: v_dual_lshlrev_b32 v1, 16, v1 :: v_dual_lshlrev_b32 v0, 16, v0
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(VALU_DEP_2)
+; GFX1250-FAKE16-NEXT: v_div_scale_f32 v2, null, v1, v1, v0
+; GFX1250-FAKE16-NEXT: v_div_scale_f32 v4, vcc_lo, v0, v1, v0
+; GFX1250-FAKE16-NEXT: v_rcp_f32_e32 v3, v2
+; GFX1250-FAKE16-NEXT: s_denorm_mode 15
+; GFX1250-FAKE16-NEXT: v_nop
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT: v_fma_f32 v5, -v2, v3, 1.0
+; GFX1250-FAKE16-NEXT: v_fmac_f32_e32 v3, v5, v3
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT: v_mul_f32_e32 v5, v4, v3
+; GFX1250-FAKE16-NEXT: v_fma_f32 v6, -v2, v5, v4
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT: v_fmac_f32_e32 v5, v6, v3
+; GFX1250-FAKE16-NEXT: v_fma_f32 v2, -v2, v5, v4
+; GFX1250-FAKE16-NEXT: s_denorm_mode 12
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT: v_div_fmas_f32 v2, v2, v3, v5
+; GFX1250-FAKE16-NEXT: v_div_fixup_f32 v0, v2, v1, v0
+; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1)
+; GFX1250-FAKE16-NEXT: v_cvt_pk_bf16_f32 v0, v0, s0
+; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
%fdiv = fdiv bfloat %x, %y
ret bfloat %fdiv
}
-*/
define bfloat @v_rcp_bf16(bfloat %x) {
; GFX1250-TRUE16-LABEL: v_rcp_bf16:
|
rampitec
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
LGTM |

Recent changes introduced custom lowering for bf16 fdiv on targets that support bf16 trans instructions, but only covered the unsafe version. This PR extends that support to the safe variant.
For the safe version, the op is lowered by converting to float, performing the div in float, and converting the result back to bf16. This matches the behavior on targets that don't support bf16 trans instructions.
Fixes SWDEV-550381.