@@ -692,9 +692,11 @@ const Type *AndINode::mul_ring( const Type *t0, const Type *t1 ) const {
692692 return and_value<TypeInt>(r0, r1);
693693}
694694
695+ static bool AndIL_is_zero_element_under_mask (const PhaseGVN* phase, const Node* expr, const Node* mask, BasicType bt);
696+
695697const Type* AndINode::Value (PhaseGVN* phase) const {
696- // patterns similar to (v << 2) & 3
697- if ( AndIL_shift_and_mask_is_always_zero ( phase, in (1 ), in (2 ), T_INT, true )) {
698+ if ( AndIL_is_zero_element_under_mask (phase, in ( 1 ), in ( 2 ), T_INT) ||
699+ AndIL_is_zero_element_under_mask ( phase, in (2 ), in (1 ), T_INT)) {
698700 return TypeInt::ZERO;
699701 }
700702
@@ -740,8 +742,8 @@ Node* AndINode::Identity(PhaseGVN* phase) {
740742
741743// ------------------------------Ideal------------------------------------------
742744Node *AndINode::Ideal (PhaseGVN *phase, bool can_reshape) {
743- // pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
744- Node* progress = AndIL_add_shift_and_mask (phase, T_INT);
745+ // Simplify (v1 + v2) & mask to v1 & mask or v2 & mask when possible.
746+ Node* progress = AndIL_sum_and_mask (phase, T_INT);
745747 if (progress != nullptr ) {
746748 return progress;
747749 }
@@ -824,8 +826,8 @@ const Type *AndLNode::mul_ring( const Type *t0, const Type *t1 ) const {
824826}
825827
826828const Type* AndLNode::Value (PhaseGVN* phase) const {
827- // patterns similar to (v << 2) & 3
828- if ( AndIL_shift_and_mask_is_always_zero ( phase, in (1 ), in (2 ), T_LONG, true )) {
829+ if ( AndIL_is_zero_element_under_mask (phase, in ( 1 ), in ( 2 ), T_LONG) ||
830+ AndIL_is_zero_element_under_mask ( phase, in (2 ), in (1 ), T_LONG)) {
829831 return TypeLong::ZERO;
830832 }
831833
@@ -872,8 +874,8 @@ Node* AndLNode::Identity(PhaseGVN* phase) {
872874
873875// ------------------------------Ideal------------------------------------------
874876Node *AndLNode::Ideal (PhaseGVN *phase, bool can_reshape) {
875- // pattern similar to (v1 + (v2 << 2)) & 3 transformed to v1 & 3
876- Node* progress = AndIL_add_shift_and_mask (phase, T_LONG);
877+ // Simplify (v1 + v2) & mask to v1 & mask or v2 & mask when possible.
878+ Node* progress = AndIL_sum_and_mask (phase, T_LONG);
877879 if (progress != nullptr ) {
878880 return progress;
879881 }
@@ -2096,99 +2098,109 @@ const Type* RotateRightNode::Value(PhaseGVN* phase) const {
20962098 }
20972099}
20982100
2099- // Given an expression (AndX shift mask) or (AndX mask shift),
2100- // determine if the AndX must always produce zero, because the
2101- // the shift (x<<N) is bitwise disjoint from the mask #M.
2102- // The X in AndX must be I or L, depending on bt.
2103- // Specifically, the following cases fold to zero,
2104- // when the shift value N is large enough to zero out
2105- // all the set positions of the and-mask M.
2106- // (AndI (LShiftI _ #N) #M) => #0
2107- // (AndL (LShiftL _ #N) #M) => #0
2108- // (AndL (ConvI2L (LShiftI _ #N)) #M) => #0
2109- // The M and N values must satisfy ((-1 << N) & M) == 0.
2110- // Because the optimization might work for a non-constant
2111- // mask M, we check the AndX for both operand orders.
2112- bool MulNode::AndIL_shift_and_mask_is_always_zero (PhaseGVN* phase, Node* shift, Node* mask, BasicType bt, bool check_reverse) {
2113- if (mask == nullptr || shift == nullptr ) {
2114- return false ;
2115- }
2116- const TypeInteger* mask_t = phase->type (mask)->isa_integer (bt);
2117- if (mask_t == nullptr || phase->type (shift)->isa_integer (bt) == nullptr ) {
2118- return false ;
2119- }
2120- shift = shift->uncast ();
2121- if (shift == nullptr ) {
2122- return false ;
2101+ // ------------------------------ Sum & Mask ------------------------------
2102+
2103+ // Returns a lower bound on the number of trailing zeros in expr.
2104+ static jint AndIL_min_trailing_zeros (const PhaseGVN* phase, const Node* expr, BasicType bt) {
2105+ expr = expr->uncast ();
2106+ const TypeInteger* type = phase->type (expr)->isa_integer (bt);
2107+ if (type == nullptr ) {
2108+ return 0 ;
21232109 }
2124- if (phase->type (shift)->isa_integer (bt) == nullptr ) {
2125- return false ;
2110+
2111+ if (type->is_con ()) {
2112+ jlong con = type->get_con_as_long (bt);
2113+ return con == 0L ? (type2aelembytes (bt) * BitsPerByte) : count_trailing_zeros (con);
21262114 }
2127- BasicType shift_bt = bt;
2128- if (bt == T_LONG && shift->Opcode () == Op_ConvI2L) {
2115+
2116+ if (expr->Opcode () == Op_ConvI2L) {
2117+ expr = expr->in (1 )->uncast ();
21292118 bt = T_INT;
2130- Node* val = shift->in (1 );
2131- if (val == nullptr ) {
2132- return false ;
2133- }
2134- val = val->uncast ();
2135- if (val == nullptr ) {
2136- return false ;
2137- }
2138- if (val->Opcode () == Op_LShiftI) {
2139- shift_bt = T_INT;
2140- shift = val;
2141- if (phase->type (shift)->isa_integer (bt) == nullptr ) {
2142- return false ;
2143- }
2144- }
2119+ type = phase->type (expr)->isa_int ();
21452120 }
2146- if (shift-> Opcode () != Op_LShift (shift_bt)) {
2147- if (check_reverse &&
2148- (mask ->Opcode () == Op_LShift (bt) ||
2149- (bt == T_LONG && mask-> Opcode () == Op_ConvI2L))) {
2150- // try it the other way around
2151- return AndIL_shift_and_mask_is_always_zero (phase, mask, shift, bt, false ) ;
2121+
2122+ // Pattern: expr = (x << shift)
2123+ if (expr ->Opcode () == Op_LShift (bt)) {
2124+ const TypeInt* shift_t = phase-> type (expr-> in ( 2 ))-> isa_int ();
2125+ if ( shift_t == nullptr || ! shift_t -> is_con ()) {
2126+ return 0 ;
21522127 }
2153- return false ;
2154- }
2155- Node* shift2 = shift->in (2 );
2156- if (shift2 == nullptr ) {
2157- return false ;
2128+ // We need to truncate the shift, as it may not have been canonicalized yet.
2129+ // T_INT: 0..31 -> shift_mask = 4 * 8 - 1 = 31
2130+ // T_LONG: 0..63 -> shift_mask = 8 * 8 - 1 = 63
2131+ // (JLS: "Shift Operators")
2132+ jint shift_mask = type2aelembytes (bt) * BitsPerByte - 1 ;
2133+ return shift_t ->get_con () & shift_mask;
21582134 }
2159- const Type* shift2_t = phase->type (shift2);
2160- if (!shift2_t ->isa_int () || !shift2_t ->is_int ()->is_con ()) {
2135+
2136+ return 0 ;
2137+ }
2138+
2139+ // Checks whether expr is neutral additive element (zero) under mask,
2140+ // i.e. whether an expression of the form:
2141+ // (AndX (AddX (expr addend) mask)
2142+ // (expr + addend) & mask
2143+ // is equivalent to
2144+ // (AndX addend mask)
2145+ // addend & mask
2146+ // for any addend.
2147+ // (The X in AndX must be I or L, depending on bt).
2148+ //
2149+ // We check for the sufficient condition when the lowest set bit in expr is higher than
2150+ // the highest set bit in mask, i.e.:
2151+ // expr: eeeeee0000000000000
2152+ // mask: 000000mmmmmmmmmmmmm
2153+ // <--w bits--->
2154+ // We do not test for other cases.
2155+ //
2156+ // Correctness:
2157+ // Given "expr" with at least "w" trailing zeros,
2158+ // let "mod = 2^w", "suffix_mask = mod - 1"
2159+ //
2160+ // Since "mask" only has bits set where "suffix_mask" does, we have:
2161+ // mask = suffix_mask & mask (SUFFIX_MASK)
2162+ //
2163+ // And since expr only has bits set above w, and suffix_mask only below:
2164+ // expr & suffix_mask == 0 (NO_BIT_OVERLAP)
2165+ //
2166+ // From unsigned modular arithmetic (with unsigned modulo %), and since mod is
2167+ // a power of 2, and we are computing in a ring of powers of 2, we know that
2168+ // (x + y) % mod = (x % mod + y) % mod
2169+ // (x + y) & suffix_mask = (x & suffix_mask + y) & suffix_mask (MOD_ARITH)
2170+ //
2171+ // We can now prove the equality:
2172+ // (expr + addend) & mask
2173+ // = (expr + addend) & suffix_mask & mask (SUFFIX_MASK)
2174+ // = (expr & suffix_mask + addend) & suffix_mask & mask (MOD_ARITH)
2175+ // = (0 + addend) & suffix_mask & mask (NO_BIT_OVERLAP)
2176+ // = addend & mask (SUFFIX_MASK)
2177+ //
2178+ // Hence, an expr with at least w trailing zeros is a neutral additive element under any mask with bit width w.
2179+ static bool AndIL_is_zero_element_under_mask (const PhaseGVN* phase, const Node* expr, const Node* mask, BasicType bt) {
2180+ // When the mask is negative, it has the most significant bit set.
2181+ const TypeInteger* mask_t = phase->type (mask)->isa_integer (bt);
2182+ if (mask_t == nullptr || mask_t ->lo_as_long () < 0 ) {
21612183 return false ;
21622184 }
21632185
2164- jint shift_con = shift2_t ->is_int ()->get_con () & ((shift_bt == T_INT ? BitsPerJavaInteger : BitsPerJavaLong) - 1 );
2165- if ((((jlong)1 ) << shift_con) > mask_t ->hi_as_long () && mask_t ->lo_as_long () >= 0 ) {
2166- return true ;
2186+ // When the mask is constant zero, we defer to MulNode::Value to eliminate the entire AndX operation.
2187+ if (mask_t ->hi_as_long () == 0 ) {
2188+ assert (mask_t ->lo_as_long () == 0 , " checked earlier" );
2189+ return false ;
21672190 }
21682191
2169- return false ;
2192+ jint mask_bit_width = BitsPerLong - count_leading_zeros (mask_t ->hi_as_long ());
2193+ jint expr_trailing_zeros = AndIL_min_trailing_zeros (phase, expr, bt);
2194+ return expr_trailing_zeros >= mask_bit_width;
21702195}
21712196
2172- // Given an expression (AndX (AddX v1 (LShiftX v2 #N)) #M)
2173- // determine if the AndX must always produce (AndX v1 #M),
2174- // because the shift (v2<<N) is bitwise disjoint from the mask #M.
2175- // The X in AndX will be I or L, depending on bt.
2176- // Specifically, the following cases fold,
2177- // when the shift value N is large enough to zero out
2178- // all the set positions of the and-mask M.
2179- // (AndI (AddI v1 (LShiftI _ #N)) #M) => (AndI v1 #M)
2180- // (AndL (AddI v1 (LShiftL _ #N)) #M) => (AndL v1 #M)
2181- // (AndL (AddL v1 (ConvI2L (LShiftI _ #N))) #M) => (AndL v1 #M)
2182- // The M and N values must satisfy ((-1 << N) & M) == 0.
2183- // Because the optimization might work for a non-constant
2184- // mask M, and because the AddX operands can come in either
2185- // order, we check for every operand order.
2186- Node* MulNode::AndIL_add_shift_and_mask (PhaseGVN* phase, BasicType bt) {
2197+ // Reduces the pattern:
2198+ // (AndX (AddX add1 add2) mask)
2199+ // to
2200+ // (AndX add1 mask), if add2 is neutral wrt mask (see above), and vice versa.
2201+ Node* MulNode::AndIL_sum_and_mask (PhaseGVN* phase, BasicType bt) {
21872202 Node* add = in (1 );
21882203 Node* mask = in (2 );
2189- if (add == nullptr || mask == nullptr ) {
2190- return nullptr ;
2191- }
21922204 int addidx = 0 ;
21932205 if (add->Opcode () == Op_Add (bt)) {
21942206 addidx = 1 ;
@@ -2200,14 +2212,12 @@ Node* MulNode::AndIL_add_shift_and_mask(PhaseGVN* phase, BasicType bt) {
22002212 if (addidx > 0 ) {
22012213 Node* add1 = add->in (1 );
22022214 Node* add2 = add->in (2 );
2203- if (add1 != nullptr && add2 != nullptr ) {
2204- if (AndIL_shift_and_mask_is_always_zero (phase, add1, mask, bt, false )) {
2205- set_req_X (addidx, add2, phase);
2206- return this ;
2207- } else if (AndIL_shift_and_mask_is_always_zero (phase, add2, mask, bt, false )) {
2208- set_req_X (addidx, add1, phase);
2209- return this ;
2210- }
2215+ if (AndIL_is_zero_element_under_mask (phase, add1, mask, bt)) {
2216+ set_req_X (addidx, add2, phase);
2217+ return this ;
2218+ } else if (AndIL_is_zero_element_under_mask (phase, add2, mask, bt)) {
2219+ set_req_X (addidx, add1, phase);
2220+ return this ;
22112221 }
22122222 }
22132223 return nullptr ;
0 commit comments