From 5fb9a0999fee13b412f0c38ab116c2ab393f9039 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Wed, 13 Aug 2025 14:53:04 +0200 Subject: [PATCH] [CodeGen] Make OrigTy in CC lowering the non-aggregate type https://github.com/llvm/llvm-project/pull/152709 exposed the original IR argument type to the CC lowering logic. However, in SDAG, this used the raw type, prior to aggregate splitting. This PR changes it to use the non-aggregate type instead. (This matches what happened in the GlobalISel case already.) I've also added some more detailed documentation on the InputArg/OutputArg fields, to explain how they differ. In most cases ArgVT is going to be the EVT of OrigTy, so they encode very similar information (OrigTy just preserves some additional information lost in EVTs, like pointer types). One case where they do differ is in post-legalization lowering of libcalls, where ArgVT is going to be a legalized type, while OrigTy is going to be the original non-legalized type. --- llvm/include/llvm/CodeGen/Analysis.h | 7 +++ llvm/include/llvm/CodeGen/TargetCallingConv.h | 11 ++++ llvm/lib/CodeGen/Analysis.cpp | 52 +++++++++------- .../SelectionDAG/SelectionDAGBuilder.cpp | 60 ++++++++++--------- llvm/lib/CodeGen/TargetLoweringBase.cpp | 12 ++-- 5 files changed, 86 insertions(+), 56 deletions(-) diff --git a/llvm/include/llvm/CodeGen/Analysis.h b/llvm/include/llvm/CodeGen/Analysis.h index 362cc30bbd06a..98b52579d03b7 100644 --- a/llvm/include/llvm/CodeGen/Analysis.h +++ b/llvm/include/llvm/CodeGen/Analysis.h @@ -55,6 +55,13 @@ inline unsigned ComputeLinearIndex(Type *Ty, return ComputeLinearIndex(Ty, Indices.begin(), Indices.end(), CurIndex); } +/// Given an LLVM IR type, compute non-aggregate subtypes. Optionally also +/// compute their offsets. +void ComputeValueTypes(const DataLayout &DL, Type *Ty, + SmallVectorImpl &Types, + SmallVectorImpl *Offsets = nullptr, + TypeSize StartingOffset = TypeSize::getZero()); + /// ComputeValueVTs - Given an LLVM IR type, compute a sequence of /// EVTs that represent all the individual underlying /// non-aggregate types that comprise it. diff --git a/llvm/include/llvm/CodeGen/TargetCallingConv.h b/llvm/include/llvm/CodeGen/TargetCallingConv.h index aa8af696c6e62..f197c7f1645ec 100644 --- a/llvm/include/llvm/CodeGen/TargetCallingConv.h +++ b/llvm/include/llvm/CodeGen/TargetCallingConv.h @@ -203,8 +203,14 @@ namespace ISD { /// struct InputArg { ArgFlagsTy Flags; + /// Legalized type of this argument part. MVT VT = MVT::Other; + /// Usually the non-legalized type of the argument, which is the EVT + /// corresponding to the OrigTy IR type. However, for post-legalization + /// libcalls, this will be a legalized type. EVT ArgVT; + /// Original IR type of the argument. For aggregates, this is the type of + /// an individual aggregate element, not the whole aggregate. Type *OrigTy; bool Used; @@ -239,8 +245,13 @@ namespace ISD { /// struct OutputArg { ArgFlagsTy Flags; + // Legalized type of this argument part. MVT VT; + /// Non-legalized type of the argument. This is the EVT corresponding to + /// the OrigTy IR type. EVT ArgVT; + /// Original IR type of the argument. For aggregates, this is the type of + /// an individual aggregate element, not the whole aggregate. Type *OrigTy; /// Index original Function's argument. diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp index e7b9417de8c9f..2ef96cc4400f7 100644 --- a/llvm/lib/CodeGen/Analysis.cpp +++ b/llvm/lib/CodeGen/Analysis.cpp @@ -69,18 +69,10 @@ unsigned llvm::ComputeLinearIndex(Type *Ty, return CurIndex + 1; } -/// ComputeValueVTs - Given an LLVM IR type, compute a sequence of -/// EVTs that represent all the individual underlying -/// non-aggregate types that comprise it. -/// -/// If Offsets is non-null, it points to a vector to be filled in -/// with the in-memory offsets of each of the individual values. -/// -void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, - Type *Ty, SmallVectorImpl &ValueVTs, - SmallVectorImpl *MemVTs, - SmallVectorImpl *Offsets, - TypeSize StartingOffset) { +void llvm::ComputeValueTypes(const DataLayout &DL, Type *Ty, + SmallVectorImpl &Types, + SmallVectorImpl *Offsets, + TypeSize StartingOffset) { assert((Ty->isScalableTy() == StartingOffset.isScalable() || StartingOffset.isZero()) && "Offset/TypeSize mismatch!"); @@ -90,15 +82,13 @@ void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, // us to support structs with scalable vectors for operations that don't // need offsets. const StructLayout *SL = Offsets ? DL.getStructLayout(STy) : nullptr; - for (StructType::element_iterator EB = STy->element_begin(), - EI = EB, + for (StructType::element_iterator EB = STy->element_begin(), EI = EB, EE = STy->element_end(); EI != EE; ++EI) { // Don't compute the element offset if we didn't get a StructLayout above. TypeSize EltOffset = SL ? SL->getElementOffset(EI - EB) : TypeSize::getZero(); - ComputeValueVTs(TLI, DL, *EI, ValueVTs, MemVTs, Offsets, - StartingOffset + EltOffset); + ComputeValueTypes(DL, *EI, Types, Offsets, StartingOffset + EltOffset); } return; } @@ -107,21 +97,39 @@ void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *EltTy = ATy->getElementType(); TypeSize EltSize = DL.getTypeAllocSize(EltTy); for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i) - ComputeValueVTs(TLI, DL, EltTy, ValueVTs, MemVTs, Offsets, - StartingOffset + i * EltSize); + ComputeValueTypes(DL, EltTy, Types, Offsets, + StartingOffset + i * EltSize); return; } // Interpret void as zero return values. if (Ty->isVoidTy()) return; - // Base case: we can get an EVT for this LLVM IR type. - ValueVTs.push_back(TLI.getValueType(DL, Ty)); - if (MemVTs) - MemVTs->push_back(TLI.getMemValueType(DL, Ty)); + Types.push_back(Ty); if (Offsets) Offsets->push_back(StartingOffset); } +/// ComputeValueVTs - Given an LLVM IR type, compute a sequence of +/// EVTs that represent all the individual underlying +/// non-aggregate types that comprise it. +/// +/// If Offsets is non-null, it points to a vector to be filled in +/// with the in-memory offsets of each of the individual values. +/// +void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, + Type *Ty, SmallVectorImpl &ValueVTs, + SmallVectorImpl *MemVTs, + SmallVectorImpl *Offsets, + TypeSize StartingOffset) { + SmallVector Types; + ComputeValueTypes(DL, Ty, Types, Offsets, StartingOffset); + for (Type *Ty : Types) { + ValueVTs.push_back(TLI.getValueType(DL, Ty)); + if (MemVTs) + MemVTs->push_back(TLI.getMemValueType(DL, Ty)); + } +} + void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, Type *Ty, SmallVectorImpl &ValueVTs, SmallVectorImpl *MemVTs, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index a71b4409a6b21..366a230eef952 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2211,9 +2211,9 @@ void SelectionDAGBuilder::visitRet(const ReturnInst &I) { Chain = DAG.getNode(ISD::TokenFactor, getCurSDLoc(), MVT::Other, Chains); } else if (I.getNumOperands() != 0) { - SmallVector ValueVTs; - ComputeValueVTs(TLI, DL, I.getOperand(0)->getType(), ValueVTs); - unsigned NumValues = ValueVTs.size(); + SmallVector Types; + ComputeValueTypes(DL, I.getOperand(0)->getType(), Types); + unsigned NumValues = Types.size(); if (NumValues) { SDValue RetOp = getValue(I.getOperand(0)); @@ -2233,7 +2233,7 @@ void SelectionDAGBuilder::visitRet(const ReturnInst &I) { bool RetInReg = F->getAttributes().hasRetAttr(Attribute::InReg); for (unsigned j = 0; j != NumValues; ++j) { - EVT VT = ValueVTs[j]; + EVT VT = TLI.getValueType(DL, Types[j]); if (ExtendKind != ISD::ANY_EXTEND && VT.isInteger()) VT = TLI.getTypeForExtReturn(Context, VT, ExtendKind); @@ -2275,7 +2275,7 @@ void SelectionDAGBuilder::visitRet(const ReturnInst &I) { for (unsigned i = 0; i < NumParts; ++i) { Outs.push_back(ISD::OutputArg(Flags, Parts[i].getValueType().getSimpleVT(), - VT, I.getOperand(0)->getType(), 0, 0)); + VT, Types[j], 0, 0)); OutVals.push_back(Parts[i]); } } @@ -10983,15 +10983,21 @@ std::pair TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const { // Handle the incoming return values from the call. CLI.Ins.clear(); - SmallVector RetTys; + SmallVector RetOrigTys; SmallVector Offsets; auto &DL = CLI.DAG.getDataLayout(); - ComputeValueVTs(*this, DL, CLI.RetTy, RetTys, &Offsets); + ComputeValueTypes(DL, CLI.RetTy, RetOrigTys, &Offsets); + + SmallVector RetTys; + for (Type *Ty : RetOrigTys) + RetTys.push_back(getValueType(DL, Ty)); if (CLI.IsPostTypeLegalization) { // If we are lowering a libcall after legalization, split the return type. + SmallVector OldRetOrigTys; SmallVector OldRetTys; SmallVector OldOffsets; + RetOrigTys.swap(OldRetOrigTys); RetTys.swap(OldRetTys); Offsets.swap(OldOffsets); @@ -11001,6 +11007,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const { MVT RegisterVT = getRegisterType(CLI.RetTy->getContext(), RetVT); unsigned NumRegs = getNumRegisters(CLI.RetTy->getContext(), RetVT); unsigned RegisterVTByteSZ = RegisterVT.getSizeInBits() / 8; + RetOrigTys.append(NumRegs, OldRetOrigTys[i]); RetTys.append(NumRegs, RegisterVT); for (unsigned j = 0; j != NumRegs; ++j) Offsets.push_back(TypeSize::getFixed(Offset + j * RegisterVTByteSZ)); @@ -11069,7 +11076,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const { unsigned NumRegs = getNumRegistersForCallingConv(CLI.RetTy->getContext(), CLI.CallConv, VT); for (unsigned i = 0; i != NumRegs; ++i) { - ISD::InputArg Ret(Flags, RegisterVT, VT, CLI.RetTy, + ISD::InputArg Ret(Flags, RegisterVT, VT, RetOrigTys[I], CLI.IsReturnValueUsed, ISD::InputArg::NoArgIndex, 0); if (CLI.RetTy->isPointerTy()) { Ret.Flags.setPointer(); @@ -11106,18 +11113,18 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const { CLI.Outs.clear(); CLI.OutVals.clear(); for (unsigned i = 0, e = Args.size(); i != e; ++i) { - SmallVector ValueVTs; - ComputeValueVTs(*this, DL, Args[i].Ty, ValueVTs); + SmallVector ArgTys; + ComputeValueTypes(DL, Args[i].Ty, ArgTys); // FIXME: Split arguments if CLI.IsPostTypeLegalization Type *FinalType = Args[i].Ty; if (Args[i].IsByVal) FinalType = Args[i].IndirectType; bool NeedsRegBlock = functionArgumentNeedsConsecutiveRegisters( FinalType, CLI.CallConv, CLI.IsVarArg, DL); - for (unsigned Value = 0, NumValues = ValueVTs.size(); Value != NumValues; + for (unsigned Value = 0, NumValues = ArgTys.size(); Value != NumValues; ++Value) { - EVT VT = ValueVTs[Value]; - Type *ArgTy = VT.getTypeForEVT(CLI.RetTy->getContext()); + Type *ArgTy = ArgTys[Value]; + EVT VT = getValueType(DL, ArgTy); SDValue Op = SDValue(Args[i].Node.getNode(), Args[i].Node.getResNo() + Value); ISD::ArgFlagsTy Flags; @@ -11130,10 +11137,9 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const { if (i >= CLI.NumFixedArgs) Flags.setVarArg(); - if (Args[i].Ty->isPointerTy()) { + if (ArgTy->isPointerTy()) { Flags.setPointer(); - Flags.setPointerAddrSpace( - cast(Args[i].Ty)->getAddressSpace()); + Flags.setPointerAddrSpace(cast(ArgTy)->getAddressSpace()); } if (Args[i].IsZExt) Flags.setZExt(); @@ -11252,7 +11258,7 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const { // For scalable vectors the scalable part is currently handled // by individual targets, so we just use the known minimum size here. ISD::OutputArg MyFlags( - Flags, Parts[j].getValueType().getSimpleVT(), VT, Args[i].Ty, i, + Flags, Parts[j].getValueType().getSimpleVT(), VT, ArgTy, i, j * Parts[j].getValueType().getStoreSize().getKnownMinValue()); if (NumParts > 1 && j == 0) MyFlags.Flags.setSplit(); @@ -11645,8 +11651,8 @@ void SelectionDAGISel::LowerArguments(const Function &F) { // Set up the incoming argument description vector. for (const Argument &Arg : F.args()) { unsigned ArgNo = Arg.getArgNo(); - SmallVector ValueVTs; - ComputeValueVTs(*TLI, DAG.getDataLayout(), Arg.getType(), ValueVTs); + SmallVector Types; + ComputeValueTypes(DAG.getDataLayout(), Arg.getType(), Types); bool isArgValueUsed = !Arg.use_empty(); unsigned PartBase = 0; Type *FinalType = Arg.getType(); @@ -11654,17 +11660,15 @@ void SelectionDAGISel::LowerArguments(const Function &F) { FinalType = Arg.getParamByValType(); bool NeedsRegBlock = TLI->functionArgumentNeedsConsecutiveRegisters( FinalType, F.getCallingConv(), F.isVarArg(), DL); - for (unsigned Value = 0, NumValues = ValueVTs.size(); - Value != NumValues; ++Value) { - EVT VT = ValueVTs[Value]; - Type *ArgTy = VT.getTypeForEVT(*DAG.getContext()); + for (unsigned Value = 0, NumValues = Types.size(); Value != NumValues; + ++Value) { + Type *ArgTy = Types[Value]; + EVT VT = TLI->getValueType(DL, ArgTy); ISD::ArgFlagsTy Flags; - - if (Arg.getType()->isPointerTy()) { + if (ArgTy->isPointerTy()) { Flags.setPointer(); - Flags.setPointerAddrSpace( - cast(Arg.getType())->getAddressSpace()); + Flags.setPointerAddrSpace(cast(ArgTy)->getAddressSpace()); } if (Arg.hasAttribute(Attribute::ZExt)) Flags.setZExt(); @@ -11768,7 +11772,7 @@ void SelectionDAGISel::LowerArguments(const Function &F) { // are responsible for handling scalable vector arguments and // return values. ISD::InputArg MyFlags( - Flags, RegisterVT, VT, Arg.getType(), isArgValueUsed, ArgNo, + Flags, RegisterVT, VT, ArgTy, isArgValueUsed, ArgNo, PartBase + i * RegisterVT.getStoreSize().getKnownMinValue()); if (NumRegs > 1 && i == 0) MyFlags.Flags.setSplit(); diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index 61ff2dfe5be22..350948a92a3ae 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -1738,13 +1738,13 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType, AttributeList attr, SmallVectorImpl &Outs, const TargetLowering &TLI, const DataLayout &DL) { - SmallVector ValueVTs; - ComputeValueVTs(TLI, DL, ReturnType, ValueVTs); - unsigned NumValues = ValueVTs.size(); + SmallVector Types; + ComputeValueTypes(DL, ReturnType, Types); + unsigned NumValues = Types.size(); if (NumValues == 0) return; - for (unsigned j = 0, f = NumValues; j != f; ++j) { - EVT VT = ValueVTs[j]; + for (Type *Ty : Types) { + EVT VT = TLI.getValueType(DL, Ty); ISD::NodeType ExtendKind = ISD::ANY_EXTEND; if (attr.hasRetAttr(Attribute::SExt)) @@ -1772,7 +1772,7 @@ void llvm::GetReturnInfo(CallingConv::ID CC, Type *ReturnType, Flags.setZExt(); for (unsigned i = 0; i < NumParts; ++i) - Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, ReturnType, 0, 0)); + Outs.push_back(ISD::OutputArg(Flags, PartVT, VT, Ty, 0, 0)); } }