Skip to content

[NVPTX] Constant-folding for f2i, d2ui, f2ll etc. #118965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 0 additions & 39 deletions llvm/include/llvm/IR/NVVMIntrinsicFlags.h

This file was deleted.

176 changes: 176 additions & 0 deletions llvm/include/llvm/IR/NVVMIntrinsicUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//===--- NVVMIntrinsicUtils.h -----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file
/// This file contains the definitions of the enumerations and flags
/// associated with NVVM Intrinsics, along with some helper functions.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_IR_NVVMINTRINSICUTILS_H
#define LLVM_IR_NVVMINTRINSICUTILS_H

#include <stdint.h>

#include "llvm/ADT/APFloat.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

namespace llvm {
namespace nvvm {

// Reduction Ops supported with TMA Copy from Shared
// to Global Memory for the "cp.reduce.async.bulk.tensor.*"
// family of PTX instructions.
enum class TMAReductionOp : uint8_t {
ADD = 0,
MIN = 1,
MAX = 2,
INC = 3,
DEC = 4,
AND = 5,
OR = 6,
XOR = 7,
};

inline bool IntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// Float to i32 / i64 conversion intrinsics:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes me wonder if we can add these FTZ, signedness, and rounding mode, properties as some sort of flag on the intrinsic in tablegen, where we define them, so we don't have to play a whack-a-mole updating these switches every time we add/change a NVPTX intrinsic.

If that's too cumbersome, we could construct a static table/map of intrinsic->flags, populate it once, and then just lookup individual intrinsic flag where we need it. That may be less cumbersome to update in the future.

case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz_ftz:

case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:

case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz_ftz:

case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
return true;
}
return false;
}

inline bool IntrinsicConvertsToSignedInteger(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// f2i
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2i_rz_ftz:
// d2i
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:
// f2ll
case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ll_rz_ftz:
// d2ll
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
return true;
}
return false;
}

inline APFloat::roundingMode
IntrinsicGetRoundingMode(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
// RM:
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2ui_rm:

case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ull_rm:
return APFloat::rmTowardNegative;

// RN:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2ui_rn:

case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ull_rn:
return APFloat::rmNearestTiesToEven;

// RP:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2ui_rp:

case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ull_rp:
return APFloat::rmTowardPositive;

// RZ:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2i_rz_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:
case Intrinsic::nvvm_d2i_rz:
case Intrinsic::nvvm_d2ui_rz:

case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ll_rz_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
case Intrinsic::nvvm_d2ll_rz:
case Intrinsic::nvvm_d2ull_rz:
return APFloat::rmTowardZero;
}
llvm_unreachable("Invalid f2i/d2i rounding mode intrinsic");
return APFloat::roundingMode::Invalid;
}

} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
139 changes: 139 additions & 0 deletions llvm/lib/Analysis/ConstantFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsARM.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
Expand Down Expand Up @@ -1687,6 +1689,58 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::x86_avx512_cvttsd2usi64:
return !Call->isStrictFP();

// NVVM float/double to int32/uint32 conversion intrinsics
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz_ftz:
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:
case Intrinsic::nvvm_d2ui_rm:
case Intrinsic::nvvm_d2ui_rn:
case Intrinsic::nvvm_d2ui_rp:
case Intrinsic::nvvm_d2ui_rz:

// NVVM float/double to int64/uint64 conversion intrinsics
case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz_ftz:
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
case Intrinsic::nvvm_d2ull_rm:
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz:

// Sign operations are actually bitwise operations, they do not raise
// exceptions even for SNANs.
case Intrinsic::fabs:
Expand Down Expand Up @@ -1849,6 +1903,12 @@ inline bool llvm_fenv_testexcept() {
return false;
}

static const APFloat FTZPreserveSign(const APFloat &V) {
if (V.isDenormal())
return APFloat::getZero(V.getSemantics(), V.isNegative());
return V;
}

Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
Type *Ty) {
llvm_fenv_clearexcept();
Expand Down Expand Up @@ -2309,6 +2369,85 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
return ConstantFP::get(Ty->getContext(), U);
}

// NVVM float/double to signed/unsigned int32/int64 conversions:
switch (IntrinsicID) {
// f2i
case Intrinsic::nvvm_f2i_rm:
case Intrinsic::nvvm_f2i_rn:
case Intrinsic::nvvm_f2i_rp:
case Intrinsic::nvvm_f2i_rz:
case Intrinsic::nvvm_f2i_rm_ftz:
case Intrinsic::nvvm_f2i_rn_ftz:
case Intrinsic::nvvm_f2i_rp_ftz:
case Intrinsic::nvvm_f2i_rz_ftz:
// f2ui
case Intrinsic::nvvm_f2ui_rm:
case Intrinsic::nvvm_f2ui_rn:
case Intrinsic::nvvm_f2ui_rp:
case Intrinsic::nvvm_f2ui_rz:
case Intrinsic::nvvm_f2ui_rm_ftz:
case Intrinsic::nvvm_f2ui_rn_ftz:
case Intrinsic::nvvm_f2ui_rp_ftz:
case Intrinsic::nvvm_f2ui_rz_ftz:
// d2i
case Intrinsic::nvvm_d2i_rm:
case Intrinsic::nvvm_d2i_rn:
case Intrinsic::nvvm_d2i_rp:
case Intrinsic::nvvm_d2i_rz:
// d2ui
case Intrinsic::nvvm_d2ui_rm:
case Intrinsic::nvvm_d2ui_rn:
case Intrinsic::nvvm_d2ui_rp:
case Intrinsic::nvvm_d2ui_rz:
// f2ll
case Intrinsic::nvvm_f2ll_rm:
case Intrinsic::nvvm_f2ll_rn:
case Intrinsic::nvvm_f2ll_rp:
case Intrinsic::nvvm_f2ll_rz:
case Intrinsic::nvvm_f2ll_rm_ftz:
case Intrinsic::nvvm_f2ll_rn_ftz:
case Intrinsic::nvvm_f2ll_rp_ftz:
case Intrinsic::nvvm_f2ll_rz_ftz:
// f2ull
case Intrinsic::nvvm_f2ull_rm:
case Intrinsic::nvvm_f2ull_rn:
case Intrinsic::nvvm_f2ull_rp:
case Intrinsic::nvvm_f2ull_rz:
case Intrinsic::nvvm_f2ull_rm_ftz:
case Intrinsic::nvvm_f2ull_rn_ftz:
case Intrinsic::nvvm_f2ull_rp_ftz:
case Intrinsic::nvvm_f2ull_rz_ftz:
// d2ll
case Intrinsic::nvvm_d2ll_rm:
case Intrinsic::nvvm_d2ll_rn:
case Intrinsic::nvvm_d2ll_rp:
case Intrinsic::nvvm_d2ll_rz:
// d2ull
case Intrinsic::nvvm_d2ull_rm:
case Intrinsic::nvvm_d2ull_rn:
case Intrinsic::nvvm_d2ull_rp:
case Intrinsic::nvvm_d2ull_rz: {
// In float-to-integer conversion, NaN inputs are converted to 0.
if (U.isNaN())
return ConstantInt::get(Ty, 0);

APFloat::roundingMode RMode = nvvm::IntrinsicGetRoundingMode(IntrinsicID);
bool IsFTZ = nvvm::IntrinsicShouldFTZ(IntrinsicID);
bool IsSigned = nvvm::IntrinsicConvertsToSignedInteger(IntrinsicID);

APSInt ResInt(Ty->getIntegerBitWidth(), !IsSigned);
auto FloatToRound = IsFTZ ? FTZPreserveSign(U) : U;

bool IsExact = false;
APFloat::opStatus Status =
FloatToRound.convertToInteger(ResInt, RMode, &IsExact);

if (Status != APFloat::opInvalidOp)
return ConstantInt::get(Ty, ResInt);
return nullptr;
}
}

/// We only fold functions with finite arguments. Folding NaN and inf is
/// likely to be aborted with an exception anyway, and some host libms
/// have known errors raising exceptions.
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/NVVMIntrinsicFlags.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCInst.h"
#include "llvm/MC/MCInstrInfo.h"
Expand Down
Loading
Loading