Skip to content

Conversation

@chencha3
Copy link
Contributor

As described by the title.

@chencha3 chencha3 requested review from adam-smnk and charithaintc and removed request for charithaintc August 19, 2025 18:47
@chencha3 chencha3 marked this pull request as ready for review August 19, 2025 18:47
@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Chao Chen (chencha3)

Changes

As described by the title.


Patch is 33.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154403.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h (+2)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+41-4)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+8-4)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+180-76)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+6-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+74-2)
  • (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 3592da4c46364..1d152f0c9ca9a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -23,6 +24,7 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
+class DistributeLayoutAttrInterface;
 class LayoutAttr;
 class SliceAttr;
 } // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index a94987885c9e0..de86141ad006a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def LayoutTrait: AttrInterface<"LayoutTrait"> {
+def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
   let cppNamespace = "::mlir::xegpu";
   let description = [{
     Common trait for all XeGPU layouts.
   }];
 
   let methods = [
+    InterfaceMethod<"Check the availability of workgroup level layouts",
+                    "bool",
+                    "isWgLayout">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
+    InterfaceMethod<"Get the num of effective subgroups",
+                    "int64_t",
+                    "getNumSubgroups">,
     InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgLayoutAsInt">,
     InterfaceMethod<"Get the SgData field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
+    InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
+                    "xegpu::DistributeLayoutAttrInterface",
+                    "dropSgLayoutAndData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
                     "FailureOr<SmallVector<Value>>",
@@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
   ];
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
       return 0;
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     LayoutAttr dropSgLayoutAndData() {
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
@@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
@@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
   }];
 
   let parameters = (ins
-    "xegpu::LayoutTrait": $parent,
+    "xegpu::DistributeLayoutAttrInterface": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
@@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return parent.isSgLayout();
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
@@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return std::nullopt;
     }
 
+    SliceAttr dropSgLayoutAndData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropSgLayoutAndData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
+    SliceAttr dropInstData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropInstData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..3ba9eaa4a66da 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return static_cast<unsigned>(MemorySpace::Global);
     }
 
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
+    }
+
   }];
 }
 
@@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
   let assemblyFormat = [{
@@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 
   let builders = [
     OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
-                    "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   }];
   let builders = [
     OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ea8cb1f45972..de118b7faea4d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                        ArrayRef<int64_t> shape) {
@@ -322,7 +323,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
 //===----------------------------------------------------------------------===//
 LogicalResult
 SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
-                  xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+                  xegpu::DistributeLayoutAttrInterface parent,
+                  DenseI64ArrayAttr dims) {
   if (!parent || !dims)
     return emitError() << "expected parent layout and dims attribute";
 
@@ -340,7 +342,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 }
 
 SliceAttr SliceAttr::flatten() const {
-  xegpu::LayoutTrait parent = getParent();
+  xegpu::DistributeLayoutAttrInterface parent = getParent();
   SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
 
   while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return parent.delinearizeSubgroupId(builder, loc, linearId);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                       ArrayRef<int64_t> shape) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..0e22af900daf1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
                          TypedValue<MemDescType> memDesc,
                          llvm::ArrayRef<OpFoldResult> offsets,
-                         LayoutTrait layout) {
+                         DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
                           TypedValue<MemDescType> memDesc,
                           llvm::ArrayRef<OpFoldResult> offsets,
-                          LayoutTrait layout) {
+                          DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..ca1209e776d0e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -55,17 +55,16 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
 }
 
 static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+                   xegpu::DistributeLayoutAttrInterface layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
-
   if (layout && layout.isWgLayout()) {
-    DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
-    auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
-      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    else
-      sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
+    if (auto maybeSgData = layout.getSgDataAsInt())
+      sgShape = *maybeSgData;
+    else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+      sgShape = *maybeDerivedSgData;
     SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
     // Clamp distUnit to the original shape to handle cases where data is
     // shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+                                     Location loc, ArrayRef<OpFoldResult> lhs,
+                                     ArrayRef<OpFoldResult> rhs) {
+  SmallVector<OpFoldResult> reversedResult;
+  auto l = lhs.rbegin();
+  auto r = rhs.rbegin();
+  for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+    if (l == lhs.rend()) {
+      reversedResult.push_back(*r);
+    } else if (r == rhs.rend()) {
+      reversedResult.push_back(*l);
+    } else {
+      auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+      auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+      auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+      reversedResult.push_back(add);
+    }
+  }
+  return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+                            SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+              xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+             typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+             OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+  Location loc = op.getLoc();
+  auto layout = op.getLayoutAttr();
+  if (!layout || !layout.isWgLayout())
+    return failure();
+
+  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+  // adjust the linearId if the range specifier is present
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+  if (sgIdRangeSpecified) {
+    if (layout.getNumSubgroups() != endOfRange - startOfRange)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+  }
+
+  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  if (failed(maybeMdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+  callback(wgOffsets, *maybeMdescOffsets);
+  return success();
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -137,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
-    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-    if (!layout)
-      return failure();
-    Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    // sgLayout must be present for workgroup-level distribution.
-    SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout())
-      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    else
-      return rewriter.notifyMatchFailure(
-          op, "sgLayout attribute is required in layout");
-
-    // Get the subgroup ID
-    Value linearSgId =
-        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-    int64_t startOfRange = -1, endOfRange = -1;
-    bool sgIdRangeSpecified =
-        isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
-    if (sgIdRangeSpecified) {
-      int64_t sgCount = endOfRange - startOfRange;
-      if (computeProduct(sgLayout) != sgCount)
-        return rewriter.notifyMatchFailure(
-            op, "sg_layout size must match the sg_id_range");
-      // Subtract startOfRange from the original subgroup id to get
-      // the adjusted sg id
-      Value startOfRangeVal =
-          arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
-      linearSgId =
-          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
-    }
-
-    auto maybeTdescOffsets =
-        layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-    if (failed(maybeTdescOffsets))
-      return failure();
-
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    xegpu::TensorDescType newTdescTy =
-        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
-                                   layout.dropSgLayoutAndData());
+    Type elemTy = tdescTy.getElementType();
 
-    SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
-    for (auto tdescOffsets : *maybeTdescOffsets) {
-      SmallVector<OpFoldResult> sgOffsets;
-      size_t rank = tdescOffsets.size();
-      for (size_t i = 0; i < rank; i++) {
-        size_t idx = origOffsets.size() - rank + i;
-        Value add = rewriter.createOrFold<index::AddOp>(
-            loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
-        sgOffsets.push_back(add);
+    // the call back function for creating new CreateNdOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+                        SmallVector<SmallVector<Value>> descOffsets) {
+      xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+      auto newTdescTy = xegpu::TensorDescType::get(
+          ctx, sgShape, elemTy, tdescTy.getEncoding(),
+          layout.dropSgLayoutAndData());
+
+      SmallVector<Value> newOps;
+      for (auto offsets : descOffsets) {
+        SmallVector<OpFoldResult> sgOffsets =
+            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
+        auto newOp = xegpu::CreateNdDescOp::create(
+            rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
+            op.getMixedSizes(), op.getMixedStrides());
+
+        newOps.push_back(newOp);
       }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+    };
 
-      auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
-          op.getMixedSizes(), op.getMixedStrides());
-      newCreateNdOps.push_back(newOp);
-    }
-    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
-    return success();
+    return distributeOp(rewriter, adaptor, op, wgShape, callback);
   }
 };
 
@@ -723,8 +752,8 @@...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 19, 2025

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

Changes

As described by the title.


Patch is 33.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154403.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h (+2)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+41-4)
  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+8-4)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+9-6)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+180-76)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+6-2)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+74-2)
  • (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 3592da4c46364..1d152f0c9ca9a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -23,6 +24,7 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
+class DistributeLayoutAttrInterface;
 class LayoutAttr;
 class SliceAttr;
 } // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index a94987885c9e0..de86141ad006a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def LayoutTrait: AttrInterface<"LayoutTrait"> {
+def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
   let cppNamespace = "::mlir::xegpu";
   let description = [{
     Common trait for all XeGPU layouts.
   }];
 
   let methods = [
+    InterfaceMethod<"Check the availability of workgroup level layouts",
+                    "bool",
+                    "isWgLayout">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
+    InterfaceMethod<"Get the num of effective subgroups",
+                    "int64_t",
+                    "getNumSubgroups">,
     InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgLayoutAsInt">,
     InterfaceMethod<"Get the SgData field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
+    InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
+                    "xegpu::DistributeLayoutAttrInterface",
+                    "dropSgLayoutAndData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
                     "FailureOr<SmallVector<Value>>",
@@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
   ];
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
       return 0;
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     LayoutAttr dropSgLayoutAndData() {
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
@@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
@@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
   }];
 
   let parameters = (ins
-    "xegpu::LayoutTrait": $parent,
+    "xegpu::DistributeLayoutAttrInterface": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
@@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return parent.isSgLayout();
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
@@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return std::nullopt;
     }
 
+    SliceAttr dropSgLayoutAndData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropSgLayoutAndData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
+    SliceAttr dropInstData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropInstData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..3ba9eaa4a66da 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return static_cast<unsigned>(MemorySpace::Global);
     }
 
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
+    }
+
   }];
 }
 
@@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
   let assemblyFormat = [{
@@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 
   let builders = [
     OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
-                    "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   }];
   let builders = [
     OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ea8cb1f45972..de118b7faea4d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                        ArrayRef<int64_t> shape) {
@@ -322,7 +323,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
 //===----------------------------------------------------------------------===//
 LogicalResult
 SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
-                  xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+                  xegpu::DistributeLayoutAttrInterface parent,
+                  DenseI64ArrayAttr dims) {
   if (!parent || !dims)
     return emitError() << "expected parent layout and dims attribute";
 
@@ -340,7 +342,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 }
 
 SliceAttr SliceAttr::flatten() const {
-  xegpu::LayoutTrait parent = getParent();
+  xegpu::DistributeLayoutAttrInterface parent = getParent();
   SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
 
   while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return parent.delinearizeSubgroupId(builder, loc, linearId);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                       ArrayRef<int64_t> shape) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..0e22af900daf1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
                          TypedValue<MemDescType> memDesc,
                          llvm::ArrayRef<OpFoldResult> offsets,
-                         LayoutTrait layout) {
+                         DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
                           TypedValue<MemDescType> memDesc,
                           llvm::ArrayRef<OpFoldResult> offsets,
-                          LayoutTrait layout) {
+                          DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..ca1209e776d0e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -55,17 +55,16 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
 }
 
 static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+                   xegpu::DistributeLayoutAttrInterface layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
-
   if (layout && layout.isWgLayout()) {
-    DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
-    auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
-      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    else
-      sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
+    if (auto maybeSgData = layout.getSgDataAsInt())
+      sgShape = *maybeSgData;
+    else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+      sgShape = *maybeDerivedSgData;
     SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
     // Clamp distUnit to the original shape to handle cases where data is
     // shared among subgroups, which may cause distUnit to exceed the original
@@ -77,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   return std::make_pair(sgShape, count);
 }
 
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+                                     Location loc, ArrayRef<OpFoldResult> lhs,
+                                     ArrayRef<OpFoldResult> rhs) {
+  SmallVector<OpFoldResult> reversedResult;
+  auto l = lhs.rbegin();
+  auto r = rhs.rbegin();
+  for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+    if (l == lhs.rend()) {
+      reversedResult.push_back(*r);
+    } else if (r == rhs.rend()) {
+      reversedResult.push_back(*l);
+    } else {
+      auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+      auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+      auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+      reversedResult.push_back(add);
+    }
+  }
+  return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+                            SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+              xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+             typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+             OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+  Location loc = op.getLoc();
+  auto layout = op.getLayoutAttr();
+  if (!layout || !layout.isWgLayout())
+    return failure();
+
+  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+  // adjust the linearId if the range specifier is present
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+  if (sgIdRangeSpecified) {
+    if (layout.getNumSubgroups() != endOfRange - startOfRange)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+  }
+
+  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  if (failed(maybeMdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+  callback(wgOffsets, *maybeMdescOffsets);
+  return success();
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -137,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
-    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-    if (!layout)
-      return failure();
-    Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    // sgLayout must be present for workgroup-level distribution.
-    SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout())
-      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    else
-      return rewriter.notifyMatchFailure(
-          op, "sgLayout attribute is required in layout");
-
-    // Get the subgroup ID
-    Value linearSgId =
-        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-    int64_t startOfRange = -1, endOfRange = -1;
-    bool sgIdRangeSpecified =
-        isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
-    if (sgIdRangeSpecified) {
-      int64_t sgCount = endOfRange - startOfRange;
-      if (computeProduct(sgLayout) != sgCount)
-        return rewriter.notifyMatchFailure(
-            op, "sg_layout size must match the sg_id_range");
-      // Subtract startOfRange from the original subgroup id to get
-      // the adjusted sg id
-      Value startOfRangeVal =
-          arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
-      linearSgId =
-          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
-    }
-
-    auto maybeTdescOffsets =
-        layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-    if (failed(maybeTdescOffsets))
-      return failure();
-
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    xegpu::TensorDescType newTdescTy =
-        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
-                                   layout.dropSgLayoutAndData());
+    Type elemTy = tdescTy.getElementType();
 
-    SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
-    for (auto tdescOffsets : *maybeTdescOffsets) {
-      SmallVector<OpFoldResult> sgOffsets;
-      size_t rank = tdescOffsets.size();
-      for (size_t i = 0; i < rank; i++) {
-        size_t idx = origOffsets.size() - rank + i;
-        Value add = rewriter.createOrFold<index::AddOp>(
-            loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
-        sgOffsets.push_back(add);
+    // the call back function for creating new CreateNdOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+                        SmallVector<SmallVector<Value>> descOffsets) {
+      xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+      auto newTdescTy = xegpu::TensorDescType::get(
+          ctx, sgShape, elemTy, tdescTy.getEncoding(),
+          layout.dropSgLayoutAndData());
+
+      SmallVector<Value> newOps;
+      for (auto offsets : descOffsets) {
+        SmallVector<OpFoldResult> sgOffsets =
+            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
+        auto newOp = xegpu::CreateNdDescOp::create(
+            rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
+            op.getMixedSizes(), op.getMixedStrides());
+
+        newOps.push_back(newOp);
       }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+    };
 
-      auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
-          op.getMixedSizes(), op.getMixedStrides());
-      newCreateNdOps.push_back(newOp);
-    }
-    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
-    return success();
+    return distributeOp(rewriter, adaptor, op, wgShape, callback);
   }
 };
 
@@ -723,8 +752,8 @@...
[truncated]

// #b = #xegpu.layout<inst_data = [8, 16]>
// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why matrix_desc here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. missed this when address conflicts happened during merging.

@Garra1980
Copy link

cc @akroviakov

let methods = [
InterfaceMethod<"Check the availability of workgroup level layouts",
"bool",
"isWgLayout">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should rename this to isSgLayout as its confusing...the layout specifies how the subgroups are laid out, there is nothing like WgLayout

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed it to isForWorkgroup. sounds good to you?

typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
xegpu::StoreMatrixOp>::value>>
static LogicalResult
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can reuse getSgOffsets utility for distributeOp..most of the logic is same

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored and reused it.

// the baseOffsets is the origial offsets of the op, and
// descOffsets is the relative offsets to the mem_desc accessed
// by each subgroup op.
auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would want to avoid callbacks and make the source more readable. I think Jian Hui did not like the idea of callbacks and was in favor of calling utility functions from the pattern..take a look at scatter ops patterns

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is scatter ops patterns?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry, not scatter ops, look at the Load/StoreNdOpWithOffsets Patterns ...https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp#L395

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to something more readable createNdDistributionCallBack if are using callbacks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored and removed callback implementation per suggestion.

// An util helper to generate elementwise addition ops for index computing.
// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
// left-alignment is performed.
static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this util seems very verbose..why are we iterating it in the reverse order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some cases, e.g., old create_nd_desc. The rank of offsets (or memref) could be higher than the rank of created tensor_desc (not sure this situation still exists for new semantic). In such cases, the local offsets from tensor_desc have less dims than the offsets to the memref. And addition is applied to the innermost dims only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reimplemented it without reverse order.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this function to something meaningful than add. Or you can call thr function with .drop_front to make the lengths equal.
Right now, function do more than what it name suggests.


// CHECK-LABEL: distribute_load_matrix
// CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
Copy link
Contributor

@nbpatel nbpatel Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move the tests to xegpu-wg-to-sg-unify-ops.mlir and xegpu-wg-to-sg-unify-ops-rr.mlir as they have tests with new versions of create_nd/load_nd/store_nd..... this file will be deleted in the future once we move away from create_nd with offsets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. good to know.

sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
else
sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is sgLayout optional to begin with?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is hidden by the interface, so we cannot be sure whether the underlying attribute actually has it, shouldn't there be some explicit assertion on layout.getSgLayoutAsInt() value presence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By definition, every field of LayoutAttr is optional, because it is used to describe different levels of code (workgroup and subgroup). So specific to this util, the input maybe subgroup code. the SgShape is the tensor size. getSgLayoutAsInt itself is returning an std::optional. it returns nullopt if SgLayout is missing. Since we checked the availability of work-group level layout, which confirms the presence of SgLayout, so it is guaranteed to return a valid value.

// adjust the linearId if the range specifier is present
int64_t startOfRange = -1, endOfRange = -1;
bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
if (sgIdRangeSpecified) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgIdRangeSpecified should be used carefully. It is false for startOfRange being 0, so [0,4] sg_id_range is considered valid for [1,2] sg_layout by this check. How about adding tests with sg range?

Copy link
Contributor Author

@chencha3 chencha3 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reused this code from upstream. It may be verified already. For your example, it should be filtered out by the second check, if I understand it correctly (line 129). is this your concern?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For your example, it should be filtered out by the second check

It won't, because sgIdRangeSpecified is false for range [0,4], so none of the checks of if (sgIdRangeSpecified) actually happen.
That is, startOfRange is 0, endOfRange is 4, but sgIdRangeSpecified is false, so we never get to if (layout.getNumSubgroups() != endOfRange - startOfRange) that should trigger the failure for [1,2] sg_layout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reused this code from upstream

The issue has been known for some time, correct me if I'm wrong @nbpatel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I got your point now. the attribute itself is not valid against the sgLayout. Our attributes are loosely attached to Ops/Values (except TensorDesc), I feel we may need a systematic way to verify the validness of the IR, instead of checking them in rewrite Patterns, which maybe not scalable. To me, it would be better for rewrite patterns only to check whether it needs to adjust the id or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reused this code from upstream

The issue has been known for some time, correct me if I'm wrong @nbpatel

Yes, the check has to happen outside, I can clean up it up in all other places as well in a follow up PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of checking them in rewrite Patterns

But only this particular pattern cares about the sg_id_range and sg_layout.

I feel we may need a systematic way to verify the validness of the IR

Totally agree, there are many entry points to XeGPU that serve specific usages and without a structured description/verification, a new user may simply get confused as to how their desired IR even supposed to look like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, then it is fine for this PR, but should be fixed or refactored very soon to avoid high exposure of this issue (already 6 ops affected).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I refactored the related code. But I still think we need a systematic approach to verify the validness of attributes and IR.

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

let methods = [
InterfaceMethod<"Check the availability of workgroup level layouts",
"bool",
"isWgLayout">,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 469 to 473
int64_t getNumSubgroups() {
std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
if (sgLayout.has_value())
return computeProduct(*sgLayout);
return 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks identical in both derived classes. maybe move to base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

// An util helper to generate elementwise addition ops for index computing.
// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
// left-alignment is performed.
static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this function to something meaningful than add. Or you can call thr function with .drop_front to make the lengths equal.
Right now, function do more than what it name suggests.

// the baseOffsets is the origial offsets of the op, and
// descOffsets is the relative offsets to the mem_desc accessed
// by each subgroup op.
auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to something more readable createNdDistributionCallBack if are using callbacks.

// the baseOffsets is the origial offsets of the op, and
// descOffsets is the relative offsets to the mem_desc accessed
// by each subgroup op.
auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. could use a better name.

Copy link
Contributor Author

@chencha3 chencha3 Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the callback style implementation and changed it to regular style. and also renamed add to genIndexAdd.

@github-actions
Copy link

github-actions bot commented Aug 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@nbpatel nbpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % some nits


// Compute the final global offsets for each accessed sub-tensor
// or sub-memory descriptor.
// SmallVector<SmallVector<OpFoldResult>> offsetsList;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this commented code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks

sgOffsets.push_back(add);
}

SmallVector<Value> newOps;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest we use the op name in variable here and everywhere else
newOps -> newCreateNdOps

return getMixedValues(getConstOffsets(), getOffsets(), getContext());
}

ArrayRef<int64_t> getDistributeShape() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to mention distribution in the op's methods names? I guess this utility could be reused elsewhere as well, so getShape would be good enough.

Copy link
Contributor Author

@chencha3 chencha3 Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good Idea, I thought about this name too, actually getShape is not available because CreateNdOp has a shape parameter, does getDataShape sound good to you?

@chencha3 chencha3 force-pushed the ld_st_matrix_wg_to_sg branch from cfa5193 to af6f83f Compare August 21, 2025 14:48
@chencha3 chencha3 merged commit 68d6866 into llvm:main Aug 21, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants