Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 125 additions & 62 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
LogicalResult
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto memRefType = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with dynamic shape");
Expand All @@ -80,12 +80,48 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
op.getLoc(), "cannot transform alloca with alignment requirement");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
if (op.getType().getRank() == 0 ||
llvm::is_contained(memRefType.getShape(), 0)) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with rank 0 or zero-sized dim");
}

auto convertedTy = getTypeConverter()->convertType(memRefType);
if (!convertedTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert memref type");
}

auto arrayTy = emitc::ArrayType::get(memRefType.getShape(),
memRefType.getElementType());
auto elemTy = memRefType.getElementType();

auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
auto arrayVar =
rewriter.create<emitc::VariableOp>(op.getLoc(), arrayTy, noInit);

// Build zero indices for the base subscript.
SmallVector<Value> indices;
for (unsigned i = 0; i < memRefType.getRank(); ++i) {
auto zero = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(0));
indices.push_back(zero);
}

auto current = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), emitc::LValueType::get(elemTy), arrayVar.getResult(),
indices);

auto ptrElemTy = emitc::PointerType::get(elemTy);
auto addrOf = rewriter.create<emitc::ApplyOp>(op.getLoc(), ptrElemTy,
rewriter.getStringAttr("&"),
current.getResult());

auto ptrArrayTy = emitc::PointerType::get(arrayTy);
auto casted = rewriter.create<emitc::CastOp>(op.getLoc(), ptrArrayTy,
addrOf.getResult());

rewriter.replaceOp(op, casted.getResult());
return success();
}
};
Expand Down Expand Up @@ -122,24 +158,6 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
return totalSizeBytes.getResult();
}

static emitc::ApplyOp
createPointerFromEmitcArray(Location loc, OpBuilder &builder,
TypedValue<emitc::ArrayType> arrayValue) {

emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
builder, loc, builder.getIndexType(), builder.getIndexAttr(0));

emitc::ArrayType arrayType = arrayValue.getType();
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
emitc::ApplyOp ptr = emitc::ApplyOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
builder.getStringAttr("&"), subPtr);

return ptr;
}

struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
Expand Down Expand Up @@ -194,8 +212,9 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);

emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
emitc::ArrayType arrayType =
emitc::ArrayType::get(memrefType.getShape(), elementType);
emitc::PointerType targetPointerType = emitc::PointerType::get(arrayType);
emitc::CastOp castOp = emitc::CastOp::create(
rewriter, loc, targetPointerType, allocCall.getResult(0));

Expand Down Expand Up @@ -223,20 +242,10 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
return rewriter.notifyMatchFailure(
loc, "incompatible target memref type for EmitC conversion");

auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
emitc::ApplyOp srcPtr =
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);

auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
emitc::ApplyOp targetPtr =
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);

emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
rewriter, loc, TypeRange{}, "memcpy",
ValueRange{
targetPtr.getResult(), srcPtr.getResult(),
operands.getTarget(), operands.getSource(),
calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});

rewriter.replaceOp(copyOp, memCpyCall.getResults());
Expand Down Expand Up @@ -264,11 +273,14 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
"currently not supported");
}

Type resultTy = convertMemRefType(opTy, getTypeConverter());

if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
Type elemTy = getTypeConverter()->convertType(opTy.getElementType());
Type globalType;
if (opTy.getRank() == 0) {
globalType = elemTy;
} else {
SmallVector<int64_t> shape(opTy.getShape().begin(),
opTy.getShape().end());
globalType = emitc::ArrayType::get(shape, elemTy);
}

SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
Expand All @@ -292,7 +304,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
initialValue = {};

rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
op, operands.getSymName(), resultTy, initialValue, externSpecifier,
op, operands.getSymName(), globalType, initialValue, externSpecifier,
staticSpecifier, operands.getConstant());
return success();
}
Expand All @@ -307,24 +319,64 @@ struct ConvertGetGlobal final
ConversionPatternRewriter &rewriter) const override {

MemRefType opTy = op.getType();
Location loc = op.getLoc();

Type elemTy = getTypeConverter()->convertType(opTy.getElementType());
if (!elemTy)
return rewriter.notifyMatchFailure(loc, "cannot convert element type");

Type resultTy = convertMemRefType(opTy, getTypeConverter());

if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
Type globalType;
if (opTy.getRank() == 0) {
globalType = elemTy;
} else {
SmallVector<int64_t> shape(opTy.getShape().begin(),
opTy.getShape().end());
globalType = emitc::ArrayType::get(shape, elemTy);
}

if (opTy.getRank() == 0) {
emitc::LValueType lvalueType = emitc::LValueType::get(resultTy);
emitc::LValueType lvalueType = emitc::LValueType::get(globalType);
emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
op, pointerType, rewriter.getStringAttr("&"), globalLValue);
emitc::PointerType pointerType = emitc::PointerType::get(globalType);
auto addrOf = rewriter.create<emitc::ApplyOp>(
loc, pointerType, rewriter.getStringAttr("&"), globalLValue.getResult());

auto arrayTy = emitc::ArrayType::get({1}, globalType);
auto ptrArrayTy = emitc::PointerType::get(arrayTy);
auto casted =
rewriter.create<emitc::CastOp>(loc, ptrArrayTy, addrOf.getResult());
rewriter.replaceOp(op, casted.getResult());
return success();
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
operands.getNameAttr());

auto getGlobal = rewriter.create<emitc::GetGlobalOp>(
loc, globalType, operands.getNameAttr());

SmallVector<Value> indices;
for (unsigned i = 0; i < opTy.getRank(); ++i) {
auto zero = rewriter.create<emitc::ConstantOp>(
loc, rewriter.getIndexType(), rewriter.getIndexAttr(0));
indices.push_back(zero);
}

auto current = rewriter.create<emitc::SubscriptOp>(
loc, emitc::LValueType::get(elemTy), getGlobal.getResult(), indices);

auto ptrElemTy = emitc::PointerType::get(opTy.getElementType());
auto addrOf = rewriter.create<emitc::ApplyOp>(
loc, ptrElemTy, rewriter.getStringAttr("&"), current.getResult());

auto casted =
rewriter.create<emitc::CastOp>(loc, resultTy, addrOf.getResult());

rewriter.replaceOp(op, casted.getResult());
return success();
}
};
Expand All @@ -340,13 +392,17 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}

auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value memrefVal = operands.getMemref();
Value deref;
if (auto ptrVal = dyn_cast<TypedValue<emitc::PointerType>>(memrefVal)) {
auto arrayTy = dyn_cast<emitc::ArrayType>(ptrVal.getType().getPointee());
if (!arrayTy)
return failure();
deref = emitc::ApplyOp::create(b, arrayTy, b.getStringAttr("*"), ptrVal);
}

auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(deref);
auto subscript = emitc::SubscriptOp::create(
rewriter, op.getLoc(), arrayValue, operands.getIndices());

Expand All @@ -361,16 +417,21 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value memrefVal = operands.getMemref();
Value deref;
if (auto ptrVal = dyn_cast<TypedValue<emitc::PointerType>>(memrefVal)) {
auto arrayTy = dyn_cast<emitc::ArrayType>(ptrVal.getType().getPointee());
if (!arrayTy)
return failure();
deref = emitc::ApplyOp::create(b, arrayTy, b.getStringAttr("*"), ptrVal);
}

auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(deref);
auto subscript = emitc::SubscriptOp::create(
rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
Value valueToStore = operands.getOperands()[0];

rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript, valueToStore);
return success();
}
};
Expand All @@ -386,8 +447,10 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.convertType(memRefType.getElementType());
if (!convertedElementType)
return {};
return emitc::ArrayType::get(memRefType.getShape(),
convertedElementType);
Type innerArrayType =
emitc::ArrayType::get(memRefType.getShape(), convertedElementType);
return emitc::PointerType::get(innerArrayType);

});

auto materializeAsUnrealizedCast = [](OpBuilder &builder, Type resultType,
Expand Down
Loading