@@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass
833833 (void )applyPatternsAndFoldGreedily (op, std::move (patterns));
834834 }
835835};
836+
837+ } // namespace
838+
839+ namespace {
840+
841+ // / Returns reassociation indices for collapsing/expanding a
842+ // / tensor of rank `rank` at position `pos`.
843+ static SmallVector<ReassociationIndices>
844+ getReassociationForReshapeAtDim (int64_t rank, int64_t pos) {
845+ SmallVector<ReassociationIndices> reassociation (rank - 1 , {0 , 1 });
846+ bool lastDim = pos == rank - 1 ;
847+ if (rank > 2 ) {
848+ for (int64_t i = 0 ; i < rank - 1 ; i++) {
849+ if (i == pos || (lastDim && i == pos - 1 ))
850+ reassociation[i] = ReassociationIndices{i, i + 1 };
851+ else if (i < pos)
852+ reassociation[i] = ReassociationIndices{i};
853+ else
854+ reassociation[i] = ReassociationIndices{i + 1 };
855+ }
856+ }
857+ return reassociation;
858+ }
859+
860+ // / Returns a collapsed `val` where the collapsing occurs at dim `pos`.
861+ // / If `pos < 0`, then don't collapse.
862+ static Value collapseSingletonDimAt (PatternRewriter &rewriter, Value val,
863+ int64_t pos) {
864+ if (pos < 0 )
865+ return val;
866+ auto valType = cast<ShapedType>(val.getType ());
867+ SmallVector<int64_t > collapsedShape (valType.getShape ());
868+ collapsedShape.erase (collapsedShape.begin () + pos);
869+ return collapseValue (
870+ rewriter, val.getLoc (), val, collapsedShape,
871+ getReassociationForReshapeAtDim (valType.getRank (), pos),
872+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
873+ }
874+
875+ // / Base class for all rank reduction patterns for contraction ops
876+ // / with unit dimensions. All patterns should convert one named op
877+ // / to another named op. Intended to reduce only one iteration space dim
878+ // / at a time.
879+ // / Reducing multiple dims will happen with recusive application of
880+ // / pattern rewrites.
881+ template <typename FromOpTy, typename ToOpTy>
882+ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
883+ using OpRewritePattern<FromOpTy>::OpRewritePattern;
884+
885+ // / Collapse all collapsable operands.
886+ SmallVector<Value>
887+ collapseOperands (PatternRewriter &rewriter, ArrayRef<Value> operands,
888+ ArrayRef<int64_t > operandCollapseDims) const {
889+ assert (operandCollapseDims.size () == 3 && operands.size () == 3 &&
890+ " expected 3 operands and dims" );
891+ return llvm::map_to_vector (
892+ llvm::zip (operands, operandCollapseDims), [&](auto pair) {
893+ return collapseSingletonDimAt (rewriter, std::get<0 >(pair),
894+ std::get<1 >(pair));
895+ });
896+ }
897+
898+ // / Expand result tensor.
899+ Value expandResult (PatternRewriter &rewriter, Value result,
900+ RankedTensorType expandedType, int64_t dim) const {
901+ return rewriter.create <tensor::ExpandShapeOp>(
902+ result.getLoc (), expandedType, result,
903+ getReassociationForReshapeAtDim (expandedType.getRank (), dim));
904+ }
905+
906+ LogicalResult matchAndRewrite (FromOpTy contractionOp,
907+ PatternRewriter &rewriter) const override {
908+
909+ auto loc = contractionOp.getLoc ();
910+ auto inputs = contractionOp.getDpsInputs ();
911+ auto inits = contractionOp.getDpsInits ();
912+ if (inputs.size () != 2 || inits.size () != 1 )
913+ return rewriter.notifyMatchFailure (contractionOp,
914+ " expected 2 inputs and 1 init" );
915+ auto lhs = inputs[0 ];
916+ auto rhs = inputs[1 ];
917+ auto init = inits[0 ];
918+ SmallVector<Value> operands{lhs, rhs, init};
919+
920+ SmallVector<int64_t > operandUnitDims;
921+ if (failed (getOperandUnitDims (contractionOp, operandUnitDims)))
922+ return rewriter.notifyMatchFailure (contractionOp,
923+ " no reducable dims found" );
924+
925+ SmallVector<Value> collapsedOperands =
926+ collapseOperands (rewriter, operands, operandUnitDims);
927+ Value collapsedLhs = collapsedOperands[0 ];
928+ Value collapsedRhs = collapsedOperands[1 ];
929+ Value collapsedInit = collapsedOperands[2 ];
930+ SmallVector<Type, 1 > collapsedResultTy;
931+ if (isa<RankedTensorType>(collapsedInit.getType ()))
932+ collapsedResultTy.push_back (collapsedInit.getType ());
933+ auto collapsedOp = rewriter.create <ToOpTy>(
934+ loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
935+ ValueRange{collapsedInit});
936+ for (auto attr : contractionOp->getAttrs ()) {
937+ if (attr.getName () == LinalgDialect::kMemoizedIndexingMapsAttrName )
938+ continue ;
939+ collapsedOp->setAttr (attr.getName (), attr.getValue ());
940+ }
941+
942+ auto results = contractionOp.getResults ();
943+ assert (results.size () < 2 && " expected at most one result" );
944+ if (results.empty ()) {
945+ rewriter.replaceOp (contractionOp, collapsedOp);
946+ } else {
947+ rewriter.replaceOp (
948+ contractionOp,
949+ expandResult (rewriter, collapsedOp.getResultTensors ()[0 ],
950+ cast<RankedTensorType>(results[0 ].getType ()),
951+ operandUnitDims[2 ]));
952+ }
953+
954+ return success ();
955+ }
956+
957+ // / Populate `operandUnitDims` with 3 indices indicating the unit dim
958+ // / for each operand that should be collapsed in this pattern. If an
959+ // / operand shouldn't be collapsed, the index should be negative.
960+ virtual LogicalResult
961+ getOperandUnitDims (LinalgOp op,
962+ SmallVectorImpl<int64_t > &operandUnitDims) const = 0 ;
963+ };
964+
965+ // / Patterns for unbatching batched contraction ops
966+ template <typename FromOpTy, typename ToOpTy>
967+ struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
968+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
969+
970+ // / Look for unit batch dims to collapse.
971+ LogicalResult
972+ getOperandUnitDims (LinalgOp op,
973+ SmallVectorImpl<int64_t > &operandUnitDims) const override {
974+ FailureOr<ContractionDimensions> maybeContractionDims =
975+ inferContractionDims (op);
976+ if (failed (maybeContractionDims)) {
977+ LLVM_DEBUG (llvm::dbgs () << " could not infer contraction dims" );
978+ return failure ();
979+ }
980+ ContractionDimensions contractionDims = maybeContractionDims.value ();
981+
982+ if (contractionDims.batch .size () != 1 )
983+ return failure ();
984+ auto batchDim = contractionDims.batch [0 ];
985+ SmallVector<std::pair<Value, unsigned >, 3 > bOperands;
986+ op.mapIterationSpaceDimToAllOperandDims (batchDim, bOperands);
987+ if (bOperands.size () != 3 || llvm::any_of (bOperands, [](auto pair) {
988+ return cast<ShapedType>(std::get<0 >(pair).getType ())
989+ .getShape ()[std::get<1 >(pair)] != 1 ;
990+ })) {
991+ LLVM_DEBUG (llvm::dbgs () << " specified unit dims not found" );
992+ return failure ();
993+ }
994+
995+ operandUnitDims = SmallVector<int64_t >{std::get<1 >(bOperands[0 ]),
996+ std::get<1 >(bOperands[1 ]),
997+ std::get<1 >(bOperands[2 ])};
998+ return success ();
999+ }
1000+ };
1001+
1002+ // / Patterns for reducing non-batch dimensions
1003+ template <typename FromOpTy, typename ToOpTy>
1004+ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1005+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1006+
1007+ // / Helper for determining whether the lhs/init or rhs/init are reduced.
1008+ static bool constexpr reduceLeft =
1009+ (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1010+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1011+ (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1012+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1013+ (std::is_same_v<FromOpTy, MatmulOp> &&
1014+ std::is_same_v<ToOpTy, VecmatOp>) ||
1015+ (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1016+ std::is_same_v<ToOpTy, VecmatOp>) ||
1017+ (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1018+
1019+ // / Look for non-batch spatial dims to collapse.
1020+ LogicalResult
1021+ getOperandUnitDims (LinalgOp op,
1022+ SmallVectorImpl<int64_t > &operandUnitDims) const override {
1023+ FailureOr<ContractionDimensions> maybeContractionDims =
1024+ inferContractionDims (op);
1025+ if (failed (maybeContractionDims)) {
1026+ LLVM_DEBUG (llvm::dbgs () << " could not infer contraction dims" );
1027+ return failure ();
1028+ }
1029+ ContractionDimensions contractionDims = maybeContractionDims.value ();
1030+
1031+ if constexpr (reduceLeft) {
1032+ auto m = contractionDims.m [0 ];
1033+ SmallVector<std::pair<Value, unsigned >, 2 > mOperands ;
1034+ op.mapIterationSpaceDimToAllOperandDims (m, mOperands );
1035+ if (mOperands .size () != 2 )
1036+ return failure ();
1037+ if (llvm::all_of (mOperands , [](auto pair) {
1038+ return cast<ShapedType>(std::get<0 >(pair).getType ())
1039+ .getShape ()[std::get<1 >(pair)] == 1 ;
1040+ })) {
1041+ operandUnitDims = SmallVector<int64_t >{std::get<1 >(mOperands [0 ]), -1 ,
1042+ std::get<1 >(mOperands [1 ])};
1043+ return success ();
1044+ }
1045+ } else {
1046+ auto n = contractionDims.n [0 ];
1047+ SmallVector<std::pair<Value, unsigned >, 2 > nOperands;
1048+ op.mapIterationSpaceDimToAllOperandDims (n, nOperands);
1049+ if (nOperands.size () != 2 )
1050+ return failure ();
1051+ if (llvm::all_of (nOperands, [](auto pair) {
1052+ return cast<ShapedType>(std::get<0 >(pair).getType ())
1053+ .getShape ()[std::get<1 >(pair)] == 1 ;
1054+ })) {
1055+ operandUnitDims = SmallVector<int64_t >{-1 , std::get<1 >(nOperands[0 ]),
1056+ std::get<1 >(nOperands[1 ])};
1057+ return success ();
1058+ }
1059+ }
1060+ LLVM_DEBUG (llvm::dbgs () << " specified unit dims not found" );
1061+ return failure ();
1062+ }
1063+ };
1064+
8361065} // namespace
1066+
1067+ void mlir::linalg::populateContractionOpRankReducingPatterns (
1068+ RewritePatternSet &patterns) {
1069+ MLIRContext *context = patterns.getContext ();
1070+ // Unbatching patterns for unit batch size
1071+ patterns.add <RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1072+ patterns
1073+ .add <RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1074+ context);
1075+ patterns
1076+ .add <RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1077+ context);
1078+ patterns.add <RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1079+ patterns.add <RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1080+
1081+ // Non-batch rank 1 reducing patterns
1082+ patterns.add <RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1083+ patterns.add <RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1084+ patterns.add <RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1085+ patterns.add <RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1086+ // Batch rank 1 reducing patterns
1087+ patterns.add <RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1088+ patterns.add <RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1089+ patterns.add <RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1090+ context);
1091+ patterns.add <RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1092+ context);
1093+
1094+ // Non-batch rank 0 reducing patterns
1095+ patterns.add <RankReduceMatmul<MatvecOp, DotOp>>(context);
1096+ patterns.add <RankReduceMatmul<VecmatOp, DotOp>>(context);
1097+ }
0 commit comments