Skip to content

Commit 8e94408

Browse files
committed
[CPU] Fix AWQ on CPU after refactoring
1 parent 7dbc816 commit 8e94408

File tree

3 files changed

+154
-60
lines changed

3 files changed

+154
-60
lines changed

test/prototype/test_awq.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66
import copy
77
import tempfile
8-
import unittest
98

109
import torch
10+
from parameterized import parameterized
1111
from torch.testing._internal.common_utils import (
1212
TestCase,
1313
run_tests,
1414
)
1515

16+
from torchao.dtypes import Int4CPULayout
1617
from torchao.prototype.awq import AWQConfig, AWQStep
1718
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
1819
from torchao.utils import (
@@ -45,15 +46,15 @@ def forward(self, x):
4546
return x
4647

4748

48-
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
49-
@unittest.skipIf(
50-
not _is_fbgemm_genai_gpu_available(),
51-
reason="need to install fbgemm_gpu_genai package",
52-
)
53-
@unittest.skipIf(
54-
not TORCH_VERSION_AT_LEAST_2_6,
55-
reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig",
56-
)
49+
devices = ["cpu"]
50+
if (
51+
torch.cuda.is_available()
52+
and _is_fbgemm_genai_gpu_available()
53+
and TORCH_VERSION_AT_LEAST_2_6
54+
):
55+
devices.append("cuda")
56+
57+
5758
class TestAWQ(TestCase):
5859
def test_awq_config(self):
5960
base_config = Int4WeightOnlyConfig()
@@ -68,8 +69,8 @@ def test_awq_config(self):
6869
with self.assertRaisesRegex(ValueError, "is not one of"):
6970
AWQConfig(base_config, step="not_supported")
7071

71-
def test_awq_functionality(self):
72-
device = "cuda"
72+
@parameterized.expand([(device,) for device in devices])
73+
def test_awq_functionality(self, device):
7374
dataset_size = 100
7475
l1, l2, l3 = 512, 256, 128
7576
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -80,13 +81,21 @@ def test_awq_functionality(self):
8081
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
8182

8283
# baseline quantization
83-
base_config = FbgemmConfig(
84-
input_dtype=torch.bfloat16,
85-
weight_dtype=torch.int4,
86-
output_dtype=torch.bfloat16,
87-
block_size=[1, group_size],
88-
preshuffle=False,
89-
)
84+
if device == "cuda":
85+
base_config = FbgemmConfig(
86+
input_dtype=torch.bfloat16,
87+
weight_dtype=torch.int4,
88+
output_dtype=torch.bfloat16,
89+
block_size=[1, group_size],
90+
preshuffle=False,
91+
)
92+
elif device == "cpu":
93+
base_config = Int4WeightOnlyConfig(
94+
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
95+
)
96+
torch.manual_seed(1234)
97+
else:
98+
assert False, "Unsupported device: {}".format(device)
9099
m_baseline = copy.deepcopy(m)
91100
quantize_(m_baseline, base_config)
92101

@@ -117,8 +126,8 @@ def test_awq_functionality(self):
117126
loss_base = (ref_out - baseline_out).pow(2).mean().item()
118127
assert loss_awq < loss_base
119128

120-
def test_awq_loading(self):
121-
device = "cuda"
129+
@parameterized.expand([(device,) for device in devices])
130+
def test_awq_loading(self, device):
122131
dataset_size = 100
123132
l1, l2, l3 = 512, 256, 128
124133
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -136,13 +145,20 @@ def test_awq_loading(self):
136145
calibration_data = dataset[:n_calibration_examples]
137146

138147
# calibrate
139-
base_config = FbgemmConfig(
140-
input_dtype=torch.bfloat16,
141-
weight_dtype=torch.int4,
142-
output_dtype=torch.bfloat16,
143-
block_size=[1, group_size],
144-
preshuffle=False,
145-
)
148+
if device == "cuda":
149+
base_config = FbgemmConfig(
150+
input_dtype=torch.bfloat16,
151+
weight_dtype=torch.int4,
152+
output_dtype=torch.bfloat16,
153+
block_size=[1, group_size],
154+
preshuffle=False,
155+
)
156+
elif device == "cpu":
157+
base_config = Int4WeightOnlyConfig(
158+
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
159+
)
160+
else:
161+
assert False, "Unsupported device: {}".format(device)
146162
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
147163
quantize_(m, quant_config)
148164

@@ -171,14 +187,14 @@ def test_awq_loading(self):
171187
assert awq_save_load_out is not None
172188
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
173189

174-
def test_awq_loading_vllm(self):
190+
@parameterized.expand([(device,) for device in devices])
191+
def test_awq_loading_vllm(self, device):
175192
"""Simulate weight loading in vllm:
176193
* prepare model weight to the same format (awq weight)
177194
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
178195
179196
There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
180197
"""
181-
device = "cuda"
182198
dataset_size = 100
183199
l1, l2, l3 = 512, 256, 128
184200
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
@@ -196,13 +212,20 @@ def test_awq_loading_vllm(self):
196212
calibration_data = dataset[:n_calibration_examples]
197213

198214
# calibrate
199-
base_config = FbgemmConfig(
200-
input_dtype=torch.bfloat16,
201-
weight_dtype=torch.int4,
202-
output_dtype=torch.bfloat16,
203-
block_size=[1, group_size],
204-
preshuffle=False,
205-
)
215+
if device == "cuda":
216+
base_config = FbgemmConfig(
217+
input_dtype=torch.bfloat16,
218+
weight_dtype=torch.int4,
219+
output_dtype=torch.bfloat16,
220+
block_size=[1, group_size],
221+
preshuffle=False,
222+
)
223+
elif device == "cpu":
224+
base_config = Int4WeightOnlyConfig(
225+
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
226+
)
227+
else:
228+
assert False, "Unsupported device: {}".format(device)
206229
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
207230
quantize_(m, quant_config)
208231

torchao/dtypes/uintx/int4_cpu_layout.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@
3030
aten = torch.ops.aten
3131

3232

33+
def _same_metadata(self: "Int4CPUAQTTensorImpl", src: "Int4CPUAQTTensorImpl") -> bool:
34+
return (
35+
isinstance(self, Int4CPUAQTTensorImpl)
36+
and isinstance(src, Int4CPUAQTTensorImpl)
37+
and self.packed_weight.shape == src.packed_weight.shape
38+
and self.scale_and_zero.shape == src.scale_and_zero.shape
39+
and self.transposed == src.transposed
40+
and type(self._layout) == type(src._layout)
41+
)
42+
43+
3344
@dataclass(frozen=True)
3445
class Int4CPULayout(Layout):
3546
"""Layout class for int4 CPU layout for affine quantized tensor, used by tinygemm kernels `_weight_int4pack_mm_for_cpu`.
@@ -208,6 +219,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
208219
f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported"
209220
)
210221

222+
if func is aten.copy_.default:
223+
self = args[0]
224+
src = args[1]
225+
if _same_metadata(self, src):
226+
self_tensors = self.__tensor_flatten__()[0]
227+
for tensor_name in self_tensors:
228+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
229+
return
230+
raise ValueError(
231+
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
232+
)
233+
211234
raise NotImplementedError(
212235
f"{cls.__name__} dispatch: attempting to run {func}, this is not supported"
213236
)

torchao/prototype/awq/example.py

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def wiki2_eval(
9393

9494

9595
# adapted from Hicham Badri (@mobicham)
96-
def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
96+
def benchmark(
97+
model, tokenizer, max_length, tasks=None, evaluation_limit=None, device="cuda"
98+
):
9799
import lm_eval
98100
import numpy as np
99101

@@ -126,21 +128,33 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
126128
for task in [("truthfulqa_mc2", 0)]:
127129
tag, fewshot = task
128130
results[tag] = lm_eval.evaluator.simple_evaluate(
129-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
131+
model_eval,
132+
tasks=[tag],
133+
num_fewshot=fewshot,
134+
batch_size=eval_batch_size,
135+
limit=evaluation_limit,
130136
)["results"]
131137
print(tag, results[tag])
132138
if "winogrande" in tasks:
133139
for task in [("winogrande", 5)]:
134140
tag, fewshot = task
135141
results[tag] = lm_eval.evaluator.simple_evaluate(
136-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
142+
model_eval,
143+
tasks=[tag],
144+
num_fewshot=fewshot,
145+
batch_size=eval_batch_size,
146+
limit=evaluation_limit,
137147
)["results"]
138148
print(tag, results[tag])
139149
if "arc_challenge" in tasks:
140150
for task in [("arc_challenge", 25)]:
141151
tag, fewshot = task
142152
results[tag] = lm_eval.evaluator.simple_evaluate(
143-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
153+
model_eval,
154+
tasks=[tag],
155+
num_fewshot=fewshot,
156+
batch_size=eval_batch_size,
157+
limit=evaluation_limit,
144158
)["results"]
145159
print(tag, results[tag])
146160

@@ -149,14 +163,22 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
149163
for task in [("hellaswag", 10)]:
150164
tag, fewshot = task
151165
results[tag] = lm_eval.evaluator.simple_evaluate(
152-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
166+
model_eval,
167+
tasks=[tag],
168+
num_fewshot=fewshot,
169+
batch_size=eval_batch_size,
170+
limit=evaluation_limit,
153171
)["results"]
154172
print(tag, results[tag])
155173
if "gsm8k" in tasks:
156174
for task in [("gsm8k", 5)]:
157175
tag, fewshot = task
158176
results[tag] = lm_eval.evaluator.simple_evaluate(
159-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
177+
model_eval,
178+
tasks=[tag],
179+
num_fewshot=fewshot,
180+
batch_size=eval_batch_size,
181+
limit=evaluation_limit,
160182
)["results"]
161183
print(tag, results[tag])
162184
# ############################################
@@ -167,7 +189,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
167189
for task in [("mmlu", 5)]:
168190
tag, fewshot = task
169191
results_mmlu[tag] = lm_eval.evaluator.simple_evaluate(
170-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
192+
model_eval,
193+
tasks=[tag],
194+
num_fewshot=fewshot,
195+
batch_size=eval_batch_size,
196+
limit=evaluation_limit,
171197
)["results"]
172198
print(tag, results_mmlu[tag])
173199

@@ -188,7 +214,11 @@ def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
188214
for task in [("leaderboard_bbh", 3)]:
189215
tag, fewshot = task
190216
results[tag] = lm_eval.evaluator.simple_evaluate(
191-
model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size
217+
model_eval,
218+
tasks=[tag],
219+
num_fewshot=fewshot,
220+
batch_size=eval_batch_size,
221+
limit=evaluation_limit,
192222
)["results"]
193223
print(tag, results[tag])
194224
results["bbh"] = results[tag]
@@ -202,7 +232,7 @@ def quantize_and_eval(
202232
tasks: list[str],
203233
max_seq_length: int,
204234
calibration_limit: int,
205-
validation_size: int,
235+
evaluation_limit: int,
206236
device: str,
207237
precision: torch.dtype,
208238
compile: bool,
@@ -223,18 +253,26 @@ def quantize_and_eval(
223253
if quant.startswith("awq-int4wo"):
224254
group_size = int(quant.split("-")[2])
225255
print(f"running {quant} quantization with group size {group_size}")
226-
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
227-
from torchao.quantization import FbgemmConfig
256+
from torchao.dtypes import Int4CPULayout
257+
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig
228258

229259
# use_hqq = True
230260
# base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
231-
base_config = FbgemmConfig(
232-
input_dtype=torch.bfloat16,
233-
weight_dtype=torch.int4,
234-
output_dtype=torch.bfloat16,
235-
block_size=[1, group_size],
236-
preshuffle=False,
237-
)
261+
if device == "cuda":
262+
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig for CUDA soon
263+
base_config = FbgemmConfig(
264+
input_dtype=torch.bfloat16,
265+
weight_dtype=torch.int4,
266+
output_dtype=torch.bfloat16,
267+
block_size=[1, group_size],
268+
preshuffle=False,
269+
)
270+
elif device == "cpu":
271+
base_config = Int4WeightOnlyConfig(
272+
group_size=group_size, layout=Int4CPULayout(), set_inductor_config=False
273+
)
274+
else:
275+
assert False, "Unsupported device: {}".format(device)
238276
print(f"running {quant} prepare and calibrate")
239277
t0 = time.time()
240278
quant_config = AWQConfig(base_config, step="prepare")
@@ -291,7 +329,14 @@ def quantize_and_eval(
291329
if compile:
292330
model = torch.compile(model)
293331

294-
return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device)
332+
return benchmark(
333+
model,
334+
tokenizer,
335+
max_seq_length,
336+
tasks=tasks,
337+
evaluation_limit=evaluation_limit,
338+
device=device,
339+
)
295340

296341

297342
if __name__ == "__main__":
@@ -310,8 +355,8 @@ def quantize_and_eval(
310355
"--tasks",
311356
nargs="+",
312357
type=str,
313-
help="Task to benchmark model on. Either PPL or QA",
314-
default=["PPL"],
358+
help="Task to benchmark model on. Here is the list of tasks you can use: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md",
359+
default=["hellaswag"],
315360
)
316361
parser.add_argument(
317362
"--calibration_limit",
@@ -320,7 +365,10 @@ def quantize_and_eval(
320365
help="Number of samples to use for calibration. Default is 10.",
321366
)
322367
parser.add_argument(
323-
"--validation_size", type=int, default=1, help="Validation size. Default is 1."
368+
"--evaluation_limit",
369+
type=int,
370+
default=None,
371+
help="Number of samples to use for evaluation. Default is None (all).",
324372
)
325373
parser.add_argument(
326374
"--device",
@@ -368,7 +416,7 @@ def quantize_and_eval(
368416
args.tasks,
369417
args.max_seq_length,
370418
args.calibration_limit,
371-
args.validation_size,
419+
args.evaluation_limit,
372420
args.device,
373421
args.precision,
374422
args.compile,

0 commit comments

Comments
 (0)