diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 9c35c07a7e587..3f27d690f949b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -379,6 +379,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> { ); let builders = [ + AttrBuilder<(ins "llvm::ArrayRef": $inst_data), + [{ + auto sg_layout = DenseI32ArrayAttr(); + auto sg_data = DenseI32ArrayAttr(); + auto order = DenseI32ArrayAttr(); + auto lane_layout = DenseI32ArrayAttr(); + auto lane_data = DenseI32ArrayAttr(); + return $_get($_ctxt, sg_layout, sg_data, + DenseI32ArrayAttr::get($_ctxt, inst_data), + lane_layout, lane_data, order); + }]>, AttrBuilder<(ins "llvm::ArrayRef": $inst_data, "llvm::ArrayRef": $lane_layout, "llvm::ArrayRef": $lane_data), diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index e42799689e490..12270af870b3b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -47,7 +47,7 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> { Option< "layoutKind", "layout-kind", "std::string", /*default=*/"\"lane\"", - "Propagate a `sg` / `inst` / `lane` level of xegpu layouts."> + "Propagate `inst` / `lane` level of xegpu layouts."> ]; } diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp index 1a1485ba2e02c..b097d3a0c9686 100644 --- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp +++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp @@ -63,13 +63,20 @@ void buildGPUPassPipeline(OpPassManager &pm, if (options.xegpuOpLevel == "workgroup") { pm.addNestedPass(xegpu::createXeGPUWgToSgDistribute()); pm.addNestedPass(createCSEPass()); + xegpu::XeGPUPropagateLayoutOptions layoutOptions; + layoutOptions.layoutKind = "inst"; + pm.addNestedPass( + xegpu::createXeGPUPropagateLayout(layoutOptions)); pm.addNestedPass(xegpu::createXeGPUBlocking()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); } if (options.xegpuOpLevel == "subgroup" || options.xegpuOpLevel == "workgroup") { - pm.addNestedPass(xegpu::createXeGPUPropagateLayout()); + xegpu::XeGPUPropagateLayoutOptions layoutOptions; + layoutOptions.layoutKind = "lane"; + pm.addNestedPass( + xegpu::createXeGPUPropagateLayout(layoutOptions)); pm.addNestedPass(xegpu::createXeGPUSubgroupDistribute()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createCSEPass()); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 4e1a539771d2f..b3a780abd3f12 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -53,6 +53,8 @@ using namespace mlir::dataflow; namespace { +enum class LayoutKind { Lane, InstData }; + //===----------------------------------------------------------------------===// // LayoutInfo //===----------------------------------------------------------------------===// @@ -166,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) { llvm_unreachable("Join should not be triggered by layout propagation."); } -/// Construct a new layout with the transposed lane layout and lane data. +/// Construct a new layout with the transposed inst_data or lane_layout, +/// lane_data. LayoutInfo LayoutInfo::transpose(ArrayRef permutation) const { if (!isAssigned()) return {}; @@ -186,12 +189,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef permutation) const { SmallVector laneData; SmallVector instData; for (int64_t idx : permutation) { - laneLayout.push_back(static_cast(getLaneLayout()[idx])); - laneData.push_back(static_cast(getLaneData()[idx])); - instData.push_back(static_cast(getInstData()[idx])); + if (getLaneLayout().size()) { + laneLayout.push_back(static_cast(getLaneLayout()[idx])); + laneData.push_back(static_cast(getLaneData()[idx])); + } + if (getInstData().size()) + instData.push_back(static_cast(getInstData()[idx])); } - return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData, - laneLayout, laneData)); + xegpu::LayoutAttr layoutAttr; + if (getLaneLayout().size()) + layoutAttr = + xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData); + if (getInstData().size()) + layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData); + return LayoutInfo(layoutAttr); } //===----------------------------------------------------------------------===// @@ -213,15 +224,14 @@ struct LayoutInfoLattice : public Lattice { /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1]. static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, unsigned rank, - const xegpu::uArch::uArch *uArch, - ArrayRef instData) { + const xegpu::uArch::uArch *uArch) { assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); if (rank == 1) { return LayoutInfo( - xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1})); + xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1})); } - return LayoutInfo(xegpu::LayoutAttr::get( - ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1})); + return LayoutInfo( + xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1})); } static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, @@ -236,7 +246,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, const xegpu::uArch::uArch *uArch, - ArrayRef instData, unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. @@ -247,16 +256,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (vectorTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData); + return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), {uArch->getSubgroupSize(), 1}, {1, packingFactor})); } - return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData, + return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor})); } @@ -264,7 +273,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy, /// Helper to get the default layout for a vector type. static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, const xegpu::uArch::uArch *uArch, - ArrayRef instData, unsigned packingSize, bool isScattered = false) { // Expecting a 1D or 2D vector. @@ -275,18 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, "Expected int or float element type."); // If the rank is 1, then return default layout for 1D vector. if (tdescTy.getRank() == 1) - return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData); + return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch); // Packing factor is determined by the element type bitwidth. unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth(); int subgroupSize = uArch->getSubgroupSize(); int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1; if (isScattered) { return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor})); + tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor})); } return LayoutInfo(xegpu::LayoutAttr::get( - tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor})); + tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor})); } /// Helper Function to get the expected layouts for DPAS operands. `lane_data` @@ -298,7 +306,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy, static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, const xegpu::uArch::uArch *uArch, - ArrayRef instData, unsigned packingSize) { + unsigned packingSize) { Type elementTy = vectorTy.getElementType(); assert(elementTy.isIntOrFloat() && "Expected int or float type in DPAS operands"); @@ -310,10 +318,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, {static_cast(packingSize / elementTy.getIntOrFloatBitWidth()), 1}); return LayoutInfo( - xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data)); + xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data)); } // Otherwise, return the default layout for the vector type. - return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize); + return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize); } //===----------------------------------------------------------------------===// @@ -328,6 +336,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum, class LayoutInfoPropagation : public SparseBackwardDataFlowAnalysis { private: + LayoutKind layoutKind; void visitDpasOp(xegpu::DpasOp dpas, ArrayRef operands, ArrayRef results); @@ -380,8 +389,10 @@ class LayoutInfoPropagation public: LayoutInfoPropagation(DataFlowSolver &solver, - SymbolTableCollection &symbolTable) - : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + SymbolTableCollection &symbolTable, + LayoutKind layoutKind) + : SparseBackwardDataFlowAnalysis(solver, symbolTable), + layoutKind(layoutKind) {} using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; LogicalResult @@ -499,8 +510,14 @@ void LayoutInfoPropagation::visitPrefetchNdOp( "No suitable instruction multiple found for the given shape."); instData = {instHeight, instWidth}; } - auto prefetchLayout = getDefaultSIMTLayoutInfo( - tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize()); + LayoutInfo prefetchLayout; + if (layoutKind == LayoutKind::InstData) + prefetchLayout = + LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData)); + else + prefetchLayout = getDefaultSIMTLayoutInfo( + tdescTy, uArch, uArchInstruction->getPackedFormatBitSize()); + // Propagate the layout to the source tensor descriptor. propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout)); } @@ -627,14 +644,24 @@ void LayoutInfoPropagation::visitDpasOp( SmallVector instDataA = {maxALen, subgroupSize}; SmallVector instDataB = {subgroupSize, maxBLen}; - propagateIfChanged(operands[0], - operands[0]->meet(getSIMTLayoutInfoForDPASOperand( - aTy, 0, uArch, instDataA, - uArchInstruction->getPackedFormatBitSizeA()))); - propagateIfChanged(operands[1], - operands[1]->meet(getSIMTLayoutInfoForDPASOperand( - bTy, 1, uArch, instDataB, - uArchInstruction->getPackedFormatBitSizeB()))); + LayoutInfo dpasALayout; + LayoutInfo dpasBLayout; + LayoutInfo dpasCLayout; + + if (layoutKind == LayoutKind::InstData) { + dpasALayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA)); + dpasBLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB)); + } else { + dpasALayout = getSIMTLayoutInfoForDPASOperand( + aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA()); + dpasBLayout = getSIMTLayoutInfoForDPASOperand( + bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB()); + } + + propagateIfChanged(operands[0], operands[0]->meet(dpasALayout)); + propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout)); if (operands.size() > 2) { VectorType cTy = dpas.getAccType(); const unsigned dataCLen = bTy.getShape().back(); @@ -645,10 +672,15 @@ void LayoutInfoPropagation::visitDpasOp( dpas.emitWarning( "No suitable instruction multiple found for the given shape."); SmallVector instDataC = {maxALen, maxCLen}; - propagateIfChanged(operands[2], - operands[2]->meet(getSIMTLayoutInfoForDPASOperand( - cTy, 2, uArch, instDataC, - uArchInstruction->getPackedFormatBitSizeB()))); + + if (layoutKind == LayoutKind::InstData) + dpasCLayout = + LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC)); + else + dpasCLayout = getSIMTLayoutInfoForDPASOperand( + cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB()); + + propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout)); } } @@ -685,9 +717,15 @@ void LayoutInfoPropagation::visitStoreNdOp( "No suitable instruction multiple found for the given shape."); instData = {instHeight, instWidth}; } - LayoutInfo storeLayout = - getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData, - uArchInstruction->getPackedFormatBitSize()); + + LayoutInfo storeLayout; + if (layoutKind == LayoutKind::InstData) + storeLayout = + LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData)); + else + storeLayout = + getDefaultSIMTLayoutInfo(store.getValueType(), uArch, + uArchInstruction->getPackedFormatBitSize()); // Both operands should have the same layout for (LayoutInfoLattice *operand : operands) propagateIfChanged(operand, operand->meet(storeLayout)); @@ -818,9 +856,13 @@ void LayoutInfoPropagation::visitLoadGatherOp( if (srcTdescTy.getChunkSizeAsInt() > 1) instData.push_back(chunkSize); } - LayoutInfo layout = getDefaultSIMTLayoutInfo( - payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), - /*scattered*/ true); + LayoutInfo layout; + if (layoutKind == LayoutKind::InstData) + layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData)); + else + layout = getDefaultSIMTLayoutInfo(payloadTy, uArch, + uArch->getGeneralPackedFormatBitSize(), + /*scattered*/ true); // Mask operand should have 1D default layout. LayoutInfo maskLayout = @@ -864,33 +906,36 @@ void LayoutInfoPropagation::visitStoreScatterOp( storeScatter.emitWarning("Not propagating, non-vector payload supplied."); return; } + LayoutInfo payloadLayout; auto uArch = getUArch(getChipStr(storeScatter).value_or("")); const int subgroupSize = uArch->getSubgroupSize(); - auto payloadShape = payloadTy.getShape(); - if (payloadShape.size() > 1) - assert( - payloadShape[0] == subgroupSize && - "Expected the first dimension of 2D tensor descriptor to be equal to " - "subgroup size."); - - SmallVector instData{subgroupSize}; - if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1) - instData.push_back(chunkSize); - else if (auto dstTdescTy = - dyn_cast(storeScatter.getDestType())) { - if (dstTdescTy.getChunkSizeAsInt() > 1) - instData.push_back(chunkSize); - } - - LayoutInfo payloadLayout; - if (auto layout = storeScatter.getLayoutAttr()) { payloadLayout = LayoutInfo(layout); } else { - payloadLayout = getDefaultSIMTLayoutInfo( - payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(), - /*scattered=*/true); + if (layoutKind == LayoutKind::InstData) { + SmallVector instData{subgroupSize}; + if (auto chunkSize = storeScatter.getChunkSize().value_or(0); + chunkSize > 1) + instData.push_back(chunkSize); + else if (auto dstTdescTy = dyn_cast( + storeScatter.getDestType())) { + if (dstTdescTy.getChunkSizeAsInt() > 1) + instData.push_back(chunkSize); + } + payloadLayout = LayoutInfo( + xegpu::LayoutAttr::get(storeScatter.getContext(), instData)); + } else { + auto payloadShape = payloadTy.getShape(); + if (payloadShape.size() > 1) + assert(payloadShape[0] == subgroupSize && + "Expected the first dimension of 2D tensor descriptor to be " + "equal to " + "subgroup size."); + payloadLayout = getDefaultSIMTLayoutInfo( + payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(), + /*scattered=*/true); + } } LayoutInfo maskLayout = @@ -916,10 +961,10 @@ class RunLayoutInfoPropagation { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation) - RunLayoutInfoPropagation(Operation *op) : target(op) { + RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) { SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); - solver.load(symbolTable); + solver.load(symbolTable, layoutKind); (void)solver.initializeAndRun(op); } @@ -1159,7 +1204,18 @@ struct XeGPUPropagateLayoutPass final } // namespace void XeGPUPropagateLayoutPass::runOnOperation() { - auto &analysis = getAnalysis(); + LayoutKind layoutKind; + if (this->layoutKind == "lane") { + layoutKind = LayoutKind::Lane; + } else if (this->layoutKind == "inst") { + layoutKind = LayoutKind::InstData; + } else { + getOperation()->emitError("Unsupported layout kind option: " + + this->layoutKind); + signalPassFailure(); + return; + } + RunLayoutInfoPropagation analysis(getOperation(), layoutKind); // Print the analysis result and exit. (for debugging purposes) if (printOnly) { auto &os = llvm::outs(); @@ -1173,8 +1229,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() { return {}; xegpu::DistributeLayoutAttr layoutAttr = cast(layout.get()); - if (this->layoutKind == "lane") - layoutAttr = layoutAttr.dropInstData(); if (layout.isSliceLayout()) return cast(layoutAttr); return cast(layoutAttr); diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir index 58461b8be52c4..c31ef323a94d2 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir @@ -2,17 +2,17 @@ // CHECK-LABEL: func.func @dpas_f16( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) { -// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0.000000e+00> : vector<8x16xf32> -// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout> -// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout} : -// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout> -> vector<8x16xf16> -// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout} : -// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -> vector<16x16xf16> -// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout} : +// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0.000000e+00> : vector<8x16xf32> +// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout> +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] {layout_result_0 = #xegpu.layout} : +// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout> -> vector<8x16xf16> +// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] {layout_result_0 = #xegpu.layout} : +// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout> -> vector<16x16xf16> +// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout} : // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> -// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout> -// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout> +// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout +// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout> gpu.module @test { func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { @@ -46,18 +46,18 @@ gpu.module @test_kernel { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) -> (!xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>) { - //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout} : - //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout> -> vector<16x32xf16> + //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout} : + //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout> -> vector<16x32xf16> %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16> - //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout} : vector<16x32xf16> + //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout} : vector<16x32xf16> %c = arith.addf %a, %b : vector<16x32xf16> - //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout>> + //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout> xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16> - //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout> + //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout> %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16> %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16> %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16> @@ -85,18 +85,18 @@ gpu.module @test_kernel { %out:3 = scf.for %k = %c0 to %c1024 step %c32 iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc) -> (!xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>) { - //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout} : - //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout> -> vector<12x32xf16> + //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout} : + //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout> -> vector<12x32xf16> %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16> %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16> - //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x32xf16> + //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout} : vector<12x32xf16> %c = arith.addf %a, %b : vector<12x32xf16> - //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout>> + //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout> xegpu.store_nd %c, %arg2: vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16> - //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout> + //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout> %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16> %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16> %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16> @@ -114,7 +114,7 @@ gpu.module @test { // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout} dense<12> : vector<16xindex> // CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16> // CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1> func.func @scatter_ops_chunksize(%src: memref<256xf16>) { %1 = arith.constant dense<1>: vector<16xi1> diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir index 61e315d0d2080..eb004932af4be 100644 --- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir +++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout -split-input-file %s | FileCheck %s +// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s gpu.module @test { // CHECK-LABEL: func.func @dpas_f16(