Skip to content

Commit 68b143d

Browse files
authored
[MLIR][XeGPU] Use operand layouts for store scatter (#161447)
The PR adds a change to use the layouts from the operands since store doesn't have a result
1 parent 9133fc8 commit 68b143d

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset
824824
return failure();
825825

826826
xegpu::DistributeLayoutAttr layout =
827-
xegpu::getDistributeLayoutAttr(op.getValue());
827+
xegpu::getDistributeLayoutAttr(op.getOperand(0));
828828
if (!layout || !layout.isForWorkgroup())
829829
return failure();
830830

@@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset
844844
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
845845
for (auto [val, offs, mask] : llvm::zip(
846846
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
847-
xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
848-
mask, chunkSizeAttr, op.getL1HintAttr(),
849-
op.getL2HintAttr(), op.getL3HintAttr());
847+
auto store = xegpu::StoreScatterOp::create(
848+
rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
849+
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
850850
// Update the layout attribute to drop sg_layout and sg_data.
851-
if (auto newLayout = layout.dropSgLayoutAndData())
852-
op->setAttr("layout", newLayout);
851+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
852+
!layout.getEffectiveInstDataAsInt().empty()) {
853+
for (OpOperand &operand : store->getOpOperands()) {
854+
// Skip for operand one (memref)
855+
if (operand.getOperandNumber() == 1)
856+
continue;
857+
xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
858+
}
859+
}
853860
}
854861
rewriter.eraseOp(op);
855862
return success();
@@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
12471254

12481255
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
12491256
[=](xegpu::StoreScatterOp op) -> bool {
1250-
// Check if the layout attribute is present on the result.
1251-
auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
1252-
if (!layout)
1253-
return true;
1257+
auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
12541258
return isLegal(layout);
12551259
});
12561260

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,20 @@ gpu.module @test_distribution {
282282
// CHECK-LABEL: @store_scatter
283283
// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>
284284
gpu.func @store_scatter(%dest : memref<256xf16>) {
285-
// CHECK: %[[VAL:.*]] = arith.constant dense<2.550000e+01> : vector<8xf16>
286-
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xindex>
287-
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8xi1>
285+
// CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
286+
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
287+
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
288288
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
289+
// CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
290+
// CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
289291
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
290-
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<25.5> : vector<256xf16>
291-
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<0> : vector<256xindex>
292-
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8]>} dense<1> : vector<256xi1>
293-
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout = #xegpu.layout<sg_layout = [32], sg_data = [8]>, l1_hint = #xegpu.cache_hint<cached>}
292+
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>
293+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<0> : vector<256xindex>
294+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<1> : vector<256xi1>
295+
xegpu.store %val, %dest[%offset], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
296+
layout_operand_2 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
297+
layout_operand_3 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>,
298+
l1_hint = #xegpu.cache_hint<cached>}
294299
: vector<256xf16>, memref<256xf16>, vector<256xindex>, vector<256xi1>
295300
gpu.return
296301
}

0 commit comments

Comments
 (0)