Skip to content

Commit 2272a3f

Browse files
authored
Merge branch 'main' into rocm_enablement_staging
2 parents 4148828 + 9bcd73b commit 2272a3f

File tree

141 files changed

+1584
-3788
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+1584
-3788
lines changed

.github/workflows/dashboard_perf_test.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ jobs:
4242
4343
mkdir -p ${{ runner.temp }}/benchmark-results
4444
# llama3 - compile baseline
45-
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
45+
${CONDA_RUN} python benchmarks/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
4646
4747
# llama3 - autoquant
48-
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
48+
${CONDA_RUN} python benchmarks/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
4949
5050
# skipping SAM because of https://hud.pytorch.org/pr/pytorch/ao/1407
5151
# # SAM
5252
# ${CONDA_RUN} pip install git+https://github.com/pytorch-labs/segment-anything-fast.git@main
5353
# # SAM compile baselilne
54-
# ${CONDA_RUN} sh torchao/_models/sam/setup.sh
55-
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
54+
# ${CONDA_RUN} sh benchmarks/_models/sam/setup.sh
55+
# ${CONDA_RUN} python benchmarks/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
5656
57-
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
57+
# ${CONDA_RUN} python benchmarks/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
5858
5959
# SAM 2.1
6060
# ${CONDA_RUN} sh scripts/download_sam2_ckpts.sh ${CHECKPOINT_PATH}/sam2

.github/workflows/torchao_experimental_test.yml

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ on:
1111
- 'gh/**'
1212

1313
jobs:
14-
test:
14+
test-cpu-ops:
1515
strategy:
1616
matrix:
1717
runner: [macos-14]
@@ -53,6 +53,58 @@ jobs:
5353
run: |
5454
conda activate venv
5555
pushd torchao/experimental/ops/tests
56-
sh build_and_run_tests.sh
57-
rm -rf /tmp/cmake-out
56+
# sh build_and_run_tests.sh
57+
# rm -rf /tmp/cmake-out
58+
popd
59+
60+
test-mps-ops:
61+
strategy:
62+
matrix:
63+
runner: [macos-m1-stable]
64+
runs-on: ${{matrix.runner}}
65+
steps:
66+
- name: Print machine info
67+
run: |
68+
uname -a
69+
if [ $(uname -s) == Darwin ]; then
70+
sysctl machdep.cpu.brand_string
71+
sysctl machdep.cpu.core_count
72+
fi
73+
- name: Checkout repo
74+
uses: actions/checkout@v3
75+
with:
76+
submodules: true
77+
- name: Create conda env
78+
run: |
79+
conda create -yn test-mps-ops-env python=3.11
80+
- name: Activate conda env
81+
run: |
82+
source activate base
83+
conda activate test-mps-ops-env
84+
- name: Install torch
85+
run: |
86+
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
87+
- name: Print torch version
88+
run: |
89+
python -c "import torch; print(torch.__version__)"
90+
- name: Install requirements
91+
run: |
92+
pip install cmake
93+
pip install parameterized
94+
pip install pyyaml
95+
pip install numpy
96+
- name: Print pip freeze
97+
run: |
98+
pip freeze
99+
- name: Print current directory
100+
run: |
101+
python -c "import os; print(os.getcwd())"
102+
- name: Build ao with experimental mps ops
103+
run: |
104+
USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install .
105+
- name: Run mps tests
106+
run: |
107+
pushd torchao/experimental/ops/mps/test
108+
python test_lowbit.py
109+
python test_quantizer.py
58110
popd

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,4 @@ checkpoints/
375375

376376
# Experimental
377377
torchao/experimental/cmake-out
378+
torchao/experimental/deps

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ torchao just works with `torch.compile()` and `FSDP2` over most PyTorch models o
1919

2020
### Post Training Quantization
2121

22-
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/torchao/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
22+
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/benchmarks/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
2323

2424
For inference, we have the option of
2525
1. Quantize only the weights: works best for memory bound models
@@ -52,7 +52,7 @@ We also provide a developer facing API so you can implement your own quantizatio
5252

5353
We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.
5454

55-
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
55+
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](benchmarks/_models/llama/README.md)
5656

5757
## Training
5858

@@ -159,20 +159,20 @@ Things we're excited about but need more time to cook in the oven
159159

160160
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
161161

162-
Stable release from Pypi which will default to CUDA 12.1
162+
Stable release from Pypi which will default to CUDA 12.4
163163

164164
```Shell
165165
pip install torchao
166166
```
167167

168168
Stable Release from the PyTorch index
169169
```Shell
170-
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
170+
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu124 # full options are cpu/cu118/cu124/cu126
171171
```
172172

173173
Nightly Release
174174
```Shell
175-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
175+
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 # full options are cpu/cu118/cu126/cu128
176176
```
177177

178178
For *most* developers you probably want to skip building custom C++/CUDA extensions for faster iteration
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

torchao/_models/llama/README.md renamed to benchmarks/_models/llama/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ and follow the steps to gain access.
88
Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to
99
download and convert the model weights
1010

11-
once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation
11+
once done you can execute benchmarks from the benchmarks/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation
1212
directly using `generate.py` or `eval.py`.
1313

1414
## KV Cache Quantization - Memory Efficient Inference

0 commit comments

Comments
 (0)