From 52018dccfb31a74b85aec67068e9a11545be09af Mon Sep 17 00:00:00 2001 From: Justin Bogner Date: Thu, 15 Aug 2024 16:51:32 +0200 Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20change?= =?UTF-8?q?s=20to=20main=20this=20commit=20is=20based=20on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Created using spr 1.3.5-bogner [skip ci] --- llvm/docs/DirectX/DXILOpTableGenDesign.rst | 28 +- llvm/docs/DirectX/DXILResources.rst | 12 +- llvm/include/llvm/Analysis/DXILResource.h | 8 + llvm/include/llvm/IR/IntrinsicsDirectX.td | 8 + llvm/include/llvm/Support/DXILABI.h | 17 - llvm/lib/Target/DirectX/DXIL.td | 305 +++++++++------ llvm/lib/Target/DirectX/DXILConstants.h | 5 + llvm/lib/Target/DirectX/DXILMetadata.cpp | 5 + llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 165 +++++--- llvm/lib/Target/DirectX/DXILOpBuilder.h | 36 +- llvm/lib/Target/DirectX/DXILOpLowering.cpp | 364 +++++++++++++++--- llvm/lib/Target/DirectX/DXILOpLowering.h | 27 ++ llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp | 58 ++- llvm/lib/Target/DirectX/DXILPrettyPrinter.h | 33 ++ llvm/lib/Target/DirectX/DXILResource.cpp | 39 +- llvm/lib/Target/DirectX/DXILResource.h | 9 +- .../Target/DirectX/DXILResourceAnalysis.cpp | 11 - .../lib/Target/DirectX/DXILResourceAnalysis.h | 15 +- .../Target/DirectX/DXILTranslateMetadata.cpp | 96 +++-- .../Target/DirectX/DXILTranslateMetadata.h | 24 ++ llvm/lib/Target/DirectX/DirectX.h | 8 +- .../Target/DirectX/DirectXPassRegistry.def | 5 +- .../Target/DirectX/DirectXTargetMachine.cpp | 10 +- .../DXILResource/buffer-frombinding.ll | 6 +- llvm/test/CodeGen/DirectX/BufferLoad.ll | 102 +++++ .../CodeGen/DirectX/BufferStore-errors.ll | 34 ++ llvm/test/CodeGen/DirectX/BufferStore.ll | 92 +++++ llvm/test/CodeGen/DirectX/CreateHandle.ll | 61 +++ .../DirectX/CreateHandleFromBinding.ll | 65 ++++ .../CodeGen/DirectX/Metadata/dxilVer-1.0.ll | 2 +- .../CodeGen/DirectX/Metadata/dxilVer-1.8.ll | 2 +- llvm/test/CodeGen/DirectX/UAVMetadata.ll | 2 +- llvm/test/CodeGen/DirectX/any.ll | 2 +- llvm/test/CodeGen/DirectX/cbuf.ll | 2 +- llvm/test/CodeGen/DirectX/floor.ll | 2 +- llvm/utils/TableGen/DXILEmitter.cpp | 180 +++------ 36 files changed, 1304 insertions(+), 536 deletions(-) create mode 100644 llvm/lib/Target/DirectX/DXILOpLowering.h create mode 100644 llvm/lib/Target/DirectX/DXILPrettyPrinter.h create mode 100644 llvm/lib/Target/DirectX/DXILTranslateMetadata.h create mode 100644 llvm/test/CodeGen/DirectX/BufferLoad.ll create mode 100644 llvm/test/CodeGen/DirectX/BufferStore-errors.ll create mode 100644 llvm/test/CodeGen/DirectX/BufferStore.ll create mode 100644 llvm/test/CodeGen/DirectX/CreateHandle.ll create mode 100644 llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll diff --git a/llvm/docs/DirectX/DXILOpTableGenDesign.rst b/llvm/docs/DirectX/DXILOpTableGenDesign.rst index 50d801bd05efd..46106ee2a50f5 100644 --- a/llvm/docs/DirectX/DXILOpTableGenDesign.rst +++ b/llvm/docs/DirectX/DXILOpTableGenDesign.rst @@ -93,18 +93,14 @@ properties are specified as fields of the ``DXILOp`` class as described below. class DXILOpClass; Concrete operation records, such as ``unary`` are defined by inheriting from ``DXILOpClass``. -6. Return type of the operation is represented as ``LLVMType``. -7. Operation arguments are represented as a list of ``LLVMType`` with each type - corresponding to the argument position. An overload type, if supported by the operation, is - denoted as the positional type ``overloadTy`` in the argument or in the result, where - ``overloadTy`` is defined to be synonymous to ``llvm_any_ty``. - - .. code-block:: - - defvar overloadTy = llvm_any_ty - - Empty list, ``[]`` represents an operation with no arguments. - +6. A set of type names are defined that represent return and argument types, + which all inherit from ``DXILOpParamType``. These represent simple types + like ``int32Ty``, DXIL types like ``dx.types.Handle``, and a special + ``overloadTy`` which can be any type allowed by ``Overloads``, described + below. +7. Operation return type is represented as a ``DXILOpParamType``, and arguments + are represented as a list of the same. An operation with no return value + shall specify ``VoidTy`` as its return. 8. Valid operation overload types predicated on DXIL version are specified as a list of ``Overloads`` records. Representation of ``Overloads`` class is described in a later section. @@ -145,10 +141,10 @@ TableGen representations of its properties described above. Intrinsic LLVMIntrinsic = ?; // Result type of the op. - LLVMType result; + DXILOpParamType result; // List of argument types of the op. Default to 0 arguments. - list arguments = []; + list arguments = []; // List of valid overload types predicated by DXIL version list overloads; @@ -233,9 +229,9 @@ overloads predicated on DXIL version as list of records of the following class .. code-block:: - class Overloads ols> { + class Overloads ols> { Version dxil_version = minver; - list overload_types = ols; + list overload_types = ols; } Following is an example specification of valid overload types for ``DXIL1_0`` and diff --git a/llvm/docs/DirectX/DXILResources.rst b/llvm/docs/DirectX/DXILResources.rst index aef88bc43b224..0df2f51318977 100644 --- a/llvm/docs/DirectX/DXILResources.rst +++ b/llvm/docs/DirectX/DXILResources.rst @@ -162,6 +162,10 @@ the subsequent ``dx.op.annotateHandle`` operation in. Note that we don't have an analogue for `dx.op.createHandle`_, since ``dx.op.createHandleFromBinding`` subsumes it. +For simplicity of lowering, We match DXIL in using an index from the beginning +of the binding space rather than an index from the lower bound of the binding +itself. + .. _dx.op.createHandle: https://github.com/microsoft/DirectXShaderCompiler/blob/main/docs/DXIL.rst#resource-handles .. list-table:: ``@llvm.dx.handle.fromBinding`` @@ -190,7 +194,7 @@ subsumes it. * - ``%index`` - 4 - ``i32`` - - Index of the resource to access. + - Index from the beginning of the binding space to access. * - ``%non-uniform`` - 5 - i1 @@ -365,11 +369,11 @@ Examples: .. code-block:: llvm - call void @llvm.dx.bufferStore.tdx.Buffer_f32_1_0t( + call void @llvm.dx.typedBufferStore.tdx.Buffer_v4f32_1_0_0t( target("dx.TypedBuffer", f32, 1, 0) %buf, i32 %index, <4 x f32> %data) - call void @llvm.dx.bufferStore.tdx.Buffer_f16_1_0t( + call void @llvm.dx.typedBufferStore.tdx.Buffer_v4f16_1_0_0t( target("dx.TypedBuffer", f16, 1, 0) %buf, i32 %index, <4 x f16> %data) - call void @llvm.dx.bufferStore.tdx.Buffer_f64_1_0t( + call void @llvm.dx.typedBufferStore.tdx.Buffer_v2f64_1_0_0t( target("dx.TypedBuffer", f64, 1, 0) %buf, i32 %index, <2 x f64> %data) .. list-table:: ``@llvm.dx.rawBufferPtr`` diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h index 3ba0ae5de61d5..2ed508b28a908 100644 --- a/llvm/include/llvm/Analysis/DXILResource.h +++ b/llvm/include/llvm/Analysis/DXILResource.h @@ -23,6 +23,7 @@ class TargetExtType; namespace dxil { class ResourceInfo { +public: struct ResourceBinding { uint32_t RecordID; uint32_t Space; @@ -89,6 +90,7 @@ class ResourceInfo { bool operator!=(const FeedbackInfo &RHS) const { return !(*this == RHS); } }; +private: // Universal properties. Value *Symbol; StringRef Name; @@ -115,6 +117,10 @@ class ResourceInfo { MSInfo MultiSample; + // We need a default constructor if we want to insert this in a MapVector. + ResourceInfo() {} + friend class MapVector; + public: ResourceInfo(dxil::ResourceClass RC, dxil::ResourceKind Kind, Value *Symbol, StringRef Name) @@ -166,6 +172,8 @@ class ResourceInfo { MultiSample.Count = Count; } + dxil::ResourceClass getResourceClass() const { return RC; } + bool operator==(const ResourceInfo &RHS) const; static ResourceInfo SRV(Value *Symbol, StringRef Name, diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index c9102aa3dd972..67351ad8f9b91 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -30,6 +30,14 @@ def int_dx_handle_fromBinding [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty], [IntrNoMem]>; +def int_dx_typedBufferLoad + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], [llvm_any_ty, llvm_i32_ty]>; +def int_dx_typedBufferStore + : DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty]>; + +// Cast between target extension handle types and dxil-style opaque handles +def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>; + def int_dx_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h index a2222eec09ba8..cf2c42c689889 100644 --- a/llvm/include/llvm/Support/DXILABI.h +++ b/llvm/include/llvm/Support/DXILABI.h @@ -22,23 +22,6 @@ namespace llvm { namespace dxil { -enum class ParameterKind : uint8_t { - Invalid = 0, - Void, - Half, - Float, - Double, - I1, - I8, - I16, - I32, - I64, - Overload, - CBufferRet, - ResourceRet, - DXILHandle, -}; - enum class ResourceClass : uint8_t { SRV = 0, UAV, diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 67015cff78a79..50eff20e810d7 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -24,19 +24,29 @@ foreach i = 0...8 in { def DXIL1_ #i : Version<1, i>; } -// Overload type alias of llvm_any_ty -defvar overloadTy = llvm_any_ty; - -// Type aliases for DXIL Op types to LLVM Types. -// TODO: Define DXIL Types independent of LLVM types -defvar i1Ty = llvm_i1_ty; -defvar i8Ty = llvm_i8_ty; -defvar i16Ty = llvm_i16_ty; -defvar i32Ty = llvm_i32_ty; -defvar i64Ty = llvm_i64_ty; -defvar halfTy = llvm_half_ty; -defvar floatTy = llvm_float_ty; -defvar doubleTy = llvm_double_ty; +class DXILOpParamType { + int isOverload = 0; +} + +let isOverload = 1 in { + def OverloadTy : DXILOpParamType; +} +def VoidTy : DXILOpParamType; +def Int1Ty : DXILOpParamType; +def Int8Ty : DXILOpParamType; +def Int16Ty : DXILOpParamType; +def Int32Ty : DXILOpParamType; +def Int64Ty : DXILOpParamType; +def HalfTy : DXILOpParamType; +def FloatTy : DXILOpParamType; +def DoubleTy : DXILOpParamType; +def ResRetHalfTy : DXILOpParamType; +def ResRetFloatTy : DXILOpParamType; +def ResRetInt16Ty : DXILOpParamType; +def ResRetInt32Ty : DXILOpParamType; +def HandleTy : DXILOpParamType; +def ResBindTy : DXILOpParamType; +def ResPropsTy : DXILOpParamType; class DXILOpClass; @@ -268,9 +278,9 @@ def IsWave : DXILAttribute; def NeedsUniformInputs : DXILAttribute; def IsBarrier : DXILAttribute; -class Overloads ols> { +class Overloads ols> { Version dxil_version = ver; - list overload_types = ols; + list overload_types = ols; } class Stages st> { @@ -298,10 +308,10 @@ class DXILOp { Intrinsic LLVMIntrinsic = ?; // Result type of the op - LLVMType result; + DXILOpParamType result; // List of argument types of the op. Default to 0 arguments. - list arguments = []; + list arguments = []; // List of valid overload types predicated by DXIL version list overloads = []; @@ -318,9 +328,9 @@ class DXILOp { def Abs : DXILOp<6, unary> { let Doc = "Returns the absolute value of the input."; let LLVMIntrinsic = int_fabs; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -328,9 +338,9 @@ def Abs : DXILOp<6, unary> { def IsInf : DXILOp<9, isSpecialFloat> { let Doc = "Determines if the specified value is infinite."; let LLVMIntrinsic = int_dx_isinf; - let arguments = [overloadTy]; - let result = i1Ty; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = Int1Ty; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -338,9 +348,9 @@ def IsInf : DXILOp<9, isSpecialFloat> { def Cos : DXILOp<12, unary> { let Doc = "Returns cosine(theta) for theta in radians."; let LLVMIntrinsic = int_cos; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -348,9 +358,9 @@ def Cos : DXILOp<12, unary> { def Sin : DXILOp<13, unary> { let Doc = "Returns sine(theta) for theta in radians."; let LLVMIntrinsic = int_sin; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -358,9 +368,9 @@ def Sin : DXILOp<13, unary> { def Tan : DXILOp<14, unary> { let Doc = "Returns tangent(theta) for theta in radians."; let LLVMIntrinsic = int_tan; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -368,9 +378,9 @@ def Tan : DXILOp<14, unary> { def ACos : DXILOp<15, unary> { let Doc = "Returns the arccosine of the specified value."; let LLVMIntrinsic = int_acos; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -378,9 +388,9 @@ def ACos : DXILOp<15, unary> { def ASin : DXILOp<16, unary> { let Doc = "Returns the arcsine of the specified value."; let LLVMIntrinsic = int_asin; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -388,9 +398,9 @@ def ASin : DXILOp<16, unary> { def ATan : DXILOp<17, unary> { let Doc = "Returns the arctangent of the specified value."; let LLVMIntrinsic = int_atan; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -398,9 +408,9 @@ def ATan : DXILOp<17, unary> { def HCos : DXILOp<18, unary> { let Doc = "Returns the hyperbolic cosine of the specified value."; let LLVMIntrinsic = int_cosh; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -408,9 +418,9 @@ def HCos : DXILOp<18, unary> { def HSin : DXILOp<19, unary> { let Doc = "Returns the hyperbolic sine of the specified value."; let LLVMIntrinsic = int_sinh; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -418,9 +428,9 @@ def HSin : DXILOp<19, unary> { def HTan : DXILOp<20, unary> { let Doc = "Returns the hyperbolic tan of the specified value."; let LLVMIntrinsic = int_tanh; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -429,9 +439,9 @@ def Exp2 : DXILOp<21, unary> { let Doc = "Returns the base 2 exponential, or 2**x, of the specified value. " "exp2(x) = 2**x."; let LLVMIntrinsic = int_exp2; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -440,9 +450,9 @@ def Frac : DXILOp<22, unary> { let Doc = "Returns a fraction from 0 to 1 that represents the decimal part " "of the input."; let LLVMIntrinsic = int_dx_frac; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -450,9 +460,9 @@ def Frac : DXILOp<22, unary> { def Log2 : DXILOp<23, unary> { let Doc = "Returns the base-2 logarithm of the specified value."; let LLVMIntrinsic = int_log2; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -461,9 +471,9 @@ def Sqrt : DXILOp<24, unary> { let Doc = "Returns the square root of the specified floating-point value, " "per component."; let LLVMIntrinsic = int_sqrt; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -472,9 +482,9 @@ def RSqrt : DXILOp<25, unary> { let Doc = "Returns the reciprocal of the square root of the specified value. " "rsqrt(x) = 1 / sqrt(x)."; let LLVMIntrinsic = int_dx_rsqrt; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -483,9 +493,9 @@ def Round : DXILOp<26, unary> { let Doc = "Returns the input rounded to the nearest integer within a " "floating-point type."; let LLVMIntrinsic = int_roundeven; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -494,9 +504,9 @@ def Floor : DXILOp<27, unary> { let Doc = "Returns the largest integer that is less than or equal to the input."; let LLVMIntrinsic = int_floor; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -505,9 +515,9 @@ def Ceil : DXILOp<28, unary> { let Doc = "Returns the smallest integer that is greater than or equal to the " "input."; let LLVMIntrinsic = int_ceil; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -515,9 +525,9 @@ def Ceil : DXILOp<28, unary> { def Trunc : DXILOp<29, unary> { let Doc = "Returns the specified value truncated to the integer component."; let LLVMIntrinsic = int_trunc; - let arguments = [overloadTy]; - let result = overloadTy; - let overloads = [Overloads]; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -525,10 +535,10 @@ def Trunc : DXILOp<29, unary> { def Rbits : DXILOp<30, unary> { let Doc = "Returns the specified value with its bits reversed."; let LLVMIntrinsic = int_bitreverse; - let arguments = [overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -536,10 +546,10 @@ def Rbits : DXILOp<30, unary> { def FMax : DXILOp<35, binary> { let Doc = "Float maximum. FMax(a,b) = a > b ? a : b"; let LLVMIntrinsic = int_maxnum; - let arguments = [overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -547,10 +557,10 @@ def FMax : DXILOp<35, binary> { def FMin : DXILOp<36, binary> { let Doc = "Float minimum. FMin(a,b) = a < b ? a : b"; let LLVMIntrinsic = int_minnum; - let arguments = [overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -558,10 +568,10 @@ def FMin : DXILOp<36, binary> { def SMax : DXILOp<37, binary> { let Doc = "Signed integer maximum. SMax(a,b) = a > b ? a : b"; let LLVMIntrinsic = int_smax; - let arguments = [overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -569,10 +579,10 @@ def SMax : DXILOp<37, binary> { def SMin : DXILOp<38, binary> { let Doc = "Signed integer minimum. SMin(a,b) = a < b ? a : b"; let LLVMIntrinsic = int_smin; - let arguments = [overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -580,10 +590,10 @@ def SMin : DXILOp<38, binary> { def UMax : DXILOp<39, binary> { let Doc = "Unsigned integer maximum. UMax(a,b) = a > b ? a : b"; let LLVMIntrinsic = int_umax; - let arguments = [overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -591,10 +601,10 @@ def UMax : DXILOp<39, binary> { def UMin : DXILOp<40, binary> { let Doc = "Unsigned integer minimum. UMin(a,b) = a < b ? a : b"; let LLVMIntrinsic = int_umin; - let arguments = [overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -603,10 +613,10 @@ def FMad : DXILOp<46, tertiary> { let Doc = "Floating point arithmetic multiply/add operation. fmad(m,a,b) = m " "* a + b."; let LLVMIntrinsic = int_fmuladd; - let arguments = [overloadTy, overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -615,10 +625,10 @@ def IMad : DXILOp<48, tertiary> { let Doc = "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m " "* a + b."; let LLVMIntrinsic = int_dx_imad; - let arguments = [overloadTy, overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -627,10 +637,10 @@ def UMad : DXILOp<49, tertiary> { let Doc = "Unsigned integer arithmetic multiply/add operation. umad(m,a, = m " "* a + b."; let LLVMIntrinsic = int_dx_umad; - let arguments = [overloadTy, overloadTy, overloadTy]; - let result = overloadTy; + let arguments = [OverloadTy, OverloadTy, OverloadTy]; + let result = OverloadTy; let overloads = - [Overloads]; + [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -639,9 +649,9 @@ def Dot2 : DXILOp<54, dot2> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " "a[n]*b[n] where n is between 0 and 1"; let LLVMIntrinsic = int_dx_dot2; - let arguments = !listsplat(overloadTy, 4); - let result = overloadTy; - let overloads = [Overloads]; + let arguments = !listsplat(OverloadTy, 4); + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -650,9 +660,9 @@ def Dot3 : DXILOp<55, dot3> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " "a[n]*b[n] where n is between 0 and 2"; let LLVMIntrinsic = int_dx_dot3; - let arguments = !listsplat(overloadTy, 6); - let result = overloadTy; - let overloads = [Overloads]; + let arguments = !listsplat(OverloadTy, 6); + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -661,18 +671,50 @@ def Dot4 : DXILOp<56, dot4> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " "a[n]*b[n] where n is between 0 and 3"; let LLVMIntrinsic = int_dx_dot4; - let arguments = !listsplat(overloadTy, 8); - let result = overloadTy; - let overloads = [Overloads]; + let arguments = !listsplat(OverloadTy, 8); + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } +def CreateHandle : DXILOp<57, createHandle> { + let Doc = "creates the handle to a resource"; + // ResourceClass, RangeID, Index, NonUniform + let arguments = [Int8Ty, Int32Ty, Int32Ty, Int1Ty]; + let result = HandleTy; + let stages = [Stages]; +} + +def BufferLoad : DXILOp<68, bufferLoad> { + let Doc = "reads from a TypedBuffer"; + // Handle, Coord0, Coord1 + let arguments = [HandleTy, Int32Ty, Int32Ty]; + let result = OverloadTy; + let overloads = + [Overloads]; + let stages = [Stages]; +} + +def BufferStore : DXILOp<69, bufferStore> { + let Doc = "writes to an RWTypedBuffer"; + // Handle, Coord0, Coord1, Val0, Val1, Val2, Val3, Mask + let arguments = [ + HandleTy, Int32Ty, Int32Ty, OverloadTy, OverloadTy, OverloadTy, OverloadTy, + Int8Ty + ]; + let result = VoidTy; + let overloads = [Overloads]; + let stages = [Stages]; +} + def ThreadId : DXILOp<93, threadId> { let Doc = "Reads the thread ID"; let LLVMIntrinsic = int_dx_thread_id; - let arguments = [i32Ty]; - let result = i32Ty; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -680,8 +722,9 @@ def ThreadId : DXILOp<93, threadId> { def GroupId : DXILOp<94, groupId> { let Doc = "Reads the group ID (SV_GroupID)"; let LLVMIntrinsic = int_dx_group_id; - let arguments = [i32Ty]; - let result = i32Ty; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -689,8 +732,9 @@ def GroupId : DXILOp<94, groupId> { def ThreadIdInGroup : DXILOp<95, threadIdInGroup> { let Doc = "Reads the thread ID within the group (SV_GroupThreadID)"; let LLVMIntrinsic = int_dx_thread_id_in_group; - let arguments = [i32Ty]; - let result = i32Ty; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -699,7 +743,22 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> { let Doc = "Provides a flattened index for a given thread within a given " "group (SV_GroupIndex)"; let LLVMIntrinsic = int_dx_flattened_thread_id_in_group; - let result = i32Ty; + let result = OverloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } + +def AnnotateHandle : DXILOp<217, annotateHandle> { + let Doc = "annotate handle with resource properties"; + let arguments = [HandleTy, ResPropsTy]; + let result = HandleTy; + let stages = [Stages]; +} + +def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> { + let Doc = "create resource handle from binding"; + let arguments = [ResBindTy, Int32Ty, Int1Ty]; + let result = HandleTy; + let stages = [Stages]; +} diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h index 0c9c1ac38fdbc..022cd57795a06 100644 --- a/llvm/lib/Target/DirectX/DXILConstants.h +++ b/llvm/lib/Target/DirectX/DXILConstants.h @@ -25,6 +25,11 @@ enum class OpCodeClass : unsigned { #include "DXILOperation.inc" }; +enum class OpParamType : unsigned { +#define DXIL_OP_PARAM_TYPE(Name) Name, +#include "DXILOperation.inc" +}; + } // namespace dxil } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILMetadata.cpp b/llvm/lib/Target/DirectX/DXILMetadata.cpp index ed0434ac98a18..1f5759c363013 100644 --- a/llvm/lib/Target/DirectX/DXILMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILMetadata.cpp @@ -284,6 +284,11 @@ void dxil::createEntryMD(Module &M, const uint64_t ShaderFlags) { EntryList.emplace_back(&F); } + // If there are no entries, do nothing. This is mostly to allow for writing + // tests with no actual entry functions. + if (EntryList.empty()) + return; + auto &Ctx = M.getContext(); // FIXME: generate metadata for resource. // See https://github.com/llvm/llvm-project/issues/57926. diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 42df7c90cb337..1594fa533379b 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -11,7 +11,6 @@ #include "DXILOpBuilder.h" #include "DXILConstants.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" @@ -87,6 +86,9 @@ static const char *getOverloadTypeName(OverloadKind Kind) { } static OverloadKind getOverloadKind(Type *Ty) { + if (!Ty) + return OverloadKind::VOID; + Type::TypeID T = Ty->getTypeID(); switch (T) { case Type::VoidTyID: @@ -118,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) { } case Type::PointerTyID: return OverloadKind::UserDefineType; - case Type::StructTyID: + case Type::StructTyID: { + // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework + // how we're handling overloads and remove the `OverloadKind` proxy enum. + StructType *ST = cast(Ty); + if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet")) + return getOverloadKind(ST->getElementType(0)); + return OverloadKind::ObjectType; + } default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; @@ -154,10 +163,8 @@ struct OpCodeProperty { llvm::SmallVector Overloads; llvm::SmallVector Stages; llvm::SmallVector Attributes; - int OverloadParamIndex; // parameter index which control the overload. - // When < 0, should be only 1 overload type. - unsigned NumOfParameters; // Number of parameters include return value. - unsigned ParameterTableOffset; // Offset in ParameterTable. + int OverloadParamIndex; // parameter index which control the overload. + // When < 0, should be only 1 overload type. }; // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and @@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name, return StructType::create(Ctx, EltTys, Name); } -static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { - OverloadKind Kind = getOverloadKind(OverloadTy); +static StructType *getResRetType(Type *ElementTy) { + LLVMContext &Ctx = ElementTy->getContext(); + OverloadKind Kind = getOverloadKind(ElementTy); std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); - Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, + Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy, Type::getInt32Ty(Ctx)}; return getOrCreateStructType(TypeName, FieldTypes, Ctx); } @@ -208,35 +216,60 @@ static StructType *getHandleType(LLVMContext &Ctx) { Ctx); } -static Type *getTypeFromParameterKind(ParameterKind Kind, LLVMContext &Ctx, - Type *OverloadTy) { +static StructType *getResBindType(LLVMContext &Context) { + if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind")) + return ST; + Type *Int32Ty = Type::getInt32Ty(Context); + Type *Int8Ty = Type::getInt8Ty(Context); + return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty}, + "dx.types.ResBind"); +} + +static StructType *getResPropsType(LLVMContext &Context) { + if (auto *ST = + StructType::getTypeByName(Context, "dx.types.ResourceProperties")) + return ST; + Type *Int32Ty = Type::getInt32Ty(Context); + return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties"); +} + +static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, + Type *OverloadTy) { switch (Kind) { - case ParameterKind::Void: + case OpParamType::VoidTy: return Type::getVoidTy(Ctx); - case ParameterKind::Half: + case OpParamType::HalfTy: return Type::getHalfTy(Ctx); - case ParameterKind::Float: + case OpParamType::FloatTy: return Type::getFloatTy(Ctx); - case ParameterKind::Double: + case OpParamType::DoubleTy: return Type::getDoubleTy(Ctx); - case ParameterKind::I1: + case OpParamType::Int1Ty: return Type::getInt1Ty(Ctx); - case ParameterKind::I8: + case OpParamType::Int8Ty: return Type::getInt8Ty(Ctx); - case ParameterKind::I16: + case OpParamType::Int16Ty: return Type::getInt16Ty(Ctx); - case ParameterKind::I32: + case OpParamType::Int32Ty: return Type::getInt32Ty(Ctx); - case ParameterKind::I64: + case OpParamType::Int64Ty: return Type::getInt64Ty(Ctx); - case ParameterKind::Overload: + case OpParamType::OverloadTy: return OverloadTy; - case ParameterKind::ResourceRet: - return getResRetType(OverloadTy, Ctx); - case ParameterKind::DXILHandle: + case OpParamType::ResRetHalfTy: + return getResRetType(Type::getHalfTy(Ctx)); + case OpParamType::ResRetFloatTy: + return getResRetType(Type::getFloatTy(Ctx)); + case OpParamType::ResRetInt16Ty: + return getResRetType(Type::getInt16Ty(Ctx)); + case OpParamType::ResRetInt32Ty: + return getResRetType(Type::getInt32Ty(Ctx)); + case OpParamType::HandleTy: return getHandleType(Ctx); - default: - break; + case OpParamType::ResBindTy: + return getResBindType(Ctx); + case OpParamType::ResPropsTy: + return getResPropsType(Ctx); } llvm_unreachable("Invalid parameter kind"); return nullptr; @@ -281,30 +314,34 @@ static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) { "Shader Kind Not Found - Invalid DXIL Environment Specified"); } +static SmallVector +getArgTypesFromOpParamTypes(ArrayRef Types, + LLVMContext &Context, Type *OverloadTy) { + SmallVector ArgTys; + ArgTys.emplace_back(Type::getInt32Ty(Context)); + for (dxil::OpParamType Ty : Types) + ArgTys.emplace_back(getTypeFromOpParamType(Ty, Context, OverloadTy)); + return ArgTys; +} + /// Construct DXIL function type. This is the type of a function with /// the following prototype /// OverloadType dx.op..(int opcode, ) /// are constructed from types in Prop. -static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, +static FunctionType *getDXILOpFunctionType(dxil::OpCode OpCode, LLVMContext &Context, Type *OverloadTy) { - SmallVector ArgTys; - - const ParameterKind *ParamKinds = getOpCodeParameterKind(*Prop); - assert(Prop->NumOfParameters && "No return type?"); - // Add return type of the function - Type *ReturnTy = getTypeFromParameterKind(ParamKinds[0], Context, OverloadTy); - - // Add DXIL Opcode value type viz., Int32 as first argument - ArgTys.emplace_back(Type::getInt32Ty(Context)); - - // Add DXIL Operation parameter types as specified in DXIL properties - for (unsigned I = 1; I < Prop->NumOfParameters; ++I) { - ParameterKind Kind = ParamKinds[I]; - ArgTys.emplace_back(getTypeFromParameterKind(Kind, Context, OverloadTy)); + switch (OpCode) { +#define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...) \ + case OpCode: \ + return FunctionType::get( \ + getTypeFromOpParamType(RetType, Context, OverloadTy), \ + getArgTypesFromOpParamTypes({__VA_ARGS__}, Context, OverloadTy), \ + /*isVarArg=*/false); +#include "DXILOperation.inc" } - return FunctionType::get(ReturnTy, ArgTys, /*isVarArg=*/false); + llvm_unreachable("Invalid OpCode?"); } /// Get index of the property from PropList valid for the most recent @@ -332,7 +369,7 @@ namespace dxil { // Triple is well-formed or that the target is supported since these checks // would have been done at the time the module M is constructed in the earlier // stages of compilation. -DXILOpBuilder::DXILOpBuilder(Module &M, IRBuilderBase &B) : M(M), B(B) { +DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) { Triple TT(Triple(M.getTargetTriple())); DXILVersion = TT.getDXILVersion(); ShaderStage = TT.getEnvironment(); @@ -368,8 +405,9 @@ Expected DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, return makeOpError(OpCode, "Wrong number of arguments"); OverloadTy = Args[ArgIndex]->getType(); } + FunctionType *DXILOpFT = - getDXILOpFunctionType(Prop, M.getContext(), OverloadTy); + getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy); std::optional OlIndexOrErr = getPropIndex(ArrayRef(Prop->Overloads), DXILVersion); @@ -379,11 +417,7 @@ Expected DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys; - // If we don't have an overload type, use the function's return type. This is - // a bit of a hack, but it's necessary to get the type suffix on unoverloaded - // DXIL ops correct, like `dx.op.threadId.i32`. - OverloadKind Kind = - getOverloadKind(OverloadTy ? OverloadTy : DXILOpFT->getReturnType()); + OverloadKind Kind = getOverloadKind(OverloadTy); // Check if the operation supports overload types and OverloadTy is valid // per the specified types for the operation @@ -418,13 +452,13 @@ Expected DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, // We need to inject the opcode as the first argument. SmallVector OpArgs; - OpArgs.push_back(B.getInt32(llvm::to_underlying(OpCode))); + OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode))); OpArgs.append(Args.begin(), Args.end()); - return B.CreateCall(DXILFn, OpArgs); + return IRB.CreateCall(DXILFn, OpArgs); } -CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef &Args, +CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef Args, Type *RetTy) { Expected Result = tryCreateOp(OpCode, Args, RetTy); if (Error E = Result.takeError()) @@ -432,6 +466,33 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef &Args, return *Result; } +StructType *DXILOpBuilder::getResRetType(Type *ElementTy) { + return ::getResRetType(ElementTy); +} + +StructType *DXILOpBuilder::getHandleType() { + return ::getHandleType(IRB.getContext()); +} + +Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound, + uint32_t SpaceID, dxil::ResourceClass RC) { + Type *Int32Ty = IRB.getInt32Ty(); + Type *Int8Ty = IRB.getInt8Ty(); + return ConstantStruct::get( + getResBindType(IRB.getContext()), + {ConstantInt::get(Int32Ty, LowerBound), + ConstantInt::get(Int32Ty, UpperBound), + ConstantInt::get(Int32Ty, SpaceID), + ConstantInt::get(Int8Ty, llvm::to_underlying(RC))}); +} + +Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) { + Type *Int32Ty = IRB.getInt32Ty(); + return ConstantStruct::get( + getResPropsType(IRB.getContext()), + {ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)}); +} + const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { return ::getOpCodeName(DXILOp); } diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index ff66f39a3ceb3..a68f0c43f67af 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -14,13 +14,16 @@ #include "DXILConstants.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/TargetParser/Triple.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/DXILABI.h" #include "llvm/Support/Error.h" +#include "llvm/TargetParser/Triple.h" namespace llvm { class Module; class IRBuilderBase; class CallInst; +class Constant; class Value; class Type; class FunctionType; @@ -29,29 +32,30 @@ namespace dxil { class DXILOpBuilder { public: - DXILOpBuilder(Module &M, IRBuilderBase &B); + DXILOpBuilder(Module &M); + + IRBuilder<> &getIRB() { return IRB; } /// Create a call instruction for the given DXIL op. The arguments /// must be valid for an overload of the operation. - CallInst *createOp(dxil::OpCode Op, ArrayRef &Args, + CallInst *createOp(dxil::OpCode Op, ArrayRef Args, Type *RetTy = nullptr); -#define DXIL_OPCODE(Op, Name) \ - CallInst *create##Name##Op(ArrayRef &Args, Type *RetTy = nullptr) { \ - return createOp(dxil::OpCode(Op), Args, RetTy); \ - } -#include "DXILOperation.inc" - /// Try to create a call instruction for the given DXIL op. Fails if the /// overload is invalid. Expected tryCreateOp(dxil::OpCode Op, ArrayRef Args, Type *RetTy = nullptr); -#define DXIL_OPCODE(Op, Name) \ - Expected tryCreate##Name##Op(ArrayRef &Args, \ - Type *RetTy = nullptr) { \ - return tryCreateOp(dxil::OpCode(Op), Args, RetTy); \ - } -#include "DXILOperation.inc" + + /// Get a `%dx.types.ResRet` type with the given element type. + StructType *getResRetType(Type *ElementTy); + /// Get the `%dx.types.Handle` type. + StructType *getHandleType(); + + /// Get a constant `%dx.types.ResBind` value. + Constant *getResBind(uint32_t LowerBound, uint32_t UpperBound, + uint32_t SpaceID, dxil::ResourceClass RC); + /// Get a constant `%dx.types.ResourceProperties` value. + Constant *getResProps(uint32_t Word0, uint32_t Word1); /// Return the name of the given opcode. static const char *getOpCodeName(dxil::OpCode DXILOp); @@ -63,7 +67,7 @@ class DXILOpBuilder { Type *OverloadType = nullptr); Module &M; - IRBuilderBase &B; + IRBuilder<> IRB; VersionTuple DXILVersion; Triple::EnvironmentType ShaderStage; }; diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 5f84cdcfda6de..f34302cc95065 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -1,20 +1,18 @@ -//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// +//===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -/// -/// \file This file contains passes and utilities to lower llvm intrinsic call -/// to DXILOp function call. -//===----------------------------------------------------------------------===// +#include "DXILOpLowering.h" #include "DXILConstants.h" #include "DXILIntrinsicExpansion.h" #include "DXILOpBuilder.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/DXILResource.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" @@ -23,6 +21,7 @@ #include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" @@ -73,84 +72,330 @@ static SmallVector argVectorFlatten(CallInst *Orig, return NewOperands; } -static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { - IRBuilder<> B(M.getContext()); - DXILOpBuilder OpBuilder(M, B); - for (User *U : make_early_inc_range(F.users())) { - CallInst *CI = dyn_cast(U); - if (!CI) - continue; - - SmallVector Args; - B.SetInsertPoint(CI); - if (isVectorArgExpansion(F)) { - SmallVector NewArgs = argVectorFlatten(CI, B); - Args.append(NewArgs.begin(), NewArgs.end()); - } else - Args.append(CI->arg_begin(), CI->arg_end()); - - Expected OpCallOrErr = OpBuilder.tryCreateOp(DXILOp, Args, - F.getReturnType()); - if (Error E = OpCallOrErr.takeError()) { - std::string Message(toString(std::move(E))); - DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message, - CI->getDebugLoc()); - M.getContext().diagnose(Diag); - continue; +namespace { +class OpLowerer { + Module &M; + DXILOpBuilder OpBuilder; + DXILResourceMap &DRM; + SmallVector CleanupCasts; + bool HasErrors = false; + +public: + OpLowerer(Module &M, DXILResourceMap &DRM) : M(M), OpBuilder(M), DRM(DRM) {} + + void replaceFunction(Function &F, + llvm::function_ref ReplaceCall) { + for (User *U : make_early_inc_range(F.users())) { + CallInst *CI = dyn_cast(U); + if (!CI) + continue; + + if (Error E = ReplaceCall(CI)) { + std::string Message(toString(std::move(E))); + DiagnosticInfoUnsupported Diag(*CI->getFunction(), Message, + CI->getDebugLoc()); + M.getContext().diagnose(Diag); + HasErrors = true; + continue; + } } - CallInst *OpCall = *OpCallOrErr; + if (F.user_empty()) + F.eraseFromParent(); + } + + void replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp) { + bool IsVectorArgExpansion = isVectorArgExpansion(F); + replaceFunction(F, [&](CallInst *CI) -> Error { + SmallVector Args; + OpBuilder.getIRB().SetInsertPoint(CI); + if (IsVectorArgExpansion) { + SmallVector NewArgs = argVectorFlatten(CI, OpBuilder.getIRB()); + Args.append(NewArgs.begin(), NewArgs.end()); + } else + Args.append(CI->arg_begin(), CI->arg_end()); - CI->replaceAllUsesWith(OpCall); - CI->eraseFromParent(); + Expected OpCall = + OpBuilder.tryCreateOp(DXILOp, Args, F.getReturnType()); + if (Error E = OpCall.takeError()) + return E; + + CI->replaceAllUsesWith(*OpCall); + CI->eraseFromParent(); + return Error::success(); + }); } - if (F.user_empty()) - F.eraseFromParent(); -} -static bool lowerIntrinsics(Module &M) { - bool Updated = false; + Value *createTmpHandleCast(Value *V, Type *Ty) { + Function *CastFn = Intrinsic::getDeclaration(&M, Intrinsic::dx_cast_handle, + {Ty, V->getType()}); + CallInst *Cast = OpBuilder.getIRB().CreateCall(CastFn, {V}); + CleanupCasts.push_back(Cast); + return Cast; + } + + void cleanupHandleCasts() { + SmallVector ToRemove; + SmallVector CastFns; + + for (CallInst *Cast : CleanupCasts) { + CastFns.push_back(Cast->getCalledFunction()); + // All of the ops should be using `dx.types.Handle` at this point, so if + // we're not producing that we should be part of a pair. Track this so we + // can remove it at the end. + if (Cast->getType() != OpBuilder.getHandleType()) { + ToRemove.push_back(Cast); + continue; + } + // Otherwise, we're the second handle in a pair. Forward the arguments and + // remove the (second) cast. + CallInst *Def = cast(Cast->getOperand(0)); + assert(Def->getIntrinsicID() == Intrinsic::dx_cast_handle && + "Unbalanced pair of temporary handle casts"); + Cast->replaceAllUsesWith(Def->getOperand(0)); + Cast->eraseFromParent(); + } + for (CallInst *Cast : ToRemove) { + assert(Cast->user_empty() && "Temporary handle cast still has users"); + Cast->eraseFromParent(); + } + llvm::sort(CastFns); + CastFns.erase(llvm::unique(CastFns), CastFns.end()); + for (Function *F : CastFns) + F->eraseFromParent(); + + CleanupCasts.clear(); + } + + void lowerToCreateHandle(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int8Ty = IRB.getInt8Ty(); + Type *Int32Ty = IRB.getInt32Ty(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + dxil::ResourceInfo &RI = DRM[CI]; + dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding(); + + std::array Args{ + ConstantInt::get(Int8Ty, llvm::to_underlying(RI.getResourceClass())), + ConstantInt::get(Int32Ty, Binding.RecordID), CI->getArgOperand(3), + CI->getArgOperand(4)}; + Expected OpCall = + OpBuilder.tryCreateOp(OpCode::CreateHandle, Args); + if (Error E = OpCall.takeError()) + return E; + + Value *Cast = createTmpHandleCast(*OpCall, CI->getType()); + + CI->replaceAllUsesWith(Cast); + CI->eraseFromParent(); + return Error::success(); + }); + } + + void lowerToBindAndAnnotateHandle(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + dxil::ResourceInfo &RI = DRM[CI]; + dxil::ResourceInfo::ResourceBinding Binding = RI.getBinding(); + std::pair Props = RI.getAnnotateProps(); + + Constant *ResBind = OpBuilder.getResBind( + Binding.LowerBound, Binding.LowerBound + Binding.Size - 1, + Binding.Space, RI.getResourceClass()); + std::array BindArgs{ResBind, CI->getArgOperand(3), + CI->getArgOperand(4)}; + Expected OpBind = + OpBuilder.tryCreateOp(OpCode::CreateHandleFromBinding, BindArgs); + if (Error E = OpBind.takeError()) + return E; + + std::array AnnotateArgs{ + *OpBind, OpBuilder.getResProps(Props.first, Props.second)}; + Expected OpAnnotate = + OpBuilder.tryCreateOp(OpCode::AnnotateHandle, AnnotateArgs); + if (Error E = OpAnnotate.takeError()) + return E; + + Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType()); + + CI->replaceAllUsesWith(Cast); + CI->eraseFromParent(); + + return Error::success(); + }); + } + + void lowerHandleFromBinding(Function &F) { + Triple TT(Triple(M.getTargetTriple())); + if (TT.getDXILVersion() < VersionTuple(1, 6)) + lowerToCreateHandle(F); + else + lowerToBindAndAnnotateHandle(F); + } + + void lowerTypedBufferLoad(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int32Ty = IRB.getInt32Ty(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + Value *Handle = + createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); + Value *Index0 = CI->getArgOperand(1); + Value *Index1 = UndefValue::get(Int32Ty); + Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType()); + + std::array Args{Handle, Index0, Index1}; + Expected OpCall = + OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy); + if (Error E = OpCall.takeError()) + return E; + + std::array Extracts = {}; + + // We've switched the return type from a vector to a struct, but at this + // point most vectors have probably already been scalarized. Try to + // forward arguments directly rather than inserting into and immediately + // extracting from a vector. + for (Use &U : make_early_inc_range(CI->uses())) + if (auto *EEI = dyn_cast(U.getUser())) + if (auto *Index = dyn_cast(EEI->getIndexOperand())) { + size_t IndexVal = Index->getZExtValue(); + assert(IndexVal < 4 && "Index into buffer load out of range"); + if (!Extracts[IndexVal]) + Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal); + EEI->replaceAllUsesWith(Extracts[IndexVal]); + EEI->eraseFromParent(); + } + + // If there are still uses then we need to create a vector. + if (!CI->use_empty()) { + for (int I = 0, E = 4; I != E; ++I) + if (!Extracts[I]) + Extracts[I] = IRB.CreateExtractValue(*OpCall, I); + + Value *Vec = UndefValue::get(CI->getType()); + for (int I = 0, E = 4; I != E; ++I) + Vec = IRB.CreateInsertElement(Vec, Extracts[I], I); + CI->replaceAllUsesWith(Vec); + } + + CI->eraseFromParent(); + return Error::success(); + }); + } + + void lowerTypedBufferStore(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int8Ty = IRB.getInt8Ty(); + Type *Int32Ty = IRB.getInt32Ty(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + Value *Handle = + createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); + Value *Index0 = CI->getArgOperand(1); + Value *Index1 = UndefValue::get(Int32Ty); + // For typed stores, the mask must always cover all four elements. + Constant *Mask = ConstantInt::get(Int8Ty, 0xF); - for (Function &F : make_early_inc_range(M.functions())) { - if (!F.isDeclaration()) - continue; - Intrinsic::ID ID = F.getIntrinsicID(); - switch (ID) { - default: - continue; + Value *Data = CI->getArgOperand(2); + auto *DataTy = dyn_cast(Data->getType()); + if (!DataTy || DataTy->getNumElements() != 4) + return make_error( + "typedBufferStore data must be a vector of 4 elements", + inconvertibleErrorCode()); + Value *Data0 = + IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 0)); + Value *Data1 = + IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 1)); + Value *Data2 = + IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 2)); + Value *Data3 = + IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, 3)); + + std::array Args{Handle, Index0, Index1, Data0, + Data1, Data2, Data3, Mask}; + Expected OpCall = + OpBuilder.tryCreateOp(OpCode::BufferStore, Args); + if (Error E = OpCall.takeError()) + return E; + + CI->eraseFromParent(); + return Error::success(); + }); + } + + bool lowerIntrinsics() { + bool Updated = false; + + for (Function &F : make_early_inc_range(M.functions())) { + if (!F.isDeclaration()) + continue; + Intrinsic::ID ID = F.getIntrinsicID(); + switch (ID) { + default: + continue; #define DXIL_OP_INTRINSIC(OpCode, Intrin) \ case Intrin: \ - lowerIntrinsic(OpCode, F, M); \ + replaceFunctionWithOp(F, OpCode); \ break; #include "DXILOperation.inc" + case Intrinsic::dx_handle_fromBinding: + lowerHandleFromBinding(F); + break; + case Intrinsic::dx_typedBufferLoad: + lowerTypedBufferLoad(F); + break; + case Intrinsic::dx_typedBufferStore: + lowerTypedBufferStore(F); + break; + } + Updated = true; } - Updated = true; - } - return Updated; -} + if (Updated && !HasErrors) + cleanupHandleCasts(); -namespace { -/// A pass that transforms external global definitions into declarations. -class DXILOpLowering : public PassInfoMixin { -public: - PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { - if (lowerIntrinsics(M)) - return PreservedAnalyses::none(); - return PreservedAnalyses::all(); + return Updated; } }; } // namespace +PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) { + DXILResourceMap &DRM = MAM.getResult(M); + + bool MadeChanges = OpLowerer(M, DRM).lowerIntrinsics(); + if (!MadeChanges) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve(); + return PA; +} + namespace { class DXILOpLoweringLegacy : public ModulePass { public: - bool runOnModule(Module &M) override { return lowerIntrinsics(M); } + bool runOnModule(Module &M) override { + DXILResourceMap &DRM = + getAnalysis().getResourceMap(); + + return OpLowerer(M, DRM).lowerIntrinsics(); + } StringRef getPassName() const override { return "DXIL Op Lowering"; } DXILOpLoweringLegacy() : ModulePass(ID) {} static char ID; // Pass identification. void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { - // Specify the passes that your pass depends on AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); } }; char DXILOpLoweringLegacy::ID = 0; @@ -158,6 +403,7 @@ char DXILOpLoweringLegacy::ID = 0; INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, false) +INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, false) diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.h b/llvm/lib/Target/DirectX/DXILOpLowering.h new file mode 100644 index 0000000000000..fe357da7bb905 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpLowering.h @@ -0,0 +1,27 @@ +//===- DXILOpLowering.h - Lowering to DXIL operations -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file Pass for lowering llvm intrinsics into DXIL operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_DIRECTX_DXILOPLOWERING_H +#define LLVM_LIB_TARGET_DIRECTX_DXILOPLOWERING_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class DXILOpLowering : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM); +}; + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_DIRECTX_DXILOPLOWERING_H diff --git a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp index 99cc4067b1d62..c57631cc4c8b6 100644 --- a/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp +++ b/llvm/lib/Target/DirectX/DXILPrettyPrinter.cpp @@ -1,37 +1,59 @@ -//===- DXILPrettyPrinter.cpp - DXIL Resource helper objects ---------------===// +//===- DXILPrettyPrinter.cpp - Print resources for textual DXIL -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -/// -/// \file This file contains a pass for pretty printing DXIL metadata into IR -/// comments when printing assembly output. -/// -//===----------------------------------------------------------------------===// +#include "DXILPrettyPrinter.h" #include "DXILResourceAnalysis.h" #include "DirectX.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +static void prettyPrintResources(raw_ostream &OS, + const dxil::Resources &MDResources) { + // Column widths are arbitrary but match the widths DXC uses. + OS << ";\n; Resource Bindings:\n;\n"; + OS << formatv("; {0,-30} {1,10} {2,7} {3,11} {4,7} {5,14} {6,16}\n", + "Name", "Type", "Format", "Dim", "ID", "HLSL Bind", "Count"); + OS << formatv( + "; {0,-+30} {1,-+10} {2,-+7} {3,-+11} {4,-+7} {5,-+14} {6,-+16}\n", "", + "", "", "", "", "", ""); + + if (MDResources.hasCBuffers()) + MDResources.printCBuffers(OS); + if (MDResources.hasUAVs()) + MDResources.printUAVs(OS); + + OS << ";\n"; +} + +PreservedAnalyses DXILPrettyPrinterPass::run(Module &M, + ModuleAnalysisManager &MAM) { + const dxil::Resources &MDResources = MAM.getResult(M); + prettyPrintResources(OS, MDResources); + return PreservedAnalyses::all(); +} + namespace { -class DXILPrettyPrinter : public llvm::ModulePass { +class DXILPrettyPrinterLegacy : public llvm::ModulePass { raw_ostream &OS; // raw_ostream to print to. public: static char ID; - DXILPrettyPrinter() : ModulePass(ID), OS(dbgs()) { - initializeDXILPrettyPrinterPass(*PassRegistry::getPassRegistry()); + DXILPrettyPrinterLegacy() : ModulePass(ID), OS(dbgs()) { + initializeDXILPrettyPrinterLegacyPass(*PassRegistry::getPassRegistry()); } - explicit DXILPrettyPrinter(raw_ostream &O) : ModulePass(ID), OS(O) { - initializeDXILPrettyPrinterPass(*PassRegistry::getPassRegistry()); + explicit DXILPrettyPrinterLegacy(raw_ostream &O) : ModulePass(ID), OS(O) { + initializeDXILPrettyPrinterLegacyPass(*PassRegistry::getPassRegistry()); } StringRef getPassName() const override { @@ -46,19 +68,19 @@ class DXILPrettyPrinter : public llvm::ModulePass { }; } // namespace -char DXILPrettyPrinter::ID = 0; -INITIALIZE_PASS_BEGIN(DXILPrettyPrinter, "dxil-pretty-printer", +char DXILPrettyPrinterLegacy::ID = 0; +INITIALIZE_PASS_BEGIN(DXILPrettyPrinterLegacy, "dxil-pretty-printer", "DXIL Metadata Pretty Printer", true, true) INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper) -INITIALIZE_PASS_END(DXILPrettyPrinter, "dxil-pretty-printer", +INITIALIZE_PASS_END(DXILPrettyPrinterLegacy, "dxil-pretty-printer", "DXIL Metadata Pretty Printer", true, true) -bool DXILPrettyPrinter::runOnModule(Module &M) { +bool DXILPrettyPrinterLegacy::runOnModule(Module &M) { dxil::Resources &Res = getAnalysis().getDXILResource(); - Res.print(OS); + prettyPrintResources(OS, Res); return false; } -ModulePass *llvm::createDXILPrettyPrinterPass(raw_ostream &OS) { - return new DXILPrettyPrinter(OS); +ModulePass *llvm::createDXILPrettyPrinterLegacyPass(raw_ostream &OS) { + return new DXILPrettyPrinterLegacy(OS); } diff --git a/llvm/lib/Target/DirectX/DXILPrettyPrinter.h b/llvm/lib/Target/DirectX/DXILPrettyPrinter.h new file mode 100644 index 0000000000000..84e17ac0f2ec6 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILPrettyPrinter.h @@ -0,0 +1,33 @@ +//===- DXILPrettyPrinter.h - Print resources for textual DXIL ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file This file contains a pass for pretty printing DXIL metadata into IR +// comments when printing assembly output. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TARGET_DIRECTX_DXILPRETTYPRINTER_H +#define LLVM_TARGET_DIRECTX_DXILPRETTYPRINTER_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// A pass that prints resources in a format suitable for textual DXIL. +class DXILPrettyPrinterPass : public PassInfoMixin { + raw_ostream &OS; + +public: + explicit DXILPrettyPrinterPass(raw_ostream &OS) : OS(OS) {} + + PreservedAnalyses run(Module &M, ModuleAnalysisManager &); +}; + +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_DXILPRETTYPRINTER_H diff --git a/llvm/lib/Target/DirectX/DXILResource.cpp b/llvm/lib/Target/DirectX/DXILResource.cpp index 8e5b9867e6661..f027283b70521 100644 --- a/llvm/lib/Target/DirectX/DXILResource.cpp +++ b/llvm/lib/Target/DirectX/DXILResource.cpp @@ -333,37 +333,14 @@ template MDNode *ResourceTable::write(Module &M) const { return MDNode::get(M.getContext(), MDs); } -void Resources::write(Module &M) const { - Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr}; - - ResourceMDs[1] = UAVs.write(M); - - ResourceMDs[2] = CBuffers.write(M); - - bool HasResource = ResourceMDs[0] != nullptr || ResourceMDs[1] != nullptr || - ResourceMDs[2] != nullptr || ResourceMDs[3] != nullptr; - - if (HasResource) { - NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources"); - DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs)); - } - - NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs"); - if (Entry) - Entry->eraseFromParent(); +Metadata *Resources::writeUAVs(Module &M) const { return UAVs.write(M); } +void Resources::printUAVs(raw_ostream &OS) const { UAVs.print(OS); } +Metadata *Resources::writeCBuffers(Module &M) const { + return CBuffers.write(M); } +void Resources::printCBuffers(raw_ostream &OS) const { CBuffers.print(OS); } -void Resources::print(raw_ostream &O) const { - O << ";\n" - << "; Resource Bindings:\n" - << ";\n" - << "; Name Type Format Dim " - "ID HLSL Bind Count\n" - << "; ------------------------------ ---------- ------- ----------- " - "------- -------------- ------\n"; - - CBuffers.print(O); - UAVs.print(O); +void Resources::dump() const { + printCBuffers(dbgs()); + printUAVs(dbgs()); } - -void Resources::dump() const { print(dbgs()); } diff --git a/llvm/lib/Target/DirectX/DXILResource.h b/llvm/lib/Target/DirectX/DXILResource.h index 06902fe2b87b0..812729bc4dc57 100644 --- a/llvm/lib/Target/DirectX/DXILResource.h +++ b/llvm/lib/Target/DirectX/DXILResource.h @@ -103,6 +103,7 @@ template class ResourceTable { public: ResourceTable(StringRef Name) : MDName(Name) {} void collect(Module &M); + bool empty() const { return Data.empty(); } MDNode *write(Module &M) const; void print(raw_ostream &O) const; }; @@ -117,8 +118,12 @@ class Resources { public: void collect(Module &M); - void write(Module &M) const; - void print(raw_ostream &O) const; + bool hasUAVs() const { return !UAVs.empty(); } + Metadata *writeUAVs(Module &M) const; + void printUAVs(raw_ostream &OS) const; + bool hasCBuffers() const { return !CBuffers.empty(); } + Metadata *writeCBuffers(Module &M) const; + void printCBuffers(raw_ostream &OS) const; LLVM_DUMP_METHOD void dump() const; }; diff --git a/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp b/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp index 33e0119807bb8..d423220bb902e 100644 --- a/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp +++ b/llvm/lib/Target/DirectX/DXILResourceAnalysis.cpp @@ -27,13 +27,6 @@ dxil::Resources DXILResourceMDAnalysis::run(Module &M, AnalysisKey DXILResourceMDAnalysis::Key; -PreservedAnalyses DXILResourceMDPrinterPass::run(Module &M, - ModuleAnalysisManager &AM) { - dxil::Resources Res = AM.getResult(M); - Res.print(OS); - return PreservedAnalyses::all(); -} - char DXILResourceMDWrapper::ID = 0; INITIALIZE_PASS_BEGIN(DXILResourceMDWrapper, DEBUG_TYPE, "DXIL resource Information", true, true) @@ -46,7 +39,3 @@ bool DXILResourceMDWrapper::runOnModule(Module &M) { } DXILResourceMDWrapper::DXILResourceMDWrapper() : ModulePass(ID) {} - -void DXILResourceMDWrapper::print(raw_ostream &OS, const Module *) const { - Resources.print(OS); -} diff --git a/llvm/lib/Target/DirectX/DXILResourceAnalysis.h b/llvm/lib/Target/DirectX/DXILResourceAnalysis.h index 3a2b8a9fd39d5..0ad97dc1992f4 100644 --- a/llvm/lib/Target/DirectX/DXILResourceAnalysis.h +++ b/llvm/lib/Target/DirectX/DXILResourceAnalysis.h @@ -30,17 +30,6 @@ class DXILResourceMDAnalysis dxil::Resources run(Module &M, ModuleAnalysisManager &AM); }; -/// Printer pass for the \c DXILResourceMDAnalysis results. -class DXILResourceMDPrinterPass - : public PassInfoMixin { - raw_ostream &OS; - -public: - explicit DXILResourceMDPrinterPass(raw_ostream &OS) : OS(OS) {} - PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); - static bool isRequired() { return true; } -}; - /// The legacy pass manager's analysis pass to compute DXIL resource /// information. class DXILResourceMDWrapper : public ModulePass { @@ -57,7 +46,9 @@ class DXILResourceMDWrapper : public ModulePass { /// Calculate the DXILResource for the module. bool runOnModule(Module &M) override; - void print(raw_ostream &O, const Module *M = nullptr) const override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } }; } // namespace llvm diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 21089a232783a..007af0b46b9f3 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -1,13 +1,12 @@ -//===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata ---*- C++ -*-===// +//===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -/// -//===----------------------------------------------------------------------===// +#include "DXILTranslateMetadata.h" #include "DXILMetadata.h" #include "DXILResource.h" #include "DXILResourceAnalysis.h" @@ -23,53 +22,90 @@ using namespace llvm; using namespace llvm::dxil; -namespace { -class DXILTranslateMetadata : public ModulePass { -public: - static char ID; // Pass identification, replacement for typeid - explicit DXILTranslateMetadata() : ModulePass(ID) {} +static void emitResourceMetadata(Module &M, + const dxil::Resources &MDResources) { + Metadata *SRVMD = nullptr, *UAVMD = nullptr, *CBufMD = nullptr, + *SmpMD = nullptr; + bool HasResources = false; - StringRef getPassName() const override { return "DXIL Translate Metadata"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - AU.addRequired(); - AU.addRequired(); + if (MDResources.hasUAVs()) { + UAVMD = MDResources.writeUAVs(M); + HasResources = true; } - bool runOnModule(Module &M) override; -}; + if (MDResources.hasCBuffers()) { + CBufMD = MDResources.writeCBuffers(M); + HasResources = true; + } -} // namespace + if (!HasResources) + return; -bool DXILTranslateMetadata::runOnModule(Module &M) { + NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources"); + ResourceMD->addOperand( + MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD})); +} +static void translateMetadata(Module &M, const dxil::Resources &MDResources, + const ComputedShaderFlags &ShaderFlags) { dxil::ValidatorVersionMD ValVerMD(M); if (ValVerMD.isEmpty()) ValVerMD.update(VersionTuple(1, 0)); dxil::createShaderModelMD(M); dxil::createDXILVersionMD(M); - const dxil::Resources &Res = - getAnalysis().getDXILResource(); - Res.write(M); + emitResourceMetadata(M, MDResources); - const uint64_t Flags = static_cast( - getAnalysis().getShaderFlags()); - dxil::createEntryMD(M, Flags); + dxil::createEntryMD(M, static_cast(ShaderFlags)); +} + +PreservedAnalyses DXILTranslateMetadata::run(Module &M, + ModuleAnalysisManager &MAM) { + const dxil::Resources &MDResources = MAM.getResult(M); + const ComputedShaderFlags &ShaderFlags = + MAM.getResult(M); - return false; + translateMetadata(M, MDResources, ShaderFlags); + + return PreservedAnalyses::all(); } -char DXILTranslateMetadata::ID = 0; +namespace { +class DXILTranslateMetadataLegacy : public ModulePass { +public: + static char ID; // Pass identification, replacement for typeid + explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} + + StringRef getPassName() const override { return "DXIL Translate Metadata"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired(); + AU.addRequired(); + } + + bool runOnModule(Module &M) override { + const dxil::Resources &MDResources = + getAnalysis().getDXILResource(); + const ComputedShaderFlags &ShaderFlags = + getAnalysis().getShaderFlags(); + + translateMetadata(M, MDResources, ShaderFlags); + return true; + } +}; + +} // namespace + +char DXILTranslateMetadataLegacy::ID = 0; -ModulePass *llvm::createDXILTranslateMetadataPass() { - return new DXILTranslateMetadata(); +ModulePass *llvm::createDXILTranslateMetadataLegacyPass() { + return new DXILTranslateMetadataLegacy(); } -INITIALIZE_PASS_BEGIN(DXILTranslateMetadata, "dxil-translate-metadata", +INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false) INITIALIZE_PASS_DEPENDENCY(DXILResourceMDWrapper) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) -INITIALIZE_PASS_END(DXILTranslateMetadata, "dxil-translate-metadata", +INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false) diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.h b/llvm/lib/Target/DirectX/DXILTranslateMetadata.h new file mode 100644 index 0000000000000..f3f5eb1901406 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.h @@ -0,0 +1,24 @@ +//===- DXILTranslateMetadata.h - Pass to emit DXIL metadata -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H +#define LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// A pass that transforms DXIL Intrinsics that don't have DXIL opCodes +class DXILTranslateMetadata : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &); +}; + +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h index d056ae2bc488e..963c39ace3af9 100644 --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -41,19 +41,19 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &); ModulePass *createDXILOpLoweringLegacyPass(); /// Initializer for DXILTranslateMetadata. -void initializeDXILTranslateMetadataPass(PassRegistry &); +void initializeDXILTranslateMetadataLegacyPass(PassRegistry &); /// Pass to emit metadata for DXIL. -ModulePass *createDXILTranslateMetadataPass(); +ModulePass *createDXILTranslateMetadataLegacyPass(); /// Initializer for DXILTranslateMetadata. void initializeDXILResourceMDWrapperPass(PassRegistry &); /// Pass to pretty print DXIL metadata. -ModulePass *createDXILPrettyPrinterPass(raw_ostream &OS); +ModulePass *createDXILPrettyPrinterLegacyPass(raw_ostream &OS); /// Initializer for DXILPrettyPrinter. -void initializeDXILPrettyPrinterPass(PassRegistry &); +void initializeDXILPrettyPrinterLegacyPass(PassRegistry &); /// Initializer for dxil::ShaderFlagsAnalysisWrapper pass. void initializeShaderFlagsAnalysisWrapperPass(PassRegistry &); diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def index 7544172ab94e4..a3e051b173d89 100644 --- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def +++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def @@ -23,7 +23,10 @@ MODULE_ANALYSIS("dxil-resource-md", DXILResourceMDAnalysis()) #ifndef MODULE_PASS #define MODULE_PASS(NAME, CREATE_PASS) #endif +MODULE_PASS("dxil-intrinsic-expansion", DXILIntrinsicExpansion()) +MODULE_PASS("dxil-op-lower", DXILOpLowering()) +MODULE_PASS("dxil-pretty-printer", DXILPrettyPrinterPass(dbgs())) +MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata()) // TODO: rename to print after NPM switch MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs())) -MODULE_PASS("print-dxil-resource-md", DXILResourceMDPrinterPass(dbgs())) #undef MODULE_PASS diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 92bd69b69684f..a578ad1452560 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -12,8 +12,12 @@ //===----------------------------------------------------------------------===// #include "DirectXTargetMachine.h" +#include "DXILIntrinsicExpansion.h" +#include "DXILOpLowering.h" +#include "DXILPrettyPrinter.h" #include "DXILResourceAnalysis.h" #include "DXILShaderFlags.h" +#include "DXILTranslateMetadata.h" #include "DXILWriter/DXILWriterPass.h" #include "DirectX.h" #include "DirectXSubtarget.h" @@ -45,7 +49,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { initializeWriteDXILPassPass(*PR); initializeDXContainerGlobalsPass(*PR); initializeDXILOpLoweringLegacyPass(*PR); - initializeDXILTranslateMetadataPass(*PR); + initializeDXILTranslateMetadataLegacyPass(*PR); initializeDXILResourceMDWrapperPass(*PR); initializeShaderFlagsAnalysisWrapperPass(*PR); } @@ -79,7 +83,7 @@ class DirectXPassConfig : public TargetPassConfig { void addCodeGenPrepare() override { addPass(createDXILIntrinsicExpansionLegacyPass()); addPass(createDXILOpLoweringLegacyPass()); - addPass(createDXILTranslateMetadataPass()); + addPass(createDXILTranslateMetadataLegacyPass()); addPass(createDXILPrepareModulePass()); } }; @@ -116,7 +120,7 @@ bool DirectXTargetMachine::addPassesToEmitFile( switch (FileType) { case CodeGenFileType::AssemblyFile: - PM.add(createDXILPrettyPrinterPass(Out)); + PM.add(createDXILPrettyPrinterLegacyPass(Out)); PM.add(createPrintModulePass(Out, "", true)); break; case CodeGenFileType::ObjectFile: diff --git a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll index 4349adb8ef8eb..65802c6d1ff87 100644 --- a/llvm/test/Analysis/DXILResource/buffer-frombinding.ll +++ b/llvm/test/Analysis/DXILResource/buffer-frombinding.ll @@ -46,14 +46,14 @@ define void @test_typedbuffer() { ; Buffer Buf[24] : register(t3, space5) %typed2 = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_0_0t( - i32 2, i32 7, i32 24, i32 0, i1 false) + i32 5, i32 3, i32 24, i32 0, i1 false) ; CHECK: Binding for %typed2 ; CHECK: Symbol: ptr undef ; CHECK: Name: "" ; CHECK: Binding: ; CHECK: Record ID: 0 - ; CHECK: Space: 2 - ; CHECK: Lower Bound: 7 + ; CHECK: Space: 5 + ; CHECK: Lower Bound: 3 ; CHECK: Size: 24 ; CHECK: Class: SRV ; CHECK: Kind: TypedBuffer diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll new file mode 100644 index 0000000000000..c3bb96dbdf909 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll @@ -0,0 +1,102 @@ +; RUN: opt -S -dxil-op-lower %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +declare void @scalar_user(float) +declare void @vector_user(<4 x float>) + +define void @loadfloats() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; The temporary casts should all have been cleaned up + ; CHECK-NOT: %dx.cast_handle + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0) + + ; The extract order depends on the users, so don't enforce that here. + ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0 + %data0_0 = extractelement <4 x float> %data0, i32 0 + ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2 + %data0_2 = extractelement <4 x float> %data0, i32 2 + + ; If all of the uses are extracts, we skip creating a vector + ; CHECK-NOT: insertelement + call void @scalar_user(float %data0_0) + call void @scalar_user(float %data0_2) + + ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef) + %data4 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4) + + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3 + ; CHECK: insertelement <4 x float> undef + ; CHECK: insertelement <4 x float> + ; CHECK: insertelement <4 x float> + ; CHECK: insertelement <4 x float> + call void @vector_user(<4 x float> %data4) + + ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef) + %data12 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12) + + ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3 + %data12_3 = extractelement <4 x float> %data12, i32 3 + + ; If there are a mix of users we need the vector, but extracts are direct + ; CHECK: call void @scalar_user(float [[DATA12_3]]) + call void @scalar_user(float %data12_3) + call void @vector_user(<4 x float> %data12) + + ret void +} + +define void @loadint() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x i32> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0) + + ret void +} + +define void @loadhalf() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x half> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0) + + ret void +} + +define void @loadi16() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x i16> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0) + + ret void +} diff --git a/llvm/test/CodeGen/DirectX/BufferStore-errors.ll b/llvm/test/CodeGen/DirectX/BufferStore-errors.ll new file mode 100644 index 0000000000000..26a805101ed2e --- /dev/null +++ b/llvm/test/CodeGen/DirectX/BufferStore-errors.ll @@ -0,0 +1,34 @@ +; We use llc for this test so that we don't abort after the first error. +; RUN: not llc %s -o /dev/null 2>&1 | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +; CHECK: error: +; CHECK-SAME: in function storetoofew +; CHECK-SAME: typedBufferStore data must be a vector of 4 elements +define void @storetoomany(<5 x float> %data, i32 %index) { + %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, + i32 %index, <5 x float> %data) + + ret void +} + +; CHECK: error: +; CHECK-SAME: in function storetoomany +; CHECK-SAME: typedBufferStore data must be a vector of 4 elements +define void @storetoofew(<3 x i32> %data, i32 %index) { + %buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_1_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer, + i32 %index, <3 x i32> %data) + + ret void +} diff --git a/llvm/test/CodeGen/DirectX/BufferStore.ll b/llvm/test/CodeGen/DirectX/BufferStore.ll new file mode 100644 index 0000000000000..102084816a6f2 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/BufferStore.ll @@ -0,0 +1,92 @@ +; RUN: opt -S -dxil-op-lower %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +define void @storefloat(<4 x float> %data, i32 %index) { + + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; The temporary casts should all have been cleaned up + ; CHECK-NOT: %dx.cast_handle + + ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x float> %data, i32 0 + ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x float> %data, i32 1 + ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x float> %data, i32 2 + ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x float> %data, i32 3 + ; CHECK: call void @dx.op.bufferStore.f32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, float [[DATA0_0]], float [[DATA0_1]], float [[DATA0_2]], float [[DATA0_3]], i8 15) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, + i32 %index, <4 x float> %data) + + ret void +} + +define void @storeint(<4 x i32> %data, i32 %index) { + + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_1_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x i32> %data, i32 0 + ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i32> %data, i32 1 + ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i32> %data, i32 2 + ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i32> %data, i32 3 + ; CHECK: call void @dx.op.bufferStore.i32(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i32 [[DATA0_0]], i32 [[DATA0_1]], i32 [[DATA0_2]], i32 [[DATA0_3]], i8 15) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x i32>, 1, 0, 0) %buffer, + i32 %index, <4 x i32> %data) + + ret void +} + +define void @storehalf(<4 x half> %data, i32 %index) { + + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_1_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; The temporary casts should all have been cleaned up + ; CHECK-NOT: %dx.cast_handle + + ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x half> %data, i32 0 + ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x half> %data, i32 1 + ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x half> %data, i32 2 + ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x half> %data, i32 3 + ; CHECK: call void @dx.op.bufferStore.f16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, half [[DATA0_0]], half [[DATA0_1]], half [[DATA0_2]], half [[DATA0_3]], i8 15) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x half>, 1, 0, 0) %buffer, + i32 %index, <4 x half> %data) + + ret void +} + +define void @storei16(<4 x i16> %data, i32 %index) { + + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x i16>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_1_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; The temporary casts should all have been cleaned up + ; CHECK-NOT: %dx.cast_handle + + ; CHECK: [[DATA0_0:%.*]] = extractelement <4 x i16> %data, i32 0 + ; CHECK: [[DATA0_1:%.*]] = extractelement <4 x i16> %data, i32 1 + ; CHECK: [[DATA0_2:%.*]] = extractelement <4 x i16> %data, i32 2 + ; CHECK: [[DATA0_3:%.*]] = extractelement <4 x i16> %data, i32 3 + ; CHECK: call void @dx.op.bufferStore.i16(i32 69, %dx.types.Handle [[HANDLE]], i32 %index, i32 undef, i16 [[DATA0_0]], i16 [[DATA0_1]], i16 [[DATA0_2]], i16 [[DATA0_3]], i8 15) + call void @llvm.dx.typedBufferStore( + target("dx.TypedBuffer", <4 x i16>, 1, 0, 0) %buffer, + i32 %index, <4 x i16> %data) + + ret void +} diff --git a/llvm/test/CodeGen/DirectX/CreateHandle.ll b/llvm/test/CodeGen/DirectX/CreateHandle.ll new file mode 100644 index 0000000000000..1fad869ab4305 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/CreateHandle.ll @@ -0,0 +1,61 @@ +; RUN: opt -S -dxil-op-lower %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.0-compute" + +define void @test_buffers() { + ; RWBuffer Buf : register(u5, space3) + %typed0 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0( + i32 3, i32 5, i32 1, i32 4, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 0, i32 4, i1 false) + ; CHECK-NOT: @llvm.dx.cast.handle + + ; RWBuffer Buf : register(u7, space2) + %typed1 = call target("dx.TypedBuffer", i32, 1, 0, 1) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_1_0_1t( + i32 2, i32 7, i32 1, i32 6, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 1, i32 1, i32 6, i1 false) + + ; Buffer Buf[24] : register(t3, space5) + ; Buffer typed2 = Buf[5] + ; Note that the index below is 3 + 4 = 7 + %typed2 = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_0_0_0t( + i32 5, i32 3, i32 24, i32 7, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 0, i32 7, i1 false) + + ; struct S { float4 a; uint4 b; }; + ; StructuredBuffer Buf : register(t2, space4) + %struct0 = call target("dx.RawBuffer", {<4 x float>, <4 x i32>}, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_sl_v4f32v4i32s_0_0t( + i32 4, i32 2, i32 1, i32 10, i1 true) + ; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 1, i32 10, i1 true) + + ; ByteAddressBuffer Buf : register(t8, space1) + %byteaddr0 = call target("dx.RawBuffer", i8, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t( + i32 1, i32 8, i32 1, i32 12, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.createHandle(i32 57, i8 0, i32 2, i32 12, i1 false) + + ret void +} + +; Note: We need declarations for each handle.fromBinding in the same order as +; they appear in source to force a deterministic ordering of record IDs. +declare target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0t( + i32, i32, i32, i32, i1) #0 +declare target("dx.TypedBuffer", i32, 1, 0, 1) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_1_0_1t( + i32, i32, i32, i32, i1) #0 +declare target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0t( + i32, i32, i32, i32, i1) #0 +declare target("dx.RawBuffer", { <4 x float>, <4 x i32> }, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_sl_v4f32v4i32s_0_0t( + i32, i32, i32, i32, i1) #0 +declare target("dx.RawBuffer", i8, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t( + i32, i32, i32, i32, i1) #0 + +attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) } diff --git a/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll new file mode 100644 index 0000000000000..e8bd8fe89132d --- /dev/null +++ b/llvm/test/CodeGen/DirectX/CreateHandleFromBinding.ll @@ -0,0 +1,65 @@ +; RUN: opt -S -dxil-op-lower %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +define void @test_bindings() { + ; RWBuffer Buf : register(u5, space3) + %typed0 = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0( + i32 3, i32 5, i32 1, i32 4, i1 false) + ; CHECK: [[BUF0:%[0-9]*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 218, %dx.types.ResBind { i32 5, i32 5, i32 3, i8 1 }, i32 4, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BUF0]], %dx.types.ResourceProperties { i32 4106, i32 1033 }) + + ; RWBuffer Buf : register(u7, space2) + %typed1 = call target("dx.TypedBuffer", i32, 1, 0, 1) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_1_0_0t( + i32 2, i32 7, i32 1, i32 6, i1 false) + ; CHECK: [[BUF1:%[0-9]*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 218, %dx.types.ResBind { i32 7, i32 7, i32 2, i8 1 }, i32 6, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BUF1]], %dx.types.ResourceProperties { i32 4106, i32 260 }) + + ; Buffer Buf[24] : register(t3, space5) + ; Buffer typed2 = Buf[4] + ; Note that the index below is 3 + 4 = 7 + %typed2 = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_0_0_0t( + i32 5, i32 3, i32 24, i32 7, i1 false) + ; CHECK: [[BUF2:%[0-9]*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 218, %dx.types.ResBind { i32 3, i32 26, i32 5, i8 0 }, i32 7, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BUF2]], %dx.types.ResourceProperties { i32 10, i32 1029 }) + + ; struct S { float4 a; uint4 b; }; + ; StructuredBuffer Buf : register(t2, space4) + %struct0 = call target("dx.RawBuffer", {<4 x float>, <4 x i32>}, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_sl_v4f32v4i32s_0_0t( + i32 4, i32 2, i32 1, i32 10, i1 true) + ; CHECK: [[BUF3:%[0-9]*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 218, %dx.types.ResBind { i32 2, i32 2, i32 4, i8 0 }, i32 10, i1 true) + ; CHECK: = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BUF3]], %dx.types.ResourceProperties { i32 1036, i32 32 }) + + ; ByteAddressBuffer Buf : register(t8, space1) + %byteaddr0 = call target("dx.RawBuffer", i8, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t( + i32 1, i32 8, i32 1, i32 12, i1 false) + ; CHECK: [[BUF4:%[0-9]*]] = call %dx.types.Handle @dx.op.createHandleFromBinding(i32 218, %dx.types.ResBind { i32 8, i32 8, i32 1, i8 0 }, i32 12, i1 false) + ; CHECK: call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BUF4]], %dx.types.ResourceProperties { i32 11, i32 0 }) + + ret void +} + +; Note: We need declarations for each handle.fromBinding in the same order as +; they appear in source to force a deterministic ordering of record IDs. +declare target("dx.TypedBuffer", <4 x float>, 1, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0t( + i32, i32, i32, i32, i1) #0 +declare target("dx.TypedBuffer", i32, 1, 0, 1) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_i32_1_0_1t( + i32, i32, i32, i32, i1) #0 +declare target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0t( + i32, i32, i32, i32, i1) #0 +declare target("dx.RawBuffer", { <4 x float>, <4 x i32> }, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_sl_v4f32v4i32s_0_0t( + i32, i32, i32, i32, i1) #0 +declare target("dx.RawBuffer", i8, 0, 0) + @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t( + i32, i32, i32, i32, i1) #0 + +attributes #0 = { nocallback nofree nosync nounwind willreturn memory(none) } diff --git a/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.0.ll b/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.0.ll index 318d5a6210eee..fb31833dd5139 100644 --- a/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.0.ll +++ b/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.0.ll @@ -1,4 +1,4 @@ -; RUN: opt -S -dxil-translate-metadata %s | FileCheck %s +; RUN: opt -S -passes=dxil-translate-metadata %s | FileCheck %s ; RUN: opt -S -passes="print" -disable-output %s 2>&1 | FileCheck %s --check-prefix=ANALYSIS target triple = "dxil-pc-shadermodel6.0-vertex" diff --git a/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.8.ll b/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.8.ll index fb54fa916f33f..5944acc3b5b4b 100644 --- a/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.8.ll +++ b/llvm/test/CodeGen/DirectX/Metadata/dxilVer-1.8.ll @@ -1,4 +1,4 @@ -; RUN: opt -S -dxil-translate-metadata %s | FileCheck %s +; RUN: opt -S -passes=dxil-translate-metadata %s | FileCheck %s ; RUN: opt -S -passes="print" -disable-output %s 2>&1 | FileCheck %s --check-prefix=ANALYSIS target triple = "dxil-pc-shadermodel6.8-compute" diff --git a/llvm/test/CodeGen/DirectX/UAVMetadata.ll b/llvm/test/CodeGen/DirectX/UAVMetadata.ll index b10112a044df5..2c242ec08eda5 100644 --- a/llvm/test/CodeGen/DirectX/UAVMetadata.ll +++ b/llvm/test/CodeGen/DirectX/UAVMetadata.ll @@ -1,5 +1,5 @@ ; RUN: opt -S -dxil-translate-metadata < %s | FileCheck %s -; RUN: opt -S --passes="print-dxil-resource-md" < %s 2>&1 | FileCheck %s --check-prefix=PRINT +; RUN: opt -S --passes="dxil-pretty-printer" < %s 2>&1 | FileCheck %s --check-prefix=PRINT ; RUN: llc %s --filetype=asm -o - < %s 2>&1 | FileCheck %s --check-prefixes=CHECK,PRINT target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" diff --git a/llvm/test/CodeGen/DirectX/any.ll b/llvm/test/CodeGen/DirectX/any.ll index ceb8077af311d..e32aa389a81a5 100644 --- a/llvm/test/CodeGen/DirectX/any.ll +++ b/llvm/test/CodeGen/DirectX/any.ll @@ -1,4 +1,4 @@ -; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s +; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s ; Make sure dxil operation function calls for any are generated for float and half. diff --git a/llvm/test/CodeGen/DirectX/cbuf.ll b/llvm/test/CodeGen/DirectX/cbuf.ll index e31a659728fcf..7589da5e15bd7 100644 --- a/llvm/test/CodeGen/DirectX/cbuf.ll +++ b/llvm/test/CodeGen/DirectX/cbuf.ll @@ -1,5 +1,5 @@ ; RUN: opt -S -dxil-translate-metadata < %s | FileCheck %s --check-prefix=DXILMD -; RUN: opt -S --passes="print-dxil-resource-md" < %s 2>&1 | FileCheck %s --check-prefix=PRINT +; RUN: opt -S --passes="dxil-pretty-printer" < %s 2>&1 | FileCheck %s --check-prefix=PRINT target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" target triple = "dxil-unknown-shadermodel6.7-library" diff --git a/llvm/test/CodeGen/DirectX/floor.ll b/llvm/test/CodeGen/DirectX/floor.ll index f667cab4aa249..f79f160e51e3b 100644 --- a/llvm/test/CodeGen/DirectX/floor.ll +++ b/llvm/test/CodeGen/DirectX/floor.ll @@ -1,4 +1,4 @@ -; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +; RUN: opt -S -passes=dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s ; Make sure dxil operation function calls for floor are generated for float and half. diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp index 53791618e80fe..332706f7e3e57 100644 --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -18,7 +18,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" -#include "llvm/CodeGenTypes/MachineValueType.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/DXILABI.h" #include "llvm/Support/VersionTuple.h" #include "llvm/TableGen/Error.h" @@ -54,40 +54,6 @@ struct DXILOperationDesc { }; } // end anonymous namespace -/// Return dxil::ParameterKind corresponding to input LLVMType record -/// -/// \param R TableGen def record of class LLVMType -/// \return ParameterKind As defined in llvm/Support/DXILABI.h - -static ParameterKind getParameterKind(const Record *R) { - auto VTRec = R->getValueAsDef("VT"); - switch (getValueType(VTRec)) { - case MVT::isVoid: - return ParameterKind::Void; - case MVT::f16: - return ParameterKind::Half; - case MVT::f32: - return ParameterKind::Float; - case MVT::f64: - return ParameterKind::Double; - case MVT::i1: - return ParameterKind::I1; - case MVT::i8: - return ParameterKind::I8; - case MVT::i16: - return ParameterKind::I16; - case MVT::i32: - return ParameterKind::I32; - case MVT::fAny: - case MVT::iAny: - case MVT::Any: - return ParameterKind::Overload; - default: - llvm_unreachable( - "Support for specified parameter type not yet implemented"); - } -} - /// In-place sort TableGen records of class with a field /// Version dxil_version /// in the ascending version order. @@ -134,10 +100,9 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { // llvm/IR/Intrinsics.td OverloadParamIndex = -1; // A sigil meaning none. for (unsigned i = 0; i < ParamTypeRecsSize; i++) { - auto TR = ParamTypeRecs[i]; + Record *TR = ParamTypeRecs[i]; // Track operation parameter indices of any overload types - auto isAny = TR->getValueAsInt("isAny"); - if (isAny == 1) { + if (TR->getValueAsInt("isOverload")) { if (OverloadParamIndex != -1) { assert(TR == ParamTypeRecs[OverloadParamIndex] && "Specification of multiple differing overload parameter types " @@ -148,8 +113,6 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { if (OverloadParamIndex <= 0) OverloadParamIndex = i; } - if (TR->isAnonymous()) - PrintFatalError(TR, "Only concrete types are allowed here"); OpTypes.emplace_back(TR); } @@ -208,71 +171,27 @@ DXILOperationDesc::DXILOperationDesc(const Record *R) { } } -/// Return a string representation of ParameterKind enum -/// \param Kind Parameter Kind enum value -/// \return std::string string representation of input Kind -static std::string getParameterKindStr(ParameterKind Kind) { - switch (Kind) { - case ParameterKind::Invalid: - return "Invalid"; - case ParameterKind::Void: - return "Void"; - case ParameterKind::Half: - return "Half"; - case ParameterKind::Float: - return "Float"; - case ParameterKind::Double: - return "Double"; - case ParameterKind::I1: - return "I1"; - case ParameterKind::I8: - return "I8"; - case ParameterKind::I16: - return "I16"; - case ParameterKind::I32: - return "I32"; - case ParameterKind::I64: - return "I64"; - case ParameterKind::Overload: - return "Overload"; - case ParameterKind::CBufferRet: - return "CBufferRet"; - case ParameterKind::ResourceRet: - return "ResourceRet"; - case ParameterKind::DXILHandle: - return "DXILHandle"; - } - llvm_unreachable("Unknown llvm::dxil::ParameterKind enum"); -} - /// Return a string representation of OverloadKind enum that maps to /// input LLVMType record /// \param R TableGen def record of class LLVMType /// \return std::string string representation of OverloadKind -static std::string getOverloadKindStr(const Record *R) { - Record *VTRec = R->getValueAsDef("VT"); - switch (getValueType(VTRec)) { - case MVT::f16: - return "OverloadKind::HALF"; - case MVT::f32: - return "OverloadKind::FLOAT"; - case MVT::f64: - return "OverloadKind::DOUBLE"; - case MVT::i1: - return "OverloadKind::I1"; - case MVT::i8: - return "OverloadKind::I8"; - case MVT::i16: - return "OverloadKind::I16"; - case MVT::i32: - return "OverloadKind::I32"; - case MVT::i64: - return "OverloadKind::I64"; - default: - llvm_unreachable("Support for specified fixed type option for overload " - "type not supported"); - } +static StringRef getOverloadKindStr(const Record *R) { + // TODO: This is a hack. We need to rework how we're handling the set of + // overloads to avoid this business with the separate OverloadKind enum. + return StringSwitch(R->getName()) + .Case("HalfTy", "OverloadKind::HALF") + .Case("FloatTy", "OverloadKind::FLOAT") + .Case("DoubleTy", "OverloadKind::DOUBLE") + .Case("Int1Ty", "OverloadKind::I1") + .Case("Int8Ty", "OverloadKind::I8") + .Case("Int16Ty", "OverloadKind::I16") + .Case("Int32Ty", "OverloadKind::I32") + .Case("Int64Ty", "OverloadKind::I64") + .Case("ResRetHalfTy", "OverloadKind::HALF") + .Case("ResRetFloatTy", "OverloadKind::FLOAT") + .Case("ResRetInt16Ty", "OverloadKind::I16") + .Case("ResRetInt32Ty", "OverloadKind::I32"); } /// Return a string representation of valid overload information denoted @@ -417,8 +336,7 @@ static void emitDXILOpCodes(std::vector &Ops, } /// Emit a list of DXIL op classes -static void emitDXILOpClasses(RecordKeeper &Records, - raw_ostream &OS) { +static void emitDXILOpClasses(RecordKeeper &Records, raw_ostream &OS) { OS << "#ifdef DXIL_OPCLASS\n"; std::vector OpClasses = Records.getAllDerivedDefinitions("DXILOpClass"); @@ -428,6 +346,35 @@ static void emitDXILOpClasses(RecordKeeper &Records, OS << "#endif\n\n"; } +/// Emit a list of DXIL op parameter types +static void emitDXILOpParamTypes(RecordKeeper &Records, raw_ostream &OS) { + OS << "#ifdef DXIL_OP_PARAM_TYPE\n"; + std::vector OpClasses = + Records.getAllDerivedDefinitions("DXILOpParamType"); + for (Record *OpClass : OpClasses) + OS << "DXIL_OP_PARAM_TYPE(" << OpClass->getName() << ")\n"; + OS << "#undef DXIL_OP_PARAM_TYPE\n"; + OS << "#endif\n\n"; +} + +/// Emit a list of DXIL op function types +static void emitDXILOpFunctionTypes(ArrayRef Ops, + raw_ostream &OS) { + OS << "#ifndef DXIL_OP_FUNCTION_TYPE\n"; + OS << "#define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...)\n"; + OS << "#endif\n"; + for (const DXILOperationDesc &Op : Ops) { + OS << "DXIL_OP_FUNCTION_TYPE(dxil::OpCode::" << Op.OpName; + for (const Record *Rec : Op.OpTypes) + OS << ", dxil::OpParamType::" << Rec->getName(); + // If there are no arguments, we need an empty comma for the varargs + if (Op.OpTypes.size() == 1) + OS << ", "; + OS << ")\n"; + } + OS << "#undef DXIL_OP_FUNCTION_TYPE\n"; +} + /// Emit map of DXIL operation to LLVM or DirectX intrinsic /// \param A vector of DXIL Ops /// \param Output stream @@ -454,9 +401,7 @@ static void emitDXILOperationTable(std::vector &Ops, // Collect Names. SequenceToOffsetTable OpClassStrings; SequenceToOffsetTable OpStrings; - SequenceToOffsetTable> Parameters; - StringMap> ParameterMap; StringSet<> ClassSet; for (auto &Op : Ops) { OpStrings.add(Op.OpName); @@ -465,18 +410,11 @@ static void emitDXILOperationTable(std::vector &Ops, continue; ClassSet.insert(Op.OpClass); OpClassStrings.add(Op.OpClass.data()); - SmallVector ParamKindVec; - for (unsigned i = 0; i < Op.OpTypes.size(); i++) { - ParamKindVec.emplace_back(getParameterKind(Op.OpTypes[i])); - } - ParameterMap[Op.OpClass] = ParamKindVec; - Parameters.add(ParamKindVec); } // Layout names. OpStrings.layout(); OpClassStrings.layout(); - Parameters.layout(); // Emit access function getOpcodeProperty() that embeds DXIL Operation table // with entries of type struct OpcodeProperty. @@ -492,8 +430,7 @@ static void emitDXILOperationTable(std::vector &Ops, << getOverloadMaskString(Op.OverloadRecs) << ", " << getStageMaskString(Op.StageRecs) << ", " << getAttributeMaskString(Op.AttrRecs) << ", " << Op.OverloadParamIndex - << ", " << Op.OpTypes.size() << ", " - << Parameters.get(ParameterMap[Op.OpClass]) << " }"; + << " }"; Prefix = ",\n"; } OS << " };\n"; @@ -531,21 +468,6 @@ static void emitDXILOperationTable(std::vector &Ops, OS << " unsigned Index = Prop.OpCodeClassNameOffset;\n"; OS << " return DXILOpCodeClassNameTable + Index;\n"; - OS << "}\n "; - - OS << "static const ParameterKind *getOpCodeParameterKind(const " - "OpCodeProperty &Prop) " - "{\n\n"; - OS << " static const ParameterKind DXILOpParameterKindTable[] = {\n"; - Parameters.emit( - OS, - [](raw_ostream &ParamOS, ParameterKind Kind) { - ParamOS << "ParameterKind::" << getParameterKindStr(Kind); - }, - "ParameterKind::Invalid"); - OS << " };\n\n"; - OS << " unsigned Index = Prop.ParameterTableOffset;\n"; - OS << " return DXILOpParameterKindTable + Index;\n"; OS << "}\n\n"; } @@ -611,6 +533,8 @@ static void EmitDXILOperation(RecordKeeper &Records, raw_ostream &OS) { emitDXILOpCodes(DXILOps, OS); emitDXILOpClasses(Records, OS); + emitDXILOpParamTypes(Records, OS); + emitDXILOpFunctionTypes(DXILOps, OS); emitDXILIntrinsicMap(DXILOps, OS); OS << "#ifdef DXIL_OP_OPERATION_TABLE\n\n"; emitDXILOperationTableDataStructs(Records, OS);