diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index a4bfdcabaa314..19e82099e87f0 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -113,6 +113,7 @@ class VectorCombine { bool foldInsExtFNeg(Instruction &I); bool foldInsExtBinop(Instruction &I); bool foldInsExtVectorToShuffle(Instruction &I); + bool foldBitOpOfBitcasts(Instruction &I); bool foldBitcastShuffle(Instruction &I); bool scalarizeOpOrCmp(Instruction &I); bool scalarizeVPIntrinsic(Instruction &I); @@ -803,6 +804,66 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) { return true; } +bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) { + // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y)) + Value *LHSSrc, *RHSSrc; + if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)), + m_BitCast(m_Value(RHSSrc))))) + return false; + + // Source types must match + if (LHSSrc->getType() != RHSSrc->getType()) + return false; + if (!LHSSrc->getType()->getScalarType()->isIntegerTy()) + return false; + + // Only handle vector types + auto *SrcVecTy = dyn_cast(LHSSrc->getType()); + auto *DstVecTy = dyn_cast(I.getType()); + if (!SrcVecTy || !DstVecTy) + return false; + + // Same total bit width + assert(SrcVecTy->getPrimitiveSizeInBits() == + DstVecTy->getPrimitiveSizeInBits() && + "Bitcast should preserve total bit width"); + + // Cost Check : + // OldCost = bitlogic + 2*bitcasts + // NewCost = bitlogic + bitcast + auto *BinOp = cast(&I); + InstructionCost OldCost = + TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) + + TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(), + TTI::CastContextHint::None) + + TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(), + TTI::CastContextHint::None); + InstructionCost NewCost = + TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) + + TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy, + TTI::CastContextHint::None); + + LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I + << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost + << "\n"); + + if (NewCost > OldCost) + return false; + + // Create the operation on the source type + Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc, + BinOp->getName() + ".inner"); + if (auto *NewBinOp = dyn_cast(NewOp)) + NewBinOp->copyIRFlags(BinOp); + + Worklist.pushValue(NewOp); + + // Bitcast the result back + Value *Result = Builder.CreateBitCast(NewOp, I.getType()); + replaceValue(I, *Result); + return true; +} + /// If this is a bitcast of a shuffle, try to bitcast the source vector to the /// destination type followed by shuffle. This can enable further transforms by /// moving bitcasts or shuffles together. @@ -3629,6 +3690,11 @@ bool VectorCombine::run() { case Instruction::BitCast: MadeChange |= foldBitcastShuffle(I); break; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + MadeChange |= foldBitOpOfBitcasts(I); + break; default: MadeChange |= shrinkType(I); break; diff --git a/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll b/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll index 22e4239009dd2..daf4a7b799dd4 100644 --- a/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll +++ b/llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll @@ -477,30 +477,22 @@ define <2 x i64> @PR66513(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c, <2 x i64> %s ; CHECK-LABEL: @PR66513( ; CHECK-NEXT: [[I:%.*]] = bitcast <2 x i64> [[A:%.*]] to <4 x i32> ; CHECK-NEXT: [[CMP_I23:%.*]] = icmp sgt <4 x i32> [[I]], zeroinitializer -; CHECK-NEXT: [[SEXT_I24:%.*]] = sext <4 x i1> [[CMP_I23]] to <4 x i32> -; CHECK-NEXT: [[I1:%.*]] = bitcast <4 x i32> [[SEXT_I24]] to <2 x i64> ; CHECK-NEXT: [[I2:%.*]] = bitcast <2 x i64> [[B:%.*]] to <4 x i32> ; CHECK-NEXT: [[CMP_I21:%.*]] = icmp sgt <4 x i32> [[I2]], zeroinitializer -; CHECK-NEXT: [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32> -; CHECK-NEXT: [[I3:%.*]] = bitcast <4 x i32> [[SEXT_I22]] to <2 x i64> ; CHECK-NEXT: [[I4:%.*]] = bitcast <2 x i64> [[C:%.*]] to <4 x i32> ; CHECK-NEXT: [[CMP_I:%.*]] = icmp sgt <4 x i32> [[I4]], zeroinitializer -; CHECK-NEXT: [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32> +; CHECK-NEXT: [[NARROW:%.*]] = select <4 x i1> [[CMP_I21]], <4 x i1> [[CMP_I23]], <4 x i1> zeroinitializer +; CHECK-NEXT: [[XOR_I_INNER1:%.*]] = xor <4 x i1> [[NARROW]], [[CMP_I]] +; CHECK-NEXT: [[NARROW3:%.*]] = select <4 x i1> [[CMP_I23]], <4 x i1> [[XOR_I_INNER1]], <4 x i1> zeroinitializer +; CHECK-NEXT: [[AND_I25_INNER2:%.*]] = and <4 x i1> [[XOR_I_INNER1]], [[CMP_I21]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[SRC:%.*]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> [[NARROW]], <4 x i32> [[TMP1]], <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <2 x i64> [[A]] to <4 x i32> +; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> [[NARROW3]], <4 x i32> [[TMP3]], <4 x i32> [[TMP2]] +; CHECK-NEXT: [[TMP5:%.*]] = bitcast <2 x i64> [[B]] to <4 x i32> +; CHECK-NEXT: [[SEXT_I:%.*]] = select <4 x i1> [[AND_I25_INNER2]], <4 x i32> [[TMP5]], <4 x i32> [[TMP4]] ; CHECK-NEXT: [[I5:%.*]] = bitcast <4 x i32> [[SEXT_I]] to <2 x i64> -; CHECK-NEXT: [[AND_I27:%.*]] = and <2 x i64> [[I3]], [[I1]] -; CHECK-NEXT: [[XOR_I:%.*]] = xor <2 x i64> [[AND_I27]], [[I5]] -; CHECK-NEXT: [[AND_I26:%.*]] = and <2 x i64> [[XOR_I]], [[I1]] -; CHECK-NEXT: [[AND_I25:%.*]] = and <2 x i64> [[XOR_I]], [[I3]] -; CHECK-NEXT: [[AND_I:%.*]] = and <2 x i64> [[AND_I27]], [[SRC:%.*]] -; CHECK-NEXT: [[I6:%.*]] = bitcast <2 x i64> [[AND_I]] to <16 x i8> -; CHECK-NEXT: [[I7:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8> -; CHECK-NEXT: [[I8:%.*]] = bitcast <2 x i64> [[AND_I26]] to <16 x i8> -; CHECK-NEXT: [[I9:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[I6]], <16 x i8> [[I7]], <16 x i8> [[I8]]) -; CHECK-NEXT: [[I12:%.*]] = bitcast <2 x i64> [[B]] to <16 x i8> -; CHECK-NEXT: [[I13:%.*]] = bitcast <2 x i64> [[AND_I25]] to <16 x i8> -; CHECK-NEXT: [[I14:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[I9]], <16 x i8> [[I12]], <16 x i8> [[I13]]) -; CHECK-NEXT: [[I15:%.*]] = bitcast <16 x i8> [[I14]] to <2 x i64> -; CHECK-NEXT: ret <2 x i64> [[I15]] +; CHECK-NEXT: ret <2 x i64> [[I5]] ; %i = bitcast <2 x i64> %a to <4 x i32> %cmp.i23 = icmp sgt <4 x i32> %i, zeroinitializer diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll index 3c672efbb5a07..761ad80d560e8 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll @@ -7,9 +7,8 @@ define i32 @test_and(<16 x i32> %a, ptr %b) { ; CHECK-LABEL: @test_and( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1 -; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8> -; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]] -; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A:%.*]] ; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]]) ; CHECK-NEXT: ret i32 [[TMP3]] ; @@ -26,9 +25,8 @@ define i32 @test_mask_or(<16 x i32> %a, ptr %b) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1 ; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], splat (i32 16) -; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8> -; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]] -; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i32> [[TMP0]], [[A_MASKED]] ; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]]) ; CHECK-NEXT: ret i32 [[TMP3]] ; @@ -47,15 +45,13 @@ define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) { ; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], splat (i32 255) ; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], splat (i32 255) ; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1 -; CHECK-NEXT: [[TMP0:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4) -; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8> -; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i8> [[TMP0]], [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[TMP2]] to <16 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = and <16 x i8> [[WIDE_LOAD]], splat (i8 15) -; CHECK-NEXT: [[TMP5:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8> -; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP4]], [[TMP5]] +; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4) ; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32> -; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP7]] +; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i32> [[TMP7]], [[V_MASKED]] +; CHECK-NEXT: [[TMP4:%.*]] = and <16 x i32> [[TMP0]], splat (i32 15) +; CHECK-NEXT: [[TMP5:%.*]] = or <16 x i32> [[TMP4]], [[U_MASKED]] +; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP5]] ; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]]) ; CHECK-NEXT: ret i32 [[TMP9]] ; @@ -81,9 +77,8 @@ define i32 @phi_bug(<16 x i32> %a, ptr %b) { ; CHECK: vector.body: ; CHECK-NEXT: [[A_PHI:%.*]] = phi <16 x i32> [ [[A:%.*]], [[ENTRY:%.*]] ] ; CHECK-NEXT: [[WIDE_LOAD_PHI:%.*]] = phi <16 x i8> [ [[WIDE_LOAD]], [[ENTRY]] ] -; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_PHI]] to <16 x i8> -; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD_PHI]], [[TMP0]] -; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32> +; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD_PHI]] to <16 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A_PHI]] ; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]]) ; CHECK-NEXT: ret i32 [[TMP3]] ;