Skip to content

Conversation

Jianhui-Li
Copy link
Contributor

@Jianhui-Li Jianhui-Li commented Jul 12, 2025

This PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd.

  1. When creating nd_tdesc for dynamic shape tensor, must use @shape and @STRIDES attributes to describe base tensor.
 %2 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
  %2 = xegpu.create_nd_tdesc %src[%x, %y] shape : [%h, %w] strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>  
  1. With this new assembly format, the offset may not be supplied by user.
 %2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
  %2 = xegpu.create_nd_tdesc %src shape : [%h, %w] strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>  

@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

Changes

This PR allows xegpu to take optional offsets when create_nd_tdesc. This is the initial PR to move offsets from create_nd_tdesc to load_nd.
When creating nd_tdesc for dynamic shape tensor, must use @shape and @strides attributes to describe base tensor.

 %2 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1]  : memref&lt;?x?xf16&gt; -&gt; !xegpu.tensor_desc&lt;8x16xf16&gt;
  %2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1]  : ui64 -&gt; !xegpu.tensor_desc&lt;8x16xf32&gt;  

Full diff: https://github.com/llvm/llvm-project/pull/148335.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+35-10)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+242-2)
  • (modified) mlir/test/Dialect/XeGPU/ops.mlir (+43-2)
  • (modified) mlir/test/Dialect/XeGPU/subgroup-distribute.mlir (+4-4)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index bd5ea9fd83781..710fc62b032a9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -110,23 +110,27 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     Variadic<Index>: $offsets,
     Variadic<Index>: $shape,
     Variadic<Index>: $strides,
-    DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
     OptionalAttr<DenseI64ArrayAttr>: $const_shape,
     OptionalAttr<DenseI64ArrayAttr>: $const_strides
   );
   let results = (outs XeGPU_TensorDesc: $TensorDesc);
 
-  let assemblyFormat = [{
-    $source ``
-    custom<DynamicIndexList>($offsets, $const_offsets)
-    (`,` custom<DynamicIndexList>($shape, $const_shape)^
-     `,` custom<DynamicIndexList>($strides, $const_strides))?
-    attr-dict `:` type($source) `->` qualified(type($TensorDesc))
-  }];
-
   let hasVerifier = 1;
 
+  let hasCustomAssemblyFormat = 1;
+
   let builders = [
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
+
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
+                   "llvm::ArrayRef<OpFoldResult>": $shape,
+                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
+                   "llvm::ArrayRef<OpFoldResult>": $shape,
+                   "llvm::ArrayRef<OpFoldResult>": $strides)>,
+
     OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
@@ -163,9 +167,30 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     }
 
     ArrayRef<int64_t> getStaticOffsets(){
-      return getConstOffsets();
+      auto attr = getConstOffsetsAttr();
+
+      if (attr) 
+        return attr;
+
+      auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
+      int rank = 0;
+      if (memrefType) {
+        //use source memref's rank, as source memref rank may be higher
+        rank = memrefType.getRank();
+      } else {
+        //nd_tdesc created from ui64, use nd_tdesc's rank
+        rank = getTensorDescShape().size();
+      };
+
+      // The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
+      // It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
+      setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));
+
+      attr = getConstOffsetsAttr();
+      return attr;
     }
 
+
     /// wrapper for matching with OffsetSizeAndStrideOpInterface
     /// If source is IntegerType or `const_shape` is filled,
     /// it will return `const_shape`, such that mixes of `shape`
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ef7cd1424e7a4..9f6090ad279f5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,8 +125,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
   build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
         ValueRange({}) /* empty dynamic shape */,
         ValueRange({}) /* empty dynamic strides */,
-        staticOffsets /* const offsets */, {} /* empty const shape*/,
-        {} /* empty const strides*/);
+        builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
+        {} /* empty const shape*/, {} /* empty const strides*/);
 }
 
 void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
@@ -221,6 +221,246 @@ LogicalResult CreateNdDescOp::verify() {
   return success();
 }
 
+ParseResult parseOptionalDynamicIndexList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
+    SmallVectorImpl<Type> *valueTypes = nullptr,
+    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+
+  SmallVector<int64_t, 4> integerVals;
+  SmallVector<bool, 4> scalableVals;
+  auto parseIntegerOrValue = [&]() {
+    OpAsmParser::UnresolvedOperand operand;
+    auto res = parser.parseOptionalOperand(operand);
+
+    // When encountering `[`, assume that this is a scalable index.
+    scalableVals.push_back(parser.parseOptionalLSquare().succeeded());
+
+    if (res.has_value() && succeeded(res.value())) {
+      values.push_back(operand);
+      integerVals.push_back(ShapedType::kDynamic);
+      if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
+        return failure();
+    } else {
+      int64_t integer;
+      if (failed(parser.parseInteger(integer)))
+        return failure();
+      integerVals.push_back(integer);
+    }
+
+    // If this is assumed to be a scalable index, verify that there's a closing
+    // `]`.
+    if (scalableVals.back() && parser.parseOptionalRSquare().failed())
+      return failure();
+    return success();
+  };
+  if (parser.parseOptionalLSquare().succeeded()) {
+    if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
+        parser.parseRSquare())
+      return parser.emitError(parser.getNameLoc())
+             << "expected SSA value or integer";
+    integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
+    scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
+    return success();
+  }
+  return success();
+}
+
+::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
+                                          ::mlir::OperationState &result) {
+  ::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
+  ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
+      &sourceRawOperand, 1);
+  ::llvm::SMLoc sourceOperandsLoc;
+
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+      offsetsOperands;
+  ::llvm::SMLoc offsetsOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_offsetsAttr;
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
+  ::llvm::SMLoc shapeOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_shapeAttr;
+  ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
+      stridesOperands;
+  ::llvm::SMLoc stridesOperandsLoc;
+  ::mlir::DenseI64ArrayAttr const_stridesAttr;
+  ::mlir::Type sourceRawType{};
+  ::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
+  ::mlir::Type TensorDescRawType{};
+  ::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);
+
+  sourceOperandsLoc = parser.getCurrentLocation();
+  if (parser.parseOperand(sourceRawOperand))
+    return ::mlir::failure();
+
+  offsetsOperandsLoc = parser.getCurrentLocation();
+
+  DenseBoolArrayAttr scalableFlags;
+  auto odsResult = parseOptionalDynamicIndexList(
+      parser, offsetsOperands, const_offsetsAttr, scalableFlags);
+
+  if (const_offsetsAttr) {
+    if (odsResult)
+      return ::mlir::failure();
+    result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
+        const_offsetsAttr;
+  }
+
+  if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
+    if (parser.parseColon())
+      return ::mlir::failure();
+    {
+      shapeOperandsLoc = parser.getCurrentLocation();
+      auto odsResult =
+          parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
+      if (const_shapeAttr) {
+        if (odsResult)
+          return ::mlir::failure();
+        result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
+            const_shapeAttr;
+      }
+    }
+
+    if (parser.parseKeyword("strides"))
+      return ::mlir::failure();
+    if (parser.parseColon())
+      return ::mlir::failure();
+    {
+      stridesOperandsLoc = parser.getCurrentLocation();
+      auto odsResult =
+          parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
+      if (const_stridesAttr) {
+        if (odsResult)
+          return ::mlir::failure();
+        result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
+            const_stridesAttr;
+      }
+    }
+  }
+  {
+    auto loc = parser.getCurrentLocation();
+    if (parser.parseOptionalAttrDict(result.attributes))
+      return ::mlir::failure();
+    if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
+          return parser.emitError(loc)
+                 << "'" << result.name.getStringRef() << "' op ";
+        })))
+      return ::mlir::failure();
+  }
+  if (parser.parseColon())
+    return ::mlir::failure();
+
+  {
+    ::mlir::Type type;
+    if (parser.parseCustomTypeWithFallback(type))
+      return ::mlir::failure();
+    sourceRawType = type;
+  }
+  if (parser.parseArrow())
+    return ::mlir::failure();
+
+  if (parser.parseType(TensorDescRawType))
+    return ::mlir::failure();
+
+  ::llvm::copy(::llvm::ArrayRef<int32_t>(
+                   {1, static_cast<int32_t>(offsetsOperands.size()),
+                    static_cast<int32_t>(shapeOperands.size()),
+                    static_cast<int32_t>(stridesOperands.size())}),
+               result.getOrAddProperties<CreateNdDescOp::Properties>()
+                   .operandSegmentSizes.begin());
+
+  ::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
+  result.addTypes(TensorDescTypes);
+
+  if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
+                             result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
+                             offsetsOperandsLoc, result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
+                             result.operands))
+    return ::mlir::failure();
+
+  if (parser.resolveOperands(stridesOperands, odsBuildableType0,
+                             stridesOperandsLoc, result.operands))
+    return ::mlir::failure();
+  return ::mlir::success();
+}
+
+void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
+  _odsPrinter << ' ';
+  _odsPrinter << getSource();
+
+  auto constOffsetsAttr = getConstOffsetsAttr();
+  bool printOffsets = false;
+  if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
+    auto firstVal = constOffsetsAttr.asArrayRef()[0];
+    if (firstVal != std::numeric_limits<int64_t>::max()) {
+      printOffsets = true;
+    }
+  }
+  if (printOffsets) {
+
+    printDynamicIndexList(_odsPrinter, *this, getOffsets(),
+                          getConstOffsetsAttr());
+  }
+  if (((!getShape().empty()) || (getConstShapeAttr()))) {
+    _odsPrinter << ' ' << "shape";
+    _odsPrinter << ' ' << ":";
+    _odsPrinter << ' ';
+    printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
+    _odsPrinter << ' ' << "strides";
+    _odsPrinter << ' ' << ":";
+    _odsPrinter << ' ';
+    printDynamicIndexList(_odsPrinter, *this, getStrides(),
+                          getConstStridesAttr());
+  }
+  ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
+  elidedAttrs.push_back("operandSegmentSizes");
+  elidedAttrs.push_back("const_offsets");
+  elidedAttrs.push_back("const_shape");
+  elidedAttrs.push_back("const_strides");
+  _odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
+  _odsPrinter << ' ' << ":";
+  _odsPrinter << ' ';
+  {
+    auto type = getSource().getType();
+    if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
+      _odsPrinter.printStrippedAttrOrType(validType);
+    else
+      _odsPrinter << type;
+  }
+  _odsPrinter << ' ' << "->";
+  _odsPrinter << ' ';
+  // _odsPrinter << getTensorDesc().getType();
+
+  _odsPrinter << "!xegpu.tensor_desc<";
+
+  auto tDesc = getTensorDesc().getType();
+  auto shape = tDesc.getShape();
+  for (int64_t dim : shape) {
+    if (mlir::ShapedType::isDynamic(dim))
+      _odsPrinter << '?';
+    else
+      _odsPrinter << dim;
+    _odsPrinter << 'x';
+  }
+
+  _odsPrinter << tDesc.getElementType();
+
+  if (auto encoding = tDesc.getEncoding())
+    _odsPrinter << ", " << encoding;
+
+  if (auto layout = tDesc.getLayout())
+    _odsPrinter << ", " << layout;
+
+  _odsPrinter << ">";
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_PrefetchNdOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3bfe1fa81aa6e..0d679e519ed60 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
 gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
   //CHECK: %[[C:.*]] = arith.constant 1 : index
   %c1 = arith.constant 1 : index
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
-  %1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]],  %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
   gpu.return
 }
 
@@ -62,6 +62,47 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
 }
 
 
+// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>) 
+gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
+  //CHECK: %[[C:.*]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+ 
+  gpu.return
+}
+
+// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index) 
+gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
+  
+  %c1 = arith.constant 1 : index   
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+  %2 = xegpu.create_nd_tdesc %src shape : [%h, %w] strides : [%w, %c1]  : ui64 -> !xegpu.tensor_desc<8x16xf32>
+ 
+  gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}}) 
+
+gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
+
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4] shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+  gpu.return
+}
+
+// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}}) 
+gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {  
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16> 
+  %2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1]  : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
+
+  gpu.return
+}
+
 // CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3d91b2269bc4b..ba29d1ab13cae 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -150,16 +150,16 @@ gpu.module @test {
 // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
 // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
 // CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]]  : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
-// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
+// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] shape : [%[[ARG2]], %[[ARG3]]] strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
 // CHECK: xegpu.store_nd %[[T1]], %[[T2]]  : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
 gpu.module @test {
   gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
     %c0 = arith.constant 0 : index
-    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.load_nd %0  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
-    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] shape:[%arg2, %arg3] strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }

@Jianhui-Li
Copy link
Contributor Author

@chencha3 @charithaintc

@Garra1980
Copy link

cc also @akroviakov

@charithaintc charithaintc self-requested a review July 14, 2025 23:55
Comment on lines 177 to 183
if (memrefType) {
//use source memref's rank, as source memref rank may be higher
rank = memrefType.getRank();
} else {
//nd_tdesc created from ui64, use nd_tdesc's rank
rank = getTensorDescShape().size();
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need of brackets for single line if and else.

Comment on lines 186 to 188
// It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we need to reuse this constant in future. Better to define it somewhere.

static constexpr int64_t optionalValue = std::numeric_limits<int64_t>::max();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parser and print code is supposed to removed once we finish the transition that move the offsets from create_nd_tdesc definition to load_nd. So no plan to reuse.

staticOffsets /* const offsets */, {} /* empty const shape*/,
{} /* empty const strides*/);
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
{} /* empty const shape*/, {} /* empty const strides*/);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the new build methods implemented?

    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,

    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
                   "llvm::ArrayRef<OpFoldResult>": $shape,
                   "llvm::ArrayRef<OpFoldResult>": $strides)>,

    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
                   "llvm::ArrayRef<OpFoldResult>": $shape,
                   "llvm::ArrayRef<OpFoldResult>": $strides)>,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines 224 to 229
ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we can not reuse parseDynamicIndexList method and avoid this? I see lost of logic replicated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The offset is provided as optional bracket [], so we need to customize parseDynamicIndexList.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored and simplified the custom parser.

Comment on lines 395 to 396
_odsPrinter << ' ';
_odsPrinter << getSource();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should avoid using tablegen generate variable names _odsPrinter and use something more readable like printer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't expect this code will stay permanent. Keeping them same as Talegen generated printer code helps debugging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

Comment on lines 310 to 312
if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
if (parser.parseColon())
return ::mlir::failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be nicer to use a keyword for offset as well? For the optional case it will be empty square brackets.

offsets : [], strides : [...], shapes: []

Copy link
Contributor Author

@Jianhui-Li Jianhui-Li Jul 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are moving offsets to load_nd.

@@ -56,7 +56,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]] strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not comma seperate them for readabilty.

shapes : [] , strides: []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be future enhancement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added support.

Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, apart from a few points.

  • Might be worth adapting examples in XeGPU_CreateNdDescOp description.

  • The verifier never checks for ui64 that the shape/strides rank matches the tensor descriptor rank. The following IR is considered valid:

  %1 = xegpu.create_nd_tdesc %src[%x, %y, %c1] shape : [%h, %w, %x] strides : [%w, %c1, %y]  : ui64 -> !xegpu.tensor_desc<8x16xf32>

but if one removes the offset:

  %1 = xegpu.create_nd_tdesc %src shape : [%h, %w, %x] strides : [%w, %c1, %y]  : ui64 -> !xegpu.tensor_desc<8x16xf32>

the rank mismatch between sizes and offsets error pops up.

For optional offsets to work with ui64, if you decide to tie the offsets rank to the tdesc rank (rather than to sizes/strides), then the verifier should check that supplied shape/strides also have a matching rank with tdesc.

  • (May be unrelated to the PR topic) For ui64, the description mandates users to supply sizes/strides, but the verifier never checks it, so the following code crashes, instead of showing some error:
  %1 = xegpu.create_nd_tdesc %src : ui64 -> !xegpu.tensor_desc<8x16xf32>

rank = memrefType.getRank();
} else {
//nd_tdesc created from ui64, use nd_tdesc's rank
rank = getTensorDescShape().size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not safe to infer the rank from the TensorDesc, since TensorDesc could have fewer rank than offset. You can simply use int rank = getStaticSizes().size() instead;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, user doesn't specify neither const offsets or dynamic offset values. So I assume that we can only infer the rank from TensorDesc. Not sure getStaticSizes() can give us correct result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link

github-actions bot commented Jul 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Jianhui-Li
Copy link
Contributor Author

you decide to tie the offsets rank to the tdesc rank (rather than to sizes/strides)

It is not intended to tie the offsets rank to tdesc rank. I just fixed to tie to shapes/strides of input tensor.


void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source) {
[[maybe_unused]] auto ty = source.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think [[maybe_unused]] is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty is only used in the assert statement which is unused in release binary.

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source) {
[[maybe_unused]] auto ty = source.getType();
assert(ty.hasStaticShape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some error text in assert. like "expecting a memref with static shape"

Type tdesc, TypedValue<MemRefType> source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && strides.size() && shape.size() == strides.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some error text. why this invariant must be satisfied.

Type tdesc, TypedValue<IntegerType> source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && strides.size() && shape.size() == strides.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
parser.parseRSquare())
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< "expected SSA value or integer";
<< "expected a list of SSA values or integers";


return success();
};
if (parser.parseOptionalLSquare().succeeded()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that for no-offset case this check will fail?
Example:

create_nd %src shape: [] strides: []

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment here like "If the optional values are given there must be left bracket"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added

TypeRange valueTypes = TypeRange()) {

if (values.empty() && llvm::all_of(integers, [](int64_t i) {
return i == std::numeric_limits<int64_t>::max();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment here to explain that here we use some place holder values to carry optional offsets.

nit: I still prefer if this constant is defined somewhere and properly documented. Will make life easier for future changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is temporary. I added a comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the valueTypes and related logic, since it is not used.

Comment on lines 331 to 333
printer << values[dynamicValIdx];
if (!valueTypes.empty())
printer << " : " << valueTypes[dynamicValIdx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does values and valueTypes always have same size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems correct from the test cases. This interface is automatically generated by parser which should guarantee it.

@Jianhui-Li
Copy link
Contributor Author

  • (May be unrelated to the PR topic) For ui64, the description mandates users to supply sizes/strides, but the verifier never checks it, so the following code crashes, instead of showing some error:

verifier and invalid test case added.

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ some nit comments

rank = memrefType.getRank();
else
//nd_tdesc created from ui64, use nd_tdesc's rank
rank = getMixedSizes().size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is not nd_tdesc's rank, it is the rank of Shape. It is not necessary to differentiate between memrefType and else, since it is handled in getMixedSizes(). It is an abstraction interface returns the shape of the memory, regardless of it is specified by a MemrefType or via the shape/stride parameter.

TypeRange valueTypes = TypeRange()) {

if (values.empty() && llvm::all_of(integers, [](int64_t i) {
return i == std::numeric_limits<int64_t>::max();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the valueTypes and related logic, since it is not used.

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also update the PR description?
The second case's IR still uses offsets

@Jianhui-Li
Copy link
Contributor Author

Could you also update the PR description? The second case's IR still uses offsets

fixed. thanks!

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the extra iterations👍
Now effectively there should be no semantic changes to the op itself.

@chencha3 chencha3 merged commit aea2d53 into llvm:main Jul 17, 2025
9 checks passed
Jianhui-Li added a commit that referenced this pull request Jul 23, 2025
…49424)

This PR allows load_nd/store_nd/prefetch_nd to take an additional offset
operand.
It is based on this PR #148335.
Now user can create a nd_tdesc with no offset, and instead set the
offset with the load_nd operation.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jul 23, 2025
…etch_nd (#149424)

This PR allows load_nd/store_nd/prefetch_nd to take an additional offset
operand.
It is based on this PR llvm/llvm-project#148335.
Now user can create a nd_tdesc with no offset, and instead set the
offset with the load_nd operation.
MarkMurrayARM pushed a commit to MarkMurrayARM/arm-toolchain that referenced this pull request Jul 24, 2025
…49424)

This PR allows load_nd/store_nd/prefetch_nd to take an additional offset
operand.
It is based on this PR llvm/llvm-project#148335.
Now user can create a nd_tdesc with no offset, and instead set the
offset with the load_nd operation.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…vm#149424)

This PR allows load_nd/store_nd/prefetch_nd to take an additional offset
operand.
It is based on this PR llvm#148335.
Now user can create a nd_tdesc with no offset, and instead set the
offset with the load_nd operation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants