Skip to content

Commit 16d5469

Browse files
arnamoy10arnamoy.bhattacharyya
andauthored
[SYCL-MLIR] Allow i32 ptr(memref) to be argument of sycl.constructor. (#7455)
During lowering of `atomic` and `multi-ptr`, constructors with pointer type arguments are being called. Here is an example call generated by clang: ` call spir_func void @_ZN4sycl3_V19multi_ptrIjLNS0_6access13address_spaceE1ELNS2_9decoratedE1EEC2EPU3AS1j(%"class.sycl::_V1::multi_ptr" addrspace(4)* noundef align 8 dereferenceable_or_null(8) %agg.tmp2.ascast, i32 addrspace(1)* noundef %add.ptr) #8` As can be seen, the second argument is a `i32` pointer. Currently our `sycl.constructor` op (that will be replaced during lowering to this function call) will not allow pointers to be arguments. This PR removes that restriction. Co-authored-by: arnamoy.bhattacharyya <[email protected]>
1 parent 531fddb commit 16d5469

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def SYCLMemref : AnyTypeOf<[
134134
VecMemRef,
135135
]>;
136136
def IndexType : AnyTypeOf<[I32, I64, Index]>;
137+
def IntMemrefType : AnyTypeOf<[MemRefOf<[I32]>]>;
137138
def SYCLGetResult : AnyTypeOf<[I64, MemRefOf<[I64]>]>;
138139
def SYCLGetIDResult : AnyTypeOf<[I64, SYCL_IDType]>;
139140
def SYCLGetRangeResult : AnyTypeOf<[I64, SYCL_RangeType]>;
@@ -142,7 +143,7 @@ def SYCLGetRangeResult : AnyTypeOf<[I64, SYCL_RangeType]>;
142143
// CONSTRUCTOR OPERATION
143144
////////////////////////////////////////////////////////////////////////////////
144145

145-
def ConstructorArgs : AnyTypeOf<[SYCLMemref, IndexType, SYCL_IDType, SYCL_RangeType]>;
146+
def ConstructorArgs : AnyTypeOf<[SYCLMemref, IndexType, IntMemrefType, SYCL_IDType, SYCL_RangeType]>;
146147
def SYCLConstructorOp : SYCL_Op<"constructor", []> {
147148
let summary = "Generic constructor operation";
148149
let description = [{
@@ -156,6 +157,8 @@ def SYCLConstructorOp : SYCL_Op<"constructor", []> {
156157
);
157158
let results = (outs);
158159

160+
let hasVerifier = 1;
161+
159162
let assemblyFormat = [{
160163
`(` $Args `)` attr-dict `:` functional-type($Args, results)
161164
}];

mlir-sycl/lib/Dialect/IR/SYCLOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/SYCL/IR/SYCLOps.h"
1010

11+
#include "mlir/Dialect/SYCL/IR/SYCLOpsTypes.h"
1112
#include "mlir/IR/OpImplementation.h"
1213
#include "llvm/ADT/TypeSwitch.h"
1314

@@ -57,6 +58,15 @@ bool mlir::sycl::SYCLCastOp::areCastCompatible(::mlir::TypeRange Inputs,
5758
return false;
5859
}
5960

61+
mlir::LogicalResult mlir::sycl::SYCLConstructorOp::verify() {
62+
auto MT = getOperand(0).getType().dyn_cast<mlir::MemRefType>();
63+
if (MT && isSYCLType(MT.getElementType()))
64+
return success();
65+
66+
return emitOpError("The first argument of a sycl::constructor op has to be a "
67+
"MemRef to a SYCL type");
68+
}
69+
6070
mlir::LogicalResult mlir::sycl::SYCLAccessorSubscriptOp::verify() {
6171
// Available only when: (Dimensions > 0)
6272
// reference operator[](id<Dimensions> index) const;

mlir-sycl/test/Dialect/IR/SYCL/constructor.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ func.func @AccessorImplDevice(%arg0: memref<?x!sycl.accessor_impl_device<[1], (!
66
sycl.constructor(%arg0, %arg1, %arg2, %arg2) {MangledFunctionName = @_ZN4sycl3_V16detail18AccessorImplDeviceILi1EEC1ENS0_2idILi1EEENS0_5rangeILi1EEES7_, TypeName = @AccessorImplDevice} : (memref<?x!sycl.accessor_impl_device<[1], (!sycl.id<1>, !sycl.range<1>, !sycl.range<1>)>>, !sycl.id<1>, !sycl.range<1>, !sycl.range<1>) -> ()
77
return
88
}
9+
10+
// Ensure integer pointer can be arguments of sycl.constructor.
11+
// CHECK-LABEL: func.func @TestConstructorII32Ptr
12+
func.func @TestConstructorII32Ptr(%arg0: memref<?x!sycl.id<1>, 4>, %arg1: memref<?xi32, 1>) {
13+
sycl.constructor(%arg0, %arg1) {MangledFunctionName = @_ZN4sycl3_V19multi_ptrIjLNS0_6access13address_spaceE1ELNS2_9decoratedE1EEC1EPU3AS1j, TypeName = @multi_ptr} : (memref<?x!sycl.id<1>, 4>, memref<?xi32, 1>) -> ()
14+
return
15+
}

0 commit comments

Comments
 (0)