1414
1515#include " mlir/Dialect/Arith/IR/Arith.h"
1616#include " mlir/Dialect/MemRef/IR/MemRef.h"
17+ #include " mlir/Dialect/Utils/IndexingUtils.h"
1718#include " mlir/Dialect/Utils/StructuredOpsUtils.h"
1819#include " mlir/Dialect/Vector/IR/VectorOps.h"
1920#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
21+ #include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
2022#include " mlir/Pass/Pass.h"
2123#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2224#include " llvm/ADT/TypeSwitch.h"
@@ -68,18 +70,14 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
6870 if (!srcTy)
6971 return rewriter.notifyMatchFailure (xferOp, " Expects memref source" );
7072
71- // Perform common data transfer checks.
72- VectorType vecTy = xferOp.getVectorType ();
73- if (failed (storeLoadPreconditions (rewriter, xferOp, vecTy)))
74- return failure ();
75-
7673 // Validate further transfer op semantics.
7774 SmallVector<int64_t > strides;
7875 int64_t offset;
7976 if (failed (srcTy.getStridesAndOffset (strides, offset)) || strides.back () != 1 )
8077 return rewriter.notifyMatchFailure (
8178 xferOp, " Buffer must be contiguous in the innermost dimension" );
8279
80+ VectorType vecTy = xferOp.getVectorType ();
8381 unsigned vecRank = vecTy.getRank ();
8482 if (xferOp.hasOutOfBoundsDim () && vecRank < 2 )
8583 return rewriter.notifyMatchFailure (
@@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
155153 return ndDesc;
156154}
157155
156+ // Adjusts the strides of a memref according to a given permutation map for
157+ // vector operations.
158+ //
159+ // This function updates the innermost strides in the `strides` array to
160+ // reflect the permutation specified by `permMap`. The permutation is computed
161+ // using the inverse and broadcasting-aware version of the permutation map,
162+ // and is applied to the relevant strides. This ensures that memory accesses
163+ // are consistent with the logical permutation of vector elements.
164+ //
165+ // Example:
166+ // Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
167+ // If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
168+ // 0]), then after calling this function, the last two strides will be
169+ // swapped:
170+ // Original strides: [s0, s1, s2, s3]
171+ // After permutation: [s0, s1, s3, s2]
172+ //
173+ static void adjustStridesForPermutation (AffineMap permMap,
174+ SmallVectorImpl<Value> &strides) {
175+
176+ AffineMap invMap = inverseAndBroadcastProjectedPermutation (permMap);
177+ SmallVector<unsigned > perms;
178+ invMap.isPermutationOfMinorIdentityWithBroadcasting (perms);
179+ SmallVector<int64_t > perms64 (perms.begin (), perms.end ());
180+ strides = applyPermutation (strides, perms64);
181+ }
182+
183+ // Computes memory strides for vector transfer operations, handling both
184+ // static and dynamic memrefs while applying permutation transformations
185+ // for XeGPU lowering.
186+ static SmallVector<Value> computeStrides (VectorTransferOpInterface xferOp,
187+ PatternRewriter &rewriter) {
188+ SmallVector<Value> strides;
189+ Value baseMemref = xferOp.getBase ();
190+ AffineMap permMap = xferOp.getPermutationMap ();
191+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
192+
193+ Location loc = xferOp.getLoc ();
194+ if (memrefType.hasStaticShape ()) {
195+ int64_t offset;
196+ SmallVector<int64_t > intStrides;
197+ if (failed (memrefType.getStridesAndOffset (intStrides, offset)))
198+ return {};
199+ // Wrap static strides as MLIR values
200+ for (int64_t s : intStrides)
201+ strides.push_back (arith::ConstantIndexOp::create (rewriter, loc, s));
202+ } else {
203+ // For dynamic shape memref, use memref.extract_strided_metadata to get
204+ // stride values
205+ unsigned rank = memrefType.getRank ();
206+ Type indexType = rewriter.getIndexType ();
207+
208+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
209+ // size0, size1, ..., sizeN-1]
210+ SmallVector<Type> resultTypes;
211+ resultTypes.push_back (MemRefType::get (
212+ {}, memrefType.getElementType ())); // base memref (unranked)
213+ resultTypes.push_back (indexType); // offset
214+
215+ for (unsigned i = 0 ; i < rank; ++i)
216+ resultTypes.push_back (indexType); // strides
217+
218+ for (unsigned i = 0 ; i < rank; ++i)
219+ resultTypes.push_back (indexType); // sizes
220+
221+ auto meta = memref::ExtractStridedMetadataOp::create (
222+ rewriter, loc, resultTypes, baseMemref);
223+ strides.append (meta.getStrides ().begin (), meta.getStrides ().end ());
224+ }
225+ // Adjust strides according to the permutation map (e.g., for transpose)
226+ adjustStridesForPermutation (permMap, strides);
227+ return strides;
228+ }
229+
230+ // This function compute the vectors of localOffsets for scattered load/stores.
231+ // It is used in the lowering of vector.transfer_read/write to
232+ // load_gather/store_scatter Example:
233+ // %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
234+ // %cst {in_bounds = [true, true, true, true]}>} :
235+ // memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
236+ //
237+ // %6 = vector.step: vector<4xindex>
238+ // %7 = vector.step: vector<2xindex>
239+ // %8 = vector.step: vector<6xindex>
240+ // %9 = vector.step: vector<32xindex>
241+ // %10 = arith.mul %6, 384
242+ // %11 = arith.mul %7, 192
243+ // %12 = arith.mul %8, 32
244+ // %13 = arith.mul %9, 1
245+ // %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
246+ // %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
247+ // %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
248+ // %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
249+ // %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
250+ // %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
251+ // %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
252+ // %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
253+ // %22 = arith.add %18, %19
254+ // %23 = arith.add %20, %21
255+ // %local_offsets = arith.add %22, %23
256+ // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257+ // %offsets = orig_offset + local_offsets
258+ static Value computeOffsets (VectorTransferOpInterface xferOp,
259+ PatternRewriter &rewriter,
260+ ArrayRef<Value> strides) {
261+ Location loc = xferOp.getLoc ();
262+ VectorType vectorType = xferOp.getVectorType ();
263+ SmallVector<Value> indices (xferOp.getIndices ().begin (),
264+ xferOp.getIndices ().end ());
265+ ArrayRef<int64_t > vectorShape = vectorType.getShape ();
266+
267+ // Create vector.step operations for each dimension
268+ SmallVector<Value> stepVectors;
269+ llvm::map_to_vector (vectorShape, [&](int64_t dim) {
270+ auto stepType = VectorType::get ({dim}, rewriter.getIndexType ());
271+ auto stepOp = vector::StepOp::create (rewriter, loc, stepType);
272+ stepVectors.push_back (stepOp);
273+ return stepOp;
274+ });
275+
276+ // Multiply step vectors by corresponding strides
277+ size_t memrefRank = strides.size ();
278+ size_t vectorRank = vectorShape.size ();
279+ SmallVector<Value> strideMultiplied;
280+ for (size_t i = 0 ; i < vectorRank; ++i) {
281+ size_t memrefDim = memrefRank - vectorRank + i;
282+ Value strideValue = strides[memrefDim];
283+ auto mulType = dyn_cast<VectorType>(stepVectors[i].getType ());
284+ auto bcastOp =
285+ vector::BroadcastOp::create (rewriter, loc, mulType, strideValue);
286+ auto mulOp = arith::MulIOp::create (rewriter, loc, stepVectors[i], bcastOp);
287+ strideMultiplied.push_back (mulOp);
288+ }
289+
290+ // Shape cast each multiplied vector to add singleton dimensions
291+ SmallVector<Value> shapeCasted;
292+ for (size_t i = 0 ; i < vectorRank; ++i) {
293+ SmallVector<int64_t > newShape (vectorRank, 1 );
294+ newShape[i] = vectorShape[i];
295+ auto newType = VectorType::get (newShape, rewriter.getIndexType ());
296+ auto castOp = vector::ShapeCastOp::create (rewriter, loc, newType,
297+ strideMultiplied[i]);
298+ shapeCasted.push_back (castOp);
299+ }
300+
301+ // Broadcast each shape-casted vector to full vector shape
302+ SmallVector<Value> broadcasted;
303+ auto fullIndexVectorType =
304+ VectorType::get (vectorShape, rewriter.getIndexType ());
305+ for (Value shapeCastVal : shapeCasted) {
306+ auto broadcastOp = vector::BroadcastOp::create (
307+ rewriter, loc, fullIndexVectorType, shapeCastVal);
308+ broadcasted.push_back (broadcastOp);
309+ }
310+
311+ // Add all broadcasted vectors together to compute local offsets
312+ Value localOffsets = broadcasted[0 ];
313+ for (size_t i = 1 ; i < broadcasted.size (); ++i)
314+ localOffsets =
315+ arith::AddIOp::create (rewriter, loc, localOffsets, broadcasted[i]);
316+
317+ // Compute base offset from transfer read indices
318+ Value baseOffset = nullptr ;
319+ if (!indices.empty ()) {
320+ baseOffset = arith::ConstantIndexOp::create (rewriter, loc, 0 );
321+ for (size_t i = 0 ; i < indices.size (); ++i) {
322+ Value strideVal = strides[i];
323+ Value offsetContrib =
324+ arith::MulIOp::create (rewriter, loc, indices[i], strideVal);
325+ baseOffset =
326+ arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
327+ }
328+ // Broadcast base offset to match vector shape
329+ Value bcastBase = vector::BroadcastOp::create (
330+ rewriter, loc, fullIndexVectorType, baseOffset);
331+ localOffsets =
332+ arith::AddIOp::create (rewriter, loc, bcastBase, localOffsets);
333+ }
334+ return localOffsets;
335+ }
336+
337+ // Collapse memref shape to 1D
338+ static Value collapseMemrefTo1D (VectorTransferOpInterface xferOp,
339+ PatternRewriter &rewriter) {
340+ Location loc = xferOp.getLoc ();
341+
342+ Value baseMemref = xferOp.getBase ();
343+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
344+ Type elementType = memrefType.getElementType ();
345+
346+ // Compute the total number of elements in the memref
347+ MemRefType flatMemrefType;
348+ if (memrefType.hasStaticShape ()) {
349+ auto totalElements = memrefType.getNumElements ();
350+ flatMemrefType = MemRefType::get ({totalElements}, elementType);
351+ } else {
352+ flatMemrefType = MemRefType::get ({ShapedType::kDynamic }, elementType);
353+ }
354+
355+ SmallVector<ReassociationIndices> reassociation;
356+ ReassociationIndices allDims =
357+ llvm::to_vector (llvm::seq<int64_t >(0 , memrefType.getRank ()));
358+ reassociation.push_back (allDims);
359+
360+ auto collapseOp = memref::CollapseShapeOp::create (
361+ rewriter, loc, flatMemrefType, baseMemref, reassociation);
362+ return collapseOp;
363+ }
364+
365+ static LogicalResult lowerToScatteredLoadOp (vector::TransferReadOp readOp,
366+ PatternRewriter &rewriter) {
367+
368+ Location loc = readOp.getLoc ();
369+ VectorType vectorType = readOp.getVectorType ();
370+ ArrayRef<int64_t > vectorShape = vectorType.getShape ();
371+ auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType ());
372+ if (!memrefType)
373+ return rewriter.notifyMatchFailure (readOp, " Expected memref source" );
374+
375+ SmallVector<Value> strides = computeStrides (readOp, rewriter);
376+ if (strides.empty ())
377+ return rewriter.notifyMatchFailure (readOp, " Failed to compute strides" );
378+
379+ Value localOffsets = computeOffsets (readOp, rewriter, strides);
380+
381+ Value flatMemref = collapseMemrefTo1D (readOp, rewriter);
382+
383+ Value mask = vector::ConstantMaskOp::create (
384+ rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
385+ vectorShape);
386+ auto gatherOp = xegpu::LoadGatherOp::create (
387+ rewriter, loc, vectorType, flatMemref, localOffsets, mask,
388+ /* chunk_size=*/ IntegerAttr{},
389+ /* l1_hint=*/ xegpu::CachePolicyAttr{},
390+ /* l2_hint=*/ xegpu::CachePolicyAttr{},
391+ /* l3_hint=*/ xegpu::CachePolicyAttr{});
392+
393+ rewriter.replaceOp (readOp, gatherOp.getResult ());
394+ return success ();
395+ }
396+
397+ static LogicalResult lowerToScatteredStoreOp (vector::TransferWriteOp writeOp,
398+ PatternRewriter &rewriter) {
399+
400+ Location loc = writeOp.getLoc ();
401+ VectorType vectorType = writeOp.getVectorType ();
402+ ArrayRef<int64_t > vectorShape = vectorType.getShape ();
403+
404+ auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType ());
405+ if (!memrefType)
406+ return rewriter.notifyMatchFailure (writeOp, " Expected memref source" );
407+
408+ SmallVector<Value> strides = computeStrides (writeOp, rewriter);
409+
410+ Value localOffsets = computeOffsets (writeOp, rewriter, strides);
411+
412+ Value flatMemref = collapseMemrefTo1D (writeOp, rewriter);
413+
414+ Value mask = vector::ConstantMaskOp::create (
415+ rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
416+ vectorShape);
417+ xegpu::StoreScatterOp::create (rewriter, loc, writeOp.getVector (), flatMemref,
418+ localOffsets, mask,
419+ /* chunk_size=*/ IntegerAttr{},
420+ /* l1_hint=*/ xegpu::CachePolicyAttr{},
421+ /* l2_hint=*/ xegpu::CachePolicyAttr{},
422+ /* l3_hint=*/ xegpu::CachePolicyAttr{});
423+ rewriter.eraseOp (writeOp);
424+ return success ();
425+ }
426+
158427struct TransferReadLowering : public OpRewritePattern <vector::TransferReadOp> {
159428 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
160429
@@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
165434 if (failed (transferPreconditions (rewriter, readOp)))
166435 return failure ();
167436
437+ // TODO:This check needs to be replaced with proper uArch capability check
438+ auto chip = xegpu::getChipStr (readOp);
439+ if (chip != " pvc" && chip != " bmg" ) {
440+ // lower to scattered load Op if the target HW doesn't have 2d block load
441+ // support
442+ // TODO: add support for OutOfBound access
443+ if (readOp.hasOutOfBoundsDim ())
444+ return failure ();
445+ return lowerToScatteredLoadOp (readOp, rewriter);
446+ }
447+
448+ // Perform common data transfer checks.
449+ VectorType vecTy = readOp.getVectorType ();
450+ if (failed (storeLoadPreconditions (rewriter, readOp, vecTy)))
451+ return failure ();
452+
168453 bool isOutOfBounds = readOp.hasOutOfBoundsDim ();
169454 if (isOutOfBounds && !isZeroConstant (readOp.getPadding ()))
170455 return rewriter.notifyMatchFailure (
@@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
173458 AffineMap readMap = readOp.getPermutationMap ();
174459 bool isTransposeLoad = !readMap.isMinorIdentity ();
175460
176- VectorType vecTy = readOp.getVectorType ();
177461 Type elementType = vecTy.getElementType ();
178462 unsigned minTransposeBitWidth = 32 ;
179463 if (isTransposeLoad &&
@@ -221,11 +505,26 @@ struct TransferWriteLowering
221505 if (failed (transferPreconditions (rewriter, writeOp)))
222506 return failure ();
223507
508+ // TODO:This check needs to be replaced with proper uArch capability check
509+ auto chip = xegpu::getChipStr (writeOp);
510+ if (chip != " pvc" && chip != " bmg" ) {
511+ // lower to scattered store Op if the target HW doesn't have 2d block
512+ // store support
513+ // TODO: add support for OutOfBound access
514+ if (writeOp.hasOutOfBoundsDim ())
515+ return failure ();
516+ return lowerToScatteredStoreOp (writeOp, rewriter);
517+ }
518+
519+ // Perform common data transfer checks.
520+ VectorType vecTy = writeOp.getVectorType ();
521+ if (failed (storeLoadPreconditions (rewriter, writeOp, vecTy)))
522+ return failure ();
523+
224524 AffineMap map = writeOp.getPermutationMap ();
225525 if (!map.isMinorIdentity ())
226526 return rewriter.notifyMatchFailure (writeOp, " Expects identity map" );
227527
228- VectorType vecTy = writeOp.getVectorType ();
229528 auto descType = xegpu::TensorDescType::get (
230529 vecTy.getShape (), vecTy.getElementType (),
231530 /* array_length=*/ 1 , /* boundary_check=*/ writeOp.hasOutOfBoundsDim (),
0 commit comments