Skip to content

Commit d51715a

Browse files
authored
Merge branch 'main' into fix_harware_check
2 parents f292646 + aeb1944 commit d51715a

File tree

5 files changed

+72
-22
lines changed

5 files changed

+72
-22
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ We're also fortunate to be integrated into some of the leading open-source libra
178178
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
179179
4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes
180180
5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization
181-
6. [SGLang](https://github.com/sgl-project/sglang/pull/1341) for LLM inference quantization
181+
6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).
182182

183183
## Videos
184184
* [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)
@@ -205,4 +205,5 @@ If you find the torchao library useful, please cite it in your work as below.
205205
license = {BSD-3-Clause},
206206
month = oct,
207207
year = {2024}
208+
}
208209
```

test/float8/test_base.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pytest
1414
import torch
1515
import torch.nn as nn
16-
1716
from torchao.utils import (
1817
TORCH_VERSION_AT_LEAST_2_5,
1918
is_sm_at_least_89,
@@ -537,6 +536,21 @@ def test_inference_mode(self):
537536
with torch.inference_mode(mode=True):
538537
m(x)
539538

539+
@unittest.skipIf(not is_sm_89(), "CUDA arch 8.9 not available")
540+
def test_quantize(self):
541+
x = torch.randn(32, 32, device="cuda")
542+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
543+
m = convert_to_float8_training(m)
544+
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
545+
from torchao.quantization.quant_api import float8_weight_only, quantize_
546+
547+
quantize_(m, float8_weight_only())
548+
assert (
549+
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
550+
), "Post quantization dtype should be torch.float8_e4m3fn"
551+
with torch.no_grad():
552+
m(x)
553+
540554

541555
class TestScaledMM:
542556
@unittest.skipIf(
@@ -582,7 +596,7 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
582596
if base_dtype in {torch.bfloat16, torch.float16}:
583597
atol, rtol = 7e-2, 7e-2
584598
else:
585-
atol, rtol = 2e-3, 2e-3
599+
atol, rtol = 3e-3, 3e-3
586600
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
587601

588602
@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")

torchao/_models/llama/generate.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@
2020
def device_sync(device):
2121
if "cuda" in device:
2222
torch.cuda.synchronize(device)
23+
elif "xpu" in device:
24+
torch.xpu.synchronize(device)
2325
elif ("cpu" in device) or ("mps" in device):
2426
pass
2527
else:
2628
print(f"device={device} is not yet suppported")
2729

28-
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
30+
default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu'
2931

3032
# support running without installing as a package
3133
wd = Path(__file__).parent.parent.resolve()
@@ -440,10 +442,13 @@ def main(
440442
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
441443

442444
if memory_profile:
443-
if device != "cuda":
444-
print("Memory profiling only works on CUDA")
445-
else:
445+
if device == "cuda":
446446
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
447+
elif device == "xpu":
448+
torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
449+
else:
450+
print("Memory profiling only works on CUDA or XPU devices")
451+
447452
aggregate_metrics = {
448453
'tokens_per_sec': [],
449454
}
@@ -453,6 +458,8 @@ def main(
453458
if i==0:
454459
if device == "cuda":
455460
torch.cuda.reset_peak_memory_stats() # MKG
461+
elif device == "xpu":
462+
torch.xpu.reset_peak_memory_stats() # MKG
456463
device_sync(device=device) # MKG
457464
if i >= 0 and interactive:
458465
prompt = input("What is your prompt? ")
@@ -520,24 +527,29 @@ def callback(x):
520527
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")
521528

522529
if memory_profile and i==0:
523-
if device != "cuda":
524-
print("Memory profiling only works on CUDA")
525-
else:
530+
if device == "cuda":
526531
snapshot = torch.cuda.memory._snapshot()
527-
with open(f"{memory_profile}.pickle", 'wb') as f:
528-
from pickle import dump
529-
dump(snapshot, f)
530-
print(
531-
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
532-
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
533-
)
534-
break
535-
532+
elif device == "xpu":
533+
snapshot = torch.xpu.memory._snapshot()
534+
else:
535+
print("Memory profiling only works on CUDA or XPU devices")
536+
537+
with open(f"{memory_profile}.pickle", 'wb') as f:
538+
from pickle import dump
539+
dump(snapshot, f)
540+
print(
541+
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
542+
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
543+
)
544+
break
536545
print("==========")
537546

538547
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
539548
bandwidth = model_size * tokpersec
540-
mem = torch.cuda.max_memory_reserved() /1e9
549+
if device == "cuda":
550+
mem = torch.cuda.max_memory_reserved() /1e9
551+
elif device == "xpu":
552+
mem = torch.xpu.max_memory_reserved() /1e9
541553
print(f"Average tokens/sec: {tokpersec:.2f}")
542554
if batch_size > 1:
543555
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")

torchao/quantization/README.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Typically quantization algorithms will have different schemes for how the activa
33

44
## Benchmarks
55
Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B.
6-
6+
### CUDA backend
77
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
88
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
99
| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 |
@@ -20,9 +20,16 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP
2020
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
2121
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
2222
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
23+
### XPU backend
24+
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
25+
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
26+
| Llama-2-7B | Base (bfloat16) | NA | 42.20 | 557.71 | 13.89 | 13.21 |
27+
| | int8dq | NA | 9.87 | 65.35 | 14.60 | 6.62 |
28+
| | int8wo | NA | 66.24 | 438.61 | 14.60 | 6.62
29+
2330

24-
Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.
2531

32+
### CUDA backend
2633
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
2734
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
2835
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
@@ -31,6 +38,15 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma
3138
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
3239
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
3340
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
41+
### XPU backend
42+
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
43+
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
44+
| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 |
45+
| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 |
46+
| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52
47+
48+
49+
Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU or Intel-Max1100 using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.
3450

3551
note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.
3652

torchao/quantization/quant_api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
to_affine_quantized_intx,
4040
to_marlinqqq_quantized_intx,
4141
)
42+
from torchao.float8.float8_linear import Float8Linear
4243
from torchao.float8.inference import Float8MMConfig
4344
from torchao.quantization.linear_activation_weight_observed_tensor import (
4445
LinearActivationWeightObservedTensor,
@@ -222,6 +223,12 @@ def _replace_with_custom_fn_if_matches_filter(
222223
Returns:
223224
None
224225
"""
226+
if isinstance(model, Float8Linear):
227+
with torch.device("meta"):
228+
new_module = nn.Linear(model.in_features, model.out_features)
229+
new_module.weight = model.weight
230+
new_module.bias = model.bias
231+
model = new_module
225232
if filter_fn(model, cur_fqn[:-1]):
226233
if device is not None:
227234
model.to(device=device) # move to device before quantization

0 commit comments

Comments
 (0)