Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Oct 14, 2025

This PR changes the pack/unpack method used for unrolling to allow for lower rank slice to be extracted and inserted from and to src vector by adding reshapes. It also removes leading unit dims from inst_data if there are any.

@nbpatel nbpatel requested a review from Jianhui-Li October 15, 2025 14:01
@nbpatel nbpatel marked this pull request as ready for review October 15, 2025 14:01
@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR changes the pack/unpack method used for unrolling to allow for lower rank slice to be extracted and inserted from and to src vector by adding reshapes. It also removes leading unit dims from inst_data if there are any.


Full diff: https://github.com/llvm/llvm-project/pull/163459.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+8-3)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+1-5)
  • (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+23-4)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+43)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784abaf0b2..48831728ad624 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,13 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
   xegpu::DistributeLayoutAttr layout =
       xegpu::getDistributeLayoutAttr(operandOrResult);
   if (layout && layout.isForSubgroup()) {
-    if (!layout.getEffectiveInstDataAsInt().empty())
-      return layout.getEffectiveInstDataAsInt();
+    if (!layout.getEffectiveInstDataAsInt().empty()) {
+      SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+      // Remove all leading unit dimensions from inst_data
+      while (!instData.empty() && instData.front() == 1)
+        instData.erase(instData.begin());
+      return instData;
+    }
 
     if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
@@ -363,7 +368,7 @@ void XeGPUBlockingPass::runOnOperation() {
           xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
                                      tdescTy.getLayoutAttr().dropInstData());
     } else {
-      newTy = type.clone(tileShape, elemTy);
+      newTy = VectorType::get(tileShape, elemTy);
     }
 
     if (returnSingleType)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0fe4b0b0..75b215c320e54 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
   Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
                Location loc, PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(destTy)) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of destTy.");
       auto shape = vecTy.getShape();
       return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
     }
@@ -93,8 +91,6 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
                           ArrayRef<int64_t> blockSize, Location loc,
                           PatternRewriter &rewriter) const {
     if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
-      assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
-             "Expecting blockSize size to match the rank of src.");
       return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
                                                      blockSize);
     }
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
     VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
     VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
     Type elemTy = valueTy.getElementType();
-    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+    VectorType newValueTy = VectorType::get(*targetShape, elemTy);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a438ea62c..40013eb161678 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,30 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
   if (!computeShapeRatio(srcShape, shape))
     return {value};
 
+  int64_t srcShapeRank = srcShape.size();
+  int64_t targetShapeRank = shape.size();
+
+  SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+  int64_t rankDiff = srcShapeRank - targetShapeRank;
+  std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+            1);
+  std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
+  int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
+
   SmallVector<Value> result;
-  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+  for (SmallVector<int64_t> offsets :
+       StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
     SmallVector<int64_t> staticStrides(offsets.size(), 1);
-    result.push_back(vector::ExtractStridedSliceOp::create(
-        builder, loc, value, offsets, shape, staticStrides));
+    Value slice = vector::ExtractStridedSliceOp::create(
+        builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+    // Reshape to remove leading unit dims if needed
+    if (adjustedTargetShapeRank > targetShapeRank) {
+      auto targetTy = VectorType::get(shape, vecTy.getElementType());
+      slice = builder.create<vector::ShapeCastOp>(loc, targetTy, slice);
+    }
+    result.push_back(slice);
   }
 
   return result;
@@ -274,7 +293,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 
   for (auto [src, offsets] :
        llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
-    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    SmallVector<int64_t> staticStrides(tileShape.size(), 1);
     result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
                                                   offsets, staticStrides);
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index fe4f44c0b02ab..6301533da640d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -682,3 +682,46 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_gather
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_gather(%src: ui64) -> vector<1x1x32xf32> {
+      %cst = arith.constant dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+
+      gpu.return %ld : vector<1x1x32xf32>
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_scatter
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_scatter(%src: ui64) {
+      %cst = arith.constant dense<[[
+      [0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248]
+      ]]> : vector<1x1x32xindex>
+
+      %mask = arith.constant dense<true> : vector<1x1x32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<1x1x32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              layout_operand_2 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              layout_operand_3 = #xegpu.layout<inst_data = [1, 1, 16]>,
+                                              l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
+
+      gpu.return
+  }
+}

%mask = arith.constant dense<true> : vector<1x1x32xi1>
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>

gpu.return %ld : vector<1x1x32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please check the code sequence how the value is composed back to 1x1x32?

]]> : vector<1x1x32xindex>

%mask = arith.constant dense<true> : vector<1x1x32xi1>
%ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [1, 1, 16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please check how the input is being extracted from %mask and %cst (the additional shape cast)?

why not add a test with inst_data attached to %cst and %mask and see how they are unrolled all together?

if (!layout.getEffectiveInstDataAsInt().empty()) {
SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
// Remove all leading unit dimensions from inst_data
while (!instData.empty() && instData.front() == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't work for all xegpu operations. Say load_nd would need 2 dim inst_data since its output is 2d. Should this logic happen per operation after they getTileShape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so do we just remove it for load_gather/store_scatter and elementwise ops?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants