diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 67015cff78a79..60185c20606b2 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -671,8 +671,9 @@ def Dot4 : DXILOp<56, dot4> { def ThreadId : DXILOp<93, threadId> { let Doc = "Reads the thread ID"; let LLVMIntrinsic = int_dx_thread_id; - let arguments = [i32Ty]; - let result = i32Ty; + let arguments = [overloadTy]; + let result = overloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -680,8 +681,9 @@ def ThreadId : DXILOp<93, threadId> { def GroupId : DXILOp<94, groupId> { let Doc = "Reads the group ID (SV_GroupID)"; let LLVMIntrinsic = int_dx_group_id; - let arguments = [i32Ty]; - let result = i32Ty; + let arguments = [overloadTy]; + let result = overloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -689,8 +691,9 @@ def GroupId : DXILOp<94, groupId> { def ThreadIdInGroup : DXILOp<95, threadIdInGroup> { let Doc = "Reads the thread ID within the group (SV_GroupThreadID)"; let LLVMIntrinsic = int_dx_thread_id_in_group; - let arguments = [i32Ty]; - let result = i32Ty; + let arguments = [overloadTy]; + let result = overloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } @@ -699,7 +702,8 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> { let Doc = "Provides a flattened index for a given thread within a given " "group (SV_GroupIndex)"; let LLVMIntrinsic = int_dx_flattened_thread_id_in_group; - let result = i32Ty; + let result = overloadTy; + let overloads = [Overloads]; let stages = [Stages]; let attributes = [Attributes]; } diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 91e6931b3f788..0e2b4601112b5 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -87,6 +87,9 @@ static const char *getOverloadTypeName(OverloadKind Kind) { } static OverloadKind getOverloadKind(Type *Ty) { + if (!Ty) + return OverloadKind::VOID; + Type::TypeID T = Ty->getTypeID(); switch (T) { case Type::VoidTyID: @@ -379,11 +382,7 @@ Expected DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys; - // If we don't have an overload type, use the function's return type. This is - // a bit of a hack, but it's necessary to get the type suffix on unoverloaded - // DXIL ops correct, like `dx.op.threadId.i32`. - OverloadKind Kind = - getOverloadKind(OverloadTy ? OverloadTy : DXILOpFT->getReturnType()); + OverloadKind Kind = getOverloadKind(OverloadTy); // Check if the operation supports overload types and OverloadTy is valid // per the specified types for the operation