diff --git a/mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td b/mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td index c3eea43508735..0b1cf83c4ba1a 100644 --- a/mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td +++ b/mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td @@ -134,6 +134,7 @@ def SYCLMemref : AnyTypeOf<[ VecMemRef, ]>; def IndexType : AnyTypeOf<[I32, I64, Index]>; +def IntMemrefType : AnyTypeOf<[MemRefOf<[I32]>]>; def SYCLGetResult : AnyTypeOf<[I64, MemRefOf<[I64]>]>; def SYCLGetIDResult : AnyTypeOf<[I64, SYCL_IDType]>; def SYCLGetRangeResult : AnyTypeOf<[I64, SYCL_RangeType]>; @@ -142,7 +143,7 @@ def SYCLGetRangeResult : AnyTypeOf<[I64, SYCL_RangeType]>; // CONSTRUCTOR OPERATION //////////////////////////////////////////////////////////////////////////////// -def ConstructorArgs : AnyTypeOf<[SYCLMemref, IndexType, SYCL_IDType, SYCL_RangeType]>; +def ConstructorArgs : AnyTypeOf<[SYCLMemref, IndexType, IntMemrefType, SYCL_IDType, SYCL_RangeType]>; def SYCLConstructorOp : SYCL_Op<"constructor", []> { let summary = "Generic constructor operation"; let description = [{ @@ -156,6 +157,8 @@ def SYCLConstructorOp : SYCL_Op<"constructor", []> { ); let results = (outs); + let hasVerifier = 1; + let assemblyFormat = [{ `(` $Args `)` attr-dict `:` functional-type($Args, results) }]; diff --git a/mlir-sycl/lib/Dialect/IR/SYCLOps.cpp b/mlir-sycl/lib/Dialect/IR/SYCLOps.cpp index fc541429d4ddd..e8b26c70bc3df 100644 --- a/mlir-sycl/lib/Dialect/IR/SYCLOps.cpp +++ b/mlir-sycl/lib/Dialect/IR/SYCLOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SYCL/IR/SYCLOps.h" +#include "mlir/Dialect/SYCL/IR/SYCLOpsTypes.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/TypeSwitch.h" @@ -57,6 +58,15 @@ bool mlir::sycl::SYCLCastOp::areCastCompatible(::mlir::TypeRange Inputs, return false; } +mlir::LogicalResult mlir::sycl::SYCLConstructorOp::verify() { + auto MT = getOperand(0).getType().dyn_cast(); + if (MT && isSYCLType(MT.getElementType())) + return success(); + + return emitOpError("The first argument of a sycl::constructor op has to be a " + "MemRef to a SYCL type"); +} + mlir::LogicalResult mlir::sycl::SYCLAccessorSubscriptOp::verify() { // Available only when: (Dimensions > 0) // reference operator[](id index) const; diff --git a/mlir-sycl/test/Dialect/IR/SYCL/constructor.mlir b/mlir-sycl/test/Dialect/IR/SYCL/constructor.mlir index 071a858cd2bd9..10d1e893e02cd 100644 --- a/mlir-sycl/test/Dialect/IR/SYCL/constructor.mlir +++ b/mlir-sycl/test/Dialect/IR/SYCL/constructor.mlir @@ -6,3 +6,10 @@ func.func @AccessorImplDevice(%arg0: memref, !sycl.range<1>, !sycl.range<1>)>>, !sycl.id<1>, !sycl.range<1>, !sycl.range<1>) -> () return } + +// Ensure integer pointer can be arguments of sycl.constructor. +// CHECK-LABEL: func.func @TestConstructorII32Ptr +func.func @TestConstructorII32Ptr(%arg0: memref, 4>, %arg1: memref) { + sycl.constructor(%arg0, %arg1) {MangledFunctionName = @_ZN4sycl3_V19multi_ptrIjLNS0_6access13address_spaceE1ELNS2_9decoratedE1EEC1EPU3AS1j, TypeName = @multi_ptr} : (memref, 4>, memref) -> () + return +}