Skip to content

Commit b96196b

Browse files
authored
Merge branch 'main' into rocm_sparse_marlin
2 parents 08d1cfb + e1cb44a commit b96196b

File tree

62 files changed

+2231
-1283
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2231
-1283
lines changed

.github/pytorch-probot.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
mergebot: True
22
ciflow_push_tags:
33
- ciflow/benchmark
4+
- ciflow/tutorials

.github/workflows/dashboard_perf_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
torch-spec:
1717
- '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
1818
steps:
19-
- uses: actions/checkout@v3
19+
- uses: actions/checkout@v4
2020

2121
- name: Setup miniconda
2222
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
@@ -55,7 +55,7 @@ jobs:
5555
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
5656
5757
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
58-
58+
5959
# SAM 2.1
6060
# ${CONDA_RUN} sh scripts/download_sam2_ckpts.sh ${CHECKPOINT_PATH}/sam2
6161
# cd examples/sam2_amg_server

.github/workflows/doc_build.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
python-version: ['3.11']
2929
steps:
3030
- name: Check out repo
31-
uses: actions/checkout@v3
31+
uses: actions/checkout@v4
3232
- name: Setup conda env
3333
uses: conda-incubator/setup-miniconda@v2
3434
with:
@@ -50,7 +50,7 @@ jobs:
5050
run: |
5151
cd docs
5252
make html
53-
- uses: actions/upload-artifact@v3
53+
- uses: actions/upload-artifact@v4
5454
with:
5555
name: Doc-Build
5656
path: docs/build/html/
@@ -61,9 +61,9 @@ jobs:
6161
if: ${{ github.event_name == 'pull_request' }}
6262
steps:
6363
- name: Checkout
64-
uses: actions/checkout@v3
64+
uses: actions/checkout@v4
6565
- name: Download artifact
66-
uses: actions/download-artifact@v3
66+
uses: actions/download-artifact@v4
6767
with:
6868
name: Doc-Build
6969
path: docs
@@ -86,12 +86,12 @@ jobs:
8686
if: github.repository == 'pytorch/ao' && github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v') || github.event_name == 'workflow_dispatch')
8787
steps:
8888
- name: Checkout
89-
uses: actions/checkout@v3
89+
uses: actions/checkout@v4
9090
with:
9191
ref: gh-pages
9292
persist-credentials: true
9393
- name: Download artifact
94-
uses: actions/download-artifact@v3
94+
uses: actions/download-artifact@v4
9595
with:
9696
name: Doc-Build
9797
path: docs

.github/workflows/ruff_linter.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
PR_NUMBER=$(echo $PR_URL | grep -oE '[0-9]+$')
3535
echo "PR_NUMBER=$PR_NUMBER" >> $GITHUB_ENV
3636
37-
- uses: actions/checkout@v3
37+
- uses: actions/checkout@v4
3838
if: github.event_name == 'workflow_dispatch'
3939
with:
4040
fetch-depth: 0
@@ -47,7 +47,7 @@ jobs:
4747
env:
4848
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
4949

50-
- uses: actions/checkout@v3
50+
- uses: actions/checkout@v4
5151
if: github.event_name != 'workflow_dispatch'
5252
with:
5353
fetch-depth: 0
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: Run tutorials
2+
3+
on:
4+
push:
5+
tags:
6+
- ciflow/tutorials/*
7+
workflow_dispatch:
8+
9+
jobs:
10+
run_tutorials:
11+
runs-on: linux.aws.a100
12+
strategy:
13+
matrix:
14+
torch-spec:
15+
- '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
- name: Setup miniconda
20+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
21+
with:
22+
python-version: "3.9"
23+
24+
- name: Run tutorials
25+
shell: bash
26+
run: |
27+
set -eux
28+
${CONDA_RUN} python -m pip install --upgrade pip
29+
${CONDA_RUN} pip install ${{ matrix.torch-spec }}
30+
${CONDA_RUN} pip install -r dev-requirements.txt
31+
${CONDA_RUN} pip install .
32+
cd tutorials
33+
${CONDA_RUN} bash run_all.sh

.github/workflows/trymerge.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
steps:
1717
- name: Checkout repo
1818
id: checkout
19-
uses: actions/checkout@v3
19+
uses: actions/checkout@v4
2020
with:
2121
fetch-depth: 0
2222
token: ${{ secrets.PYTORCH_MERGEBOT_TOKEN }}

README.md

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,38 @@ We've added kv cache quantization and other features in order to enable long con
5454

5555
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
5656

57+
## Training
58+
5759
### Quantization Aware Training
5860

59-
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
61+
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md).
6062

6163
```python
62-
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
63-
64-
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
64+
from torchao.quantization import (
65+
quantize_,
66+
int8_dynamic_activation_int4_weight,
67+
)
68+
from torchao.quantization.qat import (
69+
FakeQuantizeConfig,
70+
from_intx_quantization_aware_training,
71+
intx_quantization_aware_training,
72+
)
6573

66-
# Insert "fake quantize" operations into linear layers.
67-
# These operations simulate quantization numerics
68-
model = qat_quantizer.prepare(model)
74+
# Insert fake quantization
75+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
76+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
77+
quantize_(
78+
my_model,
79+
intx_quantization_aware_training(activation_config, weight_config),
80+
)
6981

70-
# Run Training...
82+
# Run training... (not shown)
7183

72-
# Convert fake quantize to actual quantize operations
73-
model = qat_quantizer.convert(model)
84+
# Convert fake quantization to actual quantized operations
85+
quantize_(my_model, from_intx_quantization_aware_training())
86+
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
7487
```
7588

76-
## Training
77-
7889
### Float8
7990

8091
[torchao.float8](torchao/float8) implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433.

benchmarks/float8/profile_linear_float8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def main(
355355
1, 2048, 4096, device=device, dtype=ref_dtype
356356
).requires_grad_()
357357
else:
358-
M, K, N = 4096, 4096, 4096
358+
M, K, N = 2048, 4096, 8192
359359
m_ref = torch.nn.Sequential(
360360
torch.nn.Linear(K, N, bias=False),
361361
)

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ lm_eval
2121
diskcache
2222
pycocotools
2323
tqdm
24+
importlib_metadata
2425

2526
# Custom CUDA Extensions
2627
ninja

0 commit comments

Comments
 (0)