diff --git a/src/hotspot/share/opto/countbitsnode.cpp b/src/hotspot/share/opto/countbitsnode.cpp index aac874e94b1fa..1601b33ea2b6c 100644 --- a/src/hotspot/share/opto/countbitsnode.cpp +++ b/src/hotspot/share/opto/countbitsnode.cpp @@ -26,97 +26,114 @@ #include "opto/opcodes.hpp" #include "opto/phaseX.hpp" #include "opto/type.hpp" +#include "utilities/count_leading_zeros.hpp" +#include "utilities/count_trailing_zeros.hpp" #include "utilities/population_count.hpp" +static int count_leading_zeros_int(jint i) { + return i == 0 ? BitsPerInt : count_leading_zeros(i); +} + +static int count_leading_zeros_long(jlong l) { + return l == 0 ? BitsPerLong : count_leading_zeros(l); +} + +static int count_trailing_zeros_int(jint i) { + return i == 0 ? BitsPerInt : count_trailing_zeros(i); +} + +static int count_trailing_zeros_long(jlong l) { + return l == 0 ? BitsPerLong : count_trailing_zeros(l); +} + //------------------------------Value------------------------------------------ const Type* CountLeadingZerosINode::Value(PhaseGVN* phase) const { const Type* t = phase->type(in(1)); - if (t == Type::TOP) return Type::TOP; - const TypeInt* ti = t->isa_int(); - if (ti && ti->is_con()) { - jint i = ti->get_con(); - // HD, Figure 5-6 - if (i == 0) - return TypeInt::make(BitsPerInt); - int n = 1; - unsigned int x = i; - if (x >> 16 == 0) { n += 16; x <<= 16; } - if (x >> 24 == 0) { n += 8; x <<= 8; } - if (x >> 28 == 0) { n += 4; x <<= 4; } - if (x >> 30 == 0) { n += 2; x <<= 2; } - n -= x >> 31; - return TypeInt::make(n); + if (t == Type::TOP) { + return Type::TOP; } - return TypeInt::INT; + + // To minimize `count_leading_zeros(x)`, we should make the highest 1 bit in x + // as far to the left as possible. A bit in x can be 1 iff this bit is not + // forced to be 0, i.e. the corresponding bit in `x._bits._zeros` is 0. Thus: + // min(clz(x)) = number of bits to the left of the highest 0 bit in x._bits._zeros + // = count_leading_ones(x._bits._zeros) = clz(~x._bits._zeros) + // + // To maximize `count_leading_zeros(x)`, we should make the leading zeros as + // many as possible. A bit in x can be 0 iff this bit is not forced to be 1, + // i.e. the corresponding bit in `x._bits._ones` is 0. Thus: + // max(clz(x)) = clz(x._bits._ones) + // + // Therefore, the range of `count_leading_zeros(x)` is: + // [clz(~x._bits._zeros), clz(x._bits._ones)] + // + // A more detailed proof using Z3 can be found at: + // https://github.com/openjdk/jdk/pull/25928#discussion_r2256750507 + const TypeInt* ti = t->is_int(); + return TypeInt::make(count_leading_zeros_int(~ti->_bits._zeros), + count_leading_zeros_int(ti->_bits._ones), + ti->_widen); } //------------------------------Value------------------------------------------ const Type* CountLeadingZerosLNode::Value(PhaseGVN* phase) const { const Type* t = phase->type(in(1)); - if (t == Type::TOP) return Type::TOP; - const TypeLong* tl = t->isa_long(); - if (tl && tl->is_con()) { - jlong l = tl->get_con(); - // HD, Figure 5-6 - if (l == 0) - return TypeInt::make(BitsPerLong); - int n = 1; - unsigned int x = (((julong) l) >> 32); - if (x == 0) { n += 32; x = (int) l; } - if (x >> 16 == 0) { n += 16; x <<= 16; } - if (x >> 24 == 0) { n += 8; x <<= 8; } - if (x >> 28 == 0) { n += 4; x <<= 4; } - if (x >> 30 == 0) { n += 2; x <<= 2; } - n -= x >> 31; - return TypeInt::make(n); + if (t == Type::TOP) { + return Type::TOP; } - return TypeInt::INT; + + // The proof of correctness is same as the above comments + // in `CountLeadingZerosINode::Value`. + const TypeLong* tl = t->is_long(); + return TypeInt::make(count_leading_zeros_long(~tl->_bits._zeros), + count_leading_zeros_long(tl->_bits._ones), + tl->_widen); } //------------------------------Value------------------------------------------ const Type* CountTrailingZerosINode::Value(PhaseGVN* phase) const { const Type* t = phase->type(in(1)); - if (t == Type::TOP) return Type::TOP; - const TypeInt* ti = t->isa_int(); - if (ti && ti->is_con()) { - jint i = ti->get_con(); - // HD, Figure 5-14 - int y; - if (i == 0) - return TypeInt::make(BitsPerInt); - int n = 31; - y = i << 16; if (y != 0) { n = n - 16; i = y; } - y = i << 8; if (y != 0) { n = n - 8; i = y; } - y = i << 4; if (y != 0) { n = n - 4; i = y; } - y = i << 2; if (y != 0) { n = n - 2; i = y; } - y = i << 1; if (y != 0) { n = n - 1; } - return TypeInt::make(n); + if (t == Type::TOP) { + return Type::TOP; } - return TypeInt::INT; + + // To minimize `count_trailing_zeros(x)`, we should make the lowest 1 bit in x + // as far to the right as possible. A bit in x can be 1 iff this bit is not + // forced to be 0, i.e. the corresponding bit in `x._bits._zeros` is 0. Thus: + // min(ctz(x)) = number of bits to the right of the lowest 0 bit in x._bits._zeros + // = count_trailing_ones(x._bits._zeros) = ctz(~x._bits._zeros) + // + // To maximize `count_trailing_zeros(x)`, we should make the trailing zeros as + // many as possible. A bit in x can be 0 iff this bit is not forced to be 1, + // i.e. the corresponding bit in `x._bits._ones` is 0. Thus: + // max(ctz(x)) = ctz(x._bits._ones) + // + // Therefore, the range of `count_trailing_zeros(x)` is: + // [ctz(~x._bits._zeros), ctz(x._bits._ones)] + // + // A more detailed proof using Z3 can be found at: + // https://github.com/openjdk/jdk/pull/25928#discussion_r2256750507 + const TypeInt* ti = t->is_int(); + return TypeInt::make(count_trailing_zeros_int(~ti->_bits._zeros), + count_trailing_zeros_int(ti->_bits._ones), + ti->_widen); } //------------------------------Value------------------------------------------ const Type* CountTrailingZerosLNode::Value(PhaseGVN* phase) const { const Type* t = phase->type(in(1)); - if (t == Type::TOP) return Type::TOP; - const TypeLong* tl = t->isa_long(); - if (tl && tl->is_con()) { - jlong l = tl->get_con(); - // HD, Figure 5-14 - int x, y; - if (l == 0) - return TypeInt::make(BitsPerLong); - int n = 63; - y = (int) l; if (y != 0) { n = n - 32; x = y; } else x = (((julong) l) >> 32); - y = x << 16; if (y != 0) { n = n - 16; x = y; } - y = x << 8; if (y != 0) { n = n - 8; x = y; } - y = x << 4; if (y != 0) { n = n - 4; x = y; } - y = x << 2; if (y != 0) { n = n - 2; x = y; } - y = x << 1; if (y != 0) { n = n - 1; } - return TypeInt::make(n); + if (t == Type::TOP) { + return Type::TOP; } - return TypeInt::INT; + + // The proof of correctness is same as the above comments + // in `CountTrailingZerosINode::Value`. + const TypeLong* tl = t->is_long(); + return TypeInt::make(count_trailing_zeros_long(~tl->_bits._zeros), + count_trailing_zeros_long(tl->_bits._ones), + tl->_widen); } + // We use the KnownBits information from the integer types to derive how many one bits // we have at least and at most. // From the definition of KnownBits, we know: diff --git a/test/hotspot/jtreg/compiler/c2/gvn/TestCountBitsRange.java b/test/hotspot/jtreg/compiler/c2/gvn/TestCountBitsRange.java new file mode 100644 index 0000000000000..00aa466e82262 --- /dev/null +++ b/test/hotspot/jtreg/compiler/c2/gvn/TestCountBitsRange.java @@ -0,0 +1,570 @@ +/* + * Copyright (c) 2025 Alibaba Group Holding Limited. All Rights Reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package compiler.c2.gvn; + +import compiler.lib.generators.Generator; +import compiler.lib.generators.Generators; +import compiler.lib.generators.RestrictableGenerator; +import compiler.lib.ir_framework.*; +import java.util.function.Function; +import jdk.test.lib.Asserts; + +/* + * @test + * @bug 8360192 + * @summary Tests that count bits nodes are handled correctly. + * @library /test/lib / + * @run driver compiler.c2.gvn.TestCountBitsRange + */ +public class TestCountBitsRange { + private static final Generator INTS = Generators.G.ints(); + private static final Generator LONGS = Generators.G.longs(); + + private static final RestrictableGenerator INTS_32 = Generators.G.ints().restricted(0, 32); + private static final RestrictableGenerator INTS_64 = Generators.G.ints().restricted(0, 64); + + private static final int LIMITS_32_0 = INTS_32.next(); + private static final int LIMITS_32_1 = INTS_32.next(); + private static final int LIMITS_32_2 = INTS_32.next(); + private static final int LIMITS_32_3 = INTS_32.next(); + private static final int LIMITS_32_4 = INTS_32.next(); + private static final int LIMITS_32_5 = INTS_32.next(); + private static final int LIMITS_32_6 = INTS_32.next(); + private static final int LIMITS_32_7 = INTS_32.next(); + + private static final int LIMITS_64_0 = INTS_64.next(); + private static final int LIMITS_64_1 = INTS_64.next(); + private static final int LIMITS_64_2 = INTS_64.next(); + private static final int LIMITS_64_3 = INTS_64.next(); + private static final int LIMITS_64_4 = INTS_64.next(); + private static final int LIMITS_64_5 = INTS_64.next(); + private static final int LIMITS_64_6 = INTS_64.next(); + private static final int LIMITS_64_7 = INTS_64.next(); + + private static final IntRange RANGE_INT = IntRange.generate(INTS); + private static final LongRange RANGE_LONG = LongRange.generate(LONGS); + + public static void main(String[] args) { + TestFramework.run(); + } + + @Run(test = { + "clzConstInts", "clzCompareInt", "clzDiv8Int", "clzRandLimitInt", + "clzConstLongs", "clzCompareLong", "clzDiv8Long", "clzRandLimitLong", + "ctzConstInts", "ctzCompareInt", "ctzDiv8Int", "ctzRandLimitInt", + "ctzConstLongs", "ctzCompareLong", "ctzDiv8Long", "ctzRandLimitLong", + }) + public void runTest() { + int randInt = INTS.next(); + long randLong = LONGS.next(); + assertResult(randInt, randLong); + } + + @DontCompile + public void assertResult(int randInt, long randLong) { + checkConstResults(clzConstInts(), x -> Integer.numberOfLeadingZeros(x.intValue())); + Asserts.assertEQ(Integer.numberOfLeadingZeros(randInt) < 0 + || Integer.numberOfLeadingZeros(randInt) > 32, + clzCompareInt(randInt)); + Asserts.assertEQ(Integer.numberOfLeadingZeros(randInt) / 8, + clzDiv8Int(randInt)); + Asserts.assertEQ(clzRandLimitInterpretedInt(randInt), clzRandLimitInt(randInt)); + + checkConstResults(clzConstLongs(), x -> Long.numberOfLeadingZeros(x.longValue())); + Asserts.assertEQ(Long.numberOfLeadingZeros(randLong) < 0 + || Long.numberOfLeadingZeros(randLong) > 64, + clzCompareLong(randLong)); + Asserts.assertEQ(Long.numberOfLeadingZeros(randLong) / 8, + clzDiv8Long(randLong)); + Asserts.assertEQ(clzRandLimitInterpretedLong(randLong), clzRandLimitLong(randLong)); + + checkConstResults(ctzConstInts(), x -> Integer.numberOfTrailingZeros(x.intValue())); + Asserts.assertEQ(Integer.numberOfTrailingZeros(randInt) < 0 + || Integer.numberOfTrailingZeros(randInt) > 32, + ctzCompareInt(randInt)); + Asserts.assertEQ(Integer.numberOfTrailingZeros(randInt) / 8, + ctzDiv8Int(randInt)); + Asserts.assertEQ(ctzRandLimitInterpretedInt(randInt), ctzRandLimitInt(randInt)); + + checkConstResults(ctzConstLongs(), x -> Long.numberOfTrailingZeros(x.longValue())); + Asserts.assertEQ(Long.numberOfTrailingZeros(randLong) < 0 + || Long.numberOfTrailingZeros(randLong) > 64, + ctzCompareLong(randLong)); + Asserts.assertEQ(Long.numberOfTrailingZeros(randLong) / 8, + ctzDiv8Long(randLong)); + Asserts.assertEQ(ctzRandLimitInterpretedLong(randLong), ctzRandLimitLong(randLong)); + } + + @DontCompile + public void checkConstResults(int[] results, Function op) { + Asserts.assertEQ(op.apply(Long.valueOf(0)), results[0]); + for (int i = 0; i < results.length - 1; ++i) { + Asserts.assertEQ(op.apply(Long.valueOf(1l << i)), results[i + 1]); + } + } + + // Test CLZ with constant integer inputs. + // All CLZs in this test are expected to be optimized away. + @Test + @IR(failOn = IRNode.COUNT_LEADING_ZEROS_I) + public int[] clzConstInts() { + return new int[] { + Integer.numberOfLeadingZeros(0), + Integer.numberOfLeadingZeros(1 << 0), + Integer.numberOfLeadingZeros(1 << 1), + Integer.numberOfLeadingZeros(1 << 2), + Integer.numberOfLeadingZeros(1 << 3), + Integer.numberOfLeadingZeros(1 << 4), + Integer.numberOfLeadingZeros(1 << 5), + Integer.numberOfLeadingZeros(1 << 6), + Integer.numberOfLeadingZeros(1 << 7), + Integer.numberOfLeadingZeros(1 << 8), + Integer.numberOfLeadingZeros(1 << 9), + Integer.numberOfLeadingZeros(1 << 10), + Integer.numberOfLeadingZeros(1 << 11), + Integer.numberOfLeadingZeros(1 << 12), + Integer.numberOfLeadingZeros(1 << 13), + Integer.numberOfLeadingZeros(1 << 14), + Integer.numberOfLeadingZeros(1 << 15), + Integer.numberOfLeadingZeros(1 << 16), + Integer.numberOfLeadingZeros(1 << 17), + Integer.numberOfLeadingZeros(1 << 18), + Integer.numberOfLeadingZeros(1 << 19), + Integer.numberOfLeadingZeros(1 << 20), + Integer.numberOfLeadingZeros(1 << 21), + Integer.numberOfLeadingZeros(1 << 22), + Integer.numberOfLeadingZeros(1 << 23), + Integer.numberOfLeadingZeros(1 << 24), + Integer.numberOfLeadingZeros(1 << 25), + Integer.numberOfLeadingZeros(1 << 26), + Integer.numberOfLeadingZeros(1 << 27), + Integer.numberOfLeadingZeros(1 << 28), + Integer.numberOfLeadingZeros(1 << 29), + Integer.numberOfLeadingZeros(1 << 30), + Integer.numberOfLeadingZeros(1 << 31), + }; + } + + // Test the range of CLZ with random integer input. + // The result of CLZ should be in range [0, 32], so CLZs in this test are + // expected to be optimized away, and the test should always return false. + @Test + @IR(failOn = IRNode.COUNT_LEADING_ZEROS_I) + public boolean clzCompareInt(int randInt) { + return Integer.numberOfLeadingZeros(randInt) < 0 + || Integer.numberOfLeadingZeros(randInt) > 32; + } + + // Test the combination of CLZ and division by 8. + // The result of CLZ should be positive, so the division by 8 should be + // optimized to a simple right shift without rounding. + @Test + @IR(counts = {IRNode.COUNT_LEADING_ZEROS_I, "1", + IRNode.RSHIFT_I, "1", + IRNode.URSHIFT_I, "0", + IRNode.ADD_I, "0"}) + public int clzDiv8Int(int randInt) { + return Integer.numberOfLeadingZeros(randInt) / 8; + } + + // Test the output range of CLZ with random input range. + @Test + public int clzRandLimitInt(int randInt) { + randInt = RANGE_INT.clamp(randInt); + int result = Integer.numberOfLeadingZeros(randInt); + return getResultChecksum32(result); + } + + @DontCompile + public int clzRandLimitInterpretedInt(int randInt) { + randInt = RANGE_INT.clamp(randInt); + int result = Integer.numberOfLeadingZeros(randInt); + return getResultChecksum32(result); + } + + // Test CLZ with constant long inputs. + // All CLZs in this test are expected to be optimized away. + @Test + @IR(failOn = IRNode.COUNT_LEADING_ZEROS_L) + public int[] clzConstLongs() { + return new int[] { + Long.numberOfLeadingZeros(0), + Long.numberOfLeadingZeros(1l << 0), + Long.numberOfLeadingZeros(1l << 1), + Long.numberOfLeadingZeros(1l << 2), + Long.numberOfLeadingZeros(1l << 3), + Long.numberOfLeadingZeros(1l << 4), + Long.numberOfLeadingZeros(1l << 5), + Long.numberOfLeadingZeros(1l << 6), + Long.numberOfLeadingZeros(1l << 7), + Long.numberOfLeadingZeros(1l << 8), + Long.numberOfLeadingZeros(1l << 9), + Long.numberOfLeadingZeros(1l << 10), + Long.numberOfLeadingZeros(1l << 11), + Long.numberOfLeadingZeros(1l << 12), + Long.numberOfLeadingZeros(1l << 13), + Long.numberOfLeadingZeros(1l << 14), + Long.numberOfLeadingZeros(1l << 15), + Long.numberOfLeadingZeros(1l << 16), + Long.numberOfLeadingZeros(1l << 17), + Long.numberOfLeadingZeros(1l << 18), + Long.numberOfLeadingZeros(1l << 19), + Long.numberOfLeadingZeros(1l << 20), + Long.numberOfLeadingZeros(1l << 21), + Long.numberOfLeadingZeros(1l << 22), + Long.numberOfLeadingZeros(1l << 23), + Long.numberOfLeadingZeros(1l << 24), + Long.numberOfLeadingZeros(1l << 25), + Long.numberOfLeadingZeros(1l << 26), + Long.numberOfLeadingZeros(1l << 27), + Long.numberOfLeadingZeros(1l << 28), + Long.numberOfLeadingZeros(1l << 29), + Long.numberOfLeadingZeros(1l << 30), + Long.numberOfLeadingZeros(1l << 31), + Long.numberOfLeadingZeros(1l << 32), + Long.numberOfLeadingZeros(1l << 33), + Long.numberOfLeadingZeros(1l << 34), + Long.numberOfLeadingZeros(1l << 35), + Long.numberOfLeadingZeros(1l << 36), + Long.numberOfLeadingZeros(1l << 37), + Long.numberOfLeadingZeros(1l << 38), + Long.numberOfLeadingZeros(1l << 39), + Long.numberOfLeadingZeros(1l << 40), + Long.numberOfLeadingZeros(1l << 41), + Long.numberOfLeadingZeros(1l << 42), + Long.numberOfLeadingZeros(1l << 43), + Long.numberOfLeadingZeros(1l << 44), + Long.numberOfLeadingZeros(1l << 45), + Long.numberOfLeadingZeros(1l << 46), + Long.numberOfLeadingZeros(1l << 47), + Long.numberOfLeadingZeros(1l << 48), + Long.numberOfLeadingZeros(1l << 49), + Long.numberOfLeadingZeros(1l << 50), + Long.numberOfLeadingZeros(1l << 51), + Long.numberOfLeadingZeros(1l << 52), + Long.numberOfLeadingZeros(1l << 53), + Long.numberOfLeadingZeros(1l << 54), + Long.numberOfLeadingZeros(1l << 55), + Long.numberOfLeadingZeros(1l << 56), + Long.numberOfLeadingZeros(1l << 57), + Long.numberOfLeadingZeros(1l << 58), + Long.numberOfLeadingZeros(1l << 59), + Long.numberOfLeadingZeros(1l << 60), + Long.numberOfLeadingZeros(1l << 61), + Long.numberOfLeadingZeros(1l << 62), + Long.numberOfLeadingZeros(1l << 63), + }; + } + + // Test the range of CLZ with random long input. + // The result of CLZ should be in range [0, 64], so CLZs in this test are + // expected to be optimized away, and the test should always return false. + @Test + @IR(failOn = IRNode.COUNT_LEADING_ZEROS_L) + public boolean clzCompareLong(long randLong) { + return Long.numberOfLeadingZeros(randLong) < 0 + || Long.numberOfLeadingZeros(randLong) > 64; + } + + // Test the combination of CLZ and division by 8. + // The result of CLZ should be positive, so the division by 8 should be + // optimized to a simple right shift without rounding. + @Test + @IR(counts = {IRNode.COUNT_LEADING_ZEROS_L, "1", + IRNode.RSHIFT_I, "1", + IRNode.URSHIFT_I, "0", + IRNode.ADD_I, "0"}) + public int clzDiv8Long(long randLong) { + return Long.numberOfLeadingZeros(randLong) / 8; + } + + // Test the output range of CLZ with random input range. + @Test + public int clzRandLimitLong(long randLong) { + randLong = RANGE_LONG.clamp(randLong); + int result = Long.numberOfLeadingZeros(randLong); + return getResultChecksum64(result); + } + + @DontCompile + public int clzRandLimitInterpretedLong(long randLong) { + randLong = RANGE_LONG.clamp(randLong); + int result = Long.numberOfLeadingZeros(randLong); + return getResultChecksum64(result); + } + + // Test CTZ with constant integer inputs. + // All CTZs in this test are expected to be optimized away. + @Test + @IR(failOn = IRNode.COUNT_TRAILING_ZEROS_I) + public int[] ctzConstInts() { + return new int[] { + Integer.numberOfTrailingZeros(0), + Integer.numberOfTrailingZeros(1 << 0), + Integer.numberOfTrailingZeros(1 << 1), + Integer.numberOfTrailingZeros(1 << 2), + Integer.numberOfTrailingZeros(1 << 3), + Integer.numberOfTrailingZeros(1 << 4), + Integer.numberOfTrailingZeros(1 << 5), + Integer.numberOfTrailingZeros(1 << 6), + Integer.numberOfTrailingZeros(1 << 7), + Integer.numberOfTrailingZeros(1 << 8), + Integer.numberOfTrailingZeros(1 << 9), + Integer.numberOfTrailingZeros(1 << 10), + Integer.numberOfTrailingZeros(1 << 11), + Integer.numberOfTrailingZeros(1 << 12), + Integer.numberOfTrailingZeros(1 << 13), + Integer.numberOfTrailingZeros(1 << 14), + Integer.numberOfTrailingZeros(1 << 15), + Integer.numberOfTrailingZeros(1 << 16), + Integer.numberOfTrailingZeros(1 << 17), + Integer.numberOfTrailingZeros(1 << 18), + Integer.numberOfTrailingZeros(1 << 19), + Integer.numberOfTrailingZeros(1 << 20), + Integer.numberOfTrailingZeros(1 << 21), + Integer.numberOfTrailingZeros(1 << 22), + Integer.numberOfTrailingZeros(1 << 23), + Integer.numberOfTrailingZeros(1 << 24), + Integer.numberOfTrailingZeros(1 << 25), + Integer.numberOfTrailingZeros(1 << 26), + Integer.numberOfTrailingZeros(1 << 27), + Integer.numberOfTrailingZeros(1 << 28), + Integer.numberOfTrailingZeros(1 << 29), + Integer.numberOfTrailingZeros(1 << 30), + Integer.numberOfTrailingZeros(1 << 31), + }; + } + + // Test the range of CTZ with random integer input. + // The result of CTZ should be in range [0, 32], so CTZs in this test are + // expected to be optimized away, and the test should always return false. + @Test + @IR(failOn = IRNode.COUNT_TRAILING_ZEROS_I) + public boolean ctzCompareInt(int randInt) { + return Integer.numberOfTrailingZeros(randInt) < 0 + || Integer.numberOfTrailingZeros(randInt) > 32; + } + + // Test the combination of CTZ and division by 8. + // The result of CTZ should be positive, so the division by 8 should be + // optimized to a simple right shift without rounding. + @Test + @IR(counts = {IRNode.COUNT_TRAILING_ZEROS_I, "1", + IRNode.RSHIFT_I, "1", + IRNode.URSHIFT_I, "0", + IRNode.ADD_I, "0"}) + public int ctzDiv8Int(int randInt) { + return Integer.numberOfTrailingZeros(randInt) / 8; + } + + // Test the output range of CTZ with random input range. + @Test + public int ctzRandLimitInt(int randInt) { + randInt = RANGE_INT.clamp(randInt); + int result = Integer.numberOfTrailingZeros(randInt); + return getResultChecksum32(result); + } + + @DontCompile + public int ctzRandLimitInterpretedInt(int randInt) { + randInt = RANGE_INT.clamp(randInt); + int result = Integer.numberOfTrailingZeros(randInt); + return getResultChecksum32(result); + } + + // Test CTZ with constant long inputs. + // All CTZs in this test are expected to be optimized away. + @Test + @IR(failOn = IRNode.COUNT_TRAILING_ZEROS_L) + public int[] ctzConstLongs() { + return new int[] { + Long.numberOfTrailingZeros(0), + Long.numberOfTrailingZeros(1l << 0), + Long.numberOfTrailingZeros(1l << 1), + Long.numberOfTrailingZeros(1l << 2), + Long.numberOfTrailingZeros(1l << 3), + Long.numberOfTrailingZeros(1l << 4), + Long.numberOfTrailingZeros(1l << 5), + Long.numberOfTrailingZeros(1l << 6), + Long.numberOfTrailingZeros(1l << 7), + Long.numberOfTrailingZeros(1l << 8), + Long.numberOfTrailingZeros(1l << 9), + Long.numberOfTrailingZeros(1l << 10), + Long.numberOfTrailingZeros(1l << 11), + Long.numberOfTrailingZeros(1l << 12), + Long.numberOfTrailingZeros(1l << 13), + Long.numberOfTrailingZeros(1l << 14), + Long.numberOfTrailingZeros(1l << 15), + Long.numberOfTrailingZeros(1l << 16), + Long.numberOfTrailingZeros(1l << 17), + Long.numberOfTrailingZeros(1l << 18), + Long.numberOfTrailingZeros(1l << 19), + Long.numberOfTrailingZeros(1l << 20), + Long.numberOfTrailingZeros(1l << 21), + Long.numberOfTrailingZeros(1l << 22), + Long.numberOfTrailingZeros(1l << 23), + Long.numberOfTrailingZeros(1l << 24), + Long.numberOfTrailingZeros(1l << 25), + Long.numberOfTrailingZeros(1l << 26), + Long.numberOfTrailingZeros(1l << 27), + Long.numberOfTrailingZeros(1l << 28), + Long.numberOfTrailingZeros(1l << 29), + Long.numberOfTrailingZeros(1l << 30), + Long.numberOfTrailingZeros(1l << 31), + Long.numberOfTrailingZeros(1l << 32), + Long.numberOfTrailingZeros(1l << 33), + Long.numberOfTrailingZeros(1l << 34), + Long.numberOfTrailingZeros(1l << 35), + Long.numberOfTrailingZeros(1l << 36), + Long.numberOfTrailingZeros(1l << 37), + Long.numberOfTrailingZeros(1l << 38), + Long.numberOfTrailingZeros(1l << 39), + Long.numberOfTrailingZeros(1l << 40), + Long.numberOfTrailingZeros(1l << 41), + Long.numberOfTrailingZeros(1l << 42), + Long.numberOfTrailingZeros(1l << 43), + Long.numberOfTrailingZeros(1l << 44), + Long.numberOfTrailingZeros(1l << 45), + Long.numberOfTrailingZeros(1l << 46), + Long.numberOfTrailingZeros(1l << 47), + Long.numberOfTrailingZeros(1l << 48), + Long.numberOfTrailingZeros(1l << 49), + Long.numberOfTrailingZeros(1l << 50), + Long.numberOfTrailingZeros(1l << 51), + Long.numberOfTrailingZeros(1l << 52), + Long.numberOfTrailingZeros(1l << 53), + Long.numberOfTrailingZeros(1l << 54), + Long.numberOfTrailingZeros(1l << 55), + Long.numberOfTrailingZeros(1l << 56), + Long.numberOfTrailingZeros(1l << 57), + Long.numberOfTrailingZeros(1l << 58), + Long.numberOfTrailingZeros(1l << 59), + Long.numberOfTrailingZeros(1l << 60), + Long.numberOfTrailingZeros(1l << 61), + Long.numberOfTrailingZeros(1l << 62), + Long.numberOfTrailingZeros(1l << 63), + }; + } + + // Test the range of CTZ with random long input. + // The result of CTZ should be in range [0, 64], so CTZs in this test are + // expected to be optimized away, and the test should always return false. + @Test + @IR(failOn = IRNode.COUNT_TRAILING_ZEROS_L) + public boolean ctzCompareLong(long randLong) { + return Long.numberOfTrailingZeros(randLong) < 0 + || Long.numberOfTrailingZeros(randLong) > 64; + } + + // Test the combination of CTZ and division by 8. + // The result of CTZ should be positive, so the division by 8 should be + // optimized to a simple right shift without rounding. + @Test + @IR(counts = {IRNode.COUNT_TRAILING_ZEROS_L, "1", + IRNode.RSHIFT_I, "1", + IRNode.URSHIFT_I, "0", + IRNode.ADD_I, "0"}) + public int ctzDiv8Long(long randLong) { + return Long.numberOfTrailingZeros(randLong) / 8; + } + + // Test the output range of CTZ with random input range. + @Test + public int ctzRandLimitLong(long randLong) { + randLong = RANGE_LONG.clamp(randLong); + int result = Long.numberOfLeadingZeros(randLong); + return getResultChecksum64(result); + } + + @DontCompile + public int ctzRandLimitInterpretedLong(long randLong) { + randLong = RANGE_LONG.clamp(randLong); + int result = Long.numberOfLeadingZeros(randLong); + return getResultChecksum64(result); + } + + record IntRange(int lo, int hi) { + IntRange { + if (lo > hi) { + throw new IllegalArgumentException("lo > hi"); + } + } + + @ForceInline + int clamp(int v) { + return v < lo ? lo : v > hi ? hi : v; + } + + static IntRange generate(Generator g) { + int a = g.next(), b = g.next(); + return a < b ? new IntRange(a, b) : new IntRange(b, a); + } + } + + record LongRange(long lo, long hi) { + LongRange { + if (lo > hi) { + throw new IllegalArgumentException("lo > hi"); + } + } + + @ForceInline + long clamp(long v) { + return v < lo ? lo : v > hi ? hi : v; + } + + static LongRange generate(Generator g) { + long a = g.next(), b = g.next(); + return a < b ? new LongRange(a, b) : new LongRange(b, a); + } + } + + @ForceInline + int getResultChecksum32(int result) { + int sum = 0; + if (result < LIMITS_32_0) sum += 1; + if (result < LIMITS_32_1) sum += 2; + if (result < LIMITS_32_2) sum += 4; + if (result < LIMITS_32_3) sum += 8; + if (result > LIMITS_32_4) sum += 16; + if (result > LIMITS_32_5) sum += 32; + if (result > LIMITS_32_6) sum += 64; + if (result > LIMITS_32_7) sum += 128; + return sum; + } + + @ForceInline + int getResultChecksum64(int result) { + int sum = 0; + if (result < LIMITS_64_0) sum += 1; + if (result < LIMITS_64_1) sum += 2; + if (result < LIMITS_64_2) sum += 4; + if (result < LIMITS_64_3) sum += 8; + if (result > LIMITS_64_4) sum += 16; + if (result > LIMITS_64_5) sum += 32; + if (result > LIMITS_64_6) sum += 64; + if (result > LIMITS_64_7) sum += 128; + return sum; + } +} diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java index 99a289476ec4d..f25184f7e3169 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java @@ -1655,6 +1655,16 @@ public class IRNode { vectorNode(POPCOUNT_VL, "PopCountVL", TYPE_LONG); } + public static final String COUNT_TRAILING_ZEROS_I = PREFIX + "COUNT_TRAILING_ZEROS_I" + POSTFIX; + static { + beforeMatchingNameRegex(COUNT_TRAILING_ZEROS_I, "CountTrailingZerosI"); + } + + public static final String COUNT_TRAILING_ZEROS_L = PREFIX + "COUNT_TRAILING_ZEROS_L" + POSTFIX; + static { + beforeMatchingNameRegex(COUNT_TRAILING_ZEROS_L, "CountTrailingZerosL"); + } + public static final String COUNT_TRAILING_ZEROS_VL = VECTOR_PREFIX + "COUNT_TRAILING_ZEROS_VL" + POSTFIX; static { vectorNode(COUNT_TRAILING_ZEROS_VL, "CountTrailingZerosV", TYPE_LONG); @@ -1665,6 +1675,16 @@ public class IRNode { vectorNode(COUNT_TRAILING_ZEROS_VI, "CountTrailingZerosV", TYPE_INT); } + public static final String COUNT_LEADING_ZEROS_I = PREFIX + "COUNT_LEADING_ZEROS_I" + POSTFIX; + static { + beforeMatchingNameRegex(COUNT_LEADING_ZEROS_I, "CountLeadingZerosI"); + } + + public static final String COUNT_LEADING_ZEROS_L = PREFIX + "COUNT_LEADING_ZEROS_L" + POSTFIX; + static { + beforeMatchingNameRegex(COUNT_LEADING_ZEROS_L, "CountLeadingZerosL"); + } + public static final String COUNT_LEADING_ZEROS_VL = VECTOR_PREFIX + "COUNT_LEADING_ZEROS_VL" + POSTFIX; static { vectorNode(COUNT_LEADING_ZEROS_VL, "CountLeadingZerosV", TYPE_LONG); diff --git a/test/micro/org/openjdk/bench/vm/compiler/CountLeadingZeros.java b/test/micro/org/openjdk/bench/vm/compiler/CountLeadingZeros.java new file mode 100644 index 0000000000000..dc06e3ae76411 --- /dev/null +++ b/test/micro/org/openjdk/bench/vm/compiler/CountLeadingZeros.java @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2025 Alibaba Group Holding Limited. All Rights Reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.bench.vm.compiler; + +import org.openjdk.jmh.annotations.*; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ThreadLocalRandom; + +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(value = 3) +@Warmup(iterations = 10, time = 1) +@Measurement(iterations = 5, time = 1) +@State(Scope.Thread) +public class CountLeadingZeros { + private long[] longArray = new long[1000]; + + @Setup + public void setup() { + for (int i = 0; i < longArray.length; i++) { + longArray[i] = ThreadLocalRandom.current().nextLong(); + } + } + + @Benchmark + public int benchNumberOfNibbles() { + int sum = 0; + for (long l : longArray) { + sum += numberOfNibbles((int) l); + } + return sum; + } + + public static int numberOfNibbles(int i) { + int mag = Integer.SIZE - Integer.numberOfLeadingZeros(i); + return Math.max((mag + 3) / 4, 1); + } + + @Benchmark + public int benchClzLongConstrained() { + int sum = 0; + for (long l : longArray) { + sum += clzLongConstrained(l); + } + return sum; + } + + public static int clzLongConstrained(long param) { + long constrainedParam = Math.min(175, Math.max(param, 160)); + return Long.numberOfLeadingZeros(constrainedParam); + } +}