File tree Expand file tree Collapse file tree 2 files changed +7
-1
lines changed
torchao/csrc/cuda/tensor_core_tiled_layout Expand file tree Collapse file tree 2 files changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def get_extensions():
5959
6060 if not torch .cuda .is_available ():
6161 print ("PyTorch GPU support is not available. Skipping compilation of CUDA extensions" )
62- if CUDA_HOME is None or not IS_ROCM and torch .cuda .is_available ():
62+ if ( CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda .is_available ():
6363 print ("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" )
6464 print ("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" )
6565
Original file line number Diff line number Diff line change @@ -167,6 +167,7 @@ __global__ void _dequantize_int4_kernel(
167167 // All b values within a 16x16 tile should fall within the same q group
168168 // Hence we load 1 scale and zero per loop
169169 int qgroup = ks[0 ] / groupSize;
170+ #if defined(USE_ROCM)
170171 __nv_bfloat162 scale2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
171172 __nv_bfloat162 zero2 = __bfloat162bfloat162 (__hip_bfloat16 (1 .0f ));
172173
@@ -177,6 +178,11 @@ __global__ void _dequantize_int4_kernel(
177178 scale2 = __bfloat162bfloat162 (pSZ[0 ]);
178179 zero2 = __bfloat162bfloat162 (pSZ[1 ]);
179180 }
181+ #else
182+ const __nv_bfloat16 *pSZ = reinterpret_cast <const __nv_bfloat16*>(&scales_and_zeros.value ()[qgroup][n0][0 ]);
183+ __nv_bfloat162 scale2 = __bfloat162bfloat162 (pSZ[0 ]);
184+ __nv_bfloat162 zero2 = __bfloat162bfloat162 (pSZ[1 ]);
185+ #endif
180186
181187 #pragma unroll
182188 for (int i = 0 ; i < 4 ; i++) {
You can’t perform that action at this time.
0 commit comments