diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 01118beb9cf5e..f848bc855ee15 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -618,6 +618,7 @@ namespace { bool onlyUsesZeroFlag(SDValue Flags) const; bool hasNoSignFlagUses(SDValue Flags) const; bool hasNoCarryFlagUses(SDValue Flags) const; + bool checkTCRetRegUsage(SDNode *N, LoadSDNode *Load) const; }; class X86DAGToDAGISelLegacy : public SelectionDAGISelLegacy { @@ -890,6 +891,12 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) { LD->getExtensionType() != ISD::NON_EXTLOAD) return false; + // If the load's outgoing chain has more than one use, we can't (currently) + // move the load since we'd most likely create a loop. TODO: Maybe it could + // work if moveBelowOrigChain() updated *all* the chain users. + if (!Callee.getValue(1).hasOneUse()) + return false; + // Now let's find the callseq_start. while (HasCallSeq && Chain.getOpcode() != ISD::CALLSEQ_START) { if (!Chain.hasOneUse()) @@ -897,20 +904,39 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) { Chain = Chain.getOperand(0); } - if (!Chain.getNumOperands()) - return false; - // Since we are not checking for AA here, conservatively abort if the chain - // writes to memory. It's not safe to move the callee (a load) across a store. - if (isa(Chain.getNode()) && - cast(Chain.getNode())->writeMem()) + while (true) { + if (!Chain.getNumOperands()) + return false; + + // It's not safe to move the callee (a load) across e.g. a store. + // Conservatively abort if the chain contains a node other than the ones + // below. + switch (Chain.getNode()->getOpcode()) { + case ISD::CALLSEQ_START: + case ISD::CopyToReg: + case ISD::LOAD: + break; + default: + return false; + } + + if (Chain.getOperand(0).getNode() == Callee.getNode()) + return true; + if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor && + Chain.getOperand(0).getValue(0).hasOneUse() && + Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) && + Callee.getValue(1).hasOneUse()) + return true; + + // Look past CopyToRegs. We only walk one path, so the chain mustn't branch. + if (Chain.getOperand(0).getOpcode() == ISD::CopyToReg && + Chain.getOperand(0).getValue(0).hasOneUse()) { + Chain = Chain.getOperand(0); + continue; + } + return false; - if (Chain.getOperand(0).getNode() == Callee.getNode()) - return true; - if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor && - Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) && - Callee.getValue(1).hasOneUse()) - return true; - return false; + } } static bool isEndbrImm64(uint64_t Imm) { @@ -1353,6 +1379,11 @@ void X86DAGToDAGISel::PreprocessISelDAG() { (N->getOpcode() == X86ISD::TC_RETURN && (Subtarget->is64Bit() || !getTargetMachine().isPositionIndependent())))) { + + if (N->getOpcode() == X86ISD::TC_RETURN && + !checkTCRetRegUsage(N, nullptr)) + continue; + /// Also try moving call address load from outside callseq_start to just /// before the call to allow it to be folded. /// @@ -3489,6 +3520,47 @@ static bool mayUseCarryFlag(X86::CondCode CC) { return true; } +bool X86DAGToDAGISel::checkTCRetRegUsage(SDNode *N, LoadSDNode *Load) const { + const X86RegisterInfo *RI = Subtarget->getRegisterInfo(); + const TargetRegisterClass *TailCallGPRs = RI->getGPRsForTailCall(*MF); + unsigned MaxGPRs = TailCallGPRs->getNumRegs(); + if (Subtarget->is64Bit()) { + assert(TailCallGPRs->contains(X86::RSP)); + assert(TailCallGPRs->contains(X86::RIP)); + MaxGPRs -= 2; // Can't use RSP or RIP for the address in general. + } else { + assert(TailCallGPRs->contains(X86::ESP)); + MaxGPRs -= 1; // Can't use ESP for the address in general. + } + + // The load's base and index potentially need two registers. + unsigned LoadGPRs = 2; + + if (Load) { + // But not if it's loading from a frame slot or global. + // XXX: Couldn't we be indexing off of the global though? + const SDValue &BasePtr = Load->getBasePtr(); + if (isa(BasePtr)) { + LoadGPRs = 0; + } else if (BasePtr->getNumOperands() && + isa(BasePtr->getOperand(0))) + LoadGPRs = 0; + } + + unsigned TCGPRs = 0; + // X86tcret args: (*chain, ptr, imm, regs..., glue) + for (unsigned I = 3, E = N->getNumOperands(); I != E; ++I) { + if (const auto *RN = dyn_cast(N->getOperand(I))) { + if (!RI->isGeneralPurposeRegister(*MF, RN->getReg())) + continue; + if (++TCGPRs + LoadGPRs > MaxGPRs) + return false; + } + } + + return true; +} + /// Check whether or not the chain ending in StoreNode is suitable for doing /// the {load; op; store} to modify transformation. static bool isFusableLoadOpStorePattern(StoreSDNode *StoreNode, diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td index f9d70d1bb5d85..b8573662b1bcd 100644 --- a/llvm/lib/Target/X86/X86InstrFragments.td +++ b/llvm/lib/Target/X86/X86InstrFragments.td @@ -675,27 +675,12 @@ def X86lock_sub_nocf : PatFrag<(ops node:$lhs, node:$rhs), def X86tcret_6regs : PatFrag<(ops node:$ptr, node:$off), (X86tcret node:$ptr, node:$off), [{ - // X86tcret args: (*chain, ptr, imm, regs..., glue) - unsigned NumRegs = 0; - for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i) - if (isa(N->getOperand(i)) && ++NumRegs > 6) - return false; - return true; + return checkTCRetRegUsage(N, nullptr); }]>; def X86tcret_1reg : PatFrag<(ops node:$ptr, node:$off), (X86tcret node:$ptr, node:$off), [{ - // X86tcret args: (*chain, ptr, imm, regs..., glue) - unsigned NumRegs = 1; - const SDValue& BasePtr = cast(N->getOperand(1))->getBasePtr(); - if (isa(BasePtr)) - NumRegs = 3; - else if (BasePtr->getNumOperands() && isa(BasePtr->getOperand(0))) - NumRegs = 3; - for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i) - if (isa(N->getOperand(i)) && ( NumRegs-- == 0)) - return false; - return true; + return checkTCRetRegUsage(N, cast(N->getOperand(1))); }]>; // If this is an anyext of the remainder of an 8-bit sdivrem, use a MOVSX diff --git a/llvm/test/CodeGen/X86/cfguard-checks.ll b/llvm/test/CodeGen/X86/cfguard-checks.ll index a727bbbfdcbe3..db19efaf910a3 100644 --- a/llvm/test/CodeGen/X86/cfguard-checks.ll +++ b/llvm/test/CodeGen/X86/cfguard-checks.ll @@ -210,8 +210,7 @@ entry: ; X64-LABEL: vmptr_thunk: ; X64: movq (%rcx), %rax ; X64-NEXT: movq 8(%rax), %rax - ; X64-NEXT: movq __guard_dispatch_icall_fptr(%rip), %rdx - ; X64-NEXT: rex64 jmpq *%rdx # TAILCALL + ; X64-NEXT: rex64 jmpq *__guard_dispatch_icall_fptr(%rip) # TAILCALL ; X64-NOT: callq } diff --git a/llvm/test/CodeGen/X86/fold-call-4.ll b/llvm/test/CodeGen/X86/fold-call-4.ll new file mode 100644 index 0000000000000..2c99f2cb62641 --- /dev/null +++ b/llvm/test/CodeGen/X86/fold-call-4.ll @@ -0,0 +1,26 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu | FileCheck %s --check-prefix=LIN +; RUN: llc < %s -mtriple=x86_64-pc-windows-msvc | FileCheck %s --check-prefix=WIN + +; The callee address computation should get folded into the call. +; CHECK-LABEL: f: +; CHECK-NOT: mov +; LIN: jmpq *(%rdi,%rsi,8) +; WIN: rex64 jmpq *(%rcx,%rdx,8) +define void @f(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3) { +entry: + %arrayidx = getelementptr inbounds ptr, ptr %table, i64 %idx + %funcptr = load ptr, ptr %arrayidx, align 8 + tail call void %funcptr(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3) + ret void +} + +; Check that we don't assert here. On Win64 this has a TokenFactor with +; multiple uses, which we can't currently fold. +define void @thunk(ptr %this, ...) { +entry: + %vtable = load ptr, ptr %this, align 8 + %vfn = getelementptr inbounds nuw i8, ptr %vtable, i64 8 + %0 = load ptr, ptr %vfn, align 8 + musttail call void (ptr, ...) %0(ptr %this, ...) + ret void +} diff --git a/llvm/test/CodeGen/X86/fold-call.ll b/llvm/test/CodeGen/X86/fold-call.ll index 8be817618cd92..25b4df778768f 100644 --- a/llvm/test/CodeGen/X86/fold-call.ll +++ b/llvm/test/CodeGen/X86/fold-call.ll @@ -24,3 +24,15 @@ entry: tail call void %0() ret void } + +; Don't fold the load+call if there's inline asm in between. +; CHECK: test3 +; CHECK: mov{{.*}} +; CHECK: jmp{{.*}} +define void @test3(ptr nocapture %x) { +entry: + %0 = load ptr, ptr %x + call void asm sideeffect "", ""() ; It could do anything. + tail call void %0() + ret void +}