-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][XeGPU] make offsets optional for create_nd_tdesc #148335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Jianhui Li (Jianhui-Li) ChangesThis PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd. %2 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> Full diff: https://github.com/llvm/llvm-project/pull/148335.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index bd5ea9fd83781..710fc62b032a9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -110,23 +110,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
Variadic<Index>: $offsets,
Variadic<Index>: $shape,
Variadic<Index>: $strides,
- DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);
- let assemblyFormat = [{
- $source ``
- custom<DynamicIndexList>($offsets, $const_offsets)
- (`,` custom<DynamicIndexList>($shape, $const_shape)^
- `,` custom<DynamicIndexList>($strides, $const_strides))?
- attr-dict `:` type($source) `->` qualified(type($TensorDesc))
- }];
-
let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+
let builders = [
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
+
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
+ "llvm::ArrayRef<OpFoldResult>": $shape,
+ "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
+ OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+ "llvm::ArrayRef<OpFoldResult>": $shape,
+ "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
@@ -163,9 +167,30 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}
ArrayRef<int64_t> getStaticOffsets(){
- return getConstOffsets();
+ auto attr = getConstOffsetsAttr();
+
+ if (attr)
+ return attr;
+
+ auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
+ int rank = 0;
+ if (memrefType) {
+ //use source memref's rank, as source memref rank may be higher
+ rank = memrefType.getRank();
+ } else {
+ //nd_tdesc created from ui64, use nd_tdesc's rank
+ rank = getTensorDescShape().size();
+ };
+
+ // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
+ // It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
+ setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));
+
+ attr = getConstOffsetsAttr();
+ return attr;
}
+
/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ef7cd1424e7a4..9f6090ad279f5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,8 +125,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
- staticOffsets /* const offsets */, {} /* empty const shape*/,
- {} /* empty const strides*/);
+ builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
+ {} /* empty const shape*/, {} /* empty const strides*/);
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
@@ -221,6 +221,246 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}
+ParseResult parseOptionalDynamicIndexList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+ SmallVectorImpl<Type> *valueTypes = nullptr,
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+ SmallVector<int64_t, 4> integerVals;
+ SmallVector<bool, 4> scalableVals;
+ auto parseIntegerOrValue = [&]() {
+ OpAsmParser::UnresolvedOperand operand;
+ auto res = parser.parseOptionalOperand(operand);
+
+ // When encountering `[`, assume that this is a scalable index.
+ scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+ if (res.has_value() && succeeded(res.value())) {
+ values.push_back(operand);
+ integerVals.push_back(ShapedType::kDynamic);
+ if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+ return failure();
+ } else {
+ int64_t integer;
+ if (failed(parser.parseInteger(integer)))
+ return failure();
+ integerVals.push_back(integer);
+ }
+
+ // If this is assumed to be a scalable index, verify that there's a closing
+ // `]`.
+ if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+ return failure();
+ return success();
+ };
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+ parser.parseRSquare())
+ return parser.emitError(parser.getNameLoc())
+ << "expected SSA value or integer";
+ integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+ scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+ return success();
+ }
+ return success();
+}
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+ ::mlir::OperationState &result) {
+ ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+ ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+ &sourceRawOperand, 1);
+ ::llvm::SMLoc sourceOperandsLoc;
+
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ offsetsOperands;
+ ::llvm::SMLoc offsetsOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+ ::llvm::SMLoc shapeOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_shapeAttr;
+ ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+ stridesOperands;
+ ::llvm::SMLoc stridesOperandsLoc;
+ ::mlir::DenseI64ArrayAttr const_stridesAttr;
+ ::mlir::Type sourceRawType{};
+ ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+ ::mlir::Type TensorDescRawType{};
+ ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+ sourceOperandsLoc = parser.getCurrentLocation();
+ if (parser.parseOperand(sourceRawOperand))
+ return ::mlir::failure();
+
+ offsetsOperandsLoc = parser.getCurrentLocation();
+
+ DenseBoolArrayAttr scalableFlags;
+ auto odsResult = parseOptionalDynamicIndexList(
+ parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+ if (const_offsetsAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+ const_offsetsAttr;
+ }
+
+ if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ shapeOperandsLoc = parser.getCurrentLocation();
+ auto odsResult =
+ parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+ if (const_shapeAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+ const_shapeAttr;
+ }
+ }
+
+ if (parser.parseKeyword("strides"))
+ return ::mlir::failure();
+ if (parser.parseColon())
+ return ::mlir::failure();
+ {
+ stridesOperandsLoc = parser.getCurrentLocation();
+ auto odsResult =
+ parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+ if (const_stridesAttr) {
+ if (odsResult)
+ return ::mlir::failure();
+ result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+ const_stridesAttr;
+ }
+ }
+ }
+ {
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return ::mlir::failure();
+ if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+ return parser.emitError(loc)
+ << "'" << result.name.getStringRef() << "' op ";
+ })))
+ return ::mlir::failure();
+ }
+ if (parser.parseColon())
+ return ::mlir::failure();
+
+ {
+ ::mlir::Type type;
+ if (parser.parseCustomTypeWithFallback(type))
+ return ::mlir::failure();
+ sourceRawType = type;
+ }
+ if (parser.parseArrow())
+ return ::mlir::failure();
+
+ if (parser.parseType(TensorDescRawType))
+ return ::mlir::failure();
+
+ ::llvm::copy(::llvm::ArrayRef<int32_t>(
+ {1, static_cast<int32_t>(offsetsOperands.size()),
+ static_cast<int32_t>(shapeOperands.size()),
+ static_cast<int32_t>(stridesOperands.size())}),
+ result.getOrAddProperties<CreateNdDescOp::Properties>()
+ .operandSegmentSizes.begin());
+
+ ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+ result.addTypes(TensorDescTypes);
+
+ if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+ result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+ offsetsOperandsLoc, result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+ result.operands))
+ return ::mlir::failure();
+
+ if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+ stridesOperandsLoc, result.operands))
+ return ::mlir::failure();
+ return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+ _odsPrinter << ' ';
+ _odsPrinter << getSource();
+
+ auto constOffsetsAttr = getConstOffsetsAttr();
+ bool printOffsets = false;
+ if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
+ auto firstVal = constOffsetsAttr.asArrayRef()[0];
+ if (firstVal != std::numeric_limits<int64_t>::max()) {
+ printOffsets = true;
+ }
+ }
+ if (printOffsets) {
+
+ printDynamicIndexList(_odsPrinter, *this, getOffsets(),
+ getConstOffsetsAttr());
+ }
+ if (((!getShape().empty()) || (getConstShapeAttr()))) {
+ _odsPrinter << ' ' << "shape";
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
+ _odsPrinter << ' ' << "strides";
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ printDynamicIndexList(_odsPrinter, *this, getStrides(),
+ getConstStridesAttr());
+ }
+ ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
+ elidedAttrs.push_back("operandSegmentSizes");
+ elidedAttrs.push_back("const_offsets");
+ elidedAttrs.push_back("const_shape");
+ elidedAttrs.push_back("const_strides");
+ _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+ _odsPrinter << ' ' << ":";
+ _odsPrinter << ' ';
+ {
+ auto type = getSource().getType();
+ if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
+ _odsPrinter.printStrippedAttrOrType(validType);
+ else
+ _odsPrinter << type;
+ }
+ _odsPrinter << ' ' << "->";
+ _odsPrinter << ' ';
+ // _odsPrinter << getTensorDesc().getType();
+
+ _odsPrinter << "!xegpu.tensor_desc<";
+
+ auto tDesc = getTensorDesc().getType();
+ auto shape = tDesc.getShape();
+ for (int64_t dim : shape) {
+ if (mlir::ShapedType::isDynamic(dim))
+ _odsPrinter << '?';
+ else
+ _odsPrinter << dim;
+ _odsPrinter << 'x';
+ }
+
+ _odsPrinter << tDesc.getElementType();
+
+ if (auto encoding = tDesc.getEncoding())
+ _odsPrinter << ", " << encoding;
+
+ if (auto layout = tDesc.getLayout())
+ _odsPrinter << ", " << layout;
+
+ _odsPrinter << ">";
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3bfe1fa81aa6e..0d679e519ed60 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
- // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
- %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
gpu.return
}
@@ -62,6 +62,47 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
}
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
+gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
+ //CHECK: %[[C:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
+gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ %2 = xegpu.create_nd_tdesc %src shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+ gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})
+
+gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4] shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+ %2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+ gpu.return
+}
+
// CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..ba29d1ab13cae 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -150,16 +150,16 @@ gpu.module @test {
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
gpu.module @test {
gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
%c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
- %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
|
cc also @akroviakov |
if (memrefType) { | ||
//use source memref's rank, as source memref rank may be higher | ||
rank = memrefType.getRank(); | ||
} else { | ||
//nd_tdesc created from ui64, use nd_tdesc's rank | ||
rank = getTensorDescShape().size(); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need of brackets for single line if and else.
// It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value. | ||
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max())); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably we need to reuse this constant in future. Better to define it somewhere.
static constexpr int64_t optionalValue = std::numeric_limits<int64_t>::max();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parser and print code is supposed to removed once we finish the transition that move the offsets from create_nd_tdesc definition to load_nd. So no plan to reuse.
staticOffsets /* const offsets */, {} /* empty const shape*/, | ||
{} /* empty const strides*/); | ||
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, | ||
{} /* empty const shape*/, {} /* empty const strides*/); | ||
} | ||
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where is the new build methods implemented?
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
ParseResult parseOptionalDynamicIndexList( | ||
OpAsmParser &parser, | ||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, | ||
SmallVectorImpl<Type> *valueTypes = nullptr, | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we can not reuse parseDynamicIndexList
method and avoid this? I see lost of logic replicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The offset is provided as optional bracket [], so we need to customize parseDynamicIndexList.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refactored and simplified the custom parser.
_odsPrinter << ' '; | ||
_odsPrinter << getSource(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should avoid using tablegen generate variable names _odsPrinter
and use something more readable like printer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't expect this code will stay permanent. Keeping them same as Talegen generated printer code helps debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) { | ||
if (parser.parseColon()) | ||
return ::mlir::failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't it be nicer to use a keyword for offset as well? For the optional case it will be empty square brackets.
offsets : [], strides : [...], shapes: []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are moving offsets to load_nd.
@@ -56,7 +56,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>, | |||
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]] | |||
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]] | |||
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] | |||
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] | |||
// CHECK-SAME: shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]] strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not comma seperate them for readabilty.
shapes : [] , strides: []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be future enhancement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, apart from a few points.
-
Might be worth adapting examples in
XeGPU_CreateNdDescOp
description. -
The verifier never checks for
ui64
that the shape/strides rank matches the tensor descriptor rank. The following IR is considered valid:
%1 = xegpu.create_nd_tdesc %src[%x, %y, %c1] shape : [%h, %w, %x] strides : [%w, %c1, %y] : ui64 -> !xegpu.tensor_desc<8x16xf32>
but if one removes the offset:
%1 = xegpu.create_nd_tdesc %src shape : [%h, %w, %x] strides : [%w, %c1, %y] : ui64 -> !xegpu.tensor_desc<8x16xf32>
the rank mismatch between sizes and offsets error pops up.
For optional offsets to work with ui64
, if you decide to tie the offsets rank to the tdesc rank (rather than to sizes/strides), then the verifier should check that supplied shape/strides also have a matching rank with tdesc.
- (May be unrelated to the PR topic) For
ui64
, the description mandates users to supply sizes/strides, but the verifier never checks it, so the following code crashes, instead of showing some error:
%1 = xegpu.create_nd_tdesc %src : ui64 -> !xegpu.tensor_desc<8x16xf32>
rank = memrefType.getRank(); | ||
} else { | ||
//nd_tdesc created from ui64, use nd_tdesc's rank | ||
rank = getTensorDescShape().size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not safe to infer the rank from the TensorDesc, since TensorDesc could have fewer rank than offset. You can simply use int rank = getStaticSizes().size()
instead;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, user doesn't specify neither const offsets or dynamic offset values. So I assume that we can only infer the rank from TensorDesc. Not sure getStaticSizes() can give us correct result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
✅ With the latest revision this PR passed the C/C++ code formatter. |
It is not intended to tie the offsets rank to tdesc rank. I just fixed to tie to shapes/strides of input tensor. |
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
Type tdesc, TypedValue<MemRefType> source) { | ||
[[maybe_unused]] auto ty = source.getType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think [[maybe_unused]]
is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ty is only used in the assert statement which is unused in release binary.
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
Type tdesc, TypedValue<MemRefType> source) { | ||
[[maybe_unused]] auto ty = source.getType(); | ||
assert(ty.hasStaticShape()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some error text in assert. like "expecting a memref with static shape"
Type tdesc, TypedValue<MemRefType> source, | ||
llvm::ArrayRef<OpFoldResult> shape, | ||
llvm::ArrayRef<OpFoldResult> strides) { | ||
assert(shape.size() && strides.size() && shape.size() == strides.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add some error text. why this invariant must be satisfied.
Type tdesc, TypedValue<IntegerType> source, | ||
llvm::ArrayRef<OpFoldResult> shape, | ||
llvm::ArrayRef<OpFoldResult> strides) { | ||
assert(shape.size() && strides.size() && shape.size() == strides.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
if (parser.parseCommaSeparatedList(parseIntegerOrValue) || | ||
parser.parseRSquare()) | ||
return parser.emitError(parser.getNameLoc()) | ||
<< "expected SSA value or integer"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<< "expected SSA value or integer"; | |
<< "expected a list of SSA values or integers"; |
|
||
return success(); | ||
}; | ||
if (parser.parseOptionalLSquare().succeeded()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume that for no-offset case this check will fail?
Example:
create_nd %src shape: [] strides: []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comment here like "If the optional values are given there must be left bracket"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Added
TypeRange valueTypes = TypeRange()) { | ||
|
||
if (values.empty() && llvm::all_of(integers, [](int64_t i) { | ||
return i == std::numeric_limits<int64_t>::max(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a comment here to explain that here we use some place holder values to carry optional offsets.
nit: I still prefer if this constant is defined somewhere and properly documented. Will make life easier for future changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is temporary. I added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can remove the valueTypes and related logic, since it is not used.
printer << values[dynamicValIdx]; | ||
if (!valueTypes.empty()) | ||
printer << " : " << valueTypes[dynamicValIdx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does values and valueTypes always have same size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems correct from the test cases. This interface is automatically generated by parser which should guarantee it.
verifier and invalid test case added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ some nit comments
rank = memrefType.getRank(); | ||
else | ||
//nd_tdesc created from ui64, use nd_tdesc's rank | ||
rank = getMixedSizes().size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is not nd_tdesc's rank, it is the rank of Shape. It is not necessary to differentiate between memrefType and else, since it is handled in getMixedSizes()
. It is an abstraction interface returns the shape of the memory, regardless of it is specified by a MemrefType or via the shape/stride parameter.
TypeRange valueTypes = TypeRange()) { | ||
|
||
if (values.empty() && llvm::all_of(integers, [](int64_t i) { | ||
return i == std::numeric_limits<int64_t>::max(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can remove the valueTypes and related logic, since it is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also update the PR description?
The second case's IR still uses offsets
fixed. thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the extra iterations👍
Now effectively there should be no semantic changes to the op itself.
…etch_nd (#149424) This PR allows load_nd/store_nd/prefetch_nd to take an additional offset operand. It is based on this PR llvm/llvm-project#148335. Now user can create a nd_tdesc with no offset, and instead set the offset with the load_nd operation.
…49424) This PR allows load_nd/store_nd/prefetch_nd to take an additional offset operand. It is based on this PR llvm/llvm-project#148335. Now user can create a nd_tdesc with no offset, and instead set the offset with the load_nd operation.
…vm#149424) This PR allows load_nd/store_nd/prefetch_nd to take an additional offset operand. It is based on this PR llvm#148335. Now user can create a nd_tdesc with no offset, and instead set the offset with the load_nd operation.
This PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd.