@@ -261,6 +261,104 @@ class OuterProductFusion2Way
261261 }
262262};
263263
264+ // Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
265+ //
266+ // This transforms IR like:
267+ // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
268+ // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
269+ // Into:
270+ // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
271+ // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
272+ //
273+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
274+ // pass when the result is the input to an outer product.
275+ struct SwapVectorExtractOfArithExtend
276+ : public OpRewritePattern<vector::ExtractOp> {
277+ using OpRewritePattern::OpRewritePattern;
278+
279+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
280+ PatternRewriter &rewriter) const override {
281+ VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType ());
282+ if (!resultType)
283+ return rewriter.notifyMatchFailure (extractOp,
284+ " extracted type is not a vector type" );
285+
286+ auto numScalableDims = llvm::count (resultType.getScalableDims (), true );
287+ if (numScalableDims != 1 )
288+ return rewriter.notifyMatchFailure (
289+ extractOp, " extracted type is not a 1-D scalable vector type" );
290+
291+ auto *extendOp = extractOp.getVector ().getDefiningOp ();
292+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
293+ extendOp))
294+ return rewriter.notifyMatchFailure (extractOp,
295+ " extract not from extend op" );
296+
297+ auto loc = extractOp.getLoc ();
298+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
299+ Value extendSource = extendOp->getOperand (0 );
300+
301+ // Create new extract from source of extend.
302+ Value newExtract = rewriter.create <vector::ExtractOp>(
303+ loc, extendSource, extractOp.getMixedPosition ());
304+
305+ // Extend new extract to original result type.
306+ Operation *newExtend =
307+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
308+
309+ rewriter.replaceOp (extractOp, newExtend);
310+
311+ return success ();
312+ }
313+ };
314+
315+ // Same as above, but for vector.scalable.extract.
316+ //
317+ // This transforms IR like:
318+ // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
319+ // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
320+ // Into:
321+ // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
322+ // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
323+ //
324+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
325+ // pass when the result is the input to an outer product.
326+ struct SwapVectorScalableExtractOfArithExtend
327+ : public OpRewritePattern<vector::ScalableExtractOp> {
328+ using OpRewritePattern::OpRewritePattern;
329+
330+ LogicalResult matchAndRewrite (vector::ScalableExtractOp extractOp,
331+ PatternRewriter &rewriter) const override {
332+ auto *extendOp = extractOp.getSource ().getDefiningOp ();
333+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
334+ extendOp))
335+ return rewriter.notifyMatchFailure (extractOp,
336+ " extract not from extend op" );
337+
338+ auto loc = extractOp.getLoc ();
339+ VectorType resultType = extractOp.getResultVectorType ();
340+
341+ Value extendSource = extendOp->getOperand (0 );
342+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
343+ VectorType extendSourceVectorType =
344+ cast<VectorType>(extendSource.getType ());
345+
346+ // Create new extract from source of extend.
347+ VectorType extractResultVectorType =
348+ resultType.clone (extendSourceVectorType.getElementType ());
349+ Value newExtract = rewriter.create <vector::ScalableExtractOp>(
350+ loc, extractResultVectorType, extendSource, extractOp.getPos ());
351+
352+ // Extend new extract to original result type.
353+ Operation *newExtend =
354+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
355+
356+ rewriter.replaceOp (extractOp, newExtend);
357+
358+ return success ();
359+ }
360+ };
361+
264362struct OuterProductFusionPass
265363 : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
266364
@@ -278,7 +376,11 @@ struct OuterProductFusionPass
278376
279377void mlir::arm_sme::populateOuterProductFusionPatterns (
280378 RewritePatternSet &patterns) {
281- patterns.add <OuterProductFusion2Way>(patterns.getContext ());
379+ MLIRContext *context = patterns.getContext ();
380+ // Note: High benefit to ensure extract(extend) are swapped first.
381+ patterns.add <SwapVectorExtractOfArithExtend,
382+ SwapVectorScalableExtractOfArithExtend>(context, 1024 );
383+ patterns.add <OuterProductFusion2Way>(context);
282384}
283385
284386std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass () {
0 commit comments