@@ -83,15 +83,15 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
8383 // This is the BF16 {-136, -136} represented as an integer.
8484#if defined(USE_ROCM)
8585#if ROCM_VERSION >= 60200
86- auto BF16_BIAS = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0xC308 }));
87- auto BF16_ONE = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0x3F80 }));
86+ auto BF16_SCALE_FACTOR = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0xC308 }));
87+ auto BF16_UNIT_VALUE = __bfloat162bfloat162 (__hip_bfloat16 (__hip_bfloat16_raw{0x3F80 }));
8888#else
89- auto BF16_BIAS = __bfloat162bfloat162 (__hip_bfloat16{0xC308 });
90- auto BF16_ONE = __bfloat162bfloat162 (__hip_bfloat16{0x3F80 });
89+ auto BF16_SCALE_FACTOR = __bfloat162bfloat162 (__hip_bfloat16{0xC308 });
90+ auto BF16_UNIT_VALUE = __bfloat162bfloat162 (__hip_bfloat16{0x3F80 });
9191#endif
9292#else
93- static constexpr uint32_t BF16_BIAS = 0xC308C308 ;
94- static constexpr uint32_t BF16_ONE = 0x3F803F80 ;
93+ static constexpr uint32_t BF16_SCALE_FACTOR = 0xC308C308 ;
94+ static constexpr uint32_t BF16_UNIT_VALUE = 0x3F803F80 ;
9595#endif
9696
9797// Finally, we construct the output numbers.
@@ -100,11 +100,11 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
100100 // Since this section is for Ampere+, we use bf16 fma to do the bias
101101 // subtraction
102102#if defined(USE_ROCM)
103- result.vals [ii] = __hfma2 (result.vals [ii], BF16_ONE, BF16_BIAS );
103+ result.vals [ii] = __hfma2 (result.vals [ii], BF16_UNIT_VALUE, BF16_SCALE_FACTOR );
104104#else
105105 asm (" fma.rn.bf16x2 %0, %1, %2, %3;\n "
106106 : " =r" (h[ii])
107- : " r" (h[ii]), " r" (BF16_ONE ), " r" (BF16_BIAS ));
107+ : " r" (h[ii]), " r" (BF16_UNIT_VALUE ), " r" (BF16_SCALE_FACTOR ));
108108#endif
109109 }
110110
@@ -369,3 +369,4 @@ TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
369369}
370370
371371#endif
372+ git checkout main -- file.txt
0 commit comments