diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h index da6903df65ffe..eff94e0d7f950 100644 --- a/clang/lib/Headers/hlsl/hlsl_basic_types.h +++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h @@ -23,52 +23,98 @@ namespace hlsl { // 16-bit integer. typedef unsigned short uint16_t; typedef short int16_t; + +// 16-bit floating point. +typedef half float16_t; #endif +// 32-bit integer. +typedef int int32_t; + // unsigned 32-bit integer. typedef unsigned int uint; +typedef unsigned int uint32_t; + +// 32-bit floating point. +typedef float float32_t; // 64-bit integer. typedef unsigned long uint64_t; typedef long int64_t; +// 64-bit floating point +typedef double float64_t; + // built-in vector data types: #ifdef __HLSL_ENABLE_16_BIT +typedef vector int16_t1; typedef vector int16_t2; typedef vector int16_t3; typedef vector int16_t4; +typedef vector uint16_t1; typedef vector uint16_t2; typedef vector uint16_t3; typedef vector uint16_t4; #endif +typedef vector bool1; typedef vector bool2; typedef vector bool3; typedef vector bool4; +typedef vector int1; typedef vector int2; typedef vector int3; typedef vector int4; +typedef vector uint1; typedef vector uint2; typedef vector uint3; typedef vector uint4; +typedef vector int32_t1; +typedef vector int32_t2; +typedef vector int32_t3; +typedef vector int32_t4; +typedef vector uint32_t1; +typedef vector uint32_t2; +typedef vector uint32_t3; +typedef vector uint32_t4; +typedef vector int64_t1; typedef vector int64_t2; typedef vector int64_t3; typedef vector int64_t4; +typedef vector uint64_t1; typedef vector uint64_t2; typedef vector uint64_t3; typedef vector uint64_t4; +typedef vector half1; typedef vector half2; typedef vector half3; typedef vector half4; - +typedef vector float1; typedef vector float2; typedef vector float3; typedef vector float4; +typedef vector double1; typedef vector double2; typedef vector double3; typedef vector double4; +#ifdef __HLSL_ENABLE_16_BIT +typedef vector float16_t1; +typedef vector float16_t2; +typedef vector float16_t3; +typedef vector float16_t4; +#endif + +typedef vector float32_t1; +typedef vector float32_t2; +typedef vector float32_t3; +typedef vector float32_t4; +typedef vector float64_t1; +typedef vector float64_t2; +typedef vector float64_t3; +typedef vector float64_t4; + } // namespace hlsl #endif //_HLSL_HLSL_BASIC_TYPES_H_ diff --git a/clang/test/SemaHLSL/Types/typedefs.hlsl b/clang/test/SemaHLSL/Types/typedefs.hlsl new file mode 100644 index 0000000000000..fd72b1ae8a47f --- /dev/null +++ b/clang/test/SemaHLSL/Types/typedefs.hlsl @@ -0,0 +1,34 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.4-library -finclude-default-header -verify -fnative-half-type %s +// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -finclude-default-header -verify -fnative-half-type %s + +// expected-no-diagnostics +#define SizeCheck(Ty, SizeInBits) \ + _Static_assert(sizeof(Ty) == SizeInBits / 8, #Ty " is " #SizeInBits "-bit"); \ + _Static_assert(sizeof(Ty##1) == (SizeInBits * 1) / 8, #Ty "1 is 1x" #SizeInBits "-bit"); \ + _Static_assert(__builtin_vectorelements(Ty##1) == 1, #Ty "1 is has 1 " #SizeInBits "-bit element"); \ + _Static_assert(sizeof(Ty##2) == (SizeInBits * 2) / 8, #Ty "2 is 2x" #SizeInBits "-bit"); \ + _Static_assert(__builtin_vectorelements(Ty##2) == 2, #Ty "2 is has 2 " #SizeInBits "-bit element"); \ + _Static_assert(__builtin_vectorelements(Ty##3) == 3, #Ty "3 is has 3 " #SizeInBits "-bit element"); \ + _Static_assert(sizeof(Ty##4) == (SizeInBits * 4) / 8, #Ty "4 is 4x" #SizeInBits "-bit"); \ + _Static_assert(__builtin_vectorelements(Ty##4) == 4, #Ty "4 is has 4 " #SizeInBits "-bit element"); + +// FIXME: https://github.com/llvm/llvm-project/issues/104503 - 3 element vectors +// should be the size of 3 elements not padded to 4. +// _Static_assert(sizeof(Ty##3) == (SizeInBits * 3) / 8, #Ty "3 is 3x" #SizeInBits "-bit"); + +SizeCheck(int16_t, 16); +SizeCheck(uint16_t, 16); +SizeCheck(half, 16); +SizeCheck(float16_t, 16); + +SizeCheck(int, 32); +SizeCheck(uint, 32); +SizeCheck(int32_t, 32); +SizeCheck(uint32_t, 32); +SizeCheck(float, 32); +SizeCheck(float32_t, 32); + +SizeCheck(int64_t, 64); +SizeCheck(uint64_t, 64); +SizeCheck(double, 64); +SizeCheck(float64_t, 64);