@@ -57,53 +57,9 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
5757 }
5858};
5959
60- struct ConcatFolding : public OpRewritePattern <tosa::ConcatOp> {
61- using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
62-
63- LogicalResult matchAndRewrite (tosa::ConcatOp op,
64- PatternRewriter &rewriter) const override {
65- // Fold consecutive concats on the same axis into a single op.
66- uint64_t axis = op.getAxis ();
67-
68- // Keep track of the operands so we are able to construct a new concat
69- // later. Conservatively assume that we double the number of operands when
70- // folding
71- SmallVector<Value, 8 > concatOperands;
72- concatOperands.reserve (2 * op->getNumOperands ());
73-
74- // Find all operands that are foldable concats
75- bool canFold = false ;
76- for (Value operand : op->getOperands ()) {
77- concatOperands.emplace_back (operand);
78-
79- auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp ());
80- if (!producer)
81- continue ;
82-
83- // Foldable if axis is the same
84- if (axis != producer.getAxis ())
85- continue ;
86-
87- // Replace the original operand with all incoming operands
88- canFold = true ;
89- concatOperands.pop_back ();
90- llvm::append_range (concatOperands, producer->getOperands ());
91- }
92-
93- if (!canFold)
94- return rewriter.notifyMatchFailure (op, " No foldable concats found" );
95-
96- // Replace the original concat with a new one that contains the original and
97- // folded operands
98- rewriter.replaceOpWithNewOp <tosa::ConcatOp>(op, op->getResultTypes (),
99- concatOperands, axis);
100- return success ();
101- }
102- };
103-
10460void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
10561 MLIRContext *context) {
106- results.add <ConcatOptimization, ConcatFolding >(context);
62+ results.add <ConcatOptimization>(context);
10763}
10864
10965struct ReshapeReshapeOptimization : public OpRewritePattern <tosa::ReshapeOp> {
@@ -1039,3 +995,37 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
1039995 return getInput1 ();
1040996 return {};
1041997}
998+
999+ OpFoldResult ConcatOp::fold (ArrayRef<Attribute> operands) {
1000+ // Fold consecutive concats on the same axis into a single op.
1001+ // Keep track of the operands so we are able to construct a new concat
1002+ // later. Conservatively assume that we double the number of operands when
1003+ // folding
1004+ SmallVector<Value, 8 > concatOperands;
1005+ concatOperands.reserve (2 * getNumOperands ());
1006+
1007+ // Find all operands that are foldable concats
1008+ bool canFold = false ;
1009+ for (Value operand : getOperands ()) {
1010+ concatOperands.emplace_back (operand);
1011+
1012+ auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp ());
1013+ if (!producer)
1014+ continue ;
1015+
1016+ // Foldable if axis is the same
1017+ if (getAxis () != producer.getAxis ())
1018+ continue ;
1019+
1020+ // Replace the original operand with all incoming operands
1021+ canFold = true ;
1022+ concatOperands.pop_back ();
1023+ llvm::append_range (concatOperands, producer->getOperands ());
1024+ }
1025+
1026+ if (!canFold)
1027+ return {};
1028+
1029+ getOperation ()->setOperands (concatOperands);
1030+ return getResult ();
1031+ }
0 commit comments