Commit 2ea6fbe
committed
Check numerical equivalence / closeness between different kernel preferences
Summary:
This PR checks different kernel preferences for Float8Tensor are similar in numerics
(AUTO, TORCH and FBGEMM)
triton implementation and torchao implementation are a bit different right now actually, need to decide if we should fix it or not
1. difference in quantize op
main difference seems to be the triton implementation is using:
```
a_scale = MAX_FP8 / max_abs
then do
a_scale = 1.0 / a_scale
a_fp8 = a * a_scale
```
while torch is doing:
```
a_scale = max_abs / MAX_FP8
a_fp8 = a / a_scale
```
Also the hp_value_lb and hp_value_ub settings are slightly different
triton choose scale and quantize code: https://github.com/pytorch/FBGEMM/blob/a4286c01ef01dad435b2ec8798605127d3032cd8/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py#L2382-L2392
torchao choose scale and quantize code:
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2183
https://github.com/pytorch/ao/blob/3c466f844684af0fb80014094f2ca8663881eb33/torchao/quantization/quant_primitives.py#L2283
2. (potentially) difference in matrix multiplication ops
TORCH and AUTO/FBGEMM are using different quantized mm ops
Added a reverse option to bring sqnr closer:
```
granularity: PerTensor() sizes: ((128,), 256, 128) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor() sizes: ((128,), 256, 128) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerTensor() sizes: ((32, 128), 64, 256) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerTensor() sizes: ((32, 128), 64, 256) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow() sizes: ((128,), 256, 128) kp: KernelPreference.AUTO tensor(inf, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow() sizes: ((128,), 256, 128) kp: KernelPreference.FBGEMM tensor(inf, device='cuda:0', dtype=torch.bfloat16)
.granularity: PerRow() sizes: ((32, 128), 64, 256) kp: KernelPreference.AUTO tensor(64.5000, device='cuda:0', dtype=torch.bfloat16)
granularity: PerRow() sizes: ((32, 128), 64, 256) kp: KernelPreference.FBGEMM tensor(68., device='cuda:0', dtype=torch.bfloat16)
```
Test Plan:
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_kernel_preference_numerical_equivalence
Reviewers:
Subscribers:
Tasks:
Tags:
stack-info: PR: #2651, branch: jerryzh168/stack/151 parent b06dafd commit 2ea6fbe
File tree
4 files changed
+70
-6
lines changed- test/quantization/quantize_/workflows/float8
- torchao/quantization
- quantize_/workflows/float8
4 files changed
+70
-6
lines changedLines changed: 53 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
268 | 268 | | |
269 | 269 | | |
270 | 270 | | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
271 | 324 | | |
272 | 325 | | |
273 | 326 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1640 | 1640 | | |
1641 | 1641 | | |
1642 | 1642 | | |
1643 | | - | |
| 1643 | + | |
1644 | 1644 | | |
1645 | 1645 | | |
1646 | 1646 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2185 | 2185 | | |
2186 | 2186 | | |
2187 | 2187 | | |
2188 | | - | |
| 2188 | + | |
2189 | 2189 | | |
| 2190 | + | |
2190 | 2191 | | |
2191 | 2192 | | |
2192 | 2193 | | |
| |||
2214 | 2215 | | |
2215 | 2216 | | |
2216 | 2217 | | |
2217 | | - | |
| 2218 | + | |
| 2219 | + | |
| 2220 | + | |
| 2221 | + | |
| 2222 | + | |
2218 | 2223 | | |
2219 | 2224 | | |
2220 | 2225 | | |
| |||
2283 | 2288 | | |
2284 | 2289 | | |
2285 | 2290 | | |
| 2291 | + | |
2286 | 2292 | | |
2287 | 2293 | | |
2288 | 2294 | | |
| |||
2292 | 2298 | | |
2293 | 2299 | | |
2294 | 2300 | | |
2295 | | - | |
| 2301 | + | |
| 2302 | + | |
| 2303 | + | |
| 2304 | + | |
2296 | 2305 | | |
2297 | 2306 | | |
2298 | 2307 | | |
| |||
Lines changed: 4 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
169 | 169 | | |
170 | 170 | | |
171 | 171 | | |
172 | | - | |
| 172 | + | |
173 | 173 | | |
174 | 174 | | |
175 | 175 | | |
| |||
209 | 209 | | |
210 | 210 | | |
211 | 211 | | |
| 212 | + | |
212 | 213 | | |
213 | | - | |
| 214 | + | |
| 215 | + | |
214 | 216 | | |
215 | 217 | | |
216 | 218 | | |
| |||
0 commit comments