Skip to content

Commit acff049

Browse files
[NVPTX] Add max/minimumnum to ISel (#155804)
Add direct support for the LLVM `maximumnum` and `minimumnum` intrinsics for NVPTX, rather than lowering them to a sequence of compare + select instructions. The `maximumnum` and `minimumnum` intrinsics map directly to PTX `max`/`min` instructions. In future, the LLVM `maxnum`/`minnum` intrinsics might need some fix-ups for sNaN handling added, but currently both `llvm.maxnum` and `llvm.maximumnum` will map directly to PTX `max` instructions.
1 parent 12630ed commit acff049

File tree

3 files changed

+426
-4
lines changed

3 files changed

+426
-4
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
543543
case ISD::FMINNUM_IEEE:
544544
case ISD::FMAXIMUM:
545545
case ISD::FMINIMUM:
546+
case ISD::FMAXIMUMNUM:
547+
case ISD::FMINIMUMNUM:
546548
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
547549
break;
548550
case ISD::FEXP2:
@@ -989,7 +991,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
989991
if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
990992
AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
991993

992-
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
994+
for (const auto &Op :
995+
{ISD::FMINNUM, ISD::FMAXNUM, ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM}) {
993996
setOperationAction(Op, MVT::f32, Legal);
994997
setOperationAction(Op, MVT::f64, Legal);
995998
setFP16OperationAction(Op, MVT::f16, Legal, Promote);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
295295
//
296296
// Also defines ftz (flush subnormal inputs and results to sign-preserving
297297
// zero) variants for fp32 functions.
298-
multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
298+
multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDPatternOperator OpNode> {
299299
defvar nan_str = !if(NaN, ".NaN", "");
300300
if !not(NaN) then {
301301
def _f64_rr :
@@ -911,8 +911,15 @@ defm FADD : F3_fma_component<"add", fadd>;
911911
defm FSUB : F3_fma_component<"sub", fsub>;
912912
defm FMUL : F3_fma_component<"mul", fmul>;
913913

914-
defm MIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum>;
915-
defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
914+
def fminnum_or_fminimumnum : PatFrags<(ops node:$a, node:$b),
915+
[(fminnum node:$a, node:$b),
916+
(fminimumnum node:$a, node:$b)]>;
917+
def fmaxnum_or_fmaximumnum : PatFrags<(ops node:$a, node:$b),
918+
[(fmaxnum node:$a, node:$b),
919+
(fmaximumnum node:$a, node:$b)]>;
920+
921+
defm MIN : FMINIMUMMAXIMUM<"min", /* NaN */ false, fminnum_or_fminimumnum>;
922+
defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum_or_fmaximumnum>;
916923
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
917924
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;
918925

0 commit comments

Comments
 (0)