@@ -702,56 +702,57 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
702
702
// intrinsics.
703
703
setOperationAction (ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
704
704
705
- // FP extload/truncstore is not legal in PTX. We need to expand all these.
706
- for (auto FloatVTs :
707
- {MVT::fp_valuetypes (), MVT::fp_fixedlen_vector_valuetypes ()}) {
708
- for (MVT ValVT : FloatVTs) {
709
- for (MVT MemVT : FloatVTs) {
710
- setLoadExtAction (ISD::EXTLOAD, ValVT, MemVT, Expand);
711
- setTruncStoreAction (ValVT, MemVT, Expand);
712
- }
713
- }
714
- }
715
-
716
- // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
717
- // how they'll be lowered in ISel anyway, and by doing this a little earlier
718
- // we allow for more DAG combine opportunities.
719
- for (auto IntVTs :
720
- {MVT::integer_valuetypes (), MVT::integer_fixedlen_vector_valuetypes ()})
721
- for (MVT ValVT : IntVTs)
722
- for (MVT MemVT : IntVTs)
723
- if (isTypeLegal (ValVT))
724
- setLoadExtAction (ISD::EXTLOAD, ValVT, MemVT, Custom);
705
+ // Turn FP extload into load/fpextend
706
+ setLoadExtAction (ISD::EXTLOAD, MVT::f32 , MVT::f16 , Expand);
707
+ setLoadExtAction (ISD::EXTLOAD, MVT::f64 , MVT::f16 , Expand);
708
+ setLoadExtAction (ISD::EXTLOAD, MVT::f32 , MVT::bf16 , Expand);
709
+ setLoadExtAction (ISD::EXTLOAD, MVT::f64 , MVT::bf16 , Expand);
710
+ setLoadExtAction (ISD::EXTLOAD, MVT::f64 , MVT::f32 , Expand);
711
+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
712
+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
713
+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
714
+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
715
+ setLoadExtAction (ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
716
+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
717
+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
718
+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
719
+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
720
+ setLoadExtAction (ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
721
+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
722
+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
723
+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
724
+ setLoadExtAction (ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
725
+ // Turn FP truncstore into trunc + store.
726
+ // FIXME: vector types should also be expanded
727
+ setTruncStoreAction (MVT::f32 , MVT::f16 , Expand);
728
+ setTruncStoreAction (MVT::f64 , MVT::f16 , Expand);
729
+ setTruncStoreAction (MVT::f32 , MVT::bf16 , Expand);
730
+ setTruncStoreAction (MVT::f64 , MVT::bf16 , Expand);
731
+ setTruncStoreAction (MVT::f64 , MVT::f32 , Expand);
732
+ setTruncStoreAction (MVT::v2f32, MVT::v2f16, Expand);
733
+ setTruncStoreAction (MVT::v2f32, MVT::v2bf16, Expand);
725
734
726
735
// PTX does not support load / store predicate registers
727
- setOperationAction ({ISD::LOAD, ISD::STORE}, MVT::i1, Custom);
736
+ setOperationAction (ISD::LOAD, MVT::i1, Custom);
737
+ setOperationAction (ISD::STORE, MVT::i1, Custom);
738
+
728
739
for (MVT VT : MVT::integer_valuetypes ()) {
729
- setLoadExtAction ({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MVT::i1,
730
- Promote);
740
+ setLoadExtAction (ISD::SEXTLOAD, VT, MVT::i1, Promote);
741
+ setLoadExtAction (ISD::ZEXTLOAD, VT, MVT::i1, Promote);
742
+ setLoadExtAction (ISD::EXTLOAD, VT, MVT::i1, Promote);
731
743
setTruncStoreAction (VT, MVT::i1, Expand);
732
744
}
733
745
734
- // Register custom handling for illegal type loads/stores. We'll try to custom
735
- // lower almost all illegal types and logic in the lowering will discard cases
736
- // we can't handle.
737
- setOperationAction ({ISD::LOAD, ISD::STORE}, {MVT::i128 , MVT::f128 }, Custom);
738
- for (MVT VT : MVT::fixedlen_vector_valuetypes ())
739
- if (!isTypeLegal (VT) && VT.getStoreSizeInBits () <= 256 )
740
- setOperationAction ({ISD::STORE, ISD::LOAD}, VT, Custom);
741
-
742
- // Custom legalization for LDU intrinsics.
743
- // TODO: The logic to lower these is not very robust and we should rewrite it.
744
- // Perhaps LDU should not be represented as an intrinsic at all.
745
- setOperationAction (ISD::INTRINSIC_W_CHAIN, MVT::i8 , Custom);
746
- for (MVT VT : MVT::fixedlen_vector_valuetypes ())
747
- if (IsPTXVectorType (VT))
748
- setOperationAction (ISD::INTRINSIC_W_CHAIN, VT, Custom);
749
-
750
746
setCondCodeAction ({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
751
747
ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
752
748
ISD::SETGE, ISD::SETLE},
753
749
MVT::i1, Expand);
754
750
751
+ // expand extload of vector of integers.
752
+ setLoadExtAction ({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
753
+ MVT::v2i8, Expand);
754
+ setTruncStoreAction (MVT::v2i16, MVT::v2i8, Expand);
755
+
755
756
// This is legal in NVPTX
756
757
setOperationAction (ISD::ConstantFP, MVT::f64 , Legal);
757
758
setOperationAction (ISD::ConstantFP, MVT::f32 , Legal);
@@ -766,12 +767,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
766
767
// DEBUGTRAP can be lowered to PTX brkpt
767
768
setOperationAction (ISD::DEBUGTRAP, MVT::Other, Legal);
768
769
770
+ // Register custom handling for vector loads/stores
771
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ())
772
+ if (IsPTXVectorType (VT))
773
+ setOperationAction ({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT,
774
+ Custom);
775
+
776
+ setOperationAction ({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN},
777
+ {MVT::i128 , MVT::f128 }, Custom);
778
+
769
779
// Support varargs.
770
780
setOperationAction (ISD::VASTART, MVT::Other, Custom);
771
781
setOperationAction (ISD::VAARG, MVT::Other, Custom);
772
782
setOperationAction (ISD::VACOPY, MVT::Other, Expand);
773
783
setOperationAction (ISD::VAEND, MVT::Other, Expand);
774
784
785
+ // Custom handling for i8 intrinsics
786
+ setOperationAction (ISD::INTRINSIC_W_CHAIN, MVT::i8 , Custom);
787
+
775
788
setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
776
789
{MVT::i16 , MVT::i32 , MVT::i64 }, Legal);
777
790
@@ -3079,14 +3092,39 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
3079
3092
SmallVectorImpl<SDValue> &Results,
3080
3093
const NVPTXSubtarget &STI);
3081
3094
3095
+ SDValue NVPTXTargetLowering::LowerLOAD (SDValue Op, SelectionDAG &DAG) const {
3096
+ if (Op.getValueType () == MVT::i1)
3097
+ return LowerLOADi1 (Op, DAG);
3098
+
3099
+ EVT VT = Op.getValueType ();
3100
+
3101
+ if (NVPTX::isPackedVectorTy (VT)) {
3102
+ // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
3103
+ // handle unaligned loads and have to handle it here.
3104
+ LoadSDNode *Load = cast<LoadSDNode>(Op);
3105
+ EVT MemVT = Load->getMemoryVT ();
3106
+ if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3107
+ MemVT, *Load->getMemOperand ())) {
3108
+ SDValue Ops[2 ];
3109
+ std::tie (Ops[0 ], Ops[1 ]) = expandUnalignedLoad (Load, DAG);
3110
+ return DAG.getMergeValues (Ops, SDLoc (Op));
3111
+ }
3112
+ }
3113
+
3114
+ return SDValue ();
3115
+ }
3116
+
3082
3117
// v = ld i1* addr
3083
3118
// =>
3084
3119
// v1 = ld i8* addr (-> i16)
3085
3120
// v = trunc i16 to i1
3086
- static SDValue lowerLOADi1 (LoadSDNode *LD, SelectionDAG &DAG) {
3087
- SDLoc dl (LD);
3121
+ SDValue NVPTXTargetLowering::LowerLOADi1 (SDValue Op, SelectionDAG &DAG) const {
3122
+ SDNode *Node = Op.getNode ();
3123
+ LoadSDNode *LD = cast<LoadSDNode>(Node);
3124
+ SDLoc dl (Node);
3088
3125
assert (LD->getExtensionType () == ISD::NON_EXTLOAD);
3089
- assert (LD->getValueType (0 ) == MVT::i1 && " Custom lowering for i1 load only" );
3126
+ assert (Node->getValueType (0 ) == MVT::i1 &&
3127
+ " Custom lowering for i1 load only" );
3090
3128
SDValue newLD = DAG.getExtLoad (ISD::ZEXTLOAD, dl, MVT::i16 , LD->getChain (),
3091
3129
LD->getBasePtr (), LD->getPointerInfo (),
3092
3130
MVT::i8 , LD->getAlign (),
@@ -3095,27 +3133,8 @@ static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3095
3133
// The legalizer (the caller) is expecting two values from the legalized
3096
3134
// load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
3097
3135
// in LegalizeDAG.cpp which also uses MergeValues.
3098
- return DAG.getMergeValues ({result, LD->getChain ()}, dl);
3099
- }
3100
-
3101
- SDValue NVPTXTargetLowering::LowerLOAD (SDValue Op, SelectionDAG &DAG) const {
3102
- LoadSDNode *LD = cast<LoadSDNode>(Op);
3103
-
3104
- if (Op.getValueType () == MVT::i1)
3105
- return lowerLOADi1 (LD, DAG);
3106
-
3107
- // To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3108
- // how they'll be lowered in ISel anyway, and by doing this a little earlier
3109
- // we allow for more DAG combine opportunities.
3110
- if (LD->getExtensionType () == ISD::EXTLOAD) {
3111
- assert (LD->getValueType (0 ).isInteger () && LD->getMemoryVT ().isInteger () &&
3112
- " Unexpected fpext-load" );
3113
- return DAG.getExtLoad (ISD::ZEXTLOAD, SDLoc (Op), Op.getValueType (),
3114
- LD->getChain (), LD->getBasePtr (), LD->getMemoryVT (),
3115
- LD->getMemOperand ());
3116
- }
3117
-
3118
- llvm_unreachable (" Unexpected custom lowering for load" );
3136
+ SDValue Ops[] = { result, LD->getChain () };
3137
+ return DAG.getMergeValues (Ops, dl);
3119
3138
}
3120
3139
3121
3140
SDValue NVPTXTargetLowering::LowerSTORE (SDValue Op, SelectionDAG &DAG) const {
@@ -3125,6 +3144,17 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3125
3144
if (VT == MVT::i1)
3126
3145
return LowerSTOREi1 (Op, DAG);
3127
3146
3147
+ // v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
3148
+ // handle unaligned stores and have to handle it here.
3149
+ if (NVPTX::isPackedVectorTy (VT) &&
3150
+ !allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3151
+ VT, *Store->getMemOperand ()))
3152
+ return expandUnalignedStore (Store, DAG);
3153
+
3154
+ // v2f16/v2bf16/v2i16 don't need special handling.
3155
+ if (NVPTX::isPackedVectorTy (VT) && VT.is32BitVector ())
3156
+ return SDValue ();
3157
+
3128
3158
// Lower store of any other vector type, including v2f32 as we want to break
3129
3159
// it apart since this is not a widely-supported type.
3130
3160
return LowerSTOREVector (Op, DAG);
@@ -3980,8 +4010,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
3980
4010
case Intrinsic::nvvm_ldu_global_i:
3981
4011
case Intrinsic::nvvm_ldu_global_f:
3982
4012
case Intrinsic::nvvm_ldu_global_p: {
4013
+ auto &DL = I.getDataLayout ();
3983
4014
Info.opc = ISD::INTRINSIC_W_CHAIN;
3984
- Info.memVT = getValueType (I.getDataLayout (), I.getType ());
4015
+ if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
4016
+ Info.memVT = getValueType (DL, I.getType ());
4017
+ else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
4018
+ Info.memVT = getPointerTy (DL);
4019
+ else
4020
+ Info.memVT = getValueType (DL, I.getType ());
3985
4021
Info.ptrVal = I.getArgOperand (0 );
3986
4022
Info.offset = 0 ;
3987
4023
Info.flags = MachineMemOperand::MOLoad;
0 commit comments