@@ -610,9 +610,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
610610bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT (SDNode *N) {
611611 SDValue Vector = N->getOperand (0 );
612612
613- // We only care about f16x2 as it's the only real vector type we
613+ // We only care about 16x2 as it's the only real vector type we
614614 // need to deal with.
615- if (Vector.getSimpleValueType () != MVT::v2f16)
615+ MVT VT = Vector.getSimpleValueType ();
616+ if (!Isv2x16VT (VT))
616617 return false ;
617618
618619 // Find and record all uses of this vector that extract element 0 or 1.
@@ -825,6 +826,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
825826 return Opcode_i16;
826827 case MVT::v2f16:
827828 case MVT::v2bf16:
829+ case MVT::v2i16:
828830 return Opcode_i32;
829831 case MVT::f32 :
830832 return Opcode_f32;
@@ -906,9 +908,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
906908 // Vector Setting
907909 unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
908910 if (SimpleVT.isVector ()) {
909- assert ((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
910- " Unexpected vector type" );
911- // v2f16/v2bf16 is loaded using ld.b32
911+ assert (Isv2x16VT (LoadedVT) && " Unexpected vector type" );
912+ // v2f16/v2bf16/v2i16 is loaded using ld.b32
912913 fromTypeWidth = 32 ;
913914 }
914915
@@ -1058,10 +1059,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10581059
10591060 EVT EltVT = N->getValueType (0 );
10601061
1061- // v8f16 is a special case. PTX doesn't have ld.v8.f16
1062- // instruction. Instead, we split the vector into v2f16 chunks and
1062+ // v8x16 is a special case. PTX doesn't have ld.v8.16
1063+ // instruction. Instead, we split the vector into v2x16 chunks and
10631064 // load them with ld.v4.b32.
1064- if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1065+ if (Isv2x16VT ( EltVT) ) {
10651066 assert (N->getOpcode () == NVPTXISD::LoadV4 && " Unexpected load opcode." );
10661067 EltVT = MVT::i32 ;
10671068 FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1257,10 +1258,12 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12571258 if (EltVT.isVector ()) {
12581259 NumElts = EltVT.getVectorNumElements ();
12591260 EltVT = EltVT.getVectorElementType ();
1260- // vectors of f16 are loaded/stored as multiples of v2f16 elements.
1261- if (EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) {
1261+ // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1262+ if ((EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) ||
1263+ (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16) ||
1264+ (EltVT == MVT::i16 && N->getValueType (0 ) == MVT::v2i16)) {
12621265 assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1263- EltVT = MVT::v2f16 ;
1266+ EltVT = N-> getValueType ( 0 ) ;
12641267 NumElts /= 2 ;
12651268 }
12661269 }
@@ -1674,9 +1677,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16741677 MVT ScalarVT = SimpleVT.getScalarType ();
16751678 unsigned toTypeWidth = ScalarVT.getSizeInBits ();
16761679 if (SimpleVT.isVector ()) {
1677- assert ((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1678- " Unexpected vector type" );
1679- // v2f16 is stored using st.b32
1680+ assert (Isv2x16VT (StoreVT) && " Unexpected vector type" );
1681+ // v2x16 is stored using st.b32
16801682 toTypeWidth = 32 ;
16811683 }
16821684
@@ -1840,10 +1842,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18401842 return false ;
18411843 }
18421844
1843- // v8f16 is a special case. PTX doesn't have st.v8.f16
1844- // instruction. Instead, we split the vector into v2f16 chunks and
1845+ // v8x16 is a special case. PTX doesn't have st.v8.x16
1846+ // instruction. Instead, we split the vector into v2x16 chunks and
18451847 // store them with st.v4.b32.
1846- if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1848+ if (Isv2x16VT ( EltVT) ) {
18471849 assert (N->getOpcode () == NVPTXISD::StoreV4 && " Unexpected load opcode." );
18481850 EltVT = MVT::i32 ;
18491851 ToType = NVPTX::PTXLdStInstCode::Untyped;
0 commit comments