@@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
612612bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT (SDNode *N) {
613613 SDValue Vector = N->getOperand (0 );
614614
615- // We only care about f16x2 as it's the only real vector type we
615+ // We only care about 16x2 as it's the only real vector type we
616616 // need to deal with.
617617 MVT VT = Vector.getSimpleValueType ();
618- if (!(VT == MVT::v2f16 || VT == MVT::v2bf16 ))
618+ if (!Isv2x16VT (VT))
619619 return false ;
620620 // Find and record all uses of this vector that extract element 0 or 1.
621621 SmallVector<SDNode *, 4 > E0 , E1 ;
@@ -828,6 +828,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
828828 return Opcode_i16;
829829 case MVT::v2f16:
830830 case MVT::v2bf16:
831+ case MVT::v2i16:
831832 return Opcode_i32;
832833 case MVT::f32 :
833834 return Opcode_f32;
@@ -909,9 +910,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
909910 // Vector Setting
910911 unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
911912 if (SimpleVT.isVector ()) {
912- assert ((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
913- " Unexpected vector type" );
914- // v2f16/v2bf16 is loaded using ld.b32
913+ assert (Isv2x16VT (LoadedVT) && " Unexpected vector type" );
914+ // v2f16/v2bf16/v2i16 is loaded using ld.b32
915915 fromTypeWidth = 32 ;
916916 }
917917
@@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10611061
10621062 EVT EltVT = N->getValueType (0 );
10631063
1064- // v8f16 is a special case. PTX doesn't have ld.v8.f16
1065- // instruction. Instead, we split the vector into v2f16 chunks and
1064+ // v8x16 is a special case. PTX doesn't have ld.v8.16
1065+ // instruction. Instead, we split the vector into v2x16 chunks and
10661066 // load them with ld.v4.b32.
1067- if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1067+ if (Isv2x16VT ( EltVT) ) {
10681068 assert (N->getOpcode () == NVPTXISD::LoadV4 && " Unexpected load opcode." );
10691069 EltVT = MVT::i32 ;
10701070 FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1260,12 +1260,13 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12601260 if (EltVT.isVector ()) {
12611261 NumElts = EltVT.getVectorNumElements ();
12621262 EltVT = EltVT.getVectorElementType ();
1263- // vectors of f16 are loaded/stored as multiples of v2f16 elements.
1263+ // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
12641264 if ((EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) ||
1265- (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16)) {
1266- assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1267- EltVT = N->getValueType (0 );
1268- NumElts /= 2 ;
1265+ (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16) ||
1266+ (EltVT == MVT::i16 && N->getValueType (0 ) == MVT::v2i16)) {
1267+ assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1268+ EltVT = N->getValueType (0 );
1269+ NumElts /= 2 ;
12691270 }
12701271 }
12711272
@@ -1678,9 +1679,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16781679 MVT ScalarVT = SimpleVT.getScalarType ();
16791680 unsigned toTypeWidth = ScalarVT.getSizeInBits ();
16801681 if (SimpleVT.isVector ()) {
1681- assert ((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1682- " Unexpected vector type" );
1683- // v2f16 is stored using st.b32
1682+ assert (Isv2x16VT (StoreVT) && " Unexpected vector type" );
1683+ // v2x16 is stored using st.b32
16841684 toTypeWidth = 32 ;
16851685 }
16861686
@@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18441844 return false ;
18451845 }
18461846
1847- // v8f16 is a special case. PTX doesn't have st.v8.f16
1848- // instruction. Instead, we split the vector into v2f16 chunks and
1847+ // v8x16 is a special case. PTX doesn't have st.v8.x16
1848+ // instruction. Instead, we split the vector into v2x16 chunks and
18491849 // store them with st.v4.b32.
1850- if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1850+ if (Isv2x16VT ( EltVT) ) {
18511851 assert (N->getOpcode () == NVPTXISD::StoreV4 && " Unexpected load opcode." );
18521852 EltVT = MVT::i32 ;
18531853 ToType = NVPTX::PTXLdStInstCode::Untyped;
0 commit comments