Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 219 additions & 78 deletions polygeist/tools/cgeist/Lib/CGExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//

#include "Lib/TypeUtils.h"
#include "clang-mlir.h"
#include "utils.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -2185,6 +2186,16 @@ ValueCategory MLIRScanner::EmitPromoted(Expr *E, QualType PromotionType) {
return Res;
}

ValueCategory MLIRScanner::CastToVoidPtr(ValueCategory Ptr) {
assert(mlirclang::isPointerOrMemRefTy(Ptr.val.getType()) &&
"Expecting pointer or memref");

const auto DestType =
mlirclang::getPtrTyWithNewType(Ptr.val.getType(), builder.getI8Type());

return Ptr.BitCast(builder, loc, DestType);
}

ValueCategory MLIRScanner::EmitPromotedValue(Location Loc, ValueCategory Result,
QualType PromotionType) {
return Result.FPExt(builder, Loc, Glob.getTypes().getMLIRType(PromotionType));
Expand Down Expand Up @@ -2576,6 +2587,12 @@ BinOpInfo MLIRScanner::EmitBinOps(BinaryOperator *E, QualType PromotionType) {
return {LHS, RHS, Ty, Opcode, E};
}

static void informNoOverflowCheck(LangOptions::SignedOverflowBehaviorTy SOB,
llvm::StringRef OpName) {
if (SOB != clang::LangOptions::SOB_Defined)
llvm::errs() << "Not emitting overflow-checked " << OpName << "\n";
}

ValueCategory MLIRScanner::EmitBinMul(const BinOpInfo &Info) {
auto lhs_v = Info.getLHS().getValue(builder);
auto rhs_v = Info.getRHS().getValue(builder);
Expand Down Expand Up @@ -2620,91 +2637,215 @@ ValueCategory MLIRScanner::EmitBinRem(const BinOpInfo &Info) {
}
}

/// Casts index of subindex operation conditionally.
static Optional<Value> castSubIndexOpIndex(OpBuilder &Builder, Location Loc,
ValueCategory Pointer,
ValueRange IdxList, bool IsSigned) {
if (Pointer.val.getType().isa<MemRefType>()) {
assert(IdxList.size() == 1 && "SubIndexOp accepts just an index");
return ValueCategory(IdxList.front(), false)
.IntCast(Builder, Loc, Builder.getIndexType(), IsSigned)
.val;
}
return llvm::None;
}

ValueCategory MLIRScanner::EmitCheckedInBoundsPtrOffsetOp(mlir::Type ElemTy,
ValueCategory Pointer,
ValueRange IdxList,
bool IsSigned, bool) {
assert(mlirclang::isPointerOrMemRefTy(Pointer.val.getType()) &&
"Expecting pointer or MemRef");
assert(std::all_of(IdxList.begin(), IdxList.end(),
[](mlir::Value Val) {
return Val.getType().isa<IntegerType>();
}) &&
"Expecting indices list");

if (Optional<Value> NewValue =
castSubIndexOpIndex(builder, loc, Pointer, IdxList, IsSigned))
IdxList = *NewValue;

return Pointer.InBoundsGEPOrSubIndex(builder, loc, ElemTy, IdxList);
}

ValueCategory MLIRScanner::EmitPointerArithmetic(const BinOpInfo &Info) {
const auto *Expr = cast<BinaryOperator>(Info.getExpr());

ValueCategory Pointer = Info.getLHS();
auto *PointerOperand = Expr->getLHS();
ValueCategory Index = Info.getRHS();
auto *IndexOperand = Expr->getRHS();

const auto Opcode = Info.getOpcode();
const auto IsSubtraction =
Opcode == clang::BO_Sub || Opcode == clang::BO_SubAssign;

assert((!IsSubtraction ||
mlirclang::isPointerOrMemRefTy(Pointer.val.getType())) &&
"The LHS is always a pointer in a subtraction");

if (!mlirclang::isPointerOrMemRefTy(Pointer.val.getType())) {
std::swap(Pointer, Index);
std::swap(PointerOperand, IndexOperand);
}

assert(Index.val.getType().isa<IntegerType>() && "Expecting integer type");

auto PtrTy = Pointer.val.getType();

assert(mlirclang::isPointerOrMemRefTy(PtrTy) && "Expecting pointer type");

auto &CGM = Glob.getCGM();

// Some versions of glibc and gcc use idioms (particularly in their malloc
// routines) that add a pointer-sized integer (known to be a pointer
// value) to a null pointer in order to cast the value back to an integer
// or as part of a pointer alignment algorithm. This is undefined
// behavior, but we'd like to be able to compile programs that use it.
//
// Normally, we'd generate a GEP with a null-pointer base here in response
// to that code, but it's also UB to dereference a pointer created that
// way. Instead (as an acknowledged hack to tolerate the idiom) we will
// generate a direct cast of the integer value to a pointer.
//
// The idiom (p = nullptr + N) is not met if any of the following are
// true:
//
// The operation is subtraction.
// The index is not pointer-sized.
// The pointer type is not byte-sized.
//
if (BinaryOperator::isNullPointerArithmeticExtension(
CGM.getContext(), Opcode, PointerOperand, IndexOperand)) {
return Index.IntToPtr(builder, loc, PtrTy);
}

auto &DL = CGM.getDataLayout();
const unsigned IndexTypeSize = DL.getIndexTypeSizeInBits(
CGM.getTypes().ConvertType(PointerOperand->getType()));
const auto IsSigned =
IndexOperand->getType()->isSignedIntegerOrEnumerationType();
const unsigned Width = Index.val.getType().getIntOrFloatBitWidth();
if (Width != IndexTypeSize) {
// Zero-extend or sign-extend the pointer value according to
// whether the index is signed or not.
Index = Index.IntCast(builder, loc, builder.getIntegerType(IndexTypeSize),
IsSigned);
}

// If this is subtraction, negate the index.
if (IsSubtraction)
Index = Index.Neg(builder, loc);

const auto *PointerType =
PointerOperand->getType()->getAs<clang::PointerType>();

assert(PointerType && "Not pointer type");

QualType ElementType = PointerType->getPointeeType();
assert(!CGM.getContext().getAsVariableArrayType(ElementType) &&
"Not implemented yet");

// Explicitly handle GNU void* and function pointer arithmetic extensions.
// The GNU void* casts amount to no-ops since our void* type is i8*, but
// this is future proof.
if (ElementType->isVoidType() || ElementType->isFunctionType()) {
assert(PtrTy.isa<LLVM::LLVMPointerType>() && "Expecting pointer type");
auto Result = CastToVoidPtr(Pointer);
Result = Result.GEP(builder, loc, builder.getI8Type(), Index.val);
return Result.BitCast(builder, loc, Pointer.val.getType());
}

auto ElemTy = Glob.getTypes().getMLIRType(ElementType);
if (CGM.getLangOpts().isSignedOverflowDefined()) {
if (Optional<Value> NewIndex =
castSubIndexOpIndex(builder, loc, Pointer, Index.val, IsSigned))
Index.val = *NewIndex;
return Pointer.GEPOrSubIndex(builder, loc, ElemTy, Index.val);
}

return EmitCheckedInBoundsPtrOffsetOp(ElemTy, Pointer, Index.val, IsSigned,
IsSubtraction);
}

ValueCategory MLIRScanner::EmitBinAdd(const BinOpInfo &Info) {
auto lhs_v = Info.getLHS().getValue(builder);
auto rhs_v = Info.getRHS().getValue(builder);
if (lhs_v.getType().isa<mlir::FloatType>()) {
return ValueCategory(builder.create<AddFOp>(loc, lhs_v, rhs_v),
/*isReference*/ false);
} else if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
auto shape = std::vector<int64_t>(mt.getShape());
shape[0] = -1;
auto mt0 =
mlir::MemRefType::get(shape, mt.getElementType(),
MemRefLayoutAttrInterface(), mt.getMemorySpace());
auto ptradd = rhs_v;
ptradd = castToIndex(loc, ptradd);
return ValueCategory(
builder.create<polygeist::SubIndexOp>(loc, mt0, lhs_v, ptradd),
/*isReference*/ false);
} else if (auto pt =
lhs_v.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
return ValueCategory(builder.create<LLVM::GEPOp>(
loc, pt, lhs_v, std::vector<mlir::Value>({rhs_v})),
/*isReference*/ false);
} else {
if (auto lhs_c = lhs_v.getDefiningOp<ConstantIntOp>()) {
if (auto rhs_c = rhs_v.getDefiningOp<ConstantIntOp>()) {
return ValueCategory(
builder.create<arith::ConstantIntOp>(
loc, lhs_c.value() + rhs_c.value(), lhs_c.getType()),
false);
}
}
return ValueCategory(builder.create<AddIOp>(loc, lhs_v, rhs_v),
/*isReference*/ false);
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
const auto LHS = Info.getLHS();
const auto RHS = Info.getRHS().val;

if (mlirclang::isPointerOrMemRefTy(LHS.val.getType()) ||
mlirclang::isPointerOrMemRefTy(RHS.getType())) {
loc = Loc;
return EmitPointerArithmetic(Info);
}

if (Info.getType()->isSignedIntegerOrEnumerationType()) {
informNoOverflowCheck(
Glob.getCGM().getLangOpts().getSignedOverflowBehavior(), "add");
return LHS.Add(builder, Loc, RHS);
}

assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");

if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
return LHS.FAdd(builder, Loc, RHS);
return LHS.Add(builder, Loc, RHS);
}

ValueCategory MLIRScanner::EmitBinSub(const BinOpInfo &Info) {
auto lhs_v = Info.getLHS().getValue(builder);
auto rhs_v = Info.getRHS().getValue(builder);
if (auto mt = lhs_v.getType().dyn_cast<mlir::MemRefType>()) {
lhs_v = builder.create<polygeist::Memref2PointerOp>(
loc,
LLVM::LLVMPointerType::get(mt.getElementType(),
mt.getMemorySpaceAsInt()),
lhs_v);
}
if (auto mt = rhs_v.getType().dyn_cast<mlir::MemRefType>()) {
rhs_v = builder.create<polygeist::Memref2PointerOp>(
loc,
LLVM::LLVMPointerType::get(mt.getElementType(),
mt.getMemorySpaceAsInt()),
rhs_v);
}
if (lhs_v.getType().isa<mlir::FloatType>()) {
assert(rhs_v.getType() == lhs_v.getType());
return ValueCategory(builder.create<SubFOp>(loc, lhs_v, rhs_v),
/*isReference*/ false);
} else if (auto pt =
lhs_v.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
if (auto IT = rhs_v.getType().dyn_cast<mlir::IntegerType>()) {
mlir::Value vals[1] = {builder.create<SubIOp>(
loc, builder.create<ConstantIntOp>(loc, 0, IT.getWidth()), rhs_v)};
return ValueCategory(
builder.create<LLVM::GEPOp>(loc, lhs_v.getType(), lhs_v,
ArrayRef<mlir::Value>(vals)),
false);
const auto Loc = getMLIRLocation(Info.getExpr()->getExprLoc());
auto LHS = Info.getLHS();
auto RHS = Info.getRHS();

// The LHS is always a pointer if either side is.
if (!mlirclang::isPointerOrMemRefTy(LHS.val.getType())) {
if (Info.getType()->isSignedIntegerOrEnumerationType()) {
informNoOverflowCheck(
Glob.getCGM().getLangOpts().getSignedOverflowBehavior(), "sub");
return LHS.Sub(builder, Loc, RHS.val);
}
mlir::Value val = builder.create<SubIOp>(
loc,
builder.create<LLVM::PtrToIntOp>(
loc, Glob.getTypes().getMLIRType(Info.getType()), lhs_v),
builder.create<LLVM::PtrToIntOp>(
loc, Glob.getTypes().getMLIRType(Info.getType()), rhs_v));
val = builder.create<DivSIOp>(
loc, val,
builder.create<IndexCastOp>(
loc, val.getType(),
builder.create<polygeist::TypeSizeOp>(
loc, builder.getIndexType(),
mlir::TypeAttr::get(pt.getElementType()))));
return ValueCategory(val, /*isReference*/ false);
} else {
return ValueCategory(builder.create<SubIOp>(loc, lhs_v, rhs_v),
/*isReference*/ false);
assert(!Info.getType()->isConstantMatrixType() && "Not yet implemented");
if (mlirclang::isFPOrFPVectorTy(LHS.val.getType()))
return LHS.FSub(builder, Loc, RHS.val);
return LHS.Sub(builder, Loc, RHS.val);
}

// If the RHS is not a pointer, then we have normal pointer
// arithmetic.
if (!mlirclang::isPointerOrMemRefTy(RHS.val.getType())) {
loc = Loc;
return EmitPointerArithmetic(Info);
}

// Otherwise, this is a pointer subtraction.

// Do the raw subtraction part.
const auto PtrDiffTy = builder.getIntegerType(
Glob.getCGM().getDataLayout().getPointerSizeInBits());
LHS = LHS.MemRef2Ptr(builder, Loc).PtrToInt(builder, Loc, PtrDiffTy);
RHS = RHS.MemRef2Ptr(builder, Loc).PtrToInt(builder, Loc, PtrDiffTy);
const auto DiffInChars = LHS.Sub(builder, Loc, RHS.val);

// Okay, figure out the element size.
const QualType ElementType =
Info.getExpr()->getLHS()->getType()->getPointeeType();

assert(!Glob.getCGM().getContext().getAsVariableArrayType(ElementType) &&
"Not implemented yet");

const CharUnits ElementSize =
(ElementType->isVoidType() || ElementType->isFunctionType())
? CharUnits::One()
: Glob.getCGM().getContext().getTypeSizeInChars(ElementType);

if (ElementSize.isOne())
return DiffInChars;

const auto Divisor = builder.createOrFold<arith::ConstantIntOp>(
Loc, ElementSize.getQuantity(), PtrDiffTy);

return DiffInChars.ExactSDiv(builder, Loc, Divisor);
}

ValueCategory MLIRScanner::EmitBinShl(const BinOpInfo &Info) {
Expand Down
Loading