Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,23 @@ class UpdateNdOffsetToXeVMPattern
// Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto tdesc = adaptor.getTensorDesc();
auto payload = adaptor.getTensorDesc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell the type name.

// Utility for updating payload offset values from op fold result.
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
Value offset =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offset);
Value oldOffset =
vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
return vector::InsertOp::create(rewriter, loc, newOffset, payload,
payloadPos);
};
// Update offsets in the payload.
auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
rewriter.replaceOp(op, val);
payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
rewriter.replaceOp(op, payload);
return success();
}
};
Expand Down Expand Up @@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the prev code it fails here?

return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
Value src = adaptor.getValue();
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
Expand Down
46 changes: 30 additions & 16 deletions mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe separate test for update_nd? or rename the test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
// CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
Expand All @@ -23,35 +23,35 @@ gpu.module @create_nd_tdesc {
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
%srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>

// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
// CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
// CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
// CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
// CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
// CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
// CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
// CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
Expand All @@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
%src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
%src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[C8:.*]] = arith.constant 8 : index
%c8 = arith.constant 8 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%c16 = arith.constant 16 : index
// CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
// CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
// CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
// CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
%updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}