Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
Expand All @@ -23,6 +24,7 @@
namespace mlir {
namespace xegpu {
class TensorDescType;
class DistributeLayoutAttrInterface;
class LayoutAttr;
class SliceAttr;
} // namespace xegpu
Expand Down
45 changes: 41 additions & 4 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}

def LayoutTrait: AttrInterface<"LayoutTrait"> {
def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
let cppNamespace = "::mlir::xegpu";
let description = [{
Common trait for all XeGPU layouts.
}];

let methods = [
InterfaceMethod<"Check the availability of workgroup level layouts",
"bool",
"isWgLayout">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should rename this to isSgLayout as its confusing...the layout specifies how the subgroups are laid out, there is nothing like WgLayout

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed it to isForWorkgroup. sounds good to you?

InterfaceMethod<"Get the rank of attribute",
"int64_t",
"getRank">,
InterfaceMethod<"Get the num of effective subgroups",
"int64_t",
"getNumSubgroups">,
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgLayoutAsInt">,
InterfaceMethod<"Get the SgData field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgDataAsInt">,
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
"xegpu::DistributeLayoutAttrInterface",
"dropSgLayoutAndData">,
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
indices based on the effective subgroup layout.}],
"FailureOr<SmallVector<Value>>",
Expand All @@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
];
}

def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
Expand Down Expand Up @@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
return 0;
}

int64_t getNumSubgroups() {
std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
if (sgLayout.has_value())
return computeProduct(*sgLayout);
return 0;
}

LayoutAttr dropSgLayoutAndData() {
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getInstData() && !getLaneLayout())
Expand Down Expand Up @@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
}


def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];

let description = [{
Expand All @@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
}];

let parameters = (ins
"xegpu::LayoutTrait": $parent,
"xegpu::DistributeLayoutAttrInterface": $parent,
"DenseI64ArrayAttr": $dims
);

Expand Down Expand Up @@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
return parent.isSgLayout();
}

int64_t getNumSubgroups() {
std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
if (sgLayout.has_value())
return computeProduct(*sgLayout);
return 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks identical in both derived classes. maybe move to base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

}

/// Returns the SgLayout of the attribute, computed by applying
/// the slice dimensions to the underlying LayoutAttr.
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
Expand All @@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
return std::nullopt;
}

SliceAttr dropSgLayoutAndData() {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
parent = parent.dropSgLayoutAndData();
return SliceAttr::get(getContext(), parent, attr.getDims());
}

SliceAttr dropInstData() {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
parent = parent.dropInstData();
return SliceAttr::get(getContext(), parent, attr.getDims());
}

/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
Expand Down
12 changes: 8 additions & 4 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return static_cast<unsigned>(MemorySpace::Global);
}

xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
}

}];
}

Expand Down Expand Up @@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<LayoutTrait>:$layout
OptionalAttr<DistributeLayoutAttrInterface>:$layout
);
let results = (outs XeGPU_ValueType:$res);
let assemblyFormat = [{
Expand All @@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,

let builders = [
OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
"llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
];
let extraClassDeclaration = [{
SmallVector<OpFoldResult> getMixedOffsets() {
Expand All @@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<LayoutTrait>:$layout
OptionalAttr<DistributeLayoutAttrInterface>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands)}];
Expand All @@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}];
let builders = [
OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
"llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
];
let extraClassDeclaration = [{
SmallVector<OpFoldResult> getMixedOffsets() {
Expand Down
15 changes: 9 additions & 6 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
return affine::delinearizeIndex(builder, loc, linearId, dims);
}

/// Implements LayoutTrait::getOffsets to generate instructions for
/// computing multi-dimensional offsets when distributed by LayoutAttr.
/// Implements DistributeLayoutAttrInterface::getOffsets to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
Expand Down Expand Up @@ -322,7 +323,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
//===----------------------------------------------------------------------===//
LogicalResult
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
xegpu::DistributeLayoutAttrInterface parent,
DenseI64ArrayAttr dims) {
if (!parent || !dims)
return emitError() << "expected parent layout and dims attribute";

Expand All @@ -340,7 +342,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
}

SliceAttr SliceAttr::flatten() const {
xegpu::LayoutTrait parent = getParent();
xegpu::DistributeLayoutAttrInterface parent = getParent();
SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});

while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
Expand Down Expand Up @@ -375,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
return parent.delinearizeSubgroupId(builder, loc, linearId);
}

/// Implements LayoutTrait::getOffsets to generate instructions for
/// computing multi-dimensional offsets when distributed by SliceAttr.
/// Implements DistributeLayoutAttrInterface::getOffsets to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// SliceAttr.
FailureOr<SmallVector<SmallVector<Value>>>
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
TypedValue<MemDescType> memDesc,
llvm::ArrayRef<OpFoldResult> offsets,
LayoutTrait layout) {
DistributeLayoutAttrInterface layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
Expand Down Expand Up @@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
TypedValue<MemDescType> memDesc,
llvm::ArrayRef<OpFoldResult> offsets,
LayoutTrait layout) {
DistributeLayoutAttrInterface layout) {
llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
Expand Down
Loading