diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index d8dd09a6280c0..a7f2dc2d6a43e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern // Only 2D offsets are supported for now. if (mixedOffsets.size() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); - auto tdesc = adaptor.getTensorDesc(); + auto payload = adaptor.getTensorDesc(); // Utility for updating payload offset values from op fold result. auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { Value offset = @@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern offset = getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI32Type(), offset); Value oldOffset = - vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos); + vector::ExtractOp::create(rewriter, loc, payload, payloadPos); Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); - return vector::InsertOp::create(rewriter, loc, newOffset, tdesc, + return vector::InsertOp::create(rewriter, loc, newOffset, payload, payloadPos); }; // Update offsets in the payload. - auto val = updateOffset(0, static_cast(NdTdescOffset::TensorOffsetH)); - val = updateOffset(1, static_cast(NdTdescOffset::TensorOffsetW)); - rewriter.replaceOp(op, val); + payload = updateOffset(0, static_cast(NdTdescOffset::TensorOffsetH)); + payload = updateOffset(1, static_cast(NdTdescOffset::TensorOffsetW)); + rewriter.replaceOp(op, payload); return success(); } }; @@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { auto tileH = tdescTy.getDimSize(0); int32_t vblocks = tdescTy.getArrayLength(); if constexpr (std::is_same_v) { - VectorType srcVecTy = dyn_cast(adaptor.getValue().getType()); + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast(src.getType()); if (!srcVecTy) return rewriter.notifyMatchFailure( op, "Expected store value to be a vector type."); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - Value src = adaptor.getValue(); // Get flat vector type of integer type with matching element bit size. VectorType newSrcVecTy = encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); if (srcVecTy != newSrcVecTy) src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); xevm::BlockStore2dOp::create( rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, src, diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 4ff95b40fe68c..ed664a739d134 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -2,9 +2,9 @@ gpu.module @create_nd_tdesc { // CHECK-LABEL: gpu.func @create_nd_tdesc - // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64, + // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64, // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index - gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, + gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 @@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc { %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] : ui64 -> !xegpu.tensor_desc<8x16xf32> - // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32> - %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32> + %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32> // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 + // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 + // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 - // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32 - // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64 - // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32 + // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> @@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc { // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> - %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32 // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32 - // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64 - // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32 - // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64 - // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32 + // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64 + // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32 + // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64 + // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32 // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64 // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64> @@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc { // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32> // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32> // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32> - %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[C8:.*]] = arith.constant 8 : index + %c8 = arith.constant 8 : index + // CHECK: %[[C16:.*]] = arith.constant 16 : index + %c16 = arith.constant 16 : index + // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32 + // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32> + // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32 + // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32 + // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32> + // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32 + // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32> + %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32> gpu.return } }