@@ -52,6 +52,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5252 return NodeText->getString ();
5353}
5454
55+ static Expected<dxbc::ShaderVisibility>
56+ extractShaderVisibility (MDNode *Node, unsigned int OpId) {
57+ if (std::optional<uint32_t > Val = extractMdIntValue (Node, OpId)) {
58+ if (!dxbc::isValidShaderVisibility (*Val))
59+ return make_error<RootSignatureValidationError<uint32_t >>(
60+ " ShaderVisibility" , *Val);
61+ return dxbc::ShaderVisibility (*Val);
62+ }
63+ return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
64+ }
65+
5566namespace {
5667
5768// We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -221,15 +232,10 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
221232 if (RootConstantNode->getNumOperands () != 5 )
222233 return make_error<InvalidRSMetadataFormat>(" RootConstants Element" );
223234
224- dxbc::RTS0::v1::RootParameterHeader Header;
225- // The parameter offset doesn't matter here - we recalculate it during
226- // serialization Header.ParameterOffset = 0;
227- Header.ParameterType = to_underlying (dxbc::RootParameterType::Constants32Bit);
228-
229- if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 1 ))
230- Header.ShaderVisibility = *Val;
231- else
232- return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
235+ Expected<dxbc::ShaderVisibility> Visibility =
236+ extractShaderVisibility (RootConstantNode, 1 );
237+ if (auto E = Visibility.takeError ())
238+ return Error (std::move (E));
233239
234240 dxbc::RTS0::v1::RootConstants Constants;
235241 if (std::optional<uint32_t > Val = extractMdIntValue (RootConstantNode, 2 ))
@@ -247,7 +253,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
247253 else
248254 return make_error<InvalidRSMetadataValue>(" Num32BitValues" );
249255
250- RSD.ParametersContainer .addParameter (Header, Constants);
256+ RSD.ParametersContainer .addParameter (dxbc::RootParameterType::Constants32Bit,
257+ *Visibility, Constants);
251258
252259 return Error::success ();
253260}
@@ -263,26 +270,26 @@ Error MetadataParser::parseRootDescriptors(
263270 if (RootDescriptorNode->getNumOperands () != 5 )
264271 return make_error<InvalidRSMetadataFormat>(" Root Descriptor Element" );
265272
266- dxbc::RTS0::v1::RootParameterHeader Header ;
273+ dxbc::RootParameterType Type ;
267274 switch (ElementKind) {
268275 case RootSignatureElementKind::SRV:
269- Header. ParameterType = to_underlying ( dxbc::RootParameterType::SRV) ;
276+ Type = dxbc::RootParameterType::SRV;
270277 break ;
271278 case RootSignatureElementKind::UAV:
272- Header. ParameterType = to_underlying ( dxbc::RootParameterType::UAV) ;
279+ Type = dxbc::RootParameterType::UAV;
273280 break ;
274281 case RootSignatureElementKind::CBV:
275- Header. ParameterType = to_underlying ( dxbc::RootParameterType::CBV) ;
282+ Type = dxbc::RootParameterType::CBV;
276283 break ;
277284 default :
278285 llvm_unreachable (" invalid Root Descriptor kind" );
279286 break ;
280287 }
281288
282- if (std::optional< uint32_t > Val = extractMdIntValue (RootDescriptorNode, 1 ))
283- Header. ShaderVisibility = *Val ;
284- else
285- return make_error<InvalidRSMetadataValue>( " ShaderVisibility " );
289+ Expected<dxbc::ShaderVisibility> Visibility =
290+ extractShaderVisibility (RootDescriptorNode, 1 ) ;
291+ if ( auto E = Visibility. takeError ())
292+ return Error ( std::move (E) );
286293
287294 dxbc::RTS0::v2::RootDescriptor Descriptor;
288295 if (std::optional<uint32_t > Val = extractMdIntValue (RootDescriptorNode, 2 ))
@@ -296,7 +303,7 @@ Error MetadataParser::parseRootDescriptors(
296303 return make_error<InvalidRSMetadataValue>(" RegisterSpace" );
297304
298305 if (RSD.Version == 1 ) {
299- RSD.ParametersContainer .addParameter (Header , Descriptor);
306+ RSD.ParametersContainer .addParameter (Type, *Visibility , Descriptor);
300307 return Error::success ();
301308 }
302309 assert (RSD.Version > 1 );
@@ -306,7 +313,7 @@ Error MetadataParser::parseRootDescriptors(
306313 else
307314 return make_error<InvalidRSMetadataValue>(" Root Descriptor Flags" );
308315
309- RSD.ParametersContainer .addParameter (Header , Descriptor);
316+ RSD.ParametersContainer .addParameter (Type, *Visibility , Descriptor);
310317 return Error::success ();
311318}
312319
@@ -372,15 +379,12 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
372379 if (NumOperands < 2 )
373380 return make_error<InvalidRSMetadataFormat>(" Descriptor Table" );
374381
375- dxbc::RTS0::v1::RootParameterHeader Header;
376- if (std::optional<uint32_t > Val = extractMdIntValue (DescriptorTableNode, 1 ))
377- Header.ShaderVisibility = *Val;
378- else
379- return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
382+ Expected<dxbc::ShaderVisibility> Visibility =
383+ extractShaderVisibility (DescriptorTableNode, 1 );
384+ if (auto E = Visibility.takeError ())
385+ return Error (std::move (E));
380386
381387 mcdxbc::DescriptorTable Table;
382- Header.ParameterType =
383- to_underlying (dxbc::RootParameterType::DescriptorTable);
384388
385389 for (unsigned int I = 2 ; I < NumOperands; I++) {
386390 MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand (I));
@@ -392,7 +396,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
392396 return Err;
393397 }
394398
395- RSD.ParametersContainer .addParameter (Header, Table);
399+ RSD.ParametersContainer .addParameter (dxbc::RootParameterType::DescriptorTable,
400+ *Visibility, Table);
396401 return Error::success ();
397402}
398403
@@ -528,20 +533,14 @@ Error MetadataParser::validateRootSignature(
528533 }
529534
530535 for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer ) {
531- if (!dxbc::isValidShaderVisibility (Info.Header .ShaderVisibility ))
532- DeferredErrs =
533- joinErrors (std::move (DeferredErrs),
534- make_error<RootSignatureValidationError<uint32_t >>(
535- " ShaderVisibility" , Info.Header .ShaderVisibility ));
536-
537- assert (dxbc::isValidParameterType (Info.Header .ParameterType ) &&
538- " Invalid value for ParameterType" );
539536
540- switch (Info.Header .ParameterType ) {
537+ switch (Info.Type ) {
538+ case dxbc::RootParameterType::Constants32Bit:
539+ break ;
541540
542- case to_underlying ( dxbc::RootParameterType::CBV) :
543- case to_underlying ( dxbc::RootParameterType::UAV) :
544- case to_underlying ( dxbc::RootParameterType::SRV) : {
541+ case dxbc::RootParameterType::CBV:
542+ case dxbc::RootParameterType::UAV:
543+ case dxbc::RootParameterType::SRV: {
545544 const dxbc::RTS0::v2::RootDescriptor &Descriptor =
546545 RSD.ParametersContainer .getRootDescriptor (Info.Location );
547546 if (!hlsl::rootsig::verifyRegisterValue (Descriptor.ShaderRegister ))
@@ -566,7 +565,7 @@ Error MetadataParser::validateRootSignature(
566565 }
567566 break ;
568567 }
569- case to_underlying ( dxbc::RootParameterType::DescriptorTable) : {
568+ case dxbc::RootParameterType::DescriptorTable: {
570569 const mcdxbc::DescriptorTable &Table =
571570 RSD.ParametersContainer .getDescriptorTable (Info.Location );
572571 for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
0 commit comments