@@ -35,15 +35,14 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
3535 // Whether this is assigning args for a return.
3636 bool IsRet;
3737
38- RVVArgDispatcher &RVVDispatcher;
38+ // true if assignArg has been called for a mask argument, false otherwise.
39+ bool AssignedFirstMaskArg = false ;
3940
4041public:
4142 RISCVOutgoingValueAssigner (
42- RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
43- RVVArgDispatcher &RVVDispatcher)
43+ RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
4444 : CallLowering::OutgoingValueAssigner(nullptr ),
45- RISCVAssignFn (RISCVAssignFn_), IsRet(IsRet),
46- RVVDispatcher(RVVDispatcher) {}
45+ RISCVAssignFn (RISCVAssignFn_), IsRet(IsRet) {}
4746
4847 bool assignArg (unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
4948 CCValAssign::LocInfo LocInfo,
@@ -53,9 +52,16 @@ struct RISCVOutgoingValueAssigner : public CallLowering::OutgoingValueAssigner {
5352 const DataLayout &DL = MF.getDataLayout ();
5453 const RISCVSubtarget &Subtarget = MF.getSubtarget <RISCVSubtarget>();
5554
55+ std::optional<unsigned > FirstMaskArgument;
56+ if (Subtarget.hasVInstructions () && !AssignedFirstMaskArg &&
57+ ValVT.isVector () && ValVT.getVectorElementType () == MVT::i1) {
58+ FirstMaskArgument = ValNo;
59+ AssignedFirstMaskArg = true ;
60+ }
61+
5662 if (RISCVAssignFn (DL, Subtarget.getTargetABI (), ValNo, ValVT, LocVT,
5763 LocInfo, Flags, State, Info.IsFixed , IsRet, Info.Ty ,
58- *Subtarget.getTargetLowering (), RVVDispatcher ))
64+ *Subtarget.getTargetLowering (), FirstMaskArgument ))
5965 return true ;
6066
6167 StackSize = State.getStackSize ();
@@ -181,15 +187,14 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
181187 // Whether this is assigning args from a return.
182188 bool IsRet;
183189
184- RVVArgDispatcher &RVVDispatcher;
190+ // true if assignArg has been called for a mask argument, false otherwise.
191+ bool AssignedFirstMaskArg = false ;
185192
186193public:
187194 RISCVIncomingValueAssigner (
188- RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet,
189- RVVArgDispatcher &RVVDispatcher)
195+ RISCVTargetLowering::RISCVCCAssignFn *RISCVAssignFn_, bool IsRet)
190196 : CallLowering::IncomingValueAssigner(nullptr ),
191- RISCVAssignFn (RISCVAssignFn_), IsRet(IsRet),
192- RVVDispatcher(RVVDispatcher) {}
197+ RISCVAssignFn (RISCVAssignFn_), IsRet(IsRet) {}
193198
194199 bool assignArg (unsigned ValNo, EVT OrigVT, MVT ValVT, MVT LocVT,
195200 CCValAssign::LocInfo LocInfo,
@@ -202,9 +207,16 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
202207 if (LocVT.isScalableVector ())
203208 MF.getInfo <RISCVMachineFunctionInfo>()->setIsVectorCall ();
204209
210+ std::optional<unsigned > FirstMaskArgument;
211+ if (Subtarget.hasVInstructions () && !AssignedFirstMaskArg &&
212+ ValVT.isVector () && ValVT.getVectorElementType () == MVT::i1) {
213+ FirstMaskArgument = ValNo;
214+ AssignedFirstMaskArg = true ;
215+ }
216+
205217 if (RISCVAssignFn (DL, Subtarget.getTargetABI (), ValNo, ValVT, LocVT,
206218 LocInfo, Flags, State, /* IsFixed=*/ true , IsRet, Info.Ty ,
207- *Subtarget.getTargetLowering (), RVVDispatcher ))
219+ *Subtarget.getTargetLowering (), FirstMaskArgument ))
208220 return true ;
209221
210222 StackSize = State.getStackSize ();
@@ -409,11 +421,9 @@ bool RISCVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
409421 SmallVector<ArgInfo, 4 > SplitRetInfos;
410422 splitToValueTypes (OrigRetInfo, SplitRetInfos, DL, CC);
411423
412- RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
413- ArrayRef (F.getReturnType ())};
414424 RISCVOutgoingValueAssigner Assigner (
415425 CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
416- /* IsRet=*/ true , Dispatcher );
426+ /* IsRet=*/ true );
417427 RISCVOutgoingValueHandler Handler (MIRBuilder, MF.getRegInfo (), Ret);
418428 if (!determineAndHandleAssignments (Handler, Assigner, SplitRetInfos,
419429 MIRBuilder, CC, F.isVarArg ()))
@@ -433,16 +443,24 @@ bool RISCVCallLowering::canLowerReturn(MachineFunction &MF,
433443 CCState CCInfo (CallConv, IsVarArg, MF, ArgLocs,
434444 MF.getFunction ().getContext ());
435445
436- RVVArgDispatcher Dispatcher{&MF, &TLI,
437- ArrayRef (MF.getFunction ().getReturnType ())};
438-
439446 RISCVABI::ABI ABI = MF.getSubtarget <RISCVSubtarget>().getTargetABI ();
447+ const RISCVSubtarget &Subtarget = MF.getSubtarget <RISCVSubtarget>();
448+
449+ std::optional<unsigned > FirstMaskArgument = std::nullopt ;
450+ // Preassign the first mask argument.
451+ if (Subtarget.hasVInstructions ()) {
452+ for (const auto &ArgIdx : enumerate(Outs)) {
453+ MVT ArgVT = MVT::getVT (ArgIdx.value ().Ty );
454+ if (ArgVT.isVector () && ArgVT.getVectorElementType () == MVT::i1)
455+ FirstMaskArgument = ArgIdx.index ();
456+ }
457+ }
440458
441459 for (unsigned I = 0 , E = Outs.size (); I < E; ++I) {
442460 MVT VT = MVT::getVT (Outs[I].Ty );
443461 if (RISCV::CC_RISCV (MF.getDataLayout (), ABI, I, VT, VT, CCValAssign::Full,
444462 Outs[I].Flags [0 ], CCInfo, /* IsFixed=*/ true ,
445- /* isRet=*/ true , nullptr , TLI, Dispatcher ))
463+ /* isRet=*/ true , nullptr , TLI, FirstMaskArgument ))
446464 return false ;
447465 }
448466 return true ;
@@ -552,16 +570,12 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
552570 // correspondingly and appended to SplitArgInfos.
553571 splitToValueTypes (AInfo, SplitArgInfos, DL, CC);
554572
555- TypeList.push_back (Arg.getType ());
556-
557573 ++Index;
558574 }
559575
560- RVVArgDispatcher Dispatcher{&MF, getTLI<RISCVTargetLowering>(),
561- ArrayRef (TypeList)};
562576 RISCVIncomingValueAssigner Assigner (
563577 CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
564- /* IsRet=*/ false , Dispatcher );
578+ /* IsRet=*/ false );
565579 RISCVFormalArgHandler Handler (MIRBuilder, MF.getRegInfo ());
566580
567581 SmallVector<CCValAssign, 16 > ArgLocs;
@@ -599,13 +613,11 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
599613
600614 SmallVector<ArgInfo, 32 > SplitArgInfos;
601615 SmallVector<ISD::OutputArg, 8 > Outs;
602- SmallVector<Type *, 4 > TypeList;
603616 for (auto &AInfo : Info.OrigArgs ) {
604617 // Handle any required unmerging of split value types from a given VReg into
605618 // physical registers. ArgInfo objects are constructed correspondingly and
606619 // appended to SplitArgInfos.
607620 splitToValueTypes (AInfo, SplitArgInfos, DL, CC);
608- TypeList.push_back (AInfo.Ty );
609621 }
610622
611623 // TODO: Support tail calls.
@@ -623,11 +635,9 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
623635 const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo ();
624636 Call.addRegMask (TRI->getCallPreservedMask (MF, Info.CallConv ));
625637
626- RVVArgDispatcher ArgDispatcher{&MF, getTLI<RISCVTargetLowering>(),
627- ArrayRef (TypeList)};
628638 RISCVOutgoingValueAssigner ArgAssigner (
629639 CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
630- /* IsRet=*/ false , ArgDispatcher );
640+ /* IsRet=*/ false );
631641 RISCVOutgoingValueHandler ArgHandler (MIRBuilder, MF.getRegInfo (), Call);
632642 if (!determineAndHandleAssignments (ArgHandler, ArgAssigner, SplitArgInfos,
633643 MIRBuilder, CC, Info.IsVarArg ))
@@ -653,11 +663,9 @@ bool RISCVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
653663 SmallVector<ArgInfo, 4 > SplitRetInfos;
654664 splitToValueTypes (Info.OrigRet , SplitRetInfos, DL, CC);
655665
656- RVVArgDispatcher RetDispatcher{&MF, getTLI<RISCVTargetLowering>(),
657- ArrayRef (F.getReturnType ())};
658666 RISCVIncomingValueAssigner RetAssigner (
659667 CC == CallingConv::Fast ? RISCV::CC_RISCV_FastCC : RISCV::CC_RISCV,
660- /* IsRet=*/ true , RetDispatcher );
668+ /* IsRet=*/ true );
661669 RISCVCallReturnHandler RetHandler (MIRBuilder, MF.getRegInfo (), Call);
662670 if (!determineAndHandleAssignments (RetHandler, RetAssigner, SplitRetInfos,
663671 MIRBuilder, CC, Info.IsVarArg ))
0 commit comments