diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 765a3bcbed7e2..7d744781da04f 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -15,111 +15,46 @@ #include "llvm/ADT/bit.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/ScopedPrinter.h" namespace llvm { namespace hlsl { namespace rootsig { -static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { - switch (Reg.ViewType) { - case RegisterType::BReg: - OS << "b"; - break; - case RegisterType::TReg: - OS << "t"; - break; - case RegisterType::UReg: - OS << "u"; - break; - case RegisterType::SReg: - OS << "s"; - break; - } - OS << Reg.Number; - return OS; +template +static std::optional getEnumName(const T Value, + ArrayRef> Enums) { + for (const auto &EnumItem : Enums) + if (EnumItem.Value == Value) + return EnumItem.Name; + return std::nullopt; } -static raw_ostream &operator<<(raw_ostream &OS, - const ShaderVisibility &Visibility) { - switch (Visibility) { - case ShaderVisibility::All: - OS << "All"; - break; - case ShaderVisibility::Vertex: - OS << "Vertex"; - break; - case ShaderVisibility::Hull: - OS << "Hull"; - break; - case ShaderVisibility::Domain: - OS << "Domain"; - break; - case ShaderVisibility::Geometry: - OS << "Geometry"; - break; - case ShaderVisibility::Pixel: - OS << "Pixel"; - break; - case ShaderVisibility::Amplification: - OS << "Amplification"; - break; - case ShaderVisibility::Mesh: - OS << "Mesh"; - break; - } - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { - switch (Type) { - case ClauseType::CBuffer: - OS << "CBV"; - break; - case ClauseType::SRV: - OS << "SRV"; - break; - case ClauseType::UAV: - OS << "UAV"; - break; - case ClauseType::Sampler: - OS << "Sampler"; - break; - } - +template +static raw_ostream &printEnum(raw_ostream &OS, const T Value, + ArrayRef> Enums) { + auto MaybeName = getEnumName(Value, Enums); + if (MaybeName) + OS << *MaybeName; return OS; } -static raw_ostream &operator<<(raw_ostream &OS, - const DescriptorRangeFlags &Flags) { +template +static raw_ostream &printFlags(raw_ostream &OS, const T Value, + ArrayRef> Flags) { bool FlagSet = false; - unsigned Remaining = llvm::to_underlying(Flags); + unsigned Remaining = llvm::to_underlying(Value); while (Remaining) { unsigned Bit = 1u << llvm::countr_zero(Remaining); if (Remaining & Bit) { if (FlagSet) OS << " | "; - switch (static_cast(Bit)) { - case DescriptorRangeFlags::DescriptorsVolatile: - OS << "DescriptorsVolatile"; - break; - case DescriptorRangeFlags::DataVolatile: - OS << "DataVolatile"; - break; - case DescriptorRangeFlags::DataStaticWhileSetAtExecute: - OS << "DataStaticWhileSetAtExecute"; - break; - case DescriptorRangeFlags::DataStatic: - OS << "DataStatic"; - break; - case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks: - OS << "DescriptorsStaticKeepingBufferBoundsChecks"; - break; - default: + auto MaybeFlag = getEnumName(T(Bit), Flags); + if (MaybeFlag) + OS << *MaybeFlag; + else OS << "invalid: " << Bit; - break; - } FlagSet = true; } @@ -128,6 +63,68 @@ static raw_ostream &operator<<(raw_ostream &OS, if (!FlagSet) OS << "None"; + return OS; +} + +static const EnumEntry RegisterNames[] = { + {"b", RegisterType::BReg}, + {"t", RegisterType::TReg}, + {"u", RegisterType::UReg}, + {"s", RegisterType::SReg}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { + printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); + OS << Reg.Number; + + return OS; +} + +static const EnumEntry VisibilityNames[] = { + {"All", ShaderVisibility::All}, + {"Vertex", ShaderVisibility::Vertex}, + {"Hull", ShaderVisibility::Hull}, + {"Domain", ShaderVisibility::Domain}, + {"Geometry", ShaderVisibility::Geometry}, + {"Pixel", ShaderVisibility::Pixel}, + {"Amplification", ShaderVisibility::Amplification}, + {"Mesh", ShaderVisibility::Mesh}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, + const ShaderVisibility &Visibility) { + printEnum(OS, Visibility, ArrayRef(VisibilityNames)); + + return OS; +} + +static const EnumEntry ResourceClassNames[] = { + {"CBV", dxil::ResourceClass::CBuffer}, + {"SRV", dxil::ResourceClass::SRV}, + {"UAV", dxil::ResourceClass::UAV}, + {"Sampler", dxil::ResourceClass::Sampler}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { + printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), + ArrayRef(ResourceClassNames)); + + return OS; +} + +static const EnumEntry DescriptorRangeFlagNames[] = { + {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile}, + {"DataVolatile", DescriptorRangeFlags::DataVolatile}, + {"DataStaticWhileSetAtExecute", + DescriptorRangeFlags::DataStaticWhileSetAtExecute}, + {"DataStatic", DescriptorRangeFlags::DataStatic}, + {"DescriptorsStaticKeepingBufferBoundsChecks", + DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, + const DescriptorRangeFlags &Flags) { + printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames)); return OS; } @@ -236,12 +233,13 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - llvm::SmallString<7> Name; - llvm::raw_svector_ostream OS(Name); - OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type)); - + std::optional TypeName = + getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)), + ArrayRef(ResourceClassNames)); + assert(TypeName && "Provided an invalid Resource Class"); + llvm::SmallString<7> Name({"Root", *TypeName}); Metadata *Operands[] = { - MDString::get(Ctx, OS.str()), + MDString::get(Ctx, Name), ConstantAsMetadata::get( Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), @@ -277,19 +275,20 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); - std::string Name; - llvm::raw_string_ostream OS(Name); - OS << Clause.Type; - return MDNode::get( - Ctx, { - MDString::get(Ctx, OS.str()), - ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), - }); + std::optional Name = + getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)), + ArrayRef(ResourceClassNames)); + assert(Name && "Provided an invalid Resource Class"); + Metadata *Operands[] = { + MDString::get(Ctx, *Name), + ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Clause.Flags))), + }; + return MDNode::get(Ctx, Operands); } MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {