Skip to content

Commit c6dfbc5

Browse files
authored
[DirectX] Refactor RootSignature Backend to remove to_underlying from Root Parameter Header (#154249)
This patch is refactoring Root Parameter Header in DX Container backend to remove the usage of `to_underlying`. This requires some changes: first, MC Root Signature should not depend on Object/DXContainer.h; Second, we need to assume data to be valid in scenarios where it was originally not expected, this made some tests be removed.
1 parent 6a5cb5a commit c6dfbc5

File tree

9 files changed

+125
-196
lines changed

9 files changed

+125
-196
lines changed

llvm/include/llvm/MC/DXContainerRootSignature.h

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ class raw_ostream;
2020
namespace mcdxbc {
2121

2222
struct RootParameterInfo {
23-
dxbc::RTS0::v1::RootParameterHeader Header;
23+
dxbc::RootParameterType Type;
24+
dxbc::ShaderVisibility Visibility;
2425
size_t Location;
2526

26-
RootParameterInfo() = default;
27-
28-
RootParameterInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location)
29-
: Header(Header), Location(Location) {}
27+
RootParameterInfo(dxbc::RootParameterType Type,
28+
dxbc::ShaderVisibility Visibility, size_t Location)
29+
: Type(Type), Visibility(Visibility), Location(Location) {}
3030
};
3131

3232
struct DescriptorTable {
@@ -46,41 +46,34 @@ struct RootParametersContainer {
4646
SmallVector<dxbc::RTS0::v2::RootDescriptor> Descriptors;
4747
SmallVector<DescriptorTable> Tables;
4848

49-
void addInfo(dxbc::RTS0::v1::RootParameterHeader Header, size_t Location) {
50-
ParametersInfo.push_back(RootParameterInfo(Header, Location));
49+
void addInfo(dxbc::RootParameterType Type, dxbc::ShaderVisibility Visibility,
50+
size_t Location) {
51+
ParametersInfo.emplace_back(Type, Visibility, Location);
5152
}
5253

53-
void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
54+
void addParameter(dxbc::RootParameterType Type,
55+
dxbc::ShaderVisibility Visibility,
5456
dxbc::RTS0::v1::RootConstants Constant) {
55-
addInfo(Header, Constants.size());
57+
addInfo(Type, Visibility, Constants.size());
5658
Constants.push_back(Constant);
5759
}
5860

59-
void addInvalidParameter(dxbc::RTS0::v1::RootParameterHeader Header) {
60-
addInfo(Header, -1);
61-
}
62-
63-
void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
61+
void addParameter(dxbc::RootParameterType Type,
62+
dxbc::ShaderVisibility Visibility,
6463
dxbc::RTS0::v2::RootDescriptor Descriptor) {
65-
addInfo(Header, Descriptors.size());
64+
addInfo(Type, Visibility, Descriptors.size());
6665
Descriptors.push_back(Descriptor);
6766
}
6867

69-
void addParameter(dxbc::RTS0::v1::RootParameterHeader Header,
70-
DescriptorTable Table) {
71-
addInfo(Header, Tables.size());
68+
void addParameter(dxbc::RootParameterType Type,
69+
dxbc::ShaderVisibility Visibility, DescriptorTable Table) {
70+
addInfo(Type, Visibility, Tables.size());
7271
Tables.push_back(Table);
7372
}
7473

75-
std::pair<uint32_t, uint32_t>
76-
getTypeAndLocForParameter(uint32_t Location) const {
77-
const RootParameterInfo &Info = ParametersInfo[Location];
78-
return {Info.Header.ParameterType, Info.Location};
79-
}
80-
81-
const dxbc::RTS0::v1::RootParameterHeader &getHeader(size_t Location) const {
74+
const RootParameterInfo &getInfo(uint32_t Location) const {
8275
const RootParameterInfo &Info = ParametersInfo[Location];
83-
return Info.Header;
76+
return Info;
8477
}
8578

8679
const dxbc::RTS0::v1::RootConstants &getConstant(size_t Index) const {

llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
5252
return NodeText->getString();
5353
}
5454

55+
static Expected<dxbc::ShaderVisibility>
56+
extractShaderVisibility(MDNode *Node, unsigned int OpId) {
57+
if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
58+
if (!dxbc::isValidShaderVisibility(*Val))
59+
return make_error<RootSignatureValidationError<uint32_t>>(
60+
"ShaderVisibility", *Val);
61+
return dxbc::ShaderVisibility(*Val);
62+
}
63+
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
64+
}
65+
5566
namespace {
5667

5768
// We use the OverloadVisit with std::visit to ensure the compiler catches if a
@@ -221,15 +232,10 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
221232
if (RootConstantNode->getNumOperands() != 5)
222233
return make_error<InvalidRSMetadataFormat>("RootConstants Element");
223234

224-
dxbc::RTS0::v1::RootParameterHeader Header;
225-
// The parameter offset doesn't matter here - we recalculate it during
226-
// serialization Header.ParameterOffset = 0;
227-
Header.ParameterType = to_underlying(dxbc::RootParameterType::Constants32Bit);
228-
229-
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
230-
Header.ShaderVisibility = *Val;
231-
else
232-
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
235+
Expected<dxbc::ShaderVisibility> Visibility =
236+
extractShaderVisibility(RootConstantNode, 1);
237+
if (auto E = Visibility.takeError())
238+
return Error(std::move(E));
233239

234240
dxbc::RTS0::v1::RootConstants Constants;
235241
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
@@ -247,7 +253,8 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
247253
else
248254
return make_error<InvalidRSMetadataValue>("Num32BitValues");
249255

250-
RSD.ParametersContainer.addParameter(Header, Constants);
256+
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
257+
*Visibility, Constants);
251258

252259
return Error::success();
253260
}
@@ -263,26 +270,26 @@ Error MetadataParser::parseRootDescriptors(
263270
if (RootDescriptorNode->getNumOperands() != 5)
264271
return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
265272

266-
dxbc::RTS0::v1::RootParameterHeader Header;
273+
dxbc::RootParameterType Type;
267274
switch (ElementKind) {
268275
case RootSignatureElementKind::SRV:
269-
Header.ParameterType = to_underlying(dxbc::RootParameterType::SRV);
276+
Type = dxbc::RootParameterType::SRV;
270277
break;
271278
case RootSignatureElementKind::UAV:
272-
Header.ParameterType = to_underlying(dxbc::RootParameterType::UAV);
279+
Type = dxbc::RootParameterType::UAV;
273280
break;
274281
case RootSignatureElementKind::CBV:
275-
Header.ParameterType = to_underlying(dxbc::RootParameterType::CBV);
282+
Type = dxbc::RootParameterType::CBV;
276283
break;
277284
default:
278285
llvm_unreachable("invalid Root Descriptor kind");
279286
break;
280287
}
281288

282-
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
283-
Header.ShaderVisibility = *Val;
284-
else
285-
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
289+
Expected<dxbc::ShaderVisibility> Visibility =
290+
extractShaderVisibility(RootDescriptorNode, 1);
291+
if (auto E = Visibility.takeError())
292+
return Error(std::move(E));
286293

287294
dxbc::RTS0::v2::RootDescriptor Descriptor;
288295
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
@@ -296,7 +303,7 @@ Error MetadataParser::parseRootDescriptors(
296303
return make_error<InvalidRSMetadataValue>("RegisterSpace");
297304

298305
if (RSD.Version == 1) {
299-
RSD.ParametersContainer.addParameter(Header, Descriptor);
306+
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
300307
return Error::success();
301308
}
302309
assert(RSD.Version > 1);
@@ -306,7 +313,7 @@ Error MetadataParser::parseRootDescriptors(
306313
else
307314
return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
308315

309-
RSD.ParametersContainer.addParameter(Header, Descriptor);
316+
RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
310317
return Error::success();
311318
}
312319

@@ -372,15 +379,12 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
372379
if (NumOperands < 2)
373380
return make_error<InvalidRSMetadataFormat>("Descriptor Table");
374381

375-
dxbc::RTS0::v1::RootParameterHeader Header;
376-
if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
377-
Header.ShaderVisibility = *Val;
378-
else
379-
return make_error<InvalidRSMetadataValue>("ShaderVisibility");
382+
Expected<dxbc::ShaderVisibility> Visibility =
383+
extractShaderVisibility(DescriptorTableNode, 1);
384+
if (auto E = Visibility.takeError())
385+
return Error(std::move(E));
380386

381387
mcdxbc::DescriptorTable Table;
382-
Header.ParameterType =
383-
to_underlying(dxbc::RootParameterType::DescriptorTable);
384388

385389
for (unsigned int I = 2; I < NumOperands; I++) {
386390
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
@@ -392,7 +396,8 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
392396
return Err;
393397
}
394398

395-
RSD.ParametersContainer.addParameter(Header, Table);
399+
RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
400+
*Visibility, Table);
396401
return Error::success();
397402
}
398403

@@ -528,20 +533,14 @@ Error MetadataParser::validateRootSignature(
528533
}
529534

530535
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
531-
if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
532-
DeferredErrs =
533-
joinErrors(std::move(DeferredErrs),
534-
make_error<RootSignatureValidationError<uint32_t>>(
535-
"ShaderVisibility", Info.Header.ShaderVisibility));
536-
537-
assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
538-
"Invalid value for ParameterType");
539536

540-
switch (Info.Header.ParameterType) {
537+
switch (Info.Type) {
538+
case dxbc::RootParameterType::Constants32Bit:
539+
break;
541540

542-
case to_underlying(dxbc::RootParameterType::CBV):
543-
case to_underlying(dxbc::RootParameterType::UAV):
544-
case to_underlying(dxbc::RootParameterType::SRV): {
541+
case dxbc::RootParameterType::CBV:
542+
case dxbc::RootParameterType::UAV:
543+
case dxbc::RootParameterType::SRV: {
545544
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
546545
RSD.ParametersContainer.getRootDescriptor(Info.Location);
547546
if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
@@ -566,7 +565,7 @@ Error MetadataParser::validateRootSignature(
566565
}
567566
break;
568567
}
569-
case to_underlying(dxbc::RootParameterType::DescriptorTable): {
568+
case dxbc::RootParameterType::DescriptorTable: {
570569
const mcdxbc::DescriptorTable &Table =
571570
RSD.ParametersContainer.getDescriptorTable(Info.Location);
572571
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {

llvm/lib/MC/DXContainerRootSignature.cpp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,20 @@ size_t RootSignatureDesc::getSize() const {
3535
StaticSamplers.size() * sizeof(dxbc::RTS0::v1::StaticSampler);
3636

3737
for (const RootParameterInfo &I : ParametersContainer) {
38-
switch (I.Header.ParameterType) {
39-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit):
38+
switch (I.Type) {
39+
case dxbc::RootParameterType::Constants32Bit:
4040
Size += sizeof(dxbc::RTS0::v1::RootConstants);
4141
break;
42-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
43-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
44-
case llvm::to_underlying(dxbc::RootParameterType::UAV):
42+
case dxbc::RootParameterType::CBV:
43+
case dxbc::RootParameterType::SRV:
44+
case dxbc::RootParameterType::UAV:
4545
if (Version == 1)
4646
Size += sizeof(dxbc::RTS0::v1::RootDescriptor);
4747
else
4848
Size += sizeof(dxbc::RTS0::v2::RootDescriptor);
4949

5050
break;
51-
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable):
51+
case dxbc::RootParameterType::DescriptorTable:
5252
const DescriptorTable &Table =
5353
ParametersContainer.getDescriptorTable(I.Location);
5454

@@ -84,23 +84,21 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
8484
support::endian::write(BOS, Flags, llvm::endianness::little);
8585

8686
SmallVector<uint32_t> ParamsOffsets;
87-
for (const RootParameterInfo &P : ParametersContainer) {
88-
support::endian::write(BOS, P.Header.ParameterType,
89-
llvm::endianness::little);
90-
support::endian::write(BOS, P.Header.ShaderVisibility,
91-
llvm::endianness::little);
87+
for (const RootParameterInfo &I : ParametersContainer) {
88+
support::endian::write(BOS, I.Type, llvm::endianness::little);
89+
support::endian::write(BOS, I.Visibility, llvm::endianness::little);
9290

9391
ParamsOffsets.push_back(writePlaceholder(BOS));
9492
}
9593

9694
assert(NumParameters == ParamsOffsets.size());
9795
for (size_t I = 0; I < NumParameters; ++I) {
9896
rewriteOffsetToCurrentByte(BOS, ParamsOffsets[I]);
99-
const auto &[Type, Loc] = ParametersContainer.getTypeAndLocForParameter(I);
100-
switch (Type) {
101-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
97+
const RootParameterInfo &Info = ParametersContainer.getInfo(I);
98+
switch (Info.Type) {
99+
case dxbc::RootParameterType::Constants32Bit: {
102100
const dxbc::RTS0::v1::RootConstants &Constants =
103-
ParametersContainer.getConstant(Loc);
101+
ParametersContainer.getConstant(Info.Location);
104102
support::endian::write(BOS, Constants.ShaderRegister,
105103
llvm::endianness::little);
106104
support::endian::write(BOS, Constants.RegisterSpace,
@@ -109,11 +107,11 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
109107
llvm::endianness::little);
110108
break;
111109
}
112-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
113-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
114-
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
110+
case dxbc::RootParameterType::CBV:
111+
case dxbc::RootParameterType::SRV:
112+
case dxbc::RootParameterType::UAV: {
115113
const dxbc::RTS0::v2::RootDescriptor &Descriptor =
116-
ParametersContainer.getRootDescriptor(Loc);
114+
ParametersContainer.getRootDescriptor(Info.Location);
117115

118116
support::endian::write(BOS, Descriptor.ShaderRegister,
119117
llvm::endianness::little);
@@ -123,9 +121,9 @@ void RootSignatureDesc::write(raw_ostream &OS) const {
123121
support::endian::write(BOS, Descriptor.Flags, llvm::endianness::little);
124122
break;
125123
}
126-
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
124+
case dxbc::RootParameterType::DescriptorTable: {
127125
const DescriptorTable &Table =
128-
ParametersContainer.getDescriptorTable(Loc);
126+
ParametersContainer.getDescriptorTable(Info.Location);
129127
support::endian::write(BOS, (uint32_t)Table.Ranges.size(),
130128
llvm::endianness::little);
131129
rewriteOffsetToCurrentByte(BOS, writePlaceholder(BOS));

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -275,23 +275,30 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
275275

276276
for (DXContainerYAML::RootParameterLocationYaml &L :
277277
P.RootSignature->Parameters.Locations) {
278-
dxbc::RTS0::v1::RootParameterHeader Header{L.Header.Type, L.Header.Visibility,
279-
L.Header.Offset};
280278

281-
switch (L.Header.Type) {
282-
case llvm::to_underlying(dxbc::RootParameterType::Constants32Bit): {
279+
assert(dxbc::isValidParameterType(L.Header.Type) &&
280+
"invalid DXContainer YAML");
281+
assert(dxbc::isValidShaderVisibility(L.Header.Visibility) &&
282+
"invalid DXContainer YAML");
283+
dxbc::RootParameterType Type = dxbc::RootParameterType(L.Header.Type);
284+
dxbc::ShaderVisibility Visibility =
285+
dxbc::ShaderVisibility(L.Header.Visibility);
286+
287+
switch (Type) {
288+
case dxbc::RootParameterType::Constants32Bit: {
283289
const DXContainerYAML::RootConstantsYaml &ConstantYaml =
284290
P.RootSignature->Parameters.getOrInsertConstants(L);
285291
dxbc::RTS0::v1::RootConstants Constants;
292+
286293
Constants.Num32BitValues = ConstantYaml.Num32BitValues;
287294
Constants.RegisterSpace = ConstantYaml.RegisterSpace;
288295
Constants.ShaderRegister = ConstantYaml.ShaderRegister;
289-
RS.ParametersContainer.addParameter(Header, Constants);
296+
RS.ParametersContainer.addParameter(Type, Visibility, Constants);
290297
break;
291298
}
292-
case llvm::to_underlying(dxbc::RootParameterType::CBV):
293-
case llvm::to_underlying(dxbc::RootParameterType::SRV):
294-
case llvm::to_underlying(dxbc::RootParameterType::UAV): {
299+
case dxbc::RootParameterType::CBV:
300+
case dxbc::RootParameterType::SRV:
301+
case dxbc::RootParameterType::UAV: {
295302
const DXContainerYAML::RootDescriptorYaml &DescriptorYaml =
296303
P.RootSignature->Parameters.getOrInsertDescriptor(L);
297304

@@ -300,10 +307,10 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
300307
Descriptor.ShaderRegister = DescriptorYaml.ShaderRegister;
301308
if (RS.Version > 1)
302309
Descriptor.Flags = DescriptorYaml.getEncodedFlags();
303-
RS.ParametersContainer.addParameter(Header, Descriptor);
310+
RS.ParametersContainer.addParameter(Type, Visibility, Descriptor);
304311
break;
305312
}
306-
case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
313+
case dxbc::RootParameterType::DescriptorTable: {
307314
const DXContainerYAML::DescriptorTableYaml &TableYaml =
308315
P.RootSignature->Parameters.getOrInsertTable(L);
309316
mcdxbc::DescriptorTable Table;
@@ -320,14 +327,9 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
320327
Range.Flags = R.getEncodedFlags();
321328
Table.Ranges.push_back(Range);
322329
}
323-
RS.ParametersContainer.addParameter(Header, Table);
330+
RS.ParametersContainer.addParameter(Type, Visibility, Table);
324331
break;
325332
}
326-
default:
327-
// Handling invalid parameter type edge case. We intentionally let
328-
// obj2yaml/yaml2obj parse and emit invalid dxcontainer data, in order
329-
// for that to be used as a testing tool more effectively.
330-
RS.ParametersContainer.addInvalidParameter(Header);
331333
}
332334
}
333335

0 commit comments

Comments
 (0)