@@ -821,6 +821,153 @@ Log1pApproximation::matchAndRewrite(math::Log1pOp op,
821821 return success ();
822822}
823823
824+ // ----------------------------------------------------------------------------//
825+ // Asin approximation.
826+ // ----------------------------------------------------------------------------//
827+
828+ // Approximates asin(x).
829+ // This approximation is based on the following stackoverflow post:
830+ // https://stackoverflow.com/a/42683455
831+ namespace {
832+ struct AsinPolynomialApproximation : public OpRewritePattern <math::AsinOp> {
833+ public:
834+ using OpRewritePattern::OpRewritePattern;
835+
836+ LogicalResult matchAndRewrite (math::AsinOp op,
837+ PatternRewriter &rewriter) const final ;
838+ };
839+ } // namespace
840+ LogicalResult
841+ AsinPolynomialApproximation::matchAndRewrite (math::AsinOp op,
842+ PatternRewriter &rewriter) const {
843+ Value operand = op.getOperand ();
844+ Type elementType = getElementTypeOrSelf (operand);
845+
846+ if (!(elementType.isF32 () || elementType.isF16 ()))
847+ return rewriter.notifyMatchFailure (op,
848+ " only f32 and f16 type is supported." );
849+ VectorShape shape = vectorShape (operand);
850+
851+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
852+ auto bcast = [&](Value value) -> Value {
853+ return broadcast (builder, value, shape);
854+ };
855+
856+ auto fma = [&](Value a, Value b, Value c) -> Value {
857+ return builder.create <math::FmaOp>(a, b, c);
858+ };
859+
860+ auto mul = [&](Value a, Value b) -> Value {
861+ return builder.create <arith::MulFOp>(a, b);
862+ };
863+
864+ Value s = mul (operand, operand);
865+ Value q = mul (s, s);
866+ Value r = bcast (floatCst (builder, 5.5579749017470502e-2 , elementType));
867+ Value t = bcast (floatCst (builder, -6.2027913464120114e-2 , elementType));
868+
869+ r = fma (r, q, bcast (floatCst (builder, 5.4224464349245036e-2 , elementType)));
870+ t = fma (t, q, bcast (floatCst (builder, -1.1326992890324464e-2 , elementType)));
871+ r = fma (r, q, bcast (floatCst (builder, 1.5268872539397656e-2 , elementType)));
872+ t = fma (t, q, bcast (floatCst (builder, 1.0493798473372081e-2 , elementType)));
873+ r = fma (r, q, bcast (floatCst (builder, 1.4106045900607047e-2 , elementType)));
874+ t = fma (t, q, bcast (floatCst (builder, 1.7339776384962050e-2 , elementType)));
875+ r = fma (r, q, bcast (floatCst (builder, 2.2372961589651054e-2 , elementType)));
876+ t = fma (t, q, bcast (floatCst (builder, 3.0381912707941005e-2 , elementType)));
877+ r = fma (r, q, bcast (floatCst (builder, 4.4642857881094775e-2 , elementType)));
878+ t = fma (t, q, bcast (floatCst (builder, 7.4999999991367292e-2 , elementType)));
879+ r = fma (r, s, t);
880+ r = fma (r, s, bcast (floatCst (builder, 1.6666666666670193e-1 , elementType)));
881+ t = mul (operand, s);
882+ r = fma (r, t, operand);
883+
884+ rewriter.replaceOp (op, r);
885+ return success ();
886+ }
887+
888+ // ----------------------------------------------------------------------------//
889+ // Acos approximation.
890+ // ----------------------------------------------------------------------------//
891+
892+ // Approximates acos(x).
893+ // This approximation is based on the following stackoverflow post:
894+ // https://stackoverflow.com/a/42683455
895+ namespace {
896+ struct AcosPolynomialApproximation : public OpRewritePattern <math::AcosOp> {
897+ public:
898+ using OpRewritePattern::OpRewritePattern;
899+
900+ LogicalResult matchAndRewrite (math::AcosOp op,
901+ PatternRewriter &rewriter) const final ;
902+ };
903+ } // namespace
904+ LogicalResult
905+ AcosPolynomialApproximation::matchAndRewrite (math::AcosOp op,
906+ PatternRewriter &rewriter) const {
907+ Value operand = op.getOperand ();
908+ Type elementType = getElementTypeOrSelf (operand);
909+
910+ if (!(elementType.isF32 () || elementType.isF16 ()))
911+ return rewriter.notifyMatchFailure (op,
912+ " only f32 and f16 type is supported." );
913+ VectorShape shape = vectorShape (operand);
914+
915+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
916+ auto bcast = [&](Value value) -> Value {
917+ return broadcast (builder, value, shape);
918+ };
919+
920+ auto fma = [&](Value a, Value b, Value c) -> Value {
921+ return builder.create <math::FmaOp>(a, b, c);
922+ };
923+
924+ auto mul = [&](Value a, Value b) -> Value {
925+ return builder.create <arith::MulFOp>(a, b);
926+ };
927+
928+ Value negOperand = builder.create <arith::NegFOp>(operand);
929+ Value zero = bcast (floatCst (builder, 0.0 , elementType));
930+ Value half = bcast (floatCst (builder, 0.5 , elementType));
931+ Value negOne = bcast (floatCst (builder, -1.0 , elementType));
932+ Value selR =
933+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
934+ Value r = builder.create <arith::SelectOp>(selR, negOperand, operand);
935+ Value chkConst = bcast (floatCst (builder, -0.5625 , elementType));
936+ Value firstPred =
937+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
938+
939+ Value trueVal =
940+ fma (bcast (floatCst (builder, 9.3282184640716537e-1 , elementType)),
941+ bcast (floatCst (builder, 1.6839188885261840e+0 , elementType)),
942+ builder.create <math::AsinOp>(r));
943+
944+ Value falseVal = builder.create <math::SqrtOp>(fma (half, r, half));
945+ falseVal = builder.create <math::AsinOp>(falseVal);
946+ falseVal = mul (bcast (floatCst (builder, 2.0 , elementType)), falseVal);
947+
948+ r = builder.create <arith::SelectOp>(firstPred, trueVal, falseVal);
949+
950+ // Check whether the operand lies in between [-1.0, 0.0).
951+ Value greaterThanNegOne =
952+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
953+
954+ Value lessThanZero =
955+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
956+
957+ Value betweenNegOneZero =
958+ builder.create <arith::AndIOp>(greaterThanNegOne, lessThanZero);
959+
960+ trueVal = fma (bcast (floatCst (builder, 1.8656436928143307e+0 , elementType)),
961+ bcast (floatCst (builder, 1.6839188885261840e+0 , elementType)),
962+ builder.create <arith::NegFOp>(r));
963+
964+ Value finalVal =
965+ builder.create <arith::SelectOp>(betweenNegOneZero, trueVal, r);
966+
967+ rewriter.replaceOp (op, finalVal);
968+ return success ();
969+ }
970+
824971// ----------------------------------------------------------------------------//
825972// Erf approximation.
826973// ----------------------------------------------------------------------------//
@@ -1505,12 +1652,13 @@ void mlir::populateMathPolynomialApproximationPatterns(
15051652 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
15061653 patterns.getContext ());
15071654
1508- patterns.add <AtanApproximation, Atan2Approximation, TanhApproximation,
1509- LogApproximation, Log2Approximation, Log1pApproximation,
1510- ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1511- CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1512- SinAndCosApproximation<false , math::CosOp>>(
1513- patterns.getContext ());
1655+ patterns
1656+ .add <AtanApproximation, Atan2Approximation, TanhApproximation,
1657+ LogApproximation, Log2Approximation, Log1pApproximation,
1658+ ErfPolynomialApproximation, AsinPolynomialApproximation,
1659+ AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1660+ CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1661+ SinAndCosApproximation<false , math::CosOp>>(patterns.getContext ());
15141662 if (options.enableAvx2 ) {
15151663 patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
15161664 patterns.getContext ());
0 commit comments