@@ -902,6 +902,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
902902};
903903
904904// / Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
905+ // /
906+ // / Example:
905907// / ```
906908// / %a = vector.broadcast %arg1 : index to vector<1x4xindex>
907909// / %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -987,6 +989,8 @@ struct ReorderElementwiseOpsOnBroadcast final
987989// / This may result in cleaner code when extracting a single value
988990// / from multi-element vector and also to help canonicalize 1-element vectors to
989991// / scalars.
992+ // /
993+ // / Example:
990994// / ```
991995// /  %0 = arith.addf %arg0, %arg1 : vector<4xf32>
992996// /  %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1043,6 +1047,150 @@ class ExtractOpFromElementwise final
10431047  }
10441048};
10451049
1050+ // / Check if the element type is suitable for vector.load/store sinking.
1051+ // / Element type must be index or byte-aligned integer or floating-point type.
1052+ static  bool  isSupportedMemSinkElementType (Type type) {
1053+   if  (isa<IndexType>(type))
1054+     return  true ;
1055+ 
1056+   return  type.isIntOrFloat () && type.getIntOrFloatBitWidth () % 8  == 0 ;
1057+ }
1058+ 
1059+ // / Pattern to rewrite `vector.extract(vector.load) -> vector/memref.load.
1060+ // / Only index and byte-aligned integer and floating-point element types are
1061+ // / supported for now.
1062+ // /
1063+ // / Example:
1064+ // / ```
1065+ // /  vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1066+ // /  vector.extract %0[1] : f32 from vector<4xf32>
1067+ // / ```
1068+ // / Gets converted to:
1069+ // / ```
1070+ // / %c1 = arith.constant 1 : index
1071+ // / %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1072+ // / %1 = memref.load %arg0[%0] : memref<?xf32>
1073+ // / ```
1074+ class  ExtractOpFromLoad  final  : public OpRewritePattern<vector::ExtractOp> {
1075+ public: 
1076+   using  OpRewritePattern::OpRewritePattern;
1077+ 
1078+   LogicalResult matchAndRewrite (vector::ExtractOp op,
1079+                                 PatternRewriter &rewriter) const  override  {
1080+     auto  loadOp = op.getVector ().getDefiningOp <vector::LoadOp>();
1081+     if  (!loadOp)
1082+       return  rewriter.notifyMatchFailure (op, " expected a load op"  );
1083+ 
1084+     //  Checking for single use so we won't duplicate load ops.
1085+     if  (!loadOp->hasOneUse ())
1086+       return  rewriter.notifyMatchFailure (op, " expected single op use"  );
1087+ 
1088+     VectorType loadVecType = loadOp.getVectorType ();
1089+     if  (loadVecType.isScalable ())
1090+       return  rewriter.notifyMatchFailure (op,
1091+                                          " scalable vectors are not supported"  );
1092+ 
1093+     MemRefType memType = loadOp.getMemRefType ();
1094+ 
1095+     //  Non-byte-aligned types are tricky and may require special handling,
1096+     //  ignore them for now.
1097+     if  (!isSupportedMemSinkElementType (memType.getElementType ()))
1098+       return  rewriter.notifyMatchFailure (op, " unsupported element type"  );
1099+ 
1100+     int64_t  rankOffset = memType.getRank () - loadVecType.getRank ();
1101+     if  (rankOffset < 0 )
1102+       return  rewriter.notifyMatchFailure (op, " unsupported ranks combination"  );
1103+ 
1104+     auto  extractVecType = dyn_cast<VectorType>(op.getResult ().getType ());
1105+     int64_t  finalRank = 0 ;
1106+     if  (extractVecType)
1107+       finalRank = extractVecType.getRank ();
1108+ 
1109+     SmallVector<Value> indices = loadOp.getIndices ();
1110+     SmallVector<OpFoldResult> extractPos = op.getMixedPosition ();
1111+ 
1112+     //  There may be memory stores between the load and the extract op, so we
1113+     //  need to make sure that the new load op is inserted at the same place as
1114+     //  the original load op.
1115+     OpBuilder::InsertionGuard g (rewriter);
1116+     rewriter.setInsertionPoint (loadOp);
1117+     Location loc = loadOp.getLoc ();
1118+     ArithIndexingBuilder idxBuilderf (rewriter, loc);
1119+     for  (auto  i : llvm::seq<int64_t >(rankOffset, indices.size () - finalRank)) {
1120+       OpFoldResult pos = extractPos[i - rankOffset];
1121+       if  (isConstantIntValue (pos, 0 ))
1122+         continue ;
1123+ 
1124+       Value offset = getValueOrCreateConstantIndexOp (rewriter, loc, pos);
1125+       indices[i] = idxBuilderf.add (indices[i], offset);
1126+     }
1127+ 
1128+     Value base = loadOp.getBase ();
1129+     if  (extractVecType) {
1130+       rewriter.replaceOpWithNewOp <vector::LoadOp>(op, extractVecType, base,
1131+                                                   indices);
1132+     } else  {
1133+       rewriter.replaceOpWithNewOp <memref::LoadOp>(op, base, indices);
1134+     }
1135+     //  We checked for single use so we can safely erase the load op.
1136+     rewriter.eraseOp (loadOp);
1137+     return  success ();
1138+   }
1139+ };
1140+ 
1141+ // / Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1142+ // /
1143+ // / Example:
1144+ // / ```
1145+ // / %0 = vector.splat %arg2 : vector<1xf32>
1146+ // / vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1147+ // / ```
1148+ // / Gets converted to:
1149+ // / ```
1150+ // / memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1151+ // / ```
1152+ class  StoreOpFromSplatOrBroadcast  final 
1153+     : public OpRewritePattern<vector::StoreOp> {
1154+ public: 
1155+   using  OpRewritePattern::OpRewritePattern;
1156+ 
1157+   LogicalResult matchAndRewrite (vector::StoreOp op,
1158+                                 PatternRewriter &rewriter) const  override  {
1159+     VectorType vecType = op.getVectorType ();
1160+     if  (vecType.isScalable ())
1161+       return  rewriter.notifyMatchFailure (op,
1162+                                          " scalable vectors are not supported"  );
1163+ 
1164+     if  (isa<VectorType>(op.getMemRefType ().getElementType ()))
1165+       return  rewriter.notifyMatchFailure (
1166+           op, " memrefs of vectors are not supported"  );
1167+ 
1168+     if  (vecType.getNumElements () != 1 )
1169+       return  rewriter.notifyMatchFailure (
1170+           op, " only 1-element vectors are supported"  );
1171+ 
1172+     Operation *splat = op.getValueToStore ().getDefiningOp ();
1173+     if  (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1174+       return  rewriter.notifyMatchFailure (op, " neither a splat nor a broadcast"  );
1175+ 
1176+     //  Checking for single use so we can remove splat.
1177+     if  (!splat->hasOneUse ())
1178+       return  rewriter.notifyMatchFailure (op, " expected single op use"  );
1179+ 
1180+     Value source = splat->getOperand (0 );
1181+     Value base = op.getBase ();
1182+     ValueRange indices = op.getIndices ();
1183+ 
1184+     if  (isa<VectorType>(source.getType ())) {
1185+       rewriter.replaceOpWithNewOp <vector::StoreOp>(op, source, base, indices);
1186+     } else  {
1187+       rewriter.replaceOpWithNewOp <memref::StoreOp>(op, source, base, indices);
1188+     }
1189+     rewriter.eraseOp (splat);
1190+     return  success ();
1191+   }
1192+ };
1193+ 
10461194//  Helper that returns a vector comparison that constructs a mask:
10471195//      mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
10481196// 
@@ -2109,6 +2257,13 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
21092257      patterns.getContext (), benefit);
21102258}
21112259
2260+ void  mlir::vector::populateSinkVectorMemOpsPatterns (RewritePatternSet &patterns,
2261+                                                     PatternBenefit benefit) {
2262+   //  TODO: Consider converting these patterns to canonicalizations.
2263+   patterns.add <ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
2264+       patterns.getContext (), benefit);
2265+ }
2266+ 
21122267void  mlir::vector::populateChainedVectorReductionFoldingPatterns (
21132268    RewritePatternSet &patterns, PatternBenefit benefit) {
21142269  patterns.add <ChainedReduction>(patterns.getContext (), benefit);
0 commit comments