Skip to content
19 changes: 14 additions & 5 deletions src/hotspot/share/opto/vectorIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,20 @@ bool LibraryCallKit::inline_vector_frombits_coerced() {
int opc = bcast_mode == VectorSupport::MODE_BITS_COERCED_LONG_TO_MASK ? Op_VectorLongToMask : Op_Replicate;

if (!arch_supports_vector(opc, num_elem, elem_bt, checkFlags, true /*has_scalar_args*/)) {
log_if_needed(" ** not supported: arity=0 op=broadcast vlen=%d etype=%s ismask=%d bcast_mode=%d",
num_elem, type2name(elem_bt),
is_mask ? 1 : 0,
bcast_mode);
return false; // not supported
// If the input long sets or unsets all lanes and Replicate is supported,
// generate a MaskAll or Replicate instead.

// The "maskAll" API uses the corresponding integer types for floating-point data.
BasicType maskall_bt = elem_bt == T_DOUBLE ? T_LONG : (elem_bt == T_FLOAT ? T_INT: elem_bt);
if (!(opc == Op_VectorLongToMask &&
VectorNode::is_maskall_type(bits_type, num_elem) &&
arch_supports_vector(Op_Replicate, num_elem, maskall_bt, checkFlags, true /*has_scalar_args*/))) {
log_if_needed(" ** not supported: arity=0 op=broadcast vlen=%d etype=%s ismask=%d bcast_mode=%d",
num_elem, type2name(elem_bt),
is_mask ? 1 : 0,
bcast_mode);
return false; // not supported
}
}

Node* broadcast = nullptr;
Expand Down
113 changes: 107 additions & 6 deletions src/hotspot/share/opto/vectornode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ bool VectorNode::implemented(int opc, uint vlen, BasicType bt) {
return false;
}

bool VectorNode::is_maskall_type(const TypeLong* type, int vlen) {
assert(type != nullptr, "type must not be null");
if (!type->is_con()) {
return false;
}
long mask = (-1ULL >> (64 - vlen));
long bit = type->get_con() & mask;
return bit == 0 || bit == mask;
}

bool VectorNode::is_muladds2i(const Node* n) {
return n->Opcode() == Op_MulAddS2I;
}
Expand Down Expand Up @@ -1503,6 +1513,45 @@ Node* ReductionNode::Ideal(PhaseGVN* phase, bool can_reshape) {
return nullptr;
}

// Convert fromLong to maskAll if the input sets or unsets all lanes.
Node* convertFromLongToMaskAll(PhaseGVN* phase, const TypeLong* bits_type, bool is_mask, const TypeVect* vt) {
uint vlen = vt->length();
BasicType bt = vt->element_basic_type();
// The "maskAll" API uses the corresponding integer types for floating-point data.
BasicType maskall_bt = (bt == T_FLOAT) ? T_INT : (bt == T_DOUBLE) ? T_LONG : bt;

if (VectorNode::is_maskall_type(bits_type, vlen) &&
Matcher::match_rule_supported_vector(Op_Replicate, vlen, maskall_bt)) {
Node* con = nullptr;
jlong con_value = bits_type->get_con() == 0L ? 0L : -1L;
if (maskall_bt == T_LONG) {
con = phase->longcon(con_value);
} else {
con = phase->intcon(con_value);
}
Node* res = VectorNode::scalar2vector(con, vlen, maskall_bt, is_mask);
// Convert back to the original floating-point data type.
if (is_floating_point_type(bt)) {
res = new VectorMaskCastNode(phase->transform(res), vt);
}
return res;
}
return nullptr;
}

Node* VectorLoadMaskNode::Ideal(PhaseGVN* phase, bool can_reshape) {
// VectorLoadMask(VectorLongToMask(-1/0)) => Replicate(-1/0)
if (in(1)->Opcode() == Op_VectorLongToMask) {
const TypeVect* vt = bottom_type()->is_vect();
Node* res = convertFromLongToMaskAll(phase, in(1)->in(1)->bottom_type()->isa_long(), false, vt);
if (res != nullptr) {
return res;
}
}

return VectorNode::Ideal(phase, can_reshape);
}

Node* VectorLoadMaskNode::Identity(PhaseGVN* phase) {
BasicType out_bt = type()->is_vect()->element_basic_type();
if (!Matcher::has_predicated_vectors() && out_bt == T_BOOLEAN) {
Expand Down Expand Up @@ -1918,6 +1967,45 @@ Node* VectorMaskOpNode::Ideal(PhaseGVN* phase, bool can_reshape) {
return nullptr;
}

Node* VectorMaskCastNode::Identity(PhaseGVN* phase) {
Node* in1 = in(1);
// VectorMaskCast (VectorMaskCast x) => x
if (in1->Opcode() == Op_VectorMaskCast &&
vect_type()->eq(in1->in(1)->bottom_type())) {
return in1->in(1);
}
return this;
}

// This function does the following optimization:
// VectorMaskToLong(MaskAll(l)) => (l & (-1ULL >> (64 - vlen)))
// VectorMaskToLong(VectorStoreMask(Replicate(l))) => (l & (-1ULL >> (64 - vlen)))
// l is -1 or 0.
Node* VectorMaskToLongNode::Ideal_MaskAll(PhaseGVN* phase) {
Node* in1 = in(1);
// VectorMaskToLong follows a VectorStoreMask if predicate is not supported.
if (in1->Opcode() == Op_VectorStoreMask) {
assert(!in1->in(1)->bottom_type()->isa_vectmask(), "sanity");
in1 = in1->in(1);
}
if (VectorNode::is_all_ones_vector(in1)) {
int vlen = in1->bottom_type()->is_vect()->length();
return new ConLNode(TypeLong::make(-1ULL >> (64 - vlen)));
}
if (VectorNode::is_all_zeros_vector(in1)) {
return new ConLNode(TypeLong::ZERO);
}
return nullptr;
}

Node* VectorMaskToLongNode::Ideal(PhaseGVN* phase, bool can_reshape) {
Node* res = Ideal_MaskAll(phase);
if (res != nullptr) {
return res;
}
return VectorMaskOpNode::Ideal(phase, can_reshape);
}

Node* VectorMaskToLongNode::Identity(PhaseGVN* phase) {
if (in(1)->Opcode() == Op_VectorLongToMask) {
return in(1)->in(1);
Expand All @@ -1927,28 +2015,41 @@ Node* VectorMaskToLongNode::Identity(PhaseGVN* phase) {

Node* VectorLongToMaskNode::Ideal(PhaseGVN* phase, bool can_reshape) {
const TypeVect* dst_type = bottom_type()->is_vect();
uint vlen = dst_type->length();
const TypeVectMask* is_mask = dst_type->isa_vectmask();

if (in(1)->Opcode() == Op_AndL &&
in(1)->in(1)->Opcode() == Op_VectorMaskToLong &&
in(1)->in(2)->bottom_type()->isa_long() &&
in(1)->in(2)->bottom_type()->is_long()->is_con() &&
in(1)->in(2)->bottom_type()->is_long()->get_con() == ((1L << dst_type->length()) - 1)) {
in(1)->in(2)->bottom_type()->is_long()->get_con() == ((1L << vlen) - 1)) {
// Different src/dst mask length represents a re-interpretation operation,
// we can however generate a mask casting operation if length matches.
Node* src = in(1)->in(1)->in(1);
if (dst_type->isa_vectmask() == nullptr) {
if (is_mask == nullptr) {
if (src->Opcode() != Op_VectorStoreMask) {
return nullptr;
}
src = src->in(1);
}
const TypeVect* src_type = src->bottom_type()->is_vect();
if (src_type->length() == dst_type->length() &&
((src_type->isa_vectmask() == nullptr && dst_type->isa_vectmask() == nullptr) ||
(src_type->isa_vectmask() && dst_type->isa_vectmask()))) {
if (src_type->length() == vlen &&
((src_type->isa_vectmask() == nullptr && is_mask == nullptr) ||
(src_type->isa_vectmask() && is_mask))) {
return new VectorMaskCastNode(src, dst_type);
}
}
return nullptr;

// VectorLongToMask(-1/0) => MaskAll(-1/0)
const TypeLong* bits_type = in(1)->bottom_type()->isa_long();
if (bits_type && is_mask) {
Node* res = convertFromLongToMaskAll(phase, bits_type, true, dst_type);
if (res != nullptr) {
return res;
}
}

return VectorNode::Ideal(phase, can_reshape);
}

Node* FmaVNode::Ideal(PhaseGVN* phase, bool can_reshape) {
Expand Down
5 changes: 5 additions & 0 deletions src/hotspot/share/opto/vectornode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class VectorNode : public TypeNode {
static bool implemented(int opc, uint vlen, BasicType bt);
static bool is_shift(Node* n);
static bool is_vshift_cnt(Node* n);
static bool is_maskall_type(const TypeLong* type, int vlen);
static bool is_muladds2i(const Node* n);
static bool is_roundopD(Node* n);
static bool is_scalar_rotate(Node* n);
Expand Down Expand Up @@ -1383,6 +1384,8 @@ class VectorMaskToLongNode : public VectorMaskOpNode {
VectorMaskToLongNode(Node* mask, const Type* ty):
VectorMaskOpNode(mask, ty, Op_VectorMaskToLong) {}
virtual int Opcode() const;
Node* Ideal(PhaseGVN* phase, bool can_reshape);
Node* Ideal_MaskAll(PhaseGVN* phase);
virtual uint ideal_reg() const { return Op_RegL; }
virtual Node* Identity(PhaseGVN* phase);
};
Expand Down Expand Up @@ -1776,6 +1779,7 @@ class VectorLoadMaskNode : public VectorNode {

virtual int Opcode() const;
virtual Node* Identity(PhaseGVN* phase);
Node* Ideal(PhaseGVN* phase, bool can_reshape);
};

class VectorStoreMaskNode : public VectorNode {
Expand All @@ -1795,6 +1799,7 @@ class VectorMaskCastNode : public VectorNode {
const TypeVect* in_vt = in->bottom_type()->is_vect();
assert(in_vt->length() == vt->length(), "vector length must match");
}
Node* Identity(PhaseGVN* phase);
virtual int Opcode() const;
};

Expand Down
15 changes: 15 additions & 0 deletions test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,21 @@ public class IRNode {
vectorNode(UMAX_VL, "UMaxV", TYPE_LONG);
}

public static final String MASK_ALL = PREFIX + "MASK_ALL" + POSTFIX;
static {
beforeMatchingNameRegex(MASK_ALL, "MaskAll");
}

public static final String VECTOR_LONG_TO_MASK = PREFIX + "VECTOR_LONG_TO_MASK" + POSTFIX;
static {
beforeMatchingNameRegex(VECTOR_LONG_TO_MASK, "VectorLongToMask");
}

public static final String VECTOR_MASK_TO_LONG = PREFIX + "VECTOR_MASK_TO_LONG" + POSTFIX;
static {
beforeMatchingNameRegex(VECTOR_MASK_TO_LONG, "VectorMaskToLong");
}

// Can only be used if avx512_vnni is available.
public static final String MUL_ADD_VS2VI_VNNI = PREFIX + "MUL_ADD_VS2VI_VNNI" + POSTFIX;
static {
Expand Down
124 changes: 124 additions & 0 deletions test/hotspot/jtreg/compiler/vectorapi/VectorMaskCastIdentityTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

/*
* @test
* @bug 8356760
* @library /test/lib /
* @summary Optimize VectorMask.fromLong for all-true/all-false cases
* @modules jdk.incubator.vector
*
* @run driver compiler.vectorapi.VectorMaskCastIdentityTest
*/

package compiler.vectorapi;

import compiler.lib.ir_framework.*;
import java.util.Random;
import jdk.incubator.vector.*;
import jdk.test.lib.Asserts;
import jdk.test.lib.Utils;

public class VectorMaskCastIdentityTest {
private static final boolean[] mr = new boolean[128]; // 128 is large enough
private static final Random rd = Utils.getRandomInstance();
static {
for (int i = 0; i < mr.length; i++) {
mr[i] = rd.nextBoolean();
}
}

@Test
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 2" }, applyIfCPUFeatureOr = {"asimd", "true"})
public static int testTwoCastToDifferentType() {
// The types before and after the two casts are not the same, so the cast cannot be eliminated.
VectorMask<Float> mFloat64 = VectorMask.fromArray(FloatVector.SPECIES_64, mr, 0);
VectorMask<Double> mDouble128 = mFloat64.cast(DoubleVector.SPECIES_128);
VectorMask<Integer> mInt64 = mDouble128.cast(IntVector.SPECIES_64);
return mInt64.trueCount();
}

@Run(test = "testTwoCastToDifferentType")
public static void testTwoCastToDifferentType_runner() {
int count = testTwoCastToDifferentType();
VectorMask<Float> mFloat64 = VectorMask.fromArray(FloatVector.SPECIES_64, mr, 0);
Asserts.assertEquals(count, mFloat64.trueCount());
}

@Test
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 2" }, applyIfCPUFeatureOr = {"avx2", "true"})
public static int testTwoCastToDifferentType2() {
// The types before and after the two casts are not the same, so the cast cannot be eliminated.
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
VectorMask<Double> mDouble256 = mInt128.cast(DoubleVector.SPECIES_256);
VectorMask<Short> mShort64 = mDouble256.cast(ShortVector.SPECIES_64);
return mShort64.trueCount();
}

@Run(test = "testTwoCastToDifferentType2")
public static void testTwoCastToDifferentType2_runner() {
int count = testTwoCastToDifferentType2();
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
Asserts.assertEquals(count, mInt128.trueCount());
}

@Test
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 0" }, applyIfCPUFeatureOr = {"avx2", "true", "asimd", "true"})
public static int testTwoCastToSameType() {
// The types before and after the two casts are the same, so the cast will be eliminated.
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
VectorMask<Float> mFloat128 = mInt128.cast(FloatVector.SPECIES_128);
VectorMask<Integer> mInt128_2 = mFloat128.cast(IntVector.SPECIES_128);
return mInt128_2.trueCount();
}

@Run(test = "testTwoCastToSameType")
public static void testTwoCastToSameType_runner() {
int count = testTwoCastToSameType();
VectorMask<Integer> mInt128 = VectorMask.fromArray(IntVector.SPECIES_128, mr, 0);
Asserts.assertEquals(count, mInt128.trueCount());
}

@Test
@IR(counts = { IRNode.VECTOR_MASK_CAST, "= 1" }, applyIfCPUFeatureOr = {"avx2", "true", "asimd", "true"})
public static int testOneCastToDifferentType() {
// The types before and after the only cast are different, the cast will not be eliminated.
VectorMask<Float> mFloat128 = VectorMask.fromArray(FloatVector.SPECIES_128, mr, 0).not();
VectorMask<Integer> mInt128 = mFloat128.cast(IntVector.SPECIES_128);
return mInt128.trueCount();
}

@Run(test = "testOneCastToDifferentType")
public static void testOneCastToDifferentType_runner() {
int count = testOneCastToDifferentType();
VectorMask<Float> mInt128 = VectorMask.fromArray(FloatVector.SPECIES_128, mr, 0).not();
Asserts.assertEquals(count, mInt128.trueCount());
}

public static void main(String[] args) {
TestFramework testFramework = new TestFramework();
testFramework.setDefaultWarmup(10000)
.addFlags("--add-modules=jdk.incubator.vector")
.start();
}
}
Loading