@@ -39,6 +39,277 @@ using namespace mlir::tosa;
3939// Operator Canonicalizers.
4040// ===----------------------------------------------------------------------===//
4141
42+ // ===----------------------------------------------------------------------===//
43+ // Tensor Data Engine Operators.
44+ // ===----------------------------------------------------------------------===//
45+
46+ // Check that the zero point of the tensor and padding operations are aligned.
47+ bool checkMatchingPadConstAndZp (Value padConst, Value zp) {
48+ // Check that padConst is a constant value and a scalar tensor
49+ DenseElementsAttr padConstAttr;
50+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
51+ (padConstAttr.size () != 1 )) {
52+ return false ;
53+ }
54+
55+ // Check that floating point pad is zero
56+ if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57+ float padConstVal = (*padConstFpAttr.begin ()).convertToFloat ();
58+ return padConstVal == 0 .0f ;
59+ }
60+
61+ // Check that the zp and padConst align for the integer (quantized) case
62+ if (auto padConstIntAttr =
63+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64+ DenseIntElementsAttr zpAttr;
65+ // Check that zp is a constant value and a scalar tensor
66+ if (!matchPattern (zp, m_Constant (&zpAttr)) || (padConstAttr.size () != 1 )) {
67+ return false ;
68+ }
69+
70+ // Check equality
71+ int64_t zpVal = (*zpAttr.begin ()).getSExtValue ();
72+ int64_t padConstVal = (*padConstIntAttr.begin ()).getSExtValue ();
73+ return zpVal == padConstVal;
74+ }
75+
76+ // Bail-out on unsupported type
77+ return false ;
78+ }
79+
80+ namespace {
81+ template <typename OpTy>
82+ struct PoolPadFoldAdaptor ;
83+
84+ template <>
85+ struct PoolPadFoldAdaptor <tosa::AvgPool2dOp> {
86+ using OpTy = tosa::AvgPool2dOp;
87+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
88+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
89+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
90+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
91+ return false ;
92+ return true ;
93+ }
94+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
95+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
96+ }
97+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
98+ Value padInput, ArrayRef<int64_t > newPad) {
99+ rewriter.replaceOpWithNewOp <tosa::AvgPool2dOp>(
100+ op, op.getType (), padInput, op.getInputZp (), op.getOutputZp (),
101+ op.getKernel (), op.getStride (), rewriter.getDenseI64ArrayAttr (newPad),
102+ op.getAccType ());
103+ }
104+ };
105+
106+ template <>
107+ struct PoolPadFoldAdaptor <tosa::MaxPool2dOp> {
108+ using OpTy = tosa::MaxPool2dOp;
109+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
110+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
111+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
112+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
113+ return false ;
114+ return true ;
115+ }
116+ static bool checkPadConstCompliance (OpTy, Value padConst) {
117+ // Check that padConst is a constant value and a scalar tensor
118+ DenseElementsAttr padConstAttr;
119+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
120+ padConstAttr.size () != 1 ) {
121+ return false ;
122+ }
123+
124+ // Pad needs to be in the minimum value to be able to merge
125+ if (auto padConstFpAttr =
126+ mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127+ const APFloat padConstVal = *padConstFpAttr.begin ();
128+ const APFloat lowestVal =
129+ APFloat::getLargest (padConstVal.getSemantics (), true );
130+ return padConstVal == lowestVal;
131+ } else if (auto padConstIntAttr =
132+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
133+ const APInt padConstVal = *padConstIntAttr.begin ();
134+ const APInt lowestVal =
135+ APInt::getSignedMinValue (padConstVal.getBitWidth ());
136+ return padConstVal == lowestVal;
137+ }
138+
139+ // Bail-out on unsupported type
140+ return false ;
141+ }
142+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
143+ Value padInput, ArrayRef<int64_t > newPad) {
144+ rewriter.replaceOpWithNewOp <tosa::MaxPool2dOp>(
145+ op, op.getType (), padInput, op.getKernel (), op.getStride (),
146+ rewriter.getDenseI64ArrayAttr (newPad), op.getNanMode ());
147+ }
148+ };
149+
150+ template <typename OpTy>
151+ struct ConvPadFoldAdaptor {
152+ static bool checkKernelCompliance (OpTy, const ArrayRef<int64_t >) {
153+ return true ;
154+ }
155+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
156+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
157+ }
158+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
159+ Value padInput, ArrayRef<int64_t > newPad) {
160+ rewriter.replaceOpWithNewOp <OpTy>(
161+ op, op.getResult ().getType (), padInput, op.getWeight (), op.getBias (),
162+ op.getInputZp (), op.getWeightZp (), newPad, op.getStrideAttr (),
163+ op.getDilationAttr (), op.getAccType (), op.getLocalBound ());
164+ }
165+ };
166+
167+ // Pattern attempts to fold a `tosa.pad` operator to a following tensor
168+ // operation like `tosa.conv2d` by merging the padding associated with the
169+ // pad operator directly to the implicit padding of the tensor operation.
170+ // This helps eliminate the explicit padding operator if unused.
171+ template <typename OpTy, typename AdaptorTy>
172+ struct FoldPadToTensorOp : public OpRewritePattern <OpTy> {
173+ using OpRewritePattern<OpTy>::OpRewritePattern;
174+
175+ LogicalResult matchAndRewrite (OpTy tensorOp,
176+ PatternRewriter &rewriter) const override {
177+ // Check producer is a tosa::PadOp
178+ auto padOp = tensorOp.getInput ().template getDefiningOp <tosa::PadOp>();
179+ if (!padOp)
180+ return rewriter.notifyMatchFailure (tensorOp,
181+ " Producer must be a tosa::PadOp." );
182+
183+ // Validate that tensor operation has sane padding
184+ const std::vector<int64_t > &tensorOpPad = tensorOp.getPad ().vec ();
185+ if (tensorOpPad.size () != 4 ) // pad_top, pad_bottom, pad_left, pad_right
186+ return rewriter.notifyMatchFailure (
187+ tensorOp, " Tensor operation padding shall have 4 elements." );
188+
189+ // Validate tosa::PadOp padding
190+ DenseIntElementsAttr padOpPadding;
191+ if (!matchPattern (padOp.getPadding (), m_Constant (&padOpPadding))) {
192+ return rewriter.notifyMatchFailure (
193+ tensorOp,
194+ " The `padding` input specified on the tosa::PadOp must be constant." );
195+ }
196+ // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
197+ // C_after
198+ if (padOpPadding.size () != 8 )
199+ return rewriter.notifyMatchFailure (tensorOp,
200+ " Pad padding should have 8 elements." );
201+ int64_t padNBefore = (*(padOpPadding.begin () + 0 )).getLimitedValue ();
202+ int64_t padNAfter = (*(padOpPadding.begin () + 1 )).getLimitedValue ();
203+ int64_t padHBefore = (*(padOpPadding.begin () + 2 )).getLimitedValue ();
204+ int64_t padHAfter = (*(padOpPadding.begin () + 3 )).getLimitedValue ();
205+ int64_t padWBefore = (*(padOpPadding.begin () + 4 )).getLimitedValue ();
206+ int64_t padWAfter = (*(padOpPadding.begin () + 5 )).getLimitedValue ();
207+ int64_t padCBefore = (*(padOpPadding.begin () + 6 )).getLimitedValue ();
208+ int64_t padCAfter = (*(padOpPadding.begin () + 7 )).getLimitedValue ();
209+
210+ if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0 )
211+ return rewriter.notifyMatchFailure (
212+ tensorOp, " Folding padding in N or C dimensions is not supported." );
213+
214+ // Fold padding from Pad into the tensor operation
215+ // 4 elements - pad_top, pad_bottom, pad_left, pad_right
216+ SmallVector<int64_t > foldedPad (tensorOpPad.size ());
217+ foldedPad[0 ] = padHBefore + tensorOpPad[0 ];
218+ foldedPad[1 ] = padHAfter + tensorOpPad[1 ];
219+ foldedPad[2 ] = padWBefore + tensorOpPad[2 ];
220+ foldedPad[3 ] = padWAfter + tensorOpPad[3 ];
221+
222+ // Check kernel related restrictions
223+ if (!AdaptorTy::checkKernelCompliance (tensorOp, foldedPad)) {
224+ return rewriter.notifyMatchFailure (
225+ tensorOp, " Padding size not aligned with kernel restrictions." );
226+ }
227+
228+ // Check padding constant restrictions
229+ if (!AdaptorTy::checkPadConstCompliance (tensorOp, padOp.getPadConst ())) {
230+ return rewriter.notifyMatchFailure (
231+ tensorOp,
232+ " Padding constant is not aligned with operator zero-point." );
233+ }
234+
235+ // Check that padding doesn't grow more than 8K level (8192) for now
236+ if (llvm::any_of (foldedPad, [](int64_t padVal) { return padVal > 8192 ; })) {
237+ return rewriter.notifyMatchFailure (
238+ tensorOp, " Padding size more than the 8K level limit." );
239+ }
240+
241+ // Create operator
242+ AdaptorTy::replaceOpWithNewPad (rewriter, tensorOp, padOp.getInput1 (),
243+ foldedPad);
244+
245+ return success ();
246+ }
247+ };
248+ } // namespace
249+
250+ void AvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
251+ MLIRContext *context) {
252+ results.add <FoldPadToTensorOp<tosa::AvgPool2dOp,
253+ PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
254+ context);
255+ }
256+
257+ void Conv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
258+ MLIRContext *context) {
259+ results.add <
260+ FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
261+ context);
262+ }
263+
264+ void DepthwiseConv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
265+ MLIRContext *context) {
266+ results.add <FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
267+ ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
268+ context);
269+ }
270+
271+ struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
272+ using OpRewritePattern::OpRewritePattern;
273+
274+ LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
275+ PatternRewriter &rewriter) const override {
276+ Value input = op.getInput ();
277+ Value output = op.getOutput ();
278+ ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
279+ ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
280+
281+ if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
282+ return failure ();
283+ }
284+
285+ // If the output and input shapes are 1x1, then this is a no op.
286+ ArrayRef<int64_t > outputShape = outputType.getShape ();
287+ if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
288+ return failure ();
289+ }
290+
291+ ArrayRef<int64_t > inputShape = inputType.getShape ();
292+ if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
293+ return failure ();
294+ }
295+
296+ rewriter.replaceOp (op, input);
297+ return success ();
298+ }
299+ };
300+
301+ void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
302+ MLIRContext *context) {
303+ results.add <MaxPool2dIsNoOp,
304+ FoldPadToTensorOp<tosa::MaxPool2dOp,
305+ PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
306+ context);
307+ }
308+
309+ // ===----------------------------------------------------------------------===//
310+ // Data Layout / Memory Reinterpretation.
311+ // ===----------------------------------------------------------------------===//
312+
42313struct ConcatOptimization : public OpRewritePattern <tosa::ConcatOp> {
43314 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44315
@@ -175,41 +446,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175446 results.add <ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176447}
177448
178- struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
179- using OpRewritePattern::OpRewritePattern;
180-
181- LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
182- PatternRewriter &rewriter) const override {
183- Value input = op.getInput ();
184- Value output = op.getOutput ();
185- ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
186- ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
187-
188- if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
189- return failure ();
190- }
191-
192- // If the output and input shapes are 1x1, then this is a no op.
193- ArrayRef<int64_t > outputShape = outputType.getShape ();
194- if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
195- return failure ();
196- }
197-
198- ArrayRef<int64_t > inputShape = inputType.getShape ();
199- if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
200- return failure ();
201- }
202-
203- rewriter.replaceOp (op, input);
204- return success ();
205- }
206- };
207-
208- void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
209- MLIRContext *context) {
210- results.add <MaxPool2dIsNoOp>(context);
211- }
212-
213449struct ClampIsNoOp : public OpRewritePattern <tosa::ClampOp> {
214450 using OpRewritePattern::OpRewritePattern;
215451
0 commit comments