@@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
612612bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT (SDNode *N) {
613613 SDValue Vector = N->getOperand (0 );
614614
615- // We only care about 16x2 as it's the only real vector type we
615+ // We only care about f16x2 as it's the only real vector type we
616616 // need to deal with.
617617 MVT VT = Vector.getSimpleValueType ();
618- if (!Isv2x16VT (VT))
618+ if (!(VT == MVT::v2f16 || VT == MVT::v2bf16 ))
619619 return false ;
620620 // Find and record all uses of this vector that extract element 0 or 1.
621621 SmallVector<SDNode *, 4 > E0 , E1 ;
@@ -828,7 +828,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
828828 return Opcode_i16;
829829 case MVT::v2f16:
830830 case MVT::v2bf16:
831- case MVT::v2i16:
832831 return Opcode_i32;
833832 case MVT::f32 :
834833 return Opcode_f32;
@@ -910,8 +909,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
910909 // Vector Setting
911910 unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
912911 if (SimpleVT.isVector ()) {
913- assert (Isv2x16VT (LoadedVT) && " Unexpected vector type" );
914- // v2f16/v2bf16/v2i16 is loaded using ld.b32
912+ assert ((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
913+ " Unexpected vector type" );
914+ // v2f16/v2bf16 is loaded using ld.b32
915915 fromTypeWidth = 32 ;
916916 }
917917
@@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10611061
10621062 EVT EltVT = N->getValueType (0 );
10631063
1064- // v8x16 is a special case. PTX doesn't have ld.v8.16
1065- // instruction. Instead, we split the vector into v2x16 chunks and
1064+ // v8f16 is a special case. PTX doesn't have ld.v8.f16
1065+ // instruction. Instead, we split the vector into v2f16 chunks and
10661066 // load them with ld.v4.b32.
1067- if (Isv2x16VT ( EltVT) ) {
1067+ if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
10681068 assert (N->getOpcode () == NVPTXISD::LoadV4 && " Unexpected load opcode." );
10691069 EltVT = MVT::i32 ;
10701070 FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1260,13 +1260,12 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12601260 if (EltVT.isVector ()) {
12611261 NumElts = EltVT.getVectorNumElements ();
12621262 EltVT = EltVT.getVectorElementType ();
1263- // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1263+ // vectors of f16 are loaded/stored as multiples of v2f16 elements.
12641264 if ((EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) ||
1265- (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16) ||
1266- (EltVT == MVT::i16 && N->getValueType (0 ) == MVT::v2i16)) {
1267- assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1268- EltVT = N->getValueType (0 );
1269- NumElts /= 2 ;
1265+ (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16)) {
1266+ assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1267+ EltVT = N->getValueType (0 );
1268+ NumElts /= 2 ;
12701269 }
12711270 }
12721271
@@ -1679,8 +1678,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16791678 MVT ScalarVT = SimpleVT.getScalarType ();
16801679 unsigned toTypeWidth = ScalarVT.getSizeInBits ();
16811680 if (SimpleVT.isVector ()) {
1682- assert (Isv2x16VT (StoreVT) && " Unexpected vector type" );
1683- // v2x16 is stored using st.b32
1681+ assert ((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1682+ " Unexpected vector type" );
1683+ // v2f16 is stored using st.b32
16841684 toTypeWidth = 32 ;
16851685 }
16861686
@@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18441844 return false ;
18451845 }
18461846
1847- // v8x16 is a special case. PTX doesn't have st.v8.x16
1848- // instruction. Instead, we split the vector into v2x16 chunks and
1847+ // v8f16 is a special case. PTX doesn't have st.v8.f16
1848+ // instruction. Instead, we split the vector into v2f16 chunks and
18491849 // store them with st.v4.b32.
1850- if (Isv2x16VT ( EltVT) ) {
1850+ if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
18511851 assert (N->getOpcode () == NVPTXISD::StoreV4 && " Unexpected load opcode." );
18521852 EltVT = MVT::i32 ;
18531853 ToType = NVPTX::PTXLdStInstCode::Untyped;
0 commit comments