Skip to content

Commit 0c2619c

Browse files
authored
Extend mxfp loading dtypes (#907)
1 parent 56f469e commit 0c2619c

File tree

4 files changed

+87
-12
lines changed

4 files changed

+87
-12
lines changed

auto_round/inference/backend.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ class BackendInfo:
107107
"act_dynamic",
108108
]
109109

110+
MX_TENSOR_DATA_TYPES = [
111+
"mx_fp",
112+
"mx_fp_rceil",
113+
]
114+
110115

111116
def feature_multiply_checker(in_feature, out_feature, config, in_feature_multiplier, out_feature_multiplier=None):
112117
if out_feature_multiplier is None:
@@ -230,13 +235,13 @@ def fp8_static_scheme_checker(
230235
packing_format=LLM_COMPRESSOR_FORMAT,
231236
sym=[True],
232237
compute_dtype=["float32", "float16", "bfloat16"],
233-
data_type=["mx_fp", "max_fp_rceil"],
238+
data_type=MX_TENSOR_DATA_TYPES,
234239
group_size=[32],
235240
bits=[8],
236241
act_bits=[8],
237242
act_group_size=[32],
238243
act_sym=[True],
239-
act_data_type=["mx_fp_rceil"],
244+
act_data_type=MX_TENSOR_DATA_TYPES,
240245
act_dynamic=[True],
241246
priority=0,
242247
checkers=[feature_multiply_checker_32],
@@ -250,13 +255,13 @@ def fp8_static_scheme_checker(
250255
packing_format=LLM_COMPRESSOR_FORMAT,
251256
sym=[True],
252257
compute_dtype=["float32", "float16", "bfloat16"],
253-
data_type=["mx_fp"],
258+
data_type=MX_TENSOR_DATA_TYPES,
254259
group_size=[32],
255260
bits=[4],
256261
act_bits=[4],
257262
act_group_size=[32],
258263
act_sym=[True],
259-
act_data_type=["mx_fp_rceil"],
264+
act_data_type=MX_TENSOR_DATA_TYPES,
260265
act_dynamic=[True],
261266
priority=0,
262267
checkers=[feature_multiply_checker_32],

auto_round/testing_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,11 @@ def decorator(test_func: Callable) -> Callable:
268268
return unittest.skipUnless(require_package_version(package, version_spec, on_fail="skip"), reason)(test_func)
269269

270270
return decorator
271+
272+
273+
def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
274+
"""Check if the model contains a specific module type."""
275+
for _, module in model.named_modules():
276+
if isinstance(module, target_module_type):
277+
return True
278+
return False
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import shutil
2+
import tempfile
3+
4+
import pytest
5+
import torch
6+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
7+
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
8+
9+
from auto_round import AutoRound
10+
from auto_round import schemes as ar_schemes
11+
from auto_round.experimental import qmodules as ar_qmodules
12+
from auto_round.export.export_to_autoround import AutoRoundFormat
13+
from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp
14+
from auto_round.inference.backend import MX_TENSOR_DATA_TYPES
15+
from auto_round.testing_utils import has_module
16+
17+
testing_scheme_name_lst = [
18+
AutoRoundFormat.MXFP8.value,
19+
AutoRoundFormat.MXFP4.value,
20+
]
21+
QMODULE_MAPPING = {
22+
AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear,
23+
AutoRoundFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear,
24+
}
25+
SCHEMES_MAPPING = {
26+
AutoRoundFormat.MXFP8.value: ar_schemes.MXFP8,
27+
AutoRoundFormat.MXFP4.value: ar_schemes.MXFP4,
28+
}
29+
30+
31+
@pytest.mark.parametrize("scheme_name", testing_scheme_name_lst)
32+
@pytest.mark.parametrize("weight_data_type", MX_TENSOR_DATA_TYPES)
33+
@pytest.mark.parametrize("act_data_type", MX_TENSOR_DATA_TYPES)
34+
@torch.inference_mode()
35+
def test_e2e_quant_and_load(scheme_name, weight_data_type, act_data_type):
36+
# Use a temporary directory for saving the quantized model
37+
with tempfile.TemporaryDirectory() as temp_dir:
38+
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct"
39+
config = AutoConfig.from_pretrained(model_name)
40+
config.num_hidden_layers = 2 # Use a smaller model for testing
41+
42+
# Load the tokenizer and model
43+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
44+
model = Qwen2ForCausalLM(config)
45+
scheme = SCHEMES_MAPPING[scheme_name]
46+
scheme.data_type = weight_data_type
47+
scheme.act_data_type = act_data_type
48+
# Initialize AutoRound for quantization
49+
autoround = AutoRound(
50+
model,
51+
tokenizer,
52+
scheme=scheme,
53+
iters=0,
54+
nsamples=2,
55+
)
56+
57+
# Quantize and save the model to the temporary directory
58+
quantized_model_path = f"{temp_dir}/tmp_autoround"
59+
autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path)
60+
61+
# Perform inference with the quantized model
62+
model = AutoModelForCausalLM.from_pretrained(
63+
quantized_model_path,
64+
torch_dtype="auto",
65+
)
66+
model.eval()
67+
assert has_module(
68+
model, QMODULE_MAPPING[scheme_name]
69+
), f"Expected {QMODULE_MAPPING[scheme_name].__name__} in the model."

test/test_cuda/test_mxfp_and_nvfp_quant.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from auto_round.experimental import qmodules as ar_qmodules
1111
from auto_round.export.export_to_autoround import AutoRoundFormat
1212
from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp
13+
from auto_round.testing_utils import has_module
1314

1415
testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value]
1516
QMODULE_MAPPING = {
@@ -19,14 +20,6 @@
1920
}
2021

2122

22-
def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
23-
"""Check if the model contains a specific module type."""
24-
for _, module in model.named_modules():
25-
if isinstance(module, target_module_type):
26-
return True
27-
return False
28-
29-
3023
@pytest.mark.parametrize("scheme", testing_schemes)
3124
@torch.inference_mode()
3225
def test_e2e_quant_and_infer(scheme):

0 commit comments

Comments
 (0)