@@ -39,6 +39,273 @@ using namespace mlir::tosa;
3939// Operator Canonicalizers.
4040// ===----------------------------------------------------------------------===//
4141
42+ // ===----------------------------------------------------------------------===//
43+ // Tensor Data Engine Operators.
44+ // ===----------------------------------------------------------------------===//
45+
46+ // Check that the zero point of the tensor and padding operations are aligned.
47+ bool checkMatchingPadConstAndZp (Value padConst, Value zp) {
48+ // Check that padConst is a constant value and a scalar tensor
49+ DenseElementsAttr padConstAttr;
50+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
51+ (padConstAttr.size () != 1 )) {
52+ return false ;
53+ }
54+
55+ // Check that floating point pad is zero
56+ if (auto padConstFpAttr = mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
57+ float padConstVal = (*padConstFpAttr.begin ()).convertToFloat ();
58+ return padConstVal == 0 .0f ;
59+ }
60+
61+ // Check that the zp and padConst align for the integer (quantized) case
62+ if (auto padConstIntAttr =
63+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
64+ DenseIntElementsAttr zpAttr;
65+ // Check that zp is a constant value and a scalar tensor
66+ if (!matchPattern (zp, m_Constant (&zpAttr)) || (padConstAttr.size () != 1 )) {
67+ return false ;
68+ }
69+
70+ // Check equality
71+ int64_t zpVal = (*zpAttr.begin ()).getSExtValue ();
72+ int64_t padConstVal = (*padConstIntAttr.begin ()).getSExtValue ();
73+ return zpVal == padConstVal;
74+ }
75+
76+ // Bail-out on unsupported type
77+ return false ;
78+ }
79+
80+ namespace {
81+ template <typename OpTy>
82+ struct PoolPadFoldAdaptor ;
83+
84+ template <>
85+ struct PoolPadFoldAdaptor <tosa::AvgPool2dOp> {
86+ using OpTy = tosa::AvgPool2dOp;
87+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
88+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
89+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
90+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
91+ return false ;
92+ return true ;
93+ }
94+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
95+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
96+ }
97+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
98+ Value padInput, ArrayRef<int64_t > newPad) {
99+ rewriter.replaceOpWithNewOp <tosa::AvgPool2dOp>(
100+ op, op.getType (), padInput, op.getInputZp (), op.getOutputZp (),
101+ op.getKernel (), op.getStride (), rewriter.getDenseI64ArrayAttr (newPad),
102+ op.getAccType ());
103+ }
104+ };
105+
106+ template <>
107+ struct PoolPadFoldAdaptor <tosa::MaxPool2dOp> {
108+ using OpTy = tosa::MaxPool2dOp;
109+ static bool checkKernelCompliance (OpTy op, const ArrayRef<int64_t > newPad) {
110+ const llvm::ArrayRef<int64_t > kernel = op.getKernel ();
111+ if (newPad[2 ] >= kernel[1 ] || newPad[3 ] >= kernel[1 ] ||
112+ newPad[0 ] >= kernel[0 ] || newPad[1 ] >= kernel[0 ])
113+ return false ;
114+ return true ;
115+ }
116+ static bool checkPadConstCompliance (OpTy, Value padConst) {
117+ // Check that padConst is a constant value and a scalar tensor
118+ DenseElementsAttr padConstAttr;
119+ if (!matchPattern (padConst, m_Constant (&padConstAttr)) ||
120+ padConstAttr.size () != 1 ) {
121+ return false ;
122+ }
123+
124+ // Pad needs to be in the minimum value to be able to merge
125+ if (auto padConstFpAttr =
126+ mlir::dyn_cast<DenseFPElementsAttr>(padConstAttr)) {
127+ float padConstVal = (*padConstFpAttr.begin ()).convertToFloat ();
128+ return padConstVal == std::numeric_limits<float >::lowest ();
129+ } else if (auto padConstIntAttr =
130+ mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) {
131+ int64_t padConstVal = (*padConstIntAttr.begin ()).getSExtValue ();
132+ return padConstVal == std::numeric_limits<int32_t >::lowest ();
133+ }
134+
135+ // Bail-out on unsupported type
136+ return false ;
137+ }
138+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
139+ Value padInput, ArrayRef<int64_t > newPad) {
140+ rewriter.replaceOpWithNewOp <tosa::MaxPool2dOp>(
141+ op, op.getType (), padInput, op.getKernel (), op.getStride (),
142+ rewriter.getDenseI64ArrayAttr (newPad), op.getNanMode ());
143+ }
144+ };
145+
146+ template <typename OpTy>
147+ struct ConvPadFoldAdaptor {
148+ static bool checkKernelCompliance (OpTy, const ArrayRef<int64_t >) {
149+ return true ;
150+ }
151+ static bool checkPadConstCompliance (OpTy op, Value padConst) {
152+ return checkMatchingPadConstAndZp (padConst, op.getInputZp ());
153+ }
154+ static void replaceOpWithNewPad (PatternRewriter &rewriter, OpTy op,
155+ Value padInput, ArrayRef<int64_t > newPad) {
156+ rewriter.replaceOpWithNewOp <OpTy>(
157+ op, op.getResult ().getType (), padInput, op.getWeight (), op.getBias (),
158+ op.getInputZp (), op.getWeightZp (), newPad, op.getStrideAttr (),
159+ op.getDilationAttr (), op.getAccType (), op.getLocalBound ());
160+ }
161+ };
162+
163+ // Pattern attempts to fold a `tosa.pad` operator to a following tensor
164+ // operation like `tosa.conv2d` by merging the padding associated with the
165+ // pad operator directly to the implicit padding of the tensor operation.
166+ // This helps eliminate the explicit padding operator if unused.
167+ template <typename OpTy, typename AdaptorTy>
168+ struct FoldPadToTensorOp : public OpRewritePattern <OpTy> {
169+ using OpRewritePattern<OpTy>::OpRewritePattern;
170+
171+ LogicalResult matchAndRewrite (OpTy tensorOp,
172+ PatternRewriter &rewriter) const override {
173+ // Check producer is a tosa::PadOp
174+ auto padOp = tensorOp.getInput ().template getDefiningOp <tosa::PadOp>();
175+ if (!padOp)
176+ return rewriter.notifyMatchFailure (tensorOp,
177+ " Producer must be a tosa::PadOp." );
178+
179+ // Validate that tensor operation has sane padding
180+ const std::vector<int64_t > &tensorOpPad = tensorOp.getPad ().vec ();
181+ if (tensorOpPad.size () != 4 ) // pad_top, pad_bottom, pad_left, pad_right
182+ return rewriter.notifyMatchFailure (
183+ tensorOp, " Tensor operation padding shall have 4 elements." );
184+
185+ // Validate tosa::PadOp padding
186+ DenseIntElementsAttr padOpPadding;
187+ if (!matchPattern (padOp.getPadding (), m_Constant (&padOpPadding))) {
188+ return rewriter.notifyMatchFailure (
189+ tensorOp,
190+ " The `padding` input specified on the tosa::PadOp must be constant." );
191+ }
192+ // N_before, N_after, H_before, H_after, W_before, W_after, C_before,
193+ // C_after
194+ if (padOpPadding.size () != 8 )
195+ return rewriter.notifyMatchFailure (tensorOp,
196+ " Pad padding should have 8 elements." );
197+ int64_t padNBefore = (*(padOpPadding.begin () + 0 )).getLimitedValue ();
198+ int64_t padNAfter = (*(padOpPadding.begin () + 1 )).getLimitedValue ();
199+ int64_t padHBefore = (*(padOpPadding.begin () + 2 )).getLimitedValue ();
200+ int64_t padHAfter = (*(padOpPadding.begin () + 3 )).getLimitedValue ();
201+ int64_t padWBefore = (*(padOpPadding.begin () + 4 )).getLimitedValue ();
202+ int64_t padWAfter = (*(padOpPadding.begin () + 5 )).getLimitedValue ();
203+ int64_t padCBefore = (*(padOpPadding.begin () + 6 )).getLimitedValue ();
204+ int64_t padCAfter = (*(padOpPadding.begin () + 7 )).getLimitedValue ();
205+
206+ if (padNBefore != 0 || padNAfter != 0 || padCBefore != 0 || padCAfter != 0 )
207+ return rewriter.notifyMatchFailure (
208+ tensorOp, " Folding padding in N or C dimensions is not supported." );
209+
210+ // Fold padding from Pad into the tensor operation
211+ // 4 elements - pad_top, pad_bottom, pad_left, pad_right
212+ SmallVector<int64_t > foldedPad (tensorOpPad.size ());
213+ foldedPad[0 ] = padHBefore + tensorOpPad[0 ];
214+ foldedPad[1 ] = padHAfter + tensorOpPad[1 ];
215+ foldedPad[2 ] = padWBefore + tensorOpPad[2 ];
216+ foldedPad[3 ] = padWAfter + tensorOpPad[3 ];
217+
218+ // Check kernel related restrictions
219+ if (!AdaptorTy::checkKernelCompliance (tensorOp, foldedPad)) {
220+ return rewriter.notifyMatchFailure (
221+ tensorOp, " Padding size not aligned with kernel restrictions." );
222+ }
223+
224+ // Check padding constant restrictions
225+ if (!AdaptorTy::checkPadConstCompliance (tensorOp, padOp.getPadConst ())) {
226+ return rewriter.notifyMatchFailure (
227+ tensorOp,
228+ " Padding constant is not aligned with operator zero-point." );
229+ }
230+
231+ // Check that padding doesn't grow more than 8K level (8192) for now
232+ if (llvm::any_of (foldedPad, [](int64_t padVal) { return padVal > 8192 ; })) {
233+ return rewriter.notifyMatchFailure (
234+ tensorOp, " Padding size more than the 8K level limit." );
235+ }
236+
237+ // Create operator
238+ AdaptorTy::replaceOpWithNewPad (rewriter, tensorOp, padOp.getInput1 (),
239+ foldedPad);
240+
241+ return success ();
242+ }
243+ };
244+ } // namespace
245+
246+ void AvgPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
247+ MLIRContext *context) {
248+ results.add <FoldPadToTensorOp<tosa::AvgPool2dOp,
249+ PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>(
250+ context);
251+ }
252+
253+ void Conv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
254+ MLIRContext *context) {
255+ results.add <
256+ FoldPadToTensorOp<tosa::Conv2DOp, ConvPadFoldAdaptor<tosa::Conv2DOp>>>(
257+ context);
258+ }
259+
260+ void DepthwiseConv2DOp::getCanonicalizationPatterns (RewritePatternSet &results,
261+ MLIRContext *context) {
262+ results.add <FoldPadToTensorOp<tosa::DepthwiseConv2DOp,
263+ ConvPadFoldAdaptor<tosa::DepthwiseConv2DOp>>>(
264+ context);
265+ }
266+
267+ struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
268+ using OpRewritePattern::OpRewritePattern;
269+
270+ LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
271+ PatternRewriter &rewriter) const override {
272+ Value input = op.getInput ();
273+ Value output = op.getOutput ();
274+ ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
275+ ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
276+
277+ if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
278+ return failure ();
279+ }
280+
281+ // If the output and input shapes are 1x1, then this is a no op.
282+ ArrayRef<int64_t > outputShape = outputType.getShape ();
283+ if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
284+ return failure ();
285+ }
286+
287+ ArrayRef<int64_t > inputShape = inputType.getShape ();
288+ if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
289+ return failure ();
290+ }
291+
292+ rewriter.replaceOp (op, input);
293+ return success ();
294+ }
295+ };
296+
297+ void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
298+ MLIRContext *context) {
299+ results.add <MaxPool2dIsNoOp,
300+ FoldPadToTensorOp<tosa::MaxPool2dOp,
301+ PoolPadFoldAdaptor<tosa::MaxPool2dOp>>>(
302+ context);
303+ }
304+
305+ // ===----------------------------------------------------------------------===//
306+ // Data Layout / Memory Reinterpretation.
307+ // ===----------------------------------------------------------------------===//
308+
42309struct ConcatOptimization : public OpRewritePattern <tosa::ConcatOp> {
43310 using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
44311
@@ -175,41 +442,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
175442 results.add <ConsolidateTransposeOptimization, TransposeIsReshape>(context);
176443}
177444
178- struct MaxPool2dIsNoOp : public OpRewritePattern <tosa::MaxPool2dOp> {
179- using OpRewritePattern::OpRewritePattern;
180-
181- LogicalResult matchAndRewrite (tosa::MaxPool2dOp op,
182- PatternRewriter &rewriter) const override {
183- Value input = op.getInput ();
184- Value output = op.getOutput ();
185- ShapedType inputType = llvm::cast<ShapedType>(input.getType ());
186- ShapedType outputType = llvm::cast<ShapedType>(output.getType ());
187-
188- if (!inputType.hasStaticShape () || !outputType.hasStaticShape ()) {
189- return failure ();
190- }
191-
192- // If the output and input shapes are 1x1, then this is a no op.
193- ArrayRef<int64_t > outputShape = outputType.getShape ();
194- if (outputShape[1 ] != 1 || outputShape[2 ] != 1 ) {
195- return failure ();
196- }
197-
198- ArrayRef<int64_t > inputShape = inputType.getShape ();
199- if (inputShape[1 ] != 1 || inputShape[2 ] != 1 ) {
200- return failure ();
201- }
202-
203- rewriter.replaceOp (op, input);
204- return success ();
205- }
206- };
207-
208- void MaxPool2dOp::getCanonicalizationPatterns (RewritePatternSet &results,
209- MLIRContext *context) {
210- results.add <MaxPool2dIsNoOp>(context);
211- }
212-
213445struct ClampIsNoOp : public OpRewritePattern <tosa::ClampOp> {
214446 using OpRewritePattern::OpRewritePattern;
215447
0 commit comments