1515#include " llvm/IR/Module.h"
1616#include " llvm/Support/DXILABI.h"
1717#include " llvm/Support/ErrorHandling.h"
18+ #include " llvm/TargetParser/Triple.h"
1819
1920using namespace llvm ;
2021using namespace llvm ::dxil;
@@ -24,9 +25,7 @@ constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
2425// Include DXIL Operation data and corresponding access functions
2526// generated by the TableGen backend DXILEmitter.
2627#define DXIL_OP_OPERATION_TABLE
27- #define SHADER_KIND_ENUM
2828#include " DXILOperation.inc"
29- #undef SHADER_KIND_ENUM
3029#undef DXIL_OP_OPERATION_TABLE
3130
3231static OverloadKind getOverloadKind (Type *Ty) {
@@ -162,6 +161,45 @@ static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
162161 return nullptr ;
163162}
164163
164+ static ShaderKind getShaderKindEnum (Triple::EnvironmentType EnvType) {
165+ switch (EnvType) {
166+ case Triple::Pixel:
167+ return ShaderKind::pixel;
168+ case Triple::Vertex:
169+ return ShaderKind::vertex;
170+ case Triple::Geometry:
171+ return ShaderKind::geometry;
172+ case Triple::Hull:
173+ return ShaderKind::hull;
174+ case Triple::Domain:
175+ return ShaderKind::domain;
176+ case Triple::Compute:
177+ return ShaderKind::compute;
178+ case Triple::Library:
179+ return ShaderKind::library;
180+ case Triple::RayGeneration:
181+ return ShaderKind::raygeneration;
182+ case Triple::Intersection:
183+ return ShaderKind::intersection;
184+ case Triple::AnyHit:
185+ return ShaderKind::anyhit;
186+ case Triple::ClosestHit:
187+ return ShaderKind::closesthit;
188+ case Triple::Miss:
189+ return ShaderKind::miss;
190+ case Triple::Callable:
191+ return ShaderKind::callable;
192+ case Triple::Mesh:
193+ return ShaderKind::mesh;
194+ case Triple::Amplification:
195+ return ShaderKind::amplification;
196+ default :
197+ break ;
198+ }
199+ llvm_unreachable (
200+ " Shader Kind Not Found - Invalid DXIL Environment Specified" );
201+ }
202+
165203// / Construct DXIL function type. This is the type of a function with
166204// / the following prototype
167205// / OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
@@ -213,11 +251,23 @@ static int getValidConstraintIndex(const OpCodeProperty *Prop,
213251namespace llvm {
214252namespace dxil {
215253
216- CallInst *DXILOpBuilder::createDXILOpCall (dxil::OpCode OpCode,
217- VersionTuple SMVer,
218- StringRef StageKind, Type *ReturnTy,
254+ CallInst *DXILOpBuilder::createDXILOpCall (dxil::OpCode OpCode, Type *ReturnTy,
219255 Type *OverloadTy,
220256 SmallVector<Value *> Args) {
257+
258+ std::string TTStr = M.getTargetTriple ();
259+ // No extra checks need be performed to verify that the Triple is
260+ // well-formed or the target is supported since these checks would have
261+ // been done at the time the module M is constructed in the earlier stages of
262+ // compilation.
263+ auto Major = Triple (TTStr).getOSVersion ().getMajor ();
264+ auto MinorOrErr = Triple (TTStr).getOSVersion ().getMinor ();
265+ uint32_t Minor = MinorOrErr.has_value () ? *MinorOrErr : 0 ;
266+ VersionTuple SMVer (Major, Minor);
267+ // Get Shader Stage Kind
268+ Triple::EnvironmentType ShaderEnv = Triple (TTStr).getEnvironment ();
269+ auto ShaderEnvStr = Triple (TTStr).getEnvironmentName ();
270+
221271 const OpCodeProperty *Prop = getOpCodeProperty (OpCode);
222272 int Index = getValidConstraintIndex (Prop, SMVer);
223273 uint16_t ValidTyMask = Prop->Constraints [Index].ValidTys ;
@@ -234,10 +284,18 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode,
234284 /* gen_crash_diag=*/ false );
235285 }
236286
287+ // Ensure Environment type is known
288+ if (ShaderEnv == Triple::UnknownEnvironment) {
289+ report_fatal_error (
290+ StringRef (SMVer.getAsString ().append (
291+ " : Unknown Compilation Target Shader Stage specified " )),
292+ /* gen_crash_diag*/ false );
293+ }
294+
237295 // Perform necessary checks to ensure Opcode is valid in the targeted shader
238296 // kind
239297 uint16_t ValidShaderKindMask = Prop->Constraints [Index].ValidShaderKinds ;
240- ShaderKind ModuleStagekind = getShaderkKindEnum (StageKind );
298+ enum ShaderKind ModuleStagekind = getShaderKindEnum (ShaderEnv );
241299
242300 // Ensure valid shader stage constraints are specified
243301 if (ValidShaderKindMask == ShaderKind::Unknown) {
@@ -260,7 +318,7 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode,
260318 // Verify the target shader stage is valid for the DXIL operation
261319 if (!(ValidShaderKindMask & ModuleStagekind)) {
262320 report_fatal_error (
263- StringRef (std::string (StageKind )
321+ StringRef (std::string (ShaderEnvStr )
264322 .append (" : Invalid Shader Stage for DXIL operation - " )
265323 .append (getOpCodeName (OpCode))
266324 .append (" for Shader Model " )
@@ -282,8 +340,17 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode,
282340 return B.CreateCall (DXILFn, Args);
283341}
284342
285- Type *DXILOpBuilder::getOverloadType (dxil::OpCode OpCode, VersionTuple SMVer,
286- FunctionType *FT) {
343+ Type *DXILOpBuilder::getOverloadType (dxil::OpCode OpCode, FunctionType *FT) {
344+
345+ std::string TTStr = M.getTargetTriple ();
346+ // No extra checks need be performed to verify that the Triple is
347+ // well-formed or the target is supported since these checks would have
348+ // been done at the time the module M is constructed in the earlier stages of
349+ // compilation.
350+ auto Major = Triple (TTStr).getOSVersion ().getMajor ();
351+ auto MinorOrErr = Triple (TTStr).getOSVersion ().getMinor ();
352+ uint32_t Minor = MinorOrErr.has_value () ? *MinorOrErr : 0 ;
353+ VersionTuple SMVer (Major, Minor);
287354
288355 const OpCodeProperty *Prop = getOpCodeProperty (OpCode);
289356 // If DXIL Op has no overload parameter, just return the
0 commit comments