Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,37 @@ def EmitC_FileOp
let skipDefaultBuilders = 1;
}

def EmitC_AddressOfOp : EmitC_Op<"address_of", [
CExpressionInterface,
TypesMatchWith<"input and result reference the same type", "reference", "result",
"emitc::PointerType::get(::llvm::cast<emitc::LValueType>($_self).getValueType())">
]> {
let summary = "Address operation";
let description = [{
This operation models the C & (address of) operator for a single operand,
which must be an emitc.lvalue, and returns an emitc pointer to its location.

Example:

```mlir
// Custom form of applying the & operator.
%0 = emitc.address_of %arg0 : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
```
}];
let arguments = (ins EmitC_LValueType:$reference);
let results = (outs EmitC_PointerType:$result);
let assemblyFormat = [{
$reference `:` qualified(type($reference)) attr-dict
}];
let hasVerifier = 1;

let extraClassDeclaration = [{
bool hasSideEffects() {
return false;
}
}];
}

def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
let summary = "Addition operation";
let description = [{
Expand All @@ -140,7 +171,7 @@ def EmitC_AddOp : EmitC_BinaryOp<"add", []> {
}

def EmitC_ApplyOp : EmitC_Op<"apply", [CExpressionInterface]> {
let summary = "Apply operation";
let summary = "Deprecated (use address_of/dereference)";
let description = [{
With the `emitc.apply` operation the operators & (address of) and * (contents of)
can be applied to a single operand.
Expand Down Expand Up @@ -439,6 +470,31 @@ def EmitC_ConstantOp
}];
}

def EmitC_DereferenceOp : EmitC_Op<"dereference", [
TypesMatchWith<"input and result reference the same type", "pointer", "result",
"emitc::LValueType::get(::llvm::cast<emitc::PointerType>($_self).getPointee())">
]> {
let summary = "Dereference operation";
let description = [{
This operation models the C * (dereference) operator, which must be of
!emitc.ptr<> type, returning an !emitc.lvalue<> the value pointed to by the
pointer.

Example:

```mlir
// Custom form of the dereference operator.
%0 = emitc.dereference %arg0 : (!emitc.ptr<i32>) -> !emitc.lvalue<i32>
```
}];
let arguments = (ins EmitC_PointerType:$pointer);
let results = (outs EmitC_LValueType:$result);
let assemblyFormat = [{
$pointer `:` qualified(type($pointer)) attr-dict
}];
let hasVerifier = 1;
}

def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
let summary = "Division operation";
let description = [{
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,19 @@ class EmitC_LValueOf<list<Type> allowedTypes> :
"::mlir::emitc::LValueType"
>;

class EmitC_TypesMatchOnElementType<
string summary, string lhsArg, string rhsArg, string transformL, string transformR,
string comparator = "std::equal_to<>()"
> : PredOpTrait<summary, CPred<comparator # "(" #
!subst("$_self", "$" # lhsArg # ".getType()", transformL) #
", " #
!subst("$_self", "$" # rhsArg # ".getType()", transformL) #
")">
> {
string lhs = lhsArg;
string rhs = rhsArg;
string transformerL = transformL;
string transformerR = transformR;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITCTYPES
14 changes: 7 additions & 7 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
return totalSizeBytes.getResult();
}

static emitc::ApplyOp
static emitc::AddressOfOp
createPointerFromEmitcArray(Location loc, OpBuilder &builder,
TypedValue<emitc::ArrayType> arrayValue) {

Expand All @@ -133,9 +133,9 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
emitc::ApplyOp ptr = emitc::ApplyOp::create(
emitc::AddressOfOp ptr = emitc::AddressOfOp::create(
builder, loc, emitc::PointerType::get(arrayType.getElementType()),
builder.getStringAttr("&"), subPtr);
subPtr);

return ptr;
}
Expand Down Expand Up @@ -225,12 +225,12 @@ struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {

auto srcArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getSource());
emitc::ApplyOp srcPtr =
emitc::AddressOfOp srcPtr =
createPointerFromEmitcArray(loc, rewriter, srcArrayValue);

auto targetArrayValue =
cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
emitc::ApplyOp targetPtr =
emitc::AddressOfOp targetPtr =
createPointerFromEmitcArray(loc, rewriter, targetArrayValue);

emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
Expand Down Expand Up @@ -319,8 +319,8 @@ struct ConvertGetGlobal final
emitc::GetGlobalOp globalLValue = emitc::GetGlobalOp::create(
rewriter, op.getLoc(), lvalueType, operands.getNameAttr());
emitc::PointerType pointerType = emitc::PointerType::get(resultTy);
rewriter.replaceOpWithNewOp<emitc::ApplyOp>(
op, pointerType, rewriter.getStringAttr("&"), globalLValue);
rewriter.replaceOpWithNewOp<emitc::AddressOfOp>(op, pointerType,
globalLValue);
return success();
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,21 @@ FailureOr<SmallVector<ReplacementItem>> parseFormatString(
return items;
}

//===----------------------------------------------------------------------===//
// AddressOfOp
//===----------------------------------------------------------------------===//

LogicalResult AddressOfOp::verify() {
emitc::LValueType referenceType = getReference().getType();
emitc::PointerType resultType = getResult().getType();

if (referenceType.getValueType() != resultType.getPointee())
return emitOpError("requires result to be a pointer to the type "
"referenced by operand");

return success();
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -379,6 +394,20 @@ LogicalResult emitc::ConstantOp::verify() {

OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

//===----------------------------------------------------------------------===//
// DereferenceOp
//===----------------------------------------------------------------------===//

LogicalResult DereferenceOp::verify() {
emitc::PointerType pointerType = getPointer().getType();

if (pointerType.getPointee() != getResult().getType().getValueType())
return emitOpError("requires result to be an lvalue of the type "
"pointed to by operand");

return success();
}

//===----------------------------------------------------------------------===//
// ExpressionOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 25 additions & 4 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static inline LogicalResult interleaveCommaWithError(const Container &c,
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
return llvm::TypeSwitch<Operation *, FailureOr<int>>(operation)
.Case<emitc::AddressOfOp>([&](auto op) { return 15; })
.Case<emitc::AddOp>([&](auto op) { return 12; })
.Case<emitc::ApplyOp>([&](auto op) { return 15; })
.Case<emitc::BitwiseAndOp>([&](auto op) { return 7; })
Expand Down Expand Up @@ -393,6 +394,15 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
return false;
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::DereferenceOp dereferenceOp) {
std::string out;
llvm::raw_string_ostream ss(out);
ss << "*" << emitter.getOrCreateName(dereferenceOp.getPointer());
emitter.cacheDeferredOpResult(dereferenceOp.getResult(), out);
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::GetFieldOp getFieldOp) {
emitter.cacheDeferredOpResult(getFieldOp.getResult(),
Expand Down Expand Up @@ -476,6 +486,17 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
return emitter.emitAttribute(operation->getLoc(), value);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AddressOfOp addressOfOp) {
raw_ostream &os = emitter.ostream();
Operation &op = *addressOfOp.getOperation();

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "&";
return emitter.emitOperand(addressOfOp.getReference());
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
Operation *operation = constantOp.getOperation();
Expand Down Expand Up @@ -1769,14 +1790,14 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
.Case<cf::BranchOp, cf::CondBranchOp>(
[&](auto op) { return printOperation(*this, op); })
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
.Case<emitc::AddressOfOp, emitc::AddOp, emitc::ApplyOp,
emitc::AssignOp, emitc::BitwiseAndOp, emitc::BitwiseLeftShiftOp,
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp,
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp,
emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp,
emitc::DeclareFuncOp, emitc::DereferenceOp, emitc::DivOp,
emitc::DoOp, emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp,
emitc::ForOp, emitc::FuncOp, emitc::GetFieldOp,
emitc::GetGlobalOp, emitc::GlobalOp, emitc::IfOp,
emitc::IncludeOp, emitc::LiteralOp, emitc::LoadOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[CAST_0]] : !emitc.ptr<i32> to !emitc.array<999xi32>
// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK: %[[APPLY_0:.*]] = emitc.apply "&"(%[[SUBSCRIPT_0]]) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<i32>
// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_2]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK: %[[APPLY_1:.*]] = emitc.apply "&"(%[[SUBSCRIPT_1]]) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<i32>
// CHECK: %[[CALL_OPAQUE_2:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK: %[[VAL_3:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK: %[[MUL_1:.*]] = emitc.mul %[[CALL_OPAQUE_2]], %[[VAL_3]] : (!emitc.size_t, index) -> !emitc.size_t
// CHECK: emitc.call_opaque "memcpy"(%[[APPLY_1]], %[[APPLY_0]], %[[MUL_1]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_1]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
// CHECK: %[[CALL_OPAQUE_3:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK: %[[VAL_4:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK: %[[MUL_2:.*]] = emitc.mul %[[CALL_OPAQUE_3]], %[[VAL_4]] : (!emitc.size_t, index) -> !emitc.size_t
Expand All @@ -42,13 +42,13 @@ func.func @alloc_copy(%arg0: memref<999xi32>) {
// CHECK: %[[UNREALIZED_CONVERSION_CAST_2:.*]] = builtin.unrealized_conversion_cast %[[CAST_1]] : !emitc.ptr<i32> to !emitc.array<999xi32>
// CHECK: %[[VAL_5:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK: %[[SUBSCRIPT_2:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_5]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK: %[[APPLY_2:.*]] = emitc.apply "&"(%[[SUBSCRIPT_2]]) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK: %[[ADDRESS_OF_2:.*]] = emitc.address_of %[[SUBSCRIPT_2]] : !emitc.lvalue<i32>
// CHECK: %[[VAL_6:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK: %[[SUBSCRIPT_3:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_2]]{{\[}}%[[VAL_6]]] : (!emitc.array<999xi32>, index) -> !emitc.lvalue<i32>
// CHECK: %[[APPLY_3:.*]] = emitc.apply "&"(%[[SUBSCRIPT_3]]) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK: %[[ADDRESS_OF_3:.*]] = emitc.address_of %[[SUBSCRIPT_3]] : !emitc.lvalue<i32>
// CHECK: %[[CALL_OPAQUE_5:.*]] = emitc.call_opaque "sizeof"() {args = [i32]} : () -> !emitc.size_t
// CHECK: %[[VAL_7:.*]] = "emitc.constant"() <{value = 999 : index}> : () -> index
// CHECK: %[[MUL_3:.*]] = emitc.mul %[[CALL_OPAQUE_5]], %[[VAL_7]] : (!emitc.size_t, index) -> !emitc.size_t
// CHECK: emitc.call_opaque "memcpy"(%[[APPLY_3]], %[[APPLY_2]], %[[MUL_3]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_3]], %[[ADDRESS_OF_2]], %[[MUL_3]]) : (!emitc.ptr<i32>, !emitc.ptr<i32>, !emitc.size_t) -> ()
// CHECK: return
// CHECK: }
6 changes: 3 additions & 3 deletions mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-copy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ func.func @copying(%arg0 : memref<9x4x5x7xf32>, %arg1 : memref<9x4x5x7xf32>) {
// CHECK: %[[UNREALIZED_CONVERSION_CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<9x4x5x7xf32> to !emitc.array<9x4x5x7xf32>
// CHECK: %[[VAL_0:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK: %[[SUBSCRIPT_0:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_1]]{{\[}}%[[VAL_0]], %[[VAL_0]], %[[VAL_0]], %[[VAL_0]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
// CHECK: %[[APPLY_0:.*]] = emitc.apply "&"(%[[SUBSCRIPT_0]]) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
// CHECK: %[[ADDRESS_OF_0:.*]] = emitc.address_of %[[SUBSCRIPT_0]] : !emitc.lvalue<f32>
// CHECK: %[[VAL_1:.*]] = "emitc.constant"() <{value = 0 : index}> : () -> index
// CHECK: %[[SUBSCRIPT_1:.*]] = emitc.subscript %[[UNREALIZED_CONVERSION_CAST_0]]{{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]]] : (!emitc.array<9x4x5x7xf32>, index, index, index, index) -> !emitc.lvalue<f32>
// CHECK: %[[APPLY_1:.*]] = emitc.apply "&"(%[[SUBSCRIPT_1]]) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
// CHECK: %[[ADDRESS_OF_1:.*]] = emitc.address_of %[[SUBSCRIPT_1]] : !emitc.lvalue<f32>
// CHECK: %[[CALL_OPAQUE_0:.*]] = emitc.call_opaque "sizeof"() {args = [f32]} : () -> !emitc.size_t
// CHECK: %[[VAL_2:.*]] = "emitc.constant"() <{value = 1260 : index}> : () -> index
// CHECK: %[[MUL_0:.*]] = emitc.mul %[[CALL_OPAQUE_0]], %[[VAL_2]] : (!emitc.size_t, index) -> !emitc.size_t
// CHECK: emitc.call_opaque "memcpy"(%[[APPLY_1]], %[[APPLY_0]], %[[MUL_0]]) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
// CHECK: emitc.call_opaque "memcpy"(%[[ADDRESS_OF_1]], %[[ADDRESS_OF_0]], %[[MUL_0]]) : (!emitc.ptr<f32>, !emitc.ptr<f32>, !emitc.size_t) -> ()
// CHECK: return
// CHECK: }

2 changes: 1 addition & 1 deletion mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ module @globals {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
// CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
// CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
// CHECK-NEXT: emitc.address_of %1 : !emitc.lvalue<i32>
%1 = memref.get_global @__constant_xi32 : memref<i32>
return
}
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -914,3 +914,19 @@ func.func @test_for_unmatch_type(%arg0: index) {
) : (index, index, index) -> ()
return
}

// -----

func.func @address_of(%arg0: !emitc.lvalue<i32>) {
// expected-error @+1 {{failed to verify that input and result reference the same type}}
%1 = "emitc.address_of"(%arg0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i8>
return
}

// -----

func.func @dereference(%arg0: !emitc.ptr<i32>) {
// expected-error @+1 {{failed to verify that input and result reference the same type}}
%1 = "emitc.dereference"(%arg0) : (!emitc.ptr<i32>) -> !emitc.lvalue<i8>
return
}
10 changes: 10 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,13 @@ func.func @do(%arg0 : !emitc.ptr<i32>) {

return
}

func.func @address_of(%arg0: !emitc.lvalue<i32>) {
%1 = emitc.address_of %arg0 : !emitc.lvalue<i32>
return
}

func.func @dereference(%arg0: !emitc.ptr<i32>) {
%1 = emitc.dereference %arg0 : !emitc.ptr<i32>
return
}
19 changes: 19 additions & 0 deletions mlir/test/Target/Cpp/common-cpp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ func.func @apply() -> !emitc.ptr<i32> {
return %1 : !emitc.ptr<i32>
}


// CHECK-LABEL: void address_of() {
func.func @address_of() {
// CHECK-NEXT: int32_t [[V1:[^ ]*]];
%0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.lvalue<i32>
// CHECK-NEXT: int32_t* [[V2:[^ ]*]] = &[[V1]];
%1 = emitc.address_of %0 : !emitc.lvalue<i32>
return
}

// CHECK-LABEL: void dereference
// CHECK-SAME: (int32_t* [[ARG0:[^ ]*]]) {
func.func @dereference(%arg0: !emitc.ptr<i32>) {
// CHECK: int32_t [[V1:[^ ]*]] = *[[ARG0]];
%2 = emitc.dereference %arg0 : !emitc.ptr<i32>
emitc.load %2 : !emitc.lvalue<i32>
return
}

// CHECK: void array_type(int32_t v1[3], float v2[10][20])
func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
return
Expand Down
Loading
Loading