1515#include " llvm/IR/Module.h"
1616#include " llvm/Support/DXILABI.h"
1717#include " llvm/Support/ErrorHandling.h"
18+ #include < optional>
1819
1920using namespace llvm ;
2021using namespace llvm ::dxil;
2122
2223constexpr StringLiteral DXILOpNamePrefix = " dx.op." ;
2324
2425namespace {
25-
2626enum OverloadKind : uint16_t {
27+ UNDEFINED = 0 ,
2728 VOID = 1 ,
2829 HALF = 1 << 1 ,
2930 FLOAT = 1 << 2 ,
@@ -36,9 +37,27 @@ enum OverloadKind : uint16_t {
3637 UserDefineType = 1 << 9 ,
3738 ObjectType = 1 << 10 ,
3839};
40+ struct Version {
41+ unsigned Major = 0 ;
42+ unsigned Minor = 0 ;
43+ };
3944
45+ struct OpOverload {
46+ Version DXILVersion;
47+ uint16_t ValidTys;
48+ };
4049} // namespace
4150
51+ struct OpStage {
52+ Version DXILVersion;
53+ uint32_t ValidStages;
54+ };
55+
56+ struct OpAttribute {
57+ Version DXILVersion;
58+ uint32_t ValidAttrs;
59+ };
60+
4261static const char *getOverloadTypeName (OverloadKind Kind) {
4362 switch (Kind) {
4463 case OverloadKind::HALF:
@@ -58,12 +77,13 @@ static const char *getOverloadTypeName(OverloadKind Kind) {
5877 case OverloadKind::I64:
5978 return " i64" ;
6079 case OverloadKind::VOID:
80+ case OverloadKind::UNDEFINED:
81+ return " void" ;
6182 case OverloadKind::ObjectType:
6283 case OverloadKind::UserDefineType:
6384 break ;
6485 }
6586 llvm_unreachable (" invalid overload type for name" );
66- return " void" ;
6787}
6888
6989static OverloadKind getOverloadKind (Type *Ty) {
@@ -131,8 +151,9 @@ struct OpCodeProperty {
131151 dxil::OpCodeClass OpCodeClass;
132152 // Offset in DXILOpCodeClassNameTable.
133153 unsigned OpCodeClassNameOffset;
134- uint16_t OverloadTys;
135- llvm::Attribute::AttrKind FuncAttr;
154+ llvm::SmallVector<OpOverload> Overloads;
155+ llvm::SmallVector<OpStage> Stages;
156+ llvm::SmallVector<OpAttribute> Attributes;
136157 int OverloadParamIndex; // parameter index which control the overload.
137158 // When < 0, should be only 1 overload type.
138159 unsigned NumOfParameters; // Number of parameters include return value.
@@ -221,6 +242,45 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
221242 return nullptr ;
222243}
223244
245+ static ShaderKind getShaderKindEnum (Triple::EnvironmentType EnvType) {
246+ switch (EnvType) {
247+ case Triple::Pixel:
248+ return ShaderKind::pixel;
249+ case Triple::Vertex:
250+ return ShaderKind::vertex;
251+ case Triple::Geometry:
252+ return ShaderKind::geometry;
253+ case Triple::Hull:
254+ return ShaderKind::hull;
255+ case Triple::Domain:
256+ return ShaderKind::domain;
257+ case Triple::Compute:
258+ return ShaderKind::compute;
259+ case Triple::Library:
260+ return ShaderKind::library;
261+ case Triple::RayGeneration:
262+ return ShaderKind::raygeneration;
263+ case Triple::Intersection:
264+ return ShaderKind::intersection;
265+ case Triple::AnyHit:
266+ return ShaderKind::anyhit;
267+ case Triple::ClosestHit:
268+ return ShaderKind::closesthit;
269+ case Triple::Miss:
270+ return ShaderKind::miss;
271+ case Triple::Callable:
272+ return ShaderKind::callable;
273+ case Triple::Mesh:
274+ return ShaderKind::mesh;
275+ case Triple::Amplification:
276+ return ShaderKind::amplification;
277+ default :
278+ break ;
279+ }
280+ llvm_unreachable (
281+ " Shader Kind Not Found - Invalid DXIL Environment Specified" );
282+ }
283+
224284// / Construct DXIL function type. This is the type of a function with
225285// / the following prototype
226286// / OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
@@ -232,7 +292,7 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
232292 Type *ReturnTy, Type *OverloadTy) {
233293 SmallVector<Type *> ArgTys;
234294
235- auto ParamKinds = getOpCodeParameterKind (*Prop);
295+ const ParameterKind * ParamKinds = getOpCodeParameterKind (*Prop);
236296
237297 // Add ReturnTy as return type of the function
238298 ArgTys.emplace_back (ReturnTy);
@@ -249,17 +309,103 @@ static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
249309 ArgTys[0 ], ArrayRef<Type *>(&ArgTys[1 ], ArgTys.size () - 1 ), false );
250310}
251311
312+ // / Get index of the property from PropList valid for the most recent
313+ // / DXIL version not greater than DXILVer.
314+ // / PropList is expected to be sorted in ascending order of DXIL version.
315+ template <typename T>
316+ static std::optional<size_t > getPropIndex (ArrayRef<T> PropList,
317+ const VersionTuple DXILVer) {
318+ size_t Index = PropList.size () - 1 ;
319+ for (auto Iter = PropList.rbegin (); Iter != PropList.rend ();
320+ Iter++, Index--) {
321+ const T &Prop = *Iter;
322+ if (VersionTuple (Prop.DXILVersion .Major , Prop.DXILVersion .Minor ) <=
323+ DXILVer) {
324+ return Index;
325+ }
326+ }
327+ return std::nullopt ;
328+ }
329+
252330namespace llvm {
253331namespace dxil {
254332
333+ // No extra checks on TargetTriple need be performed to verify that the
334+ // Triple is well-formed or that the target is supported since these checks
335+ // would have been done at the time the module M is constructed in the earlier
336+ // stages of compilation.
337+ DXILOpBuilder::DXILOpBuilder (Module &M, IRBuilderBase &B) : M(M), B(B) {
338+ Triple TT (Triple (M.getTargetTriple ()));
339+ DXILVersion = TT.getDXILVersion ();
340+ ShaderStage = TT.getEnvironment ();
341+ // Ensure Environment type is known
342+ if (ShaderStage == Triple::UnknownEnvironment) {
343+ report_fatal_error (
344+ Twine (DXILVersion.getAsString ()) +
345+ " : Unknown Compilation Target Shader Stage specified " ,
346+ /* gen_crash_diag*/ false );
347+ }
348+ }
349+
255350CallInst *DXILOpBuilder::createDXILOpCall (dxil::OpCode OpCode, Type *ReturnTy,
256351 Type *OverloadTy,
257352 SmallVector<Value *> Args) {
353+
258354 const OpCodeProperty *Prop = getOpCodeProperty (OpCode);
355+ std::optional<size_t > OlIndexOrErr =
356+ getPropIndex (ArrayRef (Prop->Overloads ), DXILVersion);
357+ if (!OlIndexOrErr.has_value ()) {
358+ report_fatal_error (Twine (getOpCodeName (OpCode)) +
359+ " : No valid overloads found for DXIL Version - " +
360+ DXILVersion.getAsString (),
361+ /* gen_crash_diag*/ false );
362+ }
363+ uint16_t ValidTyMask = Prop->Overloads [*OlIndexOrErr].ValidTys ;
259364
260365 OverloadKind Kind = getOverloadKind (OverloadTy);
261- if ((Prop->OverloadTys & (uint16_t )Kind) == 0 ) {
262- report_fatal_error (" Invalid Overload Type" , /* gen_crash_diag=*/ false );
366+
367+ // Check if the operation supports overload types and OverloadTy is valid
368+ // per the specified types for the operation
369+ if ((ValidTyMask != OverloadKind::UNDEFINED) &&
370+ (ValidTyMask & (uint16_t )Kind) == 0 ) {
371+ report_fatal_error (Twine (" Invalid Overload Type for DXIL operation - " ) +
372+ getOpCodeName (OpCode),
373+ /* gen_crash_diag=*/ false );
374+ }
375+
376+ // Perform necessary checks to ensure Opcode is valid in the targeted shader
377+ // kind
378+ std::optional<size_t > StIndexOrErr =
379+ getPropIndex (ArrayRef (Prop->Stages ), DXILVersion);
380+ if (!StIndexOrErr.has_value ()) {
381+ report_fatal_error (Twine (getOpCodeName (OpCode)) +
382+ " : No valid stages found for DXIL Version - " +
383+ DXILVersion.getAsString (),
384+ /* gen_crash_diag*/ false );
385+ }
386+ uint16_t ValidShaderKindMask = Prop->Stages [*StIndexOrErr].ValidStages ;
387+
388+ // Ensure valid shader stage properties are specified
389+ if (ValidShaderKindMask == ShaderKind::removed) {
390+ report_fatal_error (
391+ Twine (DXILVersion.getAsString ()) +
392+ " : Unsupported Target Shader Stage for DXIL operation - " +
393+ getOpCodeName (OpCode),
394+ /* gen_crash_diag*/ false );
395+ }
396+
397+ // Shader stage need not be validated since getShaderKindEnum() fails
398+ // for unknown shader stage.
399+
400+ // Verify the target shader stage is valid for the DXIL operation
401+ ShaderKind ModuleStagekind = getShaderKindEnum (ShaderStage);
402+ if (!(ValidShaderKindMask & ModuleStagekind)) {
403+ auto ShaderEnvStr = Triple::getEnvironmentTypeName (ShaderStage);
404+ report_fatal_error (Twine (ShaderEnvStr) +
405+ " : Invalid Shader Stage for DXIL operation - " +
406+ getOpCodeName (OpCode) + " for DXIL Version " +
407+ DXILVersion.getAsString (),
408+ /* gen_crash_diag*/ false );
263409 }
264410
265411 std::string DXILFnName = constructOverloadName (Kind, OverloadTy, *Prop);
@@ -282,40 +428,18 @@ Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
282428 // If DXIL Op has no overload parameter, just return the
283429 // precise return type specified.
284430 if (Prop->OverloadParamIndex < 0 ) {
285- auto &Ctx = FT->getContext ();
286- switch (Prop->OverloadTys ) {
287- case OverloadKind::VOID:
288- return Type::getVoidTy (Ctx);
289- case OverloadKind::HALF:
290- return Type::getHalfTy (Ctx);
291- case OverloadKind::FLOAT:
292- return Type::getFloatTy (Ctx);
293- case OverloadKind::DOUBLE:
294- return Type::getDoubleTy (Ctx);
295- case OverloadKind::I1:
296- return Type::getInt1Ty (Ctx);
297- case OverloadKind::I8:
298- return Type::getInt8Ty (Ctx);
299- case OverloadKind::I16:
300- return Type::getInt16Ty (Ctx);
301- case OverloadKind::I32:
302- return Type::getInt32Ty (Ctx);
303- case OverloadKind::I64:
304- return Type::getInt64Ty (Ctx);
305- default :
306- llvm_unreachable (" invalid overload type" );
307- return nullptr ;
308- }
431+ return FT->getReturnType ();
309432 }
310433
311- // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
434+ // Consider FT->getReturnType() as default overload type, unless
435+ // Prop->OverloadParamIndex != 0.
312436 Type *OverloadType = FT->getReturnType ();
313437 if (Prop->OverloadParamIndex != 0 ) {
314438 // Skip Return Type.
315439 OverloadType = FT->getParamType (Prop->OverloadParamIndex - 1 );
316440 }
317441
318- auto ParamKinds = getOpCodeParameterKind (*Prop);
442+ const ParameterKind * ParamKinds = getOpCodeParameterKind (*Prop);
319443 auto Kind = ParamKinds[Prop->OverloadParamIndex ];
320444 // For ResRet and CBufferRet, OverloadTy is in field of StructType.
321445 if (Kind == ParameterKind::CBufferRet ||
0 commit comments