Skip to content

Commit af6f83f

Browse files
committed
refactor
1 parent 2fb9ac7 commit af6f83f

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
236236
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
237237
}
238238

239-
ArrayRef<int64_t> getDistributeShape() {
239+
ArrayRef<int64_t> getDataShape() {
240240
return getTensorDescShape();
241241
}
242242

@@ -283,7 +283,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
283283
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
284284
}
285285

286-
ArrayRef<int64_t> getDistributeShape() {
286+
ArrayRef<int64_t> getDataShape() {
287287
return getTensorDescType().getShape();
288288
}
289289

@@ -381,7 +381,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
381381
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
382382
}
383383

384-
ArrayRef<int64_t> getDistributeShape() {
384+
ArrayRef<int64_t> getDataShape() {
385385
return getTensorDescType().getShape();
386386
}
387387

@@ -473,7 +473,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
473473
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
474474
}
475475

476-
ArrayRef<int64_t> getDistributeShape() {
476+
ArrayRef<int64_t> getDataShape() {
477477
return getTensorDescType().getShape();
478478
}
479479

@@ -1243,7 +1243,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
12431243
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
12441244
}
12451245

1246-
ArrayRef<int64_t> getDistributeShape() {
1246+
ArrayRef<int64_t> getDataShape() {
12471247
return getRes().getType().getShape();
12481248
}
12491249
}];
@@ -1285,7 +1285,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
12851285
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
12861286
}
12871287

1288-
ArrayRef<int64_t> getDistributeShape() {
1288+
ArrayRef<int64_t> getDataShape() {
12891289
return getData().getType().getShape();
12901290
}
12911291

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,12 @@ namespace {
3636

3737
// Retrieve the RangeAttr if it is specified.
3838
static xegpu::RangeAttr getRangeSpecAttr(Operation *op) {
39-
Operation *parent = op->getParentOp();
39+
Operation *parent = op->getParentOfType<scf::IfOp>();
4040
while (parent) {
41-
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
42-
if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>(
43-
ifOp->getAttr("sg_id_range"))) {
44-
return attr;
45-
}
46-
}
47-
parent = parent->getParentOp();
41+
if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42+
parent->getAttr("sg_id_range")))
43+
return attr;
44+
parent = parent->getParentOfType<scf::IfOp>();
4845
}
4946
return {};
5047
}
@@ -115,14 +112,13 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
115112

116113
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
117114
// descriptors to be accessed, based on the layout information.
118-
ArrayRef<int64_t> wgShape = op.getDistributeShape();
115+
ArrayRef<int64_t> wgShape = op.getDataShape();
119116
auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
120117
if (failed(maybeDescOffsets))
121118
return failure();
122119

123120
// Compute the final global offsets for each accessed sub-tensor
124121
// or sub-memory descriptor.
125-
// SmallVector<SmallVector<OpFoldResult>> offsetsList;
126122
for (const auto &sgOffsets : *maybeDescOffsets) {
127123
SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned(
128124
rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
@@ -777,7 +773,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
777773
if (failed(genOffsetsList(rewriter, op, offsetsList)))
778774
return failure();
779775

780-
ArrayRef<int64_t> wgShape = op.getDistributeShape();
776+
ArrayRef<int64_t> wgShape = op.getDataShape();
781777
VectorType valueTy = op.getRes().getType();
782778
Type elemTy = valueTy.getElementType();
783779

0 commit comments

Comments
 (0)