@@ -626,6 +626,20 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
626626// simply not true.
627627SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType (
628628 SmallVector<std::string, 8 > Postfixes) {
629+ auto ParseInteger = [this ](StringRef Postfix) -> ConstantInt * {
630+ unsigned long long N = 0 ;
631+ if (consumeUnsignedInteger (Postfix, 10 , N)) {
632+ BM->getErrorLog ().checkError (
633+ false , SPIRVEC_InvalidLlvmModule,
634+ " TypeJointMatrixINTEL expects integer parameters" );
635+ return 0 ;
636+ }
637+ return getUInt32 (M, N);
638+ };
639+ std::vector<SPIRVValue *> Args;
640+ for (size_t I = 1 ; I != Postfixes.size (); ++I)
641+ Args.emplace_back (transConstant (ParseInteger (Postfixes[I])));
642+
629643 Type *ElemTy = nullptr ;
630644 StringRef Ty{Postfixes[0 ]};
631645 auto NumBits = llvm::StringSwitch<unsigned >(Ty)
@@ -634,32 +648,27 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
634648 .Case (" int" , 32 )
635649 .Case (" long" , 64 )
636650 .Default (0 );
637- if (NumBits)
651+ if (NumBits) {
638652 ElemTy = IntegerType::get (M->getContext (), NumBits);
639- else if (Ty == " half" )
653+ } else if (Ty == " half" ) {
640654 ElemTy = Type::getHalfTy (M->getContext ());
641- else if (Ty == " float" )
655+ } else if (Ty == " float" ) {
642656 ElemTy = Type::getFloatTy (M->getContext ());
643- else if (Ty == " double" )
657+ } else if (Ty == " double" ) {
644658 ElemTy = Type::getDoubleTy (M->getContext ());
645- else if (Ty == " bfloat16" )
659+ } else if (Ty == " bfloat16" ) {
646660 ElemTy = Type::getInt16Ty (M->getContext ());
647- else
661+ auto *CTI = transConstant (getUInt32 (M, static_cast <uint64_t >(
662+ internal::InternalJointMatrixCTI::Bfloat16)));
663+ Args.push_back (CTI);
664+ } else if (Ty == " tf32" ) {
665+ ElemTy = Type::getFloatTy (M->getContext ());
666+ auto *CTI = transConstant (getUInt32 (M, static_cast <uint64_t >(
667+ internal::InternalJointMatrixCTI::TF32)));
668+ Args.push_back (CTI);
669+ } else {
648670 llvm_unreachable (" Unexpected type for matrix!" );
649-
650- auto ParseInteger = [this ](StringRef Postfix) -> ConstantInt * {
651- unsigned long long N = 0 ;
652- if (consumeUnsignedInteger (Postfix, 10 , N)) {
653- BM->getErrorLog ().checkError (
654- false , SPIRVEC_InvalidLlvmModule,
655- " TypeJointMatrixINTEL expects integer parameters" );
656- return 0 ;
657- }
658- return getUInt32 (M, N);
659- };
660- std::vector<SPIRVValue *> Args;
661- for (size_t I = 1 ; I != Postfixes.size (); ++I)
662- Args.emplace_back (transConstant (ParseInteger (Postfixes[I])));
671+ }
663672 return BM->addJointMatrixINTELType (transType (ElemTy), Args);
664673}
665674
0 commit comments