-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR][XeGPU] make offsets optional for create_nd_tdesc #148335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b9a6d98
2465050
1077871
42baa22
204d347
2793c81
f23ea03
0bb958b
6793689
4a96c71
689a8a5
02d3795
01718f4
5ef6ca9
26a222d
882313f
456534a
b6f016e
cd518d2
546a3f7
7846955
ded9552
97b6e39
ed1d48e
b3edff6
d3e935b
205fea7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
#include "mlir/Dialect/XeGPU/IR/XeGPU.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Interfaces/ViewLikeInterface.h" | ||
|
||
#include "llvm/Support/Debug.h" | ||
|
||
|
@@ -112,6 +113,68 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, | |
//===----------------------------------------------------------------------===// | ||
// XeGPU_CreateNdDescOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
Type tdesc, TypedValue<MemRefType> source) { | ||
[[maybe_unused]] auto ty = source.getType(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ty is only used in the assert statement which is unused in release binary. |
||
assert(ty.hasStaticShape() && "expecting a memref with static shape"); | ||
|
||
build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */, | ||
ValueRange({}) /* empty dynamic shape */, | ||
ValueRange({}) /* empty dynamic strides */, | ||
DenseI64ArrayAttr({}) /* const offsets */, | ||
DenseI64ArrayAttr({}) /* empty const shape*/, | ||
DenseI64ArrayAttr({}) /* empty const strides*/); | ||
} | ||
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
Type tdesc, TypedValue<MemRefType> source, | ||
llvm::ArrayRef<OpFoldResult> shape, | ||
llvm::ArrayRef<OpFoldResult> strides) { | ||
assert(shape.size() && strides.size() && shape.size() == strides.size() && | ||
"Shape and strides must be present and of equal size for ui64 " | ||
"initialization."); | ||
|
||
llvm::SmallVector<int64_t> staticShape; | ||
llvm::SmallVector<int64_t> staticStrides; | ||
llvm::SmallVector<Value> dynamicShape; | ||
llvm::SmallVector<Value> dynamicStrides; | ||
|
||
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); | ||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); | ||
|
||
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); | ||
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); | ||
|
||
build(builder, state, tdesc, source, ValueRange({}), dynamicShape, | ||
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, | ||
staticStridesAttr); | ||
} | ||
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
Type tdesc, TypedValue<IntegerType> source, | ||
llvm::ArrayRef<OpFoldResult> shape, | ||
llvm::ArrayRef<OpFoldResult> strides) { | ||
assert(shape.size() && strides.size() && shape.size() == strides.size() && | ||
"Shape and strides must be present and of equal size for ui64 " | ||
"initialization."); | ||
|
||
llvm::SmallVector<int64_t> staticShape; | ||
llvm::SmallVector<int64_t> staticStrides; | ||
llvm::SmallVector<Value> dynamicShape; | ||
llvm::SmallVector<Value> dynamicStrides; | ||
|
||
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); | ||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); | ||
|
||
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); | ||
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); | ||
|
||
build(builder, state, tdesc, source, ValueRange({}), dynamicShape, | ||
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, | ||
staticStridesAttr); | ||
} | ||
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
Type tdesc, TypedValue<MemRefType> source, | ||
llvm::ArrayRef<OpFoldResult> offsets) { | ||
|
@@ -125,8 +188,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | |
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */, | ||
ValueRange({}) /* empty dynamic shape */, | ||
ValueRange({}) /* empty dynamic strides */, | ||
staticOffsets /* const offsets */, {} /* empty const shape*/, | ||
{} /* empty const strides*/); | ||
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */, | ||
{} /* empty const shape*/, {} /* empty const strides*/); | ||
} | ||
|
||
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the new build methods implemented?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
@@ -197,6 +260,13 @@ LogicalResult CreateNdDescOp::verify() { | |
invalidElemTy |= memrefTy.getElementType() != getElementType(); | ||
} | ||
|
||
if (llvm::isa<IntegerType>(getSourceType())) { | ||
// strides and shape must present for integer source. | ||
if (getMixedStrides().empty() || getMixedSizes().empty()) | ||
return emitOpError("Expecting strides and shape to be present for " | ||
"integer source."); | ||
} | ||
|
||
// mismatches among shape, strides, and offsets are | ||
// already handeled by OffsetSizeAndStrideOpInterface. | ||
// So they are not check here. | ||
|
@@ -221,6 +291,53 @@ LogicalResult CreateNdDescOp::verify() { | |
return success(); | ||
} | ||
|
||
ParseResult parseOptionalDynamicIndexList( | ||
OpAsmParser &parser, | ||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, | ||
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr, | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { | ||
|
||
SmallVector<int64_t, 4> integerVals; | ||
auto parseIntegerOrValue = [&]() { | ||
OpAsmParser::UnresolvedOperand operand; | ||
auto res = parser.parseOptionalOperand(operand); | ||
|
||
if (res.has_value() && succeeded(res.value())) { | ||
values.push_back(operand); | ||
integerVals.push_back(ShapedType::kDynamic); | ||
if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) | ||
return failure(); | ||
} else { | ||
int64_t integer; | ||
if (failed(parser.parseInteger(integer))) | ||
return failure(); | ||
integerVals.push_back(integer); | ||
} | ||
return success(); | ||
}; | ||
|
||
// If the optional values are given there must be left bracket | ||
if (parser.parseOptionalLSquare().succeeded()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that for no-offset case this check will fail?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a comment here like "If the optional values are given there must be left bracket" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Added |
||
if (parser.parseCommaSeparatedList(parseIntegerOrValue) || | ||
parser.parseRSquare()) | ||
return parser.emitError(parser.getNameLoc()) | ||
<< "expected a list of SSA values or integers"; | ||
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); | ||
return success(); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void printOptionalDynamicIndexList( | ||
OpAsmPrinter &printer, Operation *op, OperandRange values, | ||
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(), | ||
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { | ||
|
||
return printDynamicIndexList(printer, op, values, integers, | ||
/*scalableFlags=*/{}, valueTypes, delimiter); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// XeGPU_PrefetchNdOp | ||
//===----------------------------------------------------------------------===// | ||
|
Uh oh!
There was an error while loading. Please reload this page.