From 5f20b415602edb4718edc9198ec878db3ebb2a7f Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Fri, 6 Jun 2025 21:33:15 +0000 Subject: [PATCH 1/5] nfc: use llvm::EnumEntry to convert Enum to Strings --- .../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 172 +++++++++--------- 1 file changed, 85 insertions(+), 87 deletions(-) diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 765a3bcbed7e2..79eee0b12b304 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -15,112 +15,48 @@ #include "llvm/ADT/bit.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/Support/ScopedPrinter.h" namespace llvm { namespace hlsl { namespace rootsig { -static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { - switch (Reg.ViewType) { - case RegisterType::BReg: - OS << "b"; - break; - case RegisterType::TReg: - OS << "t"; - break; - case RegisterType::UReg: - OS << "u"; - break; - case RegisterType::SReg: - OS << "s"; - break; - } - OS << Reg.Number; - return OS; +template +static StringRef getEnumName(const T Value, ArrayRef> Enums) { + for (const auto &EnumItem : Enums) + if (EnumItem.Value == Value) + return EnumItem.Name; + return ""; } -static raw_ostream &operator<<(raw_ostream &OS, - const ShaderVisibility &Visibility) { - switch (Visibility) { - case ShaderVisibility::All: - OS << "All"; - break; - case ShaderVisibility::Vertex: - OS << "Vertex"; - break; - case ShaderVisibility::Hull: - OS << "Hull"; - break; - case ShaderVisibility::Domain: - OS << "Domain"; - break; - case ShaderVisibility::Geometry: - OS << "Geometry"; - break; - case ShaderVisibility::Pixel: - OS << "Pixel"; - break; - case ShaderVisibility::Amplification: - OS << "Amplification"; - break; - case ShaderVisibility::Mesh: - OS << "Mesh"; - break; - } +template +static raw_ostream &printEnum(raw_ostream &OS, const T Value, + ArrayRef> Enums) { + OS << getEnumName(Value, Enums); return OS; } -static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { - switch (Type) { - case ClauseType::CBuffer: - OS << "CBV"; - break; - case ClauseType::SRV: - OS << "SRV"; - break; - case ClauseType::UAV: - OS << "UAV"; - break; - case ClauseType::Sampler: - OS << "Sampler"; - break; - } - - return OS; -} - -static raw_ostream &operator<<(raw_ostream &OS, - const DescriptorRangeFlags &Flags) { +template +static raw_ostream &printFlags(raw_ostream &OS, const T Value, + ArrayRef> Flags) { bool FlagSet = false; - unsigned Remaining = llvm::to_underlying(Flags); + unsigned Remaining = llvm::to_underlying(Value); while (Remaining) { unsigned Bit = 1u << llvm::countr_zero(Remaining); if (Remaining & Bit) { if (FlagSet) OS << " | "; - switch (static_cast(Bit)) { - case DescriptorRangeFlags::DescriptorsVolatile: - OS << "DescriptorsVolatile"; - break; - case DescriptorRangeFlags::DataVolatile: - OS << "DataVolatile"; - break; - case DescriptorRangeFlags::DataStaticWhileSetAtExecute: - OS << "DataStaticWhileSetAtExecute"; - break; - case DescriptorRangeFlags::DataStatic: - OS << "DataStatic"; - break; - case DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks: - OS << "DescriptorsStaticKeepingBufferBoundsChecks"; - break; - default: + bool Found = false; + for (const auto &FlagItem : Flags) + if (FlagItem.Value == T(Bit)) { + OS << FlagItem.Name; + Found = true; + break; + } + if (!Found) OS << "invalid: " << Bit; - break; - } - FlagSet = true; } Remaining &= ~Bit; @@ -128,6 +64,68 @@ static raw_ostream &operator<<(raw_ostream &OS, if (!FlagSet) OS << "None"; + return OS; +} + +static const EnumEntry RegisterNames[] = { + {"b", RegisterType::BReg}, + {"t", RegisterType::TReg}, + {"u", RegisterType::UReg}, + {"s", RegisterType::SReg}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) { + printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames)); + OS << Reg.Number; + + return OS; +} + +static const EnumEntry VisibilityNames[] = { + {"All", ShaderVisibility::All}, + {"Vertex", ShaderVisibility::Vertex}, + {"Hull", ShaderVisibility::Hull}, + {"Domain", ShaderVisibility::Domain}, + {"Geometry", ShaderVisibility::Geometry}, + {"Pixel", ShaderVisibility::Pixel}, + {"Amplification", ShaderVisibility::Amplification}, + {"Mesh", ShaderVisibility::Mesh}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, + const ShaderVisibility &Visibility) { + printEnum(OS, Visibility, ArrayRef(VisibilityNames)); + + return OS; +} + +static const EnumEntry ResourceClassNames[] = { + {"CBV", dxil::ResourceClass::CBuffer}, + {"SRV", dxil::ResourceClass::SRV}, + {"UAV", dxil::ResourceClass::UAV}, + {"Sampler", dxil::ResourceClass::Sampler}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) { + printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)), + ArrayRef(ResourceClassNames)); + + return OS; +} + +static const EnumEntry DescriptorRangeFlagNames[] = { + {"DescriptorsVolatile", DescriptorRangeFlags::DescriptorsVolatile}, + {"DataVolatile", DescriptorRangeFlags::DataVolatile}, + {"DataStaticWhileSetAtExecute", + DescriptorRangeFlags::DataStaticWhileSetAtExecute}, + {"DataStatic", DescriptorRangeFlags::DataStatic}, + {"DescriptorsStaticKeepingBufferBoundsChecks", + DescriptorRangeFlags::DescriptorsStaticKeepingBufferBoundsChecks}, +}; + +static raw_ostream &operator<<(raw_ostream &OS, + const DescriptorRangeFlags &Flags) { + printFlags(OS, Flags, ArrayRef(DescriptorRangeFlagNames)); return OS; } From 02931201d5d7dd11d3e1a2ac7820781103f565b7 Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Fri, 6 Jun 2025 21:43:16 +0000 Subject: [PATCH 2/5] nfc: use getEnumName instead of operator<< --- .../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 79eee0b12b304..3d8f90399dfc5 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -234,12 +234,12 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - llvm::SmallString<7> Name; - llvm::raw_svector_ostream OS(Name); - OS << "Root" << ClauseType(llvm::to_underlying(Descriptor.Type)); - + StringRef TypeName = + getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)), + ArrayRef(ResourceClassNames)); + llvm::SmallString<7> Name({"Root", TypeName}); Metadata *Operands[] = { - MDString::get(Ctx, OS.str()), + MDString::get(Ctx, Name), ConstantAsMetadata::get( Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))), ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)), @@ -275,12 +275,12 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); - std::string Name; - llvm::raw_string_ostream OS(Name); - OS << Clause.Type; + StringRef Name = + getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)), + ArrayRef(ResourceClassNames)); return MDNode::get( Ctx, { - MDString::get(Ctx, OS.str()), + MDString::get(Ctx, Name), ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), From d8cc1f9976aba530f882c4883aaa985d1e33f518 Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Fri, 6 Jun 2025 21:44:24 +0000 Subject: [PATCH 3/5] nfc: use operands to fix formatting --- .../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 3d8f90399dfc5..ab5ced523996a 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -278,16 +278,16 @@ MDNode *MetadataBuilder::BuildDescriptorTableClause( StringRef Name = getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)), ArrayRef(ResourceClassNames)); - return MDNode::get( - Ctx, { - MDString::get(Ctx, Name), - ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), - ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), - ConstantAsMetadata::get( - Builder.getInt32(llvm::to_underlying(Clause.Flags))), - }); + Metadata *Operands[] = { + MDString::get(Ctx, Name), + ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), + ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)), + ConstantAsMetadata::get( + Builder.getInt32(llvm::to_underlying(Clause.Flags))), + }; + return MDNode::get(Ctx, Operands); } MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) { From 7ae77bfe65464e9a6b94410cdb7b8d43817547dd Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Mon, 16 Jun 2025 17:03:03 +0000 Subject: [PATCH 4/5] review: use std::optional return value --- .../Frontend/HLSL/HLSLRootSignatureUtils.cpp | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index ab5ced523996a..6ed495b29b3df 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -22,18 +22,20 @@ namespace hlsl { namespace rootsig { template -static StringRef getEnumName(const T Value, ArrayRef> Enums) { +static std::optional getEnumName(const T Value, + ArrayRef> Enums) { for (const auto &EnumItem : Enums) if (EnumItem.Value == Value) return EnumItem.Name; - return ""; + return std::nullopt; } template static raw_ostream &printEnum(raw_ostream &OS, const T Value, ArrayRef> Enums) { - OS << getEnumName(Value, Enums); - + auto MaybeName = getEnumName(Value, Enums); + if (MaybeName) + OS << *MaybeName; return OS; } @@ -234,10 +236,11 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) { MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) { IRBuilder<> Builder(Ctx); - StringRef TypeName = + std::optional TypeName = getEnumName(dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)), ArrayRef(ResourceClassNames)); - llvm::SmallString<7> Name({"Root", TypeName}); + assert(TypeName && "Provided an invalid Resource Class"); + llvm::SmallString<7> Name({"Root", *TypeName}); Metadata *Operands[] = { MDString::get(Ctx, Name), ConstantAsMetadata::get( @@ -275,11 +278,12 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) { MDNode *MetadataBuilder::BuildDescriptorTableClause( const DescriptorTableClause &Clause) { IRBuilder<> Builder(Ctx); - StringRef Name = + std::optional Name = getEnumName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)), ArrayRef(ResourceClassNames)); + assert(Name && "Provided an invalid Resource Class"); Metadata *Operands[] = { - MDString::get(Ctx, Name), + MDString::get(Ctx, *Name), ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)), ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)), ConstantAsMetadata::get(Builder.getInt32(Clause.Space)), From 36c3274c62059b07c55eadd889926aa8a4fa541e Mon Sep 17 00:00:00 2001 From: Finn Plummer Date: Mon, 16 Jun 2025 17:06:04 +0000 Subject: [PATCH 5/5] review: re-use getEnumName --- llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp index 6ed495b29b3df..7d744781da04f 100644 --- a/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp +++ b/llvm/lib/Frontend/HLSL/HLSLRootSignatureUtils.cpp @@ -50,15 +50,12 @@ static raw_ostream &printFlags(raw_ostream &OS, const T Value, if (FlagSet) OS << " | "; - bool Found = false; - for (const auto &FlagItem : Flags) - if (FlagItem.Value == T(Bit)) { - OS << FlagItem.Name; - Found = true; - break; - } - if (!Found) + auto MaybeFlag = getEnumName(T(Bit), Flags); + if (MaybeFlag) + OS << *MaybeFlag; + else OS << "invalid: " << Bit; + FlagSet = true; } Remaining &= ~Bit;