Skip to content

Commit b438ea3

Browse files
authored
use SPIR-V headers for SPV_INTEL_bfloat16_conversion (#2969)
Support for the SPV_INTEL_bfloat16_conversion extension is in the SPIR-V headers now, so we no longer need to include it in spirv_internal.
1 parent 0d3a1a2 commit b438ea3

File tree

7 files changed

+12
-20
lines changed

7 files changed

+12
-20
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,7 @@ void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
18931893
}
18941894
}
18951895

1896-
mutateCallInst(CI, internal::OpConvertFToBF16INTEL);
1896+
mutateCallInst(CI, OpConvertFToBF16INTEL);
18971897
}
18981898

18991899
void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
@@ -1936,7 +1936,7 @@ void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
19361936
}
19371937
}
19381938

1939-
mutateCallInst(CI, internal::OpConvertBF16ToFINTEL);
1939+
mutateCallInst(CI, OpConvertBF16ToFINTEL);
19401940
}
19411941
} // namespace SPIRV
19421942

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
219219
visitCallSPIRVReadClockKHR(&CI);
220220
return;
221221
}
222-
if (OC == internal::OpConvertFToBF16INTEL ||
223-
OC == internal::OpConvertBF16ToFINTEL) {
222+
if (OC == OpConvertFToBF16INTEL || OC == OpConvertBF16ToFINTEL) {
224223
visitCallSPIRVBFloat16Conversions(&CI, OC);
225224
return;
226225
}
@@ -928,10 +927,10 @@ void SPIRVToOCLBase::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
928927
: "";
929928
std::string Name;
930929
switch (static_cast<uint32_t>(OC)) {
931-
case internal::OpConvertFToBF16INTEL:
930+
case OpConvertFToBF16INTEL:
932931
Name = "intel_convert_bfloat16" + N + "_as_ushort" + N;
933932
break;
934-
case internal::OpConvertBF16ToFINTEL:
933+
case OpConvertBF16ToFINTEL:
935934
Name = "intel_convert_as_bfloat16" + N + "_float" + N;
936935
break;
937936
default:

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,9 +3558,9 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
35583558
SPIRVCapVec getRequiredCapability() const override {
35593559
SPIRVType *ResCompTy = this->getType();
35603560
if (ResCompTy->isTypeCooperativeMatrixKHR())
3561-
return getVec(internal::CapabilityBfloat16ConversionINTEL,
3561+
return getVec(CapabilityBFloat16ConversionINTEL,
35623562
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
3563-
return getVec(internal::CapabilityBfloat16ConversionINTEL);
3563+
return getVec(CapabilityBFloat16ConversionINTEL);
35643564
}
35653565

35663566
std::optional<ExtensionID> getRequiredExtension() const override {
@@ -3614,7 +3614,7 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
36143614
InCompTy =
36153615
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
36163616
}
3617-
if (OC == internal::OpConvertFToBF16INTEL) {
3617+
if (OC == OpConvertFToBF16INTEL) {
36183618
SPVErrLog.checkError(
36193619
ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
36203620
InstName + "\nResult value must be a scalar or vector of integer "
@@ -3642,7 +3642,7 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
36423642
};
36433643

36443644
#define _SPIRV_OP(x) \
3645-
typedef SPIRVBfloat16ConversionINTELInstBase<internal::Op##x> SPIRV##x;
3645+
typedef SPIRVBfloat16ConversionINTELInstBase<Op##x> SPIRV##x;
36463646
_SPIRV_OP(ConvertFToBF16INTEL)
36473647
_SPIRV_OP(ConvertBF16ToFINTEL)
36483648
#undef _SPIRV_OP

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
626626
add(CapabilityOptNoneEXT, "OptNoneEXT");
627627
add(CapabilityAtomicFloat16AddEXT, "AtomicFloat16AddEXT");
628628
add(CapabilityDebugInfoModuleINTEL, "DebugInfoModuleINTEL");
629+
add(CapabilityBFloat16ConversionINTEL, "Bfloat16ConversionINTEL");
629630
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");
630631
add(CapabilityGlobalVariableFPGADecorationsINTEL,
631632
"GlobalVariableFPGADecorationsINTEL");
@@ -644,7 +645,6 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
644645
// From spirv_internal.hpp
645646
add(internal::CapabilityFastCompositeINTEL, "FastCompositeINTEL");
646647
add(internal::CapabilityTokenTypeINTEL, "TokenTypeINTEL");
647-
add(internal::CapabilityBfloat16ConversionINTEL, "Bfloat16ConversionINTEL");
648648
add(internal::CapabilityJointMatrixINTEL, "JointMatrixINTEL");
649649
add(internal::CapabilityHWThreadQueryINTEL, "HWThreadQueryINTEL");
650650
add(internal::CapabilityGlobalVariableDecorationsINTEL,

lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,8 @@ _SPIRV_OP(TypeBufferSurfaceINTEL, 6086)
568568
_SPIRV_OP(TypeStructContinuedINTEL, 6090)
569569
_SPIRV_OP(ConstantCompositeContinuedINTEL, 6091)
570570
_SPIRV_OP(SpecConstantCompositeContinuedINTEL, 6092)
571+
_SPIRV_OP(ConvertFToBF16INTEL, 6116)
572+
_SPIRV_OP(ConvertBF16ToFINTEL, 6117)
571573
_SPIRV_OP(ControlBarrierArriveINTEL, 6142)
572574
_SPIRV_OP(ControlBarrierWaitINTEL, 6143)
573575
_SPIRV_OP(ArithmeticFenceEXT, 6145)

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
_SPIRV_OP_INTERNAL(Forward, internal::OpForward)
44
_SPIRV_OP_INTERNAL(TypeTokenINTEL, internal::OpTypeTokenINTEL)
5-
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
6-
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
75
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
86
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
97
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ enum InternalLinkageType {
5959

6060
enum InternalOp {
6161
IOpTypeTokenINTEL = 6113,
62-
IOpConvertFToBF16INTEL = 6116,
63-
IOpConvertBF16ToFINTEL = 6117,
6462
IOpTypeJointMatrixINTEL = 6119,
6563
IOpJointMatrixLoadINTEL = 6120,
6664
IOpJointMatrixStoreINTEL = 6121,
@@ -107,7 +105,6 @@ enum InternalDecoration {
107105
enum InternalCapability {
108106
ICapFastCompositeINTEL = 6093,
109107
ICapTokenTypeINTEL = 6112,
110-
ICapBfloat16ConversionINTEL = 6115,
111108
ICapabilityJointMatrixINTEL = 6118,
112109
ICapabilityHWThreadQueryINTEL = 6134,
113110
ICapGlobalVariableDecorationsINTEL = 6146,
@@ -267,8 +264,6 @@ constexpr SourceLanguage SourceLanguageCPP20 =
267264

268265
constexpr Op OpForward = static_cast<Op>(IOpForward);
269266
constexpr Op OpTypeTokenINTEL = static_cast<Op>(IOpTypeTokenINTEL);
270-
constexpr Op OpConvertFToBF16INTEL = static_cast<Op>(IOpConvertFToBF16INTEL);
271-
constexpr Op OpConvertBF16ToFINTEL = static_cast<Op>(IOpConvertBF16ToFINTEL);
272267

273268
constexpr Decoration DecorationCallableFunctionINTEL =
274269
static_cast<Decoration>(IDecCallableFunctionINTEL);
@@ -287,8 +282,6 @@ constexpr Capability CapabilityFastCompositeINTEL =
287282
static_cast<Capability>(ICapFastCompositeINTEL);
288283
constexpr Capability CapabilityTokenTypeINTEL =
289284
static_cast<Capability>(ICapTokenTypeINTEL);
290-
constexpr Capability CapabilityBfloat16ConversionINTEL =
291-
static_cast<Capability>(ICapBfloat16ConversionINTEL);
292285
constexpr Capability CapabilityGlobalVariableDecorationsINTEL =
293286
static_cast<Capability>(ICapGlobalVariableDecorationsINTEL);
294287

0 commit comments

Comments
 (0)