@@ -39,6 +39,175 @@ using namespace mlir::tosa;
3939// Operator Canonicalizers.
4040// ===----------------------------------------------------------------------===//
4141
42+ // ===----------------------------------------------------------------------===//
43+ // Tensor Data Engine Operators.
44+ // ===----------------------------------------------------------------------===//
45+
46+ namespace {
47+ template <typename OpTy>
48+ struct PoolPadFoldAdaptor ;
49+
50+ template <>
51+ struct PoolPadFoldAdaptor <tosa::AvgPool2dOp> {
52+ static void replaceOpWithNewPad (PatternRewriter &rewriter,
53+ tosa::AvgPool2dOp op, Value padInput,
54+ ArrayRef<int64_t > newPad) {
55+ rewriter.replaceOpWithNewOp <tosa::AvgPool2dOp>(
56+ op, op.getType (), padInput, op.getInputZp (), op.getOutputZp (),
57+ op.getKernel (), op.getStride (), rewriter.getDenseI64ArrayAttr (newPad),
58+ op.getAccType ());
59+ }
60+ };
61+
62+ template <>
63+ struct PoolPadFoldAdaptor <tosa::MaxPool2dOp> {
64+ static void replaceOpWithNewPad (PatternRewriter &rewriter,
65+ tosa::MaxPool2dOp op, Value padInput,
66+ ArrayRef<int64_t > newPad) {
67+ rewriter.replaceOpWithNewOp <tosa::MaxPool2dOp>(
68+ op, op.getType (), padInput, op.getKernel (), op.getStride (),
69+ rewriter.getDenseI64ArrayAttr (newPad), op.getNanMode ());
70+ }
71+ };
72+
73+ template <typename OpTy>
74+ struct ConvPadFoldAdaptor {
75+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
76+ Value padInput, ArrayRef<int64_t > newPad) {
77+ rewriter.replaceOpWithNewOp <OpTy>(
78+ op, op.getResult ().getType (), padInput, op.getWeight (), op.getBias (),
79+ op.getInputZp (), op.getWeightZp (), newPad, op.getStrideAttr (),
80+ op.getDilationAttr (), op.getAccType (), op.getLocalBound ());
81+ }
82+ };
83+
84+ // Pattern attempts to fold a `tosa.pad` operator to a following tensor
85+ // operation like `tosa.conv2d` by merging the padding associated with the
86+ // pad operator directly to the implicit padding of the tensor operation.
87+ // This helps eliminate the explicit padding operator if unused.
88+ template <typename OpTy, typename AdaptorTy>
89+ struct FoldPadToTensorOp : public OpRewritePattern <OpTy> {
90+ using OpRewritePattern<OpTy>::OpRewritePattern;
91+
92+ LogicalResult matchAndRewrite (OpTy tensorOp,
93+ PatternRewriter &rewriter) const override {
94+ // Check producer is a tosa::PadOp
95+ auto padOp = tensorOp.getInput ().template getDefiningOp <tosa::PadOp>();
96+ if (!padOp)
97+ return rewriter.notifyMatchFailure (tensorOp,
98+ " Producer must be a tosa::PadOp." );
99+
100+ // Validate that tensor operation has sane padding
101+ const std::vector<int64_t > &tensorOpPad = tensorOp.getPad ().vec ();
102+ if (tensorOpPad.size () != 4 ) // pad_top, pad_bottom, pad_left, pad_right
103+ return rewriter.notifyMatchFailure (
104+ tensorOp, " Tensor operation padding shall have 4 elements." );
105+
106+ // Validate tosa::PadOp padding
107+ DenseIntElementsAttr padOpPadding;
108+ if (!matchPattern (padOp.getPadding (), m_Constant (&padOpPadding))) {
109+ return rewriter.notifyMatchFailure (
110+ tensorOp,
111+ " The `padding` input specified on the tosa::PadOp must be constant." );
112+ }
113+ // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
114+ // C_after
115+ if (padOpPadding.size () != 8 )
116+ return rewriter.notifyMatchFailure (tensorOp,
117+ " Pad padding should have 8 elements." );
118+ int64_t padNBefore = (*(padOpPadding.begin () + 0 )).getLimitedValue ();
119+ int64_t padNAfter = (*(padOpPadding.begin () + 1 )).getLimitedValue ();
120+ int64_t padHBefore = (*(padOpPadding.begin () + 2 )).getLimitedValue ();
121+ int64_t padHAfter = (*(padOpPadding.begin () + 3 )).getLimitedValue ();
122+ int64_t padWBefore = (*(padOpPadding.begin () + 4 )).getLimitedValue ();
123+ int64_t padWAfter = (*(padOpPadding.begin () + 5 )).getLimitedValue ();
124+ int64_t padCBefore = (*(padOpPadding.begin () + 6 )).getLimitedValue ();
125+ int64_t padCAfter = (*(padOpPadding.begin () + 7 )).getLimitedValue ();
126+
127+ if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0 )
128+ return rewriter.notifyMatchFailure (
129+ tensorOp, " Folding padding in N or C dimensions is not supported." );
130+
131+ // Fold padding from Pad into the tensor operation
132+ // 4 elements - pad_top, pad_bottom, pad_left, pad_right
133+ SmallVector<int64_t > foldedPad (tensorOpPad.size ());
134+ foldedPad[0 ] = padHBefore + tensorOpPad[0 ];
135+ foldedPad[1 ] = padHAfter + tensorOpPad[1 ];
136+ foldedPad[2 ] = padWBefore + tensorOpPad[2 ];
137+ foldedPad[3 ] = padWAfter + tensorOpPad[3 ];
138+
139+ // Replace operator
140+ AdaptorTy::replaceOpWithNewPad (rewriter, tensorOp, padOp.getInput1 (),
141+ foldedPad);
142+
143+ return success ();
144+ }
145+ };
146+ } // namespace
147+
148+ void AvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
149+ MLIRContext *context) {
150+ results.add <FoldPadToTensorOp<tosa::AvgPool2dOp,
151+ PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
152+ context);
153+ }
154+
155+ void Conv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
156+ MLIRContext *context) {
157+ results.add <
158+ FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
159+ context);
160+ }
161+
162+ void DepthwiseConv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
163+ MLIRContext *context) {
164+ results.add <FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
165+ ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
166+ context);
167+ }
168+
169+ struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
170+ using OpRewritePattern::OpRewritePattern;
171+
172+ LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
173+ PatternRewriter &rewriter) const override {
174+ Value input = op.getInput ();
175+ Value output = op.getOutput ();
176+ ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
177+ ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
178+
179+ if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
180+ return failure ();
181+ }
182+
183+ // If the output and input shapes are 1x1, then this is a no op.
184+ ArrayRef<int64_t > outputShape = outputType.getShape ();
185+ if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
186+ return failure ();
187+ }
188+
189+ ArrayRef<int64_t > inputShape = inputType.getShape ();
190+ if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
191+ return failure ();
192+ }
193+
194+ rewriter.replaceOp (op, input);
195+ return success ();
196+ }
197+ };
198+
199+ void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
200+ MLIRContext *context) {
201+ results.add <MaxPool2dIsNoOp,
202+ FoldPadToTensorOp<tosa::MaxPool2dOp,
203+ PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
204+ context);
205+ }
206+
207+ // ===----------------------------------------------------------------------===//
208+ // Data Layout / Memory Reinterpretation.
209+ // ===----------------------------------------------------------------------===//
210+
42211struct ConcatOptimization : public OpRewritePattern <tosa::ConcatOp> {
43212 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44213
@@ -175,41 +344,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175344 results.add <ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176345}
177346
178- struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
179- using OpRewritePattern::OpRewritePattern;
180-
181- LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
182- PatternRewriter &rewriter) const override {
183- Value input = op.getInput ();
184- Value output = op.getOutput ();
185- ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
186- ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
187-
188- if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
189- return failure ();
190- }
191-
192- // If the output and input shapes are 1x1, then this is a no op.
193- ArrayRef<int64_t > outputShape = outputType.getShape ();
194- if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
195- return failure ();
196- }
197-
198- ArrayRef<int64_t > inputShape = inputType.getShape ();
199- if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
200- return failure ();
201- }
202-
203- rewriter.replaceOp (op, input);
204- return success ();
205- }
206- };
207-
208- void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
209- MLIRContext *context) {
210- results.add <MaxPool2dIsNoOp>(context);
211- }
212-
213347struct ClampIsNoOp : public OpRewritePattern <tosa::ClampOp> {
214348 using OpRewritePattern::OpRewritePattern;
215349
0 commit comments