Skip to content

Commit 9ac759f

Browse files
authored
[SYCL-MLIR] Improve CodeGen for Add and Sub expressions (#7421)
Enable addition/subtraction of floating point vectors. Make integer and pointer addition commutative. Instead of generating an incorrect `getelementptr` instruction with a `null` pointer argument, generate an `inttoptr` using the index as an argument. Use polygeist.subindex instead of polygeist.memref2pointer+llvm.getelementptr+polygeist.pointer2memref when subtracting an index to a pointer. Cast pointers to void and functions to i8 before operating. Signed-off-by: Victor Perez <[email protected]>
1 parent 16d5469 commit 9ac759f

File tree

9 files changed

+991
-94
lines changed

9 files changed

+991
-94
lines changed

polygeist/tools/cgeist/Lib/CGExpr.cc

Lines changed: 219 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "Lib/TypeUtils.h"
910
#include "clang-mlir.h"
1011
#include "utils.h"
1112
#include "llvm/ADT/TypeSwitch.h"
@@ -2185,6 +2186,16 @@ ValueCategory MLIRScanner::EmitPromoted(Expr *E, QualType PromotionType) {
21852186
return Res;
21862187
}
21872188

2189+
ValueCategory MLIRScanner::CastToVoidPtr(ValueCategory Ptr) {
2190+
assert(mlirclang::isPointerOrMemRefTy(Ptr.val.getType()) &&
2191+
"Expecting pointer or memref");
2192+
2193+
const auto DestType =
2194+
mlirclang::getPtrTyWithNewType(Ptr.val.getType(), builder.getI8Type());
2195+
2196+
return Ptr.BitCast(builder, loc, DestType);
2197+
}
2198+
21882199
ValueCategory MLIRScanner::EmitPromotedValue(Location Loc, ValueCategory Result,
21892200
QualType PromotionType) {
21902201
return Result.FPExt(builder, Loc, Glob.getTypes().getMLIRType(PromotionType));
@@ -2576,6 +2587,12 @@ BinOpInfo MLIRScanner::EmitBinOps(BinaryOperator *E, QualType PromotionType) {
25762587
return {LHS, RHS, Ty, Opcode, E};
25772588
}
25782589

2590+
static void informNoOverflowCheck(LangOptions::SignedOverflowBehaviorTy SOB,
2591+
llvm::StringRef OpName) {
2592+
if (SOB != clang::LangOptions::SOB_Defined)
2593+
llvm::errs() << "Not emitting overflow-checked " << OpName << "\n";
2594+
}
2595+
25792596
ValueCategory MLIRScanner::EmitBinMul(const BinOpInfo &Info) {
25802597
auto lhs_v = Info.getLHS().getValue(builder);
25812598
auto rhs_v = Info.getRHS().getValue(builder);
@@ -2620,91 +2637,215 @@ ValueCategory MLIRScanner::EmitBinRem(const BinOpInfo &Info) {
26202637
}
26212638
}
26222639

2640+
/// Casts index of subindex operation conditionally.
2641+
static Optional<Value> castSubIndexOpIndex(OpBuilder &Builder, Location Loc,
2642+
ValueCategory Pointer,
2643+
ValueRange IdxList, bool IsSigned) {
2644+
if (Pointer.val.getType().isa<MemRefType>()) {
2645+
assert(IdxList.size() == 1 && "SubIndexOp accepts just an index");
2646+
return ValueCategory(IdxList.front(), false)
2647+
.IntCast(Builder, Loc, Builder.getIndexType(), IsSigned)
2648+
.val;
2649+
}
2650+
return llvm::None;
2651+
}
2652+
2653+
ValueCategory MLIRScanner::EmitCheckedInBoundsPtrOffsetOp(mlir::Type ElemTy,
2654+
ValueCategory Pointer,
2655+
ValueRange IdxList,
2656+
bool IsSigned, bool) {
2657+
assert(mlirclang::isPointerOrMemRefTy(Pointer.val.getType()) &&
2658+
"Expecting pointer or MemRef");
2659+
assert(std::all_of(IdxList.begin(), IdxList.end(),
2660+
[](mlir::Value Val) {
2661+
return Val.getType().isa<IntegerType>();
2662+
}) &&
2663+
"Expecting indices list");
2664+
2665+
if (Optional<Value> NewValue =
2666+
castSubIndexOpIndex(builder, loc, Pointer, IdxList, IsSigned))
2667+
IdxList = *NewValue;
2668+
2669+
return Pointer.InBoundsGEPOrSubIndex(builder, loc, ElemTy, IdxList);
2670+
}
2671+
2672+
ValueCategory MLIRScanner::EmitPointerArithmetic(const BinOpInfo &Info) {
2673+
const auto *Expr = cast<BinaryOperator>(Info.getExpr());
2674+
2675+
ValueCategory Pointer = Info.getLHS();
2676+
auto *PointerOperand = Expr->getLHS();
2677+
ValueCategory Index = Info.getRHS();
2678+
auto *IndexOperand = Expr->getRHS();
2679+
2680+
const auto Opcode = Info.getOpcode();
2681+
const auto IsSubtraction =
2682+
Opcode == clang::BO_Sub || Opcode == clang::BO_SubAssign;
2683+
2684+
assert((!IsSubtraction ||
2685+
mlirclang::isPointerOrMemRefTy(Pointer.val.getType())) &&
2686+
"The LHS is always a pointer in a subtraction");
2687+
2688+
if (!mlirclang::isPointerOrMemRefTy(Pointer.val.getType())) {
2689+
std::swap(Pointer, Index);
2690+
std::swap(PointerOperand, IndexOperand);
2691+
}
2692+
2693+
assert(Index.val.getType().isa<IntegerType>() && "Expecting integer type");
2694+
2695+
auto PtrTy = Pointer.val.getType();
2696+
2697+
assert(mlirclang::isPointerOrMemRefTy(PtrTy) && "Expecting pointer type");
2698+
2699+
auto &CGM = Glob.getCGM();
2700+
2701+
// Some versions of glibc and gcc use idioms (particularly in their malloc
2702+
// routines) that add a pointer-sized integer (known to be a pointer
2703+
// value) to a null pointer in order to cast the value back to an integer
2704+
// or as part of a pointer alignment algorithm. This is undefined
2705+
// behavior, but we'd like to be able to compile programs that use it.
2706+
//
2707+
// Normally, we'd generate a GEP with a null-pointer base here in response
2708+
// to that code, but it's also UB to dereference a pointer created that
2709+
// way. Instead (as an acknowledged hack to tolerate the idiom) we will
2710+
// generate a direct cast of the integer value to a pointer.
2711+
//
2712+
// The idiom (p = nullptr + N) is not met if any of the following are
2713+
// true:
2714+
//
2715+
// The operation is subtraction.
2716+
// The index is not pointer-sized.
2717+
// The pointer type is not byte-sized.
2718+
//
2719+
if (BinaryOperator::isNullPointerArithmeticExtension(
2720+
CGM.getContext(), Opcode, PointerOperand, IndexOperand)) {
2721+
return Index.IntToPtr(builder, loc, PtrTy);
2722+
}
2723+
2724+
auto &DL = CGM.getDataLayout();
2725+
const unsigned IndexTypeSize = DL.getIndexTypeSizeInBits(
2726+
CGM.getTypes().ConvertType(PointerOperand->getType()));
2727+
const auto IsSigned =
2728+
IndexOperand->getType()->isSignedIntegerOrEnumerationType();
2729+
const unsigned Width = Index.val.getType().getIntOrFloatBitWidth();
2730+
if (Width != IndexTypeSize) {
2731+
// Zero-extend or sign-extend the pointer value according to
2732+
// whether the index is signed or not.
2733+
Index = Index.IntCast(builder, loc, builder.getIntegerType(IndexTypeSize),
2734+
IsSigned);
2735+
}
2736+
2737+
// If this is subtraction, negate the index.
2738+
if (IsSubtraction)
2739+
Index = Index.Neg(builder, loc);
2740+
2741+
const auto *PointerType =
2742+
PointerOperand->getType()->getAs<clang::PointerType>();
2743+
2744+
assert(PointerType && "Not pointer type");
2745+
2746+
QualType ElementType = PointerType->getPointeeType();
2747+
assert(!CGM.getContext().getAsVariableArrayType(ElementType) &&
2748+
"Not implemented yet");
2749+
2750+
// Explicitly handle GNU void* and function pointer arithmetic extensions.
2751+
// The GNU void* casts amount to no-ops since our void* type is i8*, but
2752+
// this is future proof.
2753+
if (ElementType->isVoidType() || ElementType->isFunctionType()) {
2754+
assert(PtrTy.isa<LLVM::LLVMPointerType>() && "Expecting pointer type");
2755+
auto Result = CastToVoidPtr(Pointer);
2756+
Result = Result.GEP(builder, loc, builder.getI8Type(), Index.val);
2757+
return Result.BitCast(builder, loc, Pointer.val.getType());
2758+
}
2759+
2760+
auto ElemTy = Glob.getTypes().getMLIRType(ElementType);
2761+
if (CGM.getLangOpts().isSignedOverflowDefined()) {
2762+
if (Optional<Value> NewIndex =
2763+
castSubIndexOpIndex(builder, loc, Pointer, Index.val, IsSigned))
2764+
Index.val = *NewIndex;
2765+
return Pointer.GEPOrSubIndex(builder, loc, ElemTy, Index.val);
2766+
}
2767+
2768+
return EmitCheckedInBoundsPtrOffsetOp(ElemTy, Pointer, Index.val, IsSigned,
2769+
IsSubtraction);
2770+
}
2771+
26232772
ValueCategory MLIRScanner::EmitBinAdd(const BinOpInfo &Info) {
2624-
auto lhs_v = Info.getLHS().getValue(builder);
2625-
auto rhs_v = Info.getRHS().getValue(builder);
2626-
if (lhs_v.getType().isa<mlir::FloatType>()) {
2627-
return ValueCategory(builder.create<AddFOp>(loc, lhs_v, rhs_v),
2628-
/*isReference*/ false);
2629-
} else if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
2630-
auto shape = std::vector<int64_t>(mt.getShape());
2631-
shape[0] = -1;
2632-
auto mt0 =
2633-
mlir::MemRefType::get(shape, mt.getElementType(),
2634-
MemRefLayoutAttrInterface(), mt.getMemorySpace());
2635-
auto ptradd = rhs_v;
2636-
ptradd = castToIndex(loc, ptradd);
2637-
return ValueCategory(
2638-
builder.create<polygeist::SubIndexOp>(loc, mt0, lhs_v, ptradd),
2639-
/*isReference*/ false);
2640-
} else if (auto pt =
2641-
lhs_v.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
2642-
return ValueCategory(builder.create<LLVM::GEPOp>(
2643-
loc, pt, lhs_v, std::vector<mlir::Value>({rhs_v})),
2644-
/*isReference*/ false);
2645-
} else {
2646-
if (auto lhs_c = lhs_v.getDefiningOp<ConstantIntOp>()) {
2647-
if (auto rhs_c = rhs_v.getDefiningOp<ConstantIntOp>()) {
2648-
return ValueCategory(
2649-
builder.create<arith::ConstantIntOp>(
2650-
loc, lhs_c.value() + rhs_c.value(), lhs_c.getType()),
2651-
false);
2652-
}
2653-
}
2654-
return ValueCategory(builder.create<AddIOp>(loc, lhs_v, rhs_v),
2655-
/*isReference*/ false);
2773+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2774+
const auto LHS = Info.getLHS();
2775+
const auto RHS = Info.getRHS().val;
2776+
2777+
if (mlirclang::isPointerOrMemRefTy(LHS.val.getType()) ||
2778+
mlirclang::isPointerOrMemRefTy(RHS.getType())) {
2779+
loc = Loc;
2780+
return EmitPointerArithmetic(Info);
2781+
}
2782+
2783+
if (Info.getType()->isSignedIntegerOrEnumerationType()) {
2784+
informNoOverflowCheck(
2785+
Glob.getCGM().getLangOpts().getSignedOverflowBehavior(), "add");
2786+
return LHS.Add(builder, Loc, RHS);
26562787
}
2788+
2789+
assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");
2790+
2791+
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
2792+
return LHS.FAdd(builder, Loc, RHS);
2793+
return LHS.Add(builder, Loc, RHS);
26572794
}
26582795

26592796
ValueCategory MLIRScanner::EmitBinSub(const BinOpInfo &Info) {
2660-
auto lhs_v = Info.getLHS().getValue(builder);
2661-
auto rhs_v = Info.getRHS().getValue(builder);
2662-
if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
2663-
lhs_v = builder.create<polygeist::Memref2PointerOp>(
2664-
loc,
2665-
LLVM::LLVMPointerType::get(mt.getElementType(),
2666-
mt.getMemorySpaceAsInt()),
2667-
lhs_v);
2668-
}
2669-
if (auto mt = rhs_v.getType().dyn_cast<mlir::MemRefType>()) {
2670-
rhs_v = builder.create<polygeist::Memref2PointerOp>(
2671-
loc,
2672-
LLVM::LLVMPointerType::get(mt.getElementType(),
2673-
mt.getMemorySpaceAsInt()),
2674-
rhs_v);
2675-
}
2676-
if (lhs_v.getType().isa<mlir::FloatType>()) {
2677-
assert(rhs_v.getType() == lhs_v.getType());
2678-
return ValueCategory(builder.create<SubFOp>(loc, lhs_v, rhs_v),
2679-
/*isReference*/ false);
2680-
} else if (auto pt =
2681-
lhs_v.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
2682-
if (auto IT = rhs_v.getType().dyn_cast<mlir::IntegerType>()) {
2683-
mlir::Value vals[1] = {builder.create<SubIOp>(
2684-
loc, builder.create<ConstantIntOp>(loc, 0, IT.getWidth()), rhs_v)};
2685-
return ValueCategory(
2686-
builder.create<LLVM::GEPOp>(loc, lhs_v.getType(), lhs_v,
2687-
ArrayRef<mlir::Value>(vals)),
2688-
false);
2797+
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
2798+
auto LHS = Info.getLHS();
2799+
auto RHS = Info.getRHS();
2800+
2801+
// The LHS is always a pointer if either side is.
2802+
if (!mlirclang::isPointerOrMemRefTy(LHS.val.getType())) {
2803+
if (Info.getType()->isSignedIntegerOrEnumerationType()) {
2804+
informNoOverflowCheck(
2805+
Glob.getCGM().getLangOpts().getSignedOverflowBehavior(), "sub");
2806+
return LHS.Sub(builder, Loc, RHS.val);
26892807
}
2690-
mlir::Value val = builder.create<SubIOp>(
2691-
loc,
2692-
builder.create<LLVM::PtrToIntOp>(
2693-
loc, Glob.getTypes().getMLIRType(Info.getType()), lhs_v),
2694-
builder.create<LLVM::PtrToIntOp>(
2695-
loc, Glob.getTypes().getMLIRType(Info.getType()), rhs_v));
2696-
val = builder.create<DivSIOp>(
2697-
loc, val,
2698-
builder.create<IndexCastOp>(
2699-
loc, val.getType(),
2700-
builder.create<polygeist::TypeSizeOp>(
2701-
loc, builder.getIndexType(),
2702-
mlir::TypeAttr::get(pt.getElementType()))));
2703-
return ValueCategory(val, /*isReference*/ false);
2704-
} else {
2705-
return ValueCategory(builder.create<SubIOp>(loc, lhs_v, rhs_v),
2706-
/*isReference*/ false);
2808+
assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");
2809+
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
2810+
return LHS.FSub(builder, Loc, RHS.val);
2811+
return LHS.Sub(builder, Loc, RHS.val);
2812+
}
2813+
2814+
// If the RHS is not a pointer, then we have normal pointer
2815+
// arithmetic.
2816+
if (!mlirclang::isPointerOrMemRefTy(RHS.val.getType())) {
2817+
loc = Loc;
2818+
return EmitPointerArithmetic(Info);
27072819
}
2820+
2821+
// Otherwise, this is a pointer subtraction.
2822+
2823+
// Do the raw subtraction part.
2824+
const auto PtrDiffTy = builder.getIntegerType(
2825+
Glob.getCGM().getDataLayout().getPointerSizeInBits());
2826+
LHS = LHS.MemRef2Ptr(builder, Loc).PtrToInt(builder, Loc, PtrDiffTy);
2827+
RHS = RHS.MemRef2Ptr(builder, Loc).PtrToInt(builder, Loc, PtrDiffTy);
2828+
const auto DiffInChars = LHS.Sub(builder, Loc, RHS.val);
2829+
2830+
// Okay, figure out the element size.
2831+
const QualType ElementType =
2832+
Info.getExpr()->getLHS()->getType()->getPointeeType();
2833+
2834+
assert(!Glob.getCGM().getContext().getAsVariableArrayType(ElementType) &&
2835+
"Not implemented yet");
2836+
2837+
const CharUnits ElementSize =
2838+
(ElementType->isVoidType() || ElementType->isFunctionType())
2839+
? CharUnits::One()
2840+
: Glob.getCGM().getContext().getTypeSizeInChars(ElementType);
2841+
2842+
if (ElementSize.isOne())
2843+
return DiffInChars;
2844+
2845+
const auto Divisor = builder.createOrFold<arith::ConstantIntOp>(
2846+
Loc, ElementSize.getQuantity(), PtrDiffTy);
2847+
2848+
return DiffInChars.ExactSDiv(builder, Loc, Divisor);
27082849
}
27092850

27102851
ValueCategory MLIRScanner::EmitBinShl(const BinOpInfo &Info) {

0 commit comments

Comments
 (0)