|
22 | 22 | preprocess_data, |
23 | 23 | preprocess_scale, |
24 | 24 | ) |
25 | | -from torchao.quantization.granularity import PerRow |
| 25 | +from torchao.quantization.granularity import PerRow, PerTensor |
26 | 26 | from torchao.quantization.observer import get_block_size |
27 | 27 | from torchao.quantization.quant_primitives import ( |
28 | 28 | _choose_scale_float8, |
@@ -178,32 +178,61 @@ def from_hp( |
178 | 178 | block_size = get_block_size(hp_tensor.shape, granularity) |
179 | 179 | block_size = list(block_size) |
180 | 180 |
|
181 | | - # for per row quantization and kernel_preference default setting, we'll use triton kernel for best performance |
| 181 | + kernel_choice = None |
182 | 182 | if ( |
183 | 183 | kernel_preference == KernelPreference.AUTO |
184 | 184 | and _is_fbgemm_genai_gpu_available() |
185 | | - and ( |
186 | | - tuple(block_size) |
187 | | - == (1,) * (hp_tensor.ndim - 1) + (hp_tensor.shape[-1],) |
188 | | - ) |
| 185 | + and is_sm_at_least_90() |
| 186 | + and isinstance(granularity, PerRow) |
| 187 | + and float8_dtype == torch.float8_e4m3fn |
| 188 | + and hp_value_lb is None |
189 | 189 | ): |
190 | | - assert float8_dtype == torch.float8_e4m3fn, ( |
191 | | - f"Only torch.float8_e4m3fn is supported, got: {float8_dtype}" |
| 190 | + # if kernel_preference is AUTO and per row quantization |
| 191 | + # we'll use fbgemm quantize kernel for best performance |
| 192 | + kernel_choice = "fbgemm" |
| 193 | + elif kernel_preference == KernelPreference.FBGEMM: |
| 194 | + # if user explicitly chose FBGEMM kernel preference, we'll also use fbgemm kernel |
| 195 | + assert _is_fbgemm_genai_gpu_available() and is_sm_at_least_90(), ( |
| 196 | + "Specified fbgemm but fbgemm_gpu_genai is not installed or hardware is not >= SM 9.0 (>= H100)" |
| 197 | + ) |
| 198 | + assert hp_value_lb is None, ( |
| 199 | + "hp_value_lb should not be specified if with KerenelPreference.FBGEMM" |
192 | 200 | ) |
| 201 | + kernel_choice = "fbgemm" |
| 202 | + else: |
| 203 | + # fallback quantize kernel for everything else will be torch |
| 204 | + kernel_choice = "torch" |
| 205 | + |
| 206 | + if kernel_choice == "fbgemm": |
| 207 | + assert hp_value_lb is None, f"{hp_value_lb=} is not supported" |
193 | 208 | if hp_value_ub is not None: |
194 | 209 | maybe_hp_value_ub_tensor = torch.tensor( |
195 | 210 | hp_value_ub, dtype=torch.float, device=hp_tensor.device |
196 | 211 | ) |
197 | 212 | else: |
198 | 213 | maybe_hp_value_ub_tensor = None |
199 | | - data, scale = torch.ops.triton.quantize_fp8_row( |
200 | | - hp_tensor, scale_ub=maybe_hp_value_ub_tensor |
201 | | - ) |
202 | | - scale_shape = [] |
203 | | - for i in range(hp_tensor.ndim): |
204 | | - scale_shape.append(hp_tensor.shape[i] // block_size[i]) |
205 | | - scale = scale.reshape(*scale_shape) |
| 214 | + if isinstance(granularity, PerRow): |
| 215 | + data, scale = torch.ops.triton.quantize_fp8_row( |
| 216 | + hp_tensor, scale_ub=maybe_hp_value_ub_tensor |
| 217 | + ) |
| 218 | + scale_shape = [] |
| 219 | + for i in range(hp_tensor.ndim): |
| 220 | + scale_shape.append(hp_tensor.shape[i] // block_size[i]) |
| 221 | + scale = scale.reshape(*scale_shape) |
| 222 | + else: |
| 223 | + assert isinstance(granularity, PerTensor), ( |
| 224 | + f"Expected per tensor, got {granularity}" |
| 225 | + ) |
| 226 | + # current error: torch.AcceleratorError: CUDA error: an illegal memory access was encountered |
| 227 | + # TODO: enable after this is working |
| 228 | + # data, scale = torch.ops.fbgemm.quantize_fp8_per_tensor( |
| 229 | + # hp_tensor, num_tokens, scale_ub=maybe_hp_value_ub_tensor |
| 230 | + # ) |
| 231 | + raise NotImplementedError( |
| 232 | + "Currently KernelPreference.FBGEMM does not work for per tensor float8 quant" |
| 233 | + ) |
206 | 234 | else: |
| 235 | + assert kernel_choice == "torch", f"Expected torch, got {kernel_choice}" |
207 | 236 | scale = _choose_scale_float8( |
208 | 237 | hp_tensor, |
209 | 238 | float8_dtype=float8_dtype, |
|
0 commit comments