Skip to content

Commit 1383982

Browse files
committed
tp support
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b195f11 commit 1383982

File tree

10 files changed

+317
-214
lines changed

10 files changed

+317
-214
lines changed

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,21 @@ class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel)
138138
def test_tp(self, dtype):
139139
return self._test_tp(dtype)
140140

141+
class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
142+
from torchao.quantization import gemlite_uintx_weight_only
143+
QUANT_METHOD_FN = staticmethod(gemlite_uintx_weight_only)
144+
COMMON_DTYPES = [torch.float16]
145+
146+
@common_utils.parametrize("dtype", COMMON_DTYPES)
147+
@with_comms
148+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
149+
def test_tp_gemlite(self, dtype):
150+
return self._test_tp(dtype)
151+
141152

142153
common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
143154
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
155+
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)
144156

145157
# Run only on H100
146158
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):

test/integration/test_integration.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@
9696
)
9797
from torchao.dtypes.utils import is_device
9898

99+
try:
100+
import gemlite
101+
has_gemlite = True
102+
except ModuleNotFoundError:
103+
has_gemlite = False
104+
99105
logger = logging.getLogger("INFO")
100106

101107
torch.manual_seed(0)
@@ -870,6 +876,9 @@ def _test_lin_weight_subclass_api_impl(
870876
ref_f = mod(x)
871877
api(mod)
872878

879+
# test get_plain()
880+
mod[0].weight.tensor_impl.get_plain()
881+
873882
test = mod(x)
874883
self.assertGreater(
875884
SQNR(ref_f, test),
@@ -930,6 +939,31 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
930939
test_dtype=dtype
931940
)
932941

942+
@parameterized.expand(COMMON_DEVICE_DTYPE)
943+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater")
944+
@unittest.skipIf(not has_gemlite, "gemlite not available")
945+
def test_gemlite_layout(self, device, dtype):
946+
if dtype!= torch.float16:
947+
self.skipTest(f"gemlite only works for fp16 dtype")
948+
from torchao.quantization import gemlite_uintx_weight_only
949+
if device == "cpu":
950+
self.skipTest(f"gemlite is for cuda, not {device}")
951+
for packing_bitwidth in [32, 8]:
952+
953+
for bit_width in [4,8]:
954+
for group_size in [64, 32, None] if bit_width ==4 else [None]:
955+
api = lambda mod: quantize_(mod, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
956+
for test_shape in [[1, 1024, 512],[16, 256, 1024], [128, 256, 1024]]:
957+
print(packing_bitwidth, bit_width, group_size, test_shape, dtype)
958+
self._test_lin_weight_subclass_api_impl(
959+
api,
960+
device,
961+
15,
962+
test_shape=test_shape,
963+
test_dtype=dtype,
964+
)
965+
966+
933967
@parameterized.expand(COMMON_DEVICE_DTYPE)
934968
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
935969
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")

torchao/_models/llama/benchmark_results.txt

Lines changed: 92 additions & 0 deletions
Large diffs are not rendered by default.

torchao/_models/llama/benchmarks.sh

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

33
# README BENCHMARKS
4-
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-4-None --write_result benchmark_results.txt
4+
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5+
6+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
67

7-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-64 --write_result benchmark_results.txt
8-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-64 --write_result benchmark_results.txt
9-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-None --write_result benchmark_results.txt
10-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-None --write_result benchmark_results.txt
11-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-8-64 --write_result benchmark_results.txt
12-
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-8-64 --write_result benchmark_results.txt
138

149
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-64 --write_result benchmark_results.txt
1510
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-16-4-64 --write_result benchmark_results.txt
@@ -105,7 +100,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
105100
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt
106101
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-None --write_result benchmark_results.txt
107102

108-
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
103+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
109104
# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
110105
# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16
111106
# # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt
@@ -148,16 +143,16 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
148143
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 32
149144
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 128
150145

151-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
152-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16
153-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt
154-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-64 --write_result benchmark_results.txt
155-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-64 --write_result benchmark_results.txt
156-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-None --write_result benchmark_results.txt
157-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-None --write_result benchmark_results.txt
158-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-8-None --write_result benchmark_results.txt
159-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-8-None --write_result benchmark_results.txt
160-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt
146+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
147+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16
148+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt
149+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --num_samples 1
150+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --num_samples 1
151+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --num_samples 1 #not working
152+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --num_samples 1
153+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --num_samples 1
154+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --num_samples 1
155+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt
161156
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-None --write_result benchmark_results.txt
162157

163158
# # 2:4 sparse model
@@ -169,24 +164,24 @@ python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/mode
169164
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 32
170165
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 128
171166

172-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --batch_size 8
173-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16 --batch_size 8
174-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt --batch_size 8
175-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-64 --write_result benchmark_results.txt --batch_size 8
176-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-64 --write_result benchmark_results.txt --batch_size 8
177-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-None --write_result benchmark_results.txt --batch_size 8
178-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-None --write_result benchmark_results.txt --batch_size 8
179-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-8-None --write_result benchmark_results.txt --batch_size 8
180-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-8-None --write_result benchmark_results.txt --batch_size 8
181-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt --batch_size 8
182-
183-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --batch_size 32
184-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16 --batch_size 32
185-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt --batch_size 32
186-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-64 --write_result benchmark_results.txt --batch_size 32
187-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-64 --write_result benchmark_results.txt --batch_size 32
188-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-None --write_result benchmark_results.txt --batch_size 32
189-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-None --write_result benchmark_results.txt --batch_size 32
190-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-8-None --write_result benchmark_results.txt --batch_size 32
191-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-8-None --write_result benchmark_results.txt --batch_size 32
192-
python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt --batch_size 32
167+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --batch_size 8
168+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16 --batch_size 8
169+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt --batch_size 8
170+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 8 --num_samples 1
171+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 8 --num_samples 1
172+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 8 --num_samples 1
173+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 8 --num_samples 1
174+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 8 --num_samples 1
175+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 8 --num_samples 1
176+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt --batch_size 8
177+
178+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --batch_size 32
179+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --precision float16 --batch_size 32
180+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 --write_result benchmark_results.txt --batch_size 32
181+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-64 --write_result benchmark_results.txt --batch_size 32
182+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-64 --write_result benchmark_results.txt --batch_size 32
183+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-4-None --write_result benchmark_results.txt --batch_size 32
184+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-4-None --write_result benchmark_results.txt --batch_size 32
185+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-8-8-None --write_result benchmark_results.txt --batch_size 32
186+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision float16 --quantization gemsub-32-8-None --write_result benchmark_results.txt --batch_size 32
187+
# python generate.py --compile --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --write_result benchmark_results.txt --batch_size 32

0 commit comments

Comments
 (0)