diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index db244d1d1cac8..0b7ffa40ec09d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -16,7 +16,9 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeRange.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -77,13 +79,23 @@ struct ConvertAlloca final : public OpConversionPattern { } }; +Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { + Type resultTy; + if (opTy.getRank() == 0) { + resultTy = typeConverter->convertType(mlir::getElementTypeOrSelf(opTy)); + } else { + resultTy = typeConverter->convertType(opTy); + } + return resultTy; +} + struct ConvertGlobal final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::GlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - + MemRefType opTy = op.getType(); if (!op.getType().hasStaticShape()) { return rewriter.notifyMatchFailure( op.getLoc(), "cannot transform global with dynamic shape"); @@ -95,7 +107,9 @@ struct ConvertGlobal final : public OpConversionPattern { op.getLoc(), "global variable with alignment requirement is " "currently not supported"); } - auto resultTy = getTypeConverter()->convertType(op.getType()); + + Type resultTy = convertMemRefType(opTy, getTypeConverter()); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); @@ -114,6 +128,10 @@ struct ConvertGlobal final : public OpConversionPattern { bool externSpecifier = !staticSpecifier; Attribute initialValue = operands.getInitialValueAttr(); + if (opTy.getRank() == 0) { + auto elementsAttr = llvm::cast(*op.getInitialValue()); + initialValue = elementsAttr.getSplatValue(); + } if (isa_and_present(initialValue)) initialValue = {}; @@ -132,11 +150,23 @@ struct ConvertGetGlobal final matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { - auto resultTy = getTypeConverter()->convertType(op.getType()); + MemRefType opTy = op.getType(); + Type resultTy = convertMemRefType(opTy, getTypeConverter()); + if (!resultTy) { return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert result type"); } + + if (opTy.getRank() == 0) { + emitc::LValueType lvalueType = emitc::LValueType::get(resultTy); + emitc::GetGlobalOp globalLValue = rewriter.create( + op.getLoc(), lvalueType, operands.getNameAttr()); + emitc::PointerType pointerType = emitc::PointerType::get(resultTy); + rewriter.replaceOpWithNewOp( + op, pointerType, rewriter.getStringAttr("&"), globalLValue); + return success(); + } rewriter.replaceOpWithNewOp(op, resultTy, operands.getNameAttr()); return success(); diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index d37fd1de90add..2b4eda37903d4 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 { module @globals { memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0> // CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00> + memref.global "private" constant @__constant_xi32 : memref = dense<-1> + // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1 memref.global @public_global : memref<3x7xf32> // CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32> memref.global @uninitialized_global : memref<3x7xf32> = uninitialized @@ -50,6 +52,9 @@ module @globals { func.func @use_global() { // CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32> %0 = memref.get_global @public_global : memref<3x7xf32> + // CHECK-NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue + // CHECK-NEXT: emitc.apply "&"(%1) : (!emitc.lvalue) -> !emitc.ptr + %1 = memref.get_global @__constant_xi32 : memref return } }