Skip to content

Conversation

@akroviakov
Copy link
Contributor

Currently, we have a circular dependency.
Blocking needs inst_data (which also attaches lane_layout, lane_data), but lane_layout, lane_data need blocking.
We decouple these layout fields in the propagation and make propagation multi-step.
As in

-xegpu-propagate-layout="layout-kind=inst" -xegpu-blocking -xegpu-propagate-layout="layout-kind=lane" -xegpu-subgroup-distribute

@llvmbot
Copy link
Member

llvmbot commented Nov 7, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

Currently, we have a circular dependency.
Blocking needs inst_data (which also attaches lane_layout, lane_data), but lane_layout, lane_data need blocking.
We decouple these layout fields in the propagation and make propagation multi-step.
As in

-xegpu-propagate-layout="layout-kind=inst" -xegpu-blocking -xegpu-propagate-layout="layout-kind=lane" -xegpu-subgroup-distribute

Patch is 69.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166941.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+11)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+120-68)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+21-21)
  • (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+1-579)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 9c35c07a7e587..3f27d690f949b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,6 +379,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   );
 
   let builders = [
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto order = DenseI32ArrayAttr();
+        auto lane_layout = DenseI32ArrayAttr();
+        auto lane_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data,
+                     DenseI32ArrayAttr::get($_ctxt, inst_data),
+                     lane_layout, lane_data, order);
+      }]>,
     AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
                       "llvm::ArrayRef<int32_t>": $lane_layout,
                      "llvm::ArrayRef<int32_t>": $lane_data),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4e1a539771d2f..c00b08e0ca37f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,6 +53,8 @@ using namespace mlir::dataflow;
 
 namespace {
 
+enum class LayoutKind { Lane, InstData };
+
 //===----------------------------------------------------------------------===//
 // LayoutInfo
 //===----------------------------------------------------------------------===//
@@ -99,7 +101,8 @@ struct LayoutInfo {
 
   bool isAssigned() const { return storage != nullptr; }
 
-  LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
+  LayoutInfo transpose(ArrayRef<int64_t> permutation,
+                       LayoutKind layoutKind) const;
 
   SmallVector<int> getLaneLayout() const;
 
@@ -167,7 +170,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
 }
 
 /// Construct a new layout with the transposed lane layout and lane data.
-LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
+LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation,
+                                 LayoutKind layoutKind) const {
   if (!isAssigned())
     return {};
   // Check if the permutation is valid.
@@ -186,12 +190,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
   SmallVector<int32_t> laneData;
   SmallVector<int32_t> instData;
   for (int64_t idx : permutation) {
-    laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
-    laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
-    instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+    if (layoutKind == LayoutKind::Lane) {
+      laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+      laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+    } else if (layoutKind == LayoutKind::InstData)
+      instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+  }
+  xegpu::LayoutAttr layoutAttr;
+  if (layoutKind == LayoutKind::Lane) {
+    layoutAttr =
+        xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+  } else if (layoutKind == LayoutKind::InstData) {
+    layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
-                                           laneLayout, laneData));
+  return LayoutInfo(layoutAttr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -213,15 +225,14 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
                                            unsigned rank,
-                                           const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData) {
+                                           const xegpu::uArch::uArch *uArch) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
     return LayoutInfo(
-        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
+        xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+  return LayoutInfo(
+      xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
 }
 
 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
@@ -236,7 +247,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
 /// Helper to get the default layout for a vector type.
 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
                                            const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData,
                                            unsigned packingSize,
                                            bool isScattered = false) {
   // Expecting a 1D or 2D vector.
@@ -247,16 +257,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
   if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
                                              {uArch->getSubgroupSize(), 1},
                                              {1, packingFactor}));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
                                            {1, uArch->getSubgroupSize()},
                                            {1, packingFactor}));
 }
@@ -275,7 +285,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
+    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
   int subgroupSize = uArch->getSubgroupSize();
@@ -298,7 +308,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
 static LayoutInfo
 getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
                                 const xegpu::uArch::uArch *uArch,
-                                ArrayRef<int> instData, unsigned packingSize) {
+                                unsigned packingSize) {
   Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
@@ -310,10 +320,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
          1});
     return LayoutInfo(
-        xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
+        xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
+  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -328,6 +338,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
 class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
+  LayoutKind layoutKind;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -380,8 +391,10 @@ class LayoutInfoPropagation
 
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
-                        SymbolTableCollection &symbolTable)
-      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+                        SymbolTableCollection &symbolTable,
+                        LayoutKind layoutKind)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+        layoutKind(layoutKind) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
   LogicalResult
@@ -627,14 +640,24 @@ void LayoutInfoPropagation::visitDpasOp(
   SmallVector<int> instDataA = {maxALen, subgroupSize};
   SmallVector<int> instDataB = {subgroupSize, maxBLen};
 
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
-                         aTy, 0, uArch, instDataA,
-                         uArchInstruction->getPackedFormatBitSizeA())));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
-                         bTy, 1, uArch, instDataB,
-                         uArchInstruction->getPackedFormatBitSizeB())));
+  LayoutInfo dpasALayout;
+  LayoutInfo dpasBLayout;
+  LayoutInfo dpasCLayout;
+
+  if (layoutKind == LayoutKind::InstData) {
+    dpasALayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+    dpasBLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+  } else {
+    dpasALayout = getSIMTLayoutInfoForDPASOperand(
+        aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+    dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+        bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+  }
+
+  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
   if (operands.size() > 2) {
     VectorType cTy = dpas.getAccType();
     const unsigned dataCLen = bTy.getShape().back();
@@ -645,10 +668,15 @@ void LayoutInfoPropagation::visitDpasOp(
       dpas.emitWarning(
           "No suitable instruction multiple found for the given shape.");
     SmallVector<int> instDataC = {maxALen, maxCLen};
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
-                           cTy, 2, uArch, instDataC,
-                           uArchInstruction->getPackedFormatBitSizeB())));
+
+    if (layoutKind == LayoutKind::InstData)
+      dpasCLayout =
+          LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+    else
+      dpasCLayout = getSIMTLayoutInfoForDPASOperand(
+          cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+    propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
   }
 }
 
@@ -685,9 +713,15 @@ void LayoutInfoPropagation::visitStoreNdOp(
           "No suitable instruction multiple found for the given shape.");
     instData = {instHeight, instWidth};
   }
-  LayoutInfo storeLayout =
-      getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
-                               uArchInstruction->getPackedFormatBitSize());
+
+  LayoutInfo storeLayout;
+  if (layoutKind == LayoutKind::InstData)
+    storeLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+  else
+    storeLayout =
+        getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+                                 uArchInstruction->getPackedFormatBitSize());
   // Both operands should have the same layout
   for (LayoutInfoLattice *operand : operands)
     propagateIfChanged(operand, operand->meet(storeLayout));
@@ -709,7 +743,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
   if (auto transpose = load.getTranspose()) {
     load.emitWarning("Transpose effect is not expected for LoadNdOp at "
                      "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.transpose(transpose.value());
+    tensorDescLayout = valueLayout.transpose(transpose.value(), layoutKind);
   }
   // Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
@@ -724,7 +758,8 @@ void LayoutInfoPropagation::visitTransposeOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
+  LayoutInfo newLayout =
+      resultLayout.transpose(transpose.getPermutation(), layoutKind);
   // Propagate the new layout to the vector operand.
   propagateIfChanged(operands[0], operands[0]->meet(newLayout));
 }
@@ -818,9 +853,13 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     if (srcTdescTy.getChunkSizeAsInt() > 1)
       instData.push_back(chunkSize);
   }
-  LayoutInfo layout = getDefaultSIMTLayoutInfo(
-      payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
-      /*scattered*/ true);
+  LayoutInfo layout;
+  if (layoutKind == LayoutKind::InstData)
+    layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+  else
+    layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
+                                      uArch->getGeneralPackedFormatBitSize(),
+                                      /*scattered*/ true);
 
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout =
@@ -864,33 +903,36 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
     return;
   }
+  LayoutInfo payloadLayout;
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
   const int subgroupSize = uArch->getSubgroupSize();
 
-  auto payloadShape = payloadTy.getShape();
-  if (payloadShape.size() > 1)
-    assert(
-        payloadShape[0] == subgroupSize &&
-        "Expected the first dimension of 2D tensor descriptor to be equal to "
-        "subgroup size.");
-
-  SmallVector<int> instData{subgroupSize};
-  if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
-    instData.push_back(chunkSize);
-  else if (auto dstTdescTy =
-               dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
-    if (dstTdescTy.getChunkSizeAsInt() > 1)
-      instData.push_back(chunkSize);
-  }
-
-  LayoutInfo payloadLayout;
-
   if (auto layout = storeScatter.getLayoutAttr()) {
     payloadLayout = LayoutInfo(layout);
   } else {
-    payloadLayout = getDefaultSIMTLayoutInfo(
-        payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
-        /*scattered=*/true);
+    if (layoutKind == LayoutKind::InstData) {
+      SmallVector<int> instData{subgroupSize};
+      if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+          chunkSize > 1)
+        instData.push_back(chunkSize);
+      else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+                   storeScatter.getDestType())) {
+        if (dstTdescTy.getChunkSizeAsInt() > 1)
+          instData.push_back(chunkSize);
+      }
+      payloadLayout = LayoutInfo(
+          xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+    } else {
+      auto payloadShape = payloadTy.getShape();
+      if (payloadShape.size() > 1)
+        assert(payloadShape[0] == subgroupSize &&
+               "Expected the first dimension of 2D tensor descriptor to be "
+               "equal to "
+               "subgroup size.");
+      payloadLayout = getDefaultSIMTLayoutInfo(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+          /*scattered=*/true);
+    }
   }
 
   LayoutInfo maskLayout =
@@ -916,10 +958,10 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op) : target(op) {
+  RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable);
+    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
     (void)solver.initializeAndRun(op);
   }
 
@@ -1159,7 +1201,19 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
-  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+  LayoutKind layoutKind;
+  if (this->layoutKind == "lane")
+    layoutKind = LayoutKind::Lane;
+  else if (this->layoutKind == "inst")
+    layoutKind = LayoutKind::InstData;
+  else {
+    signalPassFailure();
+    getOperation()->emitError("Unsupported layout kind option: " +
+                              this->layoutKind);
+    return;
+  }
+  RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
+  // auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
@@ -1173,8 +1227,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
       return {};
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
-    if (this->layoutKind == "lane")
-      layoutAttr = layoutAttr.dropInstData();
     if (layout.isSliceLayout())
       return cast<xegpu::SliceAttr>(layoutAttr);
     return cast<xegpu::LayoutAttr>(layoutAttr);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 58461b8be52c4..c31ef323a94d2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -2,17 +2,17 @@
 
 // CHECK-LABEL: func.func @dpas_f16(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data...
[truncated]

@Garra1980
Copy link

cc @charithaintc @Jianhui-Li

@adam-smnk
Copy link
Contributor

Could this change be also added to Xe pipeline?
It's a nice central place to keep track of evolving lowering strategy.

@charithaintc charithaintc self-requested a review November 7, 2025 16:44
Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't looked into details but overall multi-step strategy seems good 👍
Otherwise, a few nits

Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@akroviakov
Copy link
Contributor Author

Could this change be also added to Xe pipeline?

Where is this pipeline located?

@Garra1980
Copy link

Could this change be also added to Xe pipeline?

Where is this pipeline located?

https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update 👍

@akroviakov akroviakov merged commit bba40ab into llvm:main Nov 10, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants