@@ -164,92 +164,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
164164 results.add <NoOpOptimization>(context);
165165}
166166
167- struct AddZeroOptimization : public OpRewritePattern <tosa::AddOp> {
168- using OpRewritePattern::OpRewritePattern;
169-
170- LogicalResult matchAndRewrite (tosa::AddOp op,
171- PatternRewriter &rewriter) const override {
172- auto input1 = op.getInput1 ();
173- auto input2 = op.getInput2 ();
174-
175- DenseElementsAttr input1Attr;
176- if (matchPattern (input1, m_Constant (&input1Attr)) && input1Attr.isSplat () &&
177- input2.getType () == op.getType ()) {
178- if (input1Attr.getType ().getElementType ().isa <IntegerType>() &&
179- input1Attr.getSplatValue <APInt>().isZero ()) {
180- rewriter.replaceOp (op, op.getInput2 ());
181- return success ();
182- }
183- }
184-
185- DenseElementsAttr input2Attr;
186- if (matchPattern (input2, m_Constant (&input2Attr)) && input2Attr.isSplat () &&
187- input1.getType () == op.getType ()) {
188- if (input2Attr.getType ().getElementType ().isa <IntegerType>() &&
189- input2Attr.getSplatValue <APInt>().isZero ()) {
190- rewriter.replaceOp (op, op.getInput1 ());
191- return success ();
192- }
193- }
194-
195- return failure ();
196- }
197- };
198-
199- void AddOp::getCanonicalizationPatterns (RewritePatternSet &results,
200- MLIRContext *context) {
201- results.add <AddZeroOptimization>(context);
202- }
203-
204- struct MulOneOptimization : public OpRewritePattern <tosa::MulOp> {
205- using OpRewritePattern::OpRewritePattern;
206-
207- LogicalResult matchAndRewrite (tosa::MulOp op,
208- PatternRewriter &rewriter) const override {
209- auto input1 = op.getInput1 ();
210- auto input2 = op.getInput2 ();
211-
212- DenseElementsAttr input1Attr;
213- if (matchPattern (input1, m_Constant (&input1Attr)) && input1Attr.isSplat () &&
214- input2.getType () == op.getType ()) {
215- if (input1Attr.getType ().getElementType ().isa <FloatType>() &&
216- input1Attr.getSplatValue <APFloat>().isExactlyValue (1 )) {
217- rewriter.replaceOp (op, op.getInput2 ());
218- return success ();
219- }
220-
221- if (input1Attr.getType ().getElementType ().isa <IntegerType>() &&
222- matchPattern (input1, m_One ())) {
223- rewriter.replaceOp (op, op.getInput2 ());
224- return success ();
225- }
226- }
227-
228- DenseElementsAttr input2Attr;
229- if (matchPattern (input2, m_Constant (&input2Attr)) && input2Attr.isSplat () &&
230- input1.getType () == op.getType ()) {
231- if (input2Attr.getType ().getElementType ().isa <FloatType>() &&
232- input2Attr.getSplatValue <APFloat>().isExactlyValue (1 )) {
233- rewriter.replaceOp (op, op.getInput1 ());
234- return success ();
235- }
236-
237- if (input2Attr.getType ().getElementType ().isa <IntegerType>() &&
238- matchPattern (input2, m_One ())) {
239- rewriter.replaceOp (op, op.getInput1 ());
240- return success ();
241- }
242- }
243-
244- return failure ();
245- }
246- };
247-
248- void MulOp::getCanonicalizationPatterns (RewritePatternSet &results,
249- MLIRContext *context) {
250- results.add <MulOneOptimization>(context);
251- }
252-
253167struct MaterializePadValue : public OpRewritePattern <tosa::PadOp> {
254168 using OpRewritePattern::OpRewritePattern;
255169
@@ -468,64 +382,47 @@ DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
468382 return {};
469383}
470384
385+ static bool isSplatZero (Type elemType, DenseElementsAttr val) {
386+ if (elemType.isa <FloatType>())
387+ return val && val.isSplat () && val.getSplatValue <APFloat>().isZero ();
388+ if (elemType.isa <IntegerType>())
389+ return val && val.isSplat () && val.getSplatValue <APInt>().isZero ();
390+ return false ;
391+ }
392+
393+ static bool isSplatOne (Type elemType, DenseElementsAttr val, int64_t shift) {
394+ if (elemType.isa <FloatType>())
395+ return val && val.isSplat () &&
396+ val.getSplatValue <APFloat>().isExactlyValue (1.0 );
397+ if (elemType.isa <IntegerType>()) {
398+ const int64_t shifted = 1LL << shift;
399+ return val && val.isSplat () &&
400+ val.getSplatValue <APInt>().getSExtValue () == shifted;
401+ }
402+ return false ;
403+ }
404+
471405OpFoldResult AddOp::fold (ArrayRef<Attribute> operands) {
472406 auto lhsTy = getInput1 ().getType ().dyn_cast <RankedTensorType>();
473407 auto rhsTy = getInput2 ().getType ().dyn_cast <RankedTensorType>();
474408 auto resultTy = getType ().dyn_cast <RankedTensorType>();
475409 if (!lhsTy || !rhsTy || !resultTy)
476410 return {};
477-
411+
478412 auto resultETy = resultTy.getElementType ();
479413 auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
480414 auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
481415
482- if (lhsTy == resultTy) {
483- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
484- if (rhsAttr.getSplatValue <APFloat>().isZero ())
485- return getInput1 ();
486- }
487- }
488-
489- if (lhsTy != rhsTy) {
490- if (lhsAttr && rhsAttr) {
491- if (lhsTy == resultTy && rhsAttr.isSplat ()) {
492- APFloat r = rhsAttr.getSplatValue <APFloat>();
493- std::vector<APFloat> v;
494- v.resize (lhsAttr.size (), APFloat (0.0 ));
495- for (int i=0 ;i<lhsAttr.size (); ++i) {
496- v[i] = lhsAttr.getValues <APFloat>()[i] + r;
497- }
498- return DenseElementsAttr::get (resultTy, v);
499- }
500- }
501- }
502-
503-
504- if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <FloatType>()) {
505- if (lhsAttr.getSplatValue <APFloat>().isZero ())
506- return getInput2 ();
507- }
508-
509- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
510- if (rhsAttr.getSplatValue <APFloat>().isZero ())
511- return getInput1 ();
512- }
513-
514- if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
515- if (lhsAttr.getSplatValue <APInt>().isZero ())
516- return getInput2 ();
517- }
518-
519- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
520- if (rhsAttr.getSplatValue <APInt>().isZero ())
521- return getInput1 ();
522- }
416+ if (lhsTy == resultTy && isSplatZero (resultETy, rhsAttr))
417+ return getInput1 ();
418+ if (rhsTy == resultTy && isSplatZero (resultETy, lhsAttr))
419+ return getInput2 ();
523420
524421 if (!lhsAttr || !rhsAttr)
525422 return {};
526423
527424 return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
528- lhsTy );
425+ resultTy );
529426}
530427
531428OpFoldResult DivOp::fold (ArrayRef<Attribute> operands) {
@@ -603,50 +500,26 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
603500 auto resultTy = getType ().dyn_cast <RankedTensorType>();
604501 if (!lhsTy || !rhsTy || !resultTy)
605502 return {};
606- if (lhsTy != rhsTy)
607- return {};
608503
609504 auto resultETy = resultTy.getElementType ();
610505 auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
611506 auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
612507
613- if (lhsAttr && lhsAttr. isSplat () && resultETy.isa <FloatType >()) {
614- auto val = lhsAttr. getSplatValue <APFloat>();
615- if (val. isZero ( ))
508+ const int64_t shift = resultETy.isa <IntegerType >() ? getShift () : 0 ;
509+ if (rhsTy == resultTy) {
510+ if (isSplatZero (resultETy, lhsAttr ))
616511 return lhsAttr;
617- if (val. isExactlyValue ( 1.0 ))
512+ if (isSplatOne (resultETy, lhsAttr, shift ))
618513 return rhs;
619514 }
620-
621- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
622- auto val = rhsAttr.getSplatValue <APFloat>();
623- if (val.isZero ())
624- return rhsAttr;
625- if (val.isExactlyValue (1.0 ))
626- return lhs;
627- }
628-
629- if (lhsAttr && lhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
630- auto val = lhsAttr.getSplatValue <APInt>();
631- if (val.isZero ())
632- return lhsAttr;
633- const int64_t shift = getShift ();
634- const int64_t shifted = 1LL << shift;
635- if (val.getSExtValue () == shifted)
636- return rhs;
637- }
638-
639- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
640- auto val = rhsAttr.getSplatValue <APInt>();
641- const int64_t shift = getShift ();
642- const int64_t shifted = 1LL << shift;
643- if (val.isZero ())
515+ if (lhsTy == resultTy) {
516+ if (isSplatZero (resultETy, rhsAttr))
644517 return rhsAttr;
645- if (val. getSExtValue () == shifted )
518+ if (isSplatOne (resultETy, rhsAttr, shift) )
646519 return lhs;
647520 }
648521
649- return mulBinaryFolder (lhsAttr, rhsAttr, lhsTy , getShift ());
522+ return mulBinaryFolder (lhsAttr, rhsAttr, resultTy , getShift ());
650523}
651524
652525OpFoldResult SubOp::fold (ArrayRef<Attribute> operands) {
@@ -655,28 +528,18 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
655528 auto resultTy = getType ().dyn_cast <RankedTensorType>();
656529 if (!lhsTy || !rhsTy || !resultTy)
657530 return {};
658- if (lhsTy != rhsTy)
659- return {};
660531
661532 auto resultETy = resultTy.getElementType ();
662533 auto lhsAttr = operands[0 ].dyn_cast_or_null <DenseElementsAttr>();
663534 auto rhsAttr = operands[1 ].dyn_cast_or_null <DenseElementsAttr>();
664-
665- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <FloatType>()) {
666- if (rhsAttr.getSplatValue <APFloat>().isZero ())
667- return getInput1 ();
668- }
669-
670- if (rhsAttr && rhsAttr.isSplat () && resultETy.isa <IntegerType>()) {
671- if (rhsAttr.getSplatValue <APInt>().isZero ())
672- return getInput1 ();
673- }
535+ if (lhsTy == resultTy && isSplatZero (resultETy, rhsAttr))
536+ return getInput1 ();
674537
675538 if (!lhsAttr || !rhsAttr)
676539 return {};
677540
678541 return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
679- lhsTy );
542+ resultTy );
680543}
681544
682545namespace {
@@ -917,7 +780,7 @@ OpFoldResult RsqrtOp::fold(FoldAdaptor adaptor) {
917780 auto operand = adaptor.getInput1().dyn_cast_or_null<ElementsAttr>();
918781 if (!operand)
919782 return {};
920-
783+
921784 if (!inputTy.getElementType().isF32())
922785 return {};
923786
@@ -947,7 +810,7 @@ OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
947810 auto operand2 = adaptor.getInput2().dyn_cast_or_null<ElementsAttr>();
948811 if (!operand2)
949812 return {};
950-
813+
951814 if (!operand1.getElementType().isF32() || !operand2.getElementType().isF32())
952815 return {};
953816
@@ -961,7 +824,7 @@ OpFoldResult PowOp::fold(FoldAdaptor adaptor) {
961824
962825OpFoldResult ReciprocalOp::fold(FoldAdaptor adaptor) {
963826 auto src = adaptor.getInput1().dyn_cast_or_null<mlir::DenseElementsAttr>();
964-
827+
965828 if (!src)
966829 return nullptr;
967830
@@ -989,7 +852,6 @@ OpFoldResult ReverseOp::fold(ArrayRef<Attribute> operands) {
989852 return {};
990853}
991854
992-
993855OpFoldResult SliceOp::fold (ArrayRef<Attribute> operands) {
994856 auto inputTy = getInput ().getType ().dyn_cast <RankedTensorType>();
995857 auto outputTy = getType ().dyn_cast <RankedTensorType>();
0 commit comments