Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 168 additions & 5 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ enum class TMAReductionOp : uint8_t {
XOR = 7,
};

inline bool IntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// Float to i32 / i64 conversion intrinsics:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
Expand All @@ -61,11 +60,53 @@ inline bool IntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
return true;

case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rz:

case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2ui_rz:

case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:

case Intrinsic::nvvm_d2ui_rm:
case Intrinsic::nvvm_d2ui_rn:
case Intrinsic::nvvm_d2ui_rp:
case Intrinsic::nvvm_d2ui_rz:

case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rz:

case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ull_rz:

case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:

case Intrinsic::nvvm_d2ull_rm:
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz:
return false;
}
llvm_unreachable("Checking FTZ flag for invalid f2i/d2i intrinsic");
return false;
}

inline bool IntrinsicConvertsToSignedInteger(Intrinsic::ID IntrinsicID) {
inline bool FPToIntegerIntrinsicResultIsSigned(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// f2i
case Intrinsic::nvvm_f2i_rm:
Expand Down Expand Up @@ -96,12 +137,44 @@ inline bool IntrinsicConvertsToSignedInteger(Intrinsic::ID IntrinsicID) {
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
return true;

// f2ui
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2ui_rz_ftz:
// d2ui
case Intrinsic::nvvm_d2ui_rm:
case Intrinsic::nvvm_d2ui_rn:
case Intrinsic::nvvm_d2ui_rp:
case Intrinsic::nvvm_d2ui_rz:
// f2ull
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ull_rz_ftz:
// d2ull
case Intrinsic::nvvm_d2ull_rm:
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz:
return false;
}
llvm_unreachable(
"Checking invalid f2i/d2i intrinsic for signed int conversion");
return false;
}

inline APFloat::roundingMode
IntrinsicGetRoundingMode(Intrinsic::ID IntrinsicID) {
GetFPToIntegerRoundingMode(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// RM:
case Intrinsic::nvvm_f2i_rm:
Expand Down Expand Up @@ -167,10 +240,100 @@ IntrinsicGetRoundingMode(Intrinsic::ID IntrinsicID) {
case Intrinsic::nvvm_d2ull_rz:
return APFloat::rmTowardZero;
}
llvm_unreachable("Invalid f2i/d2i rounding mode intrinsic");
llvm_unreachable("Checking rounding mode for invalid f2i/d2i intrinsic");
return APFloat::roundingMode::Invalid;
}

inline bool FMinFMaxShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:

case Intrinsic::nvvm_fmin_ftz_f:
case Intrinsic::nvvm_fmin_ftz_nan_f:
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
return true;

case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_nan_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:

case Intrinsic::nvvm_fmin_d:
case Intrinsic::nvvm_fmin_f:
case Intrinsic::nvvm_fmin_nan_f:
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_xorsign_abs_f:
return false;
}
llvm_unreachable("Checking FTZ flag for invalid fmin/fmax intrinsic");
return false;
}

inline bool FMinFMaxPropagatesNaNs(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_nan_f:
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:

case Intrinsic::nvvm_fmin_ftz_nan_f:
case Intrinsic::nvvm_fmin_nan_f:
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
return true;

case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:

case Intrinsic::nvvm_fmin_d:
case Intrinsic::nvvm_fmin_f:
case Intrinsic::nvvm_fmin_ftz_f:
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmin_xorsign_abs_f:
return false;
}
llvm_unreachable("Checking NaN flag for invalid fmin/fmax intrinsic");
return false;
}

inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:

case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_xorsign_abs_f:
return true;

case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_nan_f:

case Intrinsic::nvvm_fmin_d:
case Intrinsic::nvvm_fmin_f:
case Intrinsic::nvvm_fmin_ftz_f:
case Intrinsic::nvvm_fmin_ftz_nan_f:
case Intrinsic::nvvm_fmin_nan_f:
return false;
}
llvm_unreachable("Checking XorSignAbs flag for invalid fmin/fmax intrinsic");
return false;
}

} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
139 changes: 136 additions & 3 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,28 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::x86_avx512_cvttsd2usi64:
return !Call->isStrictFP();

// NVVM FMax intrinsics
case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmax_nan_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:

// NVVM FMin intrinsics
case Intrinsic::nvvm_fmin_d:
case Intrinsic::nvvm_fmin_f:
case Intrinsic::nvvm_fmin_ftz_f:
case Intrinsic::nvvm_fmin_ftz_nan_f:
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmin_nan_f:
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_xorsign_abs_f:

// NVVM float/double to int32/uint32 conversion intrinsics
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rn:
Expand Down Expand Up @@ -2431,9 +2453,10 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
if (U.isNaN())
return ConstantInt::get(Ty, 0);

APFloat::roundingMode RMode = nvvm::IntrinsicGetRoundingMode(IntrinsicID);
bool IsFTZ = nvvm::IntrinsicShouldFTZ(IntrinsicID);
bool IsSigned = nvvm::IntrinsicConvertsToSignedInteger(IntrinsicID);
APFloat::roundingMode RMode =
nvvm::GetFPToIntegerRoundingMode(IntrinsicID);
bool IsFTZ = nvvm::FPToIntegerIntrinsicShouldFTZ(IntrinsicID);
bool IsSigned = nvvm::FPToIntegerIntrinsicResultIsSigned(IntrinsicID);

APSInt ResInt(Ty->getIntegerBitWidth(), !IsSigned);
auto FloatToRound = IsFTZ ? FTZPreserveSign(U) : U;
Expand Down Expand Up @@ -2892,12 +2915,49 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
case Intrinsic::minnum:
case Intrinsic::maximum:
case Intrinsic::minimum:
case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmin_d:
// If one argument is undef, return the other argument.
if (IsOp0Undef)
return Operands[1];
if (IsOp1Undef)
return Operands[0];
break;

case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmax_nan_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:

case Intrinsic::nvvm_fmin_f:
case Intrinsic::nvvm_fmin_ftz_f:
case Intrinsic::nvvm_fmin_ftz_nan_f:
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmin_nan_f:
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_xorsign_abs_f:
// If one arg is undef, the other arg can be returned only if it is
// constant, as we may need to flush it to sign-preserving zero or
// canonicalize the NaN.
if (!IsOp0Undef && !IsOp1Undef)
break;
if (auto *Op = dyn_cast<ConstantFP>(Operands[IsOp0Undef ? 1 : 0])) {
if (Op->isNaN()) {
APInt NVCanonicalNaN(32, 0x7fffffff);
return ConstantFP::get(
Ty, APFloat(Ty->getFltSemantics(), NVCanonicalNaN));
}
if (nvvm::FMinFMaxShouldFTZ(IntrinsicID))
return ConstantFP::get(Ty, FTZPreserveSign(Op->getValueAPF()));
else
return Op;
}
break;
}
}

Expand Down Expand Up @@ -2955,6 +3015,79 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
return ConstantFP::get(Ty->getContext(), minimum(Op1V, Op2V));
case Intrinsic::maximum:
return ConstantFP::get(Ty->getContext(), maximum(Op1V, Op2V));

case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmax_nan_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:

case Intrinsic::nvvm_fmin_d:
case Intrinsic::nvvm_fmin_f:
case Intrinsic::nvvm_fmin_ftz_f:
case Intrinsic::nvvm_fmin_ftz_nan_f:
case Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmin_nan_f:
case Intrinsic::nvvm_fmin_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmin_xorsign_abs_f: {

bool ShouldCanonicalizeNaNs = !(IntrinsicID == Intrinsic::nvvm_fmax_d ||
IntrinsicID == Intrinsic::nvvm_fmin_d);
bool IsFTZ = nvvm::FMinFMaxShouldFTZ(IntrinsicID);
bool IsNaNPropagating = nvvm::FMinFMaxPropagatesNaNs(IntrinsicID);
bool IsXorSignAbs = nvvm::FMinFMaxIsXorSignAbs(IntrinsicID);

APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;

bool XorSign = false;
if (IsXorSignAbs) {
XorSign = A.isNegative() ^ B.isNegative();
A = abs(A);
B = abs(B);
}

bool IsFMax = false;
switch (IntrinsicID) {
case Intrinsic::nvvm_fmax_d:
case Intrinsic::nvvm_fmax_f:
case Intrinsic::nvvm_fmax_ftz_f:
case Intrinsic::nvvm_fmax_ftz_nan_f:
case Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_ftz_xorsign_abs_f:
case Intrinsic::nvvm_fmax_nan_f:
case Intrinsic::nvvm_fmax_nan_xorsign_abs_f:
case Intrinsic::nvvm_fmax_xorsign_abs_f:
IsFMax = true;
break;
}
APFloat Res = IsFMax ? maximum(A, B) : minimum(A, B);

if (ShouldCanonicalizeNaNs) {
APFloat NVCanonicalNaN(Res.getSemantics(), APInt(32, 0x7fffffff));
if (A.isNaN() && B.isNaN())
return ConstantFP::get(Ty, NVCanonicalNaN);
else if (IsNaNPropagating && (A.isNaN() || B.isNaN()))
return ConstantFP::get(Ty, NVCanonicalNaN);
}

if (A.isNaN() && B.isNaN())
return Operands[1];
else if (A.isNaN())
Res = B;
else if (B.isNaN())
Res = A;

if (IsXorSignAbs && XorSign != Res.isNegative())
Res.changeSign();

return ConstantFP::get(Ty->getContext(), Res);
}
}

if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
Expand Down
Loading
Loading