Skip to content

Commit 62c5c38

Browse files
committed
Align lowering with new utils behavior
Signed-off-by: dchigarev <[email protected]>
1 parent 4d2c284 commit 62c5c38

File tree

3 files changed

+107
-101
lines changed

3 files changed

+107
-101
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797
return success();
9898
}
9999

100+
// Common preconditions for the lowering of vector.gather and vector.scatter:
101+
// 1. Source is a memref.
102+
// 2. The innermost dimension of the memref is contiguous (stride == 1)
100103
static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
101104
Operation *op, Type baseType) {
102105
auto srcTy = dyn_cast<MemRefType>(baseType);
@@ -259,7 +262,7 @@ computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
259262
adjustStridesForPermutation(permMap, strides);
260263
}
261264

262-
return strides;
265+
return {strides, offsetVal};
263266
}
264267

265268
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -374,15 +377,14 @@ template <
374377
typename = std::enable_if_t<llvm::is_one_of<
375378
std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
376379
static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
377-
ArrayRef<Value> strides) {
380+
ArrayRef<Value> strides, Value baseOffset) {
378381
Location loc = gatScatOp.getLoc();
379382
SmallVector<Value> offsets = gatScatOp.getOffsets();
380-
Value linearOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
381383
for (size_t i = 0; i < offsets.size(); ++i) {
382384
Value offsetContrib =
383385
arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
384-
linearOffset =
385-
arith::AddIOp::create(rewriter, loc, linearOffset, offsetContrib);
386+
baseOffset =
387+
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
386388
}
387389
Value indices = gatScatOp.getIndices();
388390
VectorType vecType = cast<VectorType>(indices.getType());
@@ -391,7 +393,7 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
391393
vector::BroadcastOp::create(
392394
rewriter, loc,
393395
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
394-
linearOffset)
396+
baseOffset)
395397
.getResult();
396398
return arith::AddIOp::create(rewriter, loc, baseVector, indices).getResult();
397399
}
@@ -402,8 +404,7 @@ template <
402404
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
403405
vector::GatherOp, vector::ScatterOp>::value>>
404406
// Convert memref to i64 base pointer
405-
static Value memrefToIndexPtr(OpType xferOp,
406-
PatternRewriter &rewriter) {
407+
static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
407408
Location loc = xferOp.getLoc();
408409
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
409410
rewriter, loc, xferOp.getBase())
@@ -613,18 +614,13 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
613614
Location loc = gatherOp.getLoc();
614615
VectorType vectorType = gatherOp.getVectorType();
615616

616-
SmallVector<Value> strides = computeStrides(gatherOp, rewriter);
617-
if (strides.empty())
617+
auto meta = computeMemrefMeta(gatherOp, rewriter);
618+
if (meta.first.empty())
618619
return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
619620

620-
Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
621-
Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
622-
623-
if (auto alignment = gatherOp.getAlignment()) {
624-
flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
625-
alignment.value())
626-
.getResult();
627-
}
621+
Value localOffsets =
622+
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
623+
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
628624

629625
auto xeGatherOp = xegpu::LoadGatherOp::create(
630626
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
@@ -651,19 +647,14 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
651647
return failure();
652648

653649
Location loc = scatterOp.getLoc();
654-
SmallVector<Value> strides = computeStrides(scatterOp, rewriter);
655-
if (strides.empty())
650+
auto meta = computeMemrefMeta(scatterOp, rewriter);
651+
if (meta.first.empty())
656652
return rewriter.notifyMatchFailure(scatterOp,
657653
"Failed to compute strides");
658654

659-
Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
660-
Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
661-
662-
if (auto alignment = scatterOp.getAlignment()) {
663-
flatMemref = memref::AssumeAlignmentOp::create(rewriter, loc, flatMemref,
664-
alignment.value())
665-
.getResult();
666-
}
655+
Value localOffsets =
656+
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
657+
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
667658

668659
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
669660
flatMemref, localOffsets, scatterOp.getMask(),

mlir/test/Conversion/VectorToXeGPU/gather-to-xegpu.mlir

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
1919
// CHECK-COUNT2: arith.addi {{.*}} : index
2020
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
2121
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
22-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
23-
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
22+
// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
23+
// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
24+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
2425
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
2526
// CHECK: gpu.return %[[RES]] : vector<8xf32>
2627
}
@@ -45,8 +46,9 @@ gpu.func @load_2D_memref(%source: memref<8x32xf32>,
4546
// CHECK-COUNT1: arith.addi {{.*}} : index
4647
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
4748
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
48-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
49-
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<256xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
49+
// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x32xf32> -> index
50+
// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
51+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
5052
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
5153
// CHECK: gpu.return %[[RES]] : vector<8xf32>
5254
}
@@ -71,8 +73,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
7173
// CHECK-COUNT2: arith.addi {{.*}} : index
7274
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
7375
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
74-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
75-
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
76+
// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
77+
// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
78+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
7679
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
7780
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
7881
}
@@ -98,8 +101,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
98101
// CHECK-COUNT2: arith.addi {{.*}} : index
99102
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
100103
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
101-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
102-
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
104+
// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?xf32> -> index
105+
// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
106+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
103107
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
104108
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
105109
}
@@ -125,8 +129,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
125129
// CHECK-COUNT2: arith.addi {{.*}} : index
126130
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
127131
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
128-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
129-
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
132+
// CHECK: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x8x16xf32> -> index
133+
// CHECK: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
134+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
130135
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
131136
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
132137
}
@@ -146,42 +151,37 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
146151

147152
// -----
148153
gpu.module @xevm_module {
149-
gpu.func @no_load_non_unit_inner_stride(
150-
%source: memref<32xf32, strided<[?], offset: ?>>,
151-
%off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
152-
%pass_thru: vector<8xf32>) -> vector<8xf32> {
153-
%0 = vector.gather %source[%off][%indices], %mask, %pass_thru
154-
: memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
155-
gpu.return %0 : vector<8xf32>
156-
}
157-
// CHECK-LABEL: @no_load_non_unit_inner_stride(
158-
// CHECK: vector.gather
154+
gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
155+
%off1: index, %off2: index,
156+
%indices: vector<8xindex>,
157+
%mask: vector<8xi1>,
158+
%pass_thru: vector<8xf16>) -> vector<8xf16> {
159+
%subview = memref.subview %source[%off1, %off2] [256, 256] [1, 1]
160+
: memref<4096x4096xf16>
161+
to memref<256x256xf16, strided<[4096, 1], offset: ?>>
162+
%0 = vector.gather %subview[%off1, %off2][%indices], %mask, %pass_thru
163+
: memref<256x256xf16, strided<[4096, 1], offset: ?>>,
164+
vector<8xindex>, vector<8xi1>, vector<8xf16>
165+
into vector<8xf16>
166+
gpu.return %0 : vector<8xf16>
167+
}
168+
// CHECK-LABEL: @gather_from_subview(
169+
// CHECK-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
170+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
171+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
172+
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>,
173+
// CHECK-SAME: %[[PASS:.+]]: vector<8xf16>) -> vector<8xf16> {
174+
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
175+
// CHECK: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
176+
// CHECK: arith.muli {{.*}} : index
177+
// CHECK: arith.addi %[[OFFSET]]{{.*}} : index
178+
// CHECK: %[[BASE_OFF:.+]] = arith.addi {{.*}} : index
179+
// CHECK: %[[SPLAT:.+]] = vector.broadcast %[[BASE_OFF]] : index to vector<8xindex>
180+
// CHECK: %[[LIN:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181+
// CHECK: %[[BASE_IDX:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
182+
// CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE_IDX]] : index to i64
183+
// CHECK: %[[VEC:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN]]{{\]}}, %[[MASK]]
184+
// CHECK-SAME: : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
185+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16>
186+
// CHECK: gpu.return %[[RES]] : vector<8xf16>
159187
}
160-
161-
// -----
162-
gpu.module @xevm_module {
163-
gpu.func @load_1D_aligned(%source: memref<8x16x32xf32>,
164-
%off1: index, %off2: index, %off3: index,
165-
%indices: vector<8xindex>, %mask: vector<8xi1>,
166-
%pass_thru: vector<8xf32>) -> vector<8xf32> {
167-
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
168-
%pass_thru {alignment = 256} : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
169-
gpu.return %0 : vector<8xf32>
170-
}
171-
// CHECK-LABEL: @load_1D_aligned(
172-
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
173-
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
174-
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
175-
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
176-
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
177-
// CHECK-COUNT2: arith.muli {{.*}} : index
178-
// CHECK-COUNT2: arith.addi {{.*}} : index
179-
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
180-
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
181-
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
182-
// CHECK: %[[COLLAPSE_ALIGN:.+]] = memref.assume_alignment %[[COLLAPSE]], 256 : memref<4096xf32>
183-
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_ALIGN]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
184-
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
185-
// CHECK: gpu.return %[[RES]] : vector<8xf32>
186-
}
187-

0 commit comments

Comments
 (0)